From 31054f64f6296b3362b78577c2dff10ee563affe Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 5 May 2026 21:49:09 +0100 Subject: [PATCH] fix: load checkpoints on cpu Use CPU map_location when restoring Lightning checkpoints so packaged models load reliably without requiring accelerator-specific device state. --- src/batdetect2/train/lightning.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 9f602f6..0ccf5b0 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import lightning as L +import torch from soundevent.data import PathLike from batdetect2.models import Model, ModelConfig, build_model @@ -148,7 +149,10 @@ def load_model_from_checkpoint( describes its architecture, preprocessing, and postprocessing. """ resolved_path = resolve_checkpoint_path(path) - module = TrainingModule.load_from_checkpoint(resolved_path) + module = TrainingModule.load_from_checkpoint( + resolved_path, + map_location=torch.device("cpu"), + ) training_config = TrainingConfig.model_validate(module.train_config) model_config = ModelConfig.model_validate(module.model_config) targets_config = TargetConfig.model_validate(module.targets_config)