mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add scheduler and optimizer module
This commit is contained in:
parent
615c7d78fb
commit
feee2bdfa3
@ -64,7 +64,11 @@ model:
|
||||
|
||||
train:
|
||||
optimizer:
|
||||
name: adam
|
||||
learning_rate: 0.001
|
||||
|
||||
scheduler:
|
||||
name: cosine_annealing
|
||||
t_max: 100
|
||||
|
||||
labels:
|
||||
@ -76,9 +80,7 @@ train:
|
||||
|
||||
train_loader:
|
||||
batch_size: 8
|
||||
|
||||
num_workers: 2
|
||||
|
||||
shuffle: True
|
||||
|
||||
clipping_strategy:
|
||||
|
||||
@ -86,6 +86,7 @@ dev = [
|
||||
"rust-just>=1.40.0",
|
||||
"pandas-stubs>=2.2.2.240807",
|
||||
"python-lsp-server>=1.13.0",
|
||||
"deepdiff>=8.6.1",
|
||||
]
|
||||
dvclive = ["dvclive>=3.48.2"]
|
||||
mlflow = ["mlflow>=3.1.1"]
|
||||
|
||||
@ -8,6 +8,7 @@ configuration data from files, with optional support for accessing nested
|
||||
configuration sections.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Any, Type, TypeVar, overload
|
||||
|
||||
import yaml
|
||||
@ -15,6 +16,11 @@ from deepmerge.merger import Merger
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
from soundevent.data import PathLike
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from typing_extensions import Self
|
||||
else:
|
||||
from typing import Self
|
||||
|
||||
__all__ = [
|
||||
"BaseConfig",
|
||||
"load_config",
|
||||
@ -66,6 +72,10 @@ class BaseConfig(BaseModel):
|
||||
def from_yaml(cls, yaml_str: str):
|
||||
return cls.model_validate(yaml.safe_load(yaml_str))
|
||||
|
||||
@classmethod
|
||||
def load(cls: Self, path: PathLike, field: str | None = None) -> Self:
|
||||
return load_config(path, schema=cls, field=field) # type: ignore
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||
|
||||
@ -10,8 +10,6 @@ from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.evaluate.lightning import EvaluationModule
|
||||
from batdetect2.logging import build_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import Detection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@ -1,8 +1,5 @@
|
||||
from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR
|
||||
from batdetect2.train.config import (
|
||||
TrainingConfig,
|
||||
load_train_config,
|
||||
)
|
||||
from batdetect2.train.config import TrainingConfig, load_train_config
|
||||
from batdetect2.train.lightning import (
|
||||
TrainingModule,
|
||||
load_model_from_checkpoint,
|
||||
@ -16,5 +13,5 @@ __all__ = [
|
||||
"build_trainer",
|
||||
"load_model_from_checkpoint",
|
||||
"load_train_config",
|
||||
"train",
|
||||
"run_train",
|
||||
]
|
||||
|
||||
@ -8,6 +8,11 @@ from batdetect2.train.checkpoints import CheckpointConfig
|
||||
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
|
||||
from batdetect2.train.labels import LabelConfig
|
||||
from batdetect2.train.losses import LossConfig
|
||||
from batdetect2.train.optimizers import AdamOptimizerConfig, OptimizerConfig
|
||||
from batdetect2.train.schedulers import (
|
||||
CosineAnnealingSchedulerConfig,
|
||||
SchedulerConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TrainingConfig",
|
||||
@ -36,15 +41,13 @@ class PLTrainerConfig(BaseConfig):
|
||||
val_check_interval: int | float | None = None
|
||||
|
||||
|
||||
class OptimizerConfig(BaseConfig):
|
||||
learning_rate: float = 1e-3
|
||||
t_max: int = 100
|
||||
|
||||
|
||||
class TrainingConfig(BaseConfig):
|
||||
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
|
||||
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
|
||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||
optimizer: OptimizerConfig = Field(default_factory=AdamOptimizerConfig)
|
||||
scheduler: SchedulerConfig = Field(
|
||||
default_factory=CosineAnnealingSchedulerConfig
|
||||
)
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
import lightning as L
|
||||
from soundevent.data import PathLike
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.models import Model, ModelConfig, build_model
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.losses import LossFunction, build_loss
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.train.optimizers import build_optimizer
|
||||
from batdetect2.train.schedulers import build_scheduler
|
||||
from batdetect2.typing import LossProtocol, ModelOutput, TrainExample
|
||||
|
||||
__all__ = [
|
||||
@ -21,7 +21,7 @@ class TrainingModule(L.LightningModule):
|
||||
self,
|
||||
model_config: dict | None = None,
|
||||
train_config: dict | None = None,
|
||||
loss: LossFunction | None = None,
|
||||
loss: LossProtocol | None = None,
|
||||
model: Model | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
@ -67,10 +67,22 @@ class TrainingModule(L.LightningModule):
|
||||
return outputs
|
||||
|
||||
def configure_optimizers(self):
|
||||
config = self.train_config.optimizer
|
||||
optimizer = Adam(self.parameters(), lr=config.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=config.t_max)
|
||||
return [optimizer], [scheduler]
|
||||
optimizer = build_optimizer(
|
||||
self.parameters(),
|
||||
config=self.train_config.optimizer,
|
||||
)
|
||||
scheduler = build_scheduler(
|
||||
optimizer,
|
||||
config=self.train_config.scheduler,
|
||||
)
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
"lr_scheduler": {
|
||||
"scheduler": scheduler,
|
||||
"interval": "epoch",
|
||||
"frequency": 1,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_model_from_checkpoint(
|
||||
@ -96,7 +108,16 @@ def load_model_from_checkpoint(
|
||||
|
||||
|
||||
def build_training_module(
|
||||
model_config: dict | None = None,
|
||||
train_config: dict | None = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
) -> TrainingModule:
|
||||
return TrainingModule(model_config=model_config, train_config=train_config)
|
||||
if model_config is None:
|
||||
model_config = ModelConfig()
|
||||
|
||||
if train_config is None:
|
||||
train_config = TrainingConfig()
|
||||
|
||||
return TrainingModule(
|
||||
model_config=model_config.model_dump(mode="json"),
|
||||
train_config=train_config.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
87
src/batdetect2/train/optimizers.py
Normal file
87
src/batdetect2/train/optimizers.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""Optimizer configuration and factory utilities for training."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
from torch.optim import Adam, Optimizer
|
||||
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AdamOptimizerConfig",
|
||||
"OptimizerConfig",
|
||||
"OptimizerImportConfig",
|
||||
"build_optimizer",
|
||||
"optimizer_registry",
|
||||
]
|
||||
|
||||
|
||||
class AdamOptimizerConfig(BaseConfig):
|
||||
"""Configuration for the Adam optimizer.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : Literal["adam"]
|
||||
Discriminator field used by the optimizer registry.
|
||||
learning_rate : float
|
||||
Learning rate used by ``torch.optim.Adam``.
|
||||
"""
|
||||
|
||||
name: Literal["adam"] = "adam"
|
||||
learning_rate: float = 1e-3
|
||||
|
||||
|
||||
optimizer_registry: Registry[Optimizer, [Iterable[nn.Parameter]]] = Registry(
|
||||
"optimizer"
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(optimizer_registry)
|
||||
class OptimizerImportConfig(ImportConfig):
|
||||
"""Use any callable as an optimizer.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any callable
|
||||
that returns an optimizer. The training parameters are passed as the
|
||||
``params`` keyword argument.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
@optimizer_registry.register(AdamOptimizerConfig)
|
||||
def build_adam(
|
||||
config: AdamOptimizerConfig,
|
||||
params: Iterable[nn.Parameter],
|
||||
) -> Optimizer:
|
||||
"""Build an Adam optimizer from configuration."""
|
||||
return Adam(params, lr=config.learning_rate)
|
||||
|
||||
|
||||
OptimizerConfig = Annotated[
|
||||
AdamOptimizerConfig | OptimizerImportConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_optimizer(
|
||||
parameters: Iterable[nn.Parameter],
|
||||
config: OptimizerConfig | None = None,
|
||||
) -> Optimizer:
|
||||
"""Build an optimizer from configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
parameters : Iterable[nn.Parameter]
|
||||
Model parameters to optimize.
|
||||
config : OptimizerConfig, optional
|
||||
Optimizer configuration. Defaults to ``AdamOptimizerConfig``.
|
||||
"""
|
||||
config = config or AdamOptimizerConfig()
|
||||
return optimizer_registry.build(config, params=parameters)
|
||||
81
src/batdetect2/train/schedulers.py
Normal file
81
src/batdetect2/train/schedulers.py
Normal file
@ -0,0 +1,81 @@
|
||||
"""Scheduler configuration and factory utilities for training."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR, LRScheduler
|
||||
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CosineAnnealingSchedulerConfig",
|
||||
"SchedulerConfig",
|
||||
"SchedulerImportConfig",
|
||||
"build_scheduler",
|
||||
"scheduler_registry",
|
||||
]
|
||||
|
||||
|
||||
class CosineAnnealingSchedulerConfig(BaseConfig):
|
||||
"""Configuration for ``CosineAnnealingLR``.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : Literal["cosine_annealing"]
|
||||
Discriminator field used by the scheduler registry.
|
||||
t_max : int
|
||||
Number of epochs to complete one cosine cycle.
|
||||
"""
|
||||
|
||||
name: Literal["cosine_annealing"] = "cosine_annealing"
|
||||
t_max: int = 100
|
||||
|
||||
|
||||
scheduler_registry: Registry[LRScheduler, [Optimizer]] = Registry("scheduler")
|
||||
|
||||
|
||||
@add_import_config(scheduler_registry)
|
||||
class SchedulerImportConfig(ImportConfig):
|
||||
"""Use any callable as a scheduler.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any callable
|
||||
that returns a scheduler. The optimizer instance is passed as the
|
||||
``optimizer`` keyword argument.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
@scheduler_registry.register(CosineAnnealingSchedulerConfig)
|
||||
def build_cosine_scheduler(
|
||||
config: CosineAnnealingSchedulerConfig,
|
||||
optimizer: Optimizer,
|
||||
) -> LRScheduler:
|
||||
"""Build a cosine annealing scheduler.
|
||||
|
||||
``t_max`` is interpreted in epochs because Lightning steps the scheduler
|
||||
once per epoch when ``interval="epoch"`` is used.
|
||||
"""
|
||||
return CosineAnnealingLR(optimizer, T_max=config.t_max)
|
||||
|
||||
|
||||
SchedulerConfig = Annotated[
|
||||
CosineAnnealingSchedulerConfig | SchedulerImportConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_scheduler(
|
||||
optimizer: Optimizer,
|
||||
config: SchedulerConfig | None = None,
|
||||
) -> LRScheduler:
|
||||
"""Build a scheduler from configuration."""
|
||||
config = config or CosineAnnealingSchedulerConfig()
|
||||
|
||||
return scheduler_registry.build(config, optimizer=optimizer)
|
||||
@ -6,11 +6,13 @@ from lightning import Trainer, seed_everything
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.evaluate import build_evaluator
|
||||
from batdetect2.logging import build_logger
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train import TrainingConfig
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.checkpoints import build_checkpoint_callback
|
||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||
@ -18,7 +20,6 @@ from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipLabeller,
|
||||
@ -29,7 +30,7 @@ if TYPE_CHECKING:
|
||||
|
||||
__all__ = [
|
||||
"build_trainer",
|
||||
"train",
|
||||
"run_train",
|
||||
]
|
||||
|
||||
|
||||
@ -40,7 +41,9 @@ def run_train(
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
labeller: Optional["ClipLabeller"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
audio_config: Optional[AudioConfig] = None,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
train_config: Optional[TrainingConfig] = None,
|
||||
trainer: Trainer | None = None,
|
||||
train_workers: int | None = None,
|
||||
val_workers: int | None = None,
|
||||
@ -51,27 +54,27 @@ def run_train(
|
||||
run_name: str | None = None,
|
||||
seed: int | None = None,
|
||||
):
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
if seed is not None:
|
||||
seed_everything(seed)
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
model_config = model_config or ModelConfig()
|
||||
audio_config = audio_config or AudioConfig()
|
||||
train_config = train_config or TrainingConfig()
|
||||
|
||||
targets = targets or build_targets(config=config.model.targets)
|
||||
targets = targets or build_targets(config=model_config.targets)
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.model.preprocess,
|
||||
config=model_config.preprocess,
|
||||
)
|
||||
|
||||
labeller = labeller or build_clip_labeler(
|
||||
targets,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
config=config.train.labels,
|
||||
config=train_config.labels,
|
||||
)
|
||||
|
||||
train_dataloader = build_train_loader(
|
||||
@ -79,7 +82,7 @@ def run_train(
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
preprocessor=preprocessor,
|
||||
config=config.train.train_loader,
|
||||
config=train_config.train_loader,
|
||||
num_workers=train_workers,
|
||||
)
|
||||
|
||||
@ -89,26 +92,22 @@ def run_train(
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
preprocessor=preprocessor,
|
||||
config=config.train.val_loader,
|
||||
config=train_config.val_loader,
|
||||
num_workers=val_workers,
|
||||
)
|
||||
if val_annotations is not None
|
||||
else None
|
||||
)
|
||||
|
||||
train_config_dict = config.train.model_dump(mode="json")
|
||||
if "optimizer" in train_config_dict:
|
||||
train_config_dict["optimizer"]["t_max"] *= len(train_dataloader)
|
||||
|
||||
module = build_training_module(
|
||||
model_config=config.model.model_dump(mode="json"),
|
||||
train_config=train_config_dict,
|
||||
model_config=model_config,
|
||||
train_config=train_config,
|
||||
)
|
||||
|
||||
trainer = trainer or build_trainer(
|
||||
config,
|
||||
train_config,
|
||||
evaluator=build_evaluator(
|
||||
config.train.validation,
|
||||
train_config.validation,
|
||||
targets=targets,
|
||||
),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
@ -130,7 +129,7 @@ def run_train(
|
||||
|
||||
|
||||
def build_trainer(
|
||||
config: "BatDetect2Config",
|
||||
config: TrainingConfig,
|
||||
evaluator: "EvaluatorProtocol",
|
||||
checkpoint_dir: Path | None = None,
|
||||
log_dir: Path | None = None,
|
||||
@ -138,14 +137,14 @@ def build_trainer(
|
||||
run_name: str | None = None,
|
||||
num_epochs: int | None = None,
|
||||
) -> Trainer:
|
||||
trainer_conf = config.train.trainer
|
||||
trainer_conf = config.trainer
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building trainer with config: \n{config}",
|
||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
|
||||
train_logger = build_logger(
|
||||
config.train.logger,
|
||||
config.logger,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
@ -168,7 +167,7 @@ def build_trainer(
|
||||
logger=train_logger,
|
||||
callbacks=[
|
||||
build_checkpoint_callback(
|
||||
config=config.train.checkpoints,
|
||||
config=config.checkpoints,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
|
||||
@ -2,15 +2,22 @@ from pathlib import Path
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
from soundevent import data
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.train import (
|
||||
TrainingConfig,
|
||||
TrainingModule,
|
||||
load_model_from_checkpoint,
|
||||
run_train,
|
||||
)
|
||||
from batdetect2.train.optimizers import AdamOptimizerConfig
|
||||
from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
|
||||
from batdetect2.train.train import build_training_module
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
|
||||
@ -18,8 +25,8 @@ from batdetect2.typing.preprocess import AudioLoader
|
||||
def build_default_module(config: BatDetect2Config | None = None):
|
||||
config = config or BatDetect2Config()
|
||||
return build_training_module(
|
||||
model_config=config.model.model_dump(mode="json"),
|
||||
train_config=config.train.model_dump(mode="json"),
|
||||
model_config=config.model,
|
||||
train_config=config.train,
|
||||
)
|
||||
|
||||
|
||||
@ -57,53 +64,105 @@ def test_can_save_checkpoint(
|
||||
def test_load_model_from_checkpoint_returns_model_and_config(
|
||||
tmp_path: Path,
|
||||
):
|
||||
module = build_default_module()
|
||||
input_model_config = ModelConfig(samplerate=192_000)
|
||||
expected_model_config = ModelConfig.model_validate(
|
||||
input_model_config.model_dump(mode="json")
|
||||
)
|
||||
train_config = TrainingConfig()
|
||||
module = build_training_module(
|
||||
model_config=input_model_config,
|
||||
train_config=train_config,
|
||||
)
|
||||
trainer = L.Trainer()
|
||||
path = tmp_path / "example.ckpt"
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(path)
|
||||
|
||||
model, model_config = load_model_from_checkpoint(path)
|
||||
model, loaded_model_config = load_model_from_checkpoint(path)
|
||||
|
||||
assert model is not None
|
||||
assert model_config.model_dump(
|
||||
assert loaded_model_config.model_dump(
|
||||
mode="json"
|
||||
) == module.model_config.model_dump(mode="json")
|
||||
) == expected_model_config.model_dump(mode="json")
|
||||
|
||||
recovered = TrainingModule.load_from_checkpoint(path)
|
||||
assert recovered.train_config.model_dump(
|
||||
mode="json"
|
||||
) == train_config.model_dump(mode="json")
|
||||
|
||||
|
||||
def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
|
||||
config = BatDetect2Config()
|
||||
config.train.optimizer.learning_rate = 7e-4
|
||||
config.train.optimizer.t_max = 123
|
||||
model_config = ModelConfig(samplerate=384_000)
|
||||
expected_model_config = ModelConfig.model_validate(
|
||||
model_config.model_dump(mode="json")
|
||||
)
|
||||
train_config = TrainingConfig()
|
||||
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
|
||||
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=123)
|
||||
train_config.trainer.max_epochs = 3
|
||||
train_config.train_loader.batch_size = 2
|
||||
|
||||
module = build_default_module(config=config)
|
||||
module = build_training_module(
|
||||
model_config=model_config,
|
||||
train_config=train_config,
|
||||
)
|
||||
trainer = L.Trainer()
|
||||
path = tmp_path / "example.ckpt"
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(path)
|
||||
|
||||
checkpoint = torch.load(path, map_location="cpu", weights_only=False)
|
||||
hyper_parameters = checkpoint["hyper_parameters"]
|
||||
|
||||
assert (
|
||||
hyper_parameters["train_config"]["optimizer"]["learning_rate"] == 7e-4
|
||||
recovered = TrainingModule.load_from_checkpoint(path)
|
||||
assert not DeepDiff(
|
||||
recovered.model_config.model_dump(mode="json"),
|
||||
expected_model_config.model_dump(mode="json"),
|
||||
)
|
||||
assert not DeepDiff(
|
||||
recovered.train_config.model_dump(mode="json"),
|
||||
train_config.model_dump(mode="json"),
|
||||
)
|
||||
assert hyper_parameters["train_config"]["optimizer"]["t_max"] == 123
|
||||
assert "learning_rate" not in hyper_parameters
|
||||
assert "t_max" not in hyper_parameters
|
||||
|
||||
|
||||
def test_configure_optimizers_uses_train_config_values():
|
||||
config = BatDetect2Config()
|
||||
config.train.optimizer.learning_rate = 5e-4
|
||||
config.train.optimizer.t_max = 321
|
||||
def test_configure_optimizers_uses_train_config_values(tmp_path: Path):
|
||||
model_config = ModelConfig()
|
||||
expected_model_config = ModelConfig.model_validate(
|
||||
model_config.model_dump(mode="json")
|
||||
)
|
||||
train_config = TrainingConfig()
|
||||
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
|
||||
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=321)
|
||||
|
||||
module = build_default_module(config=config)
|
||||
module = build_training_module(
|
||||
model_config=model_config,
|
||||
train_config=train_config,
|
||||
)
|
||||
|
||||
optimizers, schedulers = module.configure_optimizers()
|
||||
optimization_config = module.configure_optimizers()
|
||||
optimizer = optimization_config["optimizer"]
|
||||
scheduler = optimization_config["lr_scheduler"]["scheduler"]
|
||||
|
||||
assert optimizers[0].param_groups[0]["lr"] == 5e-4
|
||||
assert schedulers[0].T_max == 321
|
||||
assert isinstance(optimizer, Adam)
|
||||
assert isinstance(scheduler, CosineAnnealingLR)
|
||||
assert optimizer.param_groups[0]["lr"] == 5e-4
|
||||
assert scheduler.T_max == 321
|
||||
|
||||
trainer = L.Trainer()
|
||||
path = tmp_path / "example.ckpt"
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(path)
|
||||
|
||||
recovered = TrainingModule.load_from_checkpoint(path)
|
||||
assert recovered.model_config.model_dump(
|
||||
mode="json"
|
||||
) == expected_model_config.model_dump(mode="json")
|
||||
assert recovered.train_config.model_dump(
|
||||
mode="json"
|
||||
) == train_config.model_dump(mode="json")
|
||||
|
||||
loaded_optimization_config = recovered.configure_optimizers()
|
||||
loaded_optimizer = loaded_optimization_config["optimizer"]
|
||||
loaded_scheduler = loaded_optimization_config["lr_scheduler"]["scheduler"]
|
||||
assert loaded_optimizer.param_groups[0]["lr"] == 5e-4
|
||||
assert loaded_scheduler.T_max == 321
|
||||
|
||||
|
||||
def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
|
||||
@ -136,7 +195,9 @@ def test_train_smoke_produces_loadable_checkpoint(
|
||||
run_train(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=example_annotations[:1],
|
||||
config=config,
|
||||
train_config=config.train,
|
||||
model_config=config.model,
|
||||
audio_config=config.audio,
|
||||
num_epochs=1,
|
||||
train_workers=0,
|
||||
val_workers=0,
|
||||
|
||||
22
tests/test_train/test_optimizers.py
Normal file
22
tests/test_train/test_optimizers.py
Normal file
@ -0,0 +1,22 @@
|
||||
from torch import nn
|
||||
from torch.optim import SGD, Adam
|
||||
|
||||
from batdetect2.train.optimizers import OptimizerImportConfig, build_optimizer
|
||||
|
||||
|
||||
def test_build_optimizer_defaults_to_adam():
|
||||
model = nn.Linear(4, 2)
|
||||
optimizer = build_optimizer(model.parameters())
|
||||
|
||||
assert isinstance(optimizer, Adam)
|
||||
|
||||
|
||||
def test_build_optimizer_supports_import_config():
|
||||
model = nn.Linear(4, 2)
|
||||
config = OptimizerImportConfig(
|
||||
target="torch.optim.SGD",
|
||||
arguments={"lr": 1e-3},
|
||||
)
|
||||
|
||||
optimizer = build_optimizer(model.parameters(), config=config)
|
||||
assert isinstance(optimizer, SGD)
|
||||
35
tests/test_train/test_schedulers.py
Normal file
35
tests/test_train/test_schedulers.py
Normal file
@ -0,0 +1,35 @@
|
||||
from torch import nn
|
||||
from torch.optim import SGD
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
|
||||
|
||||
from batdetect2.train.schedulers import (
|
||||
CosineAnnealingSchedulerConfig,
|
||||
SchedulerImportConfig,
|
||||
build_scheduler,
|
||||
)
|
||||
|
||||
|
||||
def test_build_scheduler_uses_epoch_t_max_directly():
|
||||
model = nn.Linear(4, 2)
|
||||
optimizer = SGD(model.parameters(), lr=1e-3)
|
||||
scheduler = build_scheduler(
|
||||
optimizer,
|
||||
config=CosineAnnealingSchedulerConfig(t_max=7),
|
||||
)
|
||||
|
||||
assert isinstance(scheduler, CosineAnnealingLR)
|
||||
assert scheduler.T_max == 7
|
||||
|
||||
|
||||
def test_build_scheduler_supports_import_config():
|
||||
model = nn.Linear(4, 2)
|
||||
optimizer = SGD(model.parameters(), lr=1e-3)
|
||||
scheduler = build_scheduler(
|
||||
optimizer,
|
||||
config=SchedulerImportConfig(
|
||||
target="torch.optim.lr_scheduler.StepLR",
|
||||
arguments={"step_size": 2},
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(scheduler, StepLR)
|
||||
Loading…
Reference in New Issue
Block a user