Added weights_only argument to model loading function

This commit is contained in:
mbsantiago 2024-11-11 11:46:06 +00:00
parent c393c5c29b
commit d085b3212c

View File

@ -85,6 +85,7 @@ def load_model(
model_path: str = DEFAULT_MODEL_PATH, model_path: str = DEFAULT_MODEL_PATH,
load_weights: bool = True, load_weights: bool = True,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
weights_only: bool = True,
) -> Tuple[DetectionModel, ModelParameters]: ) -> Tuple[DetectionModel, ModelParameters]:
"""Load model from file. """Load model from file.
@ -105,7 +106,11 @@ def load_model(
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
raise FileNotFoundError("Model file not found.") raise FileNotFoundError("Model file not found.")
net_params = torch.load(model_path, map_location=device) net_params = torch.load(
model_path,
map_location=device,
weights_only=weights_only,
)
params = net_params["params"] params = net_params["params"]