Restructure model config

This commit is contained in:
mbsantiago 2026-03-17 13:33:13 +00:00
parent 1a7c0b4b3a
commit 65bd0dc6ae
11 changed files with 168 additions and 118 deletions

View File

@ -282,18 +282,18 @@ class BatDetect2API:
cls, cls,
config: BatDetect2Config, config: BatDetect2Config,
): ):
targets = build_targets(config=config.targets) targets = build_targets(config=config.model.targets)
audio_loader = build_audio_loader(config=config.audio) audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor( preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate, input_samplerate=audio_loader.samplerate,
config=config.preprocess, config=config.model.preprocess,
) )
postprocessor = build_postprocessor( postprocessor = build_postprocessor(
preprocessor, preprocessor,
config=config.postprocess, config=config.model.postprocess,
) )
evaluator = build_evaluator(config=config.evaluation, targets=targets) evaluator = build_evaluator(config=config.evaluation, targets=targets)
@ -301,18 +301,7 @@ class BatDetect2API:
# NOTE: Better to have a separate instance of # NOTE: Better to have a separate instance of
# preprocessor and postprocessor as these may be moved # preprocessor and postprocessor as these may be moved
# to another device. # to another device.
model = build_model( model = build_model(config=config.model)
config=config.model,
targets=targets,
preprocessor=build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
),
postprocessor=build_postprocessor(
preprocessor,
config=config.postprocess,
),
)
formatter = build_output_formatter(targets, config=config.output) formatter = build_output_formatter(targets, config=config.output)
@ -333,24 +322,30 @@ class BatDetect2API:
path: data.PathLike, path: data.PathLike,
config: BatDetect2Config | None = None, config: BatDetect2Config | None = None,
): ):
model, stored_config = load_model_from_checkpoint(path) from batdetect2.audio import AudioConfig
config = ( model, model_config = load_model_from_checkpoint(path)
merge_configs(stored_config, config) if config else stored_config
# Reconstruct a full BatDetect2Config from the checkpoint's
# ModelConfig, then overlay any caller-supplied overrides.
base = BatDetect2Config(
model=model_config,
audio=AudioConfig(samplerate=model_config.samplerate),
) )
config = merge_configs(base, config) if config else base
targets = build_targets(config=config.targets) targets = build_targets(config=config.model.targets)
audio_loader = build_audio_loader(config=config.audio) audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor( preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate, input_samplerate=audio_loader.samplerate,
config=config.preprocess, config=config.model.preprocess,
) )
postprocessor = build_postprocessor( postprocessor = build_postprocessor(
preprocessor, preprocessor,
config=config.postprocess, config=config.model.postprocess,
) )
evaluator = build_evaluator(config=config.evaluation, targets=targets) evaluator = build_evaluator(config=config.evaluation, targets=targets)

View File

