batdetect2/tests/test_api.py
2023-04-07 14:46:38 -06:00

254 lines
8.2 KiB
Python

"""Test bat detect module API."""
import os
from glob import glob
import numpy as np
import torch
from torch import nn
from batdetect2 import api
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
def test_load_model_with_default_params():
"""Test loading model with default parameters."""
model, params = api.load_model()
assert model is not None
assert isinstance(model, nn.Module)
assert params is not None
assert isinstance(params, dict)
assert "model_name" in params
assert "num_filters" in params
assert "emb_dim" in params
assert "ip_height" in params
assert params["model_name"] == "Net2DFast"
assert params["num_filters"] == 128
assert params["emb_dim"] == 0
assert params["ip_height"] == 128
assert params["resize_factor"] == 0.5
assert len(params["class_names"]) == 17
def test_list_audio_files():
"""Test listing audio files."""
audio_files = api.list_audio_files(TEST_DATA_DIR)
assert len(audio_files) == 3
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
def test_load_audio():
"""Test loading audio."""
audio = api.load_audio(TEST_DATA[0])
assert audio is not None
assert isinstance(audio, np.ndarray)
assert audio.shape == (128000,)
def test_generate_spectrogram():
"""Test generating spectrogram."""
audio = api.load_audio(TEST_DATA[0])
spectrogram = api.generate_spectrogram(audio)
assert spectrogram is not None
assert isinstance(spectrogram, torch.Tensor)
assert spectrogram.shape == (1, 1, 128, 512)
def test_get_default_config():
"""Test getting default configuration."""
config = api.get_config()
assert config is not None
assert isinstance(config, dict)
assert config["target_samp_rate"] == 256000
assert config["fft_win_length"] == 0.002
assert config["fft_overlap"] == 0.75
assert config["resize_factor"] == 0.5
assert config["spec_divide_factor"] == 32
assert config["spec_height"] == 256
assert config["spec_scale"] == "pcen"
assert config["denoise_spec_avg"] is True
assert config["max_scale_spec"] is False
assert config["scale_raw_audio"] is False
assert len(config["class_names"]) == 17
assert config["detection_threshold"] == 0.01
assert config["time_expansion"] == 1
assert config["top_n"] == 3
assert config["return_raw_preds"] is False
assert config["max_duration"] is None
assert config["nms_kernel_size"] == 9
assert config["max_freq"] == 120000
assert config["min_freq"] == 10000
assert config["nms_top_k_per_sec"] == 200
assert config["quiet"] is True
assert config["chunk_size"] == 3
assert config["cnn_features"] is False
assert config["spec_features"] is False
assert config["spec_slices"] is False
def test_api_exposes_default_model():
"""Test that API exposes default model."""
assert hasattr(api, "model")
assert isinstance(api.model, nn.Module)
assert type(api.model).__name__ == "Net2DFast"
# Check that model has expected attributes
assert api.model.num_classes == 17
assert api.model.num_filts == 128
assert api.model.emb_dim == 0
assert api.model.ip_height_rs == 128
assert api.model.resize_factor == 0.5
def test_api_exposes_default_config():
"""Test that API exposes default configuration."""
assert hasattr(api, "config")
assert isinstance(api.config, dict)
assert api.config["target_samp_rate"] == 256000
assert api.config["fft_win_length"] == 0.002
assert api.config["fft_overlap"] == 0.75
assert api.config["resize_factor"] == 0.5
assert api.config["spec_divide_factor"] == 32
assert api.config["spec_height"] == 256
assert api.config["spec_scale"] == "pcen"
assert api.config["denoise_spec_avg"] is True
assert api.config["max_scale_spec"] is False
assert api.config["scale_raw_audio"] is False
assert len(api.config["class_names"]) == 17
assert api.config["detection_threshold"] == 0.01
assert api.config["time_expansion"] == 1
assert api.config["top_n"] == 3
assert api.config["return_raw_preds"] is False
assert api.config["max_duration"] is None
assert api.config["nms_kernel_size"] == 9
assert api.config["max_freq"] == 120000
assert api.config["min_freq"] == 10000
assert api.config["nms_top_k_per_sec"] == 200
assert api.config["quiet"] is True
assert api.config["chunk_size"] == 3
assert api.config["cnn_features"] is False
assert api.config["spec_features"] is False
assert api.config["spec_slices"] is False
def test_process_file_with_default_model():
"""Test processing file with model."""
predictions = api.process_file(TEST_DATA[0])
assert predictions is not None
assert isinstance(predictions, dict)
assert "pred_dict" in predictions
# By default will not return other features
assert "spec_feats" not in predictions
assert "spec_feat_names" not in predictions
assert "cnn_feats" not in predictions
assert "cnn_feat_names" not in predictions
assert "spec_slices" not in predictions
# Check that predictions are returned
assert isinstance(predictions["pred_dict"], dict)
pred_dict = predictions["pred_dict"]
assert pred_dict["id"] == os.path.basename(TEST_DATA[0])
assert pred_dict["annotated"] is False
assert pred_dict["issues"] is False
assert pred_dict["notes"] == "Automatically generated."
assert pred_dict["time_exp"] == 1
assert pred_dict["duration"] == 0.5
assert pred_dict["class_name"] is not None
assert len(pred_dict["annotation"]) > 0
def test_process_spectrogram_with_default_model():
"""Test processing spectrogram with model."""
audio = api.load_audio(TEST_DATA[0])
spectrogram = api.generate_spectrogram(audio)
predictions, features = api.process_spectrogram(spectrogram)
assert predictions is not None
assert isinstance(predictions, list)
assert len(predictions) > 0
sample_pred = predictions[0]
assert isinstance(sample_pred, dict)
assert "class" in sample_pred
assert "class_prob" in sample_pred
assert "det_prob" in sample_pred
assert "start_time" in sample_pred
assert "end_time" in sample_pred
assert "low_freq" in sample_pred
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, np.ndarray)
assert len(features) == len(predictions)
def test_process_audio_with_default_model():
"""Test processing audio with model."""
audio = api.load_audio(TEST_DATA[0])
predictions, features, spec = api.process_audio(audio)
assert predictions is not None
assert isinstance(predictions, list)
assert len(predictions) > 0
sample_pred = predictions[0]
assert isinstance(sample_pred, dict)
assert "class" in sample_pred
assert "class_prob" in sample_pred
assert "det_prob" in sample_pred
assert "start_time" in sample_pred
assert "end_time" in sample_pred
assert "low_freq" in sample_pred
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, np.ndarray)
assert len(features) == len(predictions)
assert spec is not None
assert isinstance(spec, torch.Tensor)
assert spec.shape == (1, 1, 128, 512)
def test_postprocess_model_outputs():
"""Test postprocessing model outputs."""
# Load model outputs
audio = api.load_audio(TEST_DATA[1])
spec = api.generate_spectrogram(audio)
model_outputs = api.model(spec)
# Postprocess outputs
predictions, features = api.postprocess(model_outputs)
assert predictions is not None
assert isinstance(predictions, list)
assert len(predictions) > 0
sample_pred = predictions[0]
assert isinstance(sample_pred, dict)
assert "class" in sample_pred
assert "class_prob" in sample_pred
assert "det_prob" in sample_pred
assert "start_time" in sample_pred
assert "end_time" in sample_pred
assert "low_freq" in sample_pred
assert "high_freq" in sample_pred
assert features is not None
assert isinstance(features, np.ndarray)
assert features.shape[0] == len(predictions)
assert features.shape[1] == 32