refactor: derive training config from the model

This commit is contained in:
mbsantiago 2026-05-06 12:48:40 +01:00
parent 999dc93d88
commit a27d1bbfd3
5 changed files with 21 additions and 11 deletions

View File

@ -186,7 +186,6 @@ class BatDetect2API:
num_epochs: int | None = None,
run_name: str | None = None,
seed: int | None = None,
model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None,
@ -217,8 +216,6 @@ class BatDetect2API:
Run name used by the configured logger.
seed : int | None, optional
Random seed for reproducibility.
model_config : ModelConfig | None, optional
Model config override. If omitted, the API model config is used.
audio_config : AudioConfig | None, optional
Audio config override.
train_config : TrainingConfig | None, optional
@ -242,7 +239,6 @@ class BatDetect2API:
model=self.model,
targets=self.targets,
roi_mapper=self.roi_mapper,
model_config=model_config or self.model_config,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
train_workers=train_workers,
@ -390,7 +386,6 @@ class BatDetect2API:
model=api.model,
targets=api.targets,
roi_mapper=api.roi_mapper,
model_config=api.model_config,
preprocessor=api.preprocessor,
audio_loader=api.audio_loader,
train_workers=train_workers,

View File

@ -169,6 +169,7 @@ class Model(torch.nn.Module):
postprocessor: PostprocessorProtocol
class_names: list[str]
dimension_names: list[str]
_config: dict[str, object]
def __init__(
self,
@ -177,6 +178,7 @@ class Model(torch.nn.Module):
postprocessor: PostprocessorProtocol,
class_names: list[str],
dimension_names: list[str],
config: dict[str, object],
):
super().__init__()
self.detector = detector
@ -184,6 +186,12 @@ class Model(torch.nn.Module):
self.postprocessor = postprocessor
self.class_names = class_names
self.dimension_names = dimension_names
self._config = config
def get_config(self) -> dict[str, object]:
"""Return the model configuration as plain JSON-serializable data."""
return dict(self._config)
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
"""Run the full detection pipeline on a waveform tensor.
@ -287,6 +295,7 @@ def build_model(
preprocessor=preprocessor,
class_names=class_names,
dimension_names=dimension_names,
config=config.model_dump(mode="json"),
)
@ -308,4 +317,5 @@ def build_model_with_new_targets(
preprocessor=model.preprocessor,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
config=model.get_config(),
)

View File

@ -1,5 +1,7 @@
from __future__ import annotations
from typing import Any
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
@ -28,7 +30,7 @@ __all__ = [
@dataclass(frozen=True)
class TrainLoggingContext:
model_config: ModelConfig
model_config: dict[str, Any]
train_config: TrainingConfig
audio_config: AudioConfig
targets: TargetProtocol
@ -49,9 +51,10 @@ class ConfigHyperparameterLogging:
artifact_path: Path,
context: TrainLoggingContext,
) -> None:
model_config = ModelConfig.model_validate(context.model_config)
logger.log_hyperparams(
{
"model": context.model_config.model_dump(
"model": model_config.model_dump(
mode="json",
exclude_none=True,
),

View File

@ -57,7 +57,6 @@ def run_train(
audio_loader: Optional["AudioLoader"] = None,
labeller: Optional["ClipLabeller"] = None,
audio_config: Optional[AudioConfig] = None,
model_config: Optional[ModelConfig] = None,
targets_config: TargetConfig | None = None,
train_config: Optional[TrainingConfig] = None,
logger_config: LoggerConfig | None = None,
@ -75,7 +74,11 @@ def run_train(
if seed is not None:
seed_everything(seed)
model_config = model_config or ModelConfig()
model_config = (
ModelConfig()
if model is None
else ModelConfig.model_validate(model.get_config())
)
targets_config = targets_config or TargetConfig()
audio_config = audio_config or AudioConfig()
train_config = train_config or TrainingConfig()
@ -172,7 +175,7 @@ def run_train(
root_artifact_path.mkdir(parents=True, exist_ok=True)
logging_context = TrainLoggingContext(
model_config=model_config,
model_config=model_config.model_dump(mode="json"),
train_config=train_config,
audio_config=audio_config,
targets=targets,

View File

@ -451,7 +451,6 @@ def test_run_train_rejects_incompatible_model_config(
model=incompatible_model,
targets=targets,
roi_mapper=roi_mapper,
model_config=incompatible_config,
targets_config=targets_config,
train_config=TrainingConfig(),
)