fix: rename detector heads and refresh bundled checkpoint

This commit is contained in:
mbsantiago 2026-05-06 12:50:53 +01:00
parent 2008c8000f
commit 5cc5767eff
12 changed files with 97 additions and 83 deletions

View File

@ -19,7 +19,8 @@ if TYPE_CHECKING:
LoggerConfig, LoggerConfig,
LoggingCallback, LoggingCallback,
) )
from batdetect2.models import Model, ModelConfig from batdetect2.models import ModelConfig
from batdetect2.models.types import ModelProtocol
from batdetect2.outputs import ( from batdetect2.outputs import (
OutputFormatConfig, OutputFormatConfig,
OutputFormatterProtocol, OutputFormatterProtocol,
@ -88,7 +89,7 @@ class BatDetect2API:
evaluator: EvaluatorProtocol, evaluator: EvaluatorProtocol,
formatter: OutputFormatterProtocol, formatter: OutputFormatterProtocol,
output_transform: OutputTransformProtocol, output_transform: OutputTransformProtocol,
model: Model, model: ModelProtocol,
): ):
"""Create a fully configured API instance. """Create a fully configured API instance.
@ -128,7 +129,7 @@ class BatDetect2API:
Default formatter used to save predictions. Default formatter used to save predictions.
output_transform : OutputTransformProtocol output_transform : OutputTransformProtocol
Transform that converts model outputs into detections. Transform that converts model outputs into detections.
model : Model model : ModelProtocol
Model instance. Model instance.
""" """
self.model_config = model_config self.model_config = model_config
@ -1177,5 +1178,5 @@ class BatDetect2API:
parameter.requires_grad = True parameter.requires_grad = True
if trainable in {"heads", "bbox_head"}: if trainable in {"heads", "bbox_head"}:
for parameter in detector.bbox_head.parameters(): for parameter in detector.size_head.parameters():
parameter.requires_grad = True parameter.requires_grad = True

View File

