diff --git a/batdetect2/utils/detector_utils.py b/batdetect2/utils/detector_utils.py index 66b9b19..63643b6 100644 --- a/batdetect2/utils/detector_utils.py +++ b/batdetect2/utils/detector_utils.py @@ -85,6 +85,7 @@ def load_model( model_path: str = DEFAULT_MODEL_PATH, load_weights: bool = True, device: Optional[torch.device] = None, + weights_only: bool = True, ) -> Tuple[DetectionModel, ModelParameters]: """Load model from file. @@ -105,7 +106,11 @@ def load_model( if not os.path.isfile(model_path): raise FileNotFoundError("Model file not found.") - net_params = torch.load(model_path, map_location=device) + net_params = torch.load( + model_path, + map_location=device, + weights_only=weights_only, + ) params = net_params["params"]