fix: rename detector heads and refresh bundled checkpoint

This commit is contained in:
mbsantiago 2026-05-06 12:50:53 +01:00
parent 2008c8000f
commit 5cc5767eff
12 changed files with 97 additions and 83 deletions

View File

@ -19,7 +19,8 @@ if TYPE_CHECKING:
LoggerConfig,
LoggingCallback,
)
from batdetect2.models import Model, ModelConfig
from batdetect2.models import ModelConfig
from batdetect2.models.types import ModelProtocol
from batdetect2.outputs import (
OutputFormatConfig,
OutputFormatterProtocol,
@ -88,7 +89,7 @@ class BatDetect2API:
evaluator: EvaluatorProtocol,
formatter: OutputFormatterProtocol,
output_transform: OutputTransformProtocol,
model: Model,
model: ModelProtocol,
):
"""Create a fully configured API instance.
@ -128,7 +129,7 @@ class BatDetect2API:
Default formatter used to save predictions.
output_transform : OutputTransformProtocol
Transform that converts model outputs into detections.
model : Model
model : ModelProtocol
Model instance.
"""
self.model_config = model_config
@ -1177,5 +1178,5 @@ class BatDetect2API:
parameter.requires_grad = True
if trainable in {"heads", "bbox_head"}:
for parameter in detector.bbox_head.parameters():
for parameter in detector.size_head.parameters():
parameter.requires_grad = True

View File

