mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
57 lines
1.6 KiB
Python
57 lines
1.6 KiB
Python
from pathlib import Path
|
|
|
|
import lightning as L
|
|
import torch
|
|
import xarray as xr
|
|
from soundevent import data
|
|
|
|
from batdetect2.models import build_model
|
|
from batdetect2.postprocess import build_postprocessor
|
|
from batdetect2.preprocess import build_preprocessor
|
|
from batdetect2.targets import build_targets
|
|
from batdetect2.train.lightning import TrainingModule
|
|
from batdetect2.train.losses import build_loss
|
|
|
|
|
|
def build_default_module():
|
|
loss = build_loss()
|
|
targets = build_targets()
|
|
detector = build_model(num_classes=len(targets.class_names))
|
|
preprocessor = build_preprocessor()
|
|
postprocessor = build_postprocessor(targets)
|
|
return TrainingModule(
|
|
detector=detector,
|
|
loss=loss,
|
|
targets=targets,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
)
|
|
|
|
|
|
def test_can_initialize_default_module():
|
|
module = build_default_module()
|
|
assert isinstance(module, L.LightningModule)
|
|
|
|
|
|
def test_can_save_checkpoint(tmp_path: Path, clip: data.Clip):
|
|
module = build_default_module()
|
|
trainer = L.Trainer()
|
|
path = tmp_path / "example.ckpt"
|
|
trainer.strategy.connect(module)
|
|
trainer.save_checkpoint(path)
|
|
|
|
recovered = TrainingModule.load_from_checkpoint(path)
|
|
|
|
spec1 = module.preprocessor.preprocess_clip(clip)
|
|
spec2 = recovered.preprocessor.preprocess_clip(clip)
|
|
|
|
xr.testing.assert_equal(spec1, spec2)
|
|
|
|
input1 = torch.tensor([spec1.values]).unsqueeze(0)
|
|
input2 = torch.tensor([spec2.values]).unsqueeze(0)
|
|
|
|
output1 = module(input1)
|
|
output2 = recovered(input2)
|
|
|
|
assert output1 == output2
|