refactor: decouple model metadata from target configs

This commit is contained in:
mbsantiago 2026-05-04 21:18:17 +01:00
parent e33053614a
commit 57236fc82a
17 changed files with 483 additions and 237 deletions

View File

@ -12,7 +12,6 @@ if TYPE_CHECKING:
import torch
from batdetect2.audio import AudioConfig, AudioLoader
from batdetect2.config import BatDetect2Config
from batdetect2.data import Dataset
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
from batdetect2.inference import InferenceConfig
@ -483,46 +482,70 @@ class BatDetect2API:
@classmethod
def from_config(
cls,
config: BatDetect2Config,
model_config: ModelConfig | None = None,
targets_config: TargetConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
inference_config: InferenceConfig | None = None,
outputs_config: OutputsConfig | None = None,
logging_config: AppLoggingConfig | None = None,
) -> "BatDetect2API":
from batdetect2.audio import build_audio_loader
from batdetect2.evaluate import build_evaluator
from batdetect2.models import build_model
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.evaluate import EvaluationConfig, build_evaluator
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.models import ModelConfig, build_model
from batdetect2.outputs import (
OutputsConfig,
build_output_formatter,
build_output_transform,
)
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_roi_mapping, build_targets
from batdetect2.targets import (
TargetConfig,
build_roi_mapping,
build_targets,
)
from batdetect2.train import TrainingConfig
targets = build_targets(config=config.model.targets)
roi_mapper = build_roi_mapping(config=config.model.targets.roi)
model_config = model_config or ModelConfig()
targets_config = targets_config or TargetConfig()
audio_config = audio_config or AudioConfig()
train_config = train_config or TrainingConfig()
evaluation_config = evaluation_config or EvaluationConfig()
inference_config = inference_config or InferenceConfig()
outputs_config = outputs_config or OutputsConfig()
logging_config = logging_config or AppLoggingConfig()
audio_loader = build_audio_loader(config=config.audio)
targets = build_targets(config=targets_config)
roi_mapper = build_roi_mapping(config=targets_config.roi)
audio_loader = build_audio_loader(config=audio_config)
preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.model.preprocess,
config=model_config.preprocess,
)
postprocessor = build_postprocessor(
preprocessor,
config=config.model.postprocess,
config=model_config.postprocess,
)
formatter = build_output_formatter(
targets,
config=config.outputs.format,
config=outputs_config.format,
)
output_transform = build_output_transform(
config=config.outputs.transform,
config=outputs_config.transform,
targets=targets,
roi_mapper=roi_mapper,
)
evaluator = build_evaluator(
config=config.evaluation,
config=evaluation_config,
targets=targets,
roi_mapper=roi_mapper,
transform=output_transform,
@ -531,27 +554,27 @@ class BatDetect2API:
# NOTE: Build separate instances of preprocessor and postprocessor
# to avoid device mismatch errors
model = build_model(
config=config.model,
targets=targets,
roi_mapper=roi_mapper,
config=model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
preprocessor=build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.model.preprocess,
config=model_config.preprocess,
),
postprocessor=build_postprocessor(
preprocessor,
config=config.model.postprocess,
config=model_config.postprocess,
),
)
return cls(
model_config=config.model,
audio_config=config.audio,
train_config=config.train,
evaluation_config=config.evaluation,
inference_config=config.inference,
outputs_config=config.outputs,
logging_config=config.logging,
model_config=model_config,
audio_config=audio_config,
train_config=train_config,
evaluation_config=evaluation_config,
inference_config=inference_config,
outputs_config=outputs_config,
logging_config=logging_config,
targets=targets,
roi_mapper=roi_mapper,
audio_loader=audio_loader,
@ -579,7 +602,6 @@ class BatDetect2API:
from batdetect2.evaluate import EvaluationConfig, build_evaluator
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.models import build_model_with_new_targets
from batdetect2.outputs import (
OutputsConfig,
build_output_formatter,
@ -587,8 +609,16 @@ class BatDetect2API:
)
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_roi_mapping, build_targets
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
from batdetect2.targets import (
build_default_target_config,
build_roi_mapping,
build_targets,
check_target_compatibility,
)
from batdetect2.train import (
TrainingConfig,
load_model_from_checkpoint,
)
model, model_config = load_model_from_checkpoint(path)
@ -600,24 +630,24 @@ class BatDetect2API:
inference_config = inference_config or InferenceConfig()
outputs_config = outputs_config or OutputsConfig()
logging_config = logging_config or AppLoggingConfig()
targets_config = targets_config or build_default_target_config(
class_names=model.class_names
)
if (
targets_config is not None
and targets_config != model_config.targets
):
targets = build_targets(config=targets_config)
roi_mapper = build_roi_mapping(config=targets_config.roi)
model = build_model_with_new_targets(
model=model,
targets=targets,
roi_mapper=roi_mapper,
)
model_config = model_config.model_copy(
update={"targets": targets_config}
targets = build_targets(config=targets_config)
roi_mapper = build_roi_mapping(config=targets_config.roi)
if not check_target_compatibility(targets, model.class_names):
raise ValueError(
"Provided targets_config is incompatible with the "
"checkpoint model: missing one or more model classes."
)
targets = build_targets(config=model_config.targets)
roi_mapper = build_roi_mapping(config=model_config.targets.roi)
if model.dimension_names != roi_mapper.dimension_names:
raise ValueError(
"Provided targets_config is incompatible with the "
"checkpoint model: mismatched dimension names."
)
audio_loader = build_audio_loader(config=audio_config)

View File

@ -12,6 +12,7 @@ from batdetect2.inference.config import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.targets import TargetConfig
from batdetect2.train.config import TrainingConfig
__all__ = ["BatDetect2Config"]
@ -25,6 +26,7 @@ class BatDetect2Config(BaseConfig):
default_factory=get_default_eval_config
)
model: ModelConfig = Field(default_factory=ModelConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
audio: AudioConfig = Field(default_factory=AudioConfig)
inference: InferenceConfig = Field(default_factory=InferenceConfig)
outputs: OutputsConfig = Field(default_factory=OutputsConfig)

View File

@ -24,8 +24,8 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
def run_evaluate(
model: Model,
test_annotations: Sequence[data.ClipAnnotation],
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
audio_config: AudioConfig | None = None,
@ -46,8 +46,6 @@ def run_evaluate(
audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets
roi_mapper = roi_mapper or model.roi_mapper
loader = build_test_loader(
test_annotations,

View File

@ -45,8 +45,16 @@ def run_batch_inference(
audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets
roi_mapper = roi_mapper or model.roi_mapper
if targets is None:
raise ValueError(
"targets must be provided when running batch inference."
)
if roi_mapper is None:
raise ValueError(
"roi_mapper must be provided when running batch inference."
)
output_transform = output_transform or build_output_transform(
config=output_config.transform,

View File

@ -7,21 +7,37 @@ from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
class InferenceModule(LightningModule):
def __init__(
self,
model: Model,
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
output_transform: OutputTransformProtocol | None = None,
detection_threshold: float | None = None,
):
super().__init__()
self.model = model
self.detection_threshold = detection_threshold
if output_transform is None and targets is None:
raise ValueError(
"targets must be provided when building inference output "
"transforms."
)
if output_transform is None and roi_mapper is None:
raise ValueError(
"roi_mapper must be provided when building inference output "
"transforms."
)
self.output_transform = output_transform or build_output_transform(
targets=model.targets,
roi_mapper=model.roi_mapper,
targets=targets,
roi_mapper=roi_mapper,
)
def predict_step(

View File

@ -26,11 +26,8 @@ The primary entry point for building a full, ready-to-use BatDetect2 model
is the ``build_model`` factory function exported from this module.
"""
from typing import Literal
import torch
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
from batdetect2.core.configs import BaseConfig
@ -73,7 +70,6 @@ from batdetect2.postprocess.types import (
)
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.config import TargetConfig
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
__all__ = [
@ -131,10 +127,6 @@ class ModelConfig(BaseConfig):
Parameters for converting raw model outputs into detections (NMS
kernel, thresholds, top-k limit). Defaults to
``PostprocessConfig()``.
targets : TargetConfig
Detection and classification target definitions (class list,
detection target, bounding-box mapper). Defaults to
``TargetConfig()``.
"""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
@ -143,23 +135,6 @@ class ModelConfig(BaseConfig):
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
@classmethod
def load(
cls,
path: PathLike,
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
targets: TargetConfig | None = None,
) -> "ModelConfig":
config = super().load(path, field, extra, strict)
if targets is None:
return config
return config.model_copy(update={"targets": targets})
class Model(torch.nn.Module):
@ -183,33 +158,32 @@ class Model(torch.nn.Module):
postprocessor : PostprocessorProtocol
Converts the raw ``ModelOutput`` from ``detector`` into a list of
per-clip detection tensors.
targets : TargetProtocol
Describes the set of target classes; used when building heads and
during training target construction.
roi_mapper : ROIMapperProtocol
Maps geometries to target-size channels and back.
class_names : list[str]
Class names corresponding to the model classification outputs.
dimension_names : list[str]
Size-dimension names corresponding to the model size outputs.
"""
detector: DetectionModel
preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol
targets: TargetProtocol
roi_mapper: ROIMapperProtocol
class_names: list[str]
dimension_names: list[str]
def __init__(
self,
detector: DetectionModel,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
class_names: list[str],
dimension_names: list[str],
):
super().__init__()
self.detector = detector
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.targets = targets
self.roi_mapper = roi_mapper
self.class_names = class_names
self.dimension_names = dimension_names
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
"""Run the full detection pipeline on a waveform tensor.
@ -238,8 +212,8 @@ class Model(torch.nn.Module):
def build_model(
config: ModelConfig | None = None,
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
class_names: list[str] | None = None,
dimension_names: list[str] | None = None,
preprocessor: PreprocessorProtocol | None = None,
postprocessor: PostprocessorProtocol | None = None,
) -> Model:
@ -254,11 +228,13 @@ def build_model(
----------
config : ModelConfig, optional
Full model configuration (samplerate, architecture, preprocessing,
postprocessing, targets). Defaults to ``ModelConfig()`` if not
provided.
targets : TargetProtocol, optional
Pre-built targets object. If given, overrides
``config.targets``.
postprocessing). Defaults to ``ModelConfig()`` if not provided.
class_names : list[str], optional
Class names used to size the classifier head. Required when building
a new model.
dimension_names : list[str], optional
Dimension names used to size the bbox head. Required when building a
new model.
preprocessor : PreprocessorProtocol, optional
Pre-built preprocessor. If given, overrides
``config.preprocess`` and ``config.samplerate`` for the
@ -278,19 +254,17 @@ def build_model(
"""
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_roi_mapping, build_targets
config = config or ModelConfig()
targets = targets or build_targets(config=config.targets)
targets_config = getattr(targets, "config", None)
roi_config = (
targets_config.roi
if isinstance(targets_config, TargetConfig)
else config.targets.roi
)
if class_names is None:
raise ValueError("class_names must be provided when building a model.")
if dimension_names is None:
raise ValueError(
"dimension_names must be provided when building a model."
)
roi_mapper = roi_mapper or build_roi_mapping(config=roi_config)
preprocessor = preprocessor or build_preprocessor(
config=config.preprocess,
input_samplerate=config.samplerate,
@ -300,16 +274,16 @@ def build_model(
config=config.postprocess,
)
detector = build_detector(
num_classes=len(targets.class_names),
num_sizes=len(roi_mapper.dimension_names),
num_classes=len(class_names),
num_sizes=len(dimension_names),
config=config.architecture,
)
return Model(
detector=detector,
postprocessor=postprocessor,
preprocessor=preprocessor,
targets=targets,
roi_mapper=roi_mapper,
class_names=class_names,
dimension_names=dimension_names,
)
@ -329,6 +303,6 @@ def build_model_with_new_targets(
detector=detector,
postprocessor=model.postprocessor,
preprocessor=model.preprocessor,
targets=targets,
roi_mapper=roi_mapper,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
)

View File

@ -6,7 +6,7 @@ from batdetect2.targets.classes import (
build_sound_event_encoder,
get_class_names_from_config,
)
from batdetect2.targets.config import TargetConfig
from batdetect2.targets.config import TargetConfig, build_default_target_config
from batdetect2.targets.rois import (
AnchorBBoxMapperConfig,
ROIMapperConfig,
@ -36,13 +36,14 @@ from batdetect2.targets.types import (
SoundEventFilter,
TargetProtocol,
)
from batdetect2.targets.utils import check_target_compatibility
__all__ = [
"AnchorBBoxMapperConfig",
"Position",
"ROIMappingConfig",
"ROIMapperProtocol",
"ROIMapperConfig",
"ROIMapperProtocol",
"ROIMappingConfig",
"ROITargetMapper",
"Size",
"SoundEventDecoder",
@ -52,12 +53,14 @@ __all__ = [
"TargetConfig",
"TargetProtocol",
"Targets",
"build_roi_mapping",
"build_default_target_config",
"build_roi_mapper",
"build_roi_mapping",
"build_sound_event_decoder",
"build_sound_event_encoder",
"build_targets",
"call_type",
"check_target_compatibility",
"data_source",
"generic_class",
"get_class_names_from_config",

View File

@ -2,6 +2,7 @@ from collections import Counter
from typing import List
from pydantic import Field, field_validator
from soundevent import data
from batdetect2.core.configs import BaseConfig
from batdetect2.targets.classes import (
@ -13,6 +14,7 @@ from batdetect2.targets.rois import ROIMappingConfig
__all__ = [
"TargetConfig",
"build_default_target_config",
]
@ -42,3 +44,20 @@ class TargetConfig(BaseConfig):
f"{', '.join(duplicates)}"
)
return v
def build_default_target_config(class_names: list[str]) -> TargetConfig:
"""Build a default target configuration object."""
return TargetConfig(
detection_target=DEFAULT_DETECTION_CLASS,
classification_targets=[
TargetClassConfig(
name=class_name,
tags=[
data.Tag(key="class", value=class_name),
],
)
for class_name in class_names
],
roi=ROIMappingConfig(),
)

View File

@ -0,0 +1,29 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from batdetect2.targets.types import TargetProtocol
def check_target_compatibility(
targets: "TargetProtocol",
class_names: list[str],
) -> bool:
"""Check if a target definition can decode a model's outputs.
Parameters
----------
targets : TargetProtocol
Target definition that would be used with the model outputs.
class_names : list[str]
Class names produced by the model checkpoint.
Returns
-------
bool
True when every model class name exists in the provided targets,
False otherwise.
"""
target_class_names = set(targets.class_names)
model_class_names = set(class_names)
return model_class_names.issubset(target_class_names)

View File

@ -10,6 +10,7 @@ from batdetect2.logging import get_image_logger
from batdetect2.models.types import ModelOutput
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.types import TrainExample
@ -19,11 +20,15 @@ class ValidationMetrics(Callback):
def __init__(
self,
evaluator: EvaluatorProtocol,
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
output_transform: OutputTransformProtocol | None = None,
):
super().__init__()
self.evaluator = evaluator
self.targets = targets
self.roi_mapper = roi_mapper
self.output_transform = output_transform
self._clip_annotations: List[data.ClipAnnotation] = []
@ -93,8 +98,8 @@ class ValidationMetrics(Callback):
model = pl_module.model
if self.output_transform is None:
self.output_transform = build_output_transform(
targets=model.targets,
roi_mapper=model.roi_mapper,
targets=self.targets,
roi_mapper=self.roi_mapper,
)
output_transform = self.output_transform

View File

@ -9,9 +9,7 @@ from batdetect2.train.optimizers import build_optimizer
from batdetect2.train.schedulers import build_scheduler
from batdetect2.train.types import LossProtocol, TrainExample
__all__ = [
"TrainingModule",
]
__all__ = ["TrainingModule"]
class TrainingModule(L.LightningModule):
@ -21,6 +19,8 @@ class TrainingModule(L.LightningModule):
def __init__(
self,
model_config: dict | None = None,
class_names: list[str] | None = None,
dimension_names: list[str] | None = None,
train_config: dict | None = None,
loss: LossProtocol | None = None,
model: Model | None = None,
@ -30,13 +30,31 @@ class TrainingModule(L.LightningModule):
self.save_hyperparameters(ignore=["model", "loss"], logger=False)
self.model_config = ModelConfig.model_validate(model_config or {})
self.class_names = list(class_names or [])
self.dimension_names = list(dimension_names or [])
self.train_config = TrainingConfig.model_validate(train_config or {})
if loss is None:
loss = build_loss(config=self.train_config.loss)
if model is None:
model = build_model(config=self.model_config)
if not self.class_names:
raise ValueError(
"class_names must be provided when rebuilding a training "
"module without a model."
)
if not self.dimension_names:
raise ValueError(
"dimension_names must be provided when rebuilding a "
"training module without a model."
)
model = build_model(
config=self.model_config,
class_names=self.class_names,
dimension_names=self.dimension_names,
)
self.loss = loss
self.model = model
@ -110,8 +128,7 @@ def load_model_from_checkpoint(
-------
tuple[Model, ModelConfig]
The restored ``Model`` instance and the ``ModelConfig`` that
describes its architecture, preprocessing, postprocessing, and
targets.
describes its architecture, preprocessing, and postprocessing.
"""
module = TrainingModule.load_from_checkpoint(path) # type: ignore
return module.model, module.model_config
@ -119,6 +136,8 @@ def load_model_from_checkpoint(
def build_training_module(
model_config: ModelConfig | None = None,
class_names: list[str] | None = None,
dimension_names: list[str] | None = None,
train_config: TrainingConfig | None = None,
model: Model | None = None,
) -> TrainingModule:
@ -130,6 +149,8 @@ def build_training_module(
return TrainingModule(
model_config=model_config.model_dump(mode="json"),
class_names=class_names,
dimension_names=dimension_names,
train_config=train_config.model_dump(mode="json"),
model=model,
)

View File

@ -17,6 +17,7 @@ from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import (
ROIMapperProtocol,
TargetConfig,
TargetProtocol,
build_roi_mapping,
build_targets,
@ -46,6 +47,7 @@ def run_train(
labeller: Optional["ClipLabeller"] = None,
audio_config: Optional[AudioConfig] = None,
model_config: Optional[ModelConfig] = None,
targets_config: TargetConfig | None = None,
train_config: Optional[TrainingConfig] = None,
logger_config: LoggerConfig | None = None,
trainer: Trainer | None = None,
@ -62,23 +64,34 @@ def run_train(
seed_everything(seed)
model_config = model_config or ModelConfig()
targets_config = targets_config or TargetConfig()
audio_config = audio_config or AudioConfig()
train_config = train_config or TrainingConfig()
if model is not None:
_validate_model_compatibility(model=model, model_config=model_config)
if targets is None:
raise ValueError(
"targets must be provided when training with an existing "
"model."
)
if roi_mapper is None:
raise ValueError(
"roi_mapper must be provided when training with an existing "
"model."
)
targets = targets or build_targets(config=targets_config)
roi_mapper = roi_mapper or build_roi_mapping(config=targets_config.roi)
if model is not None:
targets = targets or model.targets
if roi_mapper is None and targets is model.targets:
roi_mapper = model.roi_mapper
targets = targets or build_targets(config=model_config.targets)
roi_mapper = roi_mapper or build_roi_mapping(
config=model_config.targets.roi
)
_validate_model_compatibility(
model=model,
model_config=model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
)
audio_loader = audio_loader or build_audio_loader(config=audio_config)
@ -119,18 +132,24 @@ def run_train(
module = build_training_module(
model_config=model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=train_config,
model=model,
)
evaluator = build_evaluator(
train_config.validation,
targets=targets,
roi_mapper=roi_mapper,
)
trainer = trainer or build_trainer(
train_config,
logger_config=logger_config,
evaluator=build_evaluator(
train_config.validation,
targets=targets,
roi_mapper=roi_mapper,
),
evaluator=evaluator,
targets=targets,
roi_mapper=roi_mapper,
checkpoint_dir=checkpoint_dir,
num_epochs=num_epochs,
log_dir=log_dir,
@ -152,8 +171,14 @@ def run_train(
def _validate_model_compatibility(
model: Model,
model_config: ModelConfig,
class_names: list[str],
dimension_names: list[str],
) -> None:
reference_model = build_model(config=model_config)
reference_model = build_model(
config=model_config,
class_names=class_names,
dimension_names=dimension_names,
)
expected_shapes = {
key: tuple(value.shape)
@ -196,6 +221,8 @@ def build_trainer(
config: TrainingConfig,
logger_config: LoggerConfig | None,
evaluator: "EvaluatorProtocol",
targets: "TargetProtocol",
roi_mapper: "ROIMapperProtocol",
checkpoint_dir: Path | None = None,
log_dir: Path | None = None,
experiment_name: str | None = None,
@ -234,6 +261,6 @@ def build_trainer(
experiment_name=experiment_name,
run_name=run_name,
),
ValidationMetrics(evaluator),
ValidationMetrics(evaluator, targets, roi_mapper),
],
)

View File

@ -13,13 +13,14 @@ from soundevent import data, terms
from batdetect2.audio import build_audio_loader
from batdetect2.audio.clips import build_clipper
from batdetect2.audio.types import AudioLoader, ClipperProtocol
from batdetect2.config import BatDetect2Config
from batdetect2.data import DatasetConfig, load_dataset
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import (
ROIMapperProtocol,
TargetConfig,
build_roi_mapping,
build_targets,
call_type,
)
@ -404,6 +405,13 @@ def sample_targets(
return build_targets(sample_target_config)
@pytest.fixture
def sample_roi_mapper(
sample_target_config: TargetConfig,
) -> ROIMapperProtocol:
return build_roi_mapping(sample_target_config.roi)
@pytest.fixture
def sample_labeller(
sample_targets: TargetProtocol,
@ -458,8 +466,15 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
@pytest.fixture
def tiny_checkpoint_path(tmp_path: Path) -> Path:
module = build_training_module(model_config=BatDetect2Config().model)
def tiny_checkpoint_path(
sample_targets: TargetProtocol,
sample_roi_mapper: ROIMapperProtocol,
tmp_path: Path,
) -> Path:
module = build_training_module(
class_names=sample_targets.class_names,
dimension_names=sample_roi_mapper.dimension_names,
)
trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "model.ckpt"
trainer.strategy.connect(module)

View File

@ -8,20 +8,43 @@ import torch
from soundevent.geometry import compute_bounds
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import BatDetect2Config
from batdetect2.inference import InferenceConfig
from batdetect2.models.detectors import Detector
from batdetect2.models.heads import ClassifierHead
from batdetect2.train import load_model_from_checkpoint
from batdetect2.targets import TargetConfig
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
from batdetect2.train.lightning import build_training_module
@pytest.fixture
def api_v2() -> BatDetect2API:
def train_config() -> TrainingConfig:
"""Train config with a small batch size for testing."""
return TrainingConfig.model_validate({"train_loader": {"batch_size": 2}})
@pytest.fixture
def inference_config() -> InferenceConfig:
"""Inference config with a small batch size for testing."""
return InferenceConfig.model_validate({"loader": {"batch_size": 2}})
@pytest.fixture
def example_targets_config(example_data_dir: Path) -> TargetConfig:
return TargetConfig.load(example_data_dir / "targets.yaml")
@pytest.fixture
def api_v2(
train_config: TrainingConfig,
inference_config: InferenceConfig,
) -> BatDetect2API:
"""User story: users can create a ready-to-use API from config."""
config = BatDetect2Config()
config.inference.loader.batch_size = 2
return BatDetect2API.from_config(config)
api = BatDetect2API.from_config(
train_config=train_config,
inference_config=inference_config,
)
assert api.inference_config.loader.batch_size == 2
return api
def test_process_file_returns_recording_level_predictions(
@ -30,8 +53,10 @@ def test_process_file_returns_recording_level_predictions(
) -> None:
"""User story: process a file and get detections in recording time."""
# When
prediction = api_v2.process_file(example_audio_files[0])
# Then
assert prediction.clip.recording.path == example_audio_files[0]
assert prediction.clip.start_time == 0
assert prediction.clip.end_time == prediction.clip.recording.duration
@ -53,9 +78,11 @@ def test_process_files_is_batch_size_invariant(
) -> None:
"""User story: changing batch size should not change predictions."""
# When
preds_batch_1 = api_v2.process_files(example_audio_files, batch_size=1)
preds_batch_3 = api_v2.process_files(example_audio_files, batch_size=3)
# Then
assert len(preds_batch_1) == len(preds_batch_3)
by_key_1 = {
@ -91,12 +118,14 @@ def test_process_audio_matches_process_spectrogram(
) -> None:
"""User story: users can call either audio or spectrogram entrypoint."""
# When
audio = api_v2.load_audio(example_audio_files[0])
from_audio = api_v2.process_audio(audio)
spec = api_v2.generate_spectrogram(audio)
from_spec = api_v2.process_spectrogram(spec)
# Then
assert len(from_audio) == len(from_spec)
for det_audio, det_spec in zip(from_audio, from_spec, strict=True):
@ -116,8 +145,10 @@ def test_process_spectrogram_rejects_batched_input(
) -> None:
"""User story: invalid batched input gives a clear error."""
# Given
spec = torch.zeros((2, 1, 128, 64), dtype=torch.float32)
# When/Then
with pytest.raises(ValueError, match="Batched spectrograms not supported"):
api_v2.process_spectrogram(spec)
@ -184,26 +215,34 @@ def test_user_can_read_extracted_features_per_detection(
@pytest.mark.slow
def test_user_can_load_checkpoint_and_finetune(
tmp_path: Path,
example_targets_config: TargetConfig,
example_annotations,
) -> None:
"""User story: load a checkpoint and continue training from it."""
module = build_training_module(model_config=BatDetect2Config().model)
api = BatDetect2API.from_config(
targets_config=example_targets_config,
)
module = build_training_module(
model_config=api.model_config,
class_names=api.targets.class_names,
dimension_names=api.roi_mapper.dimension_names,
)
trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "base.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(checkpoint_path)
config = BatDetect2Config()
config.train.trainer.limit_train_batches = 1
config.train.trainer.limit_val_batches = 1
config.train.trainer.log_every_n_steps = 1
config.train.train_loader.batch_size = 1
config.train.train_loader.augmentations.enabled = False
train_config = api.train_config.model_copy(deep=True)
train_config.trainer.limit_train_batches = 1
train_config.trainer.limit_val_batches = 1
train_config.trainer.log_every_n_steps = 1
train_config.train_loader.batch_size = 1
train_config.train_loader.augmentations.enabled = False
api = BatDetect2API.from_checkpoint(
checkpoint_path,
train_config=config.train,
train_config=train_config,
)
finetune_dir = tmp_path / "finetuned"
@ -222,62 +261,36 @@ def test_user_can_load_checkpoint_and_finetune(
assert checkpoints
def test_user_can_load_checkpoint_with_new_targets(
tmp_path: Path,
sample_targets,
) -> None:
"""User story: start from checkpoint with a new target definition."""
module = build_training_module(model_config=BatDetect2Config().model)
trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "base_transfer.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(checkpoint_path)
source_model, _ = load_model_from_checkpoint(checkpoint_path)
api = BatDetect2API.from_checkpoint(
checkpoint_path,
targets_config=sample_targets.config,
)
source_detector = cast(Detector, source_model.detector)
detector = cast(Detector, api.model.detector)
classifier_head = cast(ClassifierHead, detector.classifier_head)
assert api.targets.config == sample_targets.config # type: ignore
assert detector.num_classes == len(sample_targets.class_names)
assert (
classifier_head.classifier.out_channels
== len(sample_targets.class_names) + 1
)
source_backbone = source_detector.backbone.state_dict()
target_backbone = detector.backbone.state_dict()
assert source_backbone
for key, value in source_backbone.items():
assert key in target_backbone
torch.testing.assert_close(target_backbone[key], value)
def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
example_targets_config: TargetConfig,
tmp_path: Path,
) -> None:
"""User story: same targets config does not rebuild prediction heads."""
module = build_training_module(model_config=BatDetect2Config().model)
# Given
source_api = BatDetect2API.from_config(
targets_config=example_targets_config
)
module = build_training_module(
model_config=source_api.model_config,
class_names=source_api.targets.class_names,
dimension_names=source_api.roi_mapper.dimension_names,
)
trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "same_targets.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(checkpoint_path)
source_model, source_model_config = load_model_from_checkpoint(
checkpoint_path
)
source_model, _ = load_model_from_checkpoint(checkpoint_path)
source_detector = cast(Detector, source_model.detector)
# When
api = BatDetect2API.from_checkpoint(
checkpoint_path,
targets_config=source_model_config.targets,
targets_config=example_targets_config,
)
# Then
detector = cast(Detector, api.model.detector)
for key, value in source_detector.classifier_head.state_dict().items():
@ -302,7 +315,7 @@ def test_user_can_finetune_only_heads(
) -> None:
"""User story: fine-tune only prediction heads."""
api = BatDetect2API.from_config(BatDetect2Config())
api = BatDetect2API.from_config()
finetune_dir = tmp_path / "heads_only"
api.finetune(
@ -348,8 +361,6 @@ def test_user_can_evaluate_small_dataset_and_get_metrics(
assert isinstance(metrics, list)
assert len(metrics) == 1
assert isinstance(metrics[0], dict)
assert len(metrics[0]) > 0
assert isinstance(predictions, list)
assert len(predictions) == 1
@ -450,8 +461,17 @@ def test_detection_threshold_override_changes_spectrogram_results(
spec = api_v2.generate_spectrogram(audio)
default_detections = api_v2.process_spectrogram(spec)
strict_detections = api_v2.process_spectrogram(
spec,
detection_threshold=1.0,
spec, detection_threshold=1.0
)
assert len(strict_detections) <= len(default_detections)
def test_user_can_create_api_with_custom_targets_and_model_metadata_matches(
sample_targets,
) -> None:
"""User story: custom targets define model output names for a new API."""
api = BatDetect2API.from_config(targets_config=sample_targets.config)
assert api.model.class_names == sample_targets.class_names

View File

@ -5,7 +5,6 @@ import numpy as np
import pytest
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import BatDetect2Config
from batdetect2.outputs import build_output_formatter
from batdetect2.outputs.formats import (
BatDetect2OutputConfig,
@ -18,7 +17,7 @@ from batdetect2.postprocess.types import ClipDetections
def api_v2() -> BatDetect2API:
"""User story: API object manages prediction IO formats."""
return BatDetect2API.from_config(BatDetect2Config())
return BatDetect2API.from_config()
@pytest.fixture

View File

@ -0,0 +1,40 @@
from soundevent import data
from batdetect2.targets import (
TargetClassConfig,
TargetConfig,
build_targets,
check_target_compatibility,
)
def _target_class(name: str) -> TargetClassConfig:
return TargetClassConfig(
name=name,
tags=[data.Tag(key="class", value=name)],
)
def test_check_target_compatibility_accepts_superset_targets() -> None:
config = TargetConfig(
classification_targets=[
_target_class("pip35"),
_target_class("myo"),
_target_class("extra"),
]
)
targets = build_targets(config)
assert check_target_compatibility(targets, ["pip35", "myo"])
def test_check_target_compatibility_rejects_missing_model_classes() -> None:
config = TargetConfig(
classification_targets=[
_target_class("pip35"),
_target_class("myo"),
]
)
targets = build_targets(config)
assert not check_target_compatibility(targets, ["pip35", "nyc"])

View File

@ -10,9 +10,8 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio.types import AudioLoader
from batdetect2.config import BatDetect2Config
from batdetect2.models import ModelConfig, build_model
from batdetect2.targets.classes import TargetClassConfig
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
from batdetect2.train import (
TrainingConfig,
TrainingModule,
@ -24,11 +23,21 @@ from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
from batdetect2.train.train import build_training_module
def build_default_module(config: BatDetect2Config | None = None):
config = config or BatDetect2Config()
def build_default_module(
target_config: TargetConfig | None = None,
model_config: ModelConfig | None = None,
train_config: TrainingConfig | None = None,
):
target_config = target_config or TargetConfig()
model_config = model_config or ModelConfig()
train_config = train_config or TrainingConfig()
targets = build_targets(target_config)
roi_mapper = build_roi_mapping(target_config.roi)
return build_training_module(
model_config=config.model,
train_config=config.train,
model_config=model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=train_config,
)
@ -72,8 +81,13 @@ def test_load_model_from_checkpoint_returns_model_and_config(
input_model_config.model_dump(mode="json")
)
train_config = TrainingConfig()
targets_config = TargetConfig()
targets = build_targets(targets_config)
roi_mapper = build_roi_mapping(targets_config.roi)
module = build_training_module(
model_config=input_model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=train_config,
)
trainer = L.Trainer()
@ -87,6 +101,8 @@ def test_load_model_from_checkpoint_returns_model_and_config(
assert loaded_model_config.model_dump(
mode="json"
) == expected_model_config.model_dump(mode="json")
assert model.class_names == targets.class_names
assert model.dimension_names == roi_mapper.dimension_names
recovered = TrainingModule.load_from_checkpoint(path)
assert recovered.train_config.model_dump(
@ -100,6 +116,9 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
model_config.model_dump(mode="json")
)
train_config = TrainingConfig()
targets_config = TargetConfig()
targets = build_targets(targets_config)
roi_mapper = build_roi_mapping(targets_config.roi)
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=123)
train_config.trainer.max_epochs = 3
@ -107,6 +126,8 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
module = build_training_module(
model_config=model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=train_config,
)
trainer = L.Trainer()
@ -131,11 +152,16 @@ def test_configure_optimizers_uses_train_config_values(tmp_path: Path):
model_config.model_dump(mode="json")
)
train_config = TrainingConfig()
targets_config = TargetConfig()
targets = build_targets(targets_config)
roi_mapper = build_roi_mapping(targets_config.roi)
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=321)
module = build_training_module(
model_config=model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=train_config,
)
@ -189,19 +215,26 @@ def test_train_smoke_produces_loadable_checkpoint(
example_annotations: list[data.ClipAnnotation],
sample_audio_loader: AudioLoader,
):
config = BatDetect2Config()
config.train.trainer.limit_train_batches = 1
config.train.trainer.limit_val_batches = 1
config.train.trainer.log_every_n_steps = 1
config.train.train_loader.batch_size = 1
config.train.train_loader.augmentations.enabled = False
# Given
train_config = TrainingConfig.model_validate(
{
"trainer": {
"limit_train_batches": 1,
"limit_val_batches": 1,
"log_every_n_steps": 1,
},
"train_loader": {
"batch_size": 1,
"augmentations": {"enabled": False},
},
}
)
# When
run_train(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
train_config=config.train,
model_config=config.model,
audio_config=config.audio,
train_config=train_config,
num_epochs=1,
train_workers=0,
val_workers=0,
@ -209,18 +242,11 @@ def test_train_smoke_produces_loadable_checkpoint(
seed=0,
)
# Then
checkpoints = list(tmp_path.rglob("*.ckpt"))
assert checkpoints
model, model_config = load_model_from_checkpoint(checkpoints[0])
assert model_config.samplerate == config.model.samplerate
assert model_config.architecture.name == config.model.architecture.name
assert model_config.preprocess.model_dump(
mode="json"
) == config.model.preprocess.model_dump(mode="json")
assert model_config.postprocess.model_dump(
mode="json"
) == config.model.postprocess.model_dump(mode="json")
wav = torch.tensor(
sample_audio_loader.load_clip(example_annotations[0].clip)
@ -230,10 +256,18 @@ def test_train_smoke_produces_loadable_checkpoint(
def test_build_training_module_uses_provided_model() -> None:
model = build_model(ModelConfig())
targets = build_targets(TargetConfig())
roi_mapper = build_roi_mapping(TargetConfig().roi)
model = build_model(
ModelConfig(),
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
)
module = build_training_module(
model_config=ModelConfig(),
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=TrainingConfig(),
model=model,
)
@ -244,15 +278,18 @@ def test_build_training_module_uses_provided_model() -> None:
def test_run_train_rejects_incompatible_model_config(
example_annotations: list[data.ClipAnnotation],
) -> None:
model = build_model(ModelConfig())
# Given
targets_config = TargetConfig()
targets = build_targets(targets_config)
roi_mapper = build_roi_mapping(targets_config.roi)
incompatible_config = ModelConfig()
incompatible_config.targets.classification_targets.append(
TargetClassConfig(
name="dummy_class",
tags=[data.Tag(key="class", value="Dummy class")],
)
incompatible_model = build_model(
incompatible_config,
class_names=targets.class_names,
dimension_names=[*roi_mapper.dimension_names, "extra_dim"],
)
# When/Then
with pytest.raises(
ValueError,
match="Provided model is incompatible with model_config",
@ -260,7 +297,10 @@ def test_run_train_rejects_incompatible_model_config(
run_train(
train_annotations=example_annotations[:1],
val_annotations=None,
model=model,
model=incompatible_model,
targets=targets,
roi_mapper=roi_mapper,
model_config=incompatible_config,
targets_config=targets_config,
train_config=TrainingConfig(),
)