mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
refactor: derive training config from the model
This commit is contained in:
parent
999dc93d88
commit
a27d1bbfd3
@ -186,7 +186,6 @@ class BatDetect2API:
|
|||||||
num_epochs: int | None = None,
|
num_epochs: int | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
seed: int | None = None,
|
seed: int | None = None,
|
||||||
model_config: ModelConfig | None = None,
|
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
@ -217,8 +216,6 @@ class BatDetect2API:
|
|||||||
Run name used by the configured logger.
|
Run name used by the configured logger.
|
||||||
seed : int | None, optional
|
seed : int | None, optional
|
||||||
Random seed for reproducibility.
|
Random seed for reproducibility.
|
||||||
model_config : ModelConfig | None, optional
|
|
||||||
Model config override. If omitted, the API model config is used.
|
|
||||||
audio_config : AudioConfig | None, optional
|
audio_config : AudioConfig | None, optional
|
||||||
Audio config override.
|
Audio config override.
|
||||||
train_config : TrainingConfig | None, optional
|
train_config : TrainingConfig | None, optional
|
||||||
@ -242,7 +239,6 @@ class BatDetect2API:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
roi_mapper=self.roi_mapper,
|
roi_mapper=self.roi_mapper,
|
||||||
model_config=model_config or self.model_config,
|
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
@ -390,7 +386,6 @@ class BatDetect2API:
|
|||||||
model=api.model,
|
model=api.model,
|
||||||
targets=api.targets,
|
targets=api.targets,
|
||||||
roi_mapper=api.roi_mapper,
|
roi_mapper=api.roi_mapper,
|
||||||
model_config=api.model_config,
|
|
||||||
preprocessor=api.preprocessor,
|
preprocessor=api.preprocessor,
|
||||||
audio_loader=api.audio_loader,
|
audio_loader=api.audio_loader,
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
|
|||||||
@ -169,6 +169,7 @@ class Model(torch.nn.Module):
|
|||||||
postprocessor: PostprocessorProtocol
|
postprocessor: PostprocessorProtocol
|
||||||
class_names: list[str]
|
class_names: list[str]
|
||||||
dimension_names: list[str]
|
dimension_names: list[str]
|
||||||
|
_config: dict[str, object]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -177,6 +178,7 @@ class Model(torch.nn.Module):
|
|||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
class_names: list[str],
|
class_names: list[str],
|
||||||
dimension_names: list[str],
|
dimension_names: list[str],
|
||||||
|
config: dict[str, object],
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.detector = detector
|
self.detector = detector
|
||||||
@ -184,6 +186,12 @@ class Model(torch.nn.Module):
|
|||||||
self.postprocessor = postprocessor
|
self.postprocessor = postprocessor
|
||||||
self.class_names = class_names
|
self.class_names = class_names
|
||||||
self.dimension_names = dimension_names
|
self.dimension_names = dimension_names
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, object]:
|
||||||
|
"""Return the model configuration as plain JSON-serializable data."""
|
||||||
|
|
||||||
|
return dict(self._config)
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
||||||
"""Run the full detection pipeline on a waveform tensor.
|
"""Run the full detection pipeline on a waveform tensor.
|
||||||
@ -287,6 +295,7 @@ def build_model(
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
class_names=class_names,
|
class_names=class_names,
|
||||||
dimension_names=dimension_names,
|
dimension_names=dimension_names,
|
||||||
|
config=config.model_dump(mode="json"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -308,4 +317,5 @@ def build_model_with_new_targets(
|
|||||||
preprocessor=model.preprocessor,
|
preprocessor=model.preprocessor,
|
||||||
class_names=targets.class_names,
|
class_names=targets.class_names,
|
||||||
dimension_names=roi_mapper.dimension_names,
|
dimension_names=roi_mapper.dimension_names,
|
||||||
|
config=model.get_config(),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -28,7 +30,7 @@ __all__ = [
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TrainLoggingContext:
|
class TrainLoggingContext:
|
||||||
model_config: ModelConfig
|
model_config: dict[str, Any]
|
||||||
train_config: TrainingConfig
|
train_config: TrainingConfig
|
||||||
audio_config: AudioConfig
|
audio_config: AudioConfig
|
||||||
targets: TargetProtocol
|
targets: TargetProtocol
|
||||||
@ -49,9 +51,10 @@ class ConfigHyperparameterLogging:
|
|||||||
artifact_path: Path,
|
artifact_path: Path,
|
||||||
context: TrainLoggingContext,
|
context: TrainLoggingContext,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
model_config = ModelConfig.model_validate(context.model_config)
|
||||||
logger.log_hyperparams(
|
logger.log_hyperparams(
|
||||||
{
|
{
|
||||||
"model": context.model_config.model_dump(
|
"model": model_config.model_dump(
|
||||||
mode="json",
|
mode="json",
|
||||||
exclude_none=True,
|
exclude_none=True,
|
||||||
),
|
),
|
||||||
|
|||||||
@ -57,7 +57,6 @@ def run_train(
|
|||||||
audio_loader: Optional["AudioLoader"] = None,
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
labeller: Optional["ClipLabeller"] = None,
|
labeller: Optional["ClipLabeller"] = None,
|
||||||
audio_config: Optional[AudioConfig] = None,
|
audio_config: Optional[AudioConfig] = None,
|
||||||
model_config: Optional[ModelConfig] = None,
|
|
||||||
targets_config: TargetConfig | None = None,
|
targets_config: TargetConfig | None = None,
|
||||||
train_config: Optional[TrainingConfig] = None,
|
train_config: Optional[TrainingConfig] = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
@ -75,7 +74,11 @@ def run_train(
|
|||||||
if seed is not None:
|
if seed is not None:
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
|
|
||||||
model_config = model_config or ModelConfig()
|
model_config = (
|
||||||
|
ModelConfig()
|
||||||
|
if model is None
|
||||||
|
else ModelConfig.model_validate(model.get_config())
|
||||||
|
)
|
||||||
targets_config = targets_config or TargetConfig()
|
targets_config = targets_config or TargetConfig()
|
||||||
audio_config = audio_config or AudioConfig()
|
audio_config = audio_config or AudioConfig()
|
||||||
train_config = train_config or TrainingConfig()
|
train_config = train_config or TrainingConfig()
|
||||||
@ -172,7 +175,7 @@ def run_train(
|
|||||||
root_artifact_path.mkdir(parents=True, exist_ok=True)
|
root_artifact_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
logging_context = TrainLoggingContext(
|
logging_context = TrainLoggingContext(
|
||||||
model_config=model_config,
|
model_config=model_config.model_dump(mode="json"),
|
||||||
train_config=train_config,
|
train_config=train_config,
|
||||||
audio_config=audio_config,
|
audio_config=audio_config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
|||||||
@ -451,7 +451,6 @@ def test_run_train_rejects_incompatible_model_config(
|
|||||||
model=incompatible_model,
|
model=incompatible_model,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
model_config=incompatible_config,
|
|
||||||
targets_config=targets_config,
|
targets_config=targets_config,
|
||||||
train_config=TrainingConfig(),
|
train_config=TrainingConfig(),
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user