Allow building new head

This commit is contained in:
mbsantiago 2026-03-18 17:44:35 +00:00
parent 8e35956007
commit ebe7e134e9
7 changed files with 211 additions and 23 deletions

View File

@ -1,5 +1,5 @@
from pathlib import Path
from typing import Sequence
from typing import Literal, Sequence, cast
import numpy as np
import torch
@ -19,7 +19,8 @@ from batdetect2.evaluate import (
)
from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR
from batdetect2.models import Model, build_model
from batdetect2.models import Model, build_model, build_model_with_new_targets
from batdetect2.models.detectors import Detector
from batdetect2.outputs import (
OutputFormatConfig,
OutputFormatterProtocol,
@ -109,6 +110,46 @@ class BatDetect2API:
)
return self
def finetune(
self,
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Sequence[data.ClipAnnotation] | None = None,
trainable: Literal[
"all", "heads", "classifier_head", "bbox_head"
] = "heads",
train_workers: int = 0,
val_workers: int = 0,
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
log_dir: Path | None = DEFAULT_LOGS_DIR,
experiment_name: str | None = None,
num_epochs: int | None = None,
run_name: str | None = None,
seed: int | None = None,
) -> "BatDetect2API":
"""Fine-tune the model with trainable-parameter selection."""
self._set_trainable_parameters(trainable)
run_train(
train_annotations=train_annotations,
val_annotations=val_annotations,
model=self.model,
targets=self.targets,
model_config=self.config.model,
train_config=self.config.train,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
train_workers=train_workers,
val_workers=val_workers,
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
experiment_name=experiment_name,
num_epochs=num_epochs,
run_name=run_name,
seed=seed,
)
return self
def evaluate(
self,
test_annotations: Sequence[data.ClipAnnotation],
@ -410,6 +451,7 @@ class BatDetect2API:
cls,
path: data.PathLike,
config: BatDetect2Config | None = None,
targets: TargetProtocol | None = None,
) -> "BatDetect2API":
from batdetect2.audio import AudioConfig
@ -423,7 +465,12 @@ class BatDetect2API:
)
config = merge_configs(base, config) if config else base
if targets is None:
targets = build_targets(config=config.model.targets)
else:
target_config = getattr(targets, "config", None)
if target_config is not None:
config.model.targets = target_config
audio_loader = build_audio_loader(config=config.audio)
@ -452,6 +499,16 @@ class BatDetect2API:
transform=output_transform,
)
targets_changed = targets is not None or (
config.model.targets.model_dump(mode="json")
!= model_config.targets.model_dump(mode="json")
)
if targets_changed:
model = build_model_with_new_targets(
model=model,
targets=targets,
)
model.preprocessor = preprocessor
model.postprocessor = postprocessor
model.targets = targets
@ -467,3 +524,25 @@ class BatDetect2API:
formatter=formatter,
output_transform=output_transform,
)
def _set_trainable_parameters(
self,
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
) -> None:
detector = cast(Detector, self.model.detector)
for parameter in detector.parameters():
parameter.requires_grad = False
if trainable == "all":
for parameter in detector.parameters():
parameter.requires_grad = True
return
if trainable in {"heads", "classifier_head"}:
for parameter in detector.classifier_head.parameters():
parameter.requires_grad = True
if trainable in {"heads", "bbox_head"}:
for parameter in detector.bbox_head.parameters():
parameter.requires_grad = True

View File

@ -100,6 +100,7 @@ __all__ = [
"Model",
"ModelConfig",
"build_model",
"build_model_with_new_targets",
]
@ -274,3 +275,21 @@ def build_model(
preprocessor=preprocessor,
targets=targets,
)
def build_model_with_new_targets(
model: Model,
targets: TargetProtocol,
) -> Model:
"""Build a new model with a different target set."""
detector = build_detector(
num_classes=len(targets.class_names),
backbone=model.detector.backbone,
)
return Model(
detector=detector,
postprocessor=model.postprocessor,
preprocessor=model.preprocessor,
targets=targets,
)

View File

