mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
256 lines
6.5 KiB
Python
256 lines
6.5 KiB
Python
import warnings
|
|
from typing import List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
import bat_detect.utils.audio_utils as au
|
|
import bat_detect.utils.detector_utils as du
|
|
from bat_detect.detector.parameters import (
|
|
DEFAULT_MODEL_PATH,
|
|
DEFAULT_PROCESSING_CONFIGURATIONS,
|
|
DEFAULT_SPECTROGRAM_PARAMETERS,
|
|
TARGET_SAMPLERATE_HZ,
|
|
)
|
|
from bat_detect.types import (
|
|
Annotation,
|
|
DetectionModel,
|
|
ProcessingConfiguration,
|
|
SpectrogramParameters,
|
|
)
|
|
from bat_detect.utils.detector_utils import list_audio_files, load_model
|
|
|
|
# Remove warnings from torch
|
|
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
|
|
|
|
__all__ = [
|
|
"load_model",
|
|
"load_audio",
|
|
"list_audio_files",
|
|
"generate_spectrogram",
|
|
"get_config",
|
|
"process_file",
|
|
"process_spectrogram",
|
|
"process_audio",
|
|
"model",
|
|
"config",
|
|
]
|
|
|
|
|
|
# Use GPU if available
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
# Default model
|
|
MODEL, PARAMS = load_model(DEFAULT_MODEL_PATH, device=DEVICE)
|
|
|
|
|
|
def get_config(**kwargs) -> ProcessingConfiguration:
|
|
"""Get default processing configuration.
|
|
|
|
Can be used to override default parameters by passing keyword arguments.
|
|
"""
|
|
return {**DEFAULT_PROCESSING_CONFIGURATIONS, **kwargs} # type: ignore
|
|
|
|
|
|
# Default processing configuration
|
|
CONFIG = get_config(**PARAMS)
|
|
|
|
|
|
def load_audio(
|
|
path: str,
|
|
time_exp_fact: float = 1,
|
|
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
|
|
scale: bool = False,
|
|
max_duration: Optional[float] = None,
|
|
) -> np.ndarray:
|
|
"""Load audio from file.
|
|
|
|
All audio will be resampled to the target sample rate. If the audio is
|
|
longer than max_duration, it will be truncated to max_duration.
|
|
|
|
Parameters
|
|
----------
|
|
path : str
|
|
Path to audio file.
|
|
time_exp_fact : float, optional
|
|
Time expansion factor, by default 1
|
|
target_samp_rate : int, optional
|
|
Target sample rate, by default 256000
|
|
scale : bool, optional
|
|
Scale audio to [-1, 1], by default False
|
|
max_duration : Optional[float], optional
|
|
Maximum duration of audio in seconds, by default None
|
|
|
|
Returns
|
|
-------
|
|
np.ndarray
|
|
Audio data.
|
|
"""
|
|
_, audio = au.load_audio(
|
|
path,
|
|
time_exp_fact,
|
|
target_samp_rate,
|
|
scale,
|
|
max_duration,
|
|
)
|
|
return audio
|
|
|
|
|
|
def generate_spectrogram(
|
|
audio: np.ndarray,
|
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
|
config: Optional[SpectrogramParameters] = None,
|
|
device: torch.device = DEVICE,
|
|
) -> torch.Tensor:
|
|
"""Generate spectrogram from audio array.
|
|
|
|
Parameters
|
|
----------
|
|
audio : np.ndarray
|
|
Audio data.
|
|
samp_rate : int, optional
|
|
Sample rate. Defaults to 256000 which is the target sample rate of
|
|
the default model. Only change if you loaded the audio with a
|
|
different sample rate.
|
|
config : Optional[SpectrogramParameters], optional
|
|
Spectrogram parameters, by default None (uses default parameters).
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
Spectrogram.
|
|
"""
|
|
if config is None:
|
|
config = DEFAULT_SPECTROGRAM_PARAMETERS
|
|
|
|
_, spec, _ = du.compute_spectrogram(
|
|
audio,
|
|
samp_rate,
|
|
config,
|
|
return_np=False,
|
|
device=device,
|
|
)
|
|
|
|
return spec
|
|
|
|
|
|
def process_file(
|
|
audio_file: str,
|
|
model: DetectionModel = MODEL,
|
|
config: Optional[ProcessingConfiguration] = None,
|
|
device: torch.device = DEVICE,
|
|
) -> du.RunResults:
|
|
"""Process audio file with model.
|
|
|
|
Parameters
|
|
----------
|
|
audio_file : str
|
|
Path to audio file.
|
|
model : DetectionModel, optional
|
|
Detection model. Uses default model if not specified.
|
|
config : Optional[ProcessingConfiguration], optional
|
|
Processing configuration, by default None (uses default parameters).
|
|
device : torch.device, optional
|
|
Device to use, by default tries to use GPU if available.
|
|
"""
|
|
if config is None:
|
|
config = CONFIG
|
|
|
|
return du.process_file(
|
|
audio_file,
|
|
model,
|
|
config,
|
|
device,
|
|
)
|
|
|
|
|
|
def process_spectrogram(
|
|
spec: torch.Tensor,
|
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
|
model: DetectionModel = MODEL,
|
|
config: Optional[ProcessingConfiguration] = None,
|
|
) -> Tuple[List[Annotation], List[np.ndarray]]:
|
|
"""Process spectrogram with model.
|
|
|
|
Parameters
|
|
----------
|
|
spec : torch.Tensor
|
|
Spectrogram.
|
|
samp_rate : int, optional
|
|
Sample rate of the audio from which the spectrogram was generated.
|
|
Defaults to 256000 which is the target sample rate of the default
|
|
model. Only change if you generated the spectrogram with a different
|
|
sample rate.
|
|
model : DetectionModel, optional
|
|
Detection model. Uses default model if not specified.
|
|
config : Optional[ProcessingConfiguration], optional
|
|
Processing configuration, by default None (uses default parameters).
|
|
|
|
Returns
|
|
-------
|
|
DetectionResult
|
|
"""
|
|
if config is None:
|
|
config = CONFIG
|
|
|
|
return du.process_spectrogram(
|
|
spec,
|
|
samp_rate,
|
|
model,
|
|
config,
|
|
)
|
|
|
|
|
|
def process_audio(
|
|
audio: np.ndarray,
|
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
|
model: DetectionModel = MODEL,
|
|
config: Optional[ProcessingConfiguration] = None,
|
|
device: torch.device = DEVICE,
|
|
) -> Tuple[List[Annotation], List[np.ndarray], torch.Tensor]:
|
|
"""Process audio array with model.
|
|
|
|
Parameters
|
|
----------
|
|
audio : np.ndarray
|
|
Audio data.
|
|
samp_rate : int, optional
|
|
Sample rate, by default 256000. Only change if you loaded the audio
|
|
with a different sample rate.
|
|
model : DetectionModel, optional
|
|
Detection model. Uses default model if not specified.
|
|
config : Optional[ProcessingConfiguration], optional
|
|
Processing configuration, by default None (uses default parameters).
|
|
device : torch.device, optional
|
|
Device to use, by default tries to use GPU if available.
|
|
|
|
Returns
|
|
-------
|
|
annotations : List[Annotation]
|
|
List of predicted annotations.
|
|
|
|
features: List[np.ndarray]
|
|
List of extracted features for each annotation.
|
|
|
|
spec : torch.Tensor
|
|
Spectrogram of the audio used for prediction.
|
|
"""
|
|
if config is None:
|
|
config = CONFIG
|
|
|
|
return du.process_audio_array(
|
|
audio,
|
|
samp_rate,
|
|
model,
|
|
config,
|
|
device,
|
|
)
|
|
|
|
|
|
model: DetectionModel = MODEL
|
|
"""Base detection model."""
|
|
|
|
config: ProcessingConfiguration = CONFIG
|
|
"""Default processing configuration."""
|