From a27d1bbfd37349c9494ce1401c069187b938d82e Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 6 May 2026 12:48:40 +0100 Subject: [PATCH] refactor: derive training config from the model --- src/batdetect2/api_v2.py | 5 ----- src/batdetect2/models/__init__.py | 10 ++++++++++ src/batdetect2/train/logging.py | 7 +++++-- src/batdetect2/train/train.py | 9 ++++++--- tests/test_train/test_lightning.py | 1 - 5 files changed, 21 insertions(+), 11 deletions(-) diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 4044e10..53230c7 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -186,7 +186,6 @@ class BatDetect2API: num_epochs: int | None = None, run_name: str | None = None, seed: int | None = None, - model_config: ModelConfig | None = None, audio_config: AudioConfig | None = None, train_config: TrainingConfig | None = None, logger_config: LoggerConfig | None = None, @@ -217,8 +216,6 @@ class BatDetect2API: Run name used by the configured logger. seed : int | None, optional Random seed for reproducibility. - model_config : ModelConfig | None, optional - Model config override. If omitted, the API model config is used. audio_config : AudioConfig | None, optional Audio config override. train_config : TrainingConfig | None, optional @@ -242,7 +239,6 @@ class BatDetect2API: model=self.model, targets=self.targets, roi_mapper=self.roi_mapper, - model_config=model_config or self.model_config, audio_loader=self.audio_loader, preprocessor=self.preprocessor, train_workers=train_workers, @@ -390,7 +386,6 @@ class BatDetect2API: model=api.model, targets=api.targets, roi_mapper=api.roi_mapper, - model_config=api.model_config, preprocessor=api.preprocessor, audio_loader=api.audio_loader, train_workers=train_workers, diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index 3060532..cc3ab69 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -169,6 +169,7 @@ class Model(torch.nn.Module): postprocessor: PostprocessorProtocol class_names: list[str] dimension_names: list[str] + _config: dict[str, object] def __init__( self, @@ -177,6 +178,7 @@ class Model(torch.nn.Module): postprocessor: PostprocessorProtocol, class_names: list[str], dimension_names: list[str], + config: dict[str, object], ): super().__init__() self.detector = detector @@ -184,6 +186,12 @@ class Model(torch.nn.Module): self.postprocessor = postprocessor self.class_names = class_names self.dimension_names = dimension_names + self._config = config + + def get_config(self) -> dict[str, object]: + """Return the model configuration as plain JSON-serializable data.""" + + return dict(self._config) def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]: """Run the full detection pipeline on a waveform tensor. @@ -287,6 +295,7 @@ def build_model( preprocessor=preprocessor, class_names=class_names, dimension_names=dimension_names, + config=config.model_dump(mode="json"), ) @@ -308,4 +317,5 @@ def build_model_with_new_targets( preprocessor=model.preprocessor, class_names=targets.class_names, dimension_names=roi_mapper.dimension_names, + config=model.get_config(), ) diff --git a/src/batdetect2/train/logging.py b/src/batdetect2/train/logging.py index 75829f4..7e6377f 100644 --- a/src/batdetect2/train/logging.py +++ b/src/batdetect2/train/logging.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + from collections.abc import Sequence from dataclasses import dataclass from pathlib import Path @@ -28,7 +30,7 @@ __all__ = [ @dataclass(frozen=True) class TrainLoggingContext: - model_config: ModelConfig + model_config: dict[str, Any] train_config: TrainingConfig audio_config: AudioConfig targets: TargetProtocol @@ -49,9 +51,10 @@ class ConfigHyperparameterLogging: artifact_path: Path, context: TrainLoggingContext, ) -> None: + model_config = ModelConfig.model_validate(context.model_config) logger.log_hyperparams( { - "model": context.model_config.model_dump( + "model": model_config.model_dump( mode="json", exclude_none=True, ), diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 0387e80..c56d370 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -57,7 +57,6 @@ def run_train( audio_loader: Optional["AudioLoader"] = None, labeller: Optional["ClipLabeller"] = None, audio_config: Optional[AudioConfig] = None, - model_config: Optional[ModelConfig] = None, targets_config: TargetConfig | None = None, train_config: Optional[TrainingConfig] = None, logger_config: LoggerConfig | None = None, @@ -75,7 +74,11 @@ def run_train( if seed is not None: seed_everything(seed) - model_config = model_config or ModelConfig() + model_config = ( + ModelConfig() + if model is None + else ModelConfig.model_validate(model.get_config()) + ) targets_config = targets_config or TargetConfig() audio_config = audio_config or AudioConfig() train_config = train_config or TrainingConfig() @@ -172,7 +175,7 @@ def run_train( root_artifact_path.mkdir(parents=True, exist_ok=True) logging_context = TrainLoggingContext( - model_config=model_config, + model_config=model_config.model_dump(mode="json"), train_config=train_config, audio_config=audio_config, targets=targets, diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index c6a9ccd..9ab9e02 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -451,7 +451,6 @@ def test_run_train_rejects_incompatible_model_config( model=incompatible_model, targets=targets, roi_mapper=roi_mapper, - model_config=incompatible_config, targets_config=targets_config, train_config=TrainingConfig(), )