mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-23 06:41:53 +02:00
fix: load checkpoints on cpu
Use CPU map_location when restoring Lightning checkpoints so packaged models load reliably without requiring accelerator-specific device state.
This commit is contained in:
parent
84918086c8
commit
31054f64f6
@ -1,6 +1,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
|
import torch
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.models import Model, ModelConfig, build_model
|
from batdetect2.models import Model, ModelConfig, build_model
|
||||||
@ -148,7 +149,10 @@ def load_model_from_checkpoint(
|
|||||||
describes its architecture, preprocessing, and postprocessing.
|
describes its architecture, preprocessing, and postprocessing.
|
||||||
"""
|
"""
|
||||||
resolved_path = resolve_checkpoint_path(path)
|
resolved_path = resolve_checkpoint_path(path)
|
||||||
module = TrainingModule.load_from_checkpoint(resolved_path)
|
module = TrainingModule.load_from_checkpoint(
|
||||||
|
resolved_path,
|
||||||
|
map_location=torch.device("cpu"),
|
||||||
|
)
|
||||||
training_config = TrainingConfig.model_validate(module.train_config)
|
training_config = TrainingConfig.model_validate(module.train_config)
|
||||||
model_config = ModelConfig.model_validate(module.model_config)
|
model_config = ModelConfig.model_validate(module.model_config)
|
||||||
targets_config = TargetConfig.model_validate(module.targets_config)
|
targets_config = TargetConfig.model_validate(module.targets_config)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user