From 47dbdc79c236867f0d5c25daa5b86e0a1737eb2a Mon Sep 17 00:00:00 2001 From: Kavi Date: Wed, 26 Feb 2025 14:12:42 +0100 Subject: [PATCH] Added tests for api and load_audio --- tests/test_api.py | 27 +++++++++++++++++++++ tests/test_audio_utils.py | 49 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/tests/test_api.py b/tests/test_api.py index e828c9e..51149e1 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -10,11 +10,13 @@ import torch from torch import nn from batdetect2 import api +import io PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio") TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav")) +DATA_DIR = os.path.join(os.path.dirname(__file__), "data") def test_load_model_with_default_params(): """Test loading model with default parameters.""" @@ -280,3 +282,28 @@ def test_process_file_with_empty_predictions_does_not_fail( assert results is not None assert len(results["pred_dict"]["annotation"]) == 0 + +def test_process_file_file_id_defaults_to_basename(): + """Test that no detections are made above the nyquist frequency.""" + # Recording donated by @@kdarras + basename = "20230322_172000_selec2.wav" + path = os.path.join(DATA_DIR, basename) + + output = api.process_file(path) + predictions = output["pred_dict"] + id = predictions["id"] + assert id == basename + +def test_bytesio_file_id_defaults_to_md5(): + """Test that no detections are made above the nyquist frequency.""" + # Recording donated by @@kdarras + basename = "20230322_172000_selec2.wav" + path = os.path.join(DATA_DIR, basename) + + with open(path, "rb") as f: + data = io.BytesIO(f.read()) + + output = api.process_file(data) + predictions = output["pred_dict"] + id = predictions["id"] + assert id == "7ade9ebf1a9fe5477ff3a2dc57001929" diff --git a/tests/test_audio_utils.py b/tests/test_audio_utils.py index c223ecf..9a2afc0 100644 --- a/tests/test_audio_utils.py +++ b/tests/test_audio_utils.py @@ -7,7 +7,9 @@ from hypothesis import strategies as st from batdetect2.detector import parameters from batdetect2.utils import audio_utils, detector_utils import io -import requests +import os + +DATA_DIR = os.path.join(os.path.dirname(__file__), "data") @given(duration=st.floats(min_value=0.1, max_value=2)) def test_can_compute_correct_spectrogram_width(duration: float): @@ -144,3 +146,48 @@ def test_get_samplerate_using_bytesio(): expected_sample_rate = 500000 assert expected_sample_rate == sample_rate + + + +def test_load_audio_using_bytes(): + filename = "example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav" + + with open(filename, "rb") as f: + audio_bytes = io.BytesIO(f.read()) + + sample_rate, audio_data = audio_utils.load_audio(audio_bytes, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ) + + expected_sample_rate, expected_audio_data = audio_utils.load_audio(filename, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ) + + assert expected_sample_rate == sample_rate + + assert np.array_equal(audio_data, expected_audio_data) + + + +def test_get_samplerate_using_bytesio_2(): + basename = "20230322_172000_selec2.wav" + path = os.path.join(DATA_DIR, basename) + + with open(path, "rb") as f: + audio_bytes = io.BytesIO(f.read()) + + sample_rate = audio_utils.get_samplerate(audio_bytes) + + expected_sample_rate = 192_000 + assert expected_sample_rate == sample_rate + +def test_load_audio_using_bytes_2(): + basename = "20230322_172000_selec2.wav" + path = os.path.join(DATA_DIR, basename) + + with open(path, "rb") as f: + data = io.BytesIO(f.read()) + + sample_rate, audio_data = audio_utils.load_audio(data, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ) + + expected_sample_rate, expected_audio_data = audio_utils.load_audio(path, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ) + + assert expected_sample_rate == sample_rate + + assert np.array_equal(audio_data, expected_audio_data) \ No newline at end of file