From 57236fc82a701cc286546b60e41b9b39974dcaac Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 4 May 2026 21:18:17 +0100 Subject: [PATCH] refactor: decouple model metadata from target configs --- src/batdetect2/api_v2.py | 118 +++++++++++++-------- src/batdetect2/config.py | 2 + src/batdetect2/evaluate/evaluate.py | 6 +- src/batdetect2/inference/batch.py | 12 ++- src/batdetect2/inference/lightning.py | 20 +++- src/batdetect2/models/__init__.py | 90 ++++++---------- src/batdetect2/targets/__init__.py | 11 +- src/batdetect2/targets/config.py | 19 ++++ src/batdetect2/targets/utils.py | 29 ++++++ src/batdetect2/train/callbacks.py | 9 +- src/batdetect2/train/lightning.py | 33 ++++-- src/batdetect2/train/train.py | 63 ++++++++---- tests/conftest.py | 21 +++- tests/test_api_v2/test_api_v2.py | 142 +++++++++++++++----------- tests/test_api_v2/test_outputs_io.py | 3 +- tests/test_targets/test_utils.py | 40 ++++++++ tests/test_train/test_lightning.py | 102 ++++++++++++------ 17 files changed, 483 insertions(+), 237 deletions(-) create mode 100644 src/batdetect2/targets/utils.py create mode 100644 tests/test_targets/test_utils.py diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index e989f72..c6a5492 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -12,7 +12,6 @@ if TYPE_CHECKING: import torch from batdetect2.audio import AudioConfig, AudioLoader - from batdetect2.config import BatDetect2Config from batdetect2.data import Dataset from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol from batdetect2.inference import InferenceConfig @@ -483,46 +482,70 @@ class BatDetect2API: @classmethod def from_config( cls, - config: BatDetect2Config, + model_config: ModelConfig | None = None, + targets_config: TargetConfig | None = None, + audio_config: AudioConfig | None = None, + train_config: TrainingConfig | None = None, + evaluation_config: EvaluationConfig | None = None, + inference_config: InferenceConfig | None = None, + outputs_config: OutputsConfig | None = None, + logging_config: AppLoggingConfig | None = None, ) -> "BatDetect2API": - from batdetect2.audio import build_audio_loader - from batdetect2.evaluate import build_evaluator - from batdetect2.models import build_model + from batdetect2.audio import AudioConfig, build_audio_loader + from batdetect2.evaluate import EvaluationConfig, build_evaluator + from batdetect2.inference import InferenceConfig + from batdetect2.logging import AppLoggingConfig + from batdetect2.models import ModelConfig, build_model from batdetect2.outputs import ( + OutputsConfig, build_output_formatter, build_output_transform, ) from batdetect2.postprocess import build_postprocessor from batdetect2.preprocess import build_preprocessor - from batdetect2.targets import build_roi_mapping, build_targets + from batdetect2.targets import ( + TargetConfig, + build_roi_mapping, + build_targets, + ) + from batdetect2.train import TrainingConfig - targets = build_targets(config=config.model.targets) - roi_mapper = build_roi_mapping(config=config.model.targets.roi) + model_config = model_config or ModelConfig() + targets_config = targets_config or TargetConfig() + audio_config = audio_config or AudioConfig() + train_config = train_config or TrainingConfig() + evaluation_config = evaluation_config or EvaluationConfig() + inference_config = inference_config or InferenceConfig() + outputs_config = outputs_config or OutputsConfig() + logging_config = logging_config or AppLoggingConfig() - audio_loader = build_audio_loader(config=config.audio) + targets = build_targets(config=targets_config) + roi_mapper = build_roi_mapping(config=targets_config.roi) + + audio_loader = build_audio_loader(config=audio_config) preprocessor = build_preprocessor( input_samplerate=audio_loader.samplerate, - config=config.model.preprocess, + config=model_config.preprocess, ) postprocessor = build_postprocessor( preprocessor, - config=config.model.postprocess, + config=model_config.postprocess, ) formatter = build_output_formatter( targets, - config=config.outputs.format, + config=outputs_config.format, ) output_transform = build_output_transform( - config=config.outputs.transform, + config=outputs_config.transform, targets=targets, roi_mapper=roi_mapper, ) evaluator = build_evaluator( - config=config.evaluation, + config=evaluation_config, targets=targets, roi_mapper=roi_mapper, transform=output_transform, @@ -531,27 +554,27 @@ class BatDetect2API: # NOTE: Build separate instances of preprocessor and postprocessor # to avoid device mismatch errors model = build_model( - config=config.model, - targets=targets, - roi_mapper=roi_mapper, + config=model_config, + class_names=targets.class_names, + dimension_names=roi_mapper.dimension_names, preprocessor=build_preprocessor( input_samplerate=audio_loader.samplerate, - config=config.model.preprocess, + config=model_config.preprocess, ), postprocessor=build_postprocessor( preprocessor, - config=config.model.postprocess, + config=model_config.postprocess, ), ) return cls( - model_config=config.model, - audio_config=config.audio, - train_config=config.train, - evaluation_config=config.evaluation, - inference_config=config.inference, - outputs_config=config.outputs, - logging_config=config.logging, + model_config=model_config, + audio_config=audio_config, + train_config=train_config, + evaluation_config=evaluation_config, + inference_config=inference_config, + outputs_config=outputs_config, + logging_config=logging_config, targets=targets, roi_mapper=roi_mapper, audio_loader=audio_loader, @@ -579,7 +602,6 @@ class BatDetect2API: from batdetect2.evaluate import EvaluationConfig, build_evaluator from batdetect2.inference import InferenceConfig from batdetect2.logging import AppLoggingConfig - from batdetect2.models import build_model_with_new_targets from batdetect2.outputs import ( OutputsConfig, build_output_formatter, @@ -587,8 +609,16 @@ class BatDetect2API: ) from batdetect2.postprocess import build_postprocessor from batdetect2.preprocess import build_preprocessor - from batdetect2.targets import build_roi_mapping, build_targets - from batdetect2.train import TrainingConfig, load_model_from_checkpoint + from batdetect2.targets import ( + build_default_target_config, + build_roi_mapping, + build_targets, + check_target_compatibility, + ) + from batdetect2.train import ( + TrainingConfig, + load_model_from_checkpoint, + ) model, model_config = load_model_from_checkpoint(path) @@ -600,24 +630,24 @@ class BatDetect2API: inference_config = inference_config or InferenceConfig() outputs_config = outputs_config or OutputsConfig() logging_config = logging_config or AppLoggingConfig() + targets_config = targets_config or build_default_target_config( + class_names=model.class_names + ) - if ( - targets_config is not None - and targets_config != model_config.targets - ): - targets = build_targets(config=targets_config) - roi_mapper = build_roi_mapping(config=targets_config.roi) - model = build_model_with_new_targets( - model=model, - targets=targets, - roi_mapper=roi_mapper, - ) - model_config = model_config.model_copy( - update={"targets": targets_config} + targets = build_targets(config=targets_config) + roi_mapper = build_roi_mapping(config=targets_config.roi) + + if not check_target_compatibility(targets, model.class_names): + raise ValueError( + "Provided targets_config is incompatible with the " + "checkpoint model: missing one or more model classes." ) - targets = build_targets(config=model_config.targets) - roi_mapper = build_roi_mapping(config=model_config.targets.roi) + if model.dimension_names != roi_mapper.dimension_names: + raise ValueError( + "Provided targets_config is incompatible with the " + "checkpoint model: mismatched dimension names." + ) audio_loader = build_audio_loader(config=audio_config) diff --git a/src/batdetect2/config.py b/src/batdetect2/config.py index dddd48e..88b23ee 100644 --- a/src/batdetect2/config.py +++ b/src/batdetect2/config.py @@ -12,6 +12,7 @@ from batdetect2.inference.config import InferenceConfig from batdetect2.logging import AppLoggingConfig from batdetect2.models import ModelConfig from batdetect2.outputs import OutputsConfig +from batdetect2.targets import TargetConfig from batdetect2.train.config import TrainingConfig __all__ = ["BatDetect2Config"] @@ -25,6 +26,7 @@ class BatDetect2Config(BaseConfig): default_factory=get_default_eval_config ) model: ModelConfig = Field(default_factory=ModelConfig) + targets: TargetConfig = Field(default_factory=TargetConfig) audio: AudioConfig = Field(default_factory=AudioConfig) inference: InferenceConfig = Field(default_factory=InferenceConfig) outputs: OutputsConfig = Field(default_factory=OutputsConfig) diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index 0a0ac5b..5eecb4d 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -24,8 +24,8 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations" def run_evaluate( model: Model, test_annotations: Sequence[data.ClipAnnotation], - targets: TargetProtocol | None = None, - roi_mapper: ROIMapperProtocol | None = None, + targets: TargetProtocol, + roi_mapper: ROIMapperProtocol, audio_loader: AudioLoader | None = None, preprocessor: PreprocessorProtocol | None = None, audio_config: AudioConfig | None = None, @@ -46,8 +46,6 @@ def run_evaluate( audio_loader = audio_loader or build_audio_loader(config=audio_config) preprocessor = preprocessor or model.preprocessor - targets = targets or model.targets - roi_mapper = roi_mapper or model.roi_mapper loader = build_test_loader( test_annotations, diff --git a/src/batdetect2/inference/batch.py b/src/batdetect2/inference/batch.py index 2c43bde..4be9dab 100644 --- a/src/batdetect2/inference/batch.py +++ b/src/batdetect2/inference/batch.py @@ -45,8 +45,16 @@ def run_batch_inference( audio_loader = audio_loader or build_audio_loader(config=audio_config) preprocessor = preprocessor or model.preprocessor - targets = targets or model.targets - roi_mapper = roi_mapper or model.roi_mapper + + if targets is None: + raise ValueError( + "targets must be provided when running batch inference." + ) + + if roi_mapper is None: + raise ValueError( + "roi_mapper must be provided when running batch inference." + ) output_transform = output_transform or build_output_transform( config=output_config.transform, diff --git a/src/batdetect2/inference/lightning.py b/src/batdetect2/inference/lightning.py index 7ae010e..02cba66 100644 --- a/src/batdetect2/inference/lightning.py +++ b/src/batdetect2/inference/lightning.py @@ -7,21 +7,37 @@ from batdetect2.inference.dataset import DatasetItem, InferenceDataset from batdetect2.models import Model from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.postprocess.types import ClipDetections +from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol class InferenceModule(LightningModule): def __init__( self, model: Model, + targets: TargetProtocol | None = None, + roi_mapper: ROIMapperProtocol | None = None, output_transform: OutputTransformProtocol | None = None, detection_threshold: float | None = None, ): super().__init__() self.model = model self.detection_threshold = detection_threshold + + if output_transform is None and targets is None: + raise ValueError( + "targets must be provided when building inference output " + "transforms." + ) + + if output_transform is None and roi_mapper is None: + raise ValueError( + "roi_mapper must be provided when building inference output " + "transforms." + ) + self.output_transform = output_transform or build_output_transform( - targets=model.targets, - roi_mapper=model.roi_mapper, + targets=targets, + roi_mapper=roi_mapper, ) def predict_step( diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index 9b53004..f25a221 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -26,11 +26,8 @@ The primary entry point for building a full, ready-to-use BatDetect2 model is the ``build_model`` factory function exported from this module. """ -from typing import Literal - import torch from pydantic import Field -from soundevent.data import PathLike from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ from batdetect2.core.configs import BaseConfig @@ -73,7 +70,6 @@ from batdetect2.postprocess.types import ( ) from batdetect2.preprocess.config import PreprocessingConfig from batdetect2.preprocess.types import PreprocessorProtocol -from batdetect2.targets.config import TargetConfig from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol __all__ = [ @@ -131,10 +127,6 @@ class ModelConfig(BaseConfig): Parameters for converting raw model outputs into detections (NMS kernel, thresholds, top-k limit). Defaults to ``PostprocessConfig()``. - targets : TargetConfig - Detection and classification target definitions (class list, - detection target, bounding-box mapper). Defaults to - ``TargetConfig()``. """ samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0) @@ -143,23 +135,6 @@ class ModelConfig(BaseConfig): default_factory=PreprocessingConfig ) postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) - targets: TargetConfig = Field(default_factory=TargetConfig) - - @classmethod - def load( - cls, - path: PathLike, - field: str | None = None, - extra: Literal["ignore", "allow", "forbid"] | None = None, - strict: bool | None = None, - targets: TargetConfig | None = None, - ) -> "ModelConfig": - config = super().load(path, field, extra, strict) - - if targets is None: - return config - - return config.model_copy(update={"targets": targets}) class Model(torch.nn.Module): @@ -183,33 +158,32 @@ class Model(torch.nn.Module): postprocessor : PostprocessorProtocol Converts the raw ``ModelOutput`` from ``detector`` into a list of per-clip detection tensors. - targets : TargetProtocol - Describes the set of target classes; used when building heads and - during training target construction. - roi_mapper : ROIMapperProtocol - Maps geometries to target-size channels and back. + class_names : list[str] + Class names corresponding to the model classification outputs. + dimension_names : list[str] + Size-dimension names corresponding to the model size outputs. """ detector: DetectionModel preprocessor: PreprocessorProtocol postprocessor: PostprocessorProtocol - targets: TargetProtocol - roi_mapper: ROIMapperProtocol + class_names: list[str] + dimension_names: list[str] def __init__( self, detector: DetectionModel, preprocessor: PreprocessorProtocol, postprocessor: PostprocessorProtocol, - targets: TargetProtocol, - roi_mapper: ROIMapperProtocol, + class_names: list[str], + dimension_names: list[str], ): super().__init__() self.detector = detector self.preprocessor = preprocessor self.postprocessor = postprocessor - self.targets = targets - self.roi_mapper = roi_mapper + self.class_names = class_names + self.dimension_names = dimension_names def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]: """Run the full detection pipeline on a waveform tensor. @@ -238,8 +212,8 @@ class Model(torch.nn.Module): def build_model( config: ModelConfig | None = None, - targets: TargetProtocol | None = None, - roi_mapper: ROIMapperProtocol | None = None, + class_names: list[str] | None = None, + dimension_names: list[str] | None = None, preprocessor: PreprocessorProtocol | None = None, postprocessor: PostprocessorProtocol | None = None, ) -> Model: @@ -254,11 +228,13 @@ def build_model( ---------- config : ModelConfig, optional Full model configuration (samplerate, architecture, preprocessing, - postprocessing, targets). Defaults to ``ModelConfig()`` if not - provided. - targets : TargetProtocol, optional - Pre-built targets object. If given, overrides - ``config.targets``. + postprocessing). Defaults to ``ModelConfig()`` if not provided. + class_names : list[str], optional + Class names used to size the classifier head. Required when building + a new model. + dimension_names : list[str], optional + Dimension names used to size the bbox head. Required when building a + new model. preprocessor : PreprocessorProtocol, optional Pre-built preprocessor. If given, overrides ``config.preprocess`` and ``config.samplerate`` for the @@ -278,19 +254,17 @@ def build_model( """ from batdetect2.postprocess import build_postprocessor from batdetect2.preprocess import build_preprocessor - from batdetect2.targets import build_roi_mapping, build_targets config = config or ModelConfig() - targets = targets or build_targets(config=config.targets) - targets_config = getattr(targets, "config", None) - roi_config = ( - targets_config.roi - if isinstance(targets_config, TargetConfig) - else config.targets.roi - ) + if class_names is None: + raise ValueError("class_names must be provided when building a model.") + + if dimension_names is None: + raise ValueError( + "dimension_names must be provided when building a model." + ) - roi_mapper = roi_mapper or build_roi_mapping(config=roi_config) preprocessor = preprocessor or build_preprocessor( config=config.preprocess, input_samplerate=config.samplerate, @@ -300,16 +274,16 @@ def build_model( config=config.postprocess, ) detector = build_detector( - num_classes=len(targets.class_names), - num_sizes=len(roi_mapper.dimension_names), + num_classes=len(class_names), + num_sizes=len(dimension_names), config=config.architecture, ) return Model( detector=detector, postprocessor=postprocessor, preprocessor=preprocessor, - targets=targets, - roi_mapper=roi_mapper, + class_names=class_names, + dimension_names=dimension_names, ) @@ -329,6 +303,6 @@ def build_model_with_new_targets( detector=detector, postprocessor=model.postprocessor, preprocessor=model.preprocessor, - targets=targets, - roi_mapper=roi_mapper, + class_names=targets.class_names, + dimension_names=roi_mapper.dimension_names, ) diff --git a/src/batdetect2/targets/__init__.py b/src/batdetect2/targets/__init__.py index 5bba3d1..f4fc12d 100644 --- a/src/batdetect2/targets/__init__.py +++ b/src/batdetect2/targets/__init__.py @@ -6,7 +6,7 @@ from batdetect2.targets.classes import ( build_sound_event_encoder, get_class_names_from_config, ) -from batdetect2.targets.config import TargetConfig +from batdetect2.targets.config import TargetConfig, build_default_target_config from batdetect2.targets.rois import ( AnchorBBoxMapperConfig, ROIMapperConfig, @@ -36,13 +36,14 @@ from batdetect2.targets.types import ( SoundEventFilter, TargetProtocol, ) +from batdetect2.targets.utils import check_target_compatibility __all__ = [ "AnchorBBoxMapperConfig", "Position", - "ROIMappingConfig", - "ROIMapperProtocol", "ROIMapperConfig", + "ROIMapperProtocol", + "ROIMappingConfig", "ROITargetMapper", "Size", "SoundEventDecoder", @@ -52,12 +53,14 @@ __all__ = [ "TargetConfig", "TargetProtocol", "Targets", - "build_roi_mapping", + "build_default_target_config", "build_roi_mapper", + "build_roi_mapping", "build_sound_event_decoder", "build_sound_event_encoder", "build_targets", "call_type", + "check_target_compatibility", "data_source", "generic_class", "get_class_names_from_config", diff --git a/src/batdetect2/targets/config.py b/src/batdetect2/targets/config.py index aa7cca9..b254dec 100644 --- a/src/batdetect2/targets/config.py +++ b/src/batdetect2/targets/config.py @@ -2,6 +2,7 @@ from collections import Counter from typing import List from pydantic import Field, field_validator +from soundevent import data from batdetect2.core.configs import BaseConfig from batdetect2.targets.classes import ( @@ -13,6 +14,7 @@ from batdetect2.targets.rois import ROIMappingConfig __all__ = [ "TargetConfig", + "build_default_target_config", ] @@ -42,3 +44,20 @@ class TargetConfig(BaseConfig): f"{', '.join(duplicates)}" ) return v + + +def build_default_target_config(class_names: list[str]) -> TargetConfig: + """Build a default target configuration object.""" + return TargetConfig( + detection_target=DEFAULT_DETECTION_CLASS, + classification_targets=[ + TargetClassConfig( + name=class_name, + tags=[ + data.Tag(key="class", value=class_name), + ], + ) + for class_name in class_names + ], + roi=ROIMappingConfig(), + ) diff --git a/src/batdetect2/targets/utils.py b/src/batdetect2/targets/utils.py new file mode 100644 index 0000000..679a6c6 --- /dev/null +++ b/src/batdetect2/targets/utils.py @@ -0,0 +1,29 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from batdetect2.targets.types import TargetProtocol + + +def check_target_compatibility( + targets: "TargetProtocol", + class_names: list[str], +) -> bool: + """Check if a target definition can decode a model's outputs. + + Parameters + ---------- + targets : TargetProtocol + Target definition that would be used with the model outputs. + class_names : list[str] + Class names produced by the model checkpoint. + + Returns + ------- + bool + True when every model class name exists in the provided targets, + False otherwise. + """ + target_class_names = set(targets.class_names) + model_class_names = set(class_names) + + return model_class_names.issubset(target_class_names) diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index e4a9881..0f9a062 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -10,6 +10,7 @@ from batdetect2.logging import get_image_logger from batdetect2.models.types import ModelOutput from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.postprocess.types import ClipDetections +from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol from batdetect2.train.dataset import ValidationDataset from batdetect2.train.lightning import TrainingModule from batdetect2.train.types import TrainExample @@ -19,11 +20,15 @@ class ValidationMetrics(Callback): def __init__( self, evaluator: EvaluatorProtocol, + targets: TargetProtocol, + roi_mapper: ROIMapperProtocol, output_transform: OutputTransformProtocol | None = None, ): super().__init__() self.evaluator = evaluator + self.targets = targets + self.roi_mapper = roi_mapper self.output_transform = output_transform self._clip_annotations: List[data.ClipAnnotation] = [] @@ -93,8 +98,8 @@ class ValidationMetrics(Callback): model = pl_module.model if self.output_transform is None: self.output_transform = build_output_transform( - targets=model.targets, - roi_mapper=model.roi_mapper, + targets=self.targets, + roi_mapper=self.roi_mapper, ) output_transform = self.output_transform diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 97e1b81..280dadd 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -9,9 +9,7 @@ from batdetect2.train.optimizers import build_optimizer from batdetect2.train.schedulers import build_scheduler from batdetect2.train.types import LossProtocol, TrainExample -__all__ = [ - "TrainingModule", -] +__all__ = ["TrainingModule"] class TrainingModule(L.LightningModule): @@ -21,6 +19,8 @@ class TrainingModule(L.LightningModule): def __init__( self, model_config: dict | None = None, + class_names: list[str] | None = None, + dimension_names: list[str] | None = None, train_config: dict | None = None, loss: LossProtocol | None = None, model: Model | None = None, @@ -30,13 +30,31 @@ class TrainingModule(L.LightningModule): self.save_hyperparameters(ignore=["model", "loss"], logger=False) self.model_config = ModelConfig.model_validate(model_config or {}) + self.class_names = list(class_names or []) + self.dimension_names = list(dimension_names or []) self.train_config = TrainingConfig.model_validate(train_config or {}) if loss is None: loss = build_loss(config=self.train_config.loss) if model is None: - model = build_model(config=self.model_config) + if not self.class_names: + raise ValueError( + "class_names must be provided when rebuilding a training " + "module without a model." + ) + + if not self.dimension_names: + raise ValueError( + "dimension_names must be provided when rebuilding a " + "training module without a model." + ) + + model = build_model( + config=self.model_config, + class_names=self.class_names, + dimension_names=self.dimension_names, + ) self.loss = loss self.model = model @@ -110,8 +128,7 @@ def load_model_from_checkpoint( ------- tuple[Model, ModelConfig] The restored ``Model`` instance and the ``ModelConfig`` that - describes its architecture, preprocessing, postprocessing, and - targets. + describes its architecture, preprocessing, and postprocessing. """ module = TrainingModule.load_from_checkpoint(path) # type: ignore return module.model, module.model_config @@ -119,6 +136,8 @@ def load_model_from_checkpoint( def build_training_module( model_config: ModelConfig | None = None, + class_names: list[str] | None = None, + dimension_names: list[str] | None = None, train_config: TrainingConfig | None = None, model: Model | None = None, ) -> TrainingModule: @@ -130,6 +149,8 @@ def build_training_module( return TrainingModule( model_config=model_config.model_dump(mode="json"), + class_names=class_names, + dimension_names=dimension_names, train_config=train_config.model_dump(mode="json"), model=model, ) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 83d8060..fb4743f 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -17,6 +17,7 @@ from batdetect2.models import Model, ModelConfig, build_model from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor from batdetect2.targets import ( ROIMapperProtocol, + TargetConfig, TargetProtocol, build_roi_mapping, build_targets, @@ -46,6 +47,7 @@ def run_train( labeller: Optional["ClipLabeller"] = None, audio_config: Optional[AudioConfig] = None, model_config: Optional[ModelConfig] = None, + targets_config: TargetConfig | None = None, train_config: Optional[TrainingConfig] = None, logger_config: LoggerConfig | None = None, trainer: Trainer | None = None, @@ -62,23 +64,34 @@ def run_train( seed_everything(seed) model_config = model_config or ModelConfig() + targets_config = targets_config or TargetConfig() audio_config = audio_config or AudioConfig() train_config = train_config or TrainingConfig() if model is not None: - _validate_model_compatibility(model=model, model_config=model_config) + if targets is None: + raise ValueError( + "targets must be provided when training with an existing " + "model." + ) + + if roi_mapper is None: + raise ValueError( + "roi_mapper must be provided when training with an existing " + "model." + ) + + targets = targets or build_targets(config=targets_config) + + roi_mapper = roi_mapper or build_roi_mapping(config=targets_config.roi) if model is not None: - targets = targets or model.targets - - if roi_mapper is None and targets is model.targets: - roi_mapper = model.roi_mapper - - targets = targets or build_targets(config=model_config.targets) - - roi_mapper = roi_mapper or build_roi_mapping( - config=model_config.targets.roi - ) + _validate_model_compatibility( + model=model, + model_config=model_config, + class_names=targets.class_names, + dimension_names=roi_mapper.dimension_names, + ) audio_loader = audio_loader or build_audio_loader(config=audio_config) @@ -119,18 +132,24 @@ def run_train( module = build_training_module( model_config=model_config, + class_names=targets.class_names, + dimension_names=roi_mapper.dimension_names, train_config=train_config, model=model, ) + evaluator = build_evaluator( + train_config.validation, + targets=targets, + roi_mapper=roi_mapper, + ) + trainer = trainer or build_trainer( train_config, logger_config=logger_config, - evaluator=build_evaluator( - train_config.validation, - targets=targets, - roi_mapper=roi_mapper, - ), + evaluator=evaluator, + targets=targets, + roi_mapper=roi_mapper, checkpoint_dir=checkpoint_dir, num_epochs=num_epochs, log_dir=log_dir, @@ -152,8 +171,14 @@ def run_train( def _validate_model_compatibility( model: Model, model_config: ModelConfig, + class_names: list[str], + dimension_names: list[str], ) -> None: - reference_model = build_model(config=model_config) + reference_model = build_model( + config=model_config, + class_names=class_names, + dimension_names=dimension_names, + ) expected_shapes = { key: tuple(value.shape) @@ -196,6 +221,8 @@ def build_trainer( config: TrainingConfig, logger_config: LoggerConfig | None, evaluator: "EvaluatorProtocol", + targets: "TargetProtocol", + roi_mapper: "ROIMapperProtocol", checkpoint_dir: Path | None = None, log_dir: Path | None = None, experiment_name: str | None = None, @@ -234,6 +261,6 @@ def build_trainer( experiment_name=experiment_name, run_name=run_name, ), - ValidationMetrics(evaluator), + ValidationMetrics(evaluator, targets, roi_mapper), ], ) diff --git a/tests/conftest.py b/tests/conftest.py index f3d864d..9a65b91 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,13 +13,14 @@ from soundevent import data, terms from batdetect2.audio import build_audio_loader from batdetect2.audio.clips import build_clipper from batdetect2.audio.types import AudioLoader, ClipperProtocol -from batdetect2.config import BatDetect2Config from batdetect2.data import DatasetConfig, load_dataset from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.targets import ( + ROIMapperProtocol, TargetConfig, + build_roi_mapping, build_targets, call_type, ) @@ -404,6 +405,13 @@ def sample_targets( return build_targets(sample_target_config) +@pytest.fixture +def sample_roi_mapper( + sample_target_config: TargetConfig, +) -> ROIMapperProtocol: + return build_roi_mapping(sample_target_config.roi) + + @pytest.fixture def sample_labeller( sample_targets: TargetProtocol, @@ -458,8 +466,15 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]: @pytest.fixture -def tiny_checkpoint_path(tmp_path: Path) -> Path: - module = build_training_module(model_config=BatDetect2Config().model) +def tiny_checkpoint_path( + sample_targets: TargetProtocol, + sample_roi_mapper: ROIMapperProtocol, + tmp_path: Path, +) -> Path: + module = build_training_module( + class_names=sample_targets.class_names, + dimension_names=sample_roi_mapper.dimension_names, + ) trainer = L.Trainer(enable_checkpointing=False, logger=False) checkpoint_path = tmp_path / "model.ckpt" trainer.strategy.connect(module) diff --git a/tests/test_api_v2/test_api_v2.py b/tests/test_api_v2/test_api_v2.py index 42ad1ec..dd29a24 100644 --- a/tests/test_api_v2/test_api_v2.py +++ b/tests/test_api_v2/test_api_v2.py @@ -8,20 +8,43 @@ import torch from soundevent.geometry import compute_bounds from batdetect2.api_v2 import BatDetect2API -from batdetect2.config import BatDetect2Config +from batdetect2.inference import InferenceConfig from batdetect2.models.detectors import Detector -from batdetect2.models.heads import ClassifierHead -from batdetect2.train import load_model_from_checkpoint +from batdetect2.targets import TargetConfig +from batdetect2.train import TrainingConfig, load_model_from_checkpoint from batdetect2.train.lightning import build_training_module @pytest.fixture -def api_v2() -> BatDetect2API: +def train_config() -> TrainingConfig: + """Train config with a small batch size for testing.""" + return TrainingConfig.model_validate({"train_loader": {"batch_size": 2}}) + + +@pytest.fixture +def inference_config() -> InferenceConfig: + """Inference config with a small batch size for testing.""" + return InferenceConfig.model_validate({"loader": {"batch_size": 2}}) + + +@pytest.fixture +def example_targets_config(example_data_dir: Path) -> TargetConfig: + return TargetConfig.load(example_data_dir / "targets.yaml") + + +@pytest.fixture +def api_v2( + train_config: TrainingConfig, + inference_config: InferenceConfig, +) -> BatDetect2API: """User story: users can create a ready-to-use API from config.""" - config = BatDetect2Config() - config.inference.loader.batch_size = 2 - return BatDetect2API.from_config(config) + api = BatDetect2API.from_config( + train_config=train_config, + inference_config=inference_config, + ) + assert api.inference_config.loader.batch_size == 2 + return api def test_process_file_returns_recording_level_predictions( @@ -30,8 +53,10 @@ def test_process_file_returns_recording_level_predictions( ) -> None: """User story: process a file and get detections in recording time.""" + # When prediction = api_v2.process_file(example_audio_files[0]) + # Then assert prediction.clip.recording.path == example_audio_files[0] assert prediction.clip.start_time == 0 assert prediction.clip.end_time == prediction.clip.recording.duration @@ -53,9 +78,11 @@ def test_process_files_is_batch_size_invariant( ) -> None: """User story: changing batch size should not change predictions.""" + # When preds_batch_1 = api_v2.process_files(example_audio_files, batch_size=1) preds_batch_3 = api_v2.process_files(example_audio_files, batch_size=3) + # Then assert len(preds_batch_1) == len(preds_batch_3) by_key_1 = { @@ -91,12 +118,14 @@ def test_process_audio_matches_process_spectrogram( ) -> None: """User story: users can call either audio or spectrogram entrypoint.""" + # When audio = api_v2.load_audio(example_audio_files[0]) from_audio = api_v2.process_audio(audio) spec = api_v2.generate_spectrogram(audio) from_spec = api_v2.process_spectrogram(spec) + # Then assert len(from_audio) == len(from_spec) for det_audio, det_spec in zip(from_audio, from_spec, strict=True): @@ -116,8 +145,10 @@ def test_process_spectrogram_rejects_batched_input( ) -> None: """User story: invalid batched input gives a clear error.""" + # Given spec = torch.zeros((2, 1, 128, 64), dtype=torch.float32) + # When/Then with pytest.raises(ValueError, match="Batched spectrograms not supported"): api_v2.process_spectrogram(spec) @@ -184,26 +215,34 @@ def test_user_can_read_extracted_features_per_detection( @pytest.mark.slow def test_user_can_load_checkpoint_and_finetune( tmp_path: Path, + example_targets_config: TargetConfig, example_annotations, ) -> None: """User story: load a checkpoint and continue training from it.""" - module = build_training_module(model_config=BatDetect2Config().model) + api = BatDetect2API.from_config( + targets_config=example_targets_config, + ) + module = build_training_module( + model_config=api.model_config, + class_names=api.targets.class_names, + dimension_names=api.roi_mapper.dimension_names, + ) trainer = L.Trainer(enable_checkpointing=False, logger=False) checkpoint_path = tmp_path / "base.ckpt" trainer.strategy.connect(module) trainer.save_checkpoint(checkpoint_path) - config = BatDetect2Config() - config.train.trainer.limit_train_batches = 1 - config.train.trainer.limit_val_batches = 1 - config.train.trainer.log_every_n_steps = 1 - config.train.train_loader.batch_size = 1 - config.train.train_loader.augmentations.enabled = False + train_config = api.train_config.model_copy(deep=True) + train_config.trainer.limit_train_batches = 1 + train_config.trainer.limit_val_batches = 1 + train_config.trainer.log_every_n_steps = 1 + train_config.train_loader.batch_size = 1 + train_config.train_loader.augmentations.enabled = False api = BatDetect2API.from_checkpoint( checkpoint_path, - train_config=config.train, + train_config=train_config, ) finetune_dir = tmp_path / "finetuned" @@ -222,62 +261,36 @@ def test_user_can_load_checkpoint_and_finetune( assert checkpoints -def test_user_can_load_checkpoint_with_new_targets( - tmp_path: Path, - sample_targets, -) -> None: - """User story: start from checkpoint with a new target definition.""" - - module = build_training_module(model_config=BatDetect2Config().model) - trainer = L.Trainer(enable_checkpointing=False, logger=False) - checkpoint_path = tmp_path / "base_transfer.ckpt" - trainer.strategy.connect(module) - trainer.save_checkpoint(checkpoint_path) - - source_model, _ = load_model_from_checkpoint(checkpoint_path) - api = BatDetect2API.from_checkpoint( - checkpoint_path, - targets_config=sample_targets.config, - ) - source_detector = cast(Detector, source_model.detector) - detector = cast(Detector, api.model.detector) - classifier_head = cast(ClassifierHead, detector.classifier_head) - - assert api.targets.config == sample_targets.config # type: ignore - assert detector.num_classes == len(sample_targets.class_names) - assert ( - classifier_head.classifier.out_channels - == len(sample_targets.class_names) + 1 - ) - - source_backbone = source_detector.backbone.state_dict() - target_backbone = detector.backbone.state_dict() - assert source_backbone - for key, value in source_backbone.items(): - assert key in target_backbone - torch.testing.assert_close(target_backbone[key], value) - - def test_checkpoint_with_same_targets_config_keeps_heads_unchanged( + example_targets_config: TargetConfig, tmp_path: Path, ) -> None: """User story: same targets config does not rebuild prediction heads.""" - module = build_training_module(model_config=BatDetect2Config().model) + # Given + source_api = BatDetect2API.from_config( + targets_config=example_targets_config + ) + module = build_training_module( + model_config=source_api.model_config, + class_names=source_api.targets.class_names, + dimension_names=source_api.roi_mapper.dimension_names, + ) trainer = L.Trainer(enable_checkpointing=False, logger=False) checkpoint_path = tmp_path / "same_targets.ckpt" trainer.strategy.connect(module) trainer.save_checkpoint(checkpoint_path) - source_model, source_model_config = load_model_from_checkpoint( - checkpoint_path - ) + source_model, _ = load_model_from_checkpoint(checkpoint_path) source_detector = cast(Detector, source_model.detector) + # When api = BatDetect2API.from_checkpoint( checkpoint_path, - targets_config=source_model_config.targets, + targets_config=example_targets_config, ) + + # Then detector = cast(Detector, api.model.detector) for key, value in source_detector.classifier_head.state_dict().items(): @@ -302,7 +315,7 @@ def test_user_can_finetune_only_heads( ) -> None: """User story: fine-tune only prediction heads.""" - api = BatDetect2API.from_config(BatDetect2Config()) + api = BatDetect2API.from_config() finetune_dir = tmp_path / "heads_only" api.finetune( @@ -348,8 +361,6 @@ def test_user_can_evaluate_small_dataset_and_get_metrics( assert isinstance(metrics, list) assert len(metrics) == 1 - assert isinstance(metrics[0], dict) - assert len(metrics[0]) > 0 assert isinstance(predictions, list) assert len(predictions) == 1 @@ -450,8 +461,17 @@ def test_detection_threshold_override_changes_spectrogram_results( spec = api_v2.generate_spectrogram(audio) default_detections = api_v2.process_spectrogram(spec) strict_detections = api_v2.process_spectrogram( - spec, - detection_threshold=1.0, + spec, detection_threshold=1.0 ) assert len(strict_detections) <= len(default_detections) + + +def test_user_can_create_api_with_custom_targets_and_model_metadata_matches( + sample_targets, +) -> None: + """User story: custom targets define model output names for a new API.""" + + api = BatDetect2API.from_config(targets_config=sample_targets.config) + + assert api.model.class_names == sample_targets.class_names diff --git a/tests/test_api_v2/test_outputs_io.py b/tests/test_api_v2/test_outputs_io.py index 7cb5062..5a62fb8 100644 --- a/tests/test_api_v2/test_outputs_io.py +++ b/tests/test_api_v2/test_outputs_io.py @@ -5,7 +5,6 @@ import numpy as np import pytest from batdetect2.api_v2 import BatDetect2API -from batdetect2.config import BatDetect2Config from batdetect2.outputs import build_output_formatter from batdetect2.outputs.formats import ( BatDetect2OutputConfig, @@ -18,7 +17,7 @@ from batdetect2.postprocess.types import ClipDetections def api_v2() -> BatDetect2API: """User story: API object manages prediction IO formats.""" - return BatDetect2API.from_config(BatDetect2Config()) + return BatDetect2API.from_config() @pytest.fixture diff --git a/tests/test_targets/test_utils.py b/tests/test_targets/test_utils.py new file mode 100644 index 0000000..c4b5d1c --- /dev/null +++ b/tests/test_targets/test_utils.py @@ -0,0 +1,40 @@ +from soundevent import data + +from batdetect2.targets import ( + TargetClassConfig, + TargetConfig, + build_targets, + check_target_compatibility, +) + + +def _target_class(name: str) -> TargetClassConfig: + return TargetClassConfig( + name=name, + tags=[data.Tag(key="class", value=name)], + ) + + +def test_check_target_compatibility_accepts_superset_targets() -> None: + config = TargetConfig( + classification_targets=[ + _target_class("pip35"), + _target_class("myo"), + _target_class("extra"), + ] + ) + targets = build_targets(config) + + assert check_target_compatibility(targets, ["pip35", "myo"]) + + +def test_check_target_compatibility_rejects_missing_model_classes() -> None: + config = TargetConfig( + classification_targets=[ + _target_class("pip35"), + _target_class("myo"), + ] + ) + targets = build_targets(config) + + assert not check_target_compatibility(targets, ["pip35", "nyc"]) diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index 284fea4..566ae62 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -10,9 +10,8 @@ from torch.optim.lr_scheduler import CosineAnnealingLR from batdetect2.api_v2 import BatDetect2API from batdetect2.audio.types import AudioLoader -from batdetect2.config import BatDetect2Config from batdetect2.models import ModelConfig, build_model -from batdetect2.targets.classes import TargetClassConfig +from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets from batdetect2.train import ( TrainingConfig, TrainingModule, @@ -24,11 +23,21 @@ from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig from batdetect2.train.train import build_training_module -def build_default_module(config: BatDetect2Config | None = None): - config = config or BatDetect2Config() +def build_default_module( + target_config: TargetConfig | None = None, + model_config: ModelConfig | None = None, + train_config: TrainingConfig | None = None, +): + target_config = target_config or TargetConfig() + model_config = model_config or ModelConfig() + train_config = train_config or TrainingConfig() + targets = build_targets(target_config) + roi_mapper = build_roi_mapping(target_config.roi) return build_training_module( - model_config=config.model, - train_config=config.train, + model_config=model_config, + class_names=targets.class_names, + dimension_names=roi_mapper.dimension_names, + train_config=train_config, ) @@ -72,8 +81,13 @@ def test_load_model_from_checkpoint_returns_model_and_config( input_model_config.model_dump(mode="json") ) train_config = TrainingConfig() + targets_config = TargetConfig() + targets = build_targets(targets_config) + roi_mapper = build_roi_mapping(targets_config.roi) module = build_training_module( model_config=input_model_config, + class_names=targets.class_names, + dimension_names=roi_mapper.dimension_names, train_config=train_config, ) trainer = L.Trainer() @@ -87,6 +101,8 @@ def test_load_model_from_checkpoint_returns_model_and_config( assert loaded_model_config.model_dump( mode="json" ) == expected_model_config.model_dump(mode="json") + assert model.class_names == targets.class_names + assert model.dimension_names == roi_mapper.dimension_names recovered = TrainingModule.load_from_checkpoint(path) assert recovered.train_config.model_dump( @@ -100,6 +116,9 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path): model_config.model_dump(mode="json") ) train_config = TrainingConfig() + targets_config = TargetConfig() + targets = build_targets(targets_config) + roi_mapper = build_roi_mapping(targets_config.roi) train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4) train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=123) train_config.trainer.max_epochs = 3 @@ -107,6 +126,8 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path): module = build_training_module( model_config=model_config, + class_names=targets.class_names, + dimension_names=roi_mapper.dimension_names, train_config=train_config, ) trainer = L.Trainer() @@ -131,11 +152,16 @@ def test_configure_optimizers_uses_train_config_values(tmp_path: Path): model_config.model_dump(mode="json") ) train_config = TrainingConfig() + targets_config = TargetConfig() + targets = build_targets(targets_config) + roi_mapper = build_roi_mapping(targets_config.roi) train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4) train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=321) module = build_training_module( model_config=model_config, + class_names=targets.class_names, + dimension_names=roi_mapper.dimension_names, train_config=train_config, ) @@ -189,19 +215,26 @@ def test_train_smoke_produces_loadable_checkpoint( example_annotations: list[data.ClipAnnotation], sample_audio_loader: AudioLoader, ): - config = BatDetect2Config() - config.train.trainer.limit_train_batches = 1 - config.train.trainer.limit_val_batches = 1 - config.train.trainer.log_every_n_steps = 1 - config.train.train_loader.batch_size = 1 - config.train.train_loader.augmentations.enabled = False + # Given + train_config = TrainingConfig.model_validate( + { + "trainer": { + "limit_train_batches": 1, + "limit_val_batches": 1, + "log_every_n_steps": 1, + }, + "train_loader": { + "batch_size": 1, + "augmentations": {"enabled": False}, + }, + } + ) + # When run_train( train_annotations=example_annotations[:1], val_annotations=example_annotations[:1], - train_config=config.train, - model_config=config.model, - audio_config=config.audio, + train_config=train_config, num_epochs=1, train_workers=0, val_workers=0, @@ -209,18 +242,11 @@ def test_train_smoke_produces_loadable_checkpoint( seed=0, ) + # Then checkpoints = list(tmp_path.rglob("*.ckpt")) assert checkpoints model, model_config = load_model_from_checkpoint(checkpoints[0]) - assert model_config.samplerate == config.model.samplerate - assert model_config.architecture.name == config.model.architecture.name - assert model_config.preprocess.model_dump( - mode="json" - ) == config.model.preprocess.model_dump(mode="json") - assert model_config.postprocess.model_dump( - mode="json" - ) == config.model.postprocess.model_dump(mode="json") wav = torch.tensor( sample_audio_loader.load_clip(example_annotations[0].clip) @@ -230,10 +256,18 @@ def test_train_smoke_produces_loadable_checkpoint( def test_build_training_module_uses_provided_model() -> None: - model = build_model(ModelConfig()) + targets = build_targets(TargetConfig()) + roi_mapper = build_roi_mapping(TargetConfig().roi) + model = build_model( + ModelConfig(), + class_names=targets.class_names, + dimension_names=roi_mapper.dimension_names, + ) module = build_training_module( model_config=ModelConfig(), + class_names=targets.class_names, + dimension_names=roi_mapper.dimension_names, train_config=TrainingConfig(), model=model, ) @@ -244,15 +278,18 @@ def test_build_training_module_uses_provided_model() -> None: def test_run_train_rejects_incompatible_model_config( example_annotations: list[data.ClipAnnotation], ) -> None: - model = build_model(ModelConfig()) + # Given + targets_config = TargetConfig() + targets = build_targets(targets_config) + roi_mapper = build_roi_mapping(targets_config.roi) incompatible_config = ModelConfig() - incompatible_config.targets.classification_targets.append( - TargetClassConfig( - name="dummy_class", - tags=[data.Tag(key="class", value="Dummy class")], - ) + incompatible_model = build_model( + incompatible_config, + class_names=targets.class_names, + dimension_names=[*roi_mapper.dimension_names, "extra_dim"], ) + # When/Then with pytest.raises( ValueError, match="Provided model is incompatible with model_config", @@ -260,7 +297,10 @@ def test_run_train_rejects_incompatible_model_config( run_train( train_annotations=example_annotations[:1], val_annotations=None, - model=model, + model=incompatible_model, + targets=targets, + roi_mapper=roi_mapper, model_config=incompatible_config, + targets_config=targets_config, train_config=TrainingConfig(), )