diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 9f602f6..0ccf5b0 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import lightning as L +import torch from soundevent.data import PathLike from batdetect2.models import Model, ModelConfig, build_model @@ -148,7 +149,10 @@ def load_model_from_checkpoint( describes its architecture, preprocessing, and postprocessing. """ resolved_path = resolve_checkpoint_path(path) - module = TrainingModule.load_from_checkpoint(resolved_path) + module = TrainingModule.load_from_checkpoint( + resolved_path, + map_location=torch.device("cpu"), + ) training_config = TrainingConfig.model_validate(module.train_config) model_config = ModelConfig.model_validate(module.model_config) targets_config = TargetConfig.model_validate(module.targets_config)