diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 9f377a1..d3c1e95 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -28,7 +28,7 @@ from batdetect2.targets import build_targets from batdetect2.train import ( DEFAULT_CHECKPOINT_DIR, load_model_from_checkpoint, - train, + run_train, ) from batdetect2.typing import ( AudioLoader, @@ -84,7 +84,7 @@ class BatDetect2API: run_name: str | None = None, seed: int | None = None, ): - train( + run_train( train_annotations=train_annotations, val_annotations=val_annotations, targets=self.targets, diff --git a/src/batdetect2/train/__init__.py b/src/batdetect2/train/__init__.py index 3a44092..1f2e9ec 100644 --- a/src/batdetect2/train/__init__.py +++ b/src/batdetect2/train/__init__.py @@ -7,7 +7,7 @@ from batdetect2.train.lightning import ( TrainingModule, load_model_from_checkpoint, ) -from batdetect2.train.train import build_trainer, train +from batdetect2.train.train import build_trainer, run_train __all__ = [ "DEFAULT_CHECKPOINT_DIR", diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 6f1ef01..d11abf2 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -33,7 +33,7 @@ __all__ = [ ] -def train( +def run_train( train_annotations: Sequence[data.ClipAnnotation], val_annotations: Sequence[data.ClipAnnotation] | None = None, targets: Optional["TargetProtocol"] = None, @@ -126,6 +126,8 @@ def train( ) logger.info("Training complete.") + return module + def build_trainer( config: "BatDetect2Config", diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index b3ebfb0..ce4f537 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -4,14 +4,19 @@ import lightning as L import torch from soundevent import data +from batdetect2.api_v2 import BatDetect2API from batdetect2.config import BatDetect2Config -from batdetect2.train import TrainingModule +from batdetect2.train import ( + TrainingModule, + load_model_from_checkpoint, + run_train, +) from batdetect2.train.train import build_training_module from batdetect2.typing.preprocess import AudioLoader -def build_default_module(): - config = BatDetect2Config() +def build_default_module(config: BatDetect2Config | None = None): + config = config or BatDetect2Config() return build_training_module( model_config=config.model.model_dump(mode="json"), train_config=config.train.model_dump(mode="json"), @@ -47,3 +52,113 @@ def test_can_save_checkpoint( output2 = recovered.model(wav.unsqueeze(0)) torch.testing.assert_close(output1, output2, rtol=0, atol=0) + + +def test_load_model_from_checkpoint_returns_model_and_config( + tmp_path: Path, +): + module = build_default_module() + trainer = L.Trainer() + path = tmp_path / "example.ckpt" + trainer.strategy.connect(module) + trainer.save_checkpoint(path) + + model, model_config = load_model_from_checkpoint(path) + + assert model is not None + assert model_config.model_dump( + mode="json" + ) == module.model_config.model_dump(mode="json") + + +def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path): + config = BatDetect2Config() + config.train.optimizer.learning_rate = 7e-4 + config.train.optimizer.t_max = 123 + + module = build_default_module(config=config) + trainer = L.Trainer() + path = tmp_path / "example.ckpt" + trainer.strategy.connect(module) + trainer.save_checkpoint(path) + + checkpoint = torch.load(path, map_location="cpu", weights_only=False) + hyper_parameters = checkpoint["hyper_parameters"] + + assert ( + hyper_parameters["train_config"]["optimizer"]["learning_rate"] == 7e-4 + ) + assert hyper_parameters["train_config"]["optimizer"]["t_max"] == 123 + assert "learning_rate" not in hyper_parameters + assert "t_max" not in hyper_parameters + + +def test_configure_optimizers_uses_train_config_values(): + config = BatDetect2Config() + config.train.optimizer.learning_rate = 5e-4 + config.train.optimizer.t_max = 321 + + module = build_default_module(config=config) + + optimizers, schedulers = module.configure_optimizers() + + assert optimizers[0].param_groups[0]["lr"] == 5e-4 + assert schedulers[0].T_max == 321 + + +def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path): + module = build_default_module() + trainer = L.Trainer() + path = tmp_path / "example.ckpt" + trainer.strategy.connect(module) + trainer.save_checkpoint(path) + + api = BatDetect2API.from_checkpoint(path) + + assert api.config.model.model_dump( + mode="json" + ) == module.model_config.model_dump(mode="json") + assert api.config.audio.samplerate == module.model_config.samplerate + + +def test_train_smoke_produces_loadable_checkpoint( + tmp_path: Path, + example_annotations: list[data.ClipAnnotation], + sample_audio_loader: AudioLoader, +): + config = BatDetect2Config() + config.train.trainer.limit_train_batches = 1 + config.train.trainer.limit_val_batches = 1 + config.train.trainer.log_every_n_steps = 1 + config.train.train_loader.batch_size = 1 + config.train.train_loader.augmentations.enabled = False + + run_train( + train_annotations=example_annotations[:1], + val_annotations=example_annotations[:1], + config=config, + num_epochs=1, + train_workers=0, + val_workers=0, + checkpoint_dir=tmp_path, + seed=0, + ) + + checkpoints = list(tmp_path.rglob("*.ckpt")) + assert checkpoints + + model, model_config = load_model_from_checkpoint(checkpoints[0]) + assert model_config.samplerate == config.model.samplerate + assert model_config.architecture.name == config.model.architecture.name + assert model_config.preprocess.model_dump( + mode="json" + ) == config.model.preprocess.model_dump(mode="json") + assert model_config.postprocess.model_dump( + mode="json" + ) == config.model.postprocess.model_dump(mode="json") + + wav = torch.tensor( + sample_audio_loader.load_clip(example_annotations[0].clip) + ).unsqueeze(0) + outputs = model(wav.unsqueeze(0)) + assert outputs is not None