"""Test bat detect module API.""" import os from glob import glob import numpy as np import torch from torch import nn from batdetect2 import api PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio") TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav")) def test_load_model_with_default_params(): """Test loading model with default parameters.""" model, params = api.load_model() assert model is not None assert isinstance(model, nn.Module) assert params is not None assert isinstance(params, dict) assert "model_name" in params assert "num_filters" in params assert "emb_dim" in params assert "ip_height" in params assert params["model_name"] == "Net2DFast" assert params["num_filters"] == 128 assert params["emb_dim"] == 0 assert params["ip_height"] == 128 assert params["resize_factor"] == 0.5 assert len(params["class_names"]) == 17 def test_list_audio_files(): """Test listing audio files.""" audio_files = api.list_audio_files(TEST_DATA_DIR) assert len(audio_files) == 3 assert all(path.endswith((".wav", ".WAV")) for path in audio_files) def test_load_audio(): """Test loading audio.""" audio = api.load_audio(TEST_DATA[0]) assert audio is not None assert isinstance(audio, np.ndarray) assert audio.shape == (128000,) def test_generate_spectrogram(): """Test generating spectrogram.""" audio = api.load_audio(TEST_DATA[0]) spectrogram = api.generate_spectrogram(audio) assert spectrogram is not None assert isinstance(spectrogram, torch.Tensor) assert spectrogram.shape == (1, 1, 128, 512) def test_get_default_config(): """Test getting default configuration.""" config = api.get_config() assert config is not None assert isinstance(config, dict) assert config["target_samp_rate"] == 256000 assert config["fft_win_length"] == 0.002 assert config["fft_overlap"] == 0.75 assert config["resize_factor"] == 0.5 assert config["spec_divide_factor"] == 32 assert config["spec_height"] == 256 assert config["spec_scale"] == "pcen" assert config["denoise_spec_avg"] is True assert config["max_scale_spec"] is False assert config["scale_raw_audio"] is False assert len(config["class_names"]) == 0 assert config["detection_threshold"] == 0.01 assert config["time_expansion"] == 1 assert config["top_n"] == 3 assert config["return_raw_preds"] is False assert config["max_duration"] is None assert config["nms_kernel_size"] == 9 assert config["max_freq"] == 120000 assert config["min_freq"] == 10000 assert config["nms_top_k_per_sec"] == 200 assert config["quiet"] is True assert config["chunk_size"] == 3 assert config["cnn_features"] is False assert config["spec_features"] is False assert config["spec_slices"] is False def test_api_exposes_default_model(): """Test that API exposes default model.""" assert hasattr(api, "model") assert isinstance(api.model, nn.Module) assert type(api.model).__name__ == "Net2DFast" # Check that model has expected attributes assert api.model.num_classes == 17 assert api.model.num_filts == 128 assert api.model.emb_dim == 0 assert api.model.ip_height_rs == 128 assert api.model.resize_factor == 0.5 def test_api_exposes_default_config(): """Test that API exposes default configuration.""" assert hasattr(api, "config") assert isinstance(api.config, dict) assert api.config["target_samp_rate"] == 256000 assert api.config["fft_win_length"] == 0.002 assert api.config["fft_overlap"] == 0.75 assert api.config["resize_factor"] == 0.5 assert api.config["spec_divide_factor"] == 32 assert api.config["spec_height"] == 256 assert api.config["spec_scale"] == "pcen" assert api.config["denoise_spec_avg"] is True assert api.config["max_scale_spec"] is False assert api.config["scale_raw_audio"] is False assert len(api.config["class_names"]) == 17 assert api.config["detection_threshold"] == 0.01 assert api.config["time_expansion"] == 1 assert api.config["top_n"] == 3 assert api.config["return_raw_preds"] is False assert api.config["max_duration"] is None assert api.config["nms_kernel_size"] == 9 assert api.config["max_freq"] == 120000 assert api.config["min_freq"] == 10000 assert api.config["nms_top_k_per_sec"] == 200 assert api.config["quiet"] is True assert api.config["chunk_size"] == 3 assert api.config["cnn_features"] is False assert api.config["spec_features"] is False assert api.config["spec_slices"] is False def test_process_file_with_default_model(): """Test processing file with model.""" predictions = api.process_file(TEST_DATA[0]) assert predictions is not None assert isinstance(predictions, dict) assert "pred_dict" in predictions # By default will not return other features assert "spec_feats" not in predictions assert "spec_feat_names" not in predictions assert "cnn_feats" not in predictions assert "cnn_feat_names" not in predictions assert "spec_slices" not in predictions # Check that predictions are returned assert isinstance(predictions["pred_dict"], dict) pred_dict = predictions["pred_dict"] assert pred_dict["id"] == os.path.basename(TEST_DATA[0]) assert pred_dict["annotated"] is False assert pred_dict["issues"] is False assert pred_dict["notes"] == "Automatically generated." assert pred_dict["time_exp"] == 1 assert pred_dict["duration"] == 0.5 assert pred_dict["class_name"] is not None assert len(pred_dict["annotation"]) > 0 def test_process_spectrogram_with_default_model(): """Test processing spectrogram with model.""" audio = api.load_audio(TEST_DATA[0]) spectrogram = api.generate_spectrogram(audio) predictions, features = api.process_spectrogram(spectrogram) assert predictions is not None assert isinstance(predictions, list) assert len(predictions) > 0 sample_pred = predictions[0] assert isinstance(sample_pred, dict) assert "class" in sample_pred assert "class_prob" in sample_pred assert "det_prob" in sample_pred assert "start_time" in sample_pred assert "end_time" in sample_pred assert "low_freq" in sample_pred assert "high_freq" in sample_pred assert features is not None assert isinstance(features, list) assert len(features) == 1 def test_process_audio_with_default_model(): """Test processing audio with model.""" audio = api.load_audio(TEST_DATA[0]) predictions, features, spec = api.process_audio(audio) assert predictions is not None assert isinstance(predictions, list) assert len(predictions) > 0 sample_pred = predictions[0] assert isinstance(sample_pred, dict) assert "class" in sample_pred assert "class_prob" in sample_pred assert "det_prob" in sample_pred assert "start_time" in sample_pred assert "end_time" in sample_pred assert "low_freq" in sample_pred assert "high_freq" in sample_pred assert features is not None assert isinstance(features, list) assert len(features) == 1 assert spec is not None assert isinstance(spec, torch.Tensor) assert spec.shape == (1, 1, 128, 512) def test_postprocess_model_outputs(): """Test postprocessing model outputs.""" # Load model outputs audio = api.load_audio(TEST_DATA[1]) spec = api.generate_spectrogram(audio) model_outputs = api.model(spec) # Postprocess outputs predictions, features = api.postprocess(model_outputs) assert predictions is not None assert isinstance(predictions, list) assert len(predictions) > 0 sample_pred = predictions[0] assert isinstance(sample_pred, dict) assert "class" in sample_pred assert "class_prob" in sample_pred assert "det_prob" in sample_pred assert "start_time" in sample_pred assert "end_time" in sample_pred assert "low_freq" in sample_pred assert "high_freq" in sample_pred assert features is not None assert isinstance(features, np.ndarray) assert features.shape[0] == len(predictions) assert features.shape[1] == 32