diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index 830e1e4..7fa8b39 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -4,16 +4,15 @@ import lightning as L import torch from soundevent import data -from batdetect2.models import build_model -from batdetect2.train import FullTrainingConfig, TrainingModule +from batdetect2.config import BatDetect2Config +from batdetect2.train import TrainingModule from batdetect2.train.train import build_training_module from batdetect2.typing.preprocess import AudioLoader def build_default_module(): - model = build_model() - config = FullTrainingConfig() - return build_training_module(model, config=config) + config = BatDetect2Config() + return build_training_module(config=config.model_dump()) def test_can_initialize_default_module():