mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 23:30:21 +02:00
Added a full test of training and saving
This commit is contained in:
parent
56f6affc72
commit
615c7d78fb
@ -28,7 +28,7 @@ from batdetect2.targets import build_targets
|
|||||||
from batdetect2.train import (
|
from batdetect2.train import (
|
||||||
DEFAULT_CHECKPOINT_DIR,
|
DEFAULT_CHECKPOINT_DIR,
|
||||||
load_model_from_checkpoint,
|
load_model_from_checkpoint,
|
||||||
train,
|
run_train,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
AudioLoader,
|
AudioLoader,
|
||||||
@ -84,7 +84,7 @@ class BatDetect2API:
|
|||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
seed: int | None = None,
|
seed: int | None = None,
|
||||||
):
|
):
|
||||||
train(
|
run_train(
|
||||||
train_annotations=train_annotations,
|
train_annotations=train_annotations,
|
||||||
val_annotations=val_annotations,
|
val_annotations=val_annotations,
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from batdetect2.train.lightning import (
|
|||||||
TrainingModule,
|
TrainingModule,
|
||||||
load_model_from_checkpoint,
|
load_model_from_checkpoint,
|
||||||
)
|
)
|
||||||
from batdetect2.train.train import build_trainer, train
|
from batdetect2.train.train import build_trainer, run_train
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DEFAULT_CHECKPOINT_DIR",
|
"DEFAULT_CHECKPOINT_DIR",
|
||||||
|
|||||||
@ -33,7 +33,7 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def run_train(
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||||
targets: Optional["TargetProtocol"] = None,
|
targets: Optional["TargetProtocol"] = None,
|
||||||
@ -126,6 +126,8 @@ def train(
|
|||||||
)
|
)
|
||||||
logger.info("Training complete.")
|
logger.info("Training complete.")
|
||||||
|
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
def build_trainer(
|
def build_trainer(
|
||||||
config: "BatDetect2Config",
|
config: "BatDetect2Config",
|
||||||
|
|||||||
@ -4,14 +4,19 @@ import lightning as L
|
|||||||
import torch
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.train import TrainingModule
|
from batdetect2.train import (
|
||||||
|
TrainingModule,
|
||||||
|
load_model_from_checkpoint,
|
||||||
|
run_train,
|
||||||
|
)
|
||||||
from batdetect2.train.train import build_training_module
|
from batdetect2.train.train import build_training_module
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
from batdetect2.typing.preprocess import AudioLoader
|
||||||
|
|
||||||
|
|
||||||
def build_default_module():
|
def build_default_module(config: BatDetect2Config | None = None):
|
||||||
config = BatDetect2Config()
|
config = config or BatDetect2Config()
|
||||||
return build_training_module(
|
return build_training_module(
|
||||||
model_config=config.model.model_dump(mode="json"),
|
model_config=config.model.model_dump(mode="json"),
|
||||||
train_config=config.train.model_dump(mode="json"),
|
train_config=config.train.model_dump(mode="json"),
|
||||||
@ -47,3 +52,113 @@ def test_can_save_checkpoint(
|
|||||||
output2 = recovered.model(wav.unsqueeze(0))
|
output2 = recovered.model(wav.unsqueeze(0))
|
||||||
|
|
||||||
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
torch.testing.assert_close(output1, output2, rtol=0, atol=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_model_from_checkpoint_returns_model_and_config(
|
||||||
|
tmp_path: Path,
|
||||||
|
):
|
||||||
|
module = build_default_module()
|
||||||
|
trainer = L.Trainer()
|
||||||
|
path = tmp_path / "example.ckpt"
|
||||||
|
trainer.strategy.connect(module)
|
||||||
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
|
model, model_config = load_model_from_checkpoint(path)
|
||||||
|
|
||||||
|
assert model is not None
|
||||||
|
assert model_config.model_dump(
|
||||||
|
mode="json"
|
||||||
|
) == module.model_config.model_dump(mode="json")
|
||||||
|
|
||||||
|
|
||||||
|
def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
|
||||||
|
config = BatDetect2Config()
|
||||||
|
config.train.optimizer.learning_rate = 7e-4
|
||||||
|
config.train.optimizer.t_max = 123
|
||||||
|
|
||||||
|
module = build_default_module(config=config)
|
||||||
|
trainer = L.Trainer()
|
||||||
|
path = tmp_path / "example.ckpt"
|
||||||
|
trainer.strategy.connect(module)
|
||||||
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
|
checkpoint = torch.load(path, map_location="cpu", weights_only=False)
|
||||||
|
hyper_parameters = checkpoint["hyper_parameters"]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
hyper_parameters["train_config"]["optimizer"]["learning_rate"] == 7e-4
|
||||||
|
)
|
||||||
|
assert hyper_parameters["train_config"]["optimizer"]["t_max"] == 123
|
||||||
|
assert "learning_rate" not in hyper_parameters
|
||||||
|
assert "t_max" not in hyper_parameters
|
||||||
|
|
||||||
|
|
||||||
|
def test_configure_optimizers_uses_train_config_values():
|
||||||
|
config = BatDetect2Config()
|
||||||
|
config.train.optimizer.learning_rate = 5e-4
|
||||||
|
config.train.optimizer.t_max = 321
|
||||||
|
|
||||||
|
module = build_default_module(config=config)
|
||||||
|
|
||||||
|
optimizers, schedulers = module.configure_optimizers()
|
||||||
|
|
||||||
|
assert optimizers[0].param_groups[0]["lr"] == 5e-4
|
||||||
|
assert schedulers[0].T_max == 321
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
|
||||||
|
module = build_default_module()
|
||||||
|
trainer = L.Trainer()
|
||||||
|
path = tmp_path / "example.ckpt"
|
||||||
|
trainer.strategy.connect(module)
|
||||||
|
trainer.save_checkpoint(path)
|
||||||
|
|
||||||
|
api = BatDetect2API.from_checkpoint(path)
|
||||||
|
|
||||||
|
assert api.config.model.model_dump(
|
||||||
|
mode="json"
|
||||||
|
) == module.model_config.model_dump(mode="json")
|
||||||
|
assert api.config.audio.samplerate == module.model_config.samplerate
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_smoke_produces_loadable_checkpoint(
|
||||||
|
tmp_path: Path,
|
||||||
|
example_annotations: list[data.ClipAnnotation],
|
||||||
|
sample_audio_loader: AudioLoader,
|
||||||
|
):
|
||||||
|
config = BatDetect2Config()
|
||||||
|
config.train.trainer.limit_train_batches = 1
|
||||||
|
config.train.trainer.limit_val_batches = 1
|
||||||
|
config.train.trainer.log_every_n_steps = 1
|
||||||
|
config.train.train_loader.batch_size = 1
|
||||||
|
config.train.train_loader.augmentations.enabled = False
|
||||||
|
|
||||||
|
run_train(
|
||||||
|
train_annotations=example_annotations[:1],
|
||||||
|
val_annotations=example_annotations[:1],
|
||||||
|
config=config,
|
||||||
|
num_epochs=1,
|
||||||
|
train_workers=0,
|
||||||
|
val_workers=0,
|
||||||
|
checkpoint_dir=tmp_path,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoints = list(tmp_path.rglob("*.ckpt"))
|
||||||
|
assert checkpoints
|
||||||
|
|
||||||
|
model, model_config = load_model_from_checkpoint(checkpoints[0])
|
||||||
|
assert model_config.samplerate == config.model.samplerate
|
||||||
|
assert model_config.architecture.name == config.model.architecture.name
|
||||||
|
assert model_config.preprocess.model_dump(
|
||||||
|
mode="json"
|
||||||
|
) == config.model.preprocess.model_dump(mode="json")
|
||||||
|
assert model_config.postprocess.model_dump(
|
||||||
|
mode="json"
|
||||||
|
) == config.model.postprocess.model_dump(mode="json")
|
||||||
|
|
||||||
|
wav = torch.tensor(
|
||||||
|
sample_audio_loader.load_clip(example_annotations[0].clip)
|
||||||
|
).unsqueeze(0)
|
||||||
|
outputs = model(wav.unsqueeze(0))
|
||||||
|
assert outputs is not None
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user