Restructure model config

This commit is contained in:
mbsantiago 2026-03-17 13:33:13 +00:00
parent 1a7c0b4b3a
commit 65bd0dc6ae
11 changed files with 168 additions and 118 deletions

View File

@ -282,18 +282,18 @@ class BatDetect2API:
cls,
config: BatDetect2Config,
):
targets = build_targets(config=config.targets)
targets = build_targets(config=config.model.targets)
audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
config=config.model.preprocess,
)
postprocessor = build_postprocessor(
preprocessor,
config=config.postprocess,
config=config.model.postprocess,
)
evaluator = build_evaluator(config=config.evaluation, targets=targets)
@ -301,18 +301,7 @@ class BatDetect2API:
# NOTE: Better to have a separate instance of
# preprocessor and postprocessor as these may be moved
# to another device.
model = build_model(
config=config.model,
targets=targets,
preprocessor=build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
),
postprocessor=build_postprocessor(
preprocessor,
config=config.postprocess,
),
)
model = build_model(config=config.model)
formatter = build_output_formatter(targets, config=config.output)
@ -333,24 +322,30 @@ class BatDetect2API:
path: data.PathLike,
config: BatDetect2Config | None = None,
):
model, stored_config = load_model_from_checkpoint(path)
from batdetect2.audio import AudioConfig
config = (
merge_configs(stored_config, config) if config else stored_config
model, model_config = load_model_from_checkpoint(path)
# Reconstruct a full BatDetect2Config from the checkpoint's
# ModelConfig, then overlay any caller-supplied overrides.
base = BatDetect2Config(
model=model_config,
audio=AudioConfig(samplerate=model_config.samplerate),
)
config = merge_configs(base, config) if config else base
targets = build_targets(config=config.targets)
targets = build_targets(config=config.model.targets)
audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
config=config.model.preprocess,
)
postprocessor = build_postprocessor(
preprocessor,
config=config.postprocess,
config=config.model.postprocess,
)
evaluator = build_evaluator(config=config.evaluation, targets=targets)

View File

