diff --git a/example_data/config.yaml b/example_data/config.yaml index 42a35d6..4a274e5 100644 --- a/example_data/config.yaml +++ b/example_data/config.yaml @@ -64,7 +64,11 @@ model: train: optimizer: + name: adam learning_rate: 0.001 + + scheduler: + name: cosine_annealing t_max: 100 labels: @@ -76,9 +80,7 @@ train: train_loader: batch_size: 8 - num_workers: 2 - shuffle: True clipping_strategy: diff --git a/pyproject.toml b/pyproject.toml index 42ac02a..60b3ae4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,6 +86,7 @@ dev = [ "rust-just>=1.40.0", "pandas-stubs>=2.2.2.240807", "python-lsp-server>=1.13.0", + "deepdiff>=8.6.1", ] dvclive = ["dvclive>=3.48.2"] mlflow = ["mlflow>=3.1.1"] diff --git a/src/batdetect2/core/configs.py b/src/batdetect2/core/configs.py index 5a3a524..118d366 100644 --- a/src/batdetect2/core/configs.py +++ b/src/batdetect2/core/configs.py @@ -8,6 +8,7 @@ configuration data from files, with optional support for accessing nested configuration sections. """ +import sys from typing import Any, Type, TypeVar, overload import yaml @@ -15,6 +16,11 @@ from deepmerge.merger import Merger from pydantic import BaseModel, ConfigDict, TypeAdapter from soundevent.data import PathLike +if sys.version_info < (3, 11): + from typing_extensions import Self +else: + from typing import Self + __all__ = [ "BaseConfig", "load_config", @@ -66,6 +72,10 @@ class BaseConfig(BaseModel): def from_yaml(cls, yaml_str: str): return cls.model_validate(yaml.safe_load(yaml_str)) + @classmethod + def load(cls: Self, path: PathLike, field: str | None = None) -> Self: + return load_config(path, schema=cls, field=field) # type: ignore + T = TypeVar("T") T_Model = TypeVar("T_Model", bound=BaseModel) diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index 223e135..b700d7d 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -10,8 +10,6 @@ from batdetect2.evaluate.evaluator import build_evaluator from batdetect2.evaluate.lightning import EvaluationModule from batdetect2.logging import build_logger from batdetect2.models import Model -from batdetect2.preprocess import build_preprocessor -from batdetect2.targets import build_targets from batdetect2.typing import Detection if TYPE_CHECKING: diff --git a/src/batdetect2/train/__init__.py b/src/batdetect2/train/__init__.py index 1f2e9ec..8590bac 100644 --- a/src/batdetect2/train/__init__.py +++ b/src/batdetect2/train/__init__.py @@ -1,8 +1,5 @@ from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR -from batdetect2.train.config import ( - TrainingConfig, - load_train_config, -) +from batdetect2.train.config import TrainingConfig, load_train_config from batdetect2.train.lightning import ( TrainingModule, load_model_from_checkpoint, @@ -16,5 +13,5 @@ __all__ = [ "build_trainer", "load_model_from_checkpoint", "load_train_config", - "train", + "run_train", ] diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index 083a63e..40e998b 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -8,6 +8,11 @@ from batdetect2.train.checkpoints import CheckpointConfig from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig from batdetect2.train.labels import LabelConfig from batdetect2.train.losses import LossConfig +from batdetect2.train.optimizers import AdamOptimizerConfig, OptimizerConfig +from batdetect2.train.schedulers import ( + CosineAnnealingSchedulerConfig, + SchedulerConfig, +) __all__ = [ "TrainingConfig", @@ -36,15 +41,13 @@ class PLTrainerConfig(BaseConfig): val_check_interval: int | float | None = None -class OptimizerConfig(BaseConfig): - learning_rate: float = 1e-3 - t_max: int = 100 - - class TrainingConfig(BaseConfig): train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig) val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig) - optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) + optimizer: OptimizerConfig = Field(default_factory=AdamOptimizerConfig) + scheduler: SchedulerConfig = Field( + default_factory=CosineAnnealingSchedulerConfig + ) loss: LossConfig = Field(default_factory=LossConfig) trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig) diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index a7fde38..76e20c4 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -1,11 +1,11 @@ import lightning as L from soundevent.data import PathLike -from torch.optim.adam import Adam -from torch.optim.lr_scheduler import CosineAnnealingLR from batdetect2.models import Model, ModelConfig, build_model from batdetect2.train.config import TrainingConfig -from batdetect2.train.losses import LossFunction, build_loss +from batdetect2.train.losses import build_loss +from batdetect2.train.optimizers import build_optimizer +from batdetect2.train.schedulers import build_scheduler from batdetect2.typing import LossProtocol, ModelOutput, TrainExample __all__ = [ @@ -21,7 +21,7 @@ class TrainingModule(L.LightningModule): self, model_config: dict | None = None, train_config: dict | None = None, - loss: LossFunction | None = None, + loss: LossProtocol | None = None, model: Model | None = None, ): super().__init__() @@ -67,10 +67,22 @@ class TrainingModule(L.LightningModule): return outputs def configure_optimizers(self): - config = self.train_config.optimizer - optimizer = Adam(self.parameters(), lr=config.learning_rate) - scheduler = CosineAnnealingLR(optimizer, T_max=config.t_max) - return [optimizer], [scheduler] + optimizer = build_optimizer( + self.parameters(), + config=self.train_config.optimizer, + ) + scheduler = build_scheduler( + optimizer, + config=self.train_config.scheduler, + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "epoch", + "frequency": 1, + }, + } def load_model_from_checkpoint( @@ -96,7 +108,16 @@ def load_model_from_checkpoint( def build_training_module( - model_config: dict | None = None, - train_config: dict | None = None, + model_config: ModelConfig | None = None, + train_config: TrainingConfig | None = None, ) -> TrainingModule: - return TrainingModule(model_config=model_config, train_config=train_config) + if model_config is None: + model_config = ModelConfig() + + if train_config is None: + train_config = TrainingConfig() + + return TrainingModule( + model_config=model_config.model_dump(mode="json"), + train_config=train_config.model_dump(mode="json"), + ) diff --git a/src/batdetect2/train/optimizers.py b/src/batdetect2/train/optimizers.py new file mode 100644 index 0000000..54543b7 --- /dev/null +++ b/src/batdetect2/train/optimizers.py @@ -0,0 +1,87 @@ +"""Optimizer configuration and factory utilities for training.""" + +from collections.abc import Iterable +from typing import Annotated, Literal + +from pydantic import Field +from torch import nn +from torch.optim import Adam, Optimizer + +from batdetect2.core import ( + BaseConfig, + ImportConfig, + Registry, + add_import_config, +) + +__all__ = [ + "AdamOptimizerConfig", + "OptimizerConfig", + "OptimizerImportConfig", + "build_optimizer", + "optimizer_registry", +] + + +class AdamOptimizerConfig(BaseConfig): + """Configuration for the Adam optimizer. + + Attributes + ---------- + name : Literal["adam"] + Discriminator field used by the optimizer registry. + learning_rate : float + Learning rate used by ``torch.optim.Adam``. + """ + + name: Literal["adam"] = "adam" + learning_rate: float = 1e-3 + + +optimizer_registry: Registry[Optimizer, [Iterable[nn.Parameter]]] = Registry( + "optimizer" +) + + +@add_import_config(optimizer_registry) +class OptimizerImportConfig(ImportConfig): + """Use any callable as an optimizer. + + Set ``name="import"`` and provide a ``target`` pointing to any callable + that returns an optimizer. The training parameters are passed as the + ``params`` keyword argument. + """ + + name: Literal["import"] = "import" + + +@optimizer_registry.register(AdamOptimizerConfig) +def build_adam( + config: AdamOptimizerConfig, + params: Iterable[nn.Parameter], +) -> Optimizer: + """Build an Adam optimizer from configuration.""" + return Adam(params, lr=config.learning_rate) + + +OptimizerConfig = Annotated[ + AdamOptimizerConfig | OptimizerImportConfig, + Field(discriminator="name"), +] + + +def build_optimizer( + parameters: Iterable[nn.Parameter], + config: OptimizerConfig | None = None, +) -> Optimizer: + """Build an optimizer from configuration. + + Parameters + ---------- + parameters : Iterable[nn.Parameter] + Model parameters to optimize. + config : OptimizerConfig, optional + Optimizer configuration. Defaults to ``AdamOptimizerConfig``. + """ + config = config or AdamOptimizerConfig() + return optimizer_registry.build(config, params=parameters) diff --git a/src/batdetect2/train/schedulers.py b/src/batdetect2/train/schedulers.py new file mode 100644 index 0000000..ae1c742 --- /dev/null +++ b/src/batdetect2/train/schedulers.py @@ -0,0 +1,81 @@ +"""Scheduler configuration and factory utilities for training.""" + +from typing import Annotated, Literal + +from pydantic import Field +from torch.optim import Optimizer +from torch.optim.lr_scheduler import CosineAnnealingLR, LRScheduler + +from batdetect2.core import ( + BaseConfig, + ImportConfig, + Registry, + add_import_config, +) + +__all__ = [ + "CosineAnnealingSchedulerConfig", + "SchedulerConfig", + "SchedulerImportConfig", + "build_scheduler", + "scheduler_registry", +] + + +class CosineAnnealingSchedulerConfig(BaseConfig): + """Configuration for ``CosineAnnealingLR``. + + Attributes + ---------- + name : Literal["cosine_annealing"] + Discriminator field used by the scheduler registry. + t_max : int + Number of epochs to complete one cosine cycle. + """ + + name: Literal["cosine_annealing"] = "cosine_annealing" + t_max: int = 100 + + +scheduler_registry: Registry[LRScheduler, [Optimizer]] = Registry("scheduler") + + +@add_import_config(scheduler_registry) +class SchedulerImportConfig(ImportConfig): + """Use any callable as a scheduler. + + Set ``name="import"`` and provide a ``target`` pointing to any callable + that returns a scheduler. The optimizer instance is passed as the + ``optimizer`` keyword argument. + """ + + name: Literal["import"] = "import" + + +@scheduler_registry.register(CosineAnnealingSchedulerConfig) +def build_cosine_scheduler( + config: CosineAnnealingSchedulerConfig, + optimizer: Optimizer, +) -> LRScheduler: + """Build a cosine annealing scheduler. + + ``t_max`` is interpreted in epochs because Lightning steps the scheduler + once per epoch when ``interval="epoch"`` is used. + """ + return CosineAnnealingLR(optimizer, T_max=config.t_max) + + +SchedulerConfig = Annotated[ + CosineAnnealingSchedulerConfig | SchedulerImportConfig, + Field(discriminator="name"), +] + + +def build_scheduler( + optimizer: Optimizer, + config: SchedulerConfig | None = None, +) -> LRScheduler: + """Build a scheduler from configuration.""" + config = config or CosineAnnealingSchedulerConfig() + + return scheduler_registry.build(config, optimizer=optimizer) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index d11abf2..bec3cf1 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -6,11 +6,13 @@ from lightning import Trainer, seed_everything from loguru import logger from soundevent import data -from batdetect2.audio import build_audio_loader +from batdetect2.audio import AudioConfig, build_audio_loader from batdetect2.evaluate import build_evaluator from batdetect2.logging import build_logger +from batdetect2.models import ModelConfig from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets +from batdetect2.train import TrainingConfig from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.checkpoints import build_checkpoint_callback from batdetect2.train.dataset import build_train_loader, build_val_loader @@ -18,7 +20,6 @@ from batdetect2.train.labels import build_clip_labeler from batdetect2.train.lightning import build_training_module if TYPE_CHECKING: - from batdetect2.config import BatDetect2Config from batdetect2.typing import ( AudioLoader, ClipLabeller, @@ -29,7 +30,7 @@ if TYPE_CHECKING: __all__ = [ "build_trainer", - "train", + "run_train", ] @@ -40,7 +41,9 @@ def run_train( preprocessor: Optional["PreprocessorProtocol"] = None, audio_loader: Optional["AudioLoader"] = None, labeller: Optional["ClipLabeller"] = None, - config: Optional["BatDetect2Config"] = None, + audio_config: Optional[AudioConfig] = None, + model_config: Optional[ModelConfig] = None, + train_config: Optional[TrainingConfig] = None, trainer: Trainer | None = None, train_workers: int | None = None, val_workers: int | None = None, @@ -51,27 +54,27 @@ def run_train( run_name: str | None = None, seed: int | None = None, ): - from batdetect2.config import BatDetect2Config - if seed is not None: seed_everything(seed) - config = config or BatDetect2Config() + model_config = model_config or ModelConfig() + audio_config = audio_config or AudioConfig() + train_config = train_config or TrainingConfig() - targets = targets or build_targets(config=config.model.targets) + targets = targets or build_targets(config=model_config.targets) - audio_loader = audio_loader or build_audio_loader(config=config.audio) + audio_loader = audio_loader or build_audio_loader(config=audio_config) preprocessor = preprocessor or build_preprocessor( input_samplerate=audio_loader.samplerate, - config=config.model.preprocess, + config=model_config.preprocess, ) labeller = labeller or build_clip_labeler( targets, min_freq=preprocessor.min_freq, max_freq=preprocessor.max_freq, - config=config.train.labels, + config=train_config.labels, ) train_dataloader = build_train_loader( @@ -79,7 +82,7 @@ def run_train( audio_loader=audio_loader, labeller=labeller, preprocessor=preprocessor, - config=config.train.train_loader, + config=train_config.train_loader, num_workers=train_workers, ) @@ -89,26 +92,22 @@ def run_train( audio_loader=audio_loader, labeller=labeller, preprocessor=preprocessor, - config=config.train.val_loader, + config=train_config.val_loader, num_workers=val_workers, ) if val_annotations is not None else None ) - train_config_dict = config.train.model_dump(mode="json") - if "optimizer" in train_config_dict: - train_config_dict["optimizer"]["t_max"] *= len(train_dataloader) - module = build_training_module( - model_config=config.model.model_dump(mode="json"), - train_config=train_config_dict, + model_config=model_config, + train_config=train_config, ) trainer = trainer or build_trainer( - config, + train_config, evaluator=build_evaluator( - config.train.validation, + train_config.validation, targets=targets, ), checkpoint_dir=checkpoint_dir, @@ -130,7 +129,7 @@ def run_train( def build_trainer( - config: "BatDetect2Config", + config: TrainingConfig, evaluator: "EvaluatorProtocol", checkpoint_dir: Path | None = None, log_dir: Path | None = None, @@ -138,14 +137,14 @@ def build_trainer( run_name: str | None = None, num_epochs: int | None = None, ) -> Trainer: - trainer_conf = config.train.trainer + trainer_conf = config.trainer logger.opt(lazy=True).debug( "Building trainer with config: \n{config}", config=lambda: trainer_conf.to_yaml_string(exclude_none=True), ) train_logger = build_logger( - config.train.logger, + config.logger, log_dir=log_dir, experiment_name=experiment_name, run_name=run_name, @@ -168,7 +167,7 @@ def build_trainer( logger=train_logger, callbacks=[ build_checkpoint_callback( - config=config.train.checkpoints, + config=config.checkpoints, checkpoint_dir=checkpoint_dir, experiment_name=experiment_name, run_name=run_name, diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index ce4f537..e6aabc3 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -2,15 +2,22 @@ from pathlib import Path import lightning as L import torch +from deepdiff import DeepDiff from soundevent import data +from torch.optim import Adam +from torch.optim.lr_scheduler import CosineAnnealingLR from batdetect2.api_v2 import BatDetect2API from batdetect2.config import BatDetect2Config +from batdetect2.models import ModelConfig from batdetect2.train import ( + TrainingConfig, TrainingModule, load_model_from_checkpoint, run_train, ) +from batdetect2.train.optimizers import AdamOptimizerConfig +from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig from batdetect2.train.train import build_training_module from batdetect2.typing.preprocess import AudioLoader @@ -18,8 +25,8 @@ from batdetect2.typing.preprocess import AudioLoader def build_default_module(config: BatDetect2Config | None = None): config = config or BatDetect2Config() return build_training_module( - model_config=config.model.model_dump(mode="json"), - train_config=config.train.model_dump(mode="json"), + model_config=config.model, + train_config=config.train, ) @@ -57,53 +64,105 @@ def test_can_save_checkpoint( def test_load_model_from_checkpoint_returns_model_and_config( tmp_path: Path, ): - module = build_default_module() + input_model_config = ModelConfig(samplerate=192_000) + expected_model_config = ModelConfig.model_validate( + input_model_config.model_dump(mode="json") + ) + train_config = TrainingConfig() + module = build_training_module( + model_config=input_model_config, + train_config=train_config, + ) trainer = L.Trainer() path = tmp_path / "example.ckpt" trainer.strategy.connect(module) trainer.save_checkpoint(path) - model, model_config = load_model_from_checkpoint(path) + model, loaded_model_config = load_model_from_checkpoint(path) assert model is not None - assert model_config.model_dump( + assert loaded_model_config.model_dump( mode="json" - ) == module.model_config.model_dump(mode="json") + ) == expected_model_config.model_dump(mode="json") + + recovered = TrainingModule.load_from_checkpoint(path) + assert recovered.train_config.model_dump( + mode="json" + ) == train_config.model_dump(mode="json") def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path): - config = BatDetect2Config() - config.train.optimizer.learning_rate = 7e-4 - config.train.optimizer.t_max = 123 + model_config = ModelConfig(samplerate=384_000) + expected_model_config = ModelConfig.model_validate( + model_config.model_dump(mode="json") + ) + train_config = TrainingConfig() + train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4) + train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=123) + train_config.trainer.max_epochs = 3 + train_config.train_loader.batch_size = 2 - module = build_default_module(config=config) + module = build_training_module( + model_config=model_config, + train_config=train_config, + ) trainer = L.Trainer() path = tmp_path / "example.ckpt" trainer.strategy.connect(module) trainer.save_checkpoint(path) - checkpoint = torch.load(path, map_location="cpu", weights_only=False) - hyper_parameters = checkpoint["hyper_parameters"] - - assert ( - hyper_parameters["train_config"]["optimizer"]["learning_rate"] == 7e-4 + recovered = TrainingModule.load_from_checkpoint(path) + assert not DeepDiff( + recovered.model_config.model_dump(mode="json"), + expected_model_config.model_dump(mode="json"), + ) + assert not DeepDiff( + recovered.train_config.model_dump(mode="json"), + train_config.model_dump(mode="json"), ) - assert hyper_parameters["train_config"]["optimizer"]["t_max"] == 123 - assert "learning_rate" not in hyper_parameters - assert "t_max" not in hyper_parameters -def test_configure_optimizers_uses_train_config_values(): - config = BatDetect2Config() - config.train.optimizer.learning_rate = 5e-4 - config.train.optimizer.t_max = 321 +def test_configure_optimizers_uses_train_config_values(tmp_path: Path): + model_config = ModelConfig() + expected_model_config = ModelConfig.model_validate( + model_config.model_dump(mode="json") + ) + train_config = TrainingConfig() + train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4) + train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=321) - module = build_default_module(config=config) + module = build_training_module( + model_config=model_config, + train_config=train_config, + ) - optimizers, schedulers = module.configure_optimizers() + optimization_config = module.configure_optimizers() + optimizer = optimization_config["optimizer"] + scheduler = optimization_config["lr_scheduler"]["scheduler"] - assert optimizers[0].param_groups[0]["lr"] == 5e-4 - assert schedulers[0].T_max == 321 + assert isinstance(optimizer, Adam) + assert isinstance(scheduler, CosineAnnealingLR) + assert optimizer.param_groups[0]["lr"] == 5e-4 + assert scheduler.T_max == 321 + + trainer = L.Trainer() + path = tmp_path / "example.ckpt" + trainer.strategy.connect(module) + trainer.save_checkpoint(path) + + recovered = TrainingModule.load_from_checkpoint(path) + assert recovered.model_config.model_dump( + mode="json" + ) == expected_model_config.model_dump(mode="json") + assert recovered.train_config.model_dump( + mode="json" + ) == train_config.model_dump(mode="json") + + loaded_optimization_config = recovered.configure_optimizers() + loaded_optimizer = loaded_optimization_config["optimizer"] + loaded_scheduler = loaded_optimization_config["lr_scheduler"]["scheduler"] + assert loaded_optimizer.param_groups[0]["lr"] == 5e-4 + assert loaded_scheduler.T_max == 321 def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path): @@ -136,7 +195,9 @@ def test_train_smoke_produces_loadable_checkpoint( run_train( train_annotations=example_annotations[:1], val_annotations=example_annotations[:1], - config=config, + train_config=config.train, + model_config=config.model, + audio_config=config.audio, num_epochs=1, train_workers=0, val_workers=0, diff --git a/tests/test_train/test_optimizers.py b/tests/test_train/test_optimizers.py new file mode 100644 index 0000000..89c84bc --- /dev/null +++ b/tests/test_train/test_optimizers.py @@ -0,0 +1,22 @@ +from torch import nn +from torch.optim import SGD, Adam + +from batdetect2.train.optimizers import OptimizerImportConfig, build_optimizer + + +def test_build_optimizer_defaults_to_adam(): + model = nn.Linear(4, 2) + optimizer = build_optimizer(model.parameters()) + + assert isinstance(optimizer, Adam) + + +def test_build_optimizer_supports_import_config(): + model = nn.Linear(4, 2) + config = OptimizerImportConfig( + target="torch.optim.SGD", + arguments={"lr": 1e-3}, + ) + + optimizer = build_optimizer(model.parameters(), config=config) + assert isinstance(optimizer, SGD) diff --git a/tests/test_train/test_schedulers.py b/tests/test_train/test_schedulers.py new file mode 100644 index 0000000..a489b54 --- /dev/null +++ b/tests/test_train/test_schedulers.py @@ -0,0 +1,35 @@ +from torch import nn +from torch.optim import SGD +from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR + +from batdetect2.train.schedulers import ( + CosineAnnealingSchedulerConfig, + SchedulerImportConfig, + build_scheduler, +) + + +def test_build_scheduler_uses_epoch_t_max_directly(): + model = nn.Linear(4, 2) + optimizer = SGD(model.parameters(), lr=1e-3) + scheduler = build_scheduler( + optimizer, + config=CosineAnnealingSchedulerConfig(t_max=7), + ) + + assert isinstance(scheduler, CosineAnnealingLR) + assert scheduler.T_max == 7 + + +def test_build_scheduler_supports_import_config(): + model = nn.Linear(4, 2) + optimizer = SGD(model.parameters(), lr=1e-3) + scheduler = build_scheduler( + optimizer, + config=SchedulerImportConfig( + target="torch.optim.lr_scheduler.StepLR", + arguments={"step_size": 2}, + ), + ) + + assert isinstance(scheduler, StepLR)