Add scheduler and optimizer module

This commit is contained in:
mbsantiago 2026-03-17 21:16:41 +00:00
parent 615c7d78fb
commit feee2bdfa3
13 changed files with 395 additions and 78 deletions

View File

@ -64,7 +64,11 @@ model:
train: train:
optimizer: optimizer:
name: adam
learning_rate: 0.001 learning_rate: 0.001
scheduler:
name: cosine_annealing
t_max: 100 t_max: 100
labels: labels:
@ -76,9 +80,7 @@ train:
train_loader: train_loader:
batch_size: 8 batch_size: 8
num_workers: 2 num_workers: 2
shuffle: True shuffle: True
clipping_strategy: clipping_strategy:

View File

@ -86,6 +86,7 @@ dev = [
"rust-just>=1.40.0", "rust-just>=1.40.0",
"pandas-stubs>=2.2.2.240807", "pandas-stubs>=2.2.2.240807",
"python-lsp-server>=1.13.0", "python-lsp-server>=1.13.0",
"deepdiff>=8.6.1",
] ]
dvclive = ["dvclive>=3.48.2"] dvclive = ["dvclive>=3.48.2"]
mlflow = ["mlflow>=3.1.1"] mlflow = ["mlflow>=3.1.1"]

View File

@ -8,6 +8,7 @@ configuration data from files, with optional support for accessing nested
configuration sections. configuration sections.
""" """
import sys
from typing import Any, Type, TypeVar, overload from typing import Any, Type, TypeVar, overload
import yaml import yaml
@ -15,6 +16,11 @@ from deepmerge.merger import Merger
from pydantic import BaseModel, ConfigDict, TypeAdapter from pydantic import BaseModel, ConfigDict, TypeAdapter
from soundevent.data import PathLike from soundevent.data import PathLike
if sys.version_info < (3, 11):
from typing_extensions import Self
else:
from typing import Self
__all__ = [ __all__ = [
"BaseConfig", "BaseConfig",
"load_config", "load_config",
@ -66,6 +72,10 @@ class BaseConfig(BaseModel):
def from_yaml(cls, yaml_str: str): def from_yaml(cls, yaml_str: str):
return cls.model_validate(yaml.safe_load(yaml_str)) return cls.model_validate(yaml.safe_load(yaml_str))
@classmethod
def load(cls: Self, path: PathLike, field: str | None = None) -> Self:
return load_config(path, schema=cls, field=field) # type: ignore
T = TypeVar("T") T = TypeVar("T")
T_Model = TypeVar("T_Model", bound=BaseModel) T_Model = TypeVar("T_Model", bound=BaseModel)

View File

@ -10,8 +10,6 @@ from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.lightning import EvaluationModule from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import build_logger from batdetect2.logging import build_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.typing import Detection from batdetect2.typing import Detection
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -1,8 +1,5 @@
from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR
from batdetect2.train.config import ( from batdetect2.train.config import TrainingConfig, load_train_config
TrainingConfig,
load_train_config,
)
from batdetect2.train.lightning import ( from batdetect2.train.lightning import (
TrainingModule, TrainingModule,
load_model_from_checkpoint, load_model_from_checkpoint,
@ -16,5 +13,5 @@ __all__ = [
"build_trainer", "build_trainer",
"load_model_from_checkpoint", "load_model_from_checkpoint",
"load_train_config", "load_train_config",
"train", "run_train",
] ]

View File

@ -8,6 +8,11 @@ from batdetect2.train.checkpoints import CheckpointConfig
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
from batdetect2.train.labels import LabelConfig from batdetect2.train.labels import LabelConfig
from batdetect2.train.losses import LossConfig from batdetect2.train.losses import LossConfig
from batdetect2.train.optimizers import AdamOptimizerConfig, OptimizerConfig
from batdetect2.train.schedulers import (
CosineAnnealingSchedulerConfig,
SchedulerConfig,
)
__all__ = [ __all__ = [
"TrainingConfig", "TrainingConfig",
@ -36,15 +41,13 @@ class PLTrainerConfig(BaseConfig):
val_check_interval: int | float | None = None val_check_interval: int | float | None = None
class OptimizerConfig(BaseConfig):
learning_rate: float = 1e-3
t_max: int = 100
class TrainingConfig(BaseConfig): class TrainingConfig(BaseConfig):
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig) train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig) val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) optimizer: OptimizerConfig = Field(default_factory=AdamOptimizerConfig)
scheduler: SchedulerConfig = Field(
default_factory=CosineAnnealingSchedulerConfig
)
loss: LossConfig = Field(default_factory=LossConfig) loss: LossConfig = Field(default_factory=LossConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig) logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)

View File

@ -1,11 +1,11 @@
import lightning as L import lightning as L
from soundevent.data import PathLike from soundevent.data import PathLike
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.models import Model, ModelConfig, build_model from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.train.config import TrainingConfig from batdetect2.train.config import TrainingConfig
from batdetect2.train.losses import LossFunction, build_loss from batdetect2.train.losses import build_loss
from batdetect2.train.optimizers import build_optimizer
from batdetect2.train.schedulers import build_scheduler
from batdetect2.typing import LossProtocol, ModelOutput, TrainExample from batdetect2.typing import LossProtocol, ModelOutput, TrainExample
__all__ = [ __all__ = [
@ -21,7 +21,7 @@ class TrainingModule(L.LightningModule):
self, self,
model_config: dict | None = None, model_config: dict | None = None,
train_config: dict | None = None, train_config: dict | None = None,
loss: LossFunction | None = None, loss: LossProtocol | None = None,
model: Model | None = None, model: Model | None = None,
): ):
super().__init__() super().__init__()
@ -67,10 +67,22 @@ class TrainingModule(L.LightningModule):
return outputs return outputs
def configure_optimizers(self): def configure_optimizers(self):
config = self.train_config.optimizer optimizer = build_optimizer(
optimizer = Adam(self.parameters(), lr=config.learning_rate) self.parameters(),
scheduler = CosineAnnealingLR(optimizer, T_max=config.t_max) config=self.train_config.optimizer,
return [optimizer], [scheduler] )
scheduler = build_scheduler(
optimizer,
config=self.train_config.scheduler,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "epoch",
"frequency": 1,
},
}
def load_model_from_checkpoint( def load_model_from_checkpoint(
@ -96,7 +108,16 @@ def load_model_from_checkpoint(
def build_training_module( def build_training_module(
model_config: dict | None = None, model_config: ModelConfig | None = None,
train_config: dict | None = None, train_config: TrainingConfig | None = None,
) -> TrainingModule: ) -> TrainingModule:
return TrainingModule(model_config=model_config, train_config=train_config) if model_config is None:
model_config = ModelConfig()
if train_config is None:
train_config = TrainingConfig()
return TrainingModule(
model_config=model_config.model_dump(mode="json"),
train_config=train_config.model_dump(mode="json"),
)

View File

@ -0,0 +1,87 @@
"""Optimizer configuration and factory utilities for training."""
from collections.abc import Iterable
from typing import Annotated, Literal
from pydantic import Field
from torch import nn
from torch.optim import Adam, Optimizer
from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
__all__ = [
"AdamOptimizerConfig",
"OptimizerConfig",
"OptimizerImportConfig",
"build_optimizer",
"optimizer_registry",
]
class AdamOptimizerConfig(BaseConfig):
"""Configuration for the Adam optimizer.
Attributes
----------
name : Literal["adam"]
Discriminator field used by the optimizer registry.
learning_rate : float
Learning rate used by ``torch.optim.Adam``.
"""
name: Literal["adam"] = "adam"
learning_rate: float = 1e-3
optimizer_registry: Registry[Optimizer, [Iterable[nn.Parameter]]] = Registry(
"optimizer"
)
@add_import_config(optimizer_registry)
class OptimizerImportConfig(ImportConfig):
"""Use any callable as an optimizer.
Set ``name="import"`` and provide a ``target`` pointing to any callable
that returns an optimizer. The training parameters are passed as the
``params`` keyword argument.
"""
name: Literal["import"] = "import"
@optimizer_registry.register(AdamOptimizerConfig)
def build_adam(
config: AdamOptimizerConfig,
params: Iterable[nn.Parameter],
) -> Optimizer:
"""Build an Adam optimizer from configuration."""
return Adam(params, lr=config.learning_rate)
OptimizerConfig = Annotated[
AdamOptimizerConfig | OptimizerImportConfig,
Field(discriminator="name"),
]
def build_optimizer(
parameters: Iterable[nn.Parameter],
config: OptimizerConfig | None = None,
) -> Optimizer:
"""Build an optimizer from configuration.
Parameters
----------
parameters : Iterable[nn.Parameter]
Model parameters to optimize.
config : OptimizerConfig, optional
Optimizer configuration. Defaults to ``AdamOptimizerConfig``.
"""
config = config or AdamOptimizerConfig()
return optimizer_registry.build(config, params=parameters)

View File

@ -0,0 +1,81 @@
"""Scheduler configuration and factory utilities for training."""
from typing import Annotated, Literal
from pydantic import Field
from torch.optim import Optimizer
from torch.optim.lr_scheduler import CosineAnnealingLR, LRScheduler
from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
__all__ = [
"CosineAnnealingSchedulerConfig",
"SchedulerConfig",
"SchedulerImportConfig",
"build_scheduler",
"scheduler_registry",
]
class CosineAnnealingSchedulerConfig(BaseConfig):
"""Configuration for ``CosineAnnealingLR``.
Attributes
----------
name : Literal["cosine_annealing"]
Discriminator field used by the scheduler registry.
t_max : int
Number of epochs to complete one cosine cycle.
"""
name: Literal["cosine_annealing"] = "cosine_annealing"
t_max: int = 100
scheduler_registry: Registry[LRScheduler, [Optimizer]] = Registry("scheduler")
@add_import_config(scheduler_registry)
class SchedulerImportConfig(ImportConfig):
"""Use any callable as a scheduler.
Set ``name="import"`` and provide a ``target`` pointing to any callable
that returns a scheduler. The optimizer instance is passed as the
``optimizer`` keyword argument.
"""
name: Literal["import"] = "import"
@scheduler_registry.register(CosineAnnealingSchedulerConfig)
def build_cosine_scheduler(
config: CosineAnnealingSchedulerConfig,
optimizer: Optimizer,
) -> LRScheduler:
"""Build a cosine annealing scheduler.
``t_max`` is interpreted in epochs because Lightning steps the scheduler
once per epoch when ``interval="epoch"`` is used.
"""
return CosineAnnealingLR(optimizer, T_max=config.t_max)
SchedulerConfig = Annotated[
CosineAnnealingSchedulerConfig | SchedulerImportConfig,
Field(discriminator="name"),
]
def build_scheduler(
optimizer: Optimizer,
config: SchedulerConfig | None = None,
) -> LRScheduler:
"""Build a scheduler from configuration."""
config = config or CosineAnnealingSchedulerConfig()
return scheduler_registry.build(config, optimizer=optimizer)

View File

@ -6,11 +6,13 @@ from lightning import Trainer, seed_everything
from loguru import logger from loguru import logger
from soundevent import data from soundevent import data
from batdetect2.audio import build_audio_loader from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.evaluate import build_evaluator from batdetect2.evaluate import build_evaluator
from batdetect2.logging import build_logger from batdetect2.logging import build_logger
from batdetect2.models import ModelConfig
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.train import TrainingConfig
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.checkpoints import build_checkpoint_callback from batdetect2.train.checkpoints import build_checkpoint_callback
from batdetect2.train.dataset import build_train_loader, build_val_loader from batdetect2.train.dataset import build_train_loader, build_val_loader
@ -18,7 +20,6 @@ from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import build_training_module from batdetect2.train.lightning import build_training_module
if TYPE_CHECKING: if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
from batdetect2.typing import ( from batdetect2.typing import (
AudioLoader, AudioLoader,
ClipLabeller, ClipLabeller,
@ -29,7 +30,7 @@ if TYPE_CHECKING:
__all__ = [ __all__ = [
"build_trainer", "build_trainer",
"train", "run_train",
] ]
@ -40,7 +41,9 @@ def run_train(
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None, audio_loader: Optional["AudioLoader"] = None,
labeller: Optional["ClipLabeller"] = None, labeller: Optional["ClipLabeller"] = None,
config: Optional["BatDetect2Config"] = None, audio_config: Optional[AudioConfig] = None,
model_config: Optional[ModelConfig] = None,
train_config: Optional[TrainingConfig] = None,
trainer: Trainer | None = None, trainer: Trainer | None = None,
train_workers: int | None = None, train_workers: int | None = None,
val_workers: int | None = None, val_workers: int | None = None,
@ -51,27 +54,27 @@ def run_train(
run_name: str | None = None, run_name: str | None = None,
seed: int | None = None, seed: int | None = None,
): ):
from batdetect2.config import BatDetect2Config
if seed is not None: if seed is not None:
seed_everything(seed) seed_everything(seed)
config = config or BatDetect2Config() model_config = model_config or ModelConfig()
audio_config = audio_config or AudioConfig()
train_config = train_config or TrainingConfig()
targets = targets or build_targets(config=config.model.targets) targets = targets or build_targets(config=model_config.targets)
audio_loader = audio_loader or build_audio_loader(config=config.audio) audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or build_preprocessor( preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate, input_samplerate=audio_loader.samplerate,
config=config.model.preprocess, config=model_config.preprocess,
) )
labeller = labeller or build_clip_labeler( labeller = labeller or build_clip_labeler(
targets, targets,
min_freq=preprocessor.min_freq, min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq, max_freq=preprocessor.max_freq,
config=config.train.labels, config=train_config.labels,
) )
train_dataloader = build_train_loader( train_dataloader = build_train_loader(
@ -79,7 +82,7 @@ def run_train(
audio_loader=audio_loader, audio_loader=audio_loader,
labeller=labeller, labeller=labeller,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config.train.train_loader, config=train_config.train_loader,
num_workers=train_workers, num_workers=train_workers,
) )
@ -89,26 +92,22 @@ def run_train(
audio_loader=audio_loader, audio_loader=audio_loader,
labeller=labeller, labeller=labeller,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config.train.val_loader, config=train_config.val_loader,
num_workers=val_workers, num_workers=val_workers,
) )
if val_annotations is not None if val_annotations is not None
else None else None
) )
train_config_dict = config.train.model_dump(mode="json")
if "optimizer" in train_config_dict:
train_config_dict["optimizer"]["t_max"] *= len(train_dataloader)
module = build_training_module( module = build_training_module(
model_config=config.model.model_dump(mode="json"), model_config=model_config,
train_config=train_config_dict, train_config=train_config,
) )
trainer = trainer or build_trainer( trainer = trainer or build_trainer(
config, train_config,
evaluator=build_evaluator( evaluator=build_evaluator(
config.train.validation, train_config.validation,
targets=targets, targets=targets,
), ),
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
@ -130,7 +129,7 @@ def run_train(
def build_trainer( def build_trainer(
config: "BatDetect2Config", config: TrainingConfig,
evaluator: "EvaluatorProtocol", evaluator: "EvaluatorProtocol",
checkpoint_dir: Path | None = None, checkpoint_dir: Path | None = None,
log_dir: Path | None = None, log_dir: Path | None = None,
@ -138,14 +137,14 @@ def build_trainer(
run_name: str | None = None, run_name: str | None = None,
num_epochs: int | None = None, num_epochs: int | None = None,
) -> Trainer: ) -> Trainer:
trainer_conf = config.train.trainer trainer_conf = config.trainer
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}", "Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True), config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
) )
train_logger = build_logger( train_logger = build_logger(
config.train.logger, config.logger,
log_dir=log_dir, log_dir=log_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
@ -168,7 +167,7 @@ def build_trainer(
logger=train_logger, logger=train_logger,
callbacks=[ callbacks=[
build_checkpoint_callback( build_checkpoint_callback(
config=config.train.checkpoints, config=config.checkpoints,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,

View File

@ -2,15 +2,22 @@ from pathlib import Path
import lightning as L import lightning as L
import torch import torch
from deepdiff import DeepDiff
from soundevent import data from soundevent import data
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
from batdetect2.models import ModelConfig
from batdetect2.train import ( from batdetect2.train import (
TrainingConfig,
TrainingModule, TrainingModule,
load_model_from_checkpoint, load_model_from_checkpoint,
run_train, run_train,
) )
from batdetect2.train.optimizers import AdamOptimizerConfig
from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
from batdetect2.train.train import build_training_module from batdetect2.train.train import build_training_module
from batdetect2.typing.preprocess import AudioLoader from batdetect2.typing.preprocess import AudioLoader
@ -18,8 +25,8 @@ from batdetect2.typing.preprocess import AudioLoader
def build_default_module(config: BatDetect2Config | None = None): def build_default_module(config: BatDetect2Config | None = None):
config = config or BatDetect2Config() config = config or BatDetect2Config()
return build_training_module( return build_training_module(
model_config=config.model.model_dump(mode="json"), model_config=config.model,
train_config=config.train.model_dump(mode="json"), train_config=config.train,
) )
@ -57,53 +64,105 @@ def test_can_save_checkpoint(
def test_load_model_from_checkpoint_returns_model_and_config( def test_load_model_from_checkpoint_returns_model_and_config(
tmp_path: Path, tmp_path: Path,
): ):
module = build_default_module() input_model_config = ModelConfig(samplerate=192_000)
expected_model_config = ModelConfig.model_validate(
input_model_config.model_dump(mode="json")
)
train_config = TrainingConfig()
module = build_training_module(
model_config=input_model_config,
train_config=train_config,
)
trainer = L.Trainer() trainer = L.Trainer()
path = tmp_path / "example.ckpt" path = tmp_path / "example.ckpt"
trainer.strategy.connect(module) trainer.strategy.connect(module)
trainer.save_checkpoint(path) trainer.save_checkpoint(path)
model, model_config = load_model_from_checkpoint(path) model, loaded_model_config = load_model_from_checkpoint(path)
assert model is not None assert model is not None
assert model_config.model_dump( assert loaded_model_config.model_dump(
mode="json" mode="json"
) == module.model_config.model_dump(mode="json") ) == expected_model_config.model_dump(mode="json")
recovered = TrainingModule.load_from_checkpoint(path)
assert recovered.train_config.model_dump(
mode="json"
) == train_config.model_dump(mode="json")
def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path): def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
config = BatDetect2Config() model_config = ModelConfig(samplerate=384_000)
config.train.optimizer.learning_rate = 7e-4 expected_model_config = ModelConfig.model_validate(
config.train.optimizer.t_max = 123 model_config.model_dump(mode="json")
)
train_config = TrainingConfig()
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=123)
train_config.trainer.max_epochs = 3
train_config.train_loader.batch_size = 2
module = build_default_module(config=config) module = build_training_module(
model_config=model_config,
train_config=train_config,
)
trainer = L.Trainer() trainer = L.Trainer()
path = tmp_path / "example.ckpt" path = tmp_path / "example.ckpt"
trainer.strategy.connect(module) trainer.strategy.connect(module)
trainer.save_checkpoint(path) trainer.save_checkpoint(path)
checkpoint = torch.load(path, map_location="cpu", weights_only=False) recovered = TrainingModule.load_from_checkpoint(path)
hyper_parameters = checkpoint["hyper_parameters"] assert not DeepDiff(
recovered.model_config.model_dump(mode="json"),
assert ( expected_model_config.model_dump(mode="json"),
hyper_parameters["train_config"]["optimizer"]["learning_rate"] == 7e-4 )
assert not DeepDiff(
recovered.train_config.model_dump(mode="json"),
train_config.model_dump(mode="json"),
) )
assert hyper_parameters["train_config"]["optimizer"]["t_max"] == 123
assert "learning_rate" not in hyper_parameters
assert "t_max" not in hyper_parameters
def test_configure_optimizers_uses_train_config_values(): def test_configure_optimizers_uses_train_config_values(tmp_path: Path):
config = BatDetect2Config() model_config = ModelConfig()
config.train.optimizer.learning_rate = 5e-4 expected_model_config = ModelConfig.model_validate(
config.train.optimizer.t_max = 321 model_config.model_dump(mode="json")
)
train_config = TrainingConfig()
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=321)
module = build_default_module(config=config) module = build_training_module(
model_config=model_config,
train_config=train_config,
)
optimizers, schedulers = module.configure_optimizers() optimization_config = module.configure_optimizers()
optimizer = optimization_config["optimizer"]
scheduler = optimization_config["lr_scheduler"]["scheduler"]
assert optimizers[0].param_groups[0]["lr"] == 5e-4 assert isinstance(optimizer, Adam)
assert schedulers[0].T_max == 321 assert isinstance(scheduler, CosineAnnealingLR)
assert optimizer.param_groups[0]["lr"] == 5e-4
assert scheduler.T_max == 321
trainer = L.Trainer()
path = tmp_path / "example.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(path)
recovered = TrainingModule.load_from_checkpoint(path)
assert recovered.model_config.model_dump(
mode="json"
) == expected_model_config.model_dump(mode="json")
assert recovered.train_config.model_dump(
mode="json"
) == train_config.model_dump(mode="json")
loaded_optimization_config = recovered.configure_optimizers()
loaded_optimizer = loaded_optimization_config["optimizer"]
loaded_scheduler = loaded_optimization_config["lr_scheduler"]["scheduler"]
assert loaded_optimizer.param_groups[0]["lr"] == 5e-4
assert loaded_scheduler.T_max == 321
def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path): def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
@ -136,7 +195,9 @@ def test_train_smoke_produces_loadable_checkpoint(
run_train( run_train(
train_annotations=example_annotations[:1], train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1], val_annotations=example_annotations[:1],
config=config, train_config=config.train,
model_config=config.model,
audio_config=config.audio,
num_epochs=1, num_epochs=1,
train_workers=0, train_workers=0,
val_workers=0, val_workers=0,

View File

@ -0,0 +1,22 @@
from torch import nn
from torch.optim import SGD, Adam
from batdetect2.train.optimizers import OptimizerImportConfig, build_optimizer
def test_build_optimizer_defaults_to_adam():
model = nn.Linear(4, 2)
optimizer = build_optimizer(model.parameters())
assert isinstance(optimizer, Adam)
def test_build_optimizer_supports_import_config():
model = nn.Linear(4, 2)
config = OptimizerImportConfig(
target="torch.optim.SGD",
arguments={"lr": 1e-3},
)
optimizer = build_optimizer(model.parameters(), config=config)
assert isinstance(optimizer, SGD)

View File

@ -0,0 +1,35 @@
from torch import nn
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
from batdetect2.train.schedulers import (
CosineAnnealingSchedulerConfig,
SchedulerImportConfig,
build_scheduler,
)
def test_build_scheduler_uses_epoch_t_max_directly():
model = nn.Linear(4, 2)
optimizer = SGD(model.parameters(), lr=1e-3)
scheduler = build_scheduler(
optimizer,
config=CosineAnnealingSchedulerConfig(t_max=7),
)
assert isinstance(scheduler, CosineAnnealingLR)
assert scheduler.T_max == 7
def test_build_scheduler_supports_import_config():
model = nn.Linear(4, 2)
optimizer = SGD(model.parameters(), lr=1e-3)
scheduler = build_scheduler(
optimizer,
config=SchedulerImportConfig(
target="torch.optim.lr_scheduler.StepLR",
arguments={"step_size": 2},
),
)
assert isinstance(scheduler, StepLR)