mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Added weights_only argument to model loading function
This commit is contained in:
parent
c393c5c29b
commit
d085b3212c
@ -85,6 +85,7 @@ def load_model(
|
||||
model_path: str = DEFAULT_MODEL_PATH,
|
||||
load_weights: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
weights_only: bool = True,
|
||||
) -> Tuple[DetectionModel, ModelParameters]:
|
||||
"""Load model from file.
|
||||
|
||||
@ -105,7 +106,11 @@ def load_model(
|
||||
if not os.path.isfile(model_path):
|
||||
raise FileNotFoundError("Model file not found.")
|
||||
|
||||
net_params = torch.load(model_path, map_location=device)
|
||||
net_params = torch.load(
|
||||
model_path,
|
||||
map_location=device,
|
||||
weights_only=weights_only,
|
||||
)
|
||||
|
||||
params = net_params["params"]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user