diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 53230c7..ec068b1 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -19,7 +19,8 @@ if TYPE_CHECKING: LoggerConfig, LoggingCallback, ) - from batdetect2.models import Model, ModelConfig + from batdetect2.models import ModelConfig + from batdetect2.models.types import ModelProtocol from batdetect2.outputs import ( OutputFormatConfig, OutputFormatterProtocol, @@ -88,7 +89,7 @@ class BatDetect2API: evaluator: EvaluatorProtocol, formatter: OutputFormatterProtocol, output_transform: OutputTransformProtocol, - model: Model, + model: ModelProtocol, ): """Create a fully configured API instance. @@ -128,7 +129,7 @@ class BatDetect2API: Default formatter used to save predictions. output_transform : OutputTransformProtocol Transform that converts model outputs into detections. - model : Model + model : ModelProtocol Model instance. """ self.model_config = model_config @@ -1177,5 +1178,5 @@ class BatDetect2API: parameter.requires_grad = True if trainable in {"heads", "bbox_head"}: - for parameter in detector.bbox_head.parameters(): + for parameter in detector.size_head.parameters(): parameter.requires_grad = True diff --git a/src/batdetect2/cli/finetune.py b/src/batdetect2/cli/finetune.py index 33b1c8b..f0f4791 100644 --- a/src/batdetect2/cli/finetune.py +++ b/src/batdetect2/cli/finetune.py @@ -47,24 +47,24 @@ __all__ = ["finetune_command"] @click.option( "--training-config", type=click.Path(exists=True), - help="Path to a training config file.", + help="Path to training config file.", ) @click.option( "--audio-config", type=click.Path(exists=True), - help="Path to an audio config file.", + help="Path to audio config file.", ) @click.option( "--logging-config", type=click.Path(exists=True), - help="Path to a logging config file.", + help="Path to logging config file.", ) @click.option( "--trainable", type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]), default="heads", show_default=True, - help="Which model parameters stay trainable during fine-tuning.", + help="Which model parameters remain trainable during fine-tuning.", ) @click.option( "--ckpt-dir", @@ -127,11 +127,7 @@ def finetune_command( experiment_name: str | None = None, run_name: str | None = None, ): - """Fine-tune a checkpoint on a new target definition. - - Use this command when you want to adapt an existing model to a new class - list or ROI mapping. - """ + """Fine-tune a BatDetect2 checkpoint on a new target definition.""" from batdetect2.api_v2 import BatDetect2API from batdetect2.audio import AudioConfig from batdetect2.data import load_dataset, load_dataset_config diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index cc3ab69..ee96d93 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -62,7 +62,7 @@ from batdetect2.models.encoder import ( build_encoder, ) from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead -from batdetect2.models.types import DetectionModel +from batdetect2.models.types import DetectorProtocol, ModelProtocol from batdetect2.postprocess.config import PostprocessConfig from batdetect2.postprocess.types import ( ClipDetectionsTensor, @@ -149,7 +149,7 @@ class Model(torch.nn.Module): Attributes ---------- - detector : DetectionModel + detector : DetectorProtocol The neural network that processes spectrograms and produces raw detection, classification, and bounding-box outputs. preprocessor : PreprocessorProtocol @@ -164,7 +164,7 @@ class Model(torch.nn.Module): Size-dimension names corresponding to the model size outputs. """ - detector: DetectionModel + detector: DetectorProtocol preprocessor: PreprocessorProtocol postprocessor: PostprocessorProtocol class_names: list[str] @@ -173,7 +173,7 @@ class Model(torch.nn.Module): def __init__( self, - detector: DetectionModel, + detector: DetectorProtocol, preprocessor: PreprocessorProtocol, postprocessor: PostprocessorProtocol, class_names: list[str], @@ -224,7 +224,7 @@ def build_model( dimension_names: list[str] | None = None, preprocessor: PreprocessorProtocol | None = None, postprocessor: PostprocessorProtocol | None = None, -) -> Model: +) -> ModelProtocol: """Build a complete, ready-to-use BatDetect2 model. Assembles a ``Model`` instance from a ``ModelConfig`` and optional @@ -256,7 +256,7 @@ def build_model( Returns ------- - Model + ModelProtocol A fully assembled ``Model`` instance ready for inference or training. """ @@ -285,8 +285,8 @@ def build_model( config=config.postprocess, ) detector = build_detector( - num_classes=len(class_names), - num_sizes=len(dimension_names), + class_names=class_names, + dimension_names=dimension_names, config=config.architecture, ) return Model( @@ -300,14 +300,14 @@ def build_model( def build_model_with_new_targets( - model: Model, + model: ModelProtocol, targets: TargetProtocol, roi_mapper: ROIMapperProtocol, -) -> Model: +) -> ModelProtocol: """Build a new model with a different target set.""" detector = build_detector( - num_classes=len(targets.class_names), - num_sizes=len(roi_mapper.dimension_names), + class_names=targets.class_names, + dimension_names=roi_mapper.dimension_names, backbone=model.detector.backbone, ) diff --git a/src/batdetect2/models/checkpoints/batdetect2_uk_same.ckpt b/src/batdetect2/models/checkpoints/batdetect2_uk_same.ckpt index 49b64a4..b849167 100644 Binary files a/src/batdetect2/models/checkpoints/batdetect2_uk_same.ckpt and b/src/batdetect2/models/checkpoints/batdetect2_uk_same.ckpt differ diff --git a/src/batdetect2/models/detectors.py b/src/batdetect2/models/detectors.py index a3894ce..586beca 100644 --- a/src/batdetect2/models/detectors.py +++ b/src/batdetect2/models/detectors.py @@ -6,8 +6,8 @@ bounding-box size regression. Components ---------- -- ``Detector`` – the ``torch.nn.Module`` that wires together a backbone - (``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to +- ``Detector`` - the ``torch.nn.Module`` that wires together a backbone + (``BackboneProtocol``) with a ``ClassifierHead`` and a ``BBoxHead`` to produce a ``ModelOutput`` tuple from an input spectrogram. - ``build_detector`` – factory function that builds a ready-to-use ``Detector`` from a backbone configuration and a target class count. @@ -18,15 +18,16 @@ preprocessing and output postprocessing are handled by """ import torch -from loguru import logger -from batdetect2.models.backbones import ( - BackboneConfig, - UNetBackboneConfig, - build_backbone, -) +from batdetect2.models.backbones import BackboneConfig, build_backbone from batdetect2.models.heads import BBoxHead, ClassifierHead -from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput +from batdetect2.models.types import ( + BackboneProtocol, + ClassifierHeadProtocol, + DetectorProtocol, + ModelOutput, + SizeHeadProtocol, +) __all__ = [ "Detector", @@ -34,7 +35,7 @@ __all__ = [ ] -class Detector(DetectionModel): +class Detector(torch.nn.Module): """Complete BatDetect2 detection and classification model. Combines a backbone feature extractor with two prediction heads: @@ -51,7 +52,7 @@ class Detector(DetectionModel): Attributes ---------- - backbone : BackboneModel + backbone : BackboneProtocol The feature extraction backbone. num_classes : int Number of target classes (inferred from the classifier head). @@ -61,13 +62,13 @@ class Detector(DetectionModel): Produces duration and bandwidth predictions from backbone features. """ - backbone: BackboneModel + backbone: BackboneProtocol def __init__( self, - backbone: BackboneModel, - classifier_head: ClassifierHead, - bbox_head: BBoxHead, + backbone: BackboneProtocol, + classifier_head: ClassifierHeadProtocol, + size_head: SizeHeadProtocol, ): """Initialise the Detector model. @@ -76,7 +77,7 @@ class Detector(DetectionModel): Parameters ---------- - backbone : BackboneModel + backbone : BackboneProtocol An initialised backbone module (e.g. built by ``build_backbone``). classifier_head : ClassifierHead @@ -90,7 +91,7 @@ class Detector(DetectionModel): self.backbone = backbone self.num_classes = classifier_head.num_classes self.classifier_head = classifier_head - self.bbox_head = bbox_head + self.size_head = size_head def forward(self, spec: torch.Tensor) -> ModelOutput: """Run the complete detection model on an input spectrogram. @@ -125,7 +126,7 @@ class Detector(DetectionModel): features = self.backbone(spec) classification = self.classifier_head(features) detection = classification.sum(dim=1, keepdim=True) - size_preds = self.bbox_head(features) + size_preds = self.size_head(features) return ModelOutput( detection_probs=detection, size_preds=size_preds, @@ -135,11 +136,11 @@ class Detector(DetectionModel): def build_detector( - num_classes: int, - num_sizes: int = 2, + class_names: list[str], + dimension_names: list[str], config: BackboneConfig | None = None, - backbone: BackboneModel | None = None, -) -> DetectionModel: + backbone: BackboneProtocol | None = None, +) -> DetectorProtocol: """Build a complete BatDetect2 detection model. Constructs a backbone from ``config``, attaches a ``ClassifierHead`` @@ -158,7 +159,7 @@ def build_detector( Returns ------- - DetectionModel + DetectorProtocol An initialised ``Detector`` instance ready for training or inference. @@ -168,24 +169,18 @@ def build_detector( If ``num_classes`` is not positive, or if the backbone configuration is invalid. """ - if backbone is None: - config = config or UNetBackboneConfig() - logger.opt(lazy=True).debug( - "Building model with config: \n{}", - lambda: config.to_yaml_string(), # type: ignore - ) - backbone = build_backbone(config=config) + backbone = backbone or build_backbone(config=config) classifier_head = ClassifierHead( - num_classes=num_classes, + class_names=class_names, in_channels=backbone.out_channels, ) bbox_head = BBoxHead( in_channels=backbone.out_channels, - num_sizes=num_sizes, + dimension_names=dimension_names, ) return Detector( backbone=backbone, classifier_head=classifier_head, - bbox_head=bbox_head, + size_head=bbox_head, ) diff --git a/src/batdetect2/models/heads.py b/src/batdetect2/models/heads.py index ba7b437..250ddb6 100644 --- a/src/batdetect2/models/heads.py +++ b/src/batdetect2/models/heads.py @@ -54,12 +54,14 @@ class ClassifierHead(nn.Module): 1×1 convolution with ``num_classes + 1`` output channels. """ - def __init__(self, num_classes: int, in_channels: int): + def __init__(self, class_names: list[str], in_channels: int): """Initialise the ClassifierHead.""" super().__init__() - self.num_classes = num_classes + self.class_names = class_names + self.num_classes = len(class_names) self.in_channels = in_channels + self.classifier = nn.Conv2d( self.in_channels, self.num_classes + 1, @@ -165,11 +167,12 @@ class BBoxHead(nn.Module): 1×1 convolution with 2 output channels (duration, bandwidth). """ - def __init__(self, in_channels: int, num_sizes: int = 2): + def __init__(self, dimension_names: list[str], in_channels: int): """Initialise the BBoxHead.""" super().__init__() self.in_channels = in_channels - self.num_sizes = num_sizes + self.dimension_names = dimension_names + self.num_sizes = len(dimension_names) self.bbox = nn.Conv2d( in_channels=self.in_channels, diff --git a/src/batdetect2/train/checkpoints.py b/src/batdetect2/train/checkpoints.py index a443743..be1c165 100644 --- a/src/batdetect2/train/checkpoints.py +++ b/src/batdetect2/train/checkpoints.py @@ -34,6 +34,7 @@ class CheckpointConfig(BaseConfig): monitor: str | None = None mode: str = "max" save_top_k: int = 1 + # Save distributable inference checkpoints by default. save_weights_only: bool = True filename: str | None = None save_last: bool | Literal["link"] = "link" diff --git a/tests/test_api_v2/test_api_v2.py b/tests/test_api_v2/test_api_v2.py index a4b2758..9f7a109 100644 --- a/tests/test_api_v2/test_api_v2.py +++ b/tests/test_api_v2/test_api_v2.py @@ -299,10 +299,10 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged( value, ) - for key, value in source_detector.bbox_head.state_dict().items(): - assert key in detector.bbox_head.state_dict() + for key, value in source_detector.size_head.state_dict().items(): + assert key in detector.size_head.state_dict() torch.testing.assert_close( - detector.bbox_head.state_dict()[key], + detector.size_head.state_dict()[key], value, ) diff --git a/tests/test_api_v2/test_finetune.py b/tests/test_api_v2/test_finetune.py index 8d8c6a2..5d8b223 100644 --- a/tests/test_api_v2/test_finetune.py +++ b/tests/test_api_v2/test_finetune.py @@ -18,7 +18,7 @@ def test_user_can_finetune_only_heads( api = BatDetect2API.from_config() source_classifier_head = api.model.detector.classifier_head - source_bbox_head = api.model.detector.bbox_head + source_size_head = api.model.detector.size_head source_backbone = api.model.detector.backbone finetune_dir = tmp_path / "heads_only" @@ -39,7 +39,7 @@ def test_user_can_finetune_only_heads( backbone_params = list(detector.backbone.parameters()) classifier_params = list(detector.classifier_head.parameters()) - bbox_params = list(detector.bbox_head.parameters()) + bbox_params = list(detector.size_head.parameters()) assert backbone_params assert classifier_params @@ -50,7 +50,7 @@ def test_user_can_finetune_only_heads( assert finetuned_api is not api assert detector.backbone is source_backbone assert detector.classifier_head is not source_classifier_head - assert detector.bbox_head is not source_bbox_head + assert detector.size_head is not source_size_head assert list(finetune_dir.rglob("*.ckpt")) diff --git a/tests/test_models/test_detectors.py b/tests/test_models/test_detectors.py index f5ce769..35d39a9 100644 --- a/tests/test_models/test_detectors.py +++ b/tests/test_models/test_detectors.py @@ -1,6 +1,7 @@ import numpy as np import pytest import torch +from typing import cast from batdetect2.models import UNetBackbone from batdetect2.models.backbones import UNetBackboneConfig @@ -19,12 +20,15 @@ def dummy_spectrogram() -> torch.Tensor: def test_build_detector_default(): """Test building the default detector without a config.""" num_classes = 5 - model = build_detector(num_classes=num_classes) + model = build_detector( + class_names=[f"class_{i}" for i in range(num_classes)], + dimension_names=["width", "height"], + ) assert isinstance(model, Detector) assert model.num_classes == num_classes assert isinstance(model.classifier_head, ClassifierHead) - assert isinstance(model.bbox_head, BBoxHead) + assert isinstance(model.size_head, BBoxHead) def test_build_detector_custom_config(): @@ -32,13 +36,19 @@ def test_build_detector_custom_config(): num_classes = 3 config = UNetBackboneConfig(in_channels=2, input_height=128) - model = build_detector(num_classes=num_classes, config=config) + model = build_detector( + class_names=[f"class_{i}" for i in range(num_classes)], + dimension_names=["width", "height"], + config=config, + ) assert isinstance(model, Detector) assert model.backbone.input_height == 128 - assert isinstance(model.backbone.encoder, Encoder) - assert model.backbone.encoder.in_channels == 2 + backbone = cast(UNetBackbone, model.backbone) + + assert isinstance(backbone.encoder, Encoder) + assert backbone.encoder.in_channels == 2 def test_build_detector_custom_size_channels(): @@ -47,8 +57,8 @@ def test_build_detector_custom_size_channels(): config = UNetBackboneConfig(in_channels=1, input_height=128) model = build_detector( - num_classes=num_classes, - num_sizes=num_sizes, + class_names=[f"class_{i}" for i in range(num_classes)], + dimension_names=[f"size_{i}" for i in range(num_sizes)], config=config, ) @@ -62,7 +72,11 @@ def test_detector_forward_pass_shapes(dummy_spectrogram): num_classes = 4 # Build model matching the dummy input shape config = UNetBackboneConfig(in_channels=1, input_height=256) - model = build_detector(num_classes=num_classes, config=config) + model = build_detector( + class_names=[f"class_{i}" for i in range(num_classes)], + dimension_names=["width", "height"], + config=config, + ) # Process the spectrogram through the model # PyTorch expects shape (Batch, Channels, Height, Width) @@ -132,7 +146,11 @@ def test_detector_forward_pass_with_preprocessor(sample_preprocessor): config = UNetBackboneConfig( in_channels=spec.shape[1], input_height=spec.shape[2] ) - model = build_detector(num_classes=3, config=config) + model = build_detector( + class_names=["class_0", "class_1", "class_2"], + dimension_names=["width", "height"], + config=config, + ) # Process output = model(spec) diff --git a/tests/test_train/test_checkpoints.py b/tests/test_train/test_checkpoints.py index 77a2856..4cffa48 100644 --- a/tests/test_train/test_checkpoints.py +++ b/tests/test_train/test_checkpoints.py @@ -8,7 +8,7 @@ from soundevent import data from batdetect2.train import TrainingConfig, run_train from batdetect2.train.checkpoints import ( - DEFAULT_BUNDLED_CHECKPOINT, + DEFAULT_CHECKPOINT, get_bundled_checkpoint_names, resolve_checkpoint_path, ) @@ -145,7 +145,7 @@ def test_resolve_checkpoint_path_returns_local_path_unchanged( def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None: assert get_bundled_checkpoint_names() == ( - DEFAULT_BUNDLED_CHECKPOINT, + DEFAULT_CHECKPOINT, "batdetect2_uk_same", ) @@ -153,11 +153,11 @@ def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None: def test_resolve_checkpoint_path_uses_default_bundled_alias() -> None: resolved = resolve_checkpoint_path() - assert resolved == resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT) + assert resolved == resolve_checkpoint_path(DEFAULT_CHECKPOINT) def test_resolve_checkpoint_path_accepts_bundled_alias() -> None: - resolved = resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT) + resolved = resolve_checkpoint_path(DEFAULT_CHECKPOINT) assert resolved.name == "batdetect2_uk_same.ckpt" assert resolved.exists() @@ -227,6 +227,6 @@ def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None: def test_resolve_checkpoint_path_rejects_missing_local_path() -> None: with pytest.raises( FileNotFoundError, - match="bundled checkpoint alias", + match="checkpoint alias", ): resolve_checkpoint_path("missing.ckpt") diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index 9ab9e02..756329d 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -368,7 +368,7 @@ def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> ( assert ( rebuilt_detector.classifier_head is not source_detector.classifier_head ) - assert rebuilt_detector.bbox_head is not source_detector.bbox_head + assert rebuilt_detector.size_head is not source_detector.size_head assert rebuilt_model.class_names == ["single_class"] assert rebuilt_model.dimension_names == ["width", "height"]