@ -12,10 +12,7 @@ from batdetect2.evaluate.config import (
get_default_eval_config, get_default_eval_config,
) )
from batdetect2.inference.config import InferenceConfig from batdetect2.inference.config import InferenceConfig
from batdetect2.models.backbones import BackboneConfig, UNetBackboneConfig from batdetect2.models import ModelConfig
from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.targets.config import TargetConfig
from batdetect2.train.config import TrainingConfig from batdetect2.train.config import TrainingConfig
__all__ = [ __all__ = [
@ -32,13 +29,8 @@ class BatDetect2Config(BaseConfig):
evaluation: EvaluationConfig = Field( evaluation: EvaluationConfig = Field(
default_factory=get_default_eval_config default_factory=get_default_eval_config
) )
model: BackboneConfig = Field(default_factory=UNetBackboneConfig) model: ModelConfig = Field(default_factory=ModelConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
audio: AudioConfig = Field(default_factory=AudioConfig) audio: AudioConfig = Field(default_factory=AudioConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
inference: InferenceConfig = Field(default_factory=InferenceConfig) inference: InferenceConfig = Field(default_factory=InferenceConfig)
output: OutputFormatConfig = Field(default_factory=RawOutputConfig) output: OutputFormatConfig = Field(default_factory=RawOutputConfig)

View File

@ -45,12 +45,8 @@ def evaluate(
audio_loader = audio_loader or build_audio_loader(config=config.audio) audio_loader = audio_loader or build_audio_loader(config=config.audio)
preprocessor = preprocessor or build_preprocessor( preprocessor = preprocessor or model.preprocessor
config=config.preprocess, targets = targets or model.targets
input_samplerate=audio_loader.samplerate,
)
targets = targets or build_targets(config=config.targets)
loader = build_test_loader( loader = build_test_loader(
test_annotations, test_annotations,

View File

@ -27,7 +27,10 @@ is the ``build_model`` factory function exported from this module.
""" """
import torch import torch
from pydantic import Field
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
from batdetect2.core.configs import BaseConfig
from batdetect2.models.backbones import ( from batdetect2.models.backbones import (
BackboneConfig, BackboneConfig,
UNetBackbone, UNetBackbone,
@ -59,6 +62,9 @@ from batdetect2.models.encoder import (
build_encoder, build_encoder,
) )
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.targets.config import TargetConfig
from batdetect2.typing import ( from batdetect2.typing import (
ClipDetectionsTensor, ClipDetectionsTensor,
DetectionModel, DetectionModel,
@ -92,10 +98,50 @@ __all__ = [
"build_detector", "build_detector",
"load_backbone_config", "load_backbone_config",
"Model", "Model",
"ModelConfig",
"build_model", "build_model",
] ]
class ModelConfig(BaseConfig):
"""Complete configuration describing a BatDetect2 model.
Bundles every parameter that defines a model's behaviour: the input
sample rate, backbone architecture, preprocessing pipeline,
postprocessing pipeline, and detection targets.
Attributes
----------
samplerate : int
Expected input audio sample rate in Hz. Audio must be resampled
to this rate before being passed to the model. Defaults to
``TARGET_SAMPLERATE_HZ`` (256 000 Hz).
architecture : BackboneConfig
Configuration for the encoder-decoder backbone network. Defaults
to ``UNetBackboneConfig()``.
preprocess : PreprocessingConfig
Parameters for the audio-to-spectrogram preprocessing pipeline
(STFT, frequency crop, transforms, resize). Defaults to
``PreprocessingConfig()``.
postprocess : PostprocessConfig
Parameters for converting raw model outputs into detections (NMS
kernel, thresholds, top-k limit). Defaults to
``PostprocessConfig()``.
targets : TargetConfig
Detection and classification target definitions (class list,
detection target, bounding-box mapper). Defaults to
``TargetConfig()``.
"""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
architecture: BackboneConfig = Field(default_factory=UNetBackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
class Model(torch.nn.Module): class Model(torch.nn.Module):
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing. """End-to-end BatDetect2 model wrapping preprocessing and postprocessing.
@ -166,55 +212,61 @@ class Model(torch.nn.Module):
def build_model( def build_model(
config: BackboneConfig | None = None, config: ModelConfig | None = None,
targets: TargetProtocol | None = None, targets: TargetProtocol | None = None,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
postprocessor: PostprocessorProtocol | None = None, postprocessor: PostprocessorProtocol | None = None,
) -> "Model": ) -> Model:
"""Build a complete, ready-to-use BatDetect2 model. """Build a complete, ready-to-use BatDetect2 model.
Assembles a ``Model`` instance from optional configuration and component Assembles a ``Model`` instance from a ``ModelConfig`` and optional
overrides. Any argument left as ``None`` will be replaced by a sensible component overrides. Any component argument left as ``None`` is built
default built with the project's own builder functions. from the configuration. Passing a pre-built component overrides the
corresponding config fields for that component only.
Parameters Parameters
---------- ----------
config : BackboneConfig, optional config : ModelConfig, optional
Configuration describing the backbone architecture (encoder, Full model configuration (samplerate, architecture, preprocessing,
bottleneck, decoder). Defaults to ``UNetBackboneConfig()`` if not postprocessing, targets). Defaults to ``ModelConfig()`` if not
provided. provided.
targets : TargetProtocol, optional targets : TargetProtocol, optional
Describes the target bat species or call types to detect. Determines Pre-built targets object. If given, overrides
the number of output classes. Defaults to the standard BatDetect2 ``config.targets``.
target set.
preprocessor : PreprocessorProtocol, optional preprocessor : PreprocessorProtocol, optional
Converts raw audio waveforms to spectrograms. Defaults to the Pre-built preprocessor. If given, overrides
standard BatDetect2 preprocessor. ``config.preprocess`` and ``config.samplerate`` for the
preprocessing step.
postprocessor : PostprocessorProtocol, optional postprocessor : PostprocessorProtocol, optional
Converts raw model outputs to detection tensors. Defaults to the Pre-built postprocessor. If given, overrides
standard BatDetect2 postprocessor. If a custom ``preprocessor`` is ``config.postprocess``. When omitted and a custom
given without a matching ``postprocessor``, the default postprocessor ``preprocessor`` is supplied, the default postprocessor is built
will be built using the provided preprocessor so that frequency and using that preprocessor so that frequency and time scaling remain
time scaling remain consistent. consistent.
Returns Returns
------- -------
Model Model
A fully assembled ``Model`` instance ready for inference or training. A fully assembled ``Model`` instance ready for inference or
training.
""" """
from batdetect2.postprocess import build_postprocessor from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
config = config or UNetBackboneConfig() config = config or ModelConfig()
targets = targets or build_targets() targets = targets or build_targets(config=config.targets)
preprocessor = preprocessor or build_preprocessor() preprocessor = preprocessor or build_preprocessor(
config=config.preprocess,
input_samplerate=config.samplerate,
)
postprocessor = postprocessor or build_postprocessor( postprocessor = postprocessor or build_postprocessor(
preprocessor=preprocessor, preprocessor=preprocessor,
config=config.postprocess,
) )
detector = build_detector( detector = build_detector(
num_classes=len(targets.class_names), num_classes=len(targets.class_names),
config=config, config=config.architecture,
) )
return Model( return Model(
detector=detector, detector=detector,

View File

@ -921,20 +921,6 @@ class StandardConvUpBlock(Block):
) )
LayerConfig = Annotated[
ConvConfig
| BlockImportConfig
| FreqCoordConvDownConfig
| StandardConvDownConfig
| FreqCoordConvUpConfig
| StandardConvUpConfig
| SelfAttentionConfig
| "LayerGroupConfig",
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configuration models."""
class LayerGroupConfig(BaseConfig): class LayerGroupConfig(BaseConfig):
"""Configuration for a ``LayerGroup`` — a sequential chain of blocks. """Configuration for a ``LayerGroup`` — a sequential chain of blocks.
@ -951,7 +937,20 @@ class LayerGroupConfig(BaseConfig):
""" """
name: Literal["LayerGroup"] = "LayerGroup" name: Literal["LayerGroup"] = "LayerGroup"
layers: list[LayerConfig] layers: list["LayerConfig"]
LayerConfig = Annotated[
ConvConfig
| FreqCoordConvDownConfig
| StandardConvDownConfig
| FreqCoordConvUpConfig
| StandardConvUpConfig
| SelfAttentionConfig
| LayerGroupConfig,
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configuration models."""
class LayerGroup(nn.Module): class LayerGroup(nn.Module):

View File

@ -169,7 +169,7 @@ def build_detector(
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building model with config: \n{}", "Building model with config: \n{}",
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(), # type: ignore
) )
backbone = build_backbone(config=config) backbone = build_backbone(config=config)
classifier_head = ClassifierHead( classifier_head = ClassifierHead(

View File

@ -6,15 +6,12 @@ from soundevent.data import PathLike
from torch.optim.adam import Adam from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.models import Model, build_model from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.losses import build_loss from batdetect2.train.losses import build_loss
from batdetect2.typing import ModelOutput, TrainExample from batdetect2.typing import ModelOutput, TrainExample
if TYPE_CHECKING: if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config pass
__all__ = [ __all__ = [
"TrainingModule", "TrainingModule",
@ -26,43 +23,25 @@ class TrainingModule(L.LightningModule):
def __init__( def __init__(
self, self,
config: dict | None = None, model_config: dict | None = None,
t_max: int = 100, t_max: int = 100,
model: Model | None = None, learning_rate: float = 1e-3,
loss: torch.nn.Module | None = None, loss: torch.nn.Module | None = None,
model: Model | None = None,
): ):
from batdetect2.config import validate_config
super().__init__() super().__init__()
self.save_hyperparameters(logger=False) self.save_hyperparameters(ignore=["model", "loss"], logger=False)
self.config = validate_config(config) self.model_config = ModelConfig.model_validate(model_config or {})
self.input_samplerate = self.config.audio.samplerate self.learning_rate = learning_rate
self.learning_rate = self.config.train.optimizer.learning_rate
self.t_max = t_max self.t_max = t_max
if loss is None: if loss is None:
loss = build_loss(self.config.train.loss) loss = build_loss()
if model is None: if model is None:
targets = build_targets(self.config.targets) model = build_model(config=self.model_config)
preprocessor = build_preprocessor(
config=self.config.preprocess,
input_samplerate=self.input_samplerate,
)
postprocessor = build_postprocessor(
preprocessor, config=self.config.postprocess
)
model = build_model(
config=self.config.model,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
self.loss = loss self.loss = loss
self.model = model self.model = model
@ -97,13 +76,39 @@ class TrainingModule(L.LightningModule):
def load_model_from_checkpoint( def load_model_from_checkpoint(
path: PathLike, path: PathLike,
) -> tuple[Model, "BatDetect2Config"]: ) -> tuple[Model, ModelConfig]:
"""Load a model and its configuration from a Lightning checkpoint.
Parameters
----------
path : PathLike
Path to a ``.ckpt`` file produced by the BatDetect2 training
pipeline.
Returns
-------
tuple[Model, ModelConfig]
The restored ``Model`` instance and the ``ModelConfig`` that
describes its architecture, preprocessing, postprocessing, and
targets.
"""
module = TrainingModule.load_from_checkpoint(path) # type: ignore module = TrainingModule.load_from_checkpoint(path) # type: ignore
return module.model, module.config return module.model, module.model_config
def build_training_module( def build_training_module(
config: dict | None = None, model_config: dict | None = None,
t_max: int = 200, t_max: int = 200,
learning_rate: float = 1e-3,
loss_config: dict | None = None,
) -> TrainingModule: ) -> TrainingModule:
return TrainingModule(config=config, t_max=t_max) from batdetect2.train.config import LossConfig
from batdetect2.train.losses import build_loss
loss = build_loss(LossConfig.model_validate(loss_config or {}))
return TrainingModule(
model_config=model_config,
t_max=t_max,
learning_rate=learning_rate,
loss=loss,
)

View File

@ -58,13 +58,13 @@ def train(
config = config or BatDetect2Config() config = config or BatDetect2Config()
targets = targets or build_targets(config=config.targets) targets = targets or build_targets(config=config.model.targets)
audio_loader = audio_loader or build_audio_loader(config=config.audio) audio_loader = audio_loader or build_audio_loader(config=config.audio)
preprocessor = preprocessor or build_preprocessor( preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate, input_samplerate=audio_loader.samplerate,
config=config.preprocess, config=config.model.preprocess,
) )
labeller = labeller or build_clip_labeler( labeller = labeller or build_clip_labeler(
@ -97,8 +97,10 @@ def train(
) )
module = build_training_module( module = build_training_module(
config.model_dump(mode="json"), model_config=config.model.model_dump(mode="json"),
t_max=config.train.optimizer.t_max * len(train_dataloader), t_max=config.train.optimizer.t_max * len(train_dataloader),
learning_rate=config.train.optimizer.learning_rate,
loss_config=config.train.loss.model_dump(mode="json"),
) )
trainer = trainer or build_trainer( trainer = trainer or build_trainer(

View File

@ -317,9 +317,11 @@ def convert_results(
] ]
# combine into final results dictionary # combine into final results dictionary
results: RunResults = RunResults({ # type: ignore results: RunResults = RunResults( # type: ignore[missing-argument]
{
"pred_dict": pred_dict, "pred_dict": pred_dict,
}) }
)
# add spectrogram features if they exist # add spectrogram features if they exist
if len(spec_feats) > 0 and params["spec_features"]: if len(spec_feats) > 0 and params["spec_features"]:

View File

@ -2,8 +2,10 @@ import numpy as np
import pytest import pytest
import torch import torch
from batdetect2.models import UNetBackbone
from batdetect2.models.backbones import UNetBackboneConfig from batdetect2.models.backbones import UNetBackboneConfig
from batdetect2.models.detectors import Detector, build_detector from batdetect2.models.detectors import Detector, build_detector
from batdetect2.models.encoder import Encoder
from batdetect2.models.heads import BBoxHead, ClassifierHead from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.typing.models import ModelOutput from batdetect2.typing.models import ModelOutput
@ -34,6 +36,8 @@ def test_build_detector_custom_config():
assert isinstance(model, Detector) assert isinstance(model, Detector)
assert model.backbone.input_height == 128 assert model.backbone.input_height == 128
assert isinstance(model.backbone.encoder, Encoder)
assert model.backbone.encoder.in_channels == 2 assert model.backbone.encoder.in_channels == 2
@ -80,6 +84,7 @@ def test_detector_forward_pass_shapes(dummy_spectrogram):
) )
# Check features shape: (B, out_channels, H, W) # Check features shape: (B, out_channels, H, W)
assert isinstance(model.backbone, UNetBackbone)
out_channels = model.backbone.out_channels out_channels = model.backbone.out_channels
assert output.features.shape == ( assert output.features.shape == (
batch_size, batch_size,

View File

@ -12,7 +12,9 @@ from batdetect2.typing.preprocess import AudioLoader
def build_default_module(): def build_default_module():
config = BatDetect2Config() config = BatDetect2Config()
return build_training_module(config=config.model_dump()) return build_training_module(
model_config=config.model.model_dump(mode="json"),
)
def test_can_initialize_default_module(): def test_can_initialize_default_module():