mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Changed the signature of api.process_file, au.load_audio and du.process_file. This allows users to use the same args for processing data as librosa.load()
This commit is contained in:
parent
2100a3e483
commit
66ac7e608f
@ -97,7 +97,7 @@ consult the API documentation in the code.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, BinaryIO, Any, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -120,6 +120,10 @@ from batdetect2.types import (
|
|||||||
)
|
)
|
||||||
from batdetect2.utils.detector_utils import list_audio_files, load_model
|
from batdetect2.utils.detector_utils import list_audio_files, load_model
|
||||||
|
|
||||||
|
import audioread
|
||||||
|
import os
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
# Remove warnings from torch
|
# Remove warnings from torch
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
|
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
|
||||||
|
|
||||||
@ -238,32 +242,41 @@ def generate_spectrogram(
|
|||||||
|
|
||||||
|
|
||||||
def process_file(
|
def process_file(
|
||||||
audio_file: str,
|
path: Union[
|
||||||
|
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
|
||||||
|
],
|
||||||
model: DetectionModel = MODEL,
|
model: DetectionModel = MODEL,
|
||||||
config: Optional[ProcessingConfiguration] = None,
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
|
file_id: str | None = None
|
||||||
) -> du.RunResults:
|
) -> du.RunResults:
|
||||||
"""Process audio file with model.
|
"""Process audio file with model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
audio_file : str
|
path : Union[
|
||||||
Path to audio file.
|
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
|
||||||
|
]
|
||||||
|
Path to audio data.
|
||||||
model : DetectionModel, optional
|
model : DetectionModel, optional
|
||||||
Detection model. Uses default model if not specified.
|
Detection model. Uses default model if not specified.
|
||||||
config : Optional[ProcessingConfiguration], optional
|
config : Optional[ProcessingConfiguration], optional
|
||||||
Processing configuration, by default None (uses default parameters).
|
Processing configuration, by default None (uses default parameters).
|
||||||
device : torch.device, optional
|
device : torch.device, optional
|
||||||
Device to use, by default tries to use GPU if available.
|
Device to use, by default tries to use GPU if available.
|
||||||
|
file_id: Optional[str],
|
||||||
|
Give the data an id. If path is a string path to a file this can be ignored and
|
||||||
|
the file_id will be the basename of the file.
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = CONFIG
|
config = CONFIG
|
||||||
|
|
||||||
return du.process_file(
|
return du.process_file(
|
||||||
audio_file,
|
path,
|
||||||
model,
|
model,
|
||||||
config,
|
config,
|
||||||
device,
|
device,
|
||||||
|
file_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,11 +1,15 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union, Any, BinaryIO
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import librosa.core.spectrum
|
import librosa.core.spectrum
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import audioread
|
||||||
|
import os
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
from batdetect2.detector import parameters
|
from batdetect2.detector import parameters
|
||||||
|
|
||||||
from . import wavfile
|
from . import wavfile
|
||||||
@ -140,21 +144,29 @@ def generate_spectrogram(
|
|||||||
|
|
||||||
return spec, spec_for_viz
|
return spec, spec_for_viz
|
||||||
|
|
||||||
|
def get_samplerate(
|
||||||
|
path: Union[
|
||||||
|
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
|
||||||
|
]):
|
||||||
|
with sf.SoundFile(path) as f:
|
||||||
|
return f.samplerate
|
||||||
|
|
||||||
def load_audio(
|
def load_audio(
|
||||||
audio_file: str,
|
path: Union[
|
||||||
|
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
|
||||||
|
],
|
||||||
time_exp_fact: float,
|
time_exp_fact: float,
|
||||||
target_samp_rate: int,
|
target_samp_rate: int,
|
||||||
scale: bool = False,
|
scale: bool = False,
|
||||||
max_duration: Optional[float] = None,
|
max_duration: Optional[float] = None,
|
||||||
) -> Tuple[int, np.ndarray]:
|
) -> Tuple[int, np.ndarray ]:
|
||||||
"""Load an audio file and resample it to the target sampling rate.
|
"""Load an audio file and resample it to the target sampling rate.
|
||||||
|
|
||||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||||
Only mono files are supported.
|
Only mono files are supported.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio_file (str): Path to the audio file.
|
path (string, int, pathlib.Path, soundfile.SoundFile, audioread object, or file-like object): path to the input file.
|
||||||
target_samp_rate (int): Target sampling rate.
|
target_samp_rate (int): Target sampling rate.
|
||||||
scale (bool): Whether to scale the audio to [-1, 1].
|
scale (bool): Whether to scale the audio to [-1, 1].
|
||||||
max_duration (float): Maximum duration of the audio in seconds.
|
max_duration (float): Maximum duration of the audio in seconds.
|
||||||
@ -170,16 +182,16 @@ def load_audio(
|
|||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
|
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
|
||||||
# sampling_rate, audio_raw = wavfile.read(audio_file)
|
# sampling_rate, audio_raw = wavfile.read(audio_file)
|
||||||
audio_raw, sampling_rate = librosa.load(
|
audio_raw, file_sampling_rate = librosa.load(
|
||||||
audio_file,
|
path,
|
||||||
sr=None,
|
sr=None,
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(audio_raw.shape) > 1:
|
if len(audio_raw.shape) > 1:
|
||||||
raise ValueError("Currently does not handle stereo files")
|
raise ValueError("Currently does not handle stereo files")
|
||||||
|
|
||||||
sampling_rate = sampling_rate * time_exp_fact
|
sampling_rate = file_sampling_rate * time_exp_fact
|
||||||
|
|
||||||
# resample - need to do this after correcting for time expansion
|
# resample - need to do this after correcting for time expansion
|
||||||
sampling_rate_old = sampling_rate
|
sampling_rate_old = sampling_rate
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Iterator, List, Optional, Tuple, Union, BinaryIO
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -31,6 +31,11 @@ from batdetect2.types import (
|
|||||||
SpectrogramParameters,
|
SpectrogramParameters,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import audioread
|
||||||
|
import os
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_model",
|
"load_model",
|
||||||
"list_audio_files",
|
"list_audio_files",
|
||||||
@ -729,10 +734,13 @@ def process_audio_array(
|
|||||||
|
|
||||||
|
|
||||||
def process_file(
|
def process_file(
|
||||||
audio_file: str,
|
path: Union[
|
||||||
|
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
|
||||||
|
],
|
||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
file_id: str | None = None
|
||||||
) -> Union[RunResults, Any]:
|
) -> Union[RunResults, Any]:
|
||||||
"""Process a single audio file with detection model.
|
"""Process a single audio file with detection model.
|
||||||
|
|
||||||
@ -741,7 +749,7 @@ def process_file(
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
audio_file : str
|
path : str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
|
||||||
Path to audio file.
|
Path to audio file.
|
||||||
|
|
||||||
model : torch.nn.Module
|
model : torch.nn.Module
|
||||||
@ -762,18 +770,17 @@ def process_file(
|
|||||||
cnn_feats = []
|
cnn_feats = []
|
||||||
spec_slices = []
|
spec_slices = []
|
||||||
|
|
||||||
# Get original sampling rate
|
|
||||||
file_samp_rate = librosa.get_samplerate(audio_file)
|
|
||||||
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
|
|
||||||
|
|
||||||
# load audio file
|
# load audio file
|
||||||
sampling_rate, audio_full = au.load_audio(
|
sampling_rate, audio_full = au.load_audio(
|
||||||
audio_file,
|
path,
|
||||||
time_exp_fact=config.get("time_expansion", 1) or 1,
|
time_exp_fact=config.get("time_expansion", 1) or 1,
|
||||||
target_samp_rate=config["target_samp_rate"],
|
target_samp_rate=config["target_samp_rate"],
|
||||||
scale=config["scale_raw_audio"],
|
scale=config["scale_raw_audio"],
|
||||||
max_duration=config.get("max_duration"),
|
max_duration=config.get("max_duration"),
|
||||||
)
|
)
|
||||||
|
file_samp_rate = au.get_samplerate(path)
|
||||||
|
|
||||||
|
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
|
||||||
|
|
||||||
# loop through larger file and split into chunks
|
# loop through larger file and split into chunks
|
||||||
# TODO: fix so that it overlaps correctly and takes care of
|
# TODO: fix so that it overlaps correctly and takes care of
|
||||||
@ -823,9 +830,13 @@ def process_file(
|
|||||||
spec_slices,
|
spec_slices,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_file_id = file_id
|
||||||
|
if _file_id is None:
|
||||||
|
_file_id = os.path.basename(path) if isinstance(path, str) else "unknown"
|
||||||
|
|
||||||
# convert results to a dictionary in the right format
|
# convert results to a dictionary in the right format
|
||||||
results = convert_results(
|
results = convert_results(
|
||||||
file_id=os.path.basename(audio_file),
|
file_id=_file_id,
|
||||||
time_exp=config.get("time_expansion", 1) or 1,
|
time_exp=config.get("time_expansion", 1) or 1,
|
||||||
duration=audio_full.shape[0] / float(sampling_rate),
|
duration=audio_full.shape[0] / float(sampling_rate),
|
||||||
params=config,
|
params=config,
|
||||||
|
@ -6,7 +6,8 @@ from hypothesis import strategies as st
|
|||||||
|
|
||||||
from batdetect2.detector import parameters
|
from batdetect2.detector import parameters
|
||||||
from batdetect2.utils import audio_utils, detector_utils
|
from batdetect2.utils import audio_utils, detector_utils
|
||||||
|
import io
|
||||||
|
import requests
|
||||||
|
|
||||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||||
def test_can_compute_correct_spectrogram_width(duration: float):
|
def test_can_compute_correct_spectrogram_width(duration: float):
|
||||||
@ -134,3 +135,11 @@ def test_pad_audio_with_fixed_width(duration: float, width: int):
|
|||||||
resize_factor=params["resize_factor"],
|
resize_factor=params["resize_factor"],
|
||||||
)
|
)
|
||||||
assert expected_width == width
|
assert expected_width == width
|
||||||
|
|
||||||
|
def test_get_samplerate_using_bytesio():
|
||||||
|
audio_url="https://anon.erda.au.dk/share_redirect/e5c7G2AWmg/F1/20240724/2MU02597/BIOBD01_20240626_231650.wav"
|
||||||
|
|
||||||
|
sample_rate = audio_utils.get_samplerate(io.BytesIO(requests.get(audio_url).content))
|
||||||
|
|
||||||
|
expected_sample_rate = 256000
|
||||||
|
assert expected_sample_rate == sample_rate
|
||||||
|
Loading…
Reference in New Issue
Block a user