@ -47,24 +47,24 @@ __all__ = ["finetune_command"]
@click.option( @click.option(
"--training-config", "--training-config",
type=click.Path(exists=True), type=click.Path(exists=True),
help="Path to a training config file.", help="Path to training config file.",
) )
@click.option( @click.option(
"--audio-config", "--audio-config",
type=click.Path(exists=True), type=click.Path(exists=True),
help="Path to an audio config file.", help="Path to audio config file.",
) )
@click.option( @click.option(
"--logging-config", "--logging-config",
type=click.Path(exists=True), type=click.Path(exists=True),
help="Path to a logging config file.", help="Path to logging config file.",
) )
@click.option( @click.option(
"--trainable", "--trainable",
type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]), type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]),
default="heads", default="heads",
show_default=True, show_default=True,
help="Which model parameters stay trainable during fine-tuning.", help="Which model parameters remain trainable during fine-tuning.",
) )
@click.option( @click.option(
"--ckpt-dir", "--ckpt-dir",
@ -127,11 +127,7 @@ def finetune_command(
experiment_name: str | None = None, experiment_name: str | None = None,
run_name: str | None = None, run_name: str | None = None,
): ):
"""Fine-tune a checkpoint on a new target definition. """Fine-tune a BatDetect2 checkpoint on a new target definition."""
Use this command when you want to adapt an existing model to a new class
list or ROI mapping.
"""
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig from batdetect2.audio import AudioConfig
from batdetect2.data import load_dataset, load_dataset_config from batdetect2.data import load_dataset, load_dataset_config

View File

@ -62,7 +62,7 @@ from batdetect2.models.encoder import (
build_encoder, build_encoder,
) )
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.models.types import DetectionModel from batdetect2.models.types import DetectorProtocol, ModelProtocol
from batdetect2.postprocess.config import PostprocessConfig from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.postprocess.types import ( from batdetect2.postprocess.types import (
ClipDetectionsTensor, ClipDetectionsTensor,
@ -149,7 +149,7 @@ class Model(torch.nn.Module):
Attributes Attributes
---------- ----------
detector : DetectionModel detector : DetectorProtocol
The neural network that processes spectrograms and produces raw The neural network that processes spectrograms and produces raw
detection, classification, and bounding-box outputs. detection, classification, and bounding-box outputs.
preprocessor : PreprocessorProtocol preprocessor : PreprocessorProtocol
@ -164,7 +164,7 @@ class Model(torch.nn.Module):
Size-dimension names corresponding to the model size outputs. Size-dimension names corresponding to the model size outputs.
""" """
detector: DetectionModel detector: DetectorProtocol
preprocessor: PreprocessorProtocol preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol postprocessor: PostprocessorProtocol
class_names: list[str] class_names: list[str]
@ -173,7 +173,7 @@ class Model(torch.nn.Module):
def __init__( def __init__(
self, self,
detector: DetectionModel, detector: DetectorProtocol,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol, postprocessor: PostprocessorProtocol,
class_names: list[str], class_names: list[str],
@ -224,7 +224,7 @@ def build_model(
dimension_names: list[str] | None = None, dimension_names: list[str] | None = None,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
postprocessor: PostprocessorProtocol | None = None, postprocessor: PostprocessorProtocol | None = None,
) -> Model: ) -> ModelProtocol:
"""Build a complete, ready-to-use BatDetect2 model. """Build a complete, ready-to-use BatDetect2 model.
Assembles a ``Model`` instance from a ``ModelConfig`` and optional Assembles a ``Model`` instance from a ``ModelConfig`` and optional
@ -256,7 +256,7 @@ def build_model(
Returns Returns
------- -------
Model ModelProtocol
A fully assembled ``Model`` instance ready for inference or A fully assembled ``Model`` instance ready for inference or
training. training.
""" """
@ -285,8 +285,8 @@ def build_model(
config=config.postprocess, config=config.postprocess,
) )
detector = build_detector( detector = build_detector(
num_classes=len(class_names), class_names=class_names,
num_sizes=len(dimension_names), dimension_names=dimension_names,
config=config.architecture, config=config.architecture,
) )
return Model( return Model(
@ -300,14 +300,14 @@ def build_model(
def build_model_with_new_targets( def build_model_with_new_targets(
model: Model, model: ModelProtocol,
targets: TargetProtocol, targets: TargetProtocol,
roi_mapper: ROIMapperProtocol, roi_mapper: ROIMapperProtocol,
) -> Model: ) -> ModelProtocol:
"""Build a new model with a different target set.""" """Build a new model with a different target set."""
detector = build_detector( detector = build_detector(
num_classes=len(targets.class_names), class_names=targets.class_names,
num_sizes=len(roi_mapper.dimension_names), dimension_names=roi_mapper.dimension_names,
backbone=model.detector.backbone, backbone=model.detector.backbone,
) )

View File

@ -6,8 +6,8 @@ bounding-box size regression.
Components Components
---------- ----------
- ``Detector`` the ``torch.nn.Module`` that wires together a backbone - ``Detector`` - the ``torch.nn.Module`` that wires together a backbone
(``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to (``BackboneProtocol``) with a ``ClassifierHead`` and a ``BBoxHead`` to
produce a ``ModelOutput`` tuple from an input spectrogram. produce a ``ModelOutput`` tuple from an input spectrogram.
- ``build_detector`` factory function that builds a ready-to-use - ``build_detector`` factory function that builds a ready-to-use
``Detector`` from a backbone configuration and a target class count. ``Detector`` from a backbone configuration and a target class count.
@ -18,15 +18,16 @@ preprocessing and output postprocessing are handled by
""" """
import torch import torch
from loguru import logger
from batdetect2.models.backbones import ( from batdetect2.models.backbones import BackboneConfig, build_backbone
BackboneConfig,
UNetBackboneConfig,
build_backbone,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput from batdetect2.models.types import (
BackboneProtocol,
ClassifierHeadProtocol,
DetectorProtocol,
ModelOutput,
SizeHeadProtocol,
)
__all__ = [ __all__ = [
"Detector", "Detector",
@ -34,7 +35,7 @@ __all__ = [
] ]
class Detector(DetectionModel): class Detector(torch.nn.Module):
"""Complete BatDetect2 detection and classification model. """Complete BatDetect2 detection and classification model.
Combines a backbone feature extractor with two prediction heads: Combines a backbone feature extractor with two prediction heads:
@ -51,7 +52,7 @@ class Detector(DetectionModel):
Attributes Attributes
---------- ----------
backbone : BackboneModel backbone : BackboneProtocol
The feature extraction backbone. The feature extraction backbone.
num_classes : int num_classes : int
Number of target classes (inferred from the classifier head). Number of target classes (inferred from the classifier head).
@ -61,13 +62,13 @@ class Detector(DetectionModel):
Produces duration and bandwidth predictions from backbone features. Produces duration and bandwidth predictions from backbone features.
""" """
backbone: BackboneModel backbone: BackboneProtocol
def __init__( def __init__(
self, self,
backbone: BackboneModel, backbone: BackboneProtocol,
classifier_head: ClassifierHead, classifier_head: ClassifierHeadProtocol,
bbox_head: BBoxHead, size_head: SizeHeadProtocol,
): ):
"""Initialise the Detector model. """Initialise the Detector model.
@ -76,7 +77,7 @@ class Detector(DetectionModel):
Parameters Parameters
---------- ----------
backbone : BackboneModel backbone : BackboneProtocol
An initialised backbone module (e.g. built by An initialised backbone module (e.g. built by
``build_backbone``). ``build_backbone``).
classifier_head : ClassifierHead classifier_head : ClassifierHead
@ -90,7 +91,7 @@ class Detector(DetectionModel):
self.backbone = backbone self.backbone = backbone
self.num_classes = classifier_head.num_classes self.num_classes = classifier_head.num_classes
self.classifier_head = classifier_head self.classifier_head = classifier_head
self.bbox_head = bbox_head self.size_head = size_head
def forward(self, spec: torch.Tensor) -> ModelOutput: def forward(self, spec: torch.Tensor) -> ModelOutput:
"""Run the complete detection model on an input spectrogram. """Run the complete detection model on an input spectrogram.
@ -125,7 +126,7 @@ class Detector(DetectionModel):
features = self.backbone(spec) features = self.backbone(spec)
classification = self.classifier_head(features) classification = self.classifier_head(features)
detection = classification.sum(dim=1, keepdim=True) detection = classification.sum(dim=1, keepdim=True)
size_preds = self.bbox_head(features) size_preds = self.size_head(features)
return ModelOutput( return ModelOutput(
detection_probs=detection, detection_probs=detection,
size_preds=size_preds, size_preds=size_preds,
@ -135,11 +136,11 @@ class Detector(DetectionModel):
def build_detector( def build_detector(
num_classes: int, class_names: list[str],
num_sizes: int = 2, dimension_names: list[str],
config: BackboneConfig | None = None, config: BackboneConfig | None = None,
backbone: BackboneModel | None = None, backbone: BackboneProtocol | None = None,
) -> DetectionModel: ) -> DetectorProtocol:
"""Build a complete BatDetect2 detection model. """Build a complete BatDetect2 detection model.
Constructs a backbone from ``config``, attaches a ``ClassifierHead`` Constructs a backbone from ``config``, attaches a ``ClassifierHead``
@ -158,7 +159,7 @@ def build_detector(
Returns Returns
------- -------
DetectionModel DetectorProtocol
An initialised ``Detector`` instance ready for training or An initialised ``Detector`` instance ready for training or
inference. inference.
@ -168,24 +169,18 @@ def build_detector(
If ``num_classes`` is not positive, or if the backbone If ``num_classes`` is not positive, or if the backbone
configuration is invalid. configuration is invalid.
""" """
if backbone is None: backbone = backbone or build_backbone(config=config)
config = config or UNetBackboneConfig()
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(), # type: ignore
)
backbone = build_backbone(config=config)
classifier_head = ClassifierHead( classifier_head = ClassifierHead(
num_classes=num_classes, class_names=class_names,
in_channels=backbone.out_channels, in_channels=backbone.out_channels,
) )
bbox_head = BBoxHead( bbox_head = BBoxHead(
in_channels=backbone.out_channels, in_channels=backbone.out_channels,
num_sizes=num_sizes, dimension_names=dimension_names,
) )
return Detector( return Detector(
backbone=backbone, backbone=backbone,
classifier_head=classifier_head, classifier_head=classifier_head,
bbox_head=bbox_head, size_head=bbox_head,
) )

View File

@ -54,12 +54,14 @@ class ClassifierHead(nn.Module):
1×1 convolution with ``num_classes + 1`` output channels. 1×1 convolution with ``num_classes + 1`` output channels.
""" """
def __init__(self, num_classes: int, in_channels: int): def __init__(self, class_names: list[str], in_channels: int):
"""Initialise the ClassifierHead.""" """Initialise the ClassifierHead."""
super().__init__() super().__init__()
self.num_classes = num_classes self.class_names = class_names
self.num_classes = len(class_names)
self.in_channels = in_channels self.in_channels = in_channels
self.classifier = nn.Conv2d( self.classifier = nn.Conv2d(
self.in_channels, self.in_channels,
self.num_classes + 1, self.num_classes + 1,
@ -165,11 +167,12 @@ class BBoxHead(nn.Module):
1×1 convolution with 2 output channels (duration, bandwidth). 1×1 convolution with 2 output channels (duration, bandwidth).
""" """
def __init__(self, in_channels: int, num_sizes: int = 2): def __init__(self, dimension_names: list[str], in_channels: int):
"""Initialise the BBoxHead.""" """Initialise the BBoxHead."""
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.num_sizes = num_sizes self.dimension_names = dimension_names
self.num_sizes = len(dimension_names)
self.bbox = nn.Conv2d( self.bbox = nn.Conv2d(
in_channels=self.in_channels, in_channels=self.in_channels,

View File

@ -34,6 +34,7 @@ class CheckpointConfig(BaseConfig):
monitor: str | None = None monitor: str | None = None
mode: str = "max" mode: str = "max"
save_top_k: int = 1 save_top_k: int = 1
# Save distributable inference checkpoints by default.
save_weights_only: bool = True save_weights_only: bool = True
filename: str | None = None filename: str | None = None
save_last: bool | Literal["link"] = "link" save_last: bool | Literal["link"] = "link"

View File

@ -299,10 +299,10 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
value, value,
) )
for key, value in source_detector.bbox_head.state_dict().items(): for key, value in source_detector.size_head.state_dict().items():
assert key in detector.bbox_head.state_dict() assert key in detector.size_head.state_dict()
torch.testing.assert_close( torch.testing.assert_close(
detector.bbox_head.state_dict()[key], detector.size_head.state_dict()[key],
value, value,
) )

View File

@ -18,7 +18,7 @@ def test_user_can_finetune_only_heads(
api = BatDetect2API.from_config() api = BatDetect2API.from_config()
source_classifier_head = api.model.detector.classifier_head source_classifier_head = api.model.detector.classifier_head
source_bbox_head = api.model.detector.bbox_head source_size_head = api.model.detector.size_head
source_backbone = api.model.detector.backbone source_backbone = api.model.detector.backbone
finetune_dir = tmp_path / "heads_only" finetune_dir = tmp_path / "heads_only"
@ -39,7 +39,7 @@ def test_user_can_finetune_only_heads(
backbone_params = list(detector.backbone.parameters()) backbone_params = list(detector.backbone.parameters())
classifier_params = list(detector.classifier_head.parameters()) classifier_params = list(detector.classifier_head.parameters())
bbox_params = list(detector.bbox_head.parameters()) bbox_params = list(detector.size_head.parameters())
assert backbone_params assert backbone_params
assert classifier_params assert classifier_params
@ -50,7 +50,7 @@ def test_user_can_finetune_only_heads(
assert finetuned_api is not api assert finetuned_api is not api
assert detector.backbone is source_backbone assert detector.backbone is source_backbone
assert detector.classifier_head is not source_classifier_head assert detector.classifier_head is not source_classifier_head
assert detector.bbox_head is not source_bbox_head assert detector.size_head is not source_size_head
assert list(finetune_dir.rglob("*.ckpt")) assert list(finetune_dir.rglob("*.ckpt"))

View File

@ -1,6 +1,7 @@
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from typing import cast
from batdetect2.models import UNetBackbone from batdetect2.models import UNetBackbone
from batdetect2.models.backbones import UNetBackboneConfig from batdetect2.models.backbones import UNetBackboneConfig
@ -19,12 +20,15 @@ def dummy_spectrogram() -> torch.Tensor:
def test_build_detector_default(): def test_build_detector_default():
"""Test building the default detector without a config.""" """Test building the default detector without a config."""
num_classes = 5 num_classes = 5
model = build_detector(num_classes=num_classes) model = build_detector(
class_names=[f"class_{i}" for i in range(num_classes)],
dimension_names=["width", "height"],
)
assert isinstance(model, Detector) assert isinstance(model, Detector)
assert model.num_classes == num_classes assert model.num_classes == num_classes
assert isinstance(model.classifier_head, ClassifierHead) assert isinstance(model.classifier_head, ClassifierHead)
assert isinstance(model.bbox_head, BBoxHead) assert isinstance(model.size_head, BBoxHead)
def test_build_detector_custom_config(): def test_build_detector_custom_config():
@ -32,13 +36,19 @@ def test_build_detector_custom_config():
num_classes = 3 num_classes = 3
config = UNetBackboneConfig(in_channels=2, input_height=128) config = UNetBackboneConfig(in_channels=2, input_height=128)
model = build_detector(num_classes=num_classes, config=config) model = build_detector(
class_names=[f"class_{i}" for i in range(num_classes)],
dimension_names=["width", "height"],
config=config,
)
assert isinstance(model, Detector) assert isinstance(model, Detector)
assert model.backbone.input_height == 128 assert model.backbone.input_height == 128
assert isinstance(model.backbone.encoder, Encoder) backbone = cast(UNetBackbone, model.backbone)
assert model.backbone.encoder.in_channels == 2
assert isinstance(backbone.encoder, Encoder)
assert backbone.encoder.in_channels == 2
def test_build_detector_custom_size_channels(): def test_build_detector_custom_size_channels():
@ -47,8 +57,8 @@ def test_build_detector_custom_size_channels():
config = UNetBackboneConfig(in_channels=1, input_height=128) config = UNetBackboneConfig(in_channels=1, input_height=128)
model = build_detector( model = build_detector(
num_classes=num_classes, class_names=[f"class_{i}" for i in range(num_classes)],
num_sizes=num_sizes, dimension_names=[f"size_{i}" for i in range(num_sizes)],
config=config, config=config,
) )
@ -62,7 +72,11 @@ def test_detector_forward_pass_shapes(dummy_spectrogram):
num_classes = 4 num_classes = 4
# Build model matching the dummy input shape # Build model matching the dummy input shape
config = UNetBackboneConfig(in_channels=1, input_height=256) config = UNetBackboneConfig(in_channels=1, input_height=256)
model = build_detector(num_classes=num_classes, config=config) model = build_detector(
class_names=[f"class_{i}" for i in range(num_classes)],
dimension_names=["width", "height"],
config=config,
)
# Process the spectrogram through the model # Process the spectrogram through the model
# PyTorch expects shape (Batch, Channels, Height, Width) # PyTorch expects shape (Batch, Channels, Height, Width)
@ -132,7 +146,11 @@ def test_detector_forward_pass_with_preprocessor(sample_preprocessor):
config = UNetBackboneConfig( config = UNetBackboneConfig(
in_channels=spec.shape[1], input_height=spec.shape[2] in_channels=spec.shape[1], input_height=spec.shape[2]
) )
model = build_detector(num_classes=3, config=config) model = build_detector(
class_names=["class_0", "class_1", "class_2"],
dimension_names=["width", "height"],
config=config,
)
# Process # Process
output = model(spec) output = model(spec)

View File

@ -8,7 +8,7 @@ from soundevent import data
from batdetect2.train import TrainingConfig, run_train from batdetect2.train import TrainingConfig, run_train
from batdetect2.train.checkpoints import ( from batdetect2.train.checkpoints import (
DEFAULT_BUNDLED_CHECKPOINT, DEFAULT_CHECKPOINT,
get_bundled_checkpoint_names, get_bundled_checkpoint_names,
resolve_checkpoint_path, resolve_checkpoint_path,
) )
@ -145,7 +145,7 @@ def test_resolve_checkpoint_path_returns_local_path_unchanged(
def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None: def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
assert get_bundled_checkpoint_names() == ( assert get_bundled_checkpoint_names() == (
DEFAULT_BUNDLED_CHECKPOINT, DEFAULT_CHECKPOINT,
"batdetect2_uk_same", "batdetect2_uk_same",
) )
@ -153,11 +153,11 @@ def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
def test_resolve_checkpoint_path_uses_default_bundled_alias() -> None: def test_resolve_checkpoint_path_uses_default_bundled_alias() -> None:
resolved = resolve_checkpoint_path() resolved = resolve_checkpoint_path()
assert resolved == resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT) assert resolved == resolve_checkpoint_path(DEFAULT_CHECKPOINT)
def test_resolve_checkpoint_path_accepts_bundled_alias() -> None: def test_resolve_checkpoint_path_accepts_bundled_alias() -> None:
resolved = resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT) resolved = resolve_checkpoint_path(DEFAULT_CHECKPOINT)
assert resolved.name == "batdetect2_uk_same.ckpt" assert resolved.name == "batdetect2_uk_same.ckpt"
assert resolved.exists() assert resolved.exists()
@ -227,6 +227,6 @@ def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None:
def test_resolve_checkpoint_path_rejects_missing_local_path() -> None: def test_resolve_checkpoint_path_rejects_missing_local_path() -> None:
with pytest.raises( with pytest.raises(
FileNotFoundError, FileNotFoundError,
match="bundled checkpoint alias", match="checkpoint alias",
): ):
resolve_checkpoint_path("missing.ckpt") resolve_checkpoint_path("missing.ckpt")

View File

@ -368,7 +368,7 @@ def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> (
assert ( assert (
rebuilt_detector.classifier_head is not source_detector.classifier_head rebuilt_detector.classifier_head is not source_detector.classifier_head
) )
assert rebuilt_detector.bbox_head is not source_detector.bbox_head assert rebuilt_detector.size_head is not source_detector.size_head
assert rebuilt_model.class_names == ["single_class"] assert rebuilt_model.class_names == ["single_class"]
assert rebuilt_model.dimension_names == ["width", "height"] assert rebuilt_model.dimension_names == ["width", "height"]