mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
fix: rename detector heads and refresh bundled checkpoint
This commit is contained in:
parent
2008c8000f
commit
5cc5767eff
@ -19,7 +19,8 @@ if TYPE_CHECKING:
|
|||||||
LoggerConfig,
|
LoggerConfig,
|
||||||
LoggingCallback,
|
LoggingCallback,
|
||||||
)
|
)
|
||||||
from batdetect2.models import Model, ModelConfig
|
from batdetect2.models import ModelConfig
|
||||||
|
from batdetect2.models.types import ModelProtocol
|
||||||
from batdetect2.outputs import (
|
from batdetect2.outputs import (
|
||||||
OutputFormatConfig,
|
OutputFormatConfig,
|
||||||
OutputFormatterProtocol,
|
OutputFormatterProtocol,
|
||||||
@ -88,7 +89,7 @@ class BatDetect2API:
|
|||||||
evaluator: EvaluatorProtocol,
|
evaluator: EvaluatorProtocol,
|
||||||
formatter: OutputFormatterProtocol,
|
formatter: OutputFormatterProtocol,
|
||||||
output_transform: OutputTransformProtocol,
|
output_transform: OutputTransformProtocol,
|
||||||
model: Model,
|
model: ModelProtocol,
|
||||||
):
|
):
|
||||||
"""Create a fully configured API instance.
|
"""Create a fully configured API instance.
|
||||||
|
|
||||||
@ -128,7 +129,7 @@ class BatDetect2API:
|
|||||||
Default formatter used to save predictions.
|
Default formatter used to save predictions.
|
||||||
output_transform : OutputTransformProtocol
|
output_transform : OutputTransformProtocol
|
||||||
Transform that converts model outputs into detections.
|
Transform that converts model outputs into detections.
|
||||||
model : Model
|
model : ModelProtocol
|
||||||
Model instance.
|
Model instance.
|
||||||
"""
|
"""
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
@ -1177,5 +1178,5 @@ class BatDetect2API:
|
|||||||
parameter.requires_grad = True
|
parameter.requires_grad = True
|
||||||
|
|
||||||
if trainable in {"heads", "bbox_head"}:
|
if trainable in {"heads", "bbox_head"}:
|
||||||
for parameter in detector.bbox_head.parameters():
|
for parameter in detector.size_head.parameters():
|
||||||
parameter.requires_grad = True
|
parameter.requires_grad = True
|
||||||
|
|||||||
@ -47,24 +47,24 @@ __all__ = ["finetune_command"]
|
|||||||
@click.option(
|
@click.option(
|
||||||
"--training-config",
|
"--training-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to a training config file.",
|
help="Path to training config file.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--audio-config",
|
"--audio-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to an audio config file.",
|
help="Path to audio config file.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--logging-config",
|
"--logging-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to a logging config file.",
|
help="Path to logging config file.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--trainable",
|
"--trainable",
|
||||||
type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]),
|
type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]),
|
||||||
default="heads",
|
default="heads",
|
||||||
show_default=True,
|
show_default=True,
|
||||||
help="Which model parameters stay trainable during fine-tuning.",
|
help="Which model parameters remain trainable during fine-tuning.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--ckpt-dir",
|
"--ckpt-dir",
|
||||||
@ -127,11 +127,7 @@ def finetune_command(
|
|||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
):
|
):
|
||||||
"""Fine-tune a checkpoint on a new target definition.
|
"""Fine-tune a BatDetect2 checkpoint on a new target definition."""
|
||||||
|
|
||||||
Use this command when you want to adapt an existing model to a new class
|
|
||||||
list or ROI mapping.
|
|
||||||
"""
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.audio import AudioConfig
|
||||||
from batdetect2.data import load_dataset, load_dataset_config
|
from batdetect2.data import load_dataset, load_dataset_config
|
||||||
|
|||||||
@ -62,7 +62,7 @@ from batdetect2.models.encoder import (
|
|||||||
build_encoder,
|
build_encoder,
|
||||||
)
|
)
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||||
from batdetect2.models.types import DetectionModel
|
from batdetect2.models.types import DetectorProtocol, ModelProtocol
|
||||||
from batdetect2.postprocess.config import PostprocessConfig
|
from batdetect2.postprocess.config import PostprocessConfig
|
||||||
from batdetect2.postprocess.types import (
|
from batdetect2.postprocess.types import (
|
||||||
ClipDetectionsTensor,
|
ClipDetectionsTensor,
|
||||||
@ -149,7 +149,7 @@ class Model(torch.nn.Module):
|
|||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
detector : DetectionModel
|
detector : DetectorProtocol
|
||||||
The neural network that processes spectrograms and produces raw
|
The neural network that processes spectrograms and produces raw
|
||||||
detection, classification, and bounding-box outputs.
|
detection, classification, and bounding-box outputs.
|
||||||
preprocessor : PreprocessorProtocol
|
preprocessor : PreprocessorProtocol
|
||||||
@ -164,7 +164,7 @@ class Model(torch.nn.Module):
|
|||||||
Size-dimension names corresponding to the model size outputs.
|
Size-dimension names corresponding to the model size outputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
detector: DetectionModel
|
detector: DetectorProtocol
|
||||||
preprocessor: PreprocessorProtocol
|
preprocessor: PreprocessorProtocol
|
||||||
postprocessor: PostprocessorProtocol
|
postprocessor: PostprocessorProtocol
|
||||||
class_names: list[str]
|
class_names: list[str]
|
||||||
@ -173,7 +173,7 @@ class Model(torch.nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
detector: DetectionModel,
|
detector: DetectorProtocol,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
class_names: list[str],
|
class_names: list[str],
|
||||||
@ -224,7 +224,7 @@ def build_model(
|
|||||||
dimension_names: list[str] | None = None,
|
dimension_names: list[str] | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
postprocessor: PostprocessorProtocol | None = None,
|
postprocessor: PostprocessorProtocol | None = None,
|
||||||
) -> Model:
|
) -> ModelProtocol:
|
||||||
"""Build a complete, ready-to-use BatDetect2 model.
|
"""Build a complete, ready-to-use BatDetect2 model.
|
||||||
|
|
||||||
Assembles a ``Model`` instance from a ``ModelConfig`` and optional
|
Assembles a ``Model`` instance from a ``ModelConfig`` and optional
|
||||||
@ -256,7 +256,7 @@ def build_model(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Model
|
ModelProtocol
|
||||||
A fully assembled ``Model`` instance ready for inference or
|
A fully assembled ``Model`` instance ready for inference or
|
||||||
training.
|
training.
|
||||||
"""
|
"""
|
||||||
@ -285,8 +285,8 @@ def build_model(
|
|||||||
config=config.postprocess,
|
config=config.postprocess,
|
||||||
)
|
)
|
||||||
detector = build_detector(
|
detector = build_detector(
|
||||||
num_classes=len(class_names),
|
class_names=class_names,
|
||||||
num_sizes=len(dimension_names),
|
dimension_names=dimension_names,
|
||||||
config=config.architecture,
|
config=config.architecture,
|
||||||
)
|
)
|
||||||
return Model(
|
return Model(
|
||||||
@ -300,14 +300,14 @@ def build_model(
|
|||||||
|
|
||||||
|
|
||||||
def build_model_with_new_targets(
|
def build_model_with_new_targets(
|
||||||
model: Model,
|
model: ModelProtocol,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
roi_mapper: ROIMapperProtocol,
|
roi_mapper: ROIMapperProtocol,
|
||||||
) -> Model:
|
) -> ModelProtocol:
|
||||||
"""Build a new model with a different target set."""
|
"""Build a new model with a different target set."""
|
||||||
detector = build_detector(
|
detector = build_detector(
|
||||||
num_classes=len(targets.class_names),
|
class_names=targets.class_names,
|
||||||
num_sizes=len(roi_mapper.dimension_names),
|
dimension_names=roi_mapper.dimension_names,
|
||||||
backbone=model.detector.backbone,
|
backbone=model.detector.backbone,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
@ -6,8 +6,8 @@ bounding-box size regression.
|
|||||||
|
|
||||||
Components
|
Components
|
||||||
----------
|
----------
|
||||||
- ``Detector`` – the ``torch.nn.Module`` that wires together a backbone
|
- ``Detector`` - the ``torch.nn.Module`` that wires together a backbone
|
||||||
(``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to
|
(``BackboneProtocol``) with a ``ClassifierHead`` and a ``BBoxHead`` to
|
||||||
produce a ``ModelOutput`` tuple from an input spectrogram.
|
produce a ``ModelOutput`` tuple from an input spectrogram.
|
||||||
- ``build_detector`` – factory function that builds a ready-to-use
|
- ``build_detector`` – factory function that builds a ready-to-use
|
||||||
``Detector`` from a backbone configuration and a target class count.
|
``Detector`` from a backbone configuration and a target class count.
|
||||||
@ -18,15 +18,16 @@ preprocessing and output postprocessing are handled by
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import BackboneConfig, build_backbone
|
||||||
BackboneConfig,
|
|
||||||
UNetBackboneConfig,
|
|
||||||
build_backbone,
|
|
||||||
)
|
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
from batdetect2.models.types import (
|
||||||
|
BackboneProtocol,
|
||||||
|
ClassifierHeadProtocol,
|
||||||
|
DetectorProtocol,
|
||||||
|
ModelOutput,
|
||||||
|
SizeHeadProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Detector",
|
"Detector",
|
||||||
@ -34,7 +35,7 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Detector(DetectionModel):
|
class Detector(torch.nn.Module):
|
||||||
"""Complete BatDetect2 detection and classification model.
|
"""Complete BatDetect2 detection and classification model.
|
||||||
|
|
||||||
Combines a backbone feature extractor with two prediction heads:
|
Combines a backbone feature extractor with two prediction heads:
|
||||||
@ -51,7 +52,7 @@ class Detector(DetectionModel):
|
|||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
backbone : BackboneModel
|
backbone : BackboneProtocol
|
||||||
The feature extraction backbone.
|
The feature extraction backbone.
|
||||||
num_classes : int
|
num_classes : int
|
||||||
Number of target classes (inferred from the classifier head).
|
Number of target classes (inferred from the classifier head).
|
||||||
@ -61,13 +62,13 @@ class Detector(DetectionModel):
|
|||||||
Produces duration and bandwidth predictions from backbone features.
|
Produces duration and bandwidth predictions from backbone features.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
backbone: BackboneModel
|
backbone: BackboneProtocol
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
backbone: BackboneModel,
|
backbone: BackboneProtocol,
|
||||||
classifier_head: ClassifierHead,
|
classifier_head: ClassifierHeadProtocol,
|
||||||
bbox_head: BBoxHead,
|
size_head: SizeHeadProtocol,
|
||||||
):
|
):
|
||||||
"""Initialise the Detector model.
|
"""Initialise the Detector model.
|
||||||
|
|
||||||
@ -76,7 +77,7 @@ class Detector(DetectionModel):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
backbone : BackboneModel
|
backbone : BackboneProtocol
|
||||||
An initialised backbone module (e.g. built by
|
An initialised backbone module (e.g. built by
|
||||||
``build_backbone``).
|
``build_backbone``).
|
||||||
classifier_head : ClassifierHead
|
classifier_head : ClassifierHead
|
||||||
@ -90,7 +91,7 @@ class Detector(DetectionModel):
|
|||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.num_classes = classifier_head.num_classes
|
self.num_classes = classifier_head.num_classes
|
||||||
self.classifier_head = classifier_head
|
self.classifier_head = classifier_head
|
||||||
self.bbox_head = bbox_head
|
self.size_head = size_head
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
"""Run the complete detection model on an input spectrogram.
|
"""Run the complete detection model on an input spectrogram.
|
||||||
@ -125,7 +126,7 @@ class Detector(DetectionModel):
|
|||||||
features = self.backbone(spec)
|
features = self.backbone(spec)
|
||||||
classification = self.classifier_head(features)
|
classification = self.classifier_head(features)
|
||||||
detection = classification.sum(dim=1, keepdim=True)
|
detection = classification.sum(dim=1, keepdim=True)
|
||||||
size_preds = self.bbox_head(features)
|
size_preds = self.size_head(features)
|
||||||
return ModelOutput(
|
return ModelOutput(
|
||||||
detection_probs=detection,
|
detection_probs=detection,
|
||||||
size_preds=size_preds,
|
size_preds=size_preds,
|
||||||
@ -135,11 +136,11 @@ class Detector(DetectionModel):
|
|||||||
|
|
||||||
|
|
||||||
def build_detector(
|
def build_detector(
|
||||||
num_classes: int,
|
class_names: list[str],
|
||||||
num_sizes: int = 2,
|
dimension_names: list[str],
|
||||||
config: BackboneConfig | None = None,
|
config: BackboneConfig | None = None,
|
||||||
backbone: BackboneModel | None = None,
|
backbone: BackboneProtocol | None = None,
|
||||||
) -> DetectionModel:
|
) -> DetectorProtocol:
|
||||||
"""Build a complete BatDetect2 detection model.
|
"""Build a complete BatDetect2 detection model.
|
||||||
|
|
||||||
Constructs a backbone from ``config``, attaches a ``ClassifierHead``
|
Constructs a backbone from ``config``, attaches a ``ClassifierHead``
|
||||||
@ -158,7 +159,7 @@ def build_detector(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
DetectionModel
|
DetectorProtocol
|
||||||
An initialised ``Detector`` instance ready for training or
|
An initialised ``Detector`` instance ready for training or
|
||||||
inference.
|
inference.
|
||||||
|
|
||||||
@ -168,24 +169,18 @@ def build_detector(
|
|||||||
If ``num_classes`` is not positive, or if the backbone
|
If ``num_classes`` is not positive, or if the backbone
|
||||||
configuration is invalid.
|
configuration is invalid.
|
||||||
"""
|
"""
|
||||||
if backbone is None:
|
backbone = backbone or build_backbone(config=config)
|
||||||
config = config or UNetBackboneConfig()
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Building model with config: \n{}",
|
|
||||||
lambda: config.to_yaml_string(), # type: ignore
|
|
||||||
)
|
|
||||||
backbone = build_backbone(config=config)
|
|
||||||
|
|
||||||
classifier_head = ClassifierHead(
|
classifier_head = ClassifierHead(
|
||||||
num_classes=num_classes,
|
class_names=class_names,
|
||||||
in_channels=backbone.out_channels,
|
in_channels=backbone.out_channels,
|
||||||
)
|
)
|
||||||
bbox_head = BBoxHead(
|
bbox_head = BBoxHead(
|
||||||
in_channels=backbone.out_channels,
|
in_channels=backbone.out_channels,
|
||||||
num_sizes=num_sizes,
|
dimension_names=dimension_names,
|
||||||
)
|
)
|
||||||
return Detector(
|
return Detector(
|
||||||
backbone=backbone,
|
backbone=backbone,
|
||||||
classifier_head=classifier_head,
|
classifier_head=classifier_head,
|
||||||
bbox_head=bbox_head,
|
size_head=bbox_head,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -54,12 +54,14 @@ class ClassifierHead(nn.Module):
|
|||||||
1×1 convolution with ``num_classes + 1`` output channels.
|
1×1 convolution with ``num_classes + 1`` output channels.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_classes: int, in_channels: int):
|
def __init__(self, class_names: list[str], in_channels: int):
|
||||||
"""Initialise the ClassifierHead."""
|
"""Initialise the ClassifierHead."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_classes = num_classes
|
self.class_names = class_names
|
||||||
|
self.num_classes = len(class_names)
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
self.classifier = nn.Conv2d(
|
self.classifier = nn.Conv2d(
|
||||||
self.in_channels,
|
self.in_channels,
|
||||||
self.num_classes + 1,
|
self.num_classes + 1,
|
||||||
@ -165,11 +167,12 @@ class BBoxHead(nn.Module):
|
|||||||
1×1 convolution with 2 output channels (duration, bandwidth).
|
1×1 convolution with 2 output channels (duration, bandwidth).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels: int, num_sizes: int = 2):
|
def __init__(self, dimension_names: list[str], in_channels: int):
|
||||||
"""Initialise the BBoxHead."""
|
"""Initialise the BBoxHead."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.num_sizes = num_sizes
|
self.dimension_names = dimension_names
|
||||||
|
self.num_sizes = len(dimension_names)
|
||||||
|
|
||||||
self.bbox = nn.Conv2d(
|
self.bbox = nn.Conv2d(
|
||||||
in_channels=self.in_channels,
|
in_channels=self.in_channels,
|
||||||
|
|||||||
@ -34,6 +34,7 @@ class CheckpointConfig(BaseConfig):
|
|||||||
monitor: str | None = None
|
monitor: str | None = None
|
||||||
mode: str = "max"
|
mode: str = "max"
|
||||||
save_top_k: int = 1
|
save_top_k: int = 1
|
||||||
|
# Save distributable inference checkpoints by default.
|
||||||
save_weights_only: bool = True
|
save_weights_only: bool = True
|
||||||
filename: str | None = None
|
filename: str | None = None
|
||||||
save_last: bool | Literal["link"] = "link"
|
save_last: bool | Literal["link"] = "link"
|
||||||
|
|||||||
@ -299,10 +299,10 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
|||||||
value,
|
value,
|
||||||
)
|
)
|
||||||
|
|
||||||
for key, value in source_detector.bbox_head.state_dict().items():
|
for key, value in source_detector.size_head.state_dict().items():
|
||||||
assert key in detector.bbox_head.state_dict()
|
assert key in detector.size_head.state_dict()
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
detector.bbox_head.state_dict()[key],
|
detector.size_head.state_dict()[key],
|
||||||
value,
|
value,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ def test_user_can_finetune_only_heads(
|
|||||||
|
|
||||||
api = BatDetect2API.from_config()
|
api = BatDetect2API.from_config()
|
||||||
source_classifier_head = api.model.detector.classifier_head
|
source_classifier_head = api.model.detector.classifier_head
|
||||||
source_bbox_head = api.model.detector.bbox_head
|
source_size_head = api.model.detector.size_head
|
||||||
source_backbone = api.model.detector.backbone
|
source_backbone = api.model.detector.backbone
|
||||||
finetune_dir = tmp_path / "heads_only"
|
finetune_dir = tmp_path / "heads_only"
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ def test_user_can_finetune_only_heads(
|
|||||||
|
|
||||||
backbone_params = list(detector.backbone.parameters())
|
backbone_params = list(detector.backbone.parameters())
|
||||||
classifier_params = list(detector.classifier_head.parameters())
|
classifier_params = list(detector.classifier_head.parameters())
|
||||||
bbox_params = list(detector.bbox_head.parameters())
|
bbox_params = list(detector.size_head.parameters())
|
||||||
|
|
||||||
assert backbone_params
|
assert backbone_params
|
||||||
assert classifier_params
|
assert classifier_params
|
||||||
@ -50,7 +50,7 @@ def test_user_can_finetune_only_heads(
|
|||||||
assert finetuned_api is not api
|
assert finetuned_api is not api
|
||||||
assert detector.backbone is source_backbone
|
assert detector.backbone is source_backbone
|
||||||
assert detector.classifier_head is not source_classifier_head
|
assert detector.classifier_head is not source_classifier_head
|
||||||
assert detector.bbox_head is not source_bbox_head
|
assert detector.size_head is not source_size_head
|
||||||
assert list(finetune_dir.rglob("*.ckpt"))
|
assert list(finetune_dir.rglob("*.ckpt"))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
from batdetect2.models import UNetBackbone
|
from batdetect2.models import UNetBackbone
|
||||||
from batdetect2.models.backbones import UNetBackboneConfig
|
from batdetect2.models.backbones import UNetBackboneConfig
|
||||||
@ -19,12 +20,15 @@ def dummy_spectrogram() -> torch.Tensor:
|
|||||||
def test_build_detector_default():
|
def test_build_detector_default():
|
||||||
"""Test building the default detector without a config."""
|
"""Test building the default detector without a config."""
|
||||||
num_classes = 5
|
num_classes = 5
|
||||||
model = build_detector(num_classes=num_classes)
|
model = build_detector(
|
||||||
|
class_names=[f"class_{i}" for i in range(num_classes)],
|
||||||
|
dimension_names=["width", "height"],
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(model, Detector)
|
assert isinstance(model, Detector)
|
||||||
assert model.num_classes == num_classes
|
assert model.num_classes == num_classes
|
||||||
assert isinstance(model.classifier_head, ClassifierHead)
|
assert isinstance(model.classifier_head, ClassifierHead)
|
||||||
assert isinstance(model.bbox_head, BBoxHead)
|
assert isinstance(model.size_head, BBoxHead)
|
||||||
|
|
||||||
|
|
||||||
def test_build_detector_custom_config():
|
def test_build_detector_custom_config():
|
||||||
@ -32,13 +36,19 @@ def test_build_detector_custom_config():
|
|||||||
num_classes = 3
|
num_classes = 3
|
||||||
config = UNetBackboneConfig(in_channels=2, input_height=128)
|
config = UNetBackboneConfig(in_channels=2, input_height=128)
|
||||||
|
|
||||||
model = build_detector(num_classes=num_classes, config=config)
|
model = build_detector(
|
||||||
|
class_names=[f"class_{i}" for i in range(num_classes)],
|
||||||
|
dimension_names=["width", "height"],
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(model, Detector)
|
assert isinstance(model, Detector)
|
||||||
assert model.backbone.input_height == 128
|
assert model.backbone.input_height == 128
|
||||||
|
|
||||||
assert isinstance(model.backbone.encoder, Encoder)
|
backbone = cast(UNetBackbone, model.backbone)
|
||||||
assert model.backbone.encoder.in_channels == 2
|
|
||||||
|
assert isinstance(backbone.encoder, Encoder)
|
||||||
|
assert backbone.encoder.in_channels == 2
|
||||||
|
|
||||||
|
|
||||||
def test_build_detector_custom_size_channels():
|
def test_build_detector_custom_size_channels():
|
||||||
@ -47,8 +57,8 @@ def test_build_detector_custom_size_channels():
|
|||||||
config = UNetBackboneConfig(in_channels=1, input_height=128)
|
config = UNetBackboneConfig(in_channels=1, input_height=128)
|
||||||
|
|
||||||
model = build_detector(
|
model = build_detector(
|
||||||
num_classes=num_classes,
|
class_names=[f"class_{i}" for i in range(num_classes)],
|
||||||
num_sizes=num_sizes,
|
dimension_names=[f"size_{i}" for i in range(num_sizes)],
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -62,7 +72,11 @@ def test_detector_forward_pass_shapes(dummy_spectrogram):
|
|||||||
num_classes = 4
|
num_classes = 4
|
||||||
# Build model matching the dummy input shape
|
# Build model matching the dummy input shape
|
||||||
config = UNetBackboneConfig(in_channels=1, input_height=256)
|
config = UNetBackboneConfig(in_channels=1, input_height=256)
|
||||||
model = build_detector(num_classes=num_classes, config=config)
|
model = build_detector(
|
||||||
|
class_names=[f"class_{i}" for i in range(num_classes)],
|
||||||
|
dimension_names=["width", "height"],
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
# Process the spectrogram through the model
|
# Process the spectrogram through the model
|
||||||
# PyTorch expects shape (Batch, Channels, Height, Width)
|
# PyTorch expects shape (Batch, Channels, Height, Width)
|
||||||
@ -132,7 +146,11 @@ def test_detector_forward_pass_with_preprocessor(sample_preprocessor):
|
|||||||
config = UNetBackboneConfig(
|
config = UNetBackboneConfig(
|
||||||
in_channels=spec.shape[1], input_height=spec.shape[2]
|
in_channels=spec.shape[1], input_height=spec.shape[2]
|
||||||
)
|
)
|
||||||
model = build_detector(num_classes=3, config=config)
|
model = build_detector(
|
||||||
|
class_names=["class_0", "class_1", "class_2"],
|
||||||
|
dimension_names=["width", "height"],
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
# Process
|
# Process
|
||||||
output = model(spec)
|
output = model(spec)
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.train import TrainingConfig, run_train
|
from batdetect2.train import TrainingConfig, run_train
|
||||||
from batdetect2.train.checkpoints import (
|
from batdetect2.train.checkpoints import (
|
||||||
DEFAULT_BUNDLED_CHECKPOINT,
|
DEFAULT_CHECKPOINT,
|
||||||
get_bundled_checkpoint_names,
|
get_bundled_checkpoint_names,
|
||||||
resolve_checkpoint_path,
|
resolve_checkpoint_path,
|
||||||
)
|
)
|
||||||
@ -145,7 +145,7 @@ def test_resolve_checkpoint_path_returns_local_path_unchanged(
|
|||||||
|
|
||||||
def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
|
def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
|
||||||
assert get_bundled_checkpoint_names() == (
|
assert get_bundled_checkpoint_names() == (
|
||||||
DEFAULT_BUNDLED_CHECKPOINT,
|
DEFAULT_CHECKPOINT,
|
||||||
"batdetect2_uk_same",
|
"batdetect2_uk_same",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -153,11 +153,11 @@ def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
|
|||||||
def test_resolve_checkpoint_path_uses_default_bundled_alias() -> None:
|
def test_resolve_checkpoint_path_uses_default_bundled_alias() -> None:
|
||||||
resolved = resolve_checkpoint_path()
|
resolved = resolve_checkpoint_path()
|
||||||
|
|
||||||
assert resolved == resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT)
|
assert resolved == resolve_checkpoint_path(DEFAULT_CHECKPOINT)
|
||||||
|
|
||||||
|
|
||||||
def test_resolve_checkpoint_path_accepts_bundled_alias() -> None:
|
def test_resolve_checkpoint_path_accepts_bundled_alias() -> None:
|
||||||
resolved = resolve_checkpoint_path(DEFAULT_BUNDLED_CHECKPOINT)
|
resolved = resolve_checkpoint_path(DEFAULT_CHECKPOINT)
|
||||||
|
|
||||||
assert resolved.name == "batdetect2_uk_same.ckpt"
|
assert resolved.name == "batdetect2_uk_same.ckpt"
|
||||||
assert resolved.exists()
|
assert resolved.exists()
|
||||||
@ -227,6 +227,6 @@ def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None:
|
|||||||
def test_resolve_checkpoint_path_rejects_missing_local_path() -> None:
|
def test_resolve_checkpoint_path_rejects_missing_local_path() -> None:
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
FileNotFoundError,
|
FileNotFoundError,
|
||||||
match="bundled checkpoint alias",
|
match="checkpoint alias",
|
||||||
):
|
):
|
||||||
resolve_checkpoint_path("missing.ckpt")
|
resolve_checkpoint_path("missing.ckpt")
|
||||||
|
|||||||
@ -368,7 +368,7 @@ def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> (
|
|||||||
assert (
|
assert (
|
||||||
rebuilt_detector.classifier_head is not source_detector.classifier_head
|
rebuilt_detector.classifier_head is not source_detector.classifier_head
|
||||||
)
|
)
|
||||||
assert rebuilt_detector.bbox_head is not source_detector.bbox_head
|
assert rebuilt_detector.size_head is not source_detector.size_head
|
||||||
assert rebuilt_model.class_names == ["single_class"]
|
assert rebuilt_model.class_names == ["single_class"]
|
||||||
assert rebuilt_model.dimension_names == ["width", "height"]
|
assert rebuilt_model.dimension_names == ["width", "height"]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user