mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
fix: rename detector heads and refresh bundled checkpoint
This commit is contained in:
parent
2008c8000f
commit
5cc5767eff
@ -19,7 +19,8 @@ if TYPE_CHECKING:
|
||||
LoggerConfig,
|
||||
LoggingCallback,
|
||||
)
|
||||
from batdetect2.models import Model, ModelConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.models.types import ModelProtocol
|
||||
from batdetect2.outputs import (
|
||||
OutputFormatConfig,
|
||||
OutputFormatterProtocol,
|
||||
@ -88,7 +89,7 @@ class BatDetect2API:
|
||||
evaluator: EvaluatorProtocol,
|
||||
formatter: OutputFormatterProtocol,
|
||||
output_transform: OutputTransformProtocol,
|
||||
model: Model,
|
||||
model: ModelProtocol,
|
||||
):
|
||||
"""Create a fully configured API instance.
|
||||
|
||||
@ -128,7 +129,7 @@ class BatDetect2API:
|
||||
Default formatter used to save predictions.
|
||||
output_transform : OutputTransformProtocol
|
||||
Transform that converts model outputs into detections.
|
||||
model : Model
|
||||
model : ModelProtocol
|
||||
Model instance.
|
||||
"""
|
||||
self.model_config = model_config
|
||||
@ -1177,5 +1178,5 @@ class BatDetect2API:
|
||||
parameter.requires_grad = True
|
||||
|
||||
if trainable in {"heads", "bbox_head"}:
|
||||
for parameter in detector.bbox_head.parameters():
|
||||
for parameter in detector.size_head.parameters():
|
||||
parameter.requires_grad = True
|
||||
|
||||
@ -47,24 +47,24 @@ __all__ = ["finetune_command"]
|
||||
@click.option(
|
||||
"--training-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to a training config file.",
|
||||
help="Path to training config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--audio-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to an audio config file.",
|
||||
help="Path to audio config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--logging-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to a logging config file.",
|
||||
help="Path to logging config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--trainable",
|
||||
type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]),
|
||||
default="heads",
|
||||
show_default=True,
|
||||
help="Which model parameters stay trainable during fine-tuning.",
|
||||
help="Which model parameters remain trainable during fine-tuning.",
|
||||
)
|
||||
@click.option(
|
||||
"--ckpt-dir",
|
||||
@ -127,11 +127,7 @@ def finetune_command(
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
):
|
||||
"""Fine-tune a checkpoint on a new target definition.
|
||||
|
||||
Use this command when you want to adapt an existing model to a new class
|
||||
list or ROI mapping.
|
||||
"""
|
||||
"""Fine-tune a BatDetect2 checkpoint on a new target definition."""
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.data import load_dataset, load_dataset_config
|
||||
|
||||
@ -62,7 +62,7 @@ from batdetect2.models.encoder import (
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||
from batdetect2.models.types import DetectionModel
|
||||
from batdetect2.models.types import DetectorProtocol, ModelProtocol
|
||||
from batdetect2.postprocess.config import PostprocessConfig
|
||||
from batdetect2.postprocess.types import (
|
||||
ClipDetectionsTensor,
|
||||
@ -149,7 +149,7 @@ class Model(torch.nn.Module):
|
||||
|
||||
Attributes
|
||||
----------
|
||||
detector : DetectionModel
|
||||
detector : DetectorProtocol
|
||||
The neural network that processes spectrograms and produces raw
|
||||
detection, classification, and bounding-box outputs.
|
||||
preprocessor : PreprocessorProtocol
|
||||
@ -164,7 +164,7 @@ class Model(torch.nn.Module):
|
||||
Size-dimension names corresponding to the model size outputs.
|
||||
"""
|
||||
|
||||
detector: DetectionModel
|
||||
detector: DetectorProtocol
|
||||
preprocessor: PreprocessorProtocol
|
||||
postprocessor: PostprocessorProtocol
|
||||
class_names: list[str]
|
||||
@ -173,7 +173,7 @@ class Model(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detector: DetectionModel,
|
||||
detector: DetectorProtocol,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
class_names: list[str],
|
||||
@ -224,7 +224,7 @@ def build_model(
|
||||
dimension_names: list[str] | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
postprocessor: PostprocessorProtocol | None = None,
|
||||
) -> Model:
|
||||
) -> ModelProtocol:
|
||||
"""Build a complete, ready-to-use BatDetect2 model.
|
||||
|
||||
Assembles a ``Model`` instance from a ``ModelConfig`` and optional
|
||||
@ -256,7 +256,7 @@ def build_model(
|
||||
|
||||
Returns
|
||||
-------
|
||||
Model
|
||||
ModelProtocol
|
||||
A fully assembled ``Model`` instance ready for inference or
|
||||
training.
|
||||
"""
|
||||
@ -285,8 +285,8 @@ def build_model(
|
||||
config=config.postprocess,
|
||||
)
|
||||
detector = build_detector(
|
||||
num_classes=len(class_names),
|
||||
num_sizes=len(dimension_names),
|
||||
class_names=class_names,
|
||||
dimension_names=dimension_names,
|
||||
config=config.architecture,
|
||||
)
|
||||
return Model(
|
||||
@ -300,14 +300,14 @@ def build_model(
|
||||
|
||||
|
||||
def build_model_with_new_targets(
|
||||
model: Model,
|
||||
model: ModelProtocol,
|
||||
targets: TargetProtocol,
|
||||
roi_mapper: ROIMapperProtocol,
|
||||
) -> Model:
|
||||
) -> ModelProtocol:
|
||||
"""Build a new model with a different target set."""
|
||||
detector = build_detector(
|
||||
num_classes=len(targets.class_names),
|
||||
num_sizes=len(roi_mapper.dimension_names),
|
||||
class_names=targets.class_names,
|
||||
dimension_names=roi_mapper.dimension_names,
|
||||
backbone=model.detector.backbone,
|
||||
)
|
||||
|
||||
|
||||
Binary file not shown.
@ -6,8 +6,8 @@ bounding-box size regression.
|
||||
|
||||
Components
|
||||
----------
|
||||
- ``Detector`` – the ``torch.nn.Module`` that wires together a backbone
|
||||
(``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to
|
||||
- ``Detector`` - the ``torch.nn.Module`` that wires together a backbone
|
||||
(``BackboneProtocol``) with a ``ClassifierHead`` and a ``BBoxHead`` to
|
||||
produce a ``ModelOutput`` tuple from an input spectrogram.
|
||||
- ``build_detector`` – factory function that builds a ready-to-use
|
||||
``Detector`` from a backbone configuration and a target class count.
|
||||
@ -18,15 +18,16 @@ preprocessing and output postprocessing are handled by
|
||||
"""
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.models.backbones import (
|
||||
BackboneConfig,
|
||||
UNetBackboneConfig,
|
||||
build_backbone,
|
||||
)
|
||||
from batdetect2.models.backbones import BackboneConfig, build_backbone
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||
from batdetect2.models.types import (
|
||||
BackboneProtocol,
|
||||
ClassifierHeadProtocol,
|
||||
DetectorProtocol,
|
||||
ModelOutput,
|
||||
SizeHeadProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Detector",
|
||||
@ -34,7 +35,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class Detector(DetectionModel):
|
||||
class Detector(torch.nn.Module):
|
||||
"""Complete BatDetect2 detection and classification model.
|
||||
|
||||
Combines a backbone feature extractor with two prediction heads:
|
||||
@ -51,7 +52,7 @@ class Detector(DetectionModel):
|
||||
|
||||
Attributes
|
||||
----------
|
||||
backbone : BackboneModel
|
||||
backbone : BackboneProtocol
|
||||
The feature extraction backbone.
|
||||
num_classes : int
|
||||
Number of target classes (inferred from the classifier head).
|
||||
@ -61,13 +62,13 @@ class Detector(DetectionModel):
|
||||
Produces duration and bandwidth predictions from backbone features.
|
||||
"""
|
||||
|
||||
backbone: BackboneModel
|
||||
backbone: BackboneProtocol
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone: BackboneModel,
|
||||
classifier_head: ClassifierHead,
|
||||
bbox_head: BBoxHead,
|
||||
backbone: BackboneProtocol,
|
||||
classifier_head: ClassifierHeadProtocol,
|
||||
size_head: SizeHeadProtocol,
|
||||
):
|
||||
"""Initialise the Detector model.
|
||||
|
||||
@ -76,7 +77,7 @@ class Detector(DetectionModel):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backbone : BackboneModel
|
||||
backbone : BackboneProtocol
|
||||
An initialised backbone module (e.g. built by
|
||||
``build_backbone``).
|
||||
classifier_head : ClassifierHead
|
||||
@ -90,7 +91,7 @@ class Detector(DetectionModel):
|
||||
self.backbone = backbone
|
||||
self.num_classes = classifier_head.num_classes
|
||||
self.classifier_head = classifier_head
|
||||
self.bbox_head = bbox_head
|
||||
self.size_head = size_head
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
"""Run the complete detection model on an input spectrogram.
|
||||
@ -125,7 +126,7 @@ class Detector(DetectionModel):
|
||||
features = self.backbone(spec)
|
||||
classification = self.classifier_head(features)
|
||||
detection = classification.sum(dim=1, keepdim=True)
|
||||
size_preds = self.bbox_head(features)
|
||||
size_preds = self.size_head(features)
|
||||
return ModelOutput(
|
||||
detection_probs=detection,
|
||||
size_preds=size_preds,
|
||||
@ -135,11 +136,11 @@ class Detector(DetectionModel):
|
||||
|
||||
|
||||
def build_detector(
|
||||
num_classes: int,
|
||||
num_sizes: int = 2,
|
||||
class_names: list[str],
|
||||
dimension_names: list[str],
|
||||
config: BackboneConfig | None = None,
|
||||
backbone: BackboneModel | None = None,
|
||||
) -> DetectionModel:
|
||||
backbone: BackboneProtocol | None = None,
|
||||
) -> DetectorProtocol:
|
||||
"""Build a complete BatDetect2 detection model.
|
||||
|
||||
Constructs a backbone from ``config``, attaches a ``ClassifierHead``
|
||||
@ -158,7 +159,7 @@ def build_detector(
|
||||
|
||||
Returns
|
||||
-------
|
||||
DetectionModel
|
||||
DetectorProtocol
|
||||
An initialised ``Detector`` instance ready for training or
|
||||
inference.
|
||||
|
||||
@ -168,24 +169,18 @@ def build_detector(
|
||||
If ``num_classes`` is not positive, or if the backbone
|
||||
configuration is invalid.
|
||||
"""
|
||||
if backbone is None:
|
||||
config = config or UNetBackboneConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building model with config: \n{}",
|
||||
lambda: config.to_yaml_string(), # type: ignore
|
||||
)
|
||||
backbone = build_backbone(config=config)
|
||||
backbone = backbone or build_backbone(config=config)
|
||||
|
||||
classifier_head = ClassifierHead(
|
||||
num_classes=num_classes,
|
||||
class_names=class_names,
|
||||
in_channels=backbone.out_channels,
|
||||
)
|
||||
bbox_head = BBoxHead(
|
||||
in_channels=backbone.out_channels,
|
||||
num_sizes=num_sizes,
|
||||
dimension_names=dimension_names,
|
||||
)
|
||||
return Detector(
|
||||
backbone=backbone,
|
||||
classifier_head=classifier_head,
|
||||
bbox_head=bbox_head,
|
||||
size_head=bbox_head,
|
||||
)
|
||||
|
||||
@ -54,12 +54,14 @@ class ClassifierHead(nn.Module):
|
||||
1×1 convolution with ``num_classes + 1`` output channels.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes: int, in_channels: int):
|
||||
def __init__(self, class_names: list[str], in_channels: int):
|
||||
"""Initialise the ClassifierHead."""
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.class_names = class_names
|
||||
self.num_classes = len(class_names)
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.classifier = nn.Conv2d(
|
||||
self.in_channels,
|
||||
self.num_classes + 1,
|
||||
@ -165,11 +167,12 @@ class BBoxHead(nn.Module):
|
||||
1×1 convolution with 2 output channels (duration, bandwidth).
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, num_sizes: int = 2):
|
||||
def __init__(self, dimension_names: list[str], in_channels: int):
|
||||
"""Initialise the BBoxHead."""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.num_sizes = num_sizes
|
||||
self.dimension_names = dimension_names
|
||||
self.num_sizes = len(dimension_names)
|
||||
|
||||
self.bbox = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
|
||||
@ -34,6 +34,7 @@ class CheckpointConfig(BaseConfig):
|
||||
monitor: str | None = None
|
||||
mode: str = "max"
|
||||
save_top_k: int = 1
|
||||
# Save distributable inference checkpoints by default.
|
||||
save_weights_only: bool = True
|
||||
filename: str | None = None
|
||||
save_last: bool | Literal["link"] = "link"
|
||||
|
||||
@ -299,10 +299,10 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
||||
value,
|
||||
)
|
||||
|
||||
for key, value in source_detector.bbox_head.state_dict().items():
|
||||
assert key in detector.bbox_head.state_dict()
|
||||
for key, value in source_detector.size_head.state_dict().items():
|
||||
assert key in detector.size_head.state_dict()
|
||||
torch.testing.assert_close(
|
||||
detector.bbox_head.state_dict()[key],
|
||||
detector.size_head.state_dict()[key],
|
||||
value,
|
||||
)
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ def test_user_can_finetune_only_heads(
|
||||
|
||||
api = BatDetect2API.from_config()
|
||||
source_classifier_head = api.model.detector.classifier_head
|
||||
source_bbox_head = api.model.detector.bbox_head
|
||||
source_size_head = api.model.detector.size_head
|
||||
source_backbone = api.model.detector.backbone
|
||||
finetune_dir = tmp_path / "heads_only"
|
||||
|
||||
@ -39,7 +39,7 @@ def test_user_can_finetune_only_heads(
|
||||
|
||||
backbone_params = list(detector.backbone.parameters())
|
||||
classifier_params = list(detector.classifier_head.parameters())
|
||||
bbox_params = list(detector.bbox_head.parameters())
|
||||
bbox_params = list(detector.size_head.parameters())
|
||||
|
||||
assert backbone_params
|
||||
assert classifier_params
|
||||
@ -50,7 +50,7 @@ def test_user_can_finetune_only_heads(
|
||||
assert finetuned_api is not api
|
||||
assert detector.backbone is source_backbone
|
||||
assert detector.classifier_head is not source_classifier_head
|
||||
assert detector.bbox_head is not source_bbox_head
|
||||
assert detector.size_head is not source_size_head
|
||||
assert list(finetune_dir.rglob("*.ckpt"))
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from typing import cast
|
||||
|
||||
from batdetect2.models import UNetBackbone
|
||||
from batdetect2.models.backbones import UNetBackboneConfig
|
||||
@ -19,12 +20,15 @@ def dummy_spectrogram() -> torch.Tensor:
|
||||
def test_build_detector_default():
|
||||
"""Test building the default detector without a config."""
|
||||
num_classes = 5
|
||||
model = build_detector(num_classes=num_classes)
|
||||
model = build_detector(
|
||||
class_names=[f"class_{i}" for i in range(num_classes)],
|
||||
dimension_names=["width", "height"],
|
||||
)
|
||||
|
||||
assert isinstance(model, Detector)
|
||||
assert model.num_classes == num_classes
|
||||
assert isinstance(model.classifier_head, ClassifierHead)
|
||||
assert isinstance(model.bbox_head, BBoxHead)
|
||||
assert isinstance(model.size_head, BBoxHead)
|
||||
|
||||
|
||||
def test_build_detector_custom_config():
|
||||
@ -32,13 +36,19 @@ def test_build_detector_custom_config():
|
||||
num_classes = 3
|
||||
config = UNetBackboneConfig(in_channels=2, input_height=128)
|
||||
|
||||
model = build_detector(num_classes=num_classes, config=config)
|
||||
model = build_detector(
|
||||
class_names=[f"class_{i}" for i in range(num_classes)],
|
||||
dimension_names=["width", "height"],
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert isinstance(model, Detector)
|
||||
assert model.backbone.input_height == 128
|
||||
|
||||
assert isinstance(model.backbone.encoder, Encoder)
|
||||
assert model.backbone.encoder.in_channels == 2
|
||||
backbone = cast(UNetBackbone, model.backbone)
|
||||
|
||||
assert isinstance(backbone.encoder, Encoder)
|
||||
assert backbone.encoder.in_channels == 2
|
||||
|
||||
|
||||
def test_build_detector_custom_size_channels():
|
||||
@ -47,8 +57,8 @@ def test_build_detector_custom_size_channels():
|
||||
config = UNetBackboneConfig(in_channels=1, input_height=128)
|
||||
|
||||
model = build_detector(
|
||||
num_classes=num_classes,
|
||||
num_sizes=num_sizes,
|
||||
class_names=[f"class_{i}" for i in range(num_classes)],
|
||||
dimension_names=[f"size_{i}" for i in range(num_sizes)],
|
||||
config=config,
|
||||
)
|
||||
|
||||
@ -62,7 +72,11 @@ def test_detector_forward_pass_shapes(dummy_spectrogram):
|
||||
num_classes = 4
|
||||
# Build model matching the dummy input shape
|
||||
config = UNetBackboneConfig(in_channels=1, input_height=256)
|
||||
model = build_detector(num_classes=num_classes, config=config)
|
||||
model = build_detector(
|
||||
class_names=[f"class_{i}" for i in range(num_classes)],
|
||||
dimension_names=["width", "height"],
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Process the spectrogram through the model
|
||||
# PyTorch expects shape (Batch, Channels, Height, Width)
|
||||
@ -132,7 +146,11 @@ def test_detector_forward_pass_with_preprocessor(sample_preprocessor):
|
||||
config = UNetBackboneConfig(
|
||||
in_channels=spec.shape[1], input_height=spec.shape[2]
|
||||
)
|
||||
model = build_detector(num_classes=3, config=config)
|
||||
model = build_detector(
|
||||
class_names=["class_0", "class_1", "class_2"],
|
||||
dimension_names=["width", "height"],
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Process
|
||||
output = model(spec)
|
||||
|
||||
@ -8,7 +8,7 @@ from soundevent import data
|
||||
|
||||
from batdetect2.train import TrainingConfig, run_train
|
||||
from batdetect2.train.checkpoints import (
|
||||
DEFAULT_BUNDLED_CHECKPOINT,
|
||||
DEFAULT_CHECKPOINT,
|
||||
get_bundled_checkpoint_names,
|
||||
resolve_checkpoint_path,
|
||||
)
|
||||
@ -145,7 +145,7 @@ def test_resolve_checkpoint_path_returns_local_path_unchanged(
|
||||
|
||||
def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
|
||||
assert get_bundled_checkpoint_names() == (
|
||||
DEFAULT_BUNDLED_CHECKPOINT,
|
||||
DEFAULT_CHECKPOINT,
|
||||
"batdetect2_uk_same",
|
||||
)
|
||||
|
||||
@ -153,11 +153,11 @@ def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
|
||||
def test_resolve_checkpoint_path_uses_default_bundled_alias() -> None:
|
||||
resolved = resolve_checkpoint_path()
|
||||
|
||||
assert resolved == resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT)
|
||||
assert resolved == resolve_checkpoint_path(DEFAULT_CHECKPOINT)
|
||||
|
||||
|
||||
def test_resolve_checkpoint_path_accepts_bundled_alias() -> None:
|
||||
resolved = resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT)
|
||||
resolved = resolve_checkpoint_path(DEFAULT_CHECKPOINT)
|
||||
|
||||
assert resolved.name == "batdetect2_uk_same.ckpt"
|
||||
assert resolved.exists()
|
||||
@ -227,6 +227,6 @@ def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None:
|
||||
def test_resolve_checkpoint_path_rejects_missing_local_path() -> None:
|
||||
with pytest.raises(
|
||||
FileNotFoundError,
|
||||
match="bundled checkpoint alias",
|
||||
match="checkpoint alias",
|
||||
):
|
||||
resolve_checkpoint_path("missing.ckpt")
|
||||
|
||||
@ -368,7 +368,7 @@ def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> (
|
||||
assert (
|
||||
rebuilt_detector.classifier_head is not source_detector.classifier_head
|
||||
)
|
||||
assert rebuilt_detector.bbox_head is not source_detector.bbox_head
|
||||
assert rebuilt_detector.size_head is not source_detector.size_head
|
||||
assert rebuilt_model.class_names == ["single_class"]
|
||||
assert rebuilt_model.dimension_names == ["width", "height"]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user