mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Changed public API to use trained model by default
This commit is contained in:
parent
a2deab9f3f
commit
b0d9576a24
@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
@ -6,6 +7,7 @@ import torch
|
||||
import bat_detect.utils.audio_utils as au
|
||||
import bat_detect.utils.detector_utils as du
|
||||
from bat_detect.detector.parameters import (
|
||||
DEFAULT_MODEL_PATH,
|
||||
DEFAULT_PROCESSING_CONFIGURATIONS,
|
||||
DEFAULT_SPECTROGRAM_PARAMETERS,
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
@ -18,8 +20,8 @@ from bat_detect.types import (
|
||||
)
|
||||
from bat_detect.utils.detector_utils import list_audio_files, load_model
|
||||
|
||||
# Use GPU if available
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# Remove warnings from torch
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
|
||||
|
||||
__all__ = [
|
||||
"load_model",
|
||||
@ -30,9 +32,19 @@ __all__ = [
|
||||
"process_file",
|
||||
"process_spectrogram",
|
||||
"process_audio",
|
||||
"model",
|
||||
"config",
|
||||
]
|
||||
|
||||
|
||||
# Use GPU if available
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
# Default model
|
||||
MODEL, PARAMS = load_model(DEFAULT_MODEL_PATH, device=DEVICE)
|
||||
|
||||
|
||||
def get_config(**kwargs) -> ProcessingConfiguration:
|
||||
"""Get default processing configuration.
|
||||
|
||||
@ -41,15 +53,22 @@ def get_config(**kwargs) -> ProcessingConfiguration:
|
||||
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore
|
||||
|
||||
|
||||
# Default processing configuration
|
||||
CONFIG = get_config(**PARAMS)
|
||||
|
||||
|
||||
def load_audio(
|
||||
path: str,
|
||||
time_exp_fact: float = 1,
|
||||
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
scale: bool = False,
|
||||
max_duration: Optional[float] = None,
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
) -> np.ndarray:
|
||||
"""Load audio from file.
|
||||
|
||||
All audio will be resampled to the target sample rate. If the audio is
|
||||
longer than max_duration, it will be truncated to max_duration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
@ -67,21 +86,20 @@ def load_audio(
|
||||
-------
|
||||
np.ndarray
|
||||
Audio data.
|
||||
int
|
||||
Sample rate.
|
||||
"""
|
||||
return au.load_audio(
|
||||
_, audio = au.load_audio(
|
||||
path,
|
||||
time_exp_fact,
|
||||
target_samp_rate,
|
||||
scale,
|
||||
max_duration,
|
||||
)
|
||||
return audio
|
||||
|
||||
|
||||
def generate_spectrogram(
|
||||
audio: np.ndarray,
|
||||
samp_rate: int,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: Optional[SpectrogramParameters] = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> torch.Tensor:
|
||||
@ -91,8 +109,10 @@ def generate_spectrogram(
|
||||
----------
|
||||
audio : np.ndarray
|
||||
Audio data.
|
||||
samp_rate : int
|
||||
Sample rate.
|
||||
samp_rate : int, optional
|
||||
Sample rate. Defaults to 256000 which is the target sample rate of
|
||||
the default model. Only change if you loaded the audio with a
|
||||
different sample rate.
|
||||
config : Optional[SpectrogramParameters], optional
|
||||
Spectrogram parameters, by default None (uses default parameters).
|
||||
|
||||
@ -117,7 +137,7 @@ def generate_spectrogram(
|
||||
|
||||
def process_file(
|
||||
audio_file: str,
|
||||
model: DetectionModel,
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> du.RunResults:
|
||||
@ -127,15 +147,15 @@ def process_file(
|
||||
----------
|
||||
audio_file : str
|
||||
Path to audio file.
|
||||
model : DetectionModel
|
||||
Detection model.
|
||||
model : DetectionModel, optional
|
||||
Detection model. Uses default model if not specified.
|
||||
config : Optional[ProcessingConfiguration], optional
|
||||
Processing configuration, by default None (uses default parameters).
|
||||
device : torch.device, optional
|
||||
Device to use, by default tries to use GPU if available.
|
||||
"""
|
||||
if config is None:
|
||||
config = DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
config = CONFIG
|
||||
|
||||
return du.process_file(
|
||||
audio_file,
|
||||
@ -147,8 +167,8 @@ def process_file(
|
||||
|
||||
def process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
samp_rate: int,
|
||||
model: DetectionModel,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
||||
"""Process spectrogram with model.
|
||||
@ -157,10 +177,13 @@ def process_spectrogram(
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Spectrogram.
|
||||
samp_rate : int
|
||||
samp_rate : int, optional
|
||||
Sample rate of the audio from which the spectrogram was generated.
|
||||
model : DetectionModel
|
||||
Detection model.
|
||||
Defaults to 256000 which is the target sample rate of the default
|
||||
model. Only change if you generated the spectrogram with a different
|
||||
sample rate.
|
||||
model : DetectionModel, optional
|
||||
Detection model. Uses default model if not specified.
|
||||
config : Optional[ProcessingConfiguration], optional
|
||||
Processing configuration, by default None (uses default parameters).
|
||||
|
||||
@ -169,7 +192,7 @@ def process_spectrogram(
|
||||
DetectionResult
|
||||
"""
|
||||
if config is None:
|
||||
config = DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
config = CONFIG
|
||||
|
||||
return du.process_spectrogram(
|
||||
spec,
|
||||
@ -181,8 +204,8 @@ def process_spectrogram(
|
||||
|
||||
def process_audio(
|
||||
audio: np.ndarray,
|
||||
samp_rate: int,
|
||||
model: DetectionModel,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
||||
@ -192,10 +215,11 @@ def process_audio(
|
||||
----------
|
||||
audio : np.ndarray
|
||||
Audio data.
|
||||
samp_rate : int
|
||||
Sample rate.
|
||||
model : DetectionModel
|
||||
Detection model.
|
||||
samp_rate : int, optional
|
||||
Sample rate, by default 256000. Only change if you loaded the audio
|
||||
with a different sample rate.
|
||||
model : DetectionModel, optional
|
||||
Detection model. Uses default model if not specified.
|
||||
config : Optional[ProcessingConfiguration], optional
|
||||
Processing configuration, by default None (uses default parameters).
|
||||
device : torch.device, optional
|
||||
@ -213,7 +237,7 @@ def process_audio(
|
||||
Spectrogram of the audio used for prediction.
|
||||
"""
|
||||
if config is None:
|
||||
config = DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
config = CONFIG
|
||||
|
||||
return du.process_audio_array(
|
||||
audio,
|
||||
@ -222,3 +246,10 @@ def process_audio(
|
||||
config,
|
||||
device,
|
||||
)
|
||||
|
||||
|
||||
model: DetectionModel = MODEL
|
||||
"""Base detection model."""
|
||||
|
||||
config: ProcessingConfiguration = CONFIG
|
||||
"""Default processing configuration."""
|
||||
|
@ -1,14 +1,11 @@
|
||||
"""BatDetect2 command line interface."""
|
||||
import os
|
||||
import warnings
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
import click
|
||||
|
||||
import click # noqa: E402
|
||||
|
||||
from bat_detect import api # noqa: E402
|
||||
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH # noqa: E402
|
||||
from bat_detect.utils.detector_utils import save_results_to_file # noqa: E402
|
||||
from bat_detect import api
|
||||
from bat_detect.detector.parameters import DEFAULT_MODEL_PATH
|
||||
from bat_detect.utils.detector_utils import save_results_to_file
|
||||
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
@ -124,15 +121,14 @@ def detect(
|
||||
results_path = audio_file.replace(audio_dir, ann_dir)
|
||||
save_results_to_file(results, results_path)
|
||||
except (RuntimeError, ValueError, LookupError) as err:
|
||||
# TODO: Check what other errors can be thrown
|
||||
error_files.append(audio_file)
|
||||
click.echo(f"Error processing file!: {err}")
|
||||
click.secho(f"Error processing file!: {err}", fg="red")
|
||||
raise err
|
||||
|
||||
click.echo(f"\nResults saved to: {ann_dir}")
|
||||
|
||||
if len(error_files) > 0:
|
||||
click.echo("\nUnable to process the follow files:")
|
||||
click.secho("\nUnable to process the follow files:", fg="red")
|
||||
for err in error_files:
|
||||
click.echo(f" {err}")
|
||||
|
||||
|
@ -6,8 +6,6 @@ from matplotlib import patches
|
||||
from matplotlib.collections import PatchCollection
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
from . import audio_utils as au
|
||||
|
||||
|
||||
def create_box_image(
|
||||
spec,
|
||||
|
@ -7,16 +7,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from bat_detect.api import (
|
||||
generate_spectrogram,
|
||||
get_config,
|
||||
list_audio_files,
|
||||
load_audio,
|
||||
load_model,
|
||||
process_audio,
|
||||
process_file,
|
||||
process_spectrogram,
|
||||
)
|
||||
from bat_detect import api
|
||||
|
||||
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
||||
@ -25,7 +16,7 @@ TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
|
||||
|
||||
def test_load_model_with_default_params():
|
||||
"""Test loading model with default parameters."""
|
||||
model, params = load_model()
|
||||
model, params = api.load_model()
|
||||
|
||||
assert model is not None
|
||||
assert isinstance(model, nn.Module)
|
||||
@ -50,7 +41,7 @@ def test_load_model_with_default_params():
|
||||
|
||||
def test_list_audio_files():
|
||||
"""Test listing audio files."""
|
||||
audio_files = list_audio_files(TEST_DATA_DIR)
|
||||
audio_files = api.list_audio_files(TEST_DATA_DIR)
|
||||
|
||||
assert len(audio_files) == 3
|
||||
assert all(path.endswith((".wav", ".WAV")) for path in audio_files)
|
||||
@ -58,18 +49,17 @@ def test_list_audio_files():
|
||||
|
||||
def test_load_audio():
|
||||
"""Test loading audio."""
|
||||
samplerate, audio = load_audio(TEST_DATA[0])
|
||||
audio = api.load_audio(TEST_DATA[0])
|
||||
|
||||
assert audio is not None
|
||||
assert samplerate == 256000
|
||||
assert isinstance(audio, np.ndarray)
|
||||
assert audio.shape == (128000,)
|
||||
|
||||
|
||||
def test_generate_spectrogram():
|
||||
"""Test generating spectrogram."""
|
||||
samplerate, audio = load_audio(TEST_DATA[0])
|
||||
spectrogram = generate_spectrogram(audio, samplerate)
|
||||
audio = api.load_audio(TEST_DATA[0])
|
||||
spectrogram = api.generate_spectrogram(audio)
|
||||
|
||||
assert spectrogram is not None
|
||||
assert isinstance(spectrogram, torch.Tensor)
|
||||
@ -78,7 +68,7 @@ def test_generate_spectrogram():
|
||||
|
||||
def test_get_default_config():
|
||||
"""Test getting default configuration."""
|
||||
config = get_config()
|
||||
config = api.get_config()
|
||||
|
||||
assert config is not None
|
||||
assert isinstance(config, dict)
|
||||
@ -110,11 +100,55 @@ def test_get_default_config():
|
||||
assert config["spec_slices"] is False
|
||||
|
||||
|
||||
def test_process_file_with_model():
|
||||
def test_api_exposes_default_model():
|
||||
"""Test that API exposes default model."""
|
||||
assert hasattr(api, "model")
|
||||
assert isinstance(api.model, nn.Module)
|
||||
assert type(api.model).__name__ == "Net2DFast"
|
||||
|
||||
# Check that model has expected attributes
|
||||
assert api.model.num_classes == 17
|
||||
assert api.model.num_filts == 128
|
||||
assert api.model.emb_dim == 0
|
||||
assert api.model.ip_height_rs == 128
|
||||
assert api.model.resize_factor == 0.5
|
||||
|
||||
|
||||
def test_api_exposes_default_config():
|
||||
"""Test that API exposes default configuration."""
|
||||
assert hasattr(api, "config")
|
||||
assert isinstance(api.config, dict)
|
||||
|
||||
assert api.config["target_samp_rate"] == 256000
|
||||
assert api.config["fft_win_length"] == 0.002
|
||||
assert api.config["fft_overlap"] == 0.75
|
||||
assert api.config["resize_factor"] == 0.5
|
||||
assert api.config["spec_divide_factor"] == 32
|
||||
assert api.config["spec_height"] == 256
|
||||
assert api.config["spec_scale"] == "pcen"
|
||||
assert api.config["denoise_spec_avg"] is True
|
||||
assert api.config["max_scale_spec"] is False
|
||||
assert api.config["scale_raw_audio"] is False
|
||||
assert len(api.config["class_names"]) == 17
|
||||
assert api.config["detection_threshold"] == 0.01
|
||||
assert api.config["time_expansion"] == 1
|
||||
assert api.config["top_n"] == 3
|
||||
assert api.config["return_raw_preds"] is False
|
||||
assert api.config["max_duration"] is None
|
||||
assert api.config["nms_kernel_size"] == 9
|
||||
assert api.config["max_freq"] == 120000
|
||||
assert api.config["min_freq"] == 10000
|
||||
assert api.config["nms_top_k_per_sec"] == 200
|
||||
assert api.config["quiet"] is True
|
||||
assert api.config["chunk_size"] == 3
|
||||
assert api.config["cnn_features"] is False
|
||||
assert api.config["spec_features"] is False
|
||||
assert api.config["spec_slices"] is False
|
||||
|
||||
|
||||
def test_process_file_with_default_model():
|
||||
"""Test processing file with model."""
|
||||
model, params = load_model()
|
||||
config = get_config(**params)
|
||||
predictions = process_file(TEST_DATA[0], model, config=config)
|
||||
predictions = api.process_file(TEST_DATA[0])
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, dict)
|
||||
@ -141,18 +175,11 @@ def test_process_file_with_model():
|
||||
assert len(pred_dict["annotation"]) > 0
|
||||
|
||||
|
||||
def test_process_spectrogram_with_model():
|
||||
def test_process_spectrogram_with_default_model():
|
||||
"""Test processing spectrogram with model."""
|
||||
model, params = load_model()
|
||||
config = get_config(**params)
|
||||
samplerate, audio = load_audio(TEST_DATA[0])
|
||||
spectrogram = generate_spectrogram(audio, samplerate)
|
||||
predictions, features = process_spectrogram(
|
||||
spectrogram,
|
||||
samplerate,
|
||||
model,
|
||||
config=config,
|
||||
)
|
||||
audio = api.load_audio(TEST_DATA[0])
|
||||
spectrogram = api.generate_spectrogram(audio)
|
||||
predictions, features = api.process_spectrogram(spectrogram)
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, list)
|
||||
@ -172,17 +199,10 @@ def test_process_spectrogram_with_model():
|
||||
assert len(features) == 1
|
||||
|
||||
|
||||
def test_process_audio_with_model():
|
||||
def test_process_audio_with_default_model():
|
||||
"""Test processing audio with model."""
|
||||
model, params = load_model()
|
||||
config = get_config(**params)
|
||||
samplerate, audio = load_audio(TEST_DATA[0])
|
||||
predictions, features, spec = process_audio(
|
||||
audio,
|
||||
samplerate,
|
||||
model,
|
||||
config=config,
|
||||
)
|
||||
audio = api.load_audio(TEST_DATA[0])
|
||||
predictions, features, spec = api.process_audio(audio)
|
||||
|
||||
assert predictions is not None
|
||||
assert isinstance(predictions, list)
|
||||
|
Loading…
Reference in New Issue
Block a user