From d085b3212c8f48e959aee4b1dd860d7bac43ac0c Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 11 Nov 2024 11:46:06 +0000 Subject: [PATCH] Added weights_only argument to model loading function --- batdetect2/utils/detector_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/batdetect2/utils/detector_utils.py b/batdetect2/utils/detector_utils.py index 66b9b19..63643b6 100644 --- a/batdetect2/utils/detector_utils.py +++ b/batdetect2/utils/detector_utils.py @@ -85,6 +85,7 @@ def load_model( model_path: str = DEFAULT_MODEL_PATH, load_weights: bool = True, device: Optional[torch.device] = None, + weights_only: bool = True, ) -> Tuple[DetectionModel, ModelParameters]: """Load model from file. @@ -105,7 +106,11 @@ def load_model( if not os.path.isfile(model_path): raise FileNotFoundError("Model file not found.") - net_params = torch.load(model_path, map_location=device) + net_params = torch.load( + model_path, + map_location=device, + weights_only=weights_only, + ) params = net_params["params"]