fix: load checkpoints on cpu

Use CPU map_location when restoring Lightning checkpoints so packaged models load reliably without requiring accelerator-specific device state.
This commit is contained in:
mbsantiago 2026-05-05 21:49:09 +01:00
parent 84918086c8
commit 31054f64f6

View File

@ -1,6 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
import lightning as L import lightning as L
import torch
from soundevent.data import PathLike from soundevent.data import PathLike
from batdetect2.models import Model, ModelConfig, build_model from batdetect2.models import Model, ModelConfig, build_model
@ -148,7 +149,10 @@ def load_model_from_checkpoint(
describes its architecture, preprocessing, and postprocessing. describes its architecture, preprocessing, and postprocessing.
""" """
resolved_path = resolve_checkpoint_path(path) resolved_path = resolve_checkpoint_path(path)
module = TrainingModule.load_from_checkpoint(resolved_path) module = TrainingModule.load_from_checkpoint(
resolved_path,
map_location=torch.device("cpu"),
)
training_config = TrainingConfig.model_validate(module.train_config) training_config = TrainingConfig.model_validate(module.train_config)
model_config = ModelConfig.model_validate(module.model_config) model_config = ModelConfig.model_validate(module.model_config)
targets_config = TargetConfig.model_validate(module.targets_config) targets_config = TargetConfig.model_validate(module.targets_config)