From b0d9576a2444879576a7a4a1e824f9f276b2cfc7 Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Sun, 26 Feb 2023 19:17:47 +0000 Subject: [PATCH] Changed public API to use trained model by default --- bat_detect/api.py | 83 +++++++++++++++++--------- bat_detect/cli.py | 16 ++--- bat_detect/utils/plot_utils.py | 2 - tests/test_api.py | 104 ++++++++++++++++++++------------- 4 files changed, 125 insertions(+), 80 deletions(-) diff --git a/bat_detect/api.py b/bat_detect/api.py index bf44670..f05748d 100644 --- a/bat_detect/api.py +++ b/bat_detect/api.py @@ -1,3 +1,4 @@ +import warnings from typing import List, Optional, Tuple import numpy as np @@ -6,6 +7,7 @@ import torch import bat_detect.utils.audio_utils as au import bat_detect.utils.detector_utils as du from bat_detect.detector.parameters import ( + DEFAULT_MODEL_PATH, DEFAULT_PROCESSING_CONFIGURATIONS, DEFAULT_SPECTROGRAM_PARAMETERS, TARGET_SAMPLERATE_HZ, @@ -18,8 +20,8 @@ from bat_detect.types import ( ) from bat_detect.utils.detector_utils import list_audio_files, load_model -# Use GPU if available -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# Remove warnings from torch +warnings.filterwarnings("ignore", category=UserWarning, module="torch") __all__ = [ "load_model", @@ -30,9 +32,19 @@ __all__ = [ "process_file", "process_spectrogram", "process_audio", + "model", + "config", ] +# Use GPU if available +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# Default model +MODEL, PARAMS = load_model(DEFAULT_MODEL_PATH, device=DEVICE) + + def get_config(**kwargs) -> ProcessingConfiguration: """Get default processing configuration. @@ -41,15 +53,22 @@ def get_config(**kwargs) -> ProcessingConfiguration: return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore +# Default processing configuration +CONFIG = get_config(**PARAMS) + + def load_audio( path: str, time_exp_fact: float = 1, target_samp_rate: int = TARGET_SAMPLERATE_HZ, scale: bool = False, max_duration: Optional[float] = None, -) -> Tuple[int, np.ndarray]: +) -> np.ndarray: """Load audio from file. + All audio will be resampled to the target sample rate. If the audio is + longer than max_duration, it will be truncated to max_duration. + Parameters ---------- path : str @@ -67,21 +86,20 @@ def load_audio( ------- np.ndarray Audio data. - int - Sample rate. """ - return au.load_audio( + _, audio = au.load_audio( path, time_exp_fact, target_samp_rate, scale, max_duration, ) + return audio def generate_spectrogram( audio: np.ndarray, - samp_rate: int, + samp_rate: int = TARGET_SAMPLERATE_HZ, config: Optional[SpectrogramParameters] = None, device: torch.device = DEVICE, ) -> torch.Tensor: @@ -91,8 +109,10 @@ def generate_spectrogram( ---------- audio : np.ndarray Audio data. - samp_rate : int - Sample rate. + samp_rate : int, optional + Sample rate. Defaults to 256000 which is the target sample rate of + the default model. Only change if you loaded the audio with a + different sample rate. config : Optional[SpectrogramParameters], optional Spectrogram parameters, by default None (uses default parameters). @@ -117,7 +137,7 @@ def generate_spectrogram( def process_file( audio_file: str, - model: DetectionModel, + model: DetectionModel = MODEL, config: Optional[ProcessingConfiguration] = None, device: torch.device = DEVICE, ) -> du.RunResults: @@ -127,15 +147,15 @@ def process_file( ---------- audio_file : str Path to audio file. - model : DetectionModel - Detection model. + model : DetectionModel, optional + Detection model. Uses default model if not specified. config : Optional[ProcessingConfiguration], optional Processing configuration, by default None (uses default parameters). device : torch.device, optional Device to use, by default tries to use GPU if available. """ if config is None: - config = DEFAULT_PROCESSING_CONFIGURATIONS + config = CONFIG return du.process_file( audio_file, @@ -147,8 +167,8 @@ def process_file( def process_spectrogram( spec: torch.Tensor, - samp_rate: int, - model: DetectionModel, + samp_rate: int = TARGET_SAMPLERATE_HZ, + model: DetectionModel = MODEL, config: Optional[ProcessingConfiguration] = None, ) -> Tuple[List[Annotation], List[np.ndarray]]: """Process spectrogram with model. @@ -157,10 +177,13 @@ def process_spectrogram( ---------- spec : torch.Tensor Spectrogram. - samp_rate : int + samp_rate : int, optional Sample rate of the audio from which the spectrogram was generated. - model : DetectionModel - Detection model. + Defaults to 256000 which is the target sample rate of the default + model. Only change if you generated the spectrogram with a different + sample rate. + model : DetectionModel, optional + Detection model. Uses default model if not specified. config : Optional[ProcessingConfiguration], optional Processing configuration, by default None (uses default parameters). @@ -169,7 +192,7 @@ def process_spectrogram( DetectionResult """ if config is None: - config = DEFAULT_PROCESSING_CONFIGURATIONS + config = CONFIG return du.process_spectrogram( spec, @@ -181,8 +204,8 @@ def process_spectrogram( def process_audio( audio: np.ndarray, - samp_rate: int, - model: DetectionModel, + samp_rate: int = TARGET_SAMPLERATE_HZ, + model: DetectionModel = MODEL, config: Optional[ProcessingConfiguration] = None, device: torch.device = DEVICE, ) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]: @@ -192,10 +215,11 @@ def process_audio( ---------- audio : np.ndarray Audio data. - samp_rate : int - Sample rate. - model : DetectionModel - Detection model. + samp_rate : int, optional + Sample rate, by default 256000. Only change if you loaded the audio + with a different sample rate. + model : DetectionModel, optional + Detection model. Uses default model if not specified. config : Optional[ProcessingConfiguration], optional Processing configuration, by default None (uses default parameters). device : torch.device, optional @@ -213,7 +237,7 @@ def process_audio( Spectrogram of the audio used for prediction. """ if config is None: - config = DEFAULT_PROCESSING_CONFIGURATIONS + config = CONFIG return du.process_audio_array( audio, @@ -222,3 +246,10 @@ def process_audio( config, device, ) + + +model: DetectionModel = MODEL +"""Base detection model.""" + +config: ProcessingConfiguration = CONFIG +"""Default processing configuration.""" diff --git a/bat_detect/cli.py b/bat_detect/cli.py index c9c34db..29f4142 100644 --- a/bat_detect/cli.py +++ b/bat_detect/cli.py @@ -1,14 +1,11 @@ """BatDetect2 command line interface.""" import os -import warnings -warnings.filterwarnings("ignore", category=UserWarning) +import click -import click # noqa: E402 - -from bat_detect import api # noqa: E402 -from bat_detect.detector.parameters import DEFAULT_MODEL_PATH # noqa: E402 -from bat_detect.utils.detector_utils import save_results_to_file # noqa: E402 +from bat_detect import api +from bat_detect.detector.parameters import DEFAULT_MODEL_PATH +from bat_detect.utils.detector_utils import save_results_to_file CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -124,15 +121,14 @@ def detect( results_path = audio_file.replace(audio_dir, ann_dir) save_results_to_file(results, results_path) except (RuntimeError, ValueError, LookupError) as err: - # TODO: Check what other errors can be thrown error_files.append(audio_file) - click.echo(f"Error processing file!: {err}") + click.secho(f"Error processing file!: {err}", fg="red") raise err click.echo(f"\nResults saved to: {ann_dir}") if len(error_files) > 0: - click.echo("\nUnable to process the follow files:") + click.secho("\nUnable to process the follow files:", fg="red") for err in error_files: click.echo(f" {err}") diff --git a/bat_detect/utils/plot_utils.py b/bat_detect/utils/plot_utils.py index 6d732ec..6fcb387 100644 --- a/bat_detect/utils/plot_utils.py +++ b/bat_detect/utils/plot_utils.py @@ -6,8 +6,6 @@ from matplotlib import patches from matplotlib.collections import PatchCollection from sklearn.metrics import confusion_matrix -from . import audio_utils as au - def create_box_image( spec, diff --git a/tests/test_api.py b/tests/test_api.py index 1ee3231..52ba40b 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -7,16 +7,7 @@ import numpy as np import torch from torch import nn -from bat_detect.api import ( - generate_spectrogram, - get_config, - list_audio_files, - load_audio, - load_model, - process_audio, - process_file, - process_spectrogram, -) +from bat_detect import api PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio") @@ -25,7 +16,7 @@ TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav")) def test_load_model_with_default_params(): """Test loading model with default parameters.""" - model, params = load_model() + model, params = api.load_model() assert model is not None assert isinstance(model, nn.Module) @@ -50,7 +41,7 @@ def test_load_model_with_default_params(): def test_list_audio_files(): """Test listing audio files.""" - audio_files = list_audio_files(TEST_DATA_DIR) + audio_files = api.list_audio_files(TEST_DATA_DIR) assert len(audio_files) == 3 assert all(path.endswith((".wav", ".WAV")) for path in audio_files) @@ -58,18 +49,17 @@ def test_list_audio_files(): def test_load_audio(): """Test loading audio.""" - samplerate, audio = load_audio(TEST_DATA[0]) + audio = api.load_audio(TEST_DATA[0]) assert audio is not None - assert samplerate == 256000 assert isinstance(audio, np.ndarray) assert audio.shape == (128000,) def test_generate_spectrogram(): """Test generating spectrogram.""" - samplerate, audio = load_audio(TEST_DATA[0]) - spectrogram = generate_spectrogram(audio, samplerate) + audio = api.load_audio(TEST_DATA[0]) + spectrogram = api.generate_spectrogram(audio) assert spectrogram is not None assert isinstance(spectrogram, torch.Tensor) @@ -78,7 +68,7 @@ def test_generate_spectrogram(): def test_get_default_config(): """Test getting default configuration.""" - config = get_config() + config = api.get_config() assert config is not None assert isinstance(config, dict) @@ -110,11 +100,55 @@ def test_get_default_config(): assert config["spec_slices"] is False -def test_process_file_with_model(): +def test_api_exposes_default_model(): + """Test that API exposes default model.""" + assert hasattr(api, "model") + assert isinstance(api.model, nn.Module) + assert type(api.model).__name__ == "Net2DFast" + + # Check that model has expected attributes + assert api.model.num_classes == 17 + assert api.model.num_filts == 128 + assert api.model.emb_dim == 0 + assert api.model.ip_height_rs == 128 + assert api.model.resize_factor == 0.5 + + +def test_api_exposes_default_config(): + """Test that API exposes default configuration.""" + assert hasattr(api, "config") + assert isinstance(api.config, dict) + + assert api.config["target_samp_rate"] == 256000 + assert api.config["fft_win_length"] == 0.002 + assert api.config["fft_overlap"] == 0.75 + assert api.config["resize_factor"] == 0.5 + assert api.config["spec_divide_factor"] == 32 + assert api.config["spec_height"] == 256 + assert api.config["spec_scale"] == "pcen" + assert api.config["denoise_spec_avg"] is True + assert api.config["max_scale_spec"] is False + assert api.config["scale_raw_audio"] is False + assert len(api.config["class_names"]) == 17 + assert api.config["detection_threshold"] == 0.01 + assert api.config["time_expansion"] == 1 + assert api.config["top_n"] == 3 + assert api.config["return_raw_preds"] is False + assert api.config["max_duration"] is None + assert api.config["nms_kernel_size"] == 9 + assert api.config["max_freq"] == 120000 + assert api.config["min_freq"] == 10000 + assert api.config["nms_top_k_per_sec"] == 200 + assert api.config["quiet"] is True + assert api.config["chunk_size"] == 3 + assert api.config["cnn_features"] is False + assert api.config["spec_features"] is False + assert api.config["spec_slices"] is False + + +def test_process_file_with_default_model(): """Test processing file with model.""" - model, params = load_model() - config = get_config(**params) - predictions = process_file(TEST_DATA[0], model, config=config) + predictions = api.process_file(TEST_DATA[0]) assert predictions is not None assert isinstance(predictions, dict) @@ -141,18 +175,11 @@ def test_process_file_with_model(): assert len(pred_dict["annotation"]) > 0 -def test_process_spectrogram_with_model(): +def test_process_spectrogram_with_default_model(): """Test processing spectrogram with model.""" - model, params = load_model() - config = get_config(**params) - samplerate, audio = load_audio(TEST_DATA[0]) - spectrogram = generate_spectrogram(audio, samplerate) - predictions, features = process_spectrogram( - spectrogram, - samplerate, - model, - config=config, - ) + audio = api.load_audio(TEST_DATA[0]) + spectrogram = api.generate_spectrogram(audio) + predictions, features = api.process_spectrogram(spectrogram) assert predictions is not None assert isinstance(predictions, list) @@ -172,17 +199,10 @@ def test_process_spectrogram_with_model(): assert len(features) == 1 -def test_process_audio_with_model(): +def test_process_audio_with_default_model(): """Test processing audio with model.""" - model, params = load_model() - config = get_config(**params) - samplerate, audio = load_audio(TEST_DATA[0]) - predictions, features, spec = process_audio( - audio, - samplerate, - model, - config=config, - ) + audio = api.load_audio(TEST_DATA[0]) + predictions, features, spec = api.process_audio(audio) assert predictions is not None assert isinstance(predictions, list)