mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Added weights_only argument to model loading function
This commit is contained in:
parent
c393c5c29b
commit
d085b3212c
@ -85,6 +85,7 @@ def load_model(
|
|||||||
model_path: str = DEFAULT_MODEL_PATH,
|
model_path: str = DEFAULT_MODEL_PATH,
|
||||||
load_weights: bool = True,
|
load_weights: bool = True,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
|
weights_only: bool = True,
|
||||||
) -> Tuple[DetectionModel, ModelParameters]:
|
) -> Tuple[DetectionModel, ModelParameters]:
|
||||||
"""Load model from file.
|
"""Load model from file.
|
||||||
|
|
||||||
@ -105,7 +106,11 @@ def load_model(
|
|||||||
if not os.path.isfile(model_path):
|
if not os.path.isfile(model_path):
|
||||||
raise FileNotFoundError("Model file not found.")
|
raise FileNotFoundError("Model file not found.")
|
||||||
|
|
||||||
net_params = torch.load(model_path, map_location=device)
|
net_params = torch.load(
|
||||||
|
model_path,
|
||||||
|
map_location=device,
|
||||||
|
weights_only=weights_only,
|
||||||
|
)
|
||||||
|
|
||||||
params = net_params["params"]
|
params = net_params["params"]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user