@ -135,7 +135,9 @@ class Detector(DetectionModel):
def build_detector(
num_classes: int, config: BackboneConfig | None = None
num_classes: int,
config: BackboneConfig | None = None,
backbone: BackboneModel | None = None,
) -> DetectionModel:
"""Build a complete BatDetect2 detection model.
@ -165,13 +167,14 @@ def build_detector(
If ``num_classes`` is not positive, or if the backbone
configuration is invalid.
"""
if backbone is None:
config = config or UNetBackboneConfig()
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(), # type: ignore
)
backbone = build_backbone(config=config)
classifier_head = ClassifierHead(
num_classes=num_classes,
in_channels=backbone.out_channels,

View File

@ -82,5 +82,9 @@ class EncoderDecoderModel(BackboneModel):
class DetectionModel(ABC, torch.nn.Module):
backbone: BackboneModel
classifier_head: torch.nn.Module
bbox_head: torch.nn.Module
@abstractmethod
def forward(self, spec: torch.Tensor) -> ModelOutput: ...

View File

@ -68,8 +68,17 @@ class TrainingModule(L.LightningModule):
return outputs
def configure_optimizers(self):
trainable_parameters = [
parameter
for parameter in self.parameters()
if parameter.requires_grad
]
if not trainable_parameters:
raise ValueError("No trainable parameters available.")
optimizer = build_optimizer(
self.parameters(),
trainable_parameters,
config=self.train_config.optimizer,
)
scheduler = build_scheduler(

View File

@ -1,4 +1,5 @@
from pathlib import Path
from typing import cast
import lightning as L
import numpy as np
@ -8,6 +9,9 @@ from soundevent.geometry import compute_bounds
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import BatDetect2Config
from batdetect2.models.detectors import Detector
from batdetect2.models.heads import ClassifierHead
from batdetect2.train import load_model_from_checkpoint
from batdetect2.train.lightning import build_training_module
@ -213,6 +217,77 @@ def test_user_can_load_checkpoint_and_finetune(
assert checkpoints
def test_user_can_load_checkpoint_with_new_targets(
tmp_path: Path,
sample_targets,
) -> None:
"""User story: start from checkpoint with a new target definition."""
module = build_training_module(model_config=BatDetect2Config().model)
trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "base_transfer.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(checkpoint_path)
source_model, _ = load_model_from_checkpoint(checkpoint_path)
api = BatDetect2API.from_checkpoint(
checkpoint_path,
targets=sample_targets,
)
source_detector = cast(Detector, source_model.detector)
detector = cast(Detector, api.model.detector)
classifier_head = cast(ClassifierHead, detector.classifier_head)
assert api.targets is sample_targets
assert detector.num_classes == len(sample_targets.class_names)
assert (
classifier_head.classifier.out_channels
== len(sample_targets.class_names) + 1
)
source_backbone = source_detector.backbone.state_dict()
target_backbone = detector.backbone.state_dict()
assert source_backbone
for key, value in source_backbone.items():
assert key in target_backbone
torch.testing.assert_close(target_backbone[key], value)
def test_user_can_finetune_only_heads(
tmp_path: Path,
example_annotations,
) -> None:
"""User story: fine-tune only prediction heads."""
api = BatDetect2API.from_config(BatDetect2Config())
finetune_dir = tmp_path / "heads_only"
api.finetune(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
trainable="heads",
train_workers=0,
val_workers=0,
checkpoint_dir=finetune_dir,
log_dir=tmp_path / "logs",
num_epochs=1,
seed=0,
)
detector = cast(Detector, api.model.detector)
backbone_params = list(detector.backbone.parameters())
classifier_params = list(detector.classifier_head.parameters())
bbox_params = list(detector.bbox_head.parameters())
assert backbone_params
assert classifier_params
assert bbox_params
assert all(not parameter.requires_grad for parameter in backbone_params)
assert all(parameter.requires_grad for parameter in classifier_params)
assert all(parameter.requires_grad for parameter in bbox_params)
assert list(finetune_dir.rglob("*.ckpt"))
def test_user_can_evaluate_small_dataset_and_get_metrics(
api_v2: BatDetect2API,
example_annotations,

View File

@ -1,5 +1,4 @@
from pathlib import Path
from typing import List, Optional
import numpy as np
import pytest
@ -57,10 +56,10 @@ def dummy_targets() -> TargetProtocol:
def encode_class(
self, sound_event: data.SoundEventAnnotation
) -> Optional[str]:
) -> str | None:
return "bat"
def decode_class(self, class_label: str) -> List[data.Tag]:
def decode_class(self, class_label: str) -> list[data.Tag]:
return tag_map.get(class_label.lower(), [])
def encode_roi(self, sound_event: data.SoundEventAnnotation):
@ -70,7 +69,7 @@ def dummy_targets() -> TargetProtocol:
self,
position,
size: np.ndarray,
class_name: Optional[str] = None,
class_name: str | None = None,
):
time, freq = position
width, height = size
@ -210,7 +209,7 @@ def empty_detection_dataset() -> xr.Dataset:
@pytest.fixture
def sample_raw_predictions() -> List[Detection]:
def sample_raw_predictions() -> list[Detection]:
"""Manually crafted RawPrediction objects using the actual type."""
pred1_classes = xr.DataArray(
@ -282,7 +281,7 @@ def sample_raw_predictions() -> List[Detection]:
def test_convert_raw_to_sound_event_basic(
sample_raw_predictions: List[Detection],
sample_raw_predictions: list[Detection],
sample_recording: data.Recording,
dummy_targets: TargetProtocol,
):
@ -325,7 +324,7 @@ def test_convert_raw_to_sound_event_basic(
def test_convert_raw_to_sound_event_thresholding(
sample_raw_predictions: List[Detection],
sample_raw_predictions: list[Detection],
sample_recording: data.Recording,
dummy_targets: TargetProtocol,
):
@ -353,7 +352,7 @@ def test_convert_raw_to_sound_event_thresholding(
def test_convert_raw_to_sound_event_no_threshold(
sample_raw_predictions: List[Detection],
sample_raw_predictions: list[Detection],
sample_recording: data.Recording,
dummy_targets: TargetProtocol,
):
@ -381,7 +380,7 @@ def test_convert_raw_to_sound_event_no_threshold(
def test_convert_raw_to_sound_event_top_class(
sample_raw_predictions: List[Detection],
sample_raw_predictions: list[Detection],
sample_recording: data.Recording,
dummy_targets: TargetProtocol,
):
@ -408,7 +407,7 @@ def test_convert_raw_to_sound_event_top_class(
def test_convert_raw_to_sound_event_all_below_threshold(
sample_raw_predictions: List[Detection],
sample_raw_predictions: list[Detection],
sample_recording: data.Recording,
dummy_targets: TargetProtocol,
):
@ -434,7 +433,7 @@ def test_convert_raw_to_sound_event_all_below_threshold(
def test_convert_raw_list_to_clip_basic(
sample_raw_predictions: List[Detection],
sample_raw_predictions: list[Detection],
sample_clip: data.Clip,
dummy_targets: TargetProtocol,
):
@ -486,7 +485,7 @@ def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets):
def test_convert_raw_list_to_clip_passes_args(
sample_raw_predictions: List[Detection],
sample_raw_predictions: list[Detection],
sample_clip: data.Clip,
dummy_targets: TargetProtocol,
):