batdetect2/bat_detect/api.py
2023-02-26 19:17:47 +00:00

256 lines
6.5 KiB
Python

import warnings
from typing import List, Optional, Tuple
import numpy as np
import torch
import bat_detect.utils.audio_utils as au
import bat_detect.utils.detector_utils as du
from bat_detect.detector.parameters import (
DEFAULT_MODEL_PATH,
DEFAULT_PROCESSING_CONFIGURATIONS,
DEFAULT_SPECTROGRAM_PARAMETERS,
TARGET_SAMPLERATE_HZ,
)
from bat_detect.types import (
Annotation,
DetectionModel,
ProcessingConfiguration,
SpectrogramParameters,
)
from bat_detect.utils.detector_utils import list_audio_files, load_model
# Remove warnings from torch
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
__all__ = [
"load_model",
"load_audio",
"list_audio_files",
"generate_spectrogram",
"get_config",
"process_file",
"process_spectrogram",
"process_audio",
"model",
"config",
]
# Use GPU if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Default model
MODEL, PARAMS = load_model(DEFAULT_MODEL_PATH, device=DEVICE)
def get_config(**kwargs) -> ProcessingConfiguration:
"""Get default processing configuration.
Can be used to override default parameters by passing keyword arguments.
"""
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore
# Default processing configuration
CONFIG = get_config(**PARAMS)
def load_audio(
path: str,
time_exp_fact: float = 1,
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
scale: bool = False,
max_duration: Optional[float] = None,
) -> np.ndarray:
"""Load audio from file.
All audio will be resampled to the target sample rate. If the audio is
longer than max_duration, it will be truncated to max_duration.
Parameters
----------
path : str
Path to audio file.
time_exp_fact : float, optional
Time expansion factor, by default 1
target_samp_rate : int, optional
Target sample rate, by default 256000
scale : bool, optional
Scale audio to [-1, 1], by default False
max_duration : Optional[float], optional
Maximum duration of audio in seconds, by default None
Returns
-------
np.ndarray
Audio data.
"""
_, audio = au.load_audio(
path,
time_exp_fact,
target_samp_rate,
scale,
max_duration,
)
return audio
def generate_spectrogram(
audio: np.ndarray,
samp_rate: int = TARGET_SAMPLERATE_HZ,
config: Optional[SpectrogramParameters] = None,
device: torch.device = DEVICE,
) -> torch.Tensor:
"""Generate spectrogram from audio array.
Parameters
----------
audio : np.ndarray
Audio data.
samp_rate : int, optional
Sample rate. Defaults to 256000 which is the target sample rate of
the default model. Only change if you loaded the audio with a
different sample rate.
config : Optional[SpectrogramParameters], optional
Spectrogram parameters, by default None (uses default parameters).
Returns
-------
torch.Tensor
Spectrogram.
"""
if config is None:
config = DEFAULT_SPECTROGRAM_PARAMETERS
_, spec, _ = du.compute_spectrogram(
audio,
samp_rate,
config,
return_np=False,
device=device,
)
return spec
def process_file(
audio_file: str,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE,
) -> du.RunResults:
"""Process audio file with model.
Parameters
----------
audio_file : str
Path to audio file.
model : DetectionModel, optional
Detection model. Uses default model if not specified.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
device : torch.device, optional
Device to use, by default tries to use GPU if available.
"""
if config is None:
config = CONFIG
return du.process_file(
audio_file,
model,
config,
device,
)
def process_spectrogram(
spec: torch.Tensor,
samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
) -> Tuple[List[Annotation], List[np.ndarray]]:
"""Process spectrogram with model.
Parameters
----------
spec : torch.Tensor
Spectrogram.
samp_rate : int, optional
Sample rate of the audio from which the spectrogram was generated.
Defaults to 256000 which is the target sample rate of the default
model. Only change if you generated the spectrogram with a different
sample rate.
model : DetectionModel, optional
Detection model. Uses default model if not specified.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
Returns
-------
DetectionResult
"""
if config is None:
config = CONFIG
return du.process_spectrogram(
spec,
samp_rate,
model,
config,
)
def process_audio(
audio: np.ndarray,
samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE,
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
"""Process audio array with model.
Parameters
----------
audio : np.ndarray
Audio data.
samp_rate : int, optional
Sample rate, by default 256000. Only change if you loaded the audio
with a different sample rate.
model : DetectionModel, optional
Detection model. Uses default model if not specified.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
device : torch.device, optional
Device to use, by default tries to use GPU if available.
Returns
-------
annotations : List[Annotation]
List of predicted annotations.
features: List[np.ndarray]
List of extracted features for each annotation.
spec : torch.Tensor
Spectrogram of the audio used for prediction.
"""
if config is None:
config = CONFIG
return du.process_audio_array(
audio,
samp_rate,
model,
config,
device,
)
model: DetectionModel = MODEL
"""Base detection model."""
config: ProcessingConfiguration = CONFIG
"""Default processing configuration."""