Changed public API to use trained model by default

This commit is contained in:
Santiago Martinez 2023-02-26 19:17:47 +00:00
parent a2deab9f3f
commit b0d9576a24
4 changed files with 125 additions and 80 deletions

View File

@ -1,3 +1,4 @@
import warnings
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import numpy as np import numpy as np
@ -6,6 +7,7 @@ import torch
import bat_detect.utils.audio_utils as au import bat_detect.utils.audio_utils as au
import bat_detect.utils.detector_utils as du import bat_detect.utils.detector_utils as du
from bat_detect.detector.parameters import ( from bat_detect.detector.parameters import (
DEFAULT_MODEL_PATH,
DEFAULT_PROCESSING_CONFIGURATIONS, DEFAULT_PROCESSING_CONFIGURATIONS,
DEFAULT_SPECTROGRAM_PARAMETERS, DEFAULT_SPECTROGRAM_PARAMETERS,
TARGET_SAMPLERATE_HZ, TARGET_SAMPLERATE_HZ,
@ -18,8 +20,8 @@ from bat_detect.types import (
) )
from bat_detect.utils.detector_utils import list_audio_files, load_model from bat_detect.utils.detector_utils import list_audio_files, load_model
# Use GPU if available # Remove warnings from torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") warnings.filterwarnings("ignore", category=UserWarning, module="torch")
__all__ = [ __all__ = [
"load_model", "load_model",
@ -30,9 +32,19 @@ __all__ = [
"process_file", "process_file",
"process_spectrogram", "process_spectrogram",
"process_audio", "process_audio",
"model",
"config",
] ]
# Use GPU if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Default model
MODEL, PARAMS = load_model(DEFAULT_MODEL_PATH, device=DEVICE)
def get_config(**kwargs) -> ProcessingConfiguration: def get_config(**kwargs) -> ProcessingConfiguration:
"""Get default processing configuration. """Get default processing configuration.
@ -41,15 +53,22 @@ def get_config(**kwargs) -> ProcessingConfiguration:
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore
# Default processing configuration
CONFIG = get_config(**PARAMS)
def load_audio( def load_audio(
path: str, path: str,
time_exp_fact: float = 1, time_exp_fact: float = 1,
target_samp_rate: int = TARGET_SAMPLERATE_HZ, target_samp_rate: int = TARGET_SAMPLERATE_HZ,
scale: bool = False, scale: bool = False,
max_duration: Optional[float] = None, max_duration: Optional[float] = None,
) -> Tuple[int, np.ndarray]: ) -> np.ndarray:
"""Load audio from file. """Load audio from file.
All audio will be resampled to the target sample rate. If the audio is
longer than max_duration, it will be truncated to max_duration.
Parameters Parameters
---------- ----------
path : str path : str
@ -67,21 +86,20 @@ def load_audio(
------- -------
np.ndarray np.ndarray
Audio data. Audio data.
int
Sample rate.
""" """
return au.load_audio( _, audio = au.load_audio(
path, path,
time_exp_fact, time_exp_fact,
target_samp_rate, target_samp_rate,
scale, scale,
max_duration, max_duration,
) )
return audio
def generate_spectrogram( def generate_spectrogram(
audio: np.ndarray, audio: np.ndarray,
samp_rate: int, samp_rate: int = TARGET_SAMPLERATE_HZ,
config: Optional[SpectrogramParameters] = None, config: Optional[SpectrogramParameters] = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
) -> torch.Tensor: ) -> torch.Tensor:
@ -91,8 +109,10 @@ def generate_spectrogram(
---------- ----------
audio : np.ndarray audio : np.ndarray
Audio data. Audio data.
samp_rate : int samp_rate : int, optional
Sample rate. Sample rate. Defaults to 256000 which is the target sample rate of
the default model. Only change if you loaded the audio with a
different sample rate.
config : Optional[SpectrogramParameters], optional config : Optional[SpectrogramParameters], optional
Spectrogram parameters, by default None (uses default parameters). Spectrogram parameters, by default None (uses default parameters).
@ -117,7 +137,7 @@ def generate_spectrogram(
def process_file( def process_file(
audio_file: str, audio_file: str,
model: DetectionModel, model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None, config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
) -> du.RunResults: ) -> du.RunResults:
@ -127,15 +147,15 @@ def process_file(
---------- ----------
audio_file : str audio_file : str
Path to audio file. Path to audio file.
model : DetectionModel model : DetectionModel, optional
Detection model. Detection model. Uses default model if not specified.
config : Optional[ProcessingConfiguration], optional config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters). Processing configuration, by default None (uses default parameters).
device : torch.device, optional device : torch.device, optional
Device to use, by default tries to use GPU if available. Device to use, by default tries to use GPU if available.
""" """
if config is None: if config is None:
config = DEFAULT_PROCESSING_CONFIGURATIONS config = CONFIG
return du.process_file( return du.process_file(
audio_file, audio_file,
@ -147,8 +167,8 @@ def process_file(
def process_spectrogram( def process_spectrogram(
spec: torch.Tensor, spec: torch.Tensor,
samp_rate: int, samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel, model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None, config: Optional[ProcessingConfiguration] = None,
) -> Tuple[List[Annotation], List[np.ndarray]]: ) -> Tuple[List[Annotation], List[np.ndarray]]:
"""Process spectrogram with model. """Process spectrogram with model.
@ -157,10 +177,13 @@ def process_spectrogram(
---------- ----------
spec : torch.Tensor spec : torch.Tensor
Spectrogram. Spectrogram.
samp_rate : int samp_rate : int, optional
Sample rate of the audio from which the spectrogram was generated. Sample rate of the audio from which the spectrogram was generated.
model : DetectionModel Defaults to 256000 which is the target sample rate of the default
Detection model. model. Only change if you generated the spectrogram with a different
sample rate.
model : DetectionModel, optional
Detection model. Uses default model if not specified.
config : Optional[ProcessingConfiguration], optional config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters). Processing configuration, by default None (uses default parameters).
@ -169,7 +192,7 @@ def process_spectrogram(
DetectionResult DetectionResult
""" """
if config is None: if config is None:
config = DEFAULT_PROCESSING_CONFIGURATIONS config = CONFIG
return du.process_spectrogram( return du.process_spectrogram(
spec, spec,
@ -181,8 +204,8 @@ def process_spectrogram(
def process_audio( def process_audio(
audio: np.ndarray, audio: np.ndarray,
samp_rate: int, samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel, model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None, config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]: ) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
@ -192,10 +215,11 @@ def process_audio(
---------- ----------
audio : np.ndarray audio : np.ndarray
Audio data. Audio data.
samp_rate : int samp_rate : int, optional
Sample rate. Sample rate, by default 256000. Only change if you loaded the audio
model : DetectionModel with a different sample rate.
Detection model. model : DetectionModel, optional
Detection model. Uses default model if not specified.
config : Optional[ProcessingConfiguration], optional config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters). Processing configuration, by default None (uses default parameters).
device : torch.device, optional device : torch.device, optional
@ -213,7 +237,7 @@ def process_audio(
Spectrogram of the audio used for prediction. Spectrogram of the audio used for prediction.
""" """
if config is None: if config is None:
config = DEFAULT_PROCESSING_CONFIGURATIONS config = CONFIG
return du.process_audio_array( return du.process_audio_array(
audio, audio,
@ -222,3 +246,10 @@ def process_audio(
config, config,
device, device,
) )
model: DetectionModel = MODEL
"""Base detection model."""
config: ProcessingConfiguration = CONFIG
"""Default processing configuration."""

View File

@ -1,14 +1,11 @@
"""BatDetect2 command line interface.""" """BatDetect2 command line interface."""
import os import os
import warnings
warnings.filterwarnings("ignore", category=UserWarning) import click
import click # noqa: E402 from bat_detect import api
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH
from bat_detect import api # noqa: E402 from bat_detect.utils.detector_utils import save_results_to_file
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH # noqa: E402
from bat_detect.utils.detector_utils import save_results_to_file # noqa: E402
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
@ -124,15 +121,14 @@ def detect(
results_path = audio_file.replace(audio_dir, ann_dir) results_path = audio_file.replace(audio_dir, ann_dir)
save_results_to_file(results, results_path) save_results_to_file(results, results_path)
except (RuntimeError, ValueError, LookupError) as err: except (RuntimeError, ValueError, LookupError) as err:
# TODO: Check what other errors can be thrown
error_files.append(audio_file) error_files.append(audio_file)
click.echo(f"Error processing file!: {err}") click.secho(f"Error processing file!: {err}", fg="red")
raise err raise err
click.echo(f"\nResults saved to: {ann_dir}") click.echo(f"\nResults saved to: {ann_dir}")
if len(error_files) > 0: if len(error_files) > 0:
click.echo("\nUnable to process the follow files:") click.secho("\nUnable to process the follow files:", fg="red")
for err in error_files: for err in error_files:
click.echo(f" {err}") click.echo(f" {err}")

View File

@ -6,8 +6,6 @@ from matplotlib import patches
from matplotlib.collections import PatchCollection from matplotlib.collections import PatchCollection
from sklearn.metrics import confusion_matrix from sklearn.metrics import confusion_matrix
from . import audio_utils as au
def create_box_image( def create_box_image(
spec, spec,

View File

@ -7,16 +7,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from bat_detect.api import ( from bat_detect import api
generate_spectrogram,
get_config,
list_audio_files,
load_audio,
load_model,
process_audio,
process_file,
process_spectrogram,
)
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio") TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
@ -25,7 +16,7 @@ TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
def test_load_model_with_default_params(): def test_load_model_with_default_params():
"""Test loading model with default parameters.""" """Test loading model with default parameters."""
model, params = load_model() model, params = api.load_model()
assert model is not None assert model is not None
assert isinstance(model, nn.Module) assert isinstance(model, nn.Module)
@ -50,7 +41,7 @@ def test_load_model_with_default_params():
def test_list_audio_files(): def test_list_audio_files():
"""Test listing audio files.""" """Test listing audio files."""
audio_files = list_audio_files(TEST_DATA_DIR) audio_files = api.list_audio_files(TEST_DATA_DIR)
assert len(audio_files) == 3 assert len(audio_files) == 3
assert all(path.endswith((".wav", ".WAV")) for path in audio_files) assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
@ -58,18 +49,17 @@ def test_list_audio_files():
def test_load_audio(): def test_load_audio():
"""Test loading audio.""" """Test loading audio."""
samplerate, audio = load_audio(TEST_DATA[0]) audio = api.load_audio(TEST_DATA[0])
assert audio is not None assert audio is not None
assert samplerate == 256000
assert isinstance(audio, np.ndarray) assert isinstance(audio, np.ndarray)
assert audio.shape == (128000,) assert audio.shape == (128000,)
def test_generate_spectrogram(): def test_generate_spectrogram():
"""Test generating spectrogram.""" """Test generating spectrogram."""
samplerate, audio = load_audio(TEST_DATA[0]) audio = api.load_audio(TEST_DATA[0])
spectrogram = generate_spectrogram(audio, samplerate) spectrogram = api.generate_spectrogram(audio)
assert spectrogram is not None assert spectrogram is not None
assert isinstance(spectrogram, torch.Tensor) assert isinstance(spectrogram, torch.Tensor)
@ -78,7 +68,7 @@ def test_generate_spectrogram():
def test_get_default_config(): def test_get_default_config():
"""Test getting default configuration.""" """Test getting default configuration."""
config = get_config() config = api.get_config()
assert config is not None assert config is not None
assert isinstance(config, dict) assert isinstance(config, dict)
@ -110,11 +100,55 @@ def test_get_default_config():
assert config["spec_slices"] is False assert config["spec_slices"] is False
def test_process_file_with_model(): def test_api_exposes_default_model():
"""Test that API exposes default model."""
assert hasattr(api, "model")
assert isinstance(api.model, nn.Module)
assert type(api.model).__name__ == "Net2DFast"
# Check that model has expected attributes
assert api.model.num_classes == 17
assert api.model.num_filts == 128
assert api.model.emb_dim == 0
assert api.model.ip_height_rs == 128
assert api.model.resize_factor == 0.5
def test_api_exposes_default_config():
"""Test that API exposes default configuration."""
assert hasattr(api, "config")
assert isinstance(api.config, dict)
assert api.config["target_samp_rate"] == 256000
assert api.config["fft_win_length"] == 0.002
assert api.config["fft_overlap"] == 0.75
assert api.config["resize_factor"] == 0.5
assert api.config["spec_divide_factor"] == 32
assert api.config["spec_height"] == 256
assert api.config["spec_scale"] == "pcen"
assert api.config["denoise_spec_avg"] is True
assert api.config["max_scale_spec"] is False
assert api.config["scale_raw_audio"] is False
assert len(api.config["class_names"]) == 17
assert api.config["detection_threshold"] == 0.01
assert api.config["time_expansion"] == 1
assert api.config["top_n"] == 3
assert api.config["return_raw_preds"] is False
assert api.config["max_duration"] is None
assert api.config["nms_kernel_size"] == 9
assert api.config["max_freq"] == 120000
assert api.config["min_freq"] == 10000
assert api.config["nms_top_k_per_sec"] == 200
assert api.config["quiet"] is True
assert api.config["chunk_size"] == 3
assert api.config["cnn_features"] is False
assert api.config["spec_features"] is False
assert api.config["spec_slices"] is False
def test_process_file_with_default_model():
"""Test processing file with model.""" """Test processing file with model."""
model, params = load_model() predictions = api.process_file(TEST_DATA[0])
config = get_config(**params)
predictions = process_file(TEST_DATA[0], model, config=config)
assert predictions is not None assert predictions is not None
assert isinstance(predictions, dict) assert isinstance(predictions, dict)
@ -141,18 +175,11 @@ def test_process_file_with_model():
assert len(pred_dict["annotation"]) > 0 assert len(pred_dict["annotation"]) > 0
def test_process_spectrogram_with_model(): def test_process_spectrogram_with_default_model():
"""Test processing spectrogram with model.""" """Test processing spectrogram with model."""
model, params = load_model() audio = api.load_audio(TEST_DATA[0])
config = get_config(**params) spectrogram = api.generate_spectrogram(audio)
samplerate, audio = load_audio(TEST_DATA[0]) predictions, features = api.process_spectrogram(spectrogram)
spectrogram = generate_spectrogram(audio, samplerate)
predictions, features = process_spectrogram(
spectrogram,
samplerate,
model,
config=config,
)
assert predictions is not None assert predictions is not None
assert isinstance(predictions, list) assert isinstance(predictions, list)
@ -172,17 +199,10 @@ def test_process_spectrogram_with_model():
assert len(features) == 1 assert len(features) == 1
def test_process_audio_with_model(): def test_process_audio_with_default_model():
"""Test processing audio with model.""" """Test processing audio with model."""
model, params = load_model() audio = api.load_audio(TEST_DATA[0])
config = get_config(**params) predictions, features, spec = api.process_audio(audio)
samplerate, audio = load_audio(TEST_DATA[0])
predictions, features, spec = process_audio(
audio,
samplerate,
model,
config=config,
)
assert predictions is not None assert predictions is not None
assert isinstance(predictions, list) assert isinstance(predictions, list)