Changed the signature of api.process_file, au.load_audio and du.process_file. This allows users to use the same args for processing data as librosa.load()

This commit is contained in:
Kavi 2025-02-25 14:24:48 +01:00
parent 2100a3e483
commit 66ac7e608f
4 changed files with 68 additions and 23 deletions

View File

@ -97,7 +97,7 @@ consult the API documentation in the code.
"""
import warnings
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, BinaryIO, Any, Union
import numpy as np
import torch
@ -120,6 +120,10 @@ from batdetect2.types import (
)
from batdetect2.utils.detector_utils import list_audio_files, load_model
import audioread
import os
import soundfile as sf
# Remove warnings from torch
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
@ -238,32 +242,41 @@ def generate_spectrogram(
def process_file(
audio_file: str,
path: Union[
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
],
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE,
file_id: str | None = None
) -> du.RunResults:
"""Process audio file with model.
Parameters
----------
audio_file : str
Path to audio file.
path : Union[
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
]
Path to audio data.
model : DetectionModel, optional
Detection model. Uses default model if not specified.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
device : torch.device, optional
Device to use, by default tries to use GPU if available.
file_id: Optional[str],
Give the data an id. If path is a string path to a file this can be ignored and
the file_id will be the basename of the file.
"""
if config is None:
config = CONFIG
return du.process_file(
audio_file,
path,
model,
config,
device,
file_id
)

View File

@ -1,11 +1,15 @@
import warnings
from typing import Optional, Tuple
from typing import Optional, Tuple, Union, Any, BinaryIO
import librosa
import librosa.core.spectrum
import numpy as np
import torch
import audioread
import os
import soundfile as sf
from batdetect2.detector import parameters
from . import wavfile
@ -140,21 +144,29 @@ def generate_spectrogram(
return spec, spec_for_viz
def get_samplerate(
path: Union[
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
]):
with sf.SoundFile(path) as f:
return f.samplerate
def load_audio(
audio_file: str,
path: Union[
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
],
time_exp_fact: float,
target_samp_rate: int,
scale: bool = False,
max_duration: Optional[float] = None,
) -> Tuple[int, np.ndarray]:
) -> Tuple[int, np.ndarray ]:
"""Load an audio file and resample it to the target sampling rate.
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
Only mono files are supported.
Args:
audio_file (str): Path to the audio file.
path (string, int, pathlib.Path, soundfile.SoundFile, audioread object, or file-like object): path to the input file.
target_samp_rate (int): Target sampling rate.
scale (bool): Whether to scale the audio to [-1, 1].
max_duration (float): Maximum duration of the audio in seconds.
@ -170,16 +182,16 @@ def load_audio(
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
# sampling_rate, audio_raw = wavfile.read(audio_file)
audio_raw, sampling_rate = librosa.load(
audio_file,
audio_raw, file_sampling_rate = librosa.load(
path,
sr=None,
dtype=np.float32,
)
if len(audio_raw.shape) > 1:
raise ValueError("Currently does not handle stereo files")
sampling_rate = sampling_rate * time_exp_fact
sampling_rate = file_sampling_rate * time_exp_fact
# resample - need to do this after correcting for time expansion
sampling_rate_old = sampling_rate

View File

@ -1,6 +1,6 @@
import json
import os
from typing import Any, Iterator, List, Optional, Tuple, Union
from typing import Any, Iterator, List, Optional, Tuple, Union, BinaryIO
import librosa
import numpy as np
@ -31,6 +31,11 @@ from batdetect2.types import (
SpectrogramParameters,
)
import audioread
import os
import soundfile as sf
__all__ = [
"load_model",
"list_audio_files",
@ -729,10 +734,13 @@ def process_audio_array(
def process_file(
audio_file: str,
path: Union[
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
],
model: DetectionModel,
config: ProcessingConfiguration,
device: torch.device,
file_id: str | None = None
) -> Union[RunResults, Any]:
"""Process a single audio file with detection model.
@ -741,7 +749,7 @@ def process_file(
Parameters
----------
audio_file : str
path : str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
Path to audio file.
model : torch.nn.Module
@ -762,18 +770,17 @@ def process_file(
cnn_feats = []
spec_slices = []
# Get original sampling rate
file_samp_rate = librosa.get_samplerate(audio_file)
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
# load audio file
sampling_rate, audio_full = au.load_audio(
audio_file,
path,
time_exp_fact=config.get("time_expansion", 1) or 1,
target_samp_rate=config["target_samp_rate"],
scale=config["scale_raw_audio"],
max_duration=config.get("max_duration"),
)
file_samp_rate = au.get_samplerate(path)
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
# loop through larger file and split into chunks
# TODO: fix so that it overlaps correctly and takes care of
@ -823,9 +830,13 @@ def process_file(
spec_slices,
)
_file_id = file_id
if _file_id is None:
_file_id = os.path.basename(path) if isinstance(path, str) else "unknown"
# convert results to a dictionary in the right format
results = convert_results(
file_id=os.path.basename(audio_file),
file_id=_file_id,
time_exp=config.get("time_expansion", 1) or 1,
duration=audio_full.shape[0] / float(sampling_rate),
params=config,

View File

@ -6,7 +6,8 @@ from hypothesis import strategies as st
from batdetect2.detector import parameters
from batdetect2.utils import audio_utils, detector_utils
import io
import requests
@given(duration=st.floats(min_value=0.1, max_value=2))
def test_can_compute_correct_spectrogram_width(duration: float):
@ -134,3 +135,11 @@ def test_pad_audio_with_fixed_width(duration: float, width: int):
resize_factor=params["resize_factor"],
)
assert expected_width == width
def test_get_samplerate_using_bytesio():
audio_url="https://anon.erda.au.dk/share_redirect/e5c7G2AWmg/F1/20240724/2MU02597/BIOBD01_20240626_231650.wav"
sample_rate = audio_utils.get_samplerate(io.BytesIO(requests.get(audio_url).content))
expected_sample_rate = 256000
assert expected_sample_rate == sample_rate