refactor: derive training config from the model

This commit is contained in:
mbsantiago 2026-05-06 12:48:40 +01:00
parent 999dc93d88
commit a27d1bbfd3
5 changed files with 21 additions and 11 deletions

View File

@ -186,7 +186,6 @@ class BatDetect2API:
num_epochs: int | None = None, num_epochs: int | None = None,
run_name: str | None = None, run_name: str | None = None,
seed: int | None = None, seed: int | None = None,
model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None, train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None, logger_config: LoggerConfig | None = None,
@ -217,8 +216,6 @@ class BatDetect2API:
Run name used by the configured logger. Run name used by the configured logger.
seed : int | None, optional seed : int | None, optional
Random seed for reproducibility. Random seed for reproducibility.
model_config : ModelConfig | None, optional
Model config override. If omitted, the API model config is used.
audio_config : AudioConfig | None, optional audio_config : AudioConfig | None, optional
Audio config override. Audio config override.
train_config : TrainingConfig | None, optional train_config : TrainingConfig | None, optional
@ -242,7 +239,6 @@ class BatDetect2API:
model=self.model, model=self.model,
targets=self.targets, targets=self.targets,
roi_mapper=self.roi_mapper, roi_mapper=self.roi_mapper,
model_config=model_config or self.model_config,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
train_workers=train_workers, train_workers=train_workers,
@ -390,7 +386,6 @@ class BatDetect2API:
model=api.model, model=api.model,
targets=api.targets, targets=api.targets,
roi_mapper=api.roi_mapper, roi_mapper=api.roi_mapper,
model_config=api.model_config,
preprocessor=api.preprocessor, preprocessor=api.preprocessor,
audio_loader=api.audio_loader, audio_loader=api.audio_loader,
train_workers=train_workers, train_workers=train_workers,

View File

@ -169,6 +169,7 @@ class Model(torch.nn.Module):
postprocessor: PostprocessorProtocol postprocessor: PostprocessorProtocol
class_names: list[str] class_names: list[str]
dimension_names: list[str] dimension_names: list[str]
_config: dict[str, object]
def __init__( def __init__(
self, self,
@ -177,6 +178,7 @@ class Model(torch.nn.Module):
postprocessor: PostprocessorProtocol, postprocessor: PostprocessorProtocol,
class_names: list[str], class_names: list[str],
dimension_names: list[str], dimension_names: list[str],
config: dict[str, object],
): ):
super().__init__() super().__init__()
self.detector = detector self.detector = detector
@ -184,6 +186,12 @@ class Model(torch.nn.Module):
self.postprocessor = postprocessor self.postprocessor = postprocessor
self.class_names = class_names self.class_names = class_names
self.dimension_names = dimension_names self.dimension_names = dimension_names
self._config = config
def get_config(self) -> dict[str, object]:
"""Return the model configuration as plain JSON-serializable data."""
return dict(self._config)
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]: def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
"""Run the full detection pipeline on a waveform tensor. """Run the full detection pipeline on a waveform tensor.
@ -287,6 +295,7 @@ def build_model(
preprocessor=preprocessor, preprocessor=preprocessor,
class_names=class_names, class_names=class_names,
dimension_names=dimension_names, dimension_names=dimension_names,
config=config.model_dump(mode="json"),
) )
@ -308,4 +317,5 @@ def build_model_with_new_targets(
preprocessor=model.preprocessor, preprocessor=model.preprocessor,
class_names=targets.class_names, class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names, dimension_names=roi_mapper.dimension_names,
config=model.get_config(),
) )

View File

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
from typing import Any
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
@ -28,7 +30,7 @@ __all__ = [
@dataclass(frozen=True) @dataclass(frozen=True)
class TrainLoggingContext: class TrainLoggingContext:
model_config: ModelConfig model_config: dict[str, Any]
train_config: TrainingConfig train_config: TrainingConfig
audio_config: AudioConfig audio_config: AudioConfig
targets: TargetProtocol targets: TargetProtocol
@ -49,9 +51,10 @@ class ConfigHyperparameterLogging:
artifact_path: Path, artifact_path: Path,
context: TrainLoggingContext, context: TrainLoggingContext,
) -> None: ) -> None:
model_config = ModelConfig.model_validate(context.model_config)
logger.log_hyperparams( logger.log_hyperparams(
{ {
"model": context.model_config.model_dump( "model": model_config.model_dump(
mode="json", mode="json",
exclude_none=True, exclude_none=True,
), ),

View File

@ -57,7 +57,6 @@ def run_train(
audio_loader: Optional["AudioLoader"] = None, audio_loader: Optional["AudioLoader"] = None,
labeller: Optional["ClipLabeller"] = None, labeller: Optional["ClipLabeller"] = None,
audio_config: Optional[AudioConfig] = None, audio_config: Optional[AudioConfig] = None,
model_config: Optional[ModelConfig] = None,
targets_config: TargetConfig | None = None, targets_config: TargetConfig | None = None,
train_config: Optional[TrainingConfig] = None, train_config: Optional[TrainingConfig] = None,
logger_config: LoggerConfig | None = None, logger_config: LoggerConfig | None = None,
@ -75,7 +74,11 @@ def run_train(
if seed is not None: if seed is not None:
seed_everything(seed) seed_everything(seed)
model_config = model_config or ModelConfig() model_config = (
ModelConfig()
if model is None
else ModelConfig.model_validate(model.get_config())
)
targets_config = targets_config or TargetConfig() targets_config = targets_config or TargetConfig()
audio_config = audio_config or AudioConfig() audio_config = audio_config or AudioConfig()
train_config = train_config or TrainingConfig() train_config = train_config or TrainingConfig()
@ -172,7 +175,7 @@ def run_train(
root_artifact_path.mkdir(parents=True, exist_ok=True) root_artifact_path.mkdir(parents=True, exist_ok=True)
logging_context = TrainLoggingContext( logging_context = TrainLoggingContext(
model_config=model_config, model_config=model_config.model_dump(mode="json"),
train_config=train_config, train_config=train_config,
audio_config=audio_config, audio_config=audio_config,
targets=targets, targets=targets,

View File

@ -451,7 +451,6 @@ def test_run_train_rejects_incompatible_model_config(
model=incompatible_model, model=incompatible_model,
targets=targets, targets=targets,
roi_mapper=roi_mapper, roi_mapper=roi_mapper,
model_config=incompatible_config,
targets_config=targets_config, targets_config=targets_config,
train_config=TrainingConfig(), train_config=TrainingConfig(),
) )