refactor: decouple model metadata from target configs

This commit is contained in:
mbsantiago 2026-05-04 21:18:17 +01:00
parent e33053614a
commit 57236fc82a
17 changed files with 483 additions and 237 deletions

View File

@ -12,7 +12,6 @@ if TYPE_CHECKING:
import torch import torch
from batdetect2.audio import AudioConfig, AudioLoader from batdetect2.audio import AudioConfig, AudioLoader
from batdetect2.config import BatDetect2Config
from batdetect2.data import Dataset from batdetect2.data import Dataset
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
from batdetect2.inference import InferenceConfig from batdetect2.inference import InferenceConfig
@ -483,46 +482,70 @@ class BatDetect2API:
@classmethod @classmethod
def from_config( def from_config(
cls, cls,
config: BatDetect2Config, model_config: ModelConfig | None = None,
targets_config: TargetConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
inference_config: InferenceConfig | None = None,
outputs_config: OutputsConfig | None = None,
logging_config: AppLoggingConfig | None = None,
) -> "BatDetect2API": ) -> "BatDetect2API":
from batdetect2.audio import build_audio_loader from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.evaluate import build_evaluator from batdetect2.evaluate import EvaluationConfig, build_evaluator
from batdetect2.models import build_model from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.models import ModelConfig, build_model
from batdetect2.outputs import ( from batdetect2.outputs import (
OutputsConfig,
build_output_formatter, build_output_formatter,
build_output_transform, build_output_transform,
) )
from batdetect2.postprocess import build_postprocessor from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_roi_mapping, build_targets from batdetect2.targets import (
TargetConfig,
build_roi_mapping,
build_targets,
)
from batdetect2.train import TrainingConfig
targets = build_targets(config=config.model.targets) model_config = model_config or ModelConfig()
roi_mapper = build_roi_mapping(config=config.model.targets.roi) targets_config = targets_config or TargetConfig()
audio_config = audio_config or AudioConfig()
train_config = train_config or TrainingConfig()
evaluation_config = evaluation_config or EvaluationConfig()
inference_config = inference_config or InferenceConfig()
outputs_config = outputs_config or OutputsConfig()
logging_config = logging_config or AppLoggingConfig()
audio_loader = build_audio_loader(config=config.audio) targets = build_targets(config=targets_config)
roi_mapper = build_roi_mapping(config=targets_config.roi)
audio_loader = build_audio_loader(config=audio_config)
preprocessor = build_preprocessor( preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate, input_samplerate=audio_loader.samplerate,
config=config.model.preprocess, config=model_config.preprocess,
) )
postprocessor = build_postprocessor( postprocessor = build_postprocessor(
preprocessor, preprocessor,
config=config.model.postprocess, config=model_config.postprocess,
) )
formatter = build_output_formatter( formatter = build_output_formatter(
targets, targets,
config=config.outputs.format, config=outputs_config.format,
) )
output_transform = build_output_transform( output_transform = build_output_transform(
config=config.outputs.transform, config=outputs_config.transform,
targets=targets, targets=targets,
roi_mapper=roi_mapper, roi_mapper=roi_mapper,
) )
evaluator = build_evaluator( evaluator = build_evaluator(
config=config.evaluation, config=evaluation_config,
targets=targets, targets=targets,
roi_mapper=roi_mapper, roi_mapper=roi_mapper,
transform=output_transform, transform=output_transform,
@ -531,27 +554,27 @@ class BatDetect2API:
# NOTE: Build separate instances of preprocessor and postprocessor # NOTE: Build separate instances of preprocessor and postprocessor
# to avoid device mismatch errors # to avoid device mismatch errors
model = build_model( model = build_model(
config=config.model, config=model_config,
targets=targets, class_names=targets.class_names,
roi_mapper=roi_mapper, dimension_names=roi_mapper.dimension_names,
preprocessor=build_preprocessor( preprocessor=build_preprocessor(
input_samplerate=audio_loader.samplerate, input_samplerate=audio_loader.samplerate,
config=config.model.preprocess, config=model_config.preprocess,
), ),
postprocessor=build_postprocessor( postprocessor=build_postprocessor(
preprocessor, preprocessor,
config=config.model.postprocess, config=model_config.postprocess,
), ),
) )
return cls( return cls(
model_config=config.model, model_config=model_config,
audio_config=config.audio, audio_config=audio_config,
train_config=config.train, train_config=train_config,
evaluation_config=config.evaluation, evaluation_config=evaluation_config,
inference_config=config.inference, inference_config=inference_config,
outputs_config=config.outputs, outputs_config=outputs_config,
logging_config=config.logging, logging_config=logging_config,
targets=targets, targets=targets,
roi_mapper=roi_mapper, roi_mapper=roi_mapper,
audio_loader=audio_loader, audio_loader=audio_loader,
@ -579,7 +602,6 @@ class BatDetect2API:
from batdetect2.evaluate import EvaluationConfig, build_evaluator from batdetect2.evaluate import EvaluationConfig, build_evaluator
from batdetect2.inference import InferenceConfig from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig from batdetect2.logging import AppLoggingConfig
from batdetect2.models import build_model_with_new_targets
from batdetect2.outputs import ( from batdetect2.outputs import (
OutputsConfig, OutputsConfig,
build_output_formatter, build_output_formatter,
@ -587,8 +609,16 @@ class BatDetect2API:
) )
from batdetect2.postprocess import build_postprocessor from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_roi_mapping, build_targets from batdetect2.targets import (
from batdetect2.train import TrainingConfig, load_model_from_checkpoint build_default_target_config,
build_roi_mapping,
build_targets,
check_target_compatibility,
)
from batdetect2.train import (
TrainingConfig,
load_model_from_checkpoint,
)
model, model_config = load_model_from_checkpoint(path) model, model_config = load_model_from_checkpoint(path)
@ -600,24 +630,24 @@ class BatDetect2API:
inference_config = inference_config or InferenceConfig() inference_config = inference_config or InferenceConfig()
outputs_config = outputs_config or OutputsConfig() outputs_config = outputs_config or OutputsConfig()
logging_config = logging_config or AppLoggingConfig() logging_config = logging_config or AppLoggingConfig()
targets_config = targets_config or build_default_target_config(
class_names=model.class_names
)
if ( targets = build_targets(config=targets_config)
targets_config is not None roi_mapper = build_roi_mapping(config=targets_config.roi)
and targets_config != model_config.targets
): if not check_target_compatibility(targets, model.class_names):
targets = build_targets(config=targets_config) raise ValueError(
roi_mapper = build_roi_mapping(config=targets_config.roi) "Provided targets_config is incompatible with the "
model = build_model_with_new_targets( "checkpoint model: missing one or more model classes."
model=model,
targets=targets,
roi_mapper=roi_mapper,
)
model_config = model_config.model_copy(
update={"targets": targets_config}
) )
targets = build_targets(config=model_config.targets) if model.dimension_names != roi_mapper.dimension_names:
roi_mapper = build_roi_mapping(config=model_config.targets.roi) raise ValueError(
"Provided targets_config is incompatible with the "
"checkpoint model: mismatched dimension names."
)
audio_loader = build_audio_loader(config=audio_config) audio_loader = build_audio_loader(config=audio_config)

View File

@ -12,6 +12,7 @@ from batdetect2.inference.config import InferenceConfig
from batdetect2.logging import AppLoggingConfig from batdetect2.logging import AppLoggingConfig
from batdetect2.models import ModelConfig from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig from batdetect2.outputs import OutputsConfig
from batdetect2.targets import TargetConfig
from batdetect2.train.config import TrainingConfig from batdetect2.train.config import TrainingConfig
__all__ = ["BatDetect2Config"] __all__ = ["BatDetect2Config"]
@ -25,6 +26,7 @@ class BatDetect2Config(BaseConfig):
default_factory=get_default_eval_config default_factory=get_default_eval_config
) )
model: ModelConfig = Field(default_factory=ModelConfig) model: ModelConfig = Field(default_factory=ModelConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
audio: AudioConfig = Field(default_factory=AudioConfig) audio: AudioConfig = Field(default_factory=AudioConfig)
inference: InferenceConfig = Field(default_factory=InferenceConfig) inference: InferenceConfig = Field(default_factory=InferenceConfig)
outputs: OutputsConfig = Field(default_factory=OutputsConfig) outputs: OutputsConfig = Field(default_factory=OutputsConfig)

View File

@ -24,8 +24,8 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
def run_evaluate( def run_evaluate(
model: Model, model: Model,
test_annotations: Sequence[data.ClipAnnotation], test_annotations: Sequence[data.ClipAnnotation],
targets: TargetProtocol | None = None, targets: TargetProtocol,
roi_mapper: ROIMapperProtocol | None = None, roi_mapper: ROIMapperProtocol,
audio_loader: AudioLoader | None = None, audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
@ -46,8 +46,6 @@ def run_evaluate(
audio_loader = audio_loader or build_audio_loader(config=audio_config) audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or model.preprocessor preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets
roi_mapper = roi_mapper or model.roi_mapper
loader = build_test_loader( loader = build_test_loader(
test_annotations, test_annotations,

View File

@ -45,8 +45,16 @@ def run_batch_inference(
audio_loader = audio_loader or build_audio_loader(config=audio_config) audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or model.preprocessor preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets
roi_mapper = roi_mapper or model.roi_mapper if targets is None:
raise ValueError(
"targets must be provided when running batch inference."
)
if roi_mapper is None:
raise ValueError(
"roi_mapper must be provided when running batch inference."
)
output_transform = output_transform or build_output_transform( output_transform = output_transform or build_output_transform(
config=output_config.transform, config=output_config.transform,

View File

@ -7,21 +7,37 @@ from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
class InferenceModule(LightningModule): class InferenceModule(LightningModule):
def __init__( def __init__(
self, self,
model: Model, model: Model,
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
output_transform: OutputTransformProtocol | None = None, output_transform: OutputTransformProtocol | None = None,
detection_threshold: float | None = None, detection_threshold: float | None = None,
): ):
super().__init__() super().__init__()
self.model = model self.model = model
self.detection_threshold = detection_threshold self.detection_threshold = detection_threshold
if output_transform is None and targets is None:
raise ValueError(
"targets must be provided when building inference output "
"transforms."
)
if output_transform is None and roi_mapper is None:
raise ValueError(
"roi_mapper must be provided when building inference output "
"transforms."
)
self.output_transform = output_transform or build_output_transform( self.output_transform = output_transform or build_output_transform(
targets=model.targets, targets=targets,
roi_mapper=model.roi_mapper, roi_mapper=roi_mapper,
) )
def predict_step( def predict_step(

View File

@ -26,11 +26,8 @@ The primary entry point for building a full, ready-to-use BatDetect2 model
is the ``build_model`` factory function exported from this module. is the ``build_model`` factory function exported from this module.
""" """
from typing import Literal
import torch import torch
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
@ -73,7 +70,6 @@ from batdetect2.postprocess.types import (
) )
from batdetect2.preprocess.config import PreprocessingConfig from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.config import TargetConfig
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
__all__ = [ __all__ = [
@ -131,10 +127,6 @@ class ModelConfig(BaseConfig):
Parameters for converting raw model outputs into detections (NMS Parameters for converting raw model outputs into detections (NMS
kernel, thresholds, top-k limit). Defaults to kernel, thresholds, top-k limit). Defaults to
``PostprocessConfig()``. ``PostprocessConfig()``.
targets : TargetConfig
Detection and classification target definitions (class list,
detection target, bounding-box mapper). Defaults to
``TargetConfig()``.
""" """
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0) samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
@ -143,23 +135,6 @@ class ModelConfig(BaseConfig):
default_factory=PreprocessingConfig default_factory=PreprocessingConfig
) )
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
@classmethod
def load(
cls,
path: PathLike,
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
targets: TargetConfig | None = None,
) -> "ModelConfig":
config = super().load(path, field, extra, strict)
if targets is None:
return config
return config.model_copy(update={"targets": targets})
class Model(torch.nn.Module): class Model(torch.nn.Module):
@ -183,33 +158,32 @@ class Model(torch.nn.Module):
postprocessor : PostprocessorProtocol postprocessor : PostprocessorProtocol
Converts the raw ``ModelOutput`` from ``detector`` into a list of Converts the raw ``ModelOutput`` from ``detector`` into a list of
per-clip detection tensors. per-clip detection tensors.
targets : TargetProtocol class_names : list[str]
Describes the set of target classes; used when building heads and Class names corresponding to the model classification outputs.
during training target construction. dimension_names : list[str]
roi_mapper : ROIMapperProtocol Size-dimension names corresponding to the model size outputs.
Maps geometries to target-size channels and back.
""" """
detector: DetectionModel detector: DetectionModel
preprocessor: PreprocessorProtocol preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol postprocessor: PostprocessorProtocol
targets: TargetProtocol class_names: list[str]
roi_mapper: ROIMapperProtocol dimension_names: list[str]
def __init__( def __init__(
self, self,
detector: DetectionModel, detector: DetectionModel,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol, postprocessor: PostprocessorProtocol,
targets: TargetProtocol, class_names: list[str],
roi_mapper: ROIMapperProtocol, dimension_names: list[str],
): ):
super().__init__() super().__init__()
self.detector = detector self.detector = detector
self.preprocessor = preprocessor self.preprocessor = preprocessor
self.postprocessor = postprocessor self.postprocessor = postprocessor
self.targets = targets self.class_names = class_names
self.roi_mapper = roi_mapper self.dimension_names = dimension_names
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]: def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
"""Run the full detection pipeline on a waveform tensor. """Run the full detection pipeline on a waveform tensor.
@ -238,8 +212,8 @@ class Model(torch.nn.Module):
def build_model( def build_model(
config: ModelConfig | None = None, config: ModelConfig | None = None,
targets: TargetProtocol | None = None, class_names: list[str] | None = None,
roi_mapper: ROIMapperProtocol | None = None, dimension_names: list[str] | None = None,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
postprocessor: PostprocessorProtocol | None = None, postprocessor: PostprocessorProtocol | None = None,
) -> Model: ) -> Model:
@ -254,11 +228,13 @@ def build_model(
---------- ----------
config : ModelConfig, optional config : ModelConfig, optional
Full model configuration (samplerate, architecture, preprocessing, Full model configuration (samplerate, architecture, preprocessing,
postprocessing, targets). Defaults to ``ModelConfig()`` if not postprocessing). Defaults to ``ModelConfig()`` if not provided.
provided. class_names : list[str], optional
targets : TargetProtocol, optional Class names used to size the classifier head. Required when building
Pre-built targets object. If given, overrides a new model.
``config.targets``. dimension_names : list[str], optional
Dimension names used to size the bbox head. Required when building a
new model.
preprocessor : PreprocessorProtocol, optional preprocessor : PreprocessorProtocol, optional
Pre-built preprocessor. If given, overrides Pre-built preprocessor. If given, overrides
``config.preprocess`` and ``config.samplerate`` for the ``config.preprocess`` and ``config.samplerate`` for the
@ -278,19 +254,17 @@ def build_model(
""" """
from batdetect2.postprocess import build_postprocessor from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_roi_mapping, build_targets
config = config or ModelConfig() config = config or ModelConfig()
targets = targets or build_targets(config=config.targets)
targets_config = getattr(targets, "config", None) if class_names is None:
roi_config = ( raise ValueError("class_names must be provided when building a model.")
targets_config.roi
if isinstance(targets_config, TargetConfig) if dimension_names is None:
else config.targets.roi raise ValueError(
) "dimension_names must be provided when building a model."
)
roi_mapper = roi_mapper or build_roi_mapping(config=roi_config)
preprocessor = preprocessor or build_preprocessor( preprocessor = preprocessor or build_preprocessor(
config=config.preprocess, config=config.preprocess,
input_samplerate=config.samplerate, input_samplerate=config.samplerate,
@ -300,16 +274,16 @@ def build_model(
config=config.postprocess, config=config.postprocess,
) )
detector = build_detector( detector = build_detector(
num_classes=len(targets.class_names), num_classes=len(class_names),
num_sizes=len(roi_mapper.dimension_names), num_sizes=len(dimension_names),
config=config.architecture, config=config.architecture,
) )
return Model( return Model(
detector=detector, detector=detector,
postprocessor=postprocessor, postprocessor=postprocessor,
preprocessor=preprocessor, preprocessor=preprocessor,
targets=targets, class_names=class_names,
roi_mapper=roi_mapper, dimension_names=dimension_names,
) )
@ -329,6 +303,6 @@ def build_model_with_new_targets(
detector=detector, detector=detector,
postprocessor=model.postprocessor, postprocessor=model.postprocessor,
preprocessor=model.preprocessor, preprocessor=model.preprocessor,
targets=targets, class_names=targets.class_names,
roi_mapper=roi_mapper, dimension_names=roi_mapper.dimension_names,
) )

View File

@ -6,7 +6,7 @@ from batdetect2.targets.classes import (
build_sound_event_encoder, build_sound_event_encoder,
get_class_names_from_config, get_class_names_from_config,
) )
from batdetect2.targets.config import TargetConfig from batdetect2.targets.config import TargetConfig, build_default_target_config
from batdetect2.targets.rois import ( from batdetect2.targets.rois import (
AnchorBBoxMapperConfig, AnchorBBoxMapperConfig,
ROIMapperConfig, ROIMapperConfig,
@ -36,13 +36,14 @@ from batdetect2.targets.types import (
SoundEventFilter, SoundEventFilter,
TargetProtocol, TargetProtocol,
) )
from batdetect2.targets.utils import check_target_compatibility
__all__ = [ __all__ = [
"AnchorBBoxMapperConfig", "AnchorBBoxMapperConfig",
"Position", "Position",
"ROIMappingConfig",
"ROIMapperProtocol",
"ROIMapperConfig", "ROIMapperConfig",
"ROIMapperProtocol",
"ROIMappingConfig",
"ROITargetMapper", "ROITargetMapper",
"Size", "Size",
"SoundEventDecoder", "SoundEventDecoder",
@ -52,12 +53,14 @@ __all__ = [
"TargetConfig", "TargetConfig",
"TargetProtocol", "TargetProtocol",
"Targets", "Targets",
"build_roi_mapping", "build_default_target_config",
"build_roi_mapper", "build_roi_mapper",
"build_roi_mapping",
"build_sound_event_decoder", "build_sound_event_decoder",
"build_sound_event_encoder", "build_sound_event_encoder",
"build_targets", "build_targets",
"call_type", "call_type",
"check_target_compatibility",
"data_source", "data_source",
"generic_class", "generic_class",
"get_class_names_from_config", "get_class_names_from_config",

View File

@ -2,6 +2,7 @@ from collections import Counter
from typing import List from typing import List
from pydantic import Field, field_validator from pydantic import Field, field_validator
from soundevent import data
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.targets.classes import ( from batdetect2.targets.classes import (
@ -13,6 +14,7 @@ from batdetect2.targets.rois import ROIMappingConfig
__all__ = [ __all__ = [
"TargetConfig", "TargetConfig",
"build_default_target_config",
] ]
@ -42,3 +44,20 @@ class TargetConfig(BaseConfig):
f"{', '.join(duplicates)}" f"{', '.join(duplicates)}"
) )
return v return v
def build_default_target_config(class_names: list[str]) -> TargetConfig:
"""Build a default target configuration object."""
return TargetConfig(
detection_target=DEFAULT_DETECTION_CLASS,
classification_targets=[
TargetClassConfig(
name=class_name,
tags=[
data.Tag(key="class", value=class_name),
],
)
for class_name in class_names
],
roi=ROIMappingConfig(),
)

View File

@ -0,0 +1,29 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from batdetect2.targets.types import TargetProtocol
def check_target_compatibility(
targets: "TargetProtocol",
class_names: list[str],
) -> bool:
"""Check if a target definition can decode a model's outputs.
Parameters
----------
targets : TargetProtocol
Target definition that would be used with the model outputs.
class_names : list[str]
Class names produced by the model checkpoint.
Returns
-------
bool
True when every model class name exists in the provided targets,
False otherwise.
"""
target_class_names = set(targets.class_names)
model_class_names = set(class_names)
return model_class_names.issubset(target_class_names)

View File

@ -10,6 +10,7 @@ from batdetect2.logging import get_image_logger
from batdetect2.models.types import ModelOutput from batdetect2.models.types import ModelOutput
from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
from batdetect2.train.dataset import ValidationDataset from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.types import TrainExample from batdetect2.train.types import TrainExample
@ -19,11 +20,15 @@ class ValidationMetrics(Callback):
def __init__( def __init__(
self, self,
evaluator: EvaluatorProtocol, evaluator: EvaluatorProtocol,
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
output_transform: OutputTransformProtocol | None = None, output_transform: OutputTransformProtocol | None = None,
): ):
super().__init__() super().__init__()
self.evaluator = evaluator self.evaluator = evaluator
self.targets = targets
self.roi_mapper = roi_mapper
self.output_transform = output_transform self.output_transform = output_transform
self._clip_annotations: List[data.ClipAnnotation] = [] self._clip_annotations: List[data.ClipAnnotation] = []
@ -93,8 +98,8 @@ class ValidationMetrics(Callback):
model = pl_module.model model = pl_module.model
if self.output_transform is None: if self.output_transform is None:
self.output_transform = build_output_transform( self.output_transform = build_output_transform(
targets=model.targets, targets=self.targets,
roi_mapper=model.roi_mapper, roi_mapper=self.roi_mapper,
) )
output_transform = self.output_transform output_transform = self.output_transform

View File

@ -9,9 +9,7 @@ from batdetect2.train.optimizers import build_optimizer
from batdetect2.train.schedulers import build_scheduler from batdetect2.train.schedulers import build_scheduler
from batdetect2.train.types import LossProtocol, TrainExample from batdetect2.train.types import LossProtocol, TrainExample
__all__ = [ __all__ = ["TrainingModule"]
"TrainingModule",
]
class TrainingModule(L.LightningModule): class TrainingModule(L.LightningModule):
@ -21,6 +19,8 @@ class TrainingModule(L.LightningModule):
def __init__( def __init__(
self, self,
model_config: dict | None = None, model_config: dict | None = None,
class_names: list[str] | None = None,
dimension_names: list[str] | None = None,
train_config: dict | None = None, train_config: dict | None = None,
loss: LossProtocol | None = None, loss: LossProtocol | None = None,
model: Model | None = None, model: Model | None = None,
@ -30,13 +30,31 @@ class TrainingModule(L.LightningModule):
self.save_hyperparameters(ignore=["model", "loss"], logger=False) self.save_hyperparameters(ignore=["model", "loss"], logger=False)
self.model_config = ModelConfig.model_validate(model_config or {}) self.model_config = ModelConfig.model_validate(model_config or {})
self.class_names = list(class_names or [])
self.dimension_names = list(dimension_names or [])
self.train_config = TrainingConfig.model_validate(train_config or {}) self.train_config = TrainingConfig.model_validate(train_config or {})
if loss is None: if loss is None:
loss = build_loss(config=self.train_config.loss) loss = build_loss(config=self.train_config.loss)
if model is None: if model is None:
model = build_model(config=self.model_config) if not self.class_names:
raise ValueError(
"class_names must be provided when rebuilding a training "
"module without a model."
)
if not self.dimension_names:
raise ValueError(
"dimension_names must be provided when rebuilding a "
"training module without a model."
)
model = build_model(
config=self.model_config,
class_names=self.class_names,
dimension_names=self.dimension_names,
)
self.loss = loss self.loss = loss
self.model = model self.model = model
@ -110,8 +128,7 @@ def load_model_from_checkpoint(
------- -------
tuple[Model, ModelConfig] tuple[Model, ModelConfig]
The restored ``Model`` instance and the ``ModelConfig`` that The restored ``Model`` instance and the ``ModelConfig`` that
describes its architecture, preprocessing, postprocessing, and describes its architecture, preprocessing, and postprocessing.
targets.
""" """
module = TrainingModule.load_from_checkpoint(path) # type: ignore module = TrainingModule.load_from_checkpoint(path) # type: ignore
return module.model, module.model_config return module.model, module.model_config
@ -119,6 +136,8 @@ def load_model_from_checkpoint(
def build_training_module( def build_training_module(
model_config: ModelConfig | None = None, model_config: ModelConfig | None = None,
class_names: list[str] | None = None,
dimension_names: list[str] | None = None,
train_config: TrainingConfig | None = None, train_config: TrainingConfig | None = None,
model: Model | None = None, model: Model | None = None,
) -> TrainingModule: ) -> TrainingModule:
@ -130,6 +149,8 @@ def build_training_module(
return TrainingModule( return TrainingModule(
model_config=model_config.model_dump(mode="json"), model_config=model_config.model_dump(mode="json"),
class_names=class_names,
dimension_names=dimension_names,
train_config=train_config.model_dump(mode="json"), train_config=train_config.model_dump(mode="json"),
model=model, model=model,
) )

View File

@ -17,6 +17,7 @@ from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import ( from batdetect2.targets import (
ROIMapperProtocol, ROIMapperProtocol,
TargetConfig,
TargetProtocol, TargetProtocol,
build_roi_mapping, build_roi_mapping,
build_targets, build_targets,
@ -46,6 +47,7 @@ def run_train(
labeller: Optional["ClipLabeller"] = None, labeller: Optional["ClipLabeller"] = None,
audio_config: Optional[AudioConfig] = None, audio_config: Optional[AudioConfig] = None,
model_config: Optional[ModelConfig] = None, model_config: Optional[ModelConfig] = None,
targets_config: TargetConfig | None = None,
train_config: Optional[TrainingConfig] = None, train_config: Optional[TrainingConfig] = None,
logger_config: LoggerConfig | None = None, logger_config: LoggerConfig | None = None,
trainer: Trainer | None = None, trainer: Trainer | None = None,
@ -62,23 +64,34 @@ def run_train(
seed_everything(seed) seed_everything(seed)
model_config = model_config or ModelConfig() model_config = model_config or ModelConfig()
targets_config = targets_config or TargetConfig()
audio_config = audio_config or AudioConfig() audio_config = audio_config or AudioConfig()
train_config = train_config or TrainingConfig() train_config = train_config or TrainingConfig()
if model is not None: if model is not None:
_validate_model_compatibility(model=model, model_config=model_config) if targets is None:
raise ValueError(
"targets must be provided when training with an existing "
"model."
)
if roi_mapper is None:
raise ValueError(
"roi_mapper must be provided when training with an existing "
"model."
)
targets = targets or build_targets(config=targets_config)
roi_mapper = roi_mapper or build_roi_mapping(config=targets_config.roi)
if model is not None: if model is not None:
targets = targets or model.targets _validate_model_compatibility(
model=model,
if roi_mapper is None and targets is model.targets: model_config=model_config,
roi_mapper = model.roi_mapper class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
targets = targets or build_targets(config=model_config.targets) )
roi_mapper = roi_mapper or build_roi_mapping(
config=model_config.targets.roi
)
audio_loader = audio_loader or build_audio_loader(config=audio_config) audio_loader = audio_loader or build_audio_loader(config=audio_config)
@ -119,18 +132,24 @@ def run_train(
module = build_training_module( module = build_training_module(
model_config=model_config, model_config=model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=train_config, train_config=train_config,
model=model, model=model,
) )
evaluator = build_evaluator(
train_config.validation,
targets=targets,
roi_mapper=roi_mapper,
)
trainer = trainer or build_trainer( trainer = trainer or build_trainer(
train_config, train_config,
logger_config=logger_config, logger_config=logger_config,
evaluator=build_evaluator( evaluator=evaluator,
train_config.validation, targets=targets,
targets=targets, roi_mapper=roi_mapper,
roi_mapper=roi_mapper,
),
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
num_epochs=num_epochs, num_epochs=num_epochs,
log_dir=log_dir, log_dir=log_dir,
@ -152,8 +171,14 @@ def run_train(
def _validate_model_compatibility( def _validate_model_compatibility(
model: Model, model: Model,
model_config: ModelConfig, model_config: ModelConfig,
class_names: list[str],
dimension_names: list[str],
) -> None: ) -> None:
reference_model = build_model(config=model_config) reference_model = build_model(
config=model_config,
class_names=class_names,
dimension_names=dimension_names,
)
expected_shapes = { expected_shapes = {
key: tuple(value.shape) key: tuple(value.shape)
@ -196,6 +221,8 @@ def build_trainer(
config: TrainingConfig, config: TrainingConfig,
logger_config: LoggerConfig | None, logger_config: LoggerConfig | None,
evaluator: "EvaluatorProtocol", evaluator: "EvaluatorProtocol",
targets: "TargetProtocol",
roi_mapper: "ROIMapperProtocol",
checkpoint_dir: Path | None = None, checkpoint_dir: Path | None = None,
log_dir: Path | None = None, log_dir: Path | None = None,
experiment_name: str | None = None, experiment_name: str | None = None,
@ -234,6 +261,6 @@ def build_trainer(
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
), ),
ValidationMetrics(evaluator), ValidationMetrics(evaluator, targets, roi_mapper),
], ],
) )

View File

@ -13,13 +13,14 @@ from soundevent import data, terms
from batdetect2.audio import build_audio_loader from batdetect2.audio import build_audio_loader
from batdetect2.audio.clips import build_clipper from batdetect2.audio.clips import build_clipper
from batdetect2.audio.types import AudioLoader, ClipperProtocol from batdetect2.audio.types import AudioLoader, ClipperProtocol
from batdetect2.config import BatDetect2Config
from batdetect2.data import DatasetConfig, load_dataset from batdetect2.data import DatasetConfig, load_dataset
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import ( from batdetect2.targets import (
ROIMapperProtocol,
TargetConfig, TargetConfig,
build_roi_mapping,
build_targets, build_targets,
call_type, call_type,
) )
@ -404,6 +405,13 @@ def sample_targets(
return build_targets(sample_target_config) return build_targets(sample_target_config)
@pytest.fixture
def sample_roi_mapper(
sample_target_config: TargetConfig,
) -> ROIMapperProtocol:
return build_roi_mapping(sample_target_config.roi)
@pytest.fixture @pytest.fixture
def sample_labeller( def sample_labeller(
sample_targets: TargetProtocol, sample_targets: TargetProtocol,
@ -458,8 +466,15 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
@pytest.fixture @pytest.fixture
def tiny_checkpoint_path(tmp_path: Path) -> Path: def tiny_checkpoint_path(
module = build_training_module(model_config=BatDetect2Config().model) sample_targets: TargetProtocol,
sample_roi_mapper: ROIMapperProtocol,
tmp_path: Path,
) -> Path:
module = build_training_module(
class_names=sample_targets.class_names,
dimension_names=sample_roi_mapper.dimension_names,
)
trainer = L.Trainer(enable_checkpointing=False, logger=False) trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "model.ckpt" checkpoint_path = tmp_path / "model.ckpt"
trainer.strategy.connect(module) trainer.strategy.connect(module)

View File

@ -8,20 +8,43 @@ import torch
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import BatDetect2Config from batdetect2.inference import InferenceConfig
from batdetect2.models.detectors import Detector from batdetect2.models.detectors import Detector
from batdetect2.models.heads import ClassifierHead from batdetect2.targets import TargetConfig
from batdetect2.train import load_model_from_checkpoint from batdetect2.train import TrainingConfig, load_model_from_checkpoint
from batdetect2.train.lightning import build_training_module from batdetect2.train.lightning import build_training_module
@pytest.fixture @pytest.fixture
def api_v2() -> BatDetect2API: def train_config() -> TrainingConfig:
"""Train config with a small batch size for testing."""
return TrainingConfig.model_validate({"train_loader": {"batch_size": 2}})
@pytest.fixture
def inference_config() -> InferenceConfig:
"""Inference config with a small batch size for testing."""
return InferenceConfig.model_validate({"loader": {"batch_size": 2}})
@pytest.fixture
def example_targets_config(example_data_dir: Path) -> TargetConfig:
return TargetConfig.load(example_data_dir / "targets.yaml")
@pytest.fixture
def api_v2(
train_config: TrainingConfig,
inference_config: InferenceConfig,
) -> BatDetect2API:
"""User story: users can create a ready-to-use API from config.""" """User story: users can create a ready-to-use API from config."""
config = BatDetect2Config() api = BatDetect2API.from_config(
config.inference.loader.batch_size = 2 train_config=train_config,
return BatDetect2API.from_config(config) inference_config=inference_config,
)
assert api.inference_config.loader.batch_size == 2
return api
def test_process_file_returns_recording_level_predictions( def test_process_file_returns_recording_level_predictions(
@ -30,8 +53,10 @@ def test_process_file_returns_recording_level_predictions(
) -> None: ) -> None:
"""User story: process a file and get detections in recording time.""" """User story: process a file and get detections in recording time."""
# When
prediction = api_v2.process_file(example_audio_files[0]) prediction = api_v2.process_file(example_audio_files[0])
# Then
assert prediction.clip.recording.path == example_audio_files[0] assert prediction.clip.recording.path == example_audio_files[0]
assert prediction.clip.start_time == 0 assert prediction.clip.start_time == 0
assert prediction.clip.end_time == prediction.clip.recording.duration assert prediction.clip.end_time == prediction.clip.recording.duration
@ -53,9 +78,11 @@ def test_process_files_is_batch_size_invariant(
) -> None: ) -> None:
"""User story: changing batch size should not change predictions.""" """User story: changing batch size should not change predictions."""
# When
preds_batch_1 = api_v2.process_files(example_audio_files, batch_size=1) preds_batch_1 = api_v2.process_files(example_audio_files, batch_size=1)
preds_batch_3 = api_v2.process_files(example_audio_files, batch_size=3) preds_batch_3 = api_v2.process_files(example_audio_files, batch_size=3)
# Then
assert len(preds_batch_1) == len(preds_batch_3) assert len(preds_batch_1) == len(preds_batch_3)
by_key_1 = { by_key_1 = {
@ -91,12 +118,14 @@ def test_process_audio_matches_process_spectrogram(
) -> None: ) -> None:
"""User story: users can call either audio or spectrogram entrypoint.""" """User story: users can call either audio or spectrogram entrypoint."""
# When
audio = api_v2.load_audio(example_audio_files[0]) audio = api_v2.load_audio(example_audio_files[0])
from_audio = api_v2.process_audio(audio) from_audio = api_v2.process_audio(audio)
spec = api_v2.generate_spectrogram(audio) spec = api_v2.generate_spectrogram(audio)
from_spec = api_v2.process_spectrogram(spec) from_spec = api_v2.process_spectrogram(spec)
# Then
assert len(from_audio) == len(from_spec) assert len(from_audio) == len(from_spec)
for det_audio, det_spec in zip(from_audio, from_spec, strict=True): for det_audio, det_spec in zip(from_audio, from_spec, strict=True):
@ -116,8 +145,10 @@ def test_process_spectrogram_rejects_batched_input(
) -> None: ) -> None:
"""User story: invalid batched input gives a clear error.""" """User story: invalid batched input gives a clear error."""
# Given
spec = torch.zeros((2, 1, 128, 64), dtype=torch.float32) spec = torch.zeros((2, 1, 128, 64), dtype=torch.float32)
# When/Then
with pytest.raises(ValueError, match="Batched spectrograms not supported"): with pytest.raises(ValueError, match="Batched spectrograms not supported"):
api_v2.process_spectrogram(spec) api_v2.process_spectrogram(spec)
@ -184,26 +215,34 @@ def test_user_can_read_extracted_features_per_detection(
@pytest.mark.slow @pytest.mark.slow
def test_user_can_load_checkpoint_and_finetune( def test_user_can_load_checkpoint_and_finetune(
tmp_path: Path, tmp_path: Path,
example_targets_config: TargetConfig,
example_annotations, example_annotations,
) -> None: ) -> None:
"""User story: load a checkpoint and continue training from it.""" """User story: load a checkpoint and continue training from it."""
module = build_training_module(model_config=BatDetect2Config().model) api = BatDetect2API.from_config(
targets_config=example_targets_config,
)
module = build_training_module(
model_config=api.model_config,
class_names=api.targets.class_names,
dimension_names=api.roi_mapper.dimension_names,
)
trainer = L.Trainer(enable_checkpointing=False, logger=False) trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "base.ckpt" checkpoint_path = tmp_path / "base.ckpt"
trainer.strategy.connect(module) trainer.strategy.connect(module)
trainer.save_checkpoint(checkpoint_path) trainer.save_checkpoint(checkpoint_path)
config = BatDetect2Config() train_config = api.train_config.model_copy(deep=True)
config.train.trainer.limit_train_batches = 1 train_config.trainer.limit_train_batches = 1
config.train.trainer.limit_val_batches = 1 train_config.trainer.limit_val_batches = 1
config.train.trainer.log_every_n_steps = 1 train_config.trainer.log_every_n_steps = 1
config.train.train_loader.batch_size = 1 train_config.train_loader.batch_size = 1
config.train.train_loader.augmentations.enabled = False train_config.train_loader.augmentations.enabled = False
api = BatDetect2API.from_checkpoint( api = BatDetect2API.from_checkpoint(
checkpoint_path, checkpoint_path,
train_config=config.train, train_config=train_config,
) )
finetune_dir = tmp_path / "finetuned" finetune_dir = tmp_path / "finetuned"
@ -222,62 +261,36 @@ def test_user_can_load_checkpoint_and_finetune(
assert checkpoints assert checkpoints
def test_user_can_load_checkpoint_with_new_targets(
tmp_path: Path,
sample_targets,
) -> None:
"""User story: start from checkpoint with a new target definition."""
module = build_training_module(model_config=BatDetect2Config().model)
trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "base_transfer.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(checkpoint_path)
source_model, _ = load_model_from_checkpoint(checkpoint_path)
api = BatDetect2API.from_checkpoint(
checkpoint_path,
targets_config=sample_targets.config,
)
source_detector = cast(Detector, source_model.detector)
detector = cast(Detector, api.model.detector)
classifier_head = cast(ClassifierHead, detector.classifier_head)
assert api.targets.config == sample_targets.config # type: ignore
assert detector.num_classes == len(sample_targets.class_names)
assert (
classifier_head.classifier.out_channels
== len(sample_targets.class_names) + 1
)
source_backbone = source_detector.backbone.state_dict()
target_backbone = detector.backbone.state_dict()
assert source_backbone
for key, value in source_backbone.items():
assert key in target_backbone
torch.testing.assert_close(target_backbone[key], value)
def test_checkpoint_with_same_targets_config_keeps_heads_unchanged( def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
example_targets_config: TargetConfig,
tmp_path: Path, tmp_path: Path,
) -> None: ) -> None:
"""User story: same targets config does not rebuild prediction heads.""" """User story: same targets config does not rebuild prediction heads."""
module = build_training_module(model_config=BatDetect2Config().model) # Given
source_api = BatDetect2API.from_config(
targets_config=example_targets_config
)
module = build_training_module(
model_config=source_api.model_config,
class_names=source_api.targets.class_names,
dimension_names=source_api.roi_mapper.dimension_names,
)
trainer = L.Trainer(enable_checkpointing=False, logger=False) trainer = L.Trainer(enable_checkpointing=False, logger=False)
checkpoint_path = tmp_path / "same_targets.ckpt" checkpoint_path = tmp_path / "same_targets.ckpt"
trainer.strategy.connect(module) trainer.strategy.connect(module)
trainer.save_checkpoint(checkpoint_path) trainer.save_checkpoint(checkpoint_path)
source_model, source_model_config = load_model_from_checkpoint( source_model, _ = load_model_from_checkpoint(checkpoint_path)
checkpoint_path
)
source_detector = cast(Detector, source_model.detector) source_detector = cast(Detector, source_model.detector)
# When
api = BatDetect2API.from_checkpoint( api = BatDetect2API.from_checkpoint(
checkpoint_path, checkpoint_path,
targets_config=source_model_config.targets, targets_config=example_targets_config,
) )
# Then
detector = cast(Detector, api.model.detector) detector = cast(Detector, api.model.detector)
for key, value in source_detector.classifier_head.state_dict().items(): for key, value in source_detector.classifier_head.state_dict().items():
@ -302,7 +315,7 @@ def test_user_can_finetune_only_heads(
) -> None: ) -> None:
"""User story: fine-tune only prediction heads.""" """User story: fine-tune only prediction heads."""
api = BatDetect2API.from_config(BatDetect2Config()) api = BatDetect2API.from_config()
finetune_dir = tmp_path / "heads_only" finetune_dir = tmp_path / "heads_only"
api.finetune( api.finetune(
@ -348,8 +361,6 @@ def test_user_can_evaluate_small_dataset_and_get_metrics(
assert isinstance(metrics, list) assert isinstance(metrics, list)
assert len(metrics) == 1 assert len(metrics) == 1
assert isinstance(metrics[0], dict)
assert len(metrics[0]) > 0
assert isinstance(predictions, list) assert isinstance(predictions, list)
assert len(predictions) == 1 assert len(predictions) == 1
@ -450,8 +461,17 @@ def test_detection_threshold_override_changes_spectrogram_results(
spec = api_v2.generate_spectrogram(audio) spec = api_v2.generate_spectrogram(audio)
default_detections = api_v2.process_spectrogram(spec) default_detections = api_v2.process_spectrogram(spec)
strict_detections = api_v2.process_spectrogram( strict_detections = api_v2.process_spectrogram(
spec, spec, detection_threshold=1.0
detection_threshold=1.0,
) )
assert len(strict_detections) <= len(default_detections) assert len(strict_detections) <= len(default_detections)
def test_user_can_create_api_with_custom_targets_and_model_metadata_matches(
sample_targets,
) -> None:
"""User story: custom targets define model output names for a new API."""
api = BatDetect2API.from_config(targets_config=sample_targets.config)
assert api.model.class_names == sample_targets.class_names

View File

@ -5,7 +5,6 @@ import numpy as np
import pytest import pytest
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import BatDetect2Config
from batdetect2.outputs import build_output_formatter from batdetect2.outputs import build_output_formatter
from batdetect2.outputs.formats import ( from batdetect2.outputs.formats import (
BatDetect2OutputConfig, BatDetect2OutputConfig,
@ -18,7 +17,7 @@ from batdetect2.postprocess.types import ClipDetections
def api_v2() -> BatDetect2API: def api_v2() -> BatDetect2API:
"""User story: API object manages prediction IO formats.""" """User story: API object manages prediction IO formats."""
return BatDetect2API.from_config(BatDetect2Config()) return BatDetect2API.from_config()
@pytest.fixture @pytest.fixture

View File

@ -0,0 +1,40 @@
from soundevent import data
from batdetect2.targets import (
TargetClassConfig,
TargetConfig,
build_targets,
check_target_compatibility,
)
def _target_class(name: str) -> TargetClassConfig:
return TargetClassConfig(
name=name,
tags=[data.Tag(key="class", value=name)],
)
def test_check_target_compatibility_accepts_superset_targets() -> None:
config = TargetConfig(
classification_targets=[
_target_class("pip35"),
_target_class("myo"),
_target_class("extra"),
]
)
targets = build_targets(config)
assert check_target_compatibility(targets, ["pip35", "myo"])
def test_check_target_compatibility_rejects_missing_model_classes() -> None:
config = TargetConfig(
classification_targets=[
_target_class("pip35"),
_target_class("myo"),
]
)
targets = build_targets(config)
assert not check_target_compatibility(targets, ["pip35", "nyc"])

View File

@ -10,9 +10,8 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio.types import AudioLoader from batdetect2.audio.types import AudioLoader
from batdetect2.config import BatDetect2Config
from batdetect2.models import ModelConfig, build_model from batdetect2.models import ModelConfig, build_model
from batdetect2.targets.classes import TargetClassConfig from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
from batdetect2.train import ( from batdetect2.train import (
TrainingConfig, TrainingConfig,
TrainingModule, TrainingModule,
@ -24,11 +23,21 @@ from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
from batdetect2.train.train import build_training_module from batdetect2.train.train import build_training_module
def build_default_module(config: BatDetect2Config | None = None): def build_default_module(
config = config or BatDetect2Config() target_config: TargetConfig | None = None,
model_config: ModelConfig | None = None,
train_config: TrainingConfig | None = None,
):
target_config = target_config or TargetConfig()
model_config = model_config or ModelConfig()
train_config = train_config or TrainingConfig()
targets = build_targets(target_config)
roi_mapper = build_roi_mapping(target_config.roi)
return build_training_module( return build_training_module(
model_config=config.model, model_config=model_config,
train_config=config.train, class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=train_config,
) )
@ -72,8 +81,13 @@ def test_load_model_from_checkpoint_returns_model_and_config(
input_model_config.model_dump(mode="json") input_model_config.model_dump(mode="json")
) )
train_config = TrainingConfig() train_config = TrainingConfig()
targets_config = TargetConfig()
targets = build_targets(targets_config)
roi_mapper = build_roi_mapping(targets_config.roi)
module = build_training_module( module = build_training_module(
model_config=input_model_config, model_config=input_model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=train_config, train_config=train_config,
) )
trainer = L.Trainer() trainer = L.Trainer()
@ -87,6 +101,8 @@ def test_load_model_from_checkpoint_returns_model_and_config(
assert loaded_model_config.model_dump( assert loaded_model_config.model_dump(
mode="json" mode="json"
) == expected_model_config.model_dump(mode="json") ) == expected_model_config.model_dump(mode="json")
assert model.class_names == targets.class_names
assert model.dimension_names == roi_mapper.dimension_names
recovered = TrainingModule.load_from_checkpoint(path) recovered = TrainingModule.load_from_checkpoint(path)
assert recovered.train_config.model_dump( assert recovered.train_config.model_dump(
@ -100,6 +116,9 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
model_config.model_dump(mode="json") model_config.model_dump(mode="json")
) )
train_config = TrainingConfig() train_config = TrainingConfig()
targets_config = TargetConfig()
targets = build_targets(targets_config)
roi_mapper = build_roi_mapping(targets_config.roi)
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4) train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=123) train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=123)
train_config.trainer.max_epochs = 3 train_config.trainer.max_epochs = 3
@ -107,6 +126,8 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
module = build_training_module( module = build_training_module(
model_config=model_config, model_config=model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=train_config, train_config=train_config,
) )
trainer = L.Trainer() trainer = L.Trainer()
@ -131,11 +152,16 @@ def test_configure_optimizers_uses_train_config_values(tmp_path: Path):
model_config.model_dump(mode="json") model_config.model_dump(mode="json")
) )
train_config = TrainingConfig() train_config = TrainingConfig()
targets_config = TargetConfig()
targets = build_targets(targets_config)
roi_mapper = build_roi_mapping(targets_config.roi)
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4) train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=321) train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=321)
module = build_training_module( module = build_training_module(
model_config=model_config, model_config=model_config,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=train_config, train_config=train_config,
) )
@ -189,19 +215,26 @@ def test_train_smoke_produces_loadable_checkpoint(
example_annotations: list[data.ClipAnnotation], example_annotations: list[data.ClipAnnotation],
sample_audio_loader: AudioLoader, sample_audio_loader: AudioLoader,
): ):
config = BatDetect2Config() # Given
config.train.trainer.limit_train_batches = 1 train_config = TrainingConfig.model_validate(
config.train.trainer.limit_val_batches = 1 {
config.train.trainer.log_every_n_steps = 1 "trainer": {
config.train.train_loader.batch_size = 1 "limit_train_batches": 1,
config.train.train_loader.augmentations.enabled = False "limit_val_batches": 1,
"log_every_n_steps": 1,
},
"train_loader": {
"batch_size": 1,
"augmentations": {"enabled": False},
},
}
)
# When
run_train( run_train(
train_annotations=example_annotations[:1], train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1], val_annotations=example_annotations[:1],
train_config=config.train, train_config=train_config,
model_config=config.model,
audio_config=config.audio,
num_epochs=1, num_epochs=1,
train_workers=0, train_workers=0,
val_workers=0, val_workers=0,
@ -209,18 +242,11 @@ def test_train_smoke_produces_loadable_checkpoint(
seed=0, seed=0,
) )
# Then
checkpoints = list(tmp_path.rglob("*.ckpt")) checkpoints = list(tmp_path.rglob("*.ckpt"))
assert checkpoints assert checkpoints
model, model_config = load_model_from_checkpoint(checkpoints[0]) model, model_config = load_model_from_checkpoint(checkpoints[0])
assert model_config.samplerate == config.model.samplerate
assert model_config.architecture.name == config.model.architecture.name
assert model_config.preprocess.model_dump(
mode="json"
) == config.model.preprocess.model_dump(mode="json")
assert model_config.postprocess.model_dump(
mode="json"
) == config.model.postprocess.model_dump(mode="json")
wav = torch.tensor( wav = torch.tensor(
sample_audio_loader.load_clip(example_annotations[0].clip) sample_audio_loader.load_clip(example_annotations[0].clip)
@ -230,10 +256,18 @@ def test_train_smoke_produces_loadable_checkpoint(
def test_build_training_module_uses_provided_model() -> None: def test_build_training_module_uses_provided_model() -> None:
model = build_model(ModelConfig()) targets = build_targets(TargetConfig())
roi_mapper = build_roi_mapping(TargetConfig().roi)
model = build_model(
ModelConfig(),
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
)
module = build_training_module( module = build_training_module(
model_config=ModelConfig(), model_config=ModelConfig(),
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
train_config=TrainingConfig(), train_config=TrainingConfig(),
model=model, model=model,
) )
@ -244,15 +278,18 @@ def test_build_training_module_uses_provided_model() -> None:
def test_run_train_rejects_incompatible_model_config( def test_run_train_rejects_incompatible_model_config(
example_annotations: list[data.ClipAnnotation], example_annotations: list[data.ClipAnnotation],
) -> None: ) -> None:
model = build_model(ModelConfig()) # Given
targets_config = TargetConfig()
targets = build_targets(targets_config)
roi_mapper = build_roi_mapping(targets_config.roi)
incompatible_config = ModelConfig() incompatible_config = ModelConfig()
incompatible_config.targets.classification_targets.append( incompatible_model = build_model(
TargetClassConfig( incompatible_config,
name="dummy_class", class_names=targets.class_names,
tags=[data.Tag(key="class", value="Dummy class")], dimension_names=[*roi_mapper.dimension_names, "extra_dim"],
)
) )
# When/Then
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match="Provided model is incompatible with model_config", match="Provided model is incompatible with model_config",
@ -260,7 +297,10 @@ def test_run_train_rejects_incompatible_model_config(
run_train( run_train(
train_annotations=example_annotations[:1], train_annotations=example_annotations[:1],
val_annotations=None, val_annotations=None,
model=model, model=incompatible_model,
targets=targets,
roi_mapper=roi_mapper,
model_config=incompatible_config, model_config=incompatible_config,
targets_config=targets_config,
train_config=TrainingConfig(), train_config=TrainingConfig(),
) )