mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Restructure model config
This commit is contained in:
parent
1a7c0b4b3a
commit
65bd0dc6ae
@ -282,18 +282,18 @@ class BatDetect2API:
|
||||
cls,
|
||||
config: BatDetect2Config,
|
||||
):
|
||||
targets = build_targets(config=config.targets)
|
||||
targets = build_targets(config=config.model.targets)
|
||||
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
|
||||
preprocessor = build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
config=config.model.preprocess,
|
||||
)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.postprocess,
|
||||
config=config.model.postprocess,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
@ -301,18 +301,7 @@ class BatDetect2API:
|
||||
# NOTE: Better to have a separate instance of
|
||||
# preprocessor and postprocessor as these may be moved
|
||||
# to another device.
|
||||
model = build_model(
|
||||
config=config.model,
|
||||
targets=targets,
|
||||
preprocessor=build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
),
|
||||
postprocessor=build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.postprocess,
|
||||
),
|
||||
)
|
||||
model = build_model(config=config.model)
|
||||
|
||||
formatter = build_output_formatter(targets, config=config.output)
|
||||
|
||||
@ -333,24 +322,30 @@ class BatDetect2API:
|
||||
path: data.PathLike,
|
||||
config: BatDetect2Config | None = None,
|
||||
):
|
||||
model, stored_config = load_model_from_checkpoint(path)
|
||||
from batdetect2.audio import AudioConfig
|
||||
|
||||
config = (
|
||||
merge_configs(stored_config, config) if config else stored_config
|
||||
model, model_config = load_model_from_checkpoint(path)
|
||||
|
||||
# Reconstruct a full BatDetect2Config from the checkpoint's
|
||||
# ModelConfig, then overlay any caller-supplied overrides.
|
||||
base = BatDetect2Config(
|
||||
model=model_config,
|
||||
audio=AudioConfig(samplerate=model_config.samplerate),
|
||||
)
|
||||
config = merge_configs(base, config) if config else base
|
||||
|
||||
targets = build_targets(config=config.targets)
|
||||
targets = build_targets(config=config.model.targets)
|
||||
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
|
||||
preprocessor = build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
config=config.model.preprocess,
|
||||
)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.postprocess,
|
||||
config=config.model.postprocess,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
|
||||
@ -12,10 +12,7 @@ from batdetect2.evaluate.config import (
|
||||
get_default_eval_config,
|
||||
)
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
from batdetect2.models.backbones import BackboneConfig, UNetBackboneConfig
|
||||
from batdetect2.postprocess.config import PostprocessConfig
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.targets.config import TargetConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
|
||||
__all__ = [
|
||||
@ -32,13 +29,8 @@ class BatDetect2Config(BaseConfig):
|
||||
evaluation: EvaluationConfig = Field(
|
||||
default_factory=get_default_eval_config
|
||||
)
|
||||
model: BackboneConfig = Field(default_factory=UNetBackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
model: ModelConfig = Field(default_factory=ModelConfig)
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
|
||||
|
||||
|
||||
@ -45,12 +45,8 @@ def evaluate(
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
||||
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
config=config.preprocess,
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
|
||||
targets = targets or build_targets(config=config.targets)
|
||||
preprocessor = preprocessor or model.preprocessor
|
||||
targets = targets or model.targets
|
||||
|
||||
loader = build_test_loader(
|
||||
test_annotations,
|
||||
|
||||
@ -27,7 +27,10 @@ is the ``build_model`` factory function exported from this module.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.models.backbones import (
|
||||
BackboneConfig,
|
||||
UNetBackbone,
|
||||
@ -59,6 +62,9 @@ from batdetect2.models.encoder import (
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||
from batdetect2.postprocess.config import PostprocessConfig
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.targets.config import TargetConfig
|
||||
from batdetect2.typing import (
|
||||
ClipDetectionsTensor,
|
||||
DetectionModel,
|
||||
@ -92,10 +98,50 @@ __all__ = [
|
||||
"build_detector",
|
||||
"load_backbone_config",
|
||||
"Model",
|
||||
"ModelConfig",
|
||||
"build_model",
|
||||
]
|
||||
|
||||
|
||||
class ModelConfig(BaseConfig):
|
||||
"""Complete configuration describing a BatDetect2 model.
|
||||
|
||||
Bundles every parameter that defines a model's behaviour: the input
|
||||
sample rate, backbone architecture, preprocessing pipeline,
|
||||
postprocessing pipeline, and detection targets.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
samplerate : int
|
||||
Expected input audio sample rate in Hz. Audio must be resampled
|
||||
to this rate before being passed to the model. Defaults to
|
||||
``TARGET_SAMPLERATE_HZ`` (256 000 Hz).
|
||||
architecture : BackboneConfig
|
||||
Configuration for the encoder-decoder backbone network. Defaults
|
||||
to ``UNetBackboneConfig()``.
|
||||
preprocess : PreprocessingConfig
|
||||
Parameters for the audio-to-spectrogram preprocessing pipeline
|
||||
(STFT, frequency crop, transforms, resize). Defaults to
|
||||
``PreprocessingConfig()``.
|
||||
postprocess : PostprocessConfig
|
||||
Parameters for converting raw model outputs into detections (NMS
|
||||
kernel, thresholds, top-k limit). Defaults to
|
||||
``PostprocessConfig()``.
|
||||
targets : TargetConfig
|
||||
Detection and classification target definitions (class list,
|
||||
detection target, bounding-box mapper). Defaults to
|
||||
``TargetConfig()``.
|
||||
"""
|
||||
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
architecture: BackboneConfig = Field(default_factory=UNetBackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing.
|
||||
|
||||
@ -166,55 +212,61 @@ class Model(torch.nn.Module):
|
||||
|
||||
|
||||
def build_model(
|
||||
config: BackboneConfig | None = None,
|
||||
config: ModelConfig | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
postprocessor: PostprocessorProtocol | None = None,
|
||||
) -> "Model":
|
||||
) -> Model:
|
||||
"""Build a complete, ready-to-use BatDetect2 model.
|
||||
|
||||
Assembles a ``Model`` instance from optional configuration and component
|
||||
overrides. Any argument left as ``None`` will be replaced by a sensible
|
||||
default built with the project's own builder functions.
|
||||
Assembles a ``Model`` instance from a ``ModelConfig`` and optional
|
||||
component overrides. Any component argument left as ``None`` is built
|
||||
from the configuration. Passing a pre-built component overrides the
|
||||
corresponding config fields for that component only.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : BackboneConfig, optional
|
||||
Configuration describing the backbone architecture (encoder,
|
||||
bottleneck, decoder). Defaults to ``UNetBackboneConfig()`` if not
|
||||
config : ModelConfig, optional
|
||||
Full model configuration (samplerate, architecture, preprocessing,
|
||||
postprocessing, targets). Defaults to ``ModelConfig()`` if not
|
||||
provided.
|
||||
targets : TargetProtocol, optional
|
||||
Describes the target bat species or call types to detect. Determines
|
||||
the number of output classes. Defaults to the standard BatDetect2
|
||||
target set.
|
||||
Pre-built targets object. If given, overrides
|
||||
``config.targets``.
|
||||
preprocessor : PreprocessorProtocol, optional
|
||||
Converts raw audio waveforms to spectrograms. Defaults to the
|
||||
standard BatDetect2 preprocessor.
|
||||
Pre-built preprocessor. If given, overrides
|
||||
``config.preprocess`` and ``config.samplerate`` for the
|
||||
preprocessing step.
|
||||
postprocessor : PostprocessorProtocol, optional
|
||||
Converts raw model outputs to detection tensors. Defaults to the
|
||||
standard BatDetect2 postprocessor. If a custom ``preprocessor`` is
|
||||
given without a matching ``postprocessor``, the default postprocessor
|
||||
will be built using the provided preprocessor so that frequency and
|
||||
time scaling remain consistent.
|
||||
Pre-built postprocessor. If given, overrides
|
||||
``config.postprocess``. When omitted and a custom
|
||||
``preprocessor`` is supplied, the default postprocessor is built
|
||||
using that preprocessor so that frequency and time scaling remain
|
||||
consistent.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Model
|
||||
A fully assembled ``Model`` instance ready for inference or training.
|
||||
A fully assembled ``Model`` instance ready for inference or
|
||||
training.
|
||||
"""
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
|
||||
config = config or UNetBackboneConfig()
|
||||
targets = targets or build_targets()
|
||||
preprocessor = preprocessor or build_preprocessor()
|
||||
config = config or ModelConfig()
|
||||
targets = targets or build_targets(config=config.targets)
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
config=config.preprocess,
|
||||
input_samplerate=config.samplerate,
|
||||
)
|
||||
postprocessor = postprocessor or build_postprocessor(
|
||||
preprocessor=preprocessor,
|
||||
config=config.postprocess,
|
||||
)
|
||||
detector = build_detector(
|
||||
num_classes=len(targets.class_names),
|
||||
config=config,
|
||||
config=config.architecture,
|
||||
)
|
||||
return Model(
|
||||
detector=detector,
|
||||
|
||||
@ -921,20 +921,6 @@ class StandardConvUpBlock(Block):
|
||||
)
|
||||
|
||||
|
||||
LayerConfig = Annotated[
|
||||
ConvConfig
|
||||
| BlockImportConfig
|
||||
| FreqCoordConvDownConfig
|
||||
| StandardConvDownConfig
|
||||
| FreqCoordConvUpConfig
|
||||
| StandardConvUpConfig
|
||||
| SelfAttentionConfig
|
||||
| "LayerGroupConfig",
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configuration models."""
|
||||
|
||||
|
||||
class LayerGroupConfig(BaseConfig):
|
||||
"""Configuration for a ``LayerGroup`` — a sequential chain of blocks.
|
||||
|
||||
@ -951,7 +937,20 @@ class LayerGroupConfig(BaseConfig):
|
||||
"""
|
||||
|
||||
name: Literal["LayerGroup"] = "LayerGroup"
|
||||
layers: list[LayerConfig]
|
||||
layers: list["LayerConfig"]
|
||||
|
||||
|
||||
LayerConfig = Annotated[
|
||||
ConvConfig
|
||||
| FreqCoordConvDownConfig
|
||||
| StandardConvDownConfig
|
||||
| FreqCoordConvUpConfig
|
||||
| StandardConvUpConfig
|
||||
| SelfAttentionConfig
|
||||
| LayerGroupConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configuration models."""
|
||||
|
||||
|
||||
class LayerGroup(nn.Module):
|
||||
|
||||
@ -169,7 +169,7 @@ def build_detector(
|
||||
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building model with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
lambda: config.to_yaml_string(), # type: ignore
|
||||
)
|
||||
backbone = build_backbone(config=config)
|
||||
classifier_head = ClassifierHead(
|
||||
|
||||
@ -6,15 +6,12 @@ from soundevent.data import PathLike
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.models import Model, ModelConfig, build_model
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.typing import ModelOutput, TrainExample
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
pass
|
||||
|
||||
__all__ = [
|
||||
"TrainingModule",
|
||||
@ -26,43 +23,25 @@ class TrainingModule(L.LightningModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: dict | None = None,
|
||||
model_config: dict | None = None,
|
||||
t_max: int = 100,
|
||||
model: Model | None = None,
|
||||
learning_rate: float = 1e-3,
|
||||
loss: torch.nn.Module | None = None,
|
||||
model: Model | None = None,
|
||||
):
|
||||
from batdetect2.config import validate_config
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters(logger=False)
|
||||
self.save_hyperparameters(ignore=["model", "loss"], logger=False)
|
||||
|
||||
self.config = validate_config(config)
|
||||
self.input_samplerate = self.config.audio.samplerate
|
||||
self.learning_rate = self.config.train.optimizer.learning_rate
|
||||
self.model_config = ModelConfig.model_validate(model_config or {})
|
||||
self.learning_rate = learning_rate
|
||||
self.t_max = t_max
|
||||
|
||||
if loss is None:
|
||||
loss = build_loss(self.config.train.loss)
|
||||
loss = build_loss()
|
||||
|
||||
if model is None:
|
||||
targets = build_targets(self.config.targets)
|
||||
|
||||
preprocessor = build_preprocessor(
|
||||
config=self.config.preprocess,
|
||||
input_samplerate=self.input_samplerate,
|
||||
)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor, config=self.config.postprocess
|
||||
)
|
||||
|
||||
model = build_model(
|
||||
config=self.config.model,
|
||||
targets=targets,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
model = build_model(config=self.model_config)
|
||||
|
||||
self.loss = loss
|
||||
self.model = model
|
||||
@ -97,13 +76,39 @@ class TrainingModule(L.LightningModule):
|
||||
|
||||
def load_model_from_checkpoint(
|
||||
path: PathLike,
|
||||
) -> tuple[Model, "BatDetect2Config"]:
|
||||
) -> tuple[Model, ModelConfig]:
|
||||
"""Load a model and its configuration from a Lightning checkpoint.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to a ``.ckpt`` file produced by the BatDetect2 training
|
||||
pipeline.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[Model, ModelConfig]
|
||||
The restored ``Model`` instance and the ``ModelConfig`` that
|
||||
describes its architecture, preprocessing, postprocessing, and
|
||||
targets.
|
||||
"""
|
||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||
return module.model, module.config
|
||||
return module.model, module.model_config
|
||||
|
||||
|
||||
def build_training_module(
|
||||
config: dict | None = None,
|
||||
model_config: dict | None = None,
|
||||
t_max: int = 200,
|
||||
learning_rate: float = 1e-3,
|
||||
loss_config: dict | None = None,
|
||||
) -> TrainingModule:
|
||||
return TrainingModule(config=config, t_max=t_max)
|
||||
from batdetect2.train.config import LossConfig
|
||||
from batdetect2.train.losses import build_loss
|
||||
|
||||
loss = build_loss(LossConfig.model_validate(loss_config or {}))
|
||||
return TrainingModule(
|
||||
model_config=model_config,
|
||||
t_max=t_max,
|
||||
learning_rate=learning_rate,
|
||||
loss=loss,
|
||||
)
|
||||
|
||||
@ -58,13 +58,13 @@ def train(
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
|
||||
targets = targets or build_targets(config=config.targets)
|
||||
targets = targets or build_targets(config=config.model.targets)
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
||||
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
config=config.model.preprocess,
|
||||
)
|
||||
|
||||
labeller = labeller or build_clip_labeler(
|
||||
@ -97,8 +97,10 @@ def train(
|
||||
)
|
||||
|
||||
module = build_training_module(
|
||||
config.model_dump(mode="json"),
|
||||
model_config=config.model.model_dump(mode="json"),
|
||||
t_max=config.train.optimizer.t_max * len(train_dataloader),
|
||||
learning_rate=config.train.optimizer.learning_rate,
|
||||
loss_config=config.train.loss.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
trainer = trainer or build_trainer(
|
||||
|
||||
@ -317,9 +317,11 @@ def convert_results(
|
||||
]
|
||||
|
||||
# combine into final results dictionary
|
||||
results: RunResults = RunResults({ # type: ignore
|
||||
results: RunResults = RunResults( # type: ignore[missing-argument]
|
||||
{
|
||||
"pred_dict": pred_dict,
|
||||
})
|
||||
}
|
||||
)
|
||||
|
||||
# add spectrogram features if they exist
|
||||
if len(spec_feats) > 0 and params["spec_features"]:
|
||||
|
||||
@ -2,8 +2,10 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from batdetect2.models import UNetBackbone
|
||||
from batdetect2.models.backbones import UNetBackboneConfig
|
||||
from batdetect2.models.detectors import Detector, build_detector
|
||||
from batdetect2.models.encoder import Encoder
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.typing.models import ModelOutput
|
||||
|
||||
@ -34,6 +36,8 @@ def test_build_detector_custom_config():
|
||||
|
||||
assert isinstance(model, Detector)
|
||||
assert model.backbone.input_height == 128
|
||||
|
||||
assert isinstance(model.backbone.encoder, Encoder)
|
||||
assert model.backbone.encoder.in_channels == 2
|
||||
|
||||
|
||||
@ -80,6 +84,7 @@ def test_detector_forward_pass_shapes(dummy_spectrogram):
|
||||
)
|
||||
|
||||
# Check features shape: (B, out_channels, H, W)
|
||||
assert isinstance(model.backbone, UNetBackbone)
|
||||
out_channels = model.backbone.out_channels
|
||||
assert output.features.shape == (
|
||||
batch_size,
|
||||
|
||||
@ -12,7 +12,9 @@ from batdetect2.typing.preprocess import AudioLoader
|
||||
|
||||
def build_default_module():
|
||||
config = BatDetect2Config()
|
||||
return build_training_module(config=config.model_dump())
|
||||
return build_training_module(
|
||||
model_config=config.model.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
|
||||
def test_can_initialize_default_module():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user