diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index ecf0964..43c2c3f 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Sequence +from typing import Literal, Sequence, cast import numpy as np import torch @@ -19,7 +19,8 @@ from batdetect2.evaluate import ( ) from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.logging import DEFAULT_LOGS_DIR -from batdetect2.models import Model, build_model +from batdetect2.models import Model, build_model, build_model_with_new_targets +from batdetect2.models.detectors import Detector from batdetect2.outputs import ( OutputFormatConfig, OutputFormatterProtocol, @@ -109,6 +110,46 @@ class BatDetect2API: ) return self + def finetune( + self, + train_annotations: Sequence[data.ClipAnnotation], + val_annotations: Sequence[data.ClipAnnotation] | None = None, + trainable: Literal[ + "all", "heads", "classifier_head", "bbox_head" + ] = "heads", + train_workers: int = 0, + val_workers: int = 0, + checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR, + log_dir: Path | None = DEFAULT_LOGS_DIR, + experiment_name: str | None = None, + num_epochs: int | None = None, + run_name: str | None = None, + seed: int | None = None, + ) -> "BatDetect2API": + """Fine-tune the model with trainable-parameter selection.""" + + self._set_trainable_parameters(trainable) + + run_train( + train_annotations=train_annotations, + val_annotations=val_annotations, + model=self.model, + targets=self.targets, + model_config=self.config.model, + train_config=self.config.train, + preprocessor=self.preprocessor, + audio_loader=self.audio_loader, + train_workers=train_workers, + val_workers=val_workers, + checkpoint_dir=checkpoint_dir, + log_dir=log_dir, + experiment_name=experiment_name, + num_epochs=num_epochs, + run_name=run_name, + seed=seed, + ) + return self + def evaluate( self, test_annotations: Sequence[data.ClipAnnotation], @@ -410,6 +451,7 @@ class BatDetect2API: cls, path: data.PathLike, config: BatDetect2Config | None = None, + targets: TargetProtocol | None = None, ) -> "BatDetect2API": from batdetect2.audio import AudioConfig @@ -423,7 +465,12 @@ class BatDetect2API: ) config = merge_configs(base, config) if config else base - targets = build_targets(config=config.model.targets) + if targets is None: + targets = build_targets(config=config.model.targets) + else: + target_config = getattr(targets, "config", None) + if target_config is not None: + config.model.targets = target_config audio_loader = build_audio_loader(config=config.audio) @@ -452,6 +499,16 @@ class BatDetect2API: transform=output_transform, ) + targets_changed = targets is not None or ( + config.model.targets.model_dump(mode="json") + != model_config.targets.model_dump(mode="json") + ) + if targets_changed: + model = build_model_with_new_targets( + model=model, + targets=targets, + ) + model.preprocessor = preprocessor model.postprocessor = postprocessor model.targets = targets @@ -467,3 +524,25 @@ class BatDetect2API: formatter=formatter, output_transform=output_transform, ) + + def _set_trainable_parameters( + self, + trainable: Literal["all", "heads", "classifier_head", "bbox_head"], + ) -> None: + detector = cast(Detector, self.model.detector) + + for parameter in detector.parameters(): + parameter.requires_grad = False + + if trainable == "all": + for parameter in detector.parameters(): + parameter.requires_grad = True + return + + if trainable in {"heads", "classifier_head"}: + for parameter in detector.classifier_head.parameters(): + parameter.requires_grad = True + + if trainable in {"heads", "bbox_head"}: + for parameter in detector.bbox_head.parameters(): + parameter.requires_grad = True diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index 824479f..c09a0ba 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -100,6 +100,7 @@ __all__ = [ "Model", "ModelConfig", "build_model", + "build_model_with_new_targets", ] @@ -274,3 +275,21 @@ def build_model( preprocessor=preprocessor, targets=targets, ) + + +def build_model_with_new_targets( + model: Model, + targets: TargetProtocol, +) -> Model: + """Build a new model with a different target set.""" + detector = build_detector( + num_classes=len(targets.class_names), + backbone=model.detector.backbone, + ) + + return Model( + detector=detector, + postprocessor=model.postprocessor, + preprocessor=model.preprocessor, + targets=targets, + ) diff --git a/src/batdetect2/models/detectors.py b/src/batdetect2/models/detectors.py index ed49496..12a6a77 100644 --- a/src/batdetect2/models/detectors.py +++ b/src/batdetect2/models/detectors.py @@ -135,7 +135,9 @@ class Detector(DetectionModel): def build_detector( - num_classes: int, config: BackboneConfig | None = None + num_classes: int, + config: BackboneConfig | None = None, + backbone: BackboneModel | None = None, ) -> DetectionModel: """Build a complete BatDetect2 detection model. @@ -165,13 +167,14 @@ def build_detector( If ``num_classes`` is not positive, or if the backbone configuration is invalid. """ - config = config or UNetBackboneConfig() + if backbone is None: + config = config or UNetBackboneConfig() + logger.opt(lazy=True).debug( + "Building model with config: \n{}", + lambda: config.to_yaml_string(), # type: ignore + ) + backbone = build_backbone(config=config) - logger.opt(lazy=True).debug( - "Building model with config: \n{}", - lambda: config.to_yaml_string(), # type: ignore - ) - backbone = build_backbone(config=config) classifier_head = ClassifierHead( num_classes=num_classes, in_channels=backbone.out_channels, diff --git a/src/batdetect2/models/types.py b/src/batdetect2/models/types.py index 5cf57b3..eb4302a 100644 --- a/src/batdetect2/models/types.py +++ b/src/batdetect2/models/types.py @@ -82,5 +82,9 @@ class EncoderDecoderModel(BackboneModel): class DetectionModel(ABC, torch.nn.Module): + backbone: BackboneModel + classifier_head: torch.nn.Module + bbox_head: torch.nn.Module + @abstractmethod def forward(self, spec: torch.Tensor) -> ModelOutput: ... diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 0c4d6e2..97e1b81 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -68,8 +68,17 @@ class TrainingModule(L.LightningModule): return outputs def configure_optimizers(self): + trainable_parameters = [ + parameter + for parameter in self.parameters() + if parameter.requires_grad + ] + + if not trainable_parameters: + raise ValueError("No trainable parameters available.") + optimizer = build_optimizer( - self.parameters(), + trainable_parameters, config=self.train_config.optimizer, ) scheduler = build_scheduler( diff --git a/tests/test_api_v2/test_api_v2.py b/tests/test_api_v2/test_api_v2.py index 72e2f3e..e4c2d25 100644 --- a/tests/test_api_v2/test_api_v2.py +++ b/tests/test_api_v2/test_api_v2.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import cast import lightning as L import numpy as np @@ -8,6 +9,9 @@ from soundevent.geometry import compute_bounds from batdetect2.api_v2 import BatDetect2API from batdetect2.config import BatDetect2Config +from batdetect2.models.detectors import Detector +from batdetect2.models.heads import ClassifierHead +from batdetect2.train import load_model_from_checkpoint from batdetect2.train.lightning import build_training_module @@ -213,6 +217,77 @@ def test_user_can_load_checkpoint_and_finetune( assert checkpoints +def test_user_can_load_checkpoint_with_new_targets( + tmp_path: Path, + sample_targets, +) -> None: + """User story: start from checkpoint with a new target definition.""" + + module = build_training_module(model_config=BatDetect2Config().model) + trainer = L.Trainer(enable_checkpointing=False, logger=False) + checkpoint_path = tmp_path / "base_transfer.ckpt" + trainer.strategy.connect(module) + trainer.save_checkpoint(checkpoint_path) + + source_model, _ = load_model_from_checkpoint(checkpoint_path) + api = BatDetect2API.from_checkpoint( + checkpoint_path, + targets=sample_targets, + ) + source_detector = cast(Detector, source_model.detector) + detector = cast(Detector, api.model.detector) + classifier_head = cast(ClassifierHead, detector.classifier_head) + + assert api.targets is sample_targets + assert detector.num_classes == len(sample_targets.class_names) + assert ( + classifier_head.classifier.out_channels + == len(sample_targets.class_names) + 1 + ) + + source_backbone = source_detector.backbone.state_dict() + target_backbone = detector.backbone.state_dict() + assert source_backbone + for key, value in source_backbone.items(): + assert key in target_backbone + torch.testing.assert_close(target_backbone[key], value) + + +def test_user_can_finetune_only_heads( + tmp_path: Path, + example_annotations, +) -> None: + """User story: fine-tune only prediction heads.""" + + api = BatDetect2API.from_config(BatDetect2Config()) + finetune_dir = tmp_path / "heads_only" + + api.finetune( + train_annotations=example_annotations[:1], + val_annotations=example_annotations[:1], + trainable="heads", + train_workers=0, + val_workers=0, + checkpoint_dir=finetune_dir, + log_dir=tmp_path / "logs", + num_epochs=1, + seed=0, + ) + detector = cast(Detector, api.model.detector) + + backbone_params = list(detector.backbone.parameters()) + classifier_params = list(detector.classifier_head.parameters()) + bbox_params = list(detector.bbox_head.parameters()) + + assert backbone_params + assert classifier_params + assert bbox_params + assert all(not parameter.requires_grad for parameter in backbone_params) + assert all(parameter.requires_grad for parameter in classifier_params) + assert all(parameter.requires_grad for parameter in bbox_params) + assert list(finetune_dir.rglob("*.ckpt")) + + def test_user_can_evaluate_small_dataset_and_get_metrics( api_v2: BatDetect2API, example_annotations, diff --git a/tests/test_postprocessing/test_decoding.py b/tests/test_postprocessing/test_decoding.py index 3152596..471c93a 100644 --- a/tests/test_postprocessing/test_decoding.py +++ b/tests/test_postprocessing/test_decoding.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import List, Optional import numpy as np import pytest @@ -57,10 +56,10 @@ def dummy_targets() -> TargetProtocol: def encode_class( self, sound_event: data.SoundEventAnnotation - ) -> Optional[str]: + ) -> str | None: return "bat" - def decode_class(self, class_label: str) -> List[data.Tag]: + def decode_class(self, class_label: str) -> list[data.Tag]: return tag_map.get(class_label.lower(), []) def encode_roi(self, sound_event: data.SoundEventAnnotation): @@ -70,7 +69,7 @@ def dummy_targets() -> TargetProtocol: self, position, size: np.ndarray, - class_name: Optional[str] = None, + class_name: str | None = None, ): time, freq = position width, height = size @@ -210,7 +209,7 @@ def empty_detection_dataset() -> xr.Dataset: @pytest.fixture -def sample_raw_predictions() -> List[Detection]: +def sample_raw_predictions() -> list[Detection]: """Manually crafted RawPrediction objects using the actual type.""" pred1_classes = xr.DataArray( @@ -282,7 +281,7 @@ def sample_raw_predictions() -> List[Detection]: def test_convert_raw_to_sound_event_basic( - sample_raw_predictions: List[Detection], + sample_raw_predictions: list[Detection], sample_recording: data.Recording, dummy_targets: TargetProtocol, ): @@ -325,7 +324,7 @@ def test_convert_raw_to_sound_event_basic( def test_convert_raw_to_sound_event_thresholding( - sample_raw_predictions: List[Detection], + sample_raw_predictions: list[Detection], sample_recording: data.Recording, dummy_targets: TargetProtocol, ): @@ -353,7 +352,7 @@ def test_convert_raw_to_sound_event_thresholding( def test_convert_raw_to_sound_event_no_threshold( - sample_raw_predictions: List[Detection], + sample_raw_predictions: list[Detection], sample_recording: data.Recording, dummy_targets: TargetProtocol, ): @@ -381,7 +380,7 @@ def test_convert_raw_to_sound_event_no_threshold( def test_convert_raw_to_sound_event_top_class( - sample_raw_predictions: List[Detection], + sample_raw_predictions: list[Detection], sample_recording: data.Recording, dummy_targets: TargetProtocol, ): @@ -408,7 +407,7 @@ def test_convert_raw_to_sound_event_top_class( def test_convert_raw_to_sound_event_all_below_threshold( - sample_raw_predictions: List[Detection], + sample_raw_predictions: list[Detection], sample_recording: data.Recording, dummy_targets: TargetProtocol, ): @@ -434,7 +433,7 @@ def test_convert_raw_to_sound_event_all_below_threshold( def test_convert_raw_list_to_clip_basic( - sample_raw_predictions: List[Detection], + sample_raw_predictions: list[Detection], sample_clip: data.Clip, dummy_targets: TargetProtocol, ): @@ -486,7 +485,7 @@ def test_convert_raw_list_to_clip_empty(sample_clip, dummy_targets): def test_convert_raw_list_to_clip_passes_args( - sample_raw_predictions: List[Detection], + sample_raw_predictions: list[Detection], sample_clip: data.Clip, dummy_targets: TargetProtocol, ):