mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
refactor: derive training config from the model
This commit is contained in:
parent
999dc93d88
commit
a27d1bbfd3
@ -186,7 +186,6 @@ class BatDetect2API:
|
||||
num_epochs: int | None = None,
|
||||
run_name: str | None = None,
|
||||
seed: int | None = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
@ -217,8 +216,6 @@ class BatDetect2API:
|
||||
Run name used by the configured logger.
|
||||
seed : int | None, optional
|
||||
Random seed for reproducibility.
|
||||
model_config : ModelConfig | None, optional
|
||||
Model config override. If omitted, the API model config is used.
|
||||
audio_config : AudioConfig | None, optional
|
||||
Audio config override.
|
||||
train_config : TrainingConfig | None, optional
|
||||
@ -242,7 +239,6 @@ class BatDetect2API:
|
||||
model=self.model,
|
||||
targets=self.targets,
|
||||
roi_mapper=self.roi_mapper,
|
||||
model_config=model_config or self.model_config,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
train_workers=train_workers,
|
||||
@ -390,7 +386,6 @@ class BatDetect2API:
|
||||
model=api.model,
|
||||
targets=api.targets,
|
||||
roi_mapper=api.roi_mapper,
|
||||
model_config=api.model_config,
|
||||
preprocessor=api.preprocessor,
|
||||
audio_loader=api.audio_loader,
|
||||
train_workers=train_workers,
|
||||
|
||||
@ -169,6 +169,7 @@ class Model(torch.nn.Module):
|
||||
postprocessor: PostprocessorProtocol
|
||||
class_names: list[str]
|
||||
dimension_names: list[str]
|
||||
_config: dict[str, object]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -177,6 +178,7 @@ class Model(torch.nn.Module):
|
||||
postprocessor: PostprocessorProtocol,
|
||||
class_names: list[str],
|
||||
dimension_names: list[str],
|
||||
config: dict[str, object],
|
||||
):
|
||||
super().__init__()
|
||||
self.detector = detector
|
||||
@ -184,6 +186,12 @@ class Model(torch.nn.Module):
|
||||
self.postprocessor = postprocessor
|
||||
self.class_names = class_names
|
||||
self.dimension_names = dimension_names
|
||||
self._config = config
|
||||
|
||||
def get_config(self) -> dict[str, object]:
|
||||
"""Return the model configuration as plain JSON-serializable data."""
|
||||
|
||||
return dict(self._config)
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
||||
"""Run the full detection pipeline on a waveform tensor.
|
||||
@ -287,6 +295,7 @@ def build_model(
|
||||
preprocessor=preprocessor,
|
||||
class_names=class_names,
|
||||
dimension_names=dimension_names,
|
||||
config=config.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
|
||||
@ -308,4 +317,5 @@ def build_model_with_new_targets(
|
||||
preprocessor=model.preprocessor,
|
||||
class_names=targets.class_names,
|
||||
dimension_names=roi_mapper.dimension_names,
|
||||
config=model.get_config(),
|
||||
)
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@ -28,7 +30,7 @@ __all__ = [
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrainLoggingContext:
|
||||
model_config: ModelConfig
|
||||
model_config: dict[str, Any]
|
||||
train_config: TrainingConfig
|
||||
audio_config: AudioConfig
|
||||
targets: TargetProtocol
|
||||
@ -49,9 +51,10 @@ class ConfigHyperparameterLogging:
|
||||
artifact_path: Path,
|
||||
context: TrainLoggingContext,
|
||||
) -> None:
|
||||
model_config = ModelConfig.model_validate(context.model_config)
|
||||
logger.log_hyperparams(
|
||||
{
|
||||
"model": context.model_config.model_dump(
|
||||
"model": model_config.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
|
||||
@ -57,7 +57,6 @@ def run_train(
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
labeller: Optional["ClipLabeller"] = None,
|
||||
audio_config: Optional[AudioConfig] = None,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
targets_config: TargetConfig | None = None,
|
||||
train_config: Optional[TrainingConfig] = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
@ -75,7 +74,11 @@ def run_train(
|
||||
if seed is not None:
|
||||
seed_everything(seed)
|
||||
|
||||
model_config = model_config or ModelConfig()
|
||||
model_config = (
|
||||
ModelConfig()
|
||||
if model is None
|
||||
else ModelConfig.model_validate(model.get_config())
|
||||
)
|
||||
targets_config = targets_config or TargetConfig()
|
||||
audio_config = audio_config or AudioConfig()
|
||||
train_config = train_config or TrainingConfig()
|
||||
@ -172,7 +175,7 @@ def run_train(
|
||||
root_artifact_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging_context = TrainLoggingContext(
|
||||
model_config=model_config,
|
||||
model_config=model_config.model_dump(mode="json"),
|
||||
train_config=train_config,
|
||||
audio_config=audio_config,
|
||||
targets=targets,
|
||||
|
||||
@ -451,7 +451,6 @@ def test_run_train_rejects_incompatible_model_config(
|
||||
model=incompatible_model,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
model_config=incompatible_config,
|
||||
targets_config=targets_config,
|
||||
train_config=TrainingConfig(),
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user