From 3288f52bbd3fbc07105882535b68f4aa6149e383 Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Thu, 3 Aug 2023 11:45:39 +0100 Subject: [PATCH] tests: added tests for feature computation --- batdetect2/detector/compute_features.py | 37 ++-- tests/test_features.py | 214 +++++++++++++++++++++++- 2 files changed, 235 insertions(+), 16 deletions(-) diff --git a/batdetect2/detector/compute_features.py b/batdetect2/detector/compute_features.py index e0c3a2d..b53b0cb 100644 --- a/batdetect2/detector/compute_features.py +++ b/batdetect2/detector/compute_features.py @@ -88,19 +88,28 @@ def compute_max_power_bb( return np.nan x_start = max(0, prediction["x_pos"]) - x_end = min(spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]) - - y_low = max(0, prediction["y_pos"]) - y_high = min( - spec.shape[0] - 1, prediction["y_pos"] + prediction["bb_height"] + x_end = min( + spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"] ) - spec_bb = spec[y_low:y_high, x_start:x_end] + # y low is the lowest freq but it will have a higher value due to array + # starting at 0 at top + y_low = min(spec.shape[0] - 1, prediction["y_pos"]) + y_high = max(0, prediction["y_pos"] - prediction["bb_height"]) + + spec_bb = spec[y_high:y_low, x_start:x_end] power_per_freq_band = np.sum(spec_bb, axis=1) - max_power_ind = np.argmax(power_per_freq_band) + + try: + max_power_ind = np.argmax(power_per_freq_band) + except ValueError: + # If the call is too short, the bounding box might be empty. + # In this case, return NaN. + return np.nan + return int( convert_int_to_freq( - y_low - max_power_ind, + y_high + max_power_ind, spec.shape[0], min_freq, max_freq, @@ -120,7 +129,9 @@ def compute_max_power( return np.nan x_start = max(0, prediction["x_pos"]) - x_end = min(spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]) + x_end = min( + spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"] + ) spec_call = spec[:, x_start:x_end] power_per_freq_band = np.sum(spec_call, axis=1) max_power_ind = np.argmax(power_per_freq_band) @@ -146,7 +157,9 @@ def compute_max_power_first( return np.nan x_start = max(0, prediction["x_pos"]) - x_end = min(spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]) + x_end = min( + spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"] + ) spec_call = spec[:, x_start:x_end] first_half = spec_call[:, : int(spec_call.shape[1] / 2)] power_per_freq_band = np.sum(first_half, axis=1) @@ -173,7 +186,9 @@ def compute_max_power_second( return np.nan x_start = max(0, prediction["x_pos"]) - x_end = min(spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]) + x_end = min( + spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"] + ) spec_call = spec[:, x_start:x_end] second_half = spec_call[:, int(spec_call.shape[1] / 2) :] power_per_freq_band = np.sum(second_half, axis=1) diff --git a/tests/test_features.py b/tests/test_features.py index 0394337..1271fda 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -1,9 +1,17 @@ """Test suite for feature extraction functions.""" +import logging + +import librosa import numpy as np +import pytest import batdetect2.detector.compute_features as feats -from batdetect2 import types +from batdetect2 import api, types +from batdetect2.utils import audio_utils as au + +numba_logger = logging.getLogger("numba") +numba_logger.setLevel(logging.WARNING) def index_to_freq( @@ -29,6 +37,11 @@ def index_to_time( def test_get_feats_function_with_empty_spectrogram(): + """Test get_feats function with empty spectrogram. + + This tests that the overall flow of the function works, even if the + spectrogram is empty. + """ spec_duration = 3 spec_width = 100 spec_height = 100 @@ -43,12 +56,14 @@ def test_get_feats_function_with_empty_spectrogram(): start_time = index_to_time(x_pos, spec_width, spec_duration) end_time = index_to_time(x_pos + bb_width, spec_width, spec_duration) - high_freq = index_to_freq(y_pos, spec_height, min_freq, max_freq) - low_freq = index_to_freq(y_pos + bb_height, spec_height, min_freq, max_freq) + low_freq = index_to_freq(y_pos, spec_height, min_freq, max_freq) + high_freq = index_to_freq( + y_pos - bb_height, spec_height, min_freq, max_freq + ) pred_nms: types.PredictionResults = { "det_probs": np.array([1]), - "class_probs": np.array([1]), + "class_probs": np.array([[1]]), "x_pos": np.array([x_pos]), "y_pos": np.array([y_pos]), "bb_width": np.array([bb_width]), @@ -76,7 +91,7 @@ def test_get_feats_function_with_empty_spectrogram(): low_freq, high_freq, high_freq - low_freq, - max_freq, + high_freq, max_freq, max_freq, max_freq, @@ -85,3 +100,192 @@ def test_get_feats_function_with_empty_spectrogram(): ), equal_nan=True, ).all() + + +@pytest.mark.parametrize( + "max_power", + [ + 30_000, + 31_000, + 32_000, + 33_000, + 34_000, + 35_000, + 36_000, + 37_000, + 38_000, + 39_000, + 40_000, + ], +) +def test_compute_max_power_bb(max_power: int): + """Test compute_max_power_bb function.""" + duration = 1 + samplerate = 256_000 + min_freq = 0 + max_freq = 128_000 + + start_time = 0.3 + end_time = 0.6 + low_freq = 30_000 + high_freq = 40_000 + + audio = np.zeros((int(duration * samplerate),)) + + # Add a signal during the time and frequency range of interest + audio[ + int(start_time * samplerate) : int(end_time * samplerate) + ] = 0.5 * librosa.tone( + max_power, sr=samplerate, duration=end_time - start_time + ) + + # Add a more powerful signal outside frequency range of interest + audio[ + int(start_time * samplerate) : int(end_time * samplerate) + ] += 2 * librosa.tone( + 80_000, sr=samplerate, duration=end_time - start_time + ) + + params = api.get_config( + min_freq=min_freq, + max_freq=max_freq, + target_samp_rate=samplerate, + ) + + spec, _ = au.generate_spectrogram( + audio, + samplerate, + params, + ) + + x_start = int( + au.time_to_x_coords( + start_time, + samplerate, + params["fft_win_length"], + params["fft_overlap"], + ) + ) + + x_end = int( + au.time_to_x_coords( + end_time, + samplerate, + params["fft_win_length"], + params["fft_overlap"], + ) + ) + + num_freq_bins = spec.shape[0] + y_low = num_freq_bins - int(num_freq_bins * low_freq / max_freq) + y_high = num_freq_bins - int(num_freq_bins * high_freq / max_freq) + + prediction: types.Prediction = { + "det_prob": 1, + "class_prob": np.ones((1,)), + "x_pos": x_start, + "y_pos": int(y_low), + "bb_width": int(x_end - x_start), + "bb_height": int(y_low - y_high), + "start_time": start_time, + "end_time": end_time, + "low_freq": low_freq, + "high_freq": high_freq, + } + + print(prediction) + + max_power_bb = feats.compute_max_power_bb( + prediction, + spec, + min_freq=min_freq, + max_freq=max_freq, + ) + + assert abs(max_power_bb - max_power) <= 500 + + +def test_compute_max_power(): + """Test compute_max_power_bb function.""" + duration = 3 + samplerate = 16_000 + min_freq = 0 + max_freq = 8_000 + + start_time = 1 + end_time = 2 + low_freq = 3_000 + high_freq = 4_000 + max_power = 5_000 + + audio = np.zeros((int(duration * samplerate),)) + + # Add a signal during the time and frequency range of interest + audio[ + int(start_time * samplerate) : int(end_time * samplerate) + ] = 0.5 * librosa.tone( + 3_500, sr=samplerate, duration=end_time - start_time + ) + + # Add a more powerful signal outside frequency range of interest + audio[ + int(start_time * samplerate) : int(end_time * samplerate) + ] += 2 * librosa.tone( + max_power, sr=samplerate, duration=end_time - start_time + ) + + params = api.get_config( + min_freq=min_freq, + max_freq=max_freq, + target_samp_rate=samplerate, + ) + + spec, _ = au.generate_spectrogram( + audio, + samplerate, + params, + ) + + x_start = int( + au.time_to_x_coords( + start_time, + samplerate, + params["fft_win_length"], + params["fft_overlap"], + ) + ) + + x_end = int( + au.time_to_x_coords( + end_time, + samplerate, + params["fft_win_length"], + params["fft_overlap"], + ) + ) + + num_freq_bins = spec.shape[0] + y_low = int(num_freq_bins * low_freq / max_freq) + y_high = int(num_freq_bins * high_freq / max_freq) + + prediction: types.Prediction = { + "det_prob": 1, + "class_prob": np.ones((1,)), + "x_pos": x_start, + "y_pos": int(y_high), + "bb_width": int(x_end - x_start), + "bb_height": int(y_high - y_low), + "start_time": start_time, + "end_time": end_time, + "low_freq": low_freq, + "high_freq": high_freq, + } + + computed_max_power = feats.compute_max_power( + prediction, + spec, + min_freq=min_freq, + max_freq=max_freq, + ) + + assert abs(computed_max_power - max_power) < 100