mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add scheduler and optimizer module
This commit is contained in:
parent
615c7d78fb
commit
feee2bdfa3
@ -64,7 +64,11 @@ model:
|
|||||||
|
|
||||||
train:
|
train:
|
||||||
optimizer:
|
optimizer:
|
||||||
|
name: adam
|
||||||
learning_rate: 0.001
|
learning_rate: 0.001
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
name: cosine_annealing
|
||||||
t_max: 100
|
t_max: 100
|
||||||
|
|
||||||
labels:
|
labels:
|
||||||
@ -76,9 +80,7 @@ train:
|
|||||||
|
|
||||||
train_loader:
|
train_loader:
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
|
|
||||||
num_workers: 2
|
num_workers: 2
|
||||||
|
|
||||||
shuffle: True
|
shuffle: True
|
||||||
|
|
||||||
clipping_strategy:
|
clipping_strategy:
|
||||||
|
|||||||
@ -86,6 +86,7 @@ dev = [
|
|||||||
"rust-just>=1.40.0",
|
"rust-just>=1.40.0",
|
||||||
"pandas-stubs>=2.2.2.240807",
|
"pandas-stubs>=2.2.2.240807",
|
||||||
"python-lsp-server>=1.13.0",
|
"python-lsp-server>=1.13.0",
|
||||||
|
"deepdiff>=8.6.1",
|
||||||
]
|
]
|
||||||
dvclive = ["dvclive>=3.48.2"]
|
dvclive = ["dvclive>=3.48.2"]
|
||||||
mlflow = ["mlflow>=3.1.1"]
|
mlflow = ["mlflow>=3.1.1"]
|
||||||
|
|||||||
@ -8,6 +8,7 @@ configuration data from files, with optional support for accessing nested
|
|||||||
configuration sections.
|
configuration sections.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
from typing import Any, Type, TypeVar, overload
|
from typing import Any, Type, TypeVar, overload
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
@ -15,6 +16,11 @@ from deepmerge.merger import Merger
|
|||||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
if sys.version_info < (3, 11):
|
||||||
|
from typing_extensions import Self
|
||||||
|
else:
|
||||||
|
from typing import Self
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseConfig",
|
"BaseConfig",
|
||||||
"load_config",
|
"load_config",
|
||||||
@ -66,6 +72,10 @@ class BaseConfig(BaseModel):
|
|||||||
def from_yaml(cls, yaml_str: str):
|
def from_yaml(cls, yaml_str: str):
|
||||||
return cls.model_validate(yaml.safe_load(yaml_str))
|
return cls.model_validate(yaml.safe_load(yaml_str))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls: Self, path: PathLike, field: str | None = None) -> Self:
|
||||||
|
return load_config(path, schema=cls, field=field) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||||
|
|||||||
@ -10,8 +10,6 @@ from batdetect2.evaluate.evaluator import build_evaluator
|
|||||||
from batdetect2.evaluate.lightning import EvaluationModule
|
from batdetect2.evaluate.lightning import EvaluationModule
|
||||||
from batdetect2.logging import build_logger
|
from batdetect2.logging import build_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.preprocess import build_preprocessor
|
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.typing import Detection
|
from batdetect2.typing import Detection
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@ -1,8 +1,5 @@
|
|||||||
from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR
|
from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR
|
||||||
from batdetect2.train.config import (
|
from batdetect2.train.config import TrainingConfig, load_train_config
|
||||||
TrainingConfig,
|
|
||||||
load_train_config,
|
|
||||||
)
|
|
||||||
from batdetect2.train.lightning import (
|
from batdetect2.train.lightning import (
|
||||||
TrainingModule,
|
TrainingModule,
|
||||||
load_model_from_checkpoint,
|
load_model_from_checkpoint,
|
||||||
@ -16,5 +13,5 @@ __all__ = [
|
|||||||
"build_trainer",
|
"build_trainer",
|
||||||
"load_model_from_checkpoint",
|
"load_model_from_checkpoint",
|
||||||
"load_train_config",
|
"load_train_config",
|
||||||
"train",
|
"run_train",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -8,6 +8,11 @@ from batdetect2.train.checkpoints import CheckpointConfig
|
|||||||
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
|
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
|
||||||
from batdetect2.train.labels import LabelConfig
|
from batdetect2.train.labels import LabelConfig
|
||||||
from batdetect2.train.losses import LossConfig
|
from batdetect2.train.losses import LossConfig
|
||||||
|
from batdetect2.train.optimizers import AdamOptimizerConfig, OptimizerConfig
|
||||||
|
from batdetect2.train.schedulers import (
|
||||||
|
CosineAnnealingSchedulerConfig,
|
||||||
|
SchedulerConfig,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
@ -36,15 +41,13 @@ class PLTrainerConfig(BaseConfig):
|
|||||||
val_check_interval: int | float | None = None
|
val_check_interval: int | float | None = None
|
||||||
|
|
||||||
|
|
||||||
class OptimizerConfig(BaseConfig):
|
|
||||||
learning_rate: float = 1e-3
|
|
||||||
t_max: int = 100
|
|
||||||
|
|
||||||
|
|
||||||
class TrainingConfig(BaseConfig):
|
class TrainingConfig(BaseConfig):
|
||||||
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
|
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
|
||||||
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
|
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
|
||||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
optimizer: OptimizerConfig = Field(default_factory=AdamOptimizerConfig)
|
||||||
|
scheduler: SchedulerConfig = Field(
|
||||||
|
default_factory=CosineAnnealingSchedulerConfig
|
||||||
|
)
|
||||||
loss: LossConfig = Field(default_factory=LossConfig)
|
loss: LossConfig = Field(default_factory=LossConfig)
|
||||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||||
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
import lightning as L
|
import lightning as L
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
from torch.optim.adam import Adam
|
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
||||||
|
|
||||||
from batdetect2.models import Model, ModelConfig, build_model
|
from batdetect2.models import Model, ModelConfig, build_model
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import TrainingConfig
|
||||||
from batdetect2.train.losses import LossFunction, build_loss
|
from batdetect2.train.losses import build_loss
|
||||||
|
from batdetect2.train.optimizers import build_optimizer
|
||||||
|
from batdetect2.train.schedulers import build_scheduler
|
||||||
from batdetect2.typing import LossProtocol, ModelOutput, TrainExample
|
from batdetect2.typing import LossProtocol, ModelOutput, TrainExample
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -21,7 +21,7 @@ class TrainingModule(L.LightningModule):
|
|||||||
self,
|
self,
|
||||||
model_config: dict | None = None,
|
model_config: dict | None = None,
|
||||||
train_config: dict | None = None,
|
train_config: dict | None = None,
|
||||||
loss: LossFunction | None = None,
|
loss: LossProtocol | None = None,
|
||||||
model: Model | None = None,
|
model: Model | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -67,10 +67,22 @@ class TrainingModule(L.LightningModule):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
config = self.train_config.optimizer
|
optimizer = build_optimizer(
|
||||||
optimizer = Adam(self.parameters(), lr=config.learning_rate)
|
self.parameters(),
|
||||||
scheduler = CosineAnnealingLR(optimizer, T_max=config.t_max)
|
config=self.train_config.optimizer,
|
||||||
return [optimizer], [scheduler]
|
)
|
||||||
|
scheduler = build_scheduler(
|
||||||
|
optimizer,
|
||||||
|
config=self.train_config.scheduler,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"optimizer": optimizer,
|
||||||
|
"lr_scheduler": {
|
||||||
|
"scheduler": scheduler,
|
||||||
|
"interval": "epoch",
|
||||||
|
"frequency": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_checkpoint(
|
def load_model_from_checkpoint(
|
||||||
@ -96,7 +108,16 @@ def load_model_from_checkpoint(
|
|||||||
|
|
||||||
|
|
||||||
def build_training_module(
|
def build_training_module(
|
||||||
model_config: dict | None = None,
|
model_config: ModelConfig | None = None,
|
||||||
train_config: dict | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
) -> TrainingModule:
|
) -> TrainingModule:
|
||||||
return TrainingModule(model_config=model_config, train_config=train_config)
|
if model_config is None:
|
||||||
|
model_config = ModelConfig()
|
||||||
|
|
||||||
|
if train_config is None:
|
||||||
|
train_config = TrainingConfig()
|
||||||
|
|
||||||
|
return TrainingModule(
|
||||||
|
model_config=model_config.model_dump(mode="json"),
|
||||||
|
train_config=train_config.model_dump(mode="json"),
|
||||||
|
)
|
||||||
|
|||||||
87
src/batdetect2/train/optimizers.py
Normal file
87
src/batdetect2/train/optimizers.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
"""Optimizer configuration and factory utilities for training."""
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from torch import nn
|
||||||
|
from torch.optim import Adam, Optimizer
|
||||||
|
|
||||||
|
from batdetect2.core import (
|
||||||
|
BaseConfig,
|
||||||
|
ImportConfig,
|
||||||
|
Registry,
|
||||||
|
add_import_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AdamOptimizerConfig",
|
||||||
|
"OptimizerConfig",
|
||||||
|
"OptimizerImportConfig",
|
||||||
|
"build_optimizer",
|
||||||
|
"optimizer_registry",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class AdamOptimizerConfig(BaseConfig):
|
||||||
|
"""Configuration for the Adam optimizer.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : Literal["adam"]
|
||||||
|
Discriminator field used by the optimizer registry.
|
||||||
|
learning_rate : float
|
||||||
|
Learning rate used by ``torch.optim.Adam``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Literal["adam"] = "adam"
|
||||||
|
learning_rate: float = 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
optimizer_registry: Registry[Optimizer, [Iterable[nn.Parameter]]] = Registry(
|
||||||
|
"optimizer"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_import_config(optimizer_registry)
|
||||||
|
class OptimizerImportConfig(ImportConfig):
|
||||||
|
"""Use any callable as an optimizer.
|
||||||
|
|
||||||
|
Set ``name="import"`` and provide a ``target`` pointing to any callable
|
||||||
|
that returns an optimizer. The training parameters are passed as the
|
||||||
|
``params`` keyword argument.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Literal["import"] = "import"
|
||||||
|
|
||||||
|
|
||||||
|
@optimizer_registry.register(AdamOptimizerConfig)
|
||||||
|
def build_adam(
|
||||||
|
config: AdamOptimizerConfig,
|
||||||
|
params: Iterable[nn.Parameter],
|
||||||
|
) -> Optimizer:
|
||||||
|
"""Build an Adam optimizer from configuration."""
|
||||||
|
return Adam(params, lr=config.learning_rate)
|
||||||
|
|
||||||
|
|
||||||
|
OptimizerConfig = Annotated[
|
||||||
|
AdamOptimizerConfig | OptimizerImportConfig,
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_optimizer(
|
||||||
|
parameters: Iterable[nn.Parameter],
|
||||||
|
config: OptimizerConfig | None = None,
|
||||||
|
) -> Optimizer:
|
||||||
|
"""Build an optimizer from configuration.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
parameters : Iterable[nn.Parameter]
|
||||||
|
Model parameters to optimize.
|
||||||
|
config : OptimizerConfig, optional
|
||||||
|
Optimizer configuration. Defaults to ``AdamOptimizerConfig``.
|
||||||
|
"""
|
||||||
|
config = config or AdamOptimizerConfig()
|
||||||
|
return optimizer_registry.build(config, params=parameters)
|
||||||
81
src/batdetect2/train/schedulers.py
Normal file
81
src/batdetect2/train/schedulers.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
"""Scheduler configuration and factory utilities for training."""
|
||||||
|
|
||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR, LRScheduler
|
||||||
|
|
||||||
|
from batdetect2.core import (
|
||||||
|
BaseConfig,
|
||||||
|
ImportConfig,
|
||||||
|
Registry,
|
||||||
|
add_import_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CosineAnnealingSchedulerConfig",
|
||||||
|
"SchedulerConfig",
|
||||||
|
"SchedulerImportConfig",
|
||||||
|
"build_scheduler",
|
||||||
|
"scheduler_registry",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CosineAnnealingSchedulerConfig(BaseConfig):
|
||||||
|
"""Configuration for ``CosineAnnealingLR``.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : Literal["cosine_annealing"]
|
||||||
|
Discriminator field used by the scheduler registry.
|
||||||
|
t_max : int
|
||||||
|
Number of epochs to complete one cosine cycle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Literal["cosine_annealing"] = "cosine_annealing"
|
||||||
|
t_max: int = 100
|
||||||
|
|
||||||
|
|
||||||
|
scheduler_registry: Registry[LRScheduler, [Optimizer]] = Registry("scheduler")
|
||||||
|
|
||||||
|
|
||||||
|
@add_import_config(scheduler_registry)
|
||||||
|
class SchedulerImportConfig(ImportConfig):
|
||||||
|
"""Use any callable as a scheduler.
|
||||||
|
|
||||||
|
Set ``name="import"`` and provide a ``target`` pointing to any callable
|
||||||
|
that returns a scheduler. The optimizer instance is passed as the
|
||||||
|
``optimizer`` keyword argument.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Literal["import"] = "import"
|
||||||
|
|
||||||
|
|
||||||
|
@scheduler_registry.register(CosineAnnealingSchedulerConfig)
|
||||||
|
def build_cosine_scheduler(
|
||||||
|
config: CosineAnnealingSchedulerConfig,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
) -> LRScheduler:
|
||||||
|
"""Build a cosine annealing scheduler.
|
||||||
|
|
||||||
|
``t_max`` is interpreted in epochs because Lightning steps the scheduler
|
||||||
|
once per epoch when ``interval="epoch"`` is used.
|
||||||
|
"""
|
||||||
|
return CosineAnnealingLR(optimizer, T_max=config.t_max)
|
||||||
|
|
||||||
|
|
||||||
|
SchedulerConfig = Annotated[
|
||||||
|
CosineAnnealingSchedulerConfig | SchedulerImportConfig,
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_scheduler(
|
||||||
|
optimizer: Optimizer,
|
||||||
|
config: SchedulerConfig | None = None,
|
||||||
|
) -> LRScheduler:
|
||||||
|
"""Build a scheduler from configuration."""
|
||||||
|
config = config or CosineAnnealingSchedulerConfig()
|
||||||
|
|
||||||
|
return scheduler_registry.build(config, optimizer=optimizer)
|
||||||
@ -6,11 +6,13 @@ from lightning import Trainer, seed_everything
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
from batdetect2.evaluate import build_evaluator
|
from batdetect2.evaluate import build_evaluator
|
||||||
from batdetect2.logging import build_logger
|
from batdetect2.logging import build_logger
|
||||||
|
from batdetect2.models import ModelConfig
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
|
from batdetect2.train import TrainingConfig
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.checkpoints import build_checkpoint_callback
|
from batdetect2.train.checkpoints import build_checkpoint_callback
|
||||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||||
@ -18,7 +20,6 @@ from batdetect2.train.labels import build_clip_labeler
|
|||||||
from batdetect2.train.lightning import build_training_module
|
from batdetect2.train.lightning import build_training_module
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from batdetect2.config import BatDetect2Config
|
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
AudioLoader,
|
AudioLoader,
|
||||||
ClipLabeller,
|
ClipLabeller,
|
||||||
@ -29,7 +30,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"build_trainer",
|
"build_trainer",
|
||||||
"train",
|
"run_train",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -40,7 +41,9 @@ def run_train(
|
|||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
audio_loader: Optional["AudioLoader"] = None,
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
labeller: Optional["ClipLabeller"] = None,
|
labeller: Optional["ClipLabeller"] = None,
|
||||||
config: Optional["BatDetect2Config"] = None,
|
audio_config: Optional[AudioConfig] = None,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
|
train_config: Optional[TrainingConfig] = None,
|
||||||
trainer: Trainer | None = None,
|
trainer: Trainer | None = None,
|
||||||
train_workers: int | None = None,
|
train_workers: int | None = None,
|
||||||
val_workers: int | None = None,
|
val_workers: int | None = None,
|
||||||
@ -51,27 +54,27 @@ def run_train(
|
|||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
seed: int | None = None,
|
seed: int | None = None,
|
||||||
):
|
):
|
||||||
from batdetect2.config import BatDetect2Config
|
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
|
|
||||||
config = config or BatDetect2Config()
|
model_config = model_config or ModelConfig()
|
||||||
|
audio_config = audio_config or AudioConfig()
|
||||||
|
train_config = train_config or TrainingConfig()
|
||||||
|
|
||||||
targets = targets or build_targets(config=config.model.targets)
|
targets = targets or build_targets(config=model_config.targets)
|
||||||
|
|
||||||
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||||
|
|
||||||
preprocessor = preprocessor or build_preprocessor(
|
preprocessor = preprocessor or build_preprocessor(
|
||||||
input_samplerate=audio_loader.samplerate,
|
input_samplerate=audio_loader.samplerate,
|
||||||
config=config.model.preprocess,
|
config=model_config.preprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
labeller = labeller or build_clip_labeler(
|
labeller = labeller or build_clip_labeler(
|
||||||
targets,
|
targets,
|
||||||
min_freq=preprocessor.min_freq,
|
min_freq=preprocessor.min_freq,
|
||||||
max_freq=preprocessor.max_freq,
|
max_freq=preprocessor.max_freq,
|
||||||
config=config.train.labels,
|
config=train_config.labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataloader = build_train_loader(
|
train_dataloader = build_train_loader(
|
||||||
@ -79,7 +82,7 @@ def run_train(
|
|||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config.train.train_loader,
|
config=train_config.train_loader,
|
||||||
num_workers=train_workers,
|
num_workers=train_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -89,26 +92,22 @@ def run_train(
|
|||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config.train.val_loader,
|
config=train_config.val_loader,
|
||||||
num_workers=val_workers,
|
num_workers=val_workers,
|
||||||
)
|
)
|
||||||
if val_annotations is not None
|
if val_annotations is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
train_config_dict = config.train.model_dump(mode="json")
|
|
||||||
if "optimizer" in train_config_dict:
|
|
||||||
train_config_dict["optimizer"]["t_max"] *= len(train_dataloader)
|
|
||||||
|
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model_config=config.model.model_dump(mode="json"),
|
model_config=model_config,
|
||||||
train_config=train_config_dict,
|
train_config=train_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = trainer or build_trainer(
|
trainer = trainer or build_trainer(
|
||||||
config,
|
train_config,
|
||||||
evaluator=build_evaluator(
|
evaluator=build_evaluator(
|
||||||
config.train.validation,
|
train_config.validation,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
),
|
),
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
@ -130,7 +129,7 @@ def run_train(
|
|||||||
|
|
||||||
|
|
||||||
def build_trainer(
|
def build_trainer(
|
||||||
config: "BatDetect2Config",
|
config: TrainingConfig,
|
||||||
evaluator: "EvaluatorProtocol",
|
evaluator: "EvaluatorProtocol",
|
||||||
checkpoint_dir: Path | None = None,
|
checkpoint_dir: Path | None = None,
|
||||||
log_dir: Path | None = None,
|
log_dir: Path | None = None,
|
||||||
@ -138,14 +137,14 @@ def build_trainer(
|
|||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
num_epochs: int | None = None,
|
num_epochs: int | None = None,
|
||||||
) -> Trainer:
|
) -> Trainer:
|
||||||
trainer_conf = config.train.trainer
|
trainer_conf = config.trainer
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building trainer with config: \n{config}",
|
"Building trainer with config: \n{config}",
|
||||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
train_logger = build_logger(
|
train_logger = build_logger(
|
||||||
config.train.logger,
|
config.logger,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
@ -168,7 +167,7 @@ def build_trainer(
|
|||||||
logger=train_logger,
|
logger=train_logger,
|
||||||
callbacks=[
|
callbacks=[
|
||||||
build_checkpoint_callback(
|
build_checkpoint_callback(
|
||||||
config=config.train.checkpoints,
|
config=config.checkpoints,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
|
|||||||
@ -2,15 +2,22 @@ from pathlib import Path
|
|||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
import torch
|
import torch
|
||||||
|
from deepdiff import DeepDiff
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
from torch.optim import Adam
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
|
from batdetect2.models import ModelConfig
|
||||||
from batdetect2.train import (
|
from batdetect2.train import (
|
||||||
|
TrainingConfig,
|
||||||
TrainingModule,
|
TrainingModule,
|
||||||
load_model_from_checkpoint,
|
load_model_from_checkpoint,
|
||||||
run_train,
|
run_train,
|
||||||
)
|
)
|
||||||
|
from batdetect2.train.optimizers import AdamOptimizerConfig
|
||||||
|
from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
|
||||||
from batdetect2.train.train import build_training_module
|
from batdetect2.train.train import build_training_module
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
from batdetect2.typing.preprocess import AudioLoader
|
||||||
|
|
||||||
@ -18,8 +25,8 @@ from batdetect2.typing.preprocess import AudioLoader
|
|||||||
def build_default_module(config: BatDetect2Config | None = None):
|
def build_default_module(config: BatDetect2Config | None = None):
|
||||||
config = config or BatDetect2Config()
|
config = config or BatDetect2Config()
|
||||||
return build_training_module(
|
return build_training_module(
|
||||||
model_config=config.model.model_dump(mode="json"),
|
model_config=config.model,
|
||||||
train_config=config.train.model_dump(mode="json"),
|
train_config=config.train,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -57,53 +64,105 @@ def test_can_save_checkpoint(
|
|||||||
def test_load_model_from_checkpoint_returns_model_and_config(
|
def test_load_model_from_checkpoint_returns_model_and_config(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
):
|
):
|
||||||
module = build_default_module()
|
input_model_config = ModelConfig(samplerate=192_000)
|
||||||
|
expected_model_config = ModelConfig.model_validate(
|
||||||
|
input_model_config.model_dump(mode="json")
|
||||||
|
)
|
||||||
|
train_config = TrainingConfig()
|
||||||
|
module = build_training_module(
|
||||||
|
model_config=input_model_config,
|
||||||
|
train_config=train_config,
|
||||||
|
)
|
||||||
trainer = L.Trainer()
|
trainer = L.Trainer()
|
||||||
path = tmp_path / "example.ckpt"
|
path = tmp_path / "example.ckpt"
|
||||||
trainer.strategy.connect(module)
|
trainer.strategy.connect(module)
|
||||||
trainer.save_checkpoint(path)
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
model, model_config = load_model_from_checkpoint(path)
|
model, loaded_model_config = load_model_from_checkpoint(path)
|
||||||
|
|
||||||
assert model is not None
|
assert model is not None
|
||||||
assert model_config.model_dump(
|
assert loaded_model_config.model_dump(
|
||||||
mode="json"
|
mode="json"
|
||||||
) == module.model_config.model_dump(mode="json")
|
) == expected_model_config.model_dump(mode="json")
|
||||||
|
|
||||||
|
recovered = TrainingModule.load_from_checkpoint(path)
|
||||||
|
assert recovered.train_config.model_dump(
|
||||||
|
mode="json"
|
||||||
|
) == train_config.model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
|
def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
|
||||||
config = BatDetect2Config()
|
model_config = ModelConfig(samplerate=384_000)
|
||||||
config.train.optimizer.learning_rate = 7e-4
|
expected_model_config = ModelConfig.model_validate(
|
||||||
config.train.optimizer.t_max = 123
|
model_config.model_dump(mode="json")
|
||||||
|
)
|
||||||
|
train_config = TrainingConfig()
|
||||||
|
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
|
||||||
|
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=123)
|
||||||
|
train_config.trainer.max_epochs = 3
|
||||||
|
train_config.train_loader.batch_size = 2
|
||||||
|
|
||||||
module = build_default_module(config=config)
|
module = build_training_module(
|
||||||
|
model_config=model_config,
|
||||||
|
train_config=train_config,
|
||||||
|
)
|
||||||
trainer = L.Trainer()
|
trainer = L.Trainer()
|
||||||
path = tmp_path / "example.ckpt"
|
path = tmp_path / "example.ckpt"
|
||||||
trainer.strategy.connect(module)
|
trainer.strategy.connect(module)
|
||||||
trainer.save_checkpoint(path)
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
checkpoint = torch.load(path, map_location="cpu", weights_only=False)
|
recovered = TrainingModule.load_from_checkpoint(path)
|
||||||
hyper_parameters = checkpoint["hyper_parameters"]
|
assert not DeepDiff(
|
||||||
|
recovered.model_config.model_dump(mode="json"),
|
||||||
assert (
|
expected_model_config.model_dump(mode="json"),
|
||||||
hyper_parameters["train_config"]["optimizer"]["learning_rate"] == 7e-4
|
)
|
||||||
|
assert not DeepDiff(
|
||||||
|
recovered.train_config.model_dump(mode="json"),
|
||||||
|
train_config.model_dump(mode="json"),
|
||||||
)
|
)
|
||||||
assert hyper_parameters["train_config"]["optimizer"]["t_max"] == 123
|
|
||||||
assert "learning_rate" not in hyper_parameters
|
|
||||||
assert "t_max" not in hyper_parameters
|
|
||||||
|
|
||||||
|
|
||||||
def test_configure_optimizers_uses_train_config_values():
|
def test_configure_optimizers_uses_train_config_values(tmp_path: Path):
|
||||||
config = BatDetect2Config()
|
model_config = ModelConfig()
|
||||||
config.train.optimizer.learning_rate = 5e-4
|
expected_model_config = ModelConfig.model_validate(
|
||||||
config.train.optimizer.t_max = 321
|
model_config.model_dump(mode="json")
|
||||||
|
)
|
||||||
|
train_config = TrainingConfig()
|
||||||
|
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
|
||||||
|
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=321)
|
||||||
|
|
||||||
module = build_default_module(config=config)
|
module = build_training_module(
|
||||||
|
model_config=model_config,
|
||||||
|
train_config=train_config,
|
||||||
|
)
|
||||||
|
|
||||||
optimizers, schedulers = module.configure_optimizers()
|
optimization_config = module.configure_optimizers()
|
||||||
|
optimizer = optimization_config["optimizer"]
|
||||||
|
scheduler = optimization_config["lr_scheduler"]["scheduler"]
|
||||||
|
|
||||||
assert optimizers[0].param_groups[0]["lr"] == 5e-4
|
assert isinstance(optimizer, Adam)
|
||||||
assert schedulers[0].T_max == 321
|
assert isinstance(scheduler, CosineAnnealingLR)
|
||||||
|
assert optimizer.param_groups[0]["lr"] == 5e-4
|
||||||
|
assert scheduler.T_max == 321
|
||||||
|
|
||||||
|
trainer = L.Trainer()
|
||||||
|
path = tmp_path / "example.ckpt"
|
||||||
|
trainer.strategy.connect(module)
|
||||||
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
|
recovered = TrainingModule.load_from_checkpoint(path)
|
||||||
|
assert recovered.model_config.model_dump(
|
||||||
|
mode="json"
|
||||||
|
) == expected_model_config.model_dump(mode="json")
|
||||||
|
assert recovered.train_config.model_dump(
|
||||||
|
mode="json"
|
||||||
|
) == train_config.model_dump(mode="json")
|
||||||
|
|
||||||
|
loaded_optimization_config = recovered.configure_optimizers()
|
||||||
|
loaded_optimizer = loaded_optimization_config["optimizer"]
|
||||||
|
loaded_scheduler = loaded_optimization_config["lr_scheduler"]["scheduler"]
|
||||||
|
assert loaded_optimizer.param_groups[0]["lr"] == 5e-4
|
||||||
|
assert loaded_scheduler.T_max == 321
|
||||||
|
|
||||||
|
|
||||||
def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
|
def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
|
||||||
@ -136,7 +195,9 @@ def test_train_smoke_produces_loadable_checkpoint(
|
|||||||
run_train(
|
run_train(
|
||||||
train_annotations=example_annotations[:1],
|
train_annotations=example_annotations[:1],
|
||||||
val_annotations=example_annotations[:1],
|
val_annotations=example_annotations[:1],
|
||||||
config=config,
|
train_config=config.train,
|
||||||
|
model_config=config.model,
|
||||||
|
audio_config=config.audio,
|
||||||
num_epochs=1,
|
num_epochs=1,
|
||||||
train_workers=0,
|
train_workers=0,
|
||||||
val_workers=0,
|
val_workers=0,
|
||||||
|
|||||||
22
tests/test_train/test_optimizers.py
Normal file
22
tests/test_train/test_optimizers.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from torch import nn
|
||||||
|
from torch.optim import SGD, Adam
|
||||||
|
|
||||||
|
from batdetect2.train.optimizers import OptimizerImportConfig, build_optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_optimizer_defaults_to_adam():
|
||||||
|
model = nn.Linear(4, 2)
|
||||||
|
optimizer = build_optimizer(model.parameters())
|
||||||
|
|
||||||
|
assert isinstance(optimizer, Adam)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_optimizer_supports_import_config():
|
||||||
|
model = nn.Linear(4, 2)
|
||||||
|
config = OptimizerImportConfig(
|
||||||
|
target="torch.optim.SGD",
|
||||||
|
arguments={"lr": 1e-3},
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer = build_optimizer(model.parameters(), config=config)
|
||||||
|
assert isinstance(optimizer, SGD)
|
||||||
35
tests/test_train/test_schedulers.py
Normal file
35
tests/test_train/test_schedulers.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from torch import nn
|
||||||
|
from torch.optim import SGD
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
|
||||||
|
|
||||||
|
from batdetect2.train.schedulers import (
|
||||||
|
CosineAnnealingSchedulerConfig,
|
||||||
|
SchedulerImportConfig,
|
||||||
|
build_scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_scheduler_uses_epoch_t_max_directly():
|
||||||
|
model = nn.Linear(4, 2)
|
||||||
|
optimizer = SGD(model.parameters(), lr=1e-3)
|
||||||
|
scheduler = build_scheduler(
|
||||||
|
optimizer,
|
||||||
|
config=CosineAnnealingSchedulerConfig(t_max=7),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(scheduler, CosineAnnealingLR)
|
||||||
|
assert scheduler.T_max == 7
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_scheduler_supports_import_config():
|
||||||
|
model = nn.Linear(4, 2)
|
||||||
|
optimizer = SGD(model.parameters(), lr=1e-3)
|
||||||
|
scheduler = build_scheduler(
|
||||||
|
optimizer,
|
||||||
|
config=SchedulerImportConfig(
|
||||||
|
target="torch.optim.lr_scheduler.StepLR",
|
||||||
|
arguments={"step_size": 2},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(scheduler, StepLR)
|
||||||
Loading…
Reference in New Issue
Block a user