mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
265 lines
8.6 KiB
Python
265 lines
8.6 KiB
Python
"""Test bat detect module API."""
|
|
|
|
import os
|
|
from glob import glob
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
|
|
from batdetect2 import api
|
|
|
|
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
|
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
|
|
|
|
|
|
def test_load_model_with_default_params():
|
|
"""Test loading model with default parameters."""
|
|
model, params = api.load_model()
|
|
|
|
assert model is not None
|
|
assert isinstance(model, nn.Module)
|
|
|
|
assert params is not None
|
|
assert isinstance(params, dict)
|
|
|
|
assert "model_name" in params
|
|
assert "num_filters" in params
|
|
assert "emb_dim" in params
|
|
assert "ip_height" in params
|
|
|
|
assert params["model_name"] == "Net2DFast"
|
|
assert params["num_filters"] == 128
|
|
assert params["emb_dim"] == 0
|
|
assert params["ip_height"] == 128
|
|
assert params["resize_factor"] == 0.5
|
|
assert len(params["class_names"]) == 17
|
|
|
|
|
|
def test_list_audio_files():
|
|
"""Test listing audio files."""
|
|
audio_files = api.list_audio_files(TEST_DATA_DIR)
|
|
|
|
assert len(audio_files) == 3
|
|
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
|
|
|
|
|
|
def test_load_audio():
|
|
"""Test loading audio."""
|
|
audio = api.load_audio(TEST_DATA[0])
|
|
|
|
assert audio is not None
|
|
assert isinstance(audio, np.ndarray)
|
|
assert audio.shape == (128000,)
|
|
|
|
|
|
def test_generate_spectrogram():
|
|
"""Test generating spectrogram."""
|
|
audio = api.load_audio(TEST_DATA[0])
|
|
spectrogram = api.generate_spectrogram(audio)
|
|
|
|
assert spectrogram is not None
|
|
assert isinstance(spectrogram, torch.Tensor)
|
|
assert spectrogram.shape == (1, 1, 128, 512)
|
|
|
|
|
|
def test_get_default_config():
|
|
"""Test getting default configuration."""
|
|
config = api.get_config()
|
|
|
|
assert config is not None
|
|
assert isinstance(config, dict)
|
|
|
|
assert config["target_samp_rate"] == 256000
|
|
assert config["fft_win_length"] == 0.002
|
|
assert config["fft_overlap"] == 0.75
|
|
assert config["resize_factor"] == 0.5
|
|
assert config["spec_divide_factor"] == 32
|
|
assert config["spec_height"] == 256
|
|
assert config["spec_scale"] == "pcen"
|
|
assert config["denoise_spec_avg"] is True
|
|
assert config["max_scale_spec"] is False
|
|
assert config["scale_raw_audio"] is False
|
|
assert len(config["class_names"]) == 17
|
|
assert config["detection_threshold"] == 0.01
|
|
assert config["time_expansion"] == 1
|
|
assert config["top_n"] == 3
|
|
assert config["return_raw_preds"] is False
|
|
assert config["max_duration"] is None
|
|
assert config["nms_kernel_size"] == 9
|
|
assert config["max_freq"] == 120000
|
|
assert config["min_freq"] == 10000
|
|
assert config["nms_top_k_per_sec"] == 200
|
|
assert config["quiet"] is True
|
|
assert config["chunk_size"] == 3
|
|
assert config["cnn_features"] is False
|
|
assert config["spec_features"] is False
|
|
assert config["spec_slices"] is False
|
|
|
|
|
|
def test_api_exposes_default_model():
|
|
"""Test that API exposes default model."""
|
|
assert hasattr(api, "model")
|
|
assert isinstance(api.model, nn.Module)
|
|
assert type(api.model).__name__ == "Net2DFast"
|
|
|
|
# Check that model has expected attributes
|
|
assert api.model.num_classes == 17
|
|
assert api.model.num_filts == 128
|
|
assert api.model.emb_dim == 0
|
|
assert api.model.ip_height_rs == 128
|
|
assert api.model.resize_factor == 0.5
|
|
|
|
|
|
def test_api_exposes_default_config():
|
|
"""Test that API exposes default configuration."""
|
|
assert hasattr(api, "config")
|
|
assert isinstance(api.config, dict)
|
|
|
|
assert api.config["target_samp_rate"] == 256000
|
|
assert api.config["fft_win_length"] == 0.002
|
|
assert api.config["fft_overlap"] == 0.75
|
|
assert api.config["resize_factor"] == 0.5
|
|
assert api.config["spec_divide_factor"] == 32
|
|
assert api.config["spec_height"] == 256
|
|
assert api.config["spec_scale"] == "pcen"
|
|
assert api.config["denoise_spec_avg"] is True
|
|
assert api.config["max_scale_spec"] is False
|
|
assert api.config["scale_raw_audio"] is False
|
|
assert len(api.config["class_names"]) == 17
|
|
assert api.config["detection_threshold"] == 0.01
|
|
assert api.config["time_expansion"] == 1
|
|
assert api.config["top_n"] == 3
|
|
assert api.config["return_raw_preds"] is False
|
|
assert api.config["max_duration"] is None
|
|
assert api.config["nms_kernel_size"] == 9
|
|
assert api.config["max_freq"] == 120000
|
|
assert api.config["min_freq"] == 10000
|
|
assert api.config["nms_top_k_per_sec"] == 200
|
|
assert api.config["quiet"] is True
|
|
assert api.config["chunk_size"] == 3
|
|
assert api.config["cnn_features"] is False
|
|
assert api.config["spec_features"] is False
|
|
assert api.config["spec_slices"] is False
|
|
|
|
|
|
def test_process_file_with_default_model():
|
|
"""Test processing file with model."""
|
|
predictions = api.process_file(TEST_DATA[0])
|
|
|
|
assert predictions is not None
|
|
assert isinstance(predictions, dict)
|
|
|
|
assert "pred_dict" in predictions
|
|
|
|
# By default will not return other features
|
|
assert "spec_feats" not in predictions
|
|
assert "spec_feat_names" not in predictions
|
|
assert "cnn_feats" not in predictions
|
|
assert "cnn_feat_names" not in predictions
|
|
assert "spec_slices" not in predictions
|
|
|
|
# Check that predictions are returned
|
|
assert isinstance(predictions["pred_dict"], dict)
|
|
pred_dict = predictions["pred_dict"]
|
|
assert pred_dict["id"] == os.path.basename(TEST_DATA[0])
|
|
assert pred_dict["annotated"] is False
|
|
assert pred_dict["issues"] is False
|
|
assert pred_dict["notes"] == "Automatically generated."
|
|
assert pred_dict["time_exp"] == 1
|
|
assert pred_dict["duration"] == 0.5
|
|
assert pred_dict["class_name"] is not None
|
|
assert len(pred_dict["annotation"]) > 0
|
|
|
|
|
|
def test_process_spectrogram_with_default_model():
|
|
"""Test processing spectrogram with model."""
|
|
audio = api.load_audio(TEST_DATA[0])
|
|
spectrogram = api.generate_spectrogram(audio)
|
|
predictions, features = api.process_spectrogram(spectrogram)
|
|
|
|
assert predictions is not None
|
|
assert isinstance(predictions, list)
|
|
assert len(predictions) > 0
|
|
sample_pred = predictions[0]
|
|
assert isinstance(sample_pred, dict)
|
|
assert "class" in sample_pred
|
|
assert "class_prob" in sample_pred
|
|
assert "det_prob" in sample_pred
|
|
assert "start_time" in sample_pred
|
|
assert "end_time" in sample_pred
|
|
assert "low_freq" in sample_pred
|
|
assert "high_freq" in sample_pred
|
|
|
|
assert features is not None
|
|
assert isinstance(features, np.ndarray)
|
|
assert len(features) == len(predictions)
|
|
|
|
|
|
def test_process_audio_with_default_model():
|
|
"""Test processing audio with model."""
|
|
audio = api.load_audio(TEST_DATA[0])
|
|
predictions, features, spec = api.process_audio(audio)
|
|
|
|
assert predictions is not None
|
|
assert isinstance(predictions, list)
|
|
assert len(predictions) > 0
|
|
sample_pred = predictions[0]
|
|
assert isinstance(sample_pred, dict)
|
|
assert "class" in sample_pred
|
|
assert "class_prob" in sample_pred
|
|
assert "det_prob" in sample_pred
|
|
assert "start_time" in sample_pred
|
|
assert "end_time" in sample_pred
|
|
assert "low_freq" in sample_pred
|
|
assert "high_freq" in sample_pred
|
|
|
|
assert features is not None
|
|
assert isinstance(features, np.ndarray)
|
|
assert len(features) == len(predictions)
|
|
|
|
assert spec is not None
|
|
assert isinstance(spec, torch.Tensor)
|
|
assert spec.shape == (1, 1, 128, 512)
|
|
|
|
|
|
def test_postprocess_model_outputs():
|
|
"""Test postprocessing model outputs."""
|
|
# Load model outputs
|
|
audio = api.load_audio(TEST_DATA[1])
|
|
spec = api.generate_spectrogram(audio)
|
|
model_outputs = api.model(spec)
|
|
|
|
# Postprocess outputs
|
|
predictions, features = api.postprocess(model_outputs)
|
|
|
|
assert predictions is not None
|
|
assert isinstance(predictions, list)
|
|
assert len(predictions) > 0
|
|
sample_pred = predictions[0]
|
|
assert isinstance(sample_pred, dict)
|
|
assert "class" in sample_pred
|
|
assert "class_prob" in sample_pred
|
|
assert "det_prob" in sample_pred
|
|
assert "start_time" in sample_pred
|
|
assert "end_time" in sample_pred
|
|
assert "low_freq" in sample_pred
|
|
assert "high_freq" in sample_pred
|
|
|
|
assert features is not None
|
|
assert isinstance(features, np.ndarray)
|
|
assert features.shape[0] == len(predictions)
|
|
assert features.shape[1] == 32
|
|
|
|
|
|
def test_process_file_with_spec_slices():
|
|
"""Test process file returns spec slices."""
|
|
config = api.get_config(spec_slices=True)
|
|
results = api.process_file(TEST_DATA[0], config=config)
|
|
detections = results["pred_dict"]["annotation"]
|
|
|
|
assert "spec_slices" in results
|
|
assert isinstance(results["spec_slices"], list)
|
|
assert len(results["spec_slices"]) == len(detections)
|