diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index f21d473..9f377a1 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -282,18 +282,18 @@ class BatDetect2API: cls, config: BatDetect2Config, ): - targets = build_targets(config=config.targets) + targets = build_targets(config=config.model.targets) audio_loader = build_audio_loader(config=config.audio) preprocessor = build_preprocessor( input_samplerate=audio_loader.samplerate, - config=config.preprocess, + config=config.model.preprocess, ) postprocessor = build_postprocessor( preprocessor, - config=config.postprocess, + config=config.model.postprocess, ) evaluator = build_evaluator(config=config.evaluation, targets=targets) @@ -301,18 +301,7 @@ class BatDetect2API: # NOTE: Better to have a separate instance of # preprocessor and postprocessor as these may be moved # to another device. - model = build_model( - config=config.model, - targets=targets, - preprocessor=build_preprocessor( - input_samplerate=audio_loader.samplerate, - config=config.preprocess, - ), - postprocessor=build_postprocessor( - preprocessor, - config=config.postprocess, - ), - ) + model = build_model(config=config.model) formatter = build_output_formatter(targets, config=config.output) @@ -333,24 +322,30 @@ class BatDetect2API: path: data.PathLike, config: BatDetect2Config | None = None, ): - model, stored_config = load_model_from_checkpoint(path) + from batdetect2.audio import AudioConfig - config = ( - merge_configs(stored_config, config) if config else stored_config + model, model_config = load_model_from_checkpoint(path) + + # Reconstruct a full BatDetect2Config from the checkpoint's + # ModelConfig, then overlay any caller-supplied overrides. + base = BatDetect2Config( + model=model_config, + audio=AudioConfig(samplerate=model_config.samplerate), ) + config = merge_configs(base, config) if config else base - targets = build_targets(config=config.targets) + targets = build_targets(config=config.model.targets) audio_loader = build_audio_loader(config=config.audio) preprocessor = build_preprocessor( input_samplerate=audio_loader.samplerate, - config=config.preprocess, + config=config.model.preprocess, ) postprocessor = build_postprocessor( preprocessor, - config=config.postprocess, + config=config.model.postprocess, ) evaluator = build_evaluator(config=config.evaluation, targets=targets) diff --git a/src/batdetect2/config.py b/src/batdetect2/config.py index 9b51dfa..6382b30 100644 --- a/src/batdetect2/config.py +++ b/src/batdetect2/config.py @@ -12,10 +12,7 @@ from batdetect2.evaluate.config import ( get_default_eval_config, ) from batdetect2.inference.config import InferenceConfig -from batdetect2.models.backbones import BackboneConfig, UNetBackboneConfig -from batdetect2.postprocess.config import PostprocessConfig -from batdetect2.preprocess.config import PreprocessingConfig -from batdetect2.targets.config import TargetConfig +from batdetect2.models import ModelConfig from batdetect2.train.config import TrainingConfig __all__ = [ @@ -32,13 +29,8 @@ class BatDetect2Config(BaseConfig): evaluation: EvaluationConfig = Field( default_factory=get_default_eval_config ) - model: BackboneConfig = Field(default_factory=UNetBackboneConfig) - preprocess: PreprocessingConfig = Field( - default_factory=PreprocessingConfig - ) - postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) + model: ModelConfig = Field(default_factory=ModelConfig) audio: AudioConfig = Field(default_factory=AudioConfig) - targets: TargetConfig = Field(default_factory=TargetConfig) inference: InferenceConfig = Field(default_factory=InferenceConfig) output: OutputFormatConfig = Field(default_factory=RawOutputConfig) diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index e69bb2b..223e135 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -45,12 +45,8 @@ def evaluate( audio_loader = audio_loader or build_audio_loader(config=config.audio) - preprocessor = preprocessor or build_preprocessor( - config=config.preprocess, - input_samplerate=audio_loader.samplerate, - ) - - targets = targets or build_targets(config=config.targets) + preprocessor = preprocessor or model.preprocessor + targets = targets or model.targets loader = build_test_loader( test_annotations, diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index ff99c77..e65ddef 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -27,7 +27,10 @@ is the ``build_model`` factory function exported from this module. """ import torch +from pydantic import Field +from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ +from batdetect2.core.configs import BaseConfig from batdetect2.models.backbones import ( BackboneConfig, UNetBackbone, @@ -59,6 +62,9 @@ from batdetect2.models.encoder import ( build_encoder, ) from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead +from batdetect2.postprocess.config import PostprocessConfig +from batdetect2.preprocess.config import PreprocessingConfig +from batdetect2.targets.config import TargetConfig from batdetect2.typing import ( ClipDetectionsTensor, DetectionModel, @@ -92,10 +98,50 @@ __all__ = [ "build_detector", "load_backbone_config", "Model", + "ModelConfig", "build_model", ] +class ModelConfig(BaseConfig): + """Complete configuration describing a BatDetect2 model. + + Bundles every parameter that defines a model's behaviour: the input + sample rate, backbone architecture, preprocessing pipeline, + postprocessing pipeline, and detection targets. + + Attributes + ---------- + samplerate : int + Expected input audio sample rate in Hz. Audio must be resampled + to this rate before being passed to the model. Defaults to + ``TARGET_SAMPLERATE_HZ`` (256 000 Hz). + architecture : BackboneConfig + Configuration for the encoder-decoder backbone network. Defaults + to ``UNetBackboneConfig()``. + preprocess : PreprocessingConfig + Parameters for the audio-to-spectrogram preprocessing pipeline + (STFT, frequency crop, transforms, resize). Defaults to + ``PreprocessingConfig()``. + postprocess : PostprocessConfig + Parameters for converting raw model outputs into detections (NMS + kernel, thresholds, top-k limit). Defaults to + ``PostprocessConfig()``. + targets : TargetConfig + Detection and classification target definitions (class list, + detection target, bounding-box mapper). Defaults to + ``TargetConfig()``. + """ + + samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0) + architecture: BackboneConfig = Field(default_factory=UNetBackboneConfig) + preprocess: PreprocessingConfig = Field( + default_factory=PreprocessingConfig + ) + postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) + targets: TargetConfig = Field(default_factory=TargetConfig) + + class Model(torch.nn.Module): """End-to-end BatDetect2 model wrapping preprocessing and postprocessing. @@ -166,55 +212,61 @@ class Model(torch.nn.Module): def build_model( - config: BackboneConfig | None = None, + config: ModelConfig | None = None, targets: TargetProtocol | None = None, preprocessor: PreprocessorProtocol | None = None, postprocessor: PostprocessorProtocol | None = None, -) -> "Model": +) -> Model: """Build a complete, ready-to-use BatDetect2 model. - Assembles a ``Model`` instance from optional configuration and component - overrides. Any argument left as ``None`` will be replaced by a sensible - default built with the project's own builder functions. + Assembles a ``Model`` instance from a ``ModelConfig`` and optional + component overrides. Any component argument left as ``None`` is built + from the configuration. Passing a pre-built component overrides the + corresponding config fields for that component only. Parameters ---------- - config : BackboneConfig, optional - Configuration describing the backbone architecture (encoder, - bottleneck, decoder). Defaults to ``UNetBackboneConfig()`` if not + config : ModelConfig, optional + Full model configuration (samplerate, architecture, preprocessing, + postprocessing, targets). Defaults to ``ModelConfig()`` if not provided. targets : TargetProtocol, optional - Describes the target bat species or call types to detect. Determines - the number of output classes. Defaults to the standard BatDetect2 - target set. + Pre-built targets object. If given, overrides + ``config.targets``. preprocessor : PreprocessorProtocol, optional - Converts raw audio waveforms to spectrograms. Defaults to the - standard BatDetect2 preprocessor. + Pre-built preprocessor. If given, overrides + ``config.preprocess`` and ``config.samplerate`` for the + preprocessing step. postprocessor : PostprocessorProtocol, optional - Converts raw model outputs to detection tensors. Defaults to the - standard BatDetect2 postprocessor. If a custom ``preprocessor`` is - given without a matching ``postprocessor``, the default postprocessor - will be built using the provided preprocessor so that frequency and - time scaling remain consistent. + Pre-built postprocessor. If given, overrides + ``config.postprocess``. When omitted and a custom + ``preprocessor`` is supplied, the default postprocessor is built + using that preprocessor so that frequency and time scaling remain + consistent. Returns ------- Model - A fully assembled ``Model`` instance ready for inference or training. + A fully assembled ``Model`` instance ready for inference or + training. """ from batdetect2.postprocess import build_postprocessor from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets - config = config or UNetBackboneConfig() - targets = targets or build_targets() - preprocessor = preprocessor or build_preprocessor() + config = config or ModelConfig() + targets = targets or build_targets(config=config.targets) + preprocessor = preprocessor or build_preprocessor( + config=config.preprocess, + input_samplerate=config.samplerate, + ) postprocessor = postprocessor or build_postprocessor( preprocessor=preprocessor, + config=config.postprocess, ) detector = build_detector( num_classes=len(targets.class_names), - config=config, + config=config.architecture, ) return Model( detector=detector, diff --git a/src/batdetect2/models/blocks.py b/src/batdetect2/models/blocks.py index f7f96fc..8bf1c69 100644 --- a/src/batdetect2/models/blocks.py +++ b/src/batdetect2/models/blocks.py @@ -921,20 +921,6 @@ class StandardConvUpBlock(Block): ) -LayerConfig = Annotated[ - ConvConfig - | BlockImportConfig - | FreqCoordConvDownConfig - | StandardConvDownConfig - | FreqCoordConvUpConfig - | StandardConvUpConfig - | SelfAttentionConfig - | "LayerGroupConfig", - Field(discriminator="name"), -] -"""Type alias for the discriminated union of block configuration models.""" - - class LayerGroupConfig(BaseConfig): """Configuration for a ``LayerGroup`` — a sequential chain of blocks. @@ -951,7 +937,20 @@ class LayerGroupConfig(BaseConfig): """ name: Literal["LayerGroup"] = "LayerGroup" - layers: list[LayerConfig] + layers: list["LayerConfig"] + + +LayerConfig = Annotated[ + ConvConfig + | FreqCoordConvDownConfig + | StandardConvDownConfig + | FreqCoordConvUpConfig + | StandardConvUpConfig + | SelfAttentionConfig + | LayerGroupConfig, + Field(discriminator="name"), +] +"""Type alias for the discriminated union of block configuration models.""" class LayerGroup(nn.Module): diff --git a/src/batdetect2/models/detectors.py b/src/batdetect2/models/detectors.py index 55fbabf..97238f6 100644 --- a/src/batdetect2/models/detectors.py +++ b/src/batdetect2/models/detectors.py @@ -169,7 +169,7 @@ def build_detector( logger.opt(lazy=True).debug( "Building model with config: \n{}", - lambda: config.to_yaml_string(), + lambda: config.to_yaml_string(), # type: ignore ) backbone = build_backbone(config=config) classifier_head = ClassifierHead( diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index be357fe..a9412a0 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -6,15 +6,12 @@ from soundevent.data import PathLike from torch.optim.adam import Adam from torch.optim.lr_scheduler import CosineAnnealingLR -from batdetect2.models import Model, build_model -from batdetect2.postprocess import build_postprocessor -from batdetect2.preprocess import build_preprocessor -from batdetect2.targets import build_targets +from batdetect2.models import Model, ModelConfig, build_model from batdetect2.train.losses import build_loss from batdetect2.typing import ModelOutput, TrainExample if TYPE_CHECKING: - from batdetect2.config import BatDetect2Config + pass __all__ = [ "TrainingModule", @@ -26,43 +23,25 @@ class TrainingModule(L.LightningModule): def __init__( self, - config: dict | None = None, + model_config: dict | None = None, t_max: int = 100, - model: Model | None = None, + learning_rate: float = 1e-3, loss: torch.nn.Module | None = None, + model: Model | None = None, ): - from batdetect2.config import validate_config - super().__init__() - self.save_hyperparameters(logger=False) + self.save_hyperparameters(ignore=["model", "loss"], logger=False) - self.config = validate_config(config) - self.input_samplerate = self.config.audio.samplerate - self.learning_rate = self.config.train.optimizer.learning_rate + self.model_config = ModelConfig.model_validate(model_config or {}) + self.learning_rate = learning_rate self.t_max = t_max if loss is None: - loss = build_loss(self.config.train.loss) + loss = build_loss() if model is None: - targets = build_targets(self.config.targets) - - preprocessor = build_preprocessor( - config=self.config.preprocess, - input_samplerate=self.input_samplerate, - ) - - postprocessor = build_postprocessor( - preprocessor, config=self.config.postprocess - ) - - model = build_model( - config=self.config.model, - targets=targets, - preprocessor=preprocessor, - postprocessor=postprocessor, - ) + model = build_model(config=self.model_config) self.loss = loss self.model = model @@ -97,13 +76,39 @@ class TrainingModule(L.LightningModule): def load_model_from_checkpoint( path: PathLike, -) -> tuple[Model, "BatDetect2Config"]: +) -> tuple[Model, ModelConfig]: + """Load a model and its configuration from a Lightning checkpoint. + + Parameters + ---------- + path : PathLike + Path to a ``.ckpt`` file produced by the BatDetect2 training + pipeline. + + Returns + ------- + tuple[Model, ModelConfig] + The restored ``Model`` instance and the ``ModelConfig`` that + describes its architecture, preprocessing, postprocessing, and + targets. + """ module = TrainingModule.load_from_checkpoint(path) # type: ignore - return module.model, module.config + return module.model, module.model_config def build_training_module( - config: dict | None = None, + model_config: dict | None = None, t_max: int = 200, + learning_rate: float = 1e-3, + loss_config: dict | None = None, ) -> TrainingModule: - return TrainingModule(config=config, t_max=t_max) + from batdetect2.train.config import LossConfig + from batdetect2.train.losses import build_loss + + loss = build_loss(LossConfig.model_validate(loss_config or {})) + return TrainingModule( + model_config=model_config, + t_max=t_max, + learning_rate=learning_rate, + loss=loss, + ) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 70ff131..66f59c1 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -58,13 +58,13 @@ def train( config = config or BatDetect2Config() - targets = targets or build_targets(config=config.targets) + targets = targets or build_targets(config=config.model.targets) audio_loader = audio_loader or build_audio_loader(config=config.audio) preprocessor = preprocessor or build_preprocessor( input_samplerate=audio_loader.samplerate, - config=config.preprocess, + config=config.model.preprocess, ) labeller = labeller or build_clip_labeler( @@ -97,8 +97,10 @@ def train( ) module = build_training_module( - config.model_dump(mode="json"), + model_config=config.model.model_dump(mode="json"), t_max=config.train.optimizer.t_max * len(train_dataloader), + learning_rate=config.train.optimizer.learning_rate, + loss_config=config.train.loss.model_dump(mode="json"), ) trainer = trainer or build_trainer( diff --git a/src/batdetect2/utils/detector_utils.py b/src/batdetect2/utils/detector_utils.py index fc448c3..9c838b2 100644 --- a/src/batdetect2/utils/detector_utils.py +++ b/src/batdetect2/utils/detector_utils.py @@ -317,9 +317,11 @@ def convert_results( ] # combine into final results dictionary - results: RunResults = RunResults({ # type: ignore - "pred_dict": pred_dict, - }) + results: RunResults = RunResults( # type: ignore[missing-argument] + { + "pred_dict": pred_dict, + } + ) # add spectrogram features if they exist if len(spec_feats) > 0 and params["spec_features"]: diff --git a/tests/test_models/test_detectors.py b/tests/test_models/test_detectors.py index 0edb694..823e07d 100644 --- a/tests/test_models/test_detectors.py +++ b/tests/test_models/test_detectors.py @@ -2,8 +2,10 @@ import numpy as np import pytest import torch +from batdetect2.models import UNetBackbone from batdetect2.models.backbones import UNetBackboneConfig from batdetect2.models.detectors import Detector, build_detector +from batdetect2.models.encoder import Encoder from batdetect2.models.heads import BBoxHead, ClassifierHead from batdetect2.typing.models import ModelOutput @@ -34,6 +36,8 @@ def test_build_detector_custom_config(): assert isinstance(model, Detector) assert model.backbone.input_height == 128 + + assert isinstance(model.backbone.encoder, Encoder) assert model.backbone.encoder.in_channels == 2 @@ -80,6 +84,7 @@ def test_detector_forward_pass_shapes(dummy_spectrogram): ) # Check features shape: (B, out_channels, H, W) + assert isinstance(model.backbone, UNetBackbone) out_channels = model.backbone.out_channels assert output.features.shape == ( batch_size, diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index d467e6b..928961e 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -12,7 +12,9 @@ from batdetect2.typing.preprocess import AudioLoader def build_default_module(): config = BatDetect2Config() - return build_training_module(config=config.model_dump()) + return build_training_module( + model_config=config.model.model_dump(mode="json"), + ) def test_can_initialize_default_module():