@ -12,10 +12,7 @@ from batdetect2.evaluate.config import (
get_default_eval_config,
)
from batdetect2.inference.config import InferenceConfig
from batdetect2.models.backbones import BackboneConfig, UNetBackboneConfig
from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.targets.config import TargetConfig
from batdetect2.models import ModelConfig
from batdetect2.train.config import TrainingConfig
__all__ = [
@ -32,13 +29,8 @@ class BatDetect2Config(BaseConfig):
evaluation: EvaluationConfig = Field(
default_factory=get_default_eval_config
)
model: BackboneConfig = Field(default_factory=UNetBackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
model: ModelConfig = Field(default_factory=ModelConfig)
audio: AudioConfig = Field(default_factory=AudioConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
inference: InferenceConfig = Field(default_factory=InferenceConfig)
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)

View File

@ -45,12 +45,8 @@ def evaluate(
audio_loader = audio_loader or build_audio_loader(config=config.audio)
preprocessor = preprocessor or build_preprocessor(
config=config.preprocess,
input_samplerate=audio_loader.samplerate,
)
targets = targets or build_targets(config=config.targets)
preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets
loader = build_test_loader(
test_annotations,

View File

@ -27,7 +27,10 @@ is the ``build_model`` factory function exported from this module.
"""
import torch
from pydantic import Field
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
from batdetect2.core.configs import BaseConfig
from batdetect2.models.backbones import (
BackboneConfig,
UNetBackbone,
@ -59,6 +62,9 @@ from batdetect2.models.encoder import (
build_encoder,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.targets.config import TargetConfig
from batdetect2.typing import (
ClipDetectionsTensor,
DetectionModel,
@ -92,10 +98,50 @@ __all__ = [
"build_detector",
"load_backbone_config",
"Model",
"ModelConfig",
"build_model",
]
class ModelConfig(BaseConfig):
"""Complete configuration describing a BatDetect2 model.
Bundles every parameter that defines a model's behaviour: the input
sample rate, backbone architecture, preprocessing pipeline,
postprocessing pipeline, and detection targets.
Attributes
----------
samplerate : int
Expected input audio sample rate in Hz. Audio must be resampled
to this rate before being passed to the model. Defaults to
``TARGET_SAMPLERATE_HZ`` (256 000 Hz).
architecture : BackboneConfig
Configuration for the encoder-decoder backbone network. Defaults
to ``UNetBackboneConfig()``.
preprocess : PreprocessingConfig
Parameters for the audio-to-spectrogram preprocessing pipeline
(STFT, frequency crop, transforms, resize). Defaults to
``PreprocessingConfig()``.
postprocess : PostprocessConfig
Parameters for converting raw model outputs into detections (NMS
kernel, thresholds, top-k limit). Defaults to
``PostprocessConfig()``.
targets : TargetConfig
Detection and classification target definitions (class list,
detection target, bounding-box mapper). Defaults to
``TargetConfig()``.
"""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
architecture: BackboneConfig = Field(default_factory=UNetBackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
class Model(torch.nn.Module):
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing.
@ -166,55 +212,61 @@ class Model(torch.nn.Module):
def build_model(
config: BackboneConfig | None = None,
config: ModelConfig | None = None,
targets: TargetProtocol | None = None,
preprocessor: PreprocessorProtocol | None = None,
postprocessor: PostprocessorProtocol | None = None,
) -> "Model":
) -> Model:
"""Build a complete, ready-to-use BatDetect2 model.
Assembles a ``Model`` instance from optional configuration and component
overrides. Any argument left as ``None`` will be replaced by a sensible
default built with the project's own builder functions.
Assembles a ``Model`` instance from a ``ModelConfig`` and optional
component overrides. Any component argument left as ``None`` is built
from the configuration. Passing a pre-built component overrides the
corresponding config fields for that component only.
Parameters
----------
config : BackboneConfig, optional
Configuration describing the backbone architecture (encoder,
bottleneck, decoder). Defaults to ``UNetBackboneConfig()`` if not
config : ModelConfig, optional
Full model configuration (samplerate, architecture, preprocessing,
postprocessing, targets). Defaults to ``ModelConfig()`` if not
provided.
targets : TargetProtocol, optional
Describes the target bat species or call types to detect. Determines
the number of output classes. Defaults to the standard BatDetect2
target set.
Pre-built targets object. If given, overrides
``config.targets``.
preprocessor : PreprocessorProtocol, optional
Converts raw audio waveforms to spectrograms. Defaults to the
standard BatDetect2 preprocessor.
Pre-built preprocessor. If given, overrides
``config.preprocess`` and ``config.samplerate`` for the
preprocessing step.
postprocessor : PostprocessorProtocol, optional
Converts raw model outputs to detection tensors. Defaults to the
standard BatDetect2 postprocessor. If a custom ``preprocessor`` is
given without a matching ``postprocessor``, the default postprocessor
will be built using the provided preprocessor so that frequency and
time scaling remain consistent.
Pre-built postprocessor. If given, overrides
``config.postprocess``. When omitted and a custom
``preprocessor`` is supplied, the default postprocessor is built
using that preprocessor so that frequency and time scaling remain
consistent.
Returns
-------
Model
A fully assembled ``Model`` instance ready for inference or training.
A fully assembled ``Model`` instance ready for inference or
training.
"""
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
config = config or UNetBackboneConfig()
targets = targets or build_targets()
preprocessor = preprocessor or build_preprocessor()
config = config or ModelConfig()
targets = targets or build_targets(config=config.targets)
preprocessor = preprocessor or build_preprocessor(
config=config.preprocess,
input_samplerate=config.samplerate,
)
postprocessor = postprocessor or build_postprocessor(
preprocessor=preprocessor,
config=config.postprocess,
)
detector = build_detector(
num_classes=len(targets.class_names),
config=config,
config=config.architecture,
)
return Model(
detector=detector,

View File

@ -921,20 +921,6 @@ class StandardConvUpBlock(Block):
)
LayerConfig = Annotated[
ConvConfig
| BlockImportConfig
| FreqCoordConvDownConfig
| StandardConvDownConfig
| FreqCoordConvUpConfig
| StandardConvUpConfig
| SelfAttentionConfig
| "LayerGroupConfig",
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configuration models."""
class LayerGroupConfig(BaseConfig):
"""Configuration for a ``LayerGroup`` — a sequential chain of blocks.
@ -951,7 +937,20 @@ class LayerGroupConfig(BaseConfig):
"""
name: Literal["LayerGroup"] = "LayerGroup"
layers: list[LayerConfig]
layers: list["LayerConfig"]
LayerConfig = Annotated[
ConvConfig
| FreqCoordConvDownConfig
| StandardConvDownConfig
| FreqCoordConvUpConfig
| StandardConvUpConfig
| SelfAttentionConfig
| LayerGroupConfig,
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configuration models."""
class LayerGroup(nn.Module):

View File

@ -169,7 +169,7 @@ def build_detector(
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(),
lambda: config.to_yaml_string(), # type: ignore
)
backbone = build_backbone(config=config)
classifier_head = ClassifierHead(

View File

@ -6,15 +6,12 @@ from soundevent.data import PathLike
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.models import Model, build_model
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.train.losses import build_loss
from batdetect2.typing import ModelOutput, TrainExample
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
pass
__all__ = [
"TrainingModule",
@ -26,43 +23,25 @@ class TrainingModule(L.LightningModule):
def __init__(
self,
config: dict | None = None,
model_config: dict | None = None,
t_max: int = 100,
model: Model | None = None,
learning_rate: float = 1e-3,
loss: torch.nn.Module | None = None,
model: Model | None = None,
):
from batdetect2.config import validate_config
super().__init__()
self.save_hyperparameters(logger=False)
self.save_hyperparameters(ignore=["model", "loss"], logger=False)
self.config = validate_config(config)
self.input_samplerate = self.config.audio.samplerate
self.learning_rate = self.config.train.optimizer.learning_rate
self.model_config = ModelConfig.model_validate(model_config or {})
self.learning_rate = learning_rate
self.t_max = t_max
if loss is None:
loss = build_loss(self.config.train.loss)
loss = build_loss()
if model is None:
targets = build_targets(self.config.targets)
preprocessor = build_preprocessor(
config=self.config.preprocess,
input_samplerate=self.input_samplerate,
)
postprocessor = build_postprocessor(
preprocessor, config=self.config.postprocess
)
model = build_model(
config=self.config.model,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
model = build_model(config=self.model_config)
self.loss = loss
self.model = model
@ -97,13 +76,39 @@ class TrainingModule(L.LightningModule):
def load_model_from_checkpoint(
path: PathLike,
) -> tuple[Model, "BatDetect2Config"]:
) -> tuple[Model, ModelConfig]:
"""Load a model and its configuration from a Lightning checkpoint.
Parameters
----------
path : PathLike
Path to a ``.ckpt`` file produced by the BatDetect2 training
pipeline.
Returns
-------
tuple[Model, ModelConfig]
The restored ``Model`` instance and the ``ModelConfig`` that
describes its architecture, preprocessing, postprocessing, and
targets.
"""
module = TrainingModule.load_from_checkpoint(path) # type: ignore
return module.model, module.config
return module.model, module.model_config
def build_training_module(
config: dict | None = None,
model_config: dict | None = None,
t_max: int = 200,
learning_rate: float = 1e-3,
loss_config: dict | None = None,
) -> TrainingModule:
return TrainingModule(config=config, t_max=t_max)
from batdetect2.train.config import LossConfig
from batdetect2.train.losses import build_loss
loss = build_loss(LossConfig.model_validate(loss_config or {}))
return TrainingModule(
model_config=model_config,
t_max=t_max,
learning_rate=learning_rate,
loss=loss,
)

View File

@ -58,13 +58,13 @@ def train(
config = config or BatDetect2Config()
targets = targets or build_targets(config=config.targets)
targets = targets or build_targets(config=config.model.targets)
audio_loader = audio_loader or build_audio_loader(config=config.audio)
preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
config=config.model.preprocess,
)
labeller = labeller or build_clip_labeler(
@ -97,8 +97,10 @@ def train(
)
module = build_training_module(
config.model_dump(mode="json"),
model_config=config.model.model_dump(mode="json"),
t_max=config.train.optimizer.t_max * len(train_dataloader),
learning_rate=config.train.optimizer.learning_rate,
loss_config=config.train.loss.model_dump(mode="json"),
)
trainer = trainer or build_trainer(

View File

@ -317,9 +317,11 @@ def convert_results(
]
# combine into final results dictionary
results: RunResults = RunResults({ # type: ignore
results: RunResults = RunResults( # type: ignore[missing-argument]
{
"pred_dict": pred_dict,
})
}
)
# add spectrogram features if they exist
if len(spec_feats) > 0 and params["spec_features"]:

View File

@ -2,8 +2,10 @@ import numpy as np
import pytest
import torch
from batdetect2.models import UNetBackbone
from batdetect2.models.backbones import UNetBackboneConfig
from batdetect2.models.detectors import Detector, build_detector
from batdetect2.models.encoder import Encoder
from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.typing.models import ModelOutput
@ -34,6 +36,8 @@ def test_build_detector_custom_config():
assert isinstance(model, Detector)
assert model.backbone.input_height == 128
assert isinstance(model.backbone.encoder, Encoder)
assert model.backbone.encoder.in_channels == 2
@ -80,6 +84,7 @@ def test_detector_forward_pass_shapes(dummy_spectrogram):
)
# Check features shape: (B, out_channels, H, W)
assert isinstance(model.backbone, UNetBackbone)
out_channels = model.backbone.out_channels
assert output.features.shape == (
batch_size,

View File

@ -12,7 +12,9 @@ from batdetect2.typing.preprocess import AudioLoader
def build_default_module():
config = BatDetect2Config()
return build_training_module(config=config.model_dump())
return build_training_module(
model_config=config.model.model_dump(mode="json"),
)
def test_can_initialize_default_module():