@ -47,24 +47,24 @@ __all__ = ["finetune_command"]
@click.option(
"--training-config",
type=click.Path(exists=True),
help="Path to a training config file.",
help="Path to training config file.",
)
@click.option(
"--audio-config",
type=click.Path(exists=True),
help="Path to an audio config file.",
help="Path to audio config file.",
)
@click.option(
"--logging-config",
type=click.Path(exists=True),
help="Path to a logging config file.",
help="Path to logging config file.",
)
@click.option(
"--trainable",
type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]),
default="heads",
show_default=True,
help="Which model parameters stay trainable during fine-tuning.",
help="Which model parameters remain trainable during fine-tuning.",
)
@click.option(
"--ckpt-dir",
@ -127,11 +127,7 @@ def finetune_command(
experiment_name: str | None = None,
run_name: str | None = None,
):
"""Fine-tune a checkpoint on a new target definition.
Use this command when you want to adapt an existing model to a new class
list or ROI mapping.
"""
"""Fine-tune a BatDetect2 checkpoint on a new target definition."""
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig
from batdetect2.data import load_dataset, load_dataset_config

View File

@ -62,7 +62,7 @@ from batdetect2.models.encoder import (
build_encoder,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.models.types import DetectionModel
from batdetect2.models.types import DetectorProtocol, ModelProtocol
from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.postprocess.types import (
ClipDetectionsTensor,
@ -149,7 +149,7 @@ class Model(torch.nn.Module):
Attributes
----------
detector : DetectionModel
detector : DetectorProtocol
The neural network that processes spectrograms and produces raw
detection, classification, and bounding-box outputs.
preprocessor : PreprocessorProtocol
@ -164,7 +164,7 @@ class Model(torch.nn.Module):
Size-dimension names corresponding to the model size outputs.
"""
detector: DetectionModel
detector: DetectorProtocol
preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol
class_names: list[str]
@ -173,7 +173,7 @@ class Model(torch.nn.Module):
def __init__(
self,
detector: DetectionModel,
detector: DetectorProtocol,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
class_names: list[str],
@ -224,7 +224,7 @@ def build_model(
dimension_names: list[str] | None = None,
preprocessor: PreprocessorProtocol | None = None,
postprocessor: PostprocessorProtocol | None = None,
) -> Model:
) -> ModelProtocol:
"""Build a complete, ready-to-use BatDetect2 model.
Assembles a ``Model`` instance from a ``ModelConfig`` and optional
@ -256,7 +256,7 @@ def build_model(
Returns
-------
Model
ModelProtocol
A fully assembled ``Model`` instance ready for inference or
training.
"""
@ -285,8 +285,8 @@ def build_model(
config=config.postprocess,
)
detector = build_detector(
num_classes=len(class_names),
num_sizes=len(dimension_names),
class_names=class_names,
dimension_names=dimension_names,
config=config.architecture,
)
return Model(
@ -300,14 +300,14 @@ def build_model(
def build_model_with_new_targets(
model: Model,
model: ModelProtocol,
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
) -> Model:
) -> ModelProtocol:
"""Build a new model with a different target set."""
detector = build_detector(
num_classes=len(targets.class_names),
num_sizes=len(roi_mapper.dimension_names),
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
backbone=model.detector.backbone,
)

View File

@ -6,8 +6,8 @@ bounding-box size regression.
Components
----------
- ``Detector`` the ``torch.nn.Module`` that wires together a backbone
(``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to
- ``Detector`` - the ``torch.nn.Module`` that wires together a backbone
(``BackboneProtocol``) with a ``ClassifierHead`` and a ``BBoxHead`` to
produce a ``ModelOutput`` tuple from an input spectrogram.
- ``build_detector`` factory function that builds a ready-to-use
``Detector`` from a backbone configuration and a target class count.
@ -18,15 +18,16 @@ preprocessing and output postprocessing are handled by
"""
import torch
from loguru import logger
from batdetect2.models.backbones import (
BackboneConfig,
UNetBackboneConfig,
build_backbone,
)
from batdetect2.models.backbones import BackboneConfig, build_backbone
from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
from batdetect2.models.types import (
BackboneProtocol,
ClassifierHeadProtocol,
DetectorProtocol,
ModelOutput,
SizeHeadProtocol,
)
__all__ = [
"Detector",
@ -34,7 +35,7 @@ __all__ = [
]
class Detector(DetectionModel):
class Detector(torch.nn.Module):
"""Complete BatDetect2 detection and classification model.
Combines a backbone feature extractor with two prediction heads:
@ -51,7 +52,7 @@ class Detector(DetectionModel):
Attributes
----------
backbone : BackboneModel
backbone : BackboneProtocol
The feature extraction backbone.
num_classes : int
Number of target classes (inferred from the classifier head).
@ -61,13 +62,13 @@ class Detector(DetectionModel):
Produces duration and bandwidth predictions from backbone features.
"""
backbone: BackboneModel
backbone: BackboneProtocol
def __init__(
self,
backbone: BackboneModel,
classifier_head: ClassifierHead,
bbox_head: BBoxHead,
backbone: BackboneProtocol,
classifier_head: ClassifierHeadProtocol,
size_head: SizeHeadProtocol,
):
"""Initialise the Detector model.
@ -76,7 +77,7 @@ class Detector(DetectionModel):
Parameters
----------
backbone : BackboneModel
backbone : BackboneProtocol
An initialised backbone module (e.g. built by
``build_backbone``).
classifier_head : ClassifierHead
@ -90,7 +91,7 @@ class Detector(DetectionModel):
self.backbone = backbone
self.num_classes = classifier_head.num_classes
self.classifier_head = classifier_head
self.bbox_head = bbox_head
self.size_head = size_head
def forward(self, spec: torch.Tensor) -> ModelOutput:
"""Run the complete detection model on an input spectrogram.
@ -125,7 +126,7 @@ class Detector(DetectionModel):
features = self.backbone(spec)
classification = self.classifier_head(features)
detection = classification.sum(dim=1, keepdim=True)
size_preds = self.bbox_head(features)
size_preds = self.size_head(features)
return ModelOutput(
detection_probs=detection,
size_preds=size_preds,
@ -135,11 +136,11 @@ class Detector(DetectionModel):
def build_detector(
num_classes: int,
num_sizes: int = 2,
class_names: list[str],
dimension_names: list[str],
config: BackboneConfig | None = None,
backbone: BackboneModel | None = None,
) -> DetectionModel:
backbone: BackboneProtocol | None = None,
) -> DetectorProtocol:
"""Build a complete BatDetect2 detection model.
Constructs a backbone from ``config``, attaches a ``ClassifierHead``
@ -158,7 +159,7 @@ def build_detector(
Returns
-------
DetectionModel
DetectorProtocol
An initialised ``Detector`` instance ready for training or
inference.
@ -168,24 +169,18 @@ def build_detector(
If ``num_classes`` is not positive, or if the backbone
configuration is invalid.
"""
if backbone is None:
config = config or UNetBackboneConfig()
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(), # type: ignore
)
backbone = build_backbone(config=config)
backbone = backbone or build_backbone(config=config)
classifier_head = ClassifierHead(
num_classes=num_classes,
class_names=class_names,
in_channels=backbone.out_channels,
)
bbox_head = BBoxHead(
in_channels=backbone.out_channels,
num_sizes=num_sizes,
dimension_names=dimension_names,
)
return Detector(
backbone=backbone,
classifier_head=classifier_head,
bbox_head=bbox_head,
size_head=bbox_head,
)

View File

@ -54,12 +54,14 @@ class ClassifierHead(nn.Module):
1×1 convolution with ``num_classes + 1`` output channels.
"""
def __init__(self, num_classes: int, in_channels: int):
def __init__(self, class_names: list[str], in_channels: int):
"""Initialise the ClassifierHead."""
super().__init__()
self.num_classes = num_classes
self.class_names = class_names
self.num_classes = len(class_names)
self.in_channels = in_channels
self.classifier = nn.Conv2d(
self.in_channels,
self.num_classes + 1,
@ -165,11 +167,12 @@ class BBoxHead(nn.Module):
1×1 convolution with 2 output channels (duration, bandwidth).
"""
def __init__(self, in_channels: int, num_sizes: int = 2):
def __init__(self, dimension_names: list[str], in_channels: int):
"""Initialise the BBoxHead."""
super().__init__()
self.in_channels = in_channels
self.num_sizes = num_sizes
self.dimension_names = dimension_names
self.num_sizes = len(dimension_names)
self.bbox = nn.Conv2d(
in_channels=self.in_channels,

View File

@ -34,6 +34,7 @@ class CheckpointConfig(BaseConfig):
monitor: str | None = None
mode: str = "max"
save_top_k: int = 1
# Save distributable inference checkpoints by default.
save_weights_only: bool = True
filename: str | None = None
save_last: bool | Literal["link"] = "link"

View File

@ -299,10 +299,10 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
value,
)
for key, value in source_detector.bbox_head.state_dict().items():
assert key in detector.bbox_head.state_dict()
for key, value in source_detector.size_head.state_dict().items():
assert key in detector.size_head.state_dict()
torch.testing.assert_close(
detector.bbox_head.state_dict()[key],
detector.size_head.state_dict()[key],
value,
)

View File

@ -18,7 +18,7 @@ def test_user_can_finetune_only_heads(
api = BatDetect2API.from_config()
source_classifier_head = api.model.detector.classifier_head
source_bbox_head = api.model.detector.bbox_head
source_size_head = api.model.detector.size_head
source_backbone = api.model.detector.backbone
finetune_dir = tmp_path / "heads_only"
@ -39,7 +39,7 @@ def test_user_can_finetune_only_heads(
backbone_params = list(detector.backbone.parameters())
classifier_params = list(detector.classifier_head.parameters())
bbox_params = list(detector.bbox_head.parameters())
bbox_params = list(detector.size_head.parameters())
assert backbone_params
assert classifier_params
@ -50,7 +50,7 @@ def test_user_can_finetune_only_heads(
assert finetuned_api is not api
assert detector.backbone is source_backbone
assert detector.classifier_head is not source_classifier_head
assert detector.bbox_head is not source_bbox_head
assert detector.size_head is not source_size_head
assert list(finetune_dir.rglob("*.ckpt"))

View File

@ -1,6 +1,7 @@
import numpy as np
import pytest
import torch
from typing import cast
from batdetect2.models import UNetBackbone
from batdetect2.models.backbones import UNetBackboneConfig
@ -19,12 +20,15 @@ def dummy_spectrogram() -> torch.Tensor:
def test_build_detector_default():
"""Test building the default detector without a config."""
num_classes = 5
model = build_detector(num_classes=num_classes)
model = build_detector(
class_names=[f"class_{i}" for i in range(num_classes)],
dimension_names=["width", "height"],
)
assert isinstance(model, Detector)
assert model.num_classes == num_classes
assert isinstance(model.classifier_head, ClassifierHead)
assert isinstance(model.bbox_head, BBoxHead)
assert isinstance(model.size_head, BBoxHead)
def test_build_detector_custom_config():
@ -32,13 +36,19 @@ def test_build_detector_custom_config():
num_classes = 3
config = UNetBackboneConfig(in_channels=2, input_height=128)
model = build_detector(num_classes=num_classes, config=config)
model = build_detector(
class_names=[f"class_{i}" for i in range(num_classes)],
dimension_names=["width", "height"],
config=config,
)
assert isinstance(model, Detector)
assert model.backbone.input_height == 128
assert isinstance(model.backbone.encoder, Encoder)
assert model.backbone.encoder.in_channels == 2
backbone = cast(UNetBackbone, model.backbone)
assert isinstance(backbone.encoder, Encoder)
assert backbone.encoder.in_channels == 2
def test_build_detector_custom_size_channels():
@ -47,8 +57,8 @@ def test_build_detector_custom_size_channels():
config = UNetBackboneConfig(in_channels=1, input_height=128)
model = build_detector(
num_classes=num_classes,
num_sizes=num_sizes,
class_names=[f"class_{i}" for i in range(num_classes)],
dimension_names=[f"size_{i}" for i in range(num_sizes)],
config=config,
)
@ -62,7 +72,11 @@ def test_detector_forward_pass_shapes(dummy_spectrogram):
num_classes = 4
# Build model matching the dummy input shape
config = UNetBackboneConfig(in_channels=1, input_height=256)
model = build_detector(num_classes=num_classes, config=config)
model = build_detector(
class_names=[f"class_{i}" for i in range(num_classes)],
dimension_names=["width", "height"],
config=config,
)
# Process the spectrogram through the model
# PyTorch expects shape (Batch, Channels, Height, Width)
@ -132,7 +146,11 @@ def test_detector_forward_pass_with_preprocessor(sample_preprocessor):
config = UNetBackboneConfig(
in_channels=spec.shape[1], input_height=spec.shape[2]
)
model = build_detector(num_classes=3, config=config)
model = build_detector(
class_names=["class_0", "class_1", "class_2"],
dimension_names=["width", "height"],
config=config,
)
# Process
output = model(spec)

View File

@ -8,7 +8,7 @@ from soundevent import data
from batdetect2.train import TrainingConfig, run_train
from batdetect2.train.checkpoints import (
DEFAULT_BUNDLED_CHECKPOINT,
DEFAULT_CHECKPOINT,
get_bundled_checkpoint_names,
resolve_checkpoint_path,
)
@ -145,7 +145,7 @@ def test_resolve_checkpoint_path_returns_local_path_unchanged(
def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
assert get_bundled_checkpoint_names() == (
DEFAULT_BUNDLED_CHECKPOINT,
DEFAULT_CHECKPOINT,
"batdetect2_uk_same",
)
@ -153,11 +153,11 @@ def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
def test_resolve_checkpoint_path_uses_default_bundled_alias() -> None:
resolved = resolve_checkpoint_path()
assert resolved == resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT)
assert resolved == resolve_checkpoint_path(DEFAULT_CHECKPOINT)
def test_resolve_checkpoint_path_accepts_bundled_alias() -> None:
resolved = resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT)
resolved = resolve_checkpoint_path(DEFAULT_CHECKPOINT)
assert resolved.name == "batdetect2_uk_same.ckpt"
assert resolved.exists()
@ -227,6 +227,6 @@ def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None:
def test_resolve_checkpoint_path_rejects_missing_local_path() -> None:
with pytest.raises(
FileNotFoundError,
match="bundled checkpoint alias",
match="checkpoint alias",
):
resolve_checkpoint_path("missing.ckpt")

View File

@ -368,7 +368,7 @@ def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> (
assert (
rebuilt_detector.classifier_head is not source_detector.classifier_head
)
assert rebuilt_detector.bbox_head is not source_detector.bbox_head
assert rebuilt_detector.size_head is not source_detector.size_head
assert rebuilt_model.class_names == ["single_class"]
assert rebuilt_model.dimension_names == ["width", "height"]