mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Merge pull request #37 from macaodha/fix/GH-30-torch-deprecation-warning-weights-only
fix: Address PyTorch Model Loading Deprecation Warning (GH-30)
This commit is contained in:
commit
4627ddd739
@ -85,6 +85,7 @@ def load_model(
|
|||||||
model_path: str = DEFAULT_MODEL_PATH,
|
model_path: str = DEFAULT_MODEL_PATH,
|
||||||
load_weights: bool = True,
|
load_weights: bool = True,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
|
weights_only: bool = True,
|
||||||
) -> Tuple[DetectionModel, ModelParameters]:
|
) -> Tuple[DetectionModel, ModelParameters]:
|
||||||
"""Load model from file.
|
"""Load model from file.
|
||||||
|
|
||||||
@ -105,7 +106,11 @@ def load_model(
|
|||||||
if not os.path.isfile(model_path):
|
if not os.path.isfile(model_path):
|
||||||
raise FileNotFoundError("Model file not found.")
|
raise FileNotFoundError("Model file not found.")
|
||||||
|
|
||||||
net_params = torch.load(model_path, map_location=device)
|
net_params = torch.load(
|
||||||
|
model_path,
|
||||||
|
map_location=device,
|
||||||
|
weights_only=weights_only,
|
||||||
|
)
|
||||||
|
|
||||||
params = net_params["params"]
|
params = net_params["params"]
|
||||||
|
|
||||||
|
@ -1,8 +1,31 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_data_dir() -> Path:
|
||||||
|
pkg_dir = Path(__file__).parent.parent
|
||||||
|
example_data_dir = pkg_dir / "example_data"
|
||||||
|
assert example_data_dir.exists()
|
||||||
|
return example_data_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_audio_dir(example_data_dir: Path) -> Path:
|
||||||
|
example_audio_dir = example_data_dir / "audio"
|
||||||
|
assert example_audio_dir.exists()
|
||||||
|
return example_audio_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_audio_files(example_audio_dir: Path) -> List[Path]:
|
||||||
|
audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]"))
|
||||||
|
assert len(audio_files) == 3
|
||||||
|
return audio_files
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def data_dir() -> Path:
|
def data_dir() -> Path:
|
||||||
dir = Path(__file__).parent / "data"
|
dir = Path(__file__).parent / "data"
|
||||||
|
@ -1,14 +1,13 @@
|
|||||||
"""Test bat detect module API."""
|
"""Test bat detect module API."""
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import soundfile as sf
|
|
||||||
|
|
||||||
from batdetect2 import api
|
from batdetect2 import api
|
||||||
|
|
||||||
@ -267,7 +266,6 @@ def test_process_file_with_spec_slices():
|
|||||||
assert len(results["spec_slices"]) == len(detections)
|
assert len(results["spec_slices"]) == len(detections)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_file_with_empty_predictions_does_not_fail(
|
def test_process_file_with_empty_predictions_does_not_fail(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
):
|
):
|
||||||
|
78
tests/test_model.py
Normal file
78
tests/test_model.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
"""Test suite for model functions."""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from hypothesis import given, settings
|
||||||
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
|
from batdetect2 import api
|
||||||
|
from batdetect2.detector import parameters
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_import_model_without_warnings():
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("error")
|
||||||
|
api.load_model()
|
||||||
|
|
||||||
|
|
||||||
|
@settings(deadline=None, max_examples=5)
|
||||||
|
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||||
|
def test_can_import_model_without_pickle(duration: float):
|
||||||
|
# NOTE: remove this test once no other issues are found This is a temporary
|
||||||
|
# test to check that change in model loading did not impact model behaviour
|
||||||
|
# in any way.
|
||||||
|
|
||||||
|
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||||
|
audio = np.random.rand(int(duration * samplerate))
|
||||||
|
|
||||||
|
model_without_pickle, model_params_without_pickle = api.load_model(
|
||||||
|
weights_only=True
|
||||||
|
)
|
||||||
|
model_with_pickle, model_params_with_pickle = api.load_model(
|
||||||
|
weights_only=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert model_params_without_pickle == model_params_with_pickle
|
||||||
|
|
||||||
|
predictions_without_pickle, _, _ = api.process_audio(
|
||||||
|
audio,
|
||||||
|
model=model_without_pickle,
|
||||||
|
)
|
||||||
|
predictions_with_pickle, _, _ = api.process_audio(
|
||||||
|
audio,
|
||||||
|
model=model_with_pickle,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert predictions_without_pickle == predictions_with_pickle
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_import_model_without_pickle_on_test_data(
|
||||||
|
example_audio_files: List[Path],
|
||||||
|
):
|
||||||
|
# NOTE: remove this test once no other issues are found This is a temporary
|
||||||
|
# test to check that change in model loading did not impact model behaviour
|
||||||
|
# in any way.
|
||||||
|
|
||||||
|
model_without_pickle, model_params_without_pickle = api.load_model(
|
||||||
|
weights_only=True
|
||||||
|
)
|
||||||
|
model_with_pickle, model_params_with_pickle = api.load_model(
|
||||||
|
weights_only=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert model_params_without_pickle == model_params_with_pickle
|
||||||
|
|
||||||
|
for audio_file in example_audio_files:
|
||||||
|
audio = api.load_audio(str(audio_file))
|
||||||
|
predictions_without_pickle, _, _ = api.process_audio(
|
||||||
|
audio,
|
||||||
|
model=model_without_pickle,
|
||||||
|
)
|
||||||
|
predictions_with_pickle, _, _ = api.process_audio(
|
||||||
|
audio,
|
||||||
|
model=model_with_pickle,
|
||||||
|
)
|
||||||
|
assert predictions_without_pickle == predictions_with_pickle
|
Loading…
Reference in New Issue
Block a user