mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Added tests for api and load_audio
This commit is contained in:
parent
e10e270de4
commit
47dbdc79c2
@ -10,11 +10,13 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from batdetect2 import api
|
||||
import io
|
||||
|
||||
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
||||
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
|
||||
|
||||
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
||||
|
||||
def test_load_model_with_default_params():
|
||||
"""Test loading model with default parameters."""
|
||||
@ -280,3 +282,28 @@ def test_process_file_with_empty_predictions_does_not_fail(
|
||||
|
||||
assert results is not None
|
||||
assert len(results["pred_dict"]["annotation"]) == 0
|
||||
|
||||
def test_process_file_file_id_defaults_to_basename():
|
||||
"""Test that no detections are made above the nyquist frequency."""
|
||||
# Recording donated by @@kdarras
|
||||
basename = "20230322_172000_selec2.wav"
|
||||
path = os.path.join(DATA_DIR, basename)
|
||||
|
||||
output = api.process_file(path)
|
||||
predictions = output["pred_dict"]
|
||||
id = predictions["id"]
|
||||
assert id == basename
|
||||
|
||||
def test_bytesio_file_id_defaults_to_md5():
|
||||
"""Test that no detections are made above the nyquist frequency."""
|
||||
# Recording donated by @@kdarras
|
||||
basename = "20230322_172000_selec2.wav"
|
||||
path = os.path.join(DATA_DIR, basename)
|
||||
|
||||
with open(path, "rb") as f:
|
||||
data = io.BytesIO(f.read())
|
||||
|
||||
output = api.process_file(data)
|
||||
predictions = output["pred_dict"]
|
||||
id = predictions["id"]
|
||||
assert id == "7ade9ebf1a9fe5477ff3a2dc57001929"
|
||||
|
@ -7,7 +7,9 @@ from hypothesis import strategies as st
|
||||
from batdetect2.detector import parameters
|
||||
from batdetect2.utils import audio_utils, detector_utils
|
||||
import io
|
||||
import requests
|
||||
import os
|
||||
|
||||
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
||||
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
def test_can_compute_correct_spectrogram_width(duration: float):
|
||||
@ -144,3 +146,48 @@ def test_get_samplerate_using_bytesio():
|
||||
|
||||
expected_sample_rate = 500000
|
||||
assert expected_sample_rate == sample_rate
|
||||
|
||||
|
||||
|
||||
def test_load_audio_using_bytes():
|
||||
filename = "example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav"
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
audio_bytes = io.BytesIO(f.read())
|
||||
|
||||
sample_rate, audio_data = audio_utils.load_audio(audio_bytes, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ)
|
||||
|
||||
expected_sample_rate, expected_audio_data = audio_utils.load_audio(filename, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ)
|
||||
|
||||
assert expected_sample_rate == sample_rate
|
||||
|
||||
assert np.array_equal(audio_data, expected_audio_data)
|
||||
|
||||
|
||||
|
||||
def test_get_samplerate_using_bytesio_2():
|
||||
basename = "20230322_172000_selec2.wav"
|
||||
path = os.path.join(DATA_DIR, basename)
|
||||
|
||||
with open(path, "rb") as f:
|
||||
audio_bytes = io.BytesIO(f.read())
|
||||
|
||||
sample_rate = audio_utils.get_samplerate(audio_bytes)
|
||||
|
||||
expected_sample_rate = 192_000
|
||||
assert expected_sample_rate == sample_rate
|
||||
|
||||
def test_load_audio_using_bytes_2():
|
||||
basename = "20230322_172000_selec2.wav"
|
||||
path = os.path.join(DATA_DIR, basename)
|
||||
|
||||
with open(path, "rb") as f:
|
||||
data = io.BytesIO(f.read())
|
||||
|
||||
sample_rate, audio_data = audio_utils.load_audio(data, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ)
|
||||
|
||||
expected_sample_rate, expected_audio_data = audio_utils.load_audio(path, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ)
|
||||
|
||||
assert expected_sample_rate == sample_rate
|
||||
|
||||
assert np.array_equal(audio_data, expected_audio_data)
|
Loading…
Reference in New Issue
Block a user