batdetect2/tests/test_train/test_lightning.py
2026-03-17 15:38:07 +00:00

165 lines
4.9 KiB
Python

from pathlib import Path
import lightning as L
import torch
from soundevent import data
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import BatDetect2Config
from batdetect2.train import (
TrainingModule,
load_model_from_checkpoint,
run_train,
)
from batdetect2.train.train import build_training_module
from batdetect2.typing.preprocess import AudioLoader
def build_default_module(config: BatDetect2Config | None = None):
config = config or BatDetect2Config()
return build_training_module(
model_config=config.model.model_dump(mode="json"),
train_config=config.train.model_dump(mode="json"),
)
def test_can_initialize_default_module():
module = build_default_module()
assert isinstance(module, L.LightningModule)
def test_can_save_checkpoint(
tmp_path: Path,
clip: data.Clip,
sample_audio_loader: AudioLoader,
):
module = build_default_module()
trainer = L.Trainer()
path = tmp_path / "example.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(path)
recovered = TrainingModule.load_from_checkpoint(path)
wav = torch.tensor(sample_audio_loader.load_clip(clip)).unsqueeze(0)
spec1 = module.model.preprocessor(wav)
spec2 = recovered.model.preprocessor(wav)
torch.testing.assert_close(spec1, spec2, rtol=0, atol=0)
output1 = module.model(wav.unsqueeze(0))
output2 = recovered.model(wav.unsqueeze(0))
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
def test_load_model_from_checkpoint_returns_model_and_config(
tmp_path: Path,
):
module = build_default_module()
trainer = L.Trainer()
path = tmp_path / "example.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(path)
model, model_config = load_model_from_checkpoint(path)
assert model is not None
assert model_config.model_dump(
mode="json"
) == module.model_config.model_dump(mode="json")
def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
config = BatDetect2Config()
config.train.optimizer.learning_rate = 7e-4
config.train.optimizer.t_max = 123
module = build_default_module(config=config)
trainer = L.Trainer()
path = tmp_path / "example.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(path)
checkpoint = torch.load(path, map_location="cpu", weights_only=False)
hyper_parameters = checkpoint["hyper_parameters"]
assert (
hyper_parameters["train_config"]["optimizer"]["learning_rate"] == 7e-4
)
assert hyper_parameters["train_config"]["optimizer"]["t_max"] == 123
assert "learning_rate" not in hyper_parameters
assert "t_max" not in hyper_parameters
def test_configure_optimizers_uses_train_config_values():
config = BatDetect2Config()
config.train.optimizer.learning_rate = 5e-4
config.train.optimizer.t_max = 321
module = build_default_module(config=config)
optimizers, schedulers = module.configure_optimizers()
assert optimizers[0].param_groups[0]["lr"] == 5e-4
assert schedulers[0].T_max == 321
def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
module = build_default_module()
trainer = L.Trainer()
path = tmp_path / "example.ckpt"
trainer.strategy.connect(module)
trainer.save_checkpoint(path)
api = BatDetect2API.from_checkpoint(path)
assert api.config.model.model_dump(
mode="json"
) == module.model_config.model_dump(mode="json")
assert api.config.audio.samplerate == module.model_config.samplerate
def test_train_smoke_produces_loadable_checkpoint(
tmp_path: Path,
example_annotations: list[data.ClipAnnotation],
sample_audio_loader: AudioLoader,
):
config = BatDetect2Config()
config.train.trainer.limit_train_batches = 1
config.train.trainer.limit_val_batches = 1
config.train.trainer.log_every_n_steps = 1
config.train.train_loader.batch_size = 1
config.train.train_loader.augmentations.enabled = False
run_train(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
config=config,
num_epochs=1,
train_workers=0,
val_workers=0,
checkpoint_dir=tmp_path,
seed=0,
)
checkpoints = list(tmp_path.rglob("*.ckpt"))
assert checkpoints
model, model_config = load_model_from_checkpoint(checkpoints[0])
assert model_config.samplerate == config.model.samplerate
assert model_config.architecture.name == config.model.architecture.name
assert model_config.preprocess.model_dump(
mode="json"
) == config.model.preprocess.model_dump(mode="json")
assert model_config.postprocess.model_dump(
mode="json"
) == config.model.postprocess.model_dump(mode="json")
wav = torch.tensor(
sample_audio_loader.load_clip(example_annotations[0].clip)
).unsqueeze(0)
outputs = model(wav.unsqueeze(0))
assert outputs is not None