mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-30 15:12:06 +02:00
Compare commits
No commits in common. "main" and "v1.2.0" have entirely different histories.
@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 1.3.0
|
current_version = 1.2.0
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
|
|
||||||
|
21
README.md
21
README.md
@ -96,27 +96,6 @@ detections, features = api.process_spectrogram(spec)
|
|||||||
|
|
||||||
You can integrate the detections or the extracted features to your custom analysis pipeline.
|
You can integrate the detections or the extracted features to your custom analysis pipeline.
|
||||||
|
|
||||||
#### Using the Python API with HTTP
|
|
||||||
|
|
||||||
```python
|
|
||||||
from batdetect2 import api
|
|
||||||
import io
|
|
||||||
import requests
|
|
||||||
|
|
||||||
AUDIO_URL = "<insert your audio url here>"
|
|
||||||
|
|
||||||
# Process a whole file from a url
|
|
||||||
results = api.process_url(AUDIO_URL)
|
|
||||||
|
|
||||||
# Or, load audio and compute spectrograms
|
|
||||||
# 'requests.get(AUDIO_URL).content' fetches the raw bytes. You are free to use other sources to fetch the raw bytes
|
|
||||||
audio = api.load_audio(io.BytesIO(requests.get(AUDIO_URL).content))
|
|
||||||
spec = api.generate_spectrogram(audio)
|
|
||||||
|
|
||||||
# And process the audio or the spectrogram with the model
|
|
||||||
detections, features, spec = api.process_audio(audio)
|
|
||||||
detections, features = api.process_spectrogram(spec)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Training the model on your own data
|
## Training the model on your own data
|
||||||
Take a look at the steps outlined in finetuning readme [here](batdetect2/finetune/readme.md) for a description of how to train your own model.
|
Take a look at the steps outlined in finetuning readme [here](batdetect2/finetune/readme.md) for a description of how to train your own model.
|
||||||
|
@ -3,4 +3,4 @@ import logging
|
|||||||
numba_logger = logging.getLogger("numba")
|
numba_logger = logging.getLogger("numba")
|
||||||
numba_logger.setLevel(logging.WARNING)
|
numba_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
__version__ = "1.3.0"
|
__version__ = "1.2.0"
|
||||||
|
@ -97,9 +97,8 @@ consult the API documentation in the code.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple, BinaryIO, Any, Union
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from .types import AudioPath
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -121,12 +120,6 @@ from batdetect2.types import (
|
|||||||
)
|
)
|
||||||
from batdetect2.utils.detector_utils import list_audio_files, load_model
|
from batdetect2.utils.detector_utils import list_audio_files, load_model
|
||||||
|
|
||||||
import audioread
|
|
||||||
import os
|
|
||||||
import soundfile as sf
|
|
||||||
import requests
|
|
||||||
import io
|
|
||||||
|
|
||||||
# Remove warnings from torch
|
# Remove warnings from torch
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
|
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
|
||||||
|
|
||||||
@ -245,82 +238,34 @@ def generate_spectrogram(
|
|||||||
|
|
||||||
|
|
||||||
def process_file(
|
def process_file(
|
||||||
path: AudioPath,
|
audio_file: str,
|
||||||
model: DetectionModel = MODEL,
|
model: DetectionModel = MODEL,
|
||||||
config: Optional[ProcessingConfiguration] = None,
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
file_id: Optional[str] = None
|
|
||||||
) -> du.RunResults:
|
) -> du.RunResults:
|
||||||
"""Process audio file with model.
|
"""Process audio file with model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
path : AudioPath
|
audio_file : str
|
||||||
Path to audio data.
|
Path to audio file.
|
||||||
model : DetectionModel, optional
|
model : DetectionModel, optional
|
||||||
Detection model. Uses default model if not specified.
|
Detection model. Uses default model if not specified.
|
||||||
config : Optional[ProcessingConfiguration], optional
|
config : Optional[ProcessingConfiguration], optional
|
||||||
Processing configuration, by default None (uses default parameters).
|
Processing configuration, by default None (uses default parameters).
|
||||||
device : torch.device, optional
|
device : torch.device, optional
|
||||||
Device to use, by default tries to use GPU if available.
|
Device to use, by default tries to use GPU if available.
|
||||||
file_id: Optional[str],
|
|
||||||
Give the data an id. If path is a string path to a file this can be ignored and
|
|
||||||
the file_id will be the basename of the file.
|
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = CONFIG
|
config = CONFIG
|
||||||
|
|
||||||
return du.process_file(
|
return du.process_file(
|
||||||
path,
|
audio_file,
|
||||||
model,
|
model,
|
||||||
config,
|
config,
|
||||||
device,
|
device,
|
||||||
file_id
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_url(
|
|
||||||
url: str,
|
|
||||||
model: DetectionModel = MODEL,
|
|
||||||
config: Optional[ProcessingConfiguration] = None,
|
|
||||||
device: torch.device = DEVICE,
|
|
||||||
file_id: Optional[str] = None
|
|
||||||
) -> du.RunResults:
|
|
||||||
"""Process audio file with model.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
url : str
|
|
||||||
HTTP URL to load the audio data from
|
|
||||||
model : DetectionModel, optional
|
|
||||||
Detection model. Uses default model if not specified.
|
|
||||||
config : Optional[ProcessingConfiguration], optional
|
|
||||||
Processing configuration, by default None (uses default parameters).
|
|
||||||
device : torch.device, optional
|
|
||||||
Device to use, by default tries to use GPU if available.
|
|
||||||
file_id: Optional[str],
|
|
||||||
Give the data an id. Defaults to the URL
|
|
||||||
"""
|
|
||||||
if config is None:
|
|
||||||
config = CONFIG
|
|
||||||
|
|
||||||
if file_id is None:
|
|
||||||
file_id = url
|
|
||||||
|
|
||||||
response = requests.get(url)
|
|
||||||
|
|
||||||
# Raise exception on HTTP error
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
# Retrieve body as raw bytes
|
|
||||||
raw_audio_data = response.content
|
|
||||||
|
|
||||||
return du.process_file(
|
|
||||||
io.BytesIO(raw_audio_data),
|
|
||||||
model,
|
|
||||||
config,
|
|
||||||
device,
|
|
||||||
file_id
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_spectrogram(
|
def process_spectrogram(
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
|
@ -1,10 +1,6 @@
|
|||||||
"""Types used in the code base."""
|
"""Types used in the code base."""
|
||||||
|
|
||||||
from typing import List, NamedTuple, Optional, Union, Any, BinaryIO
|
from typing import List, NamedTuple, Optional, Union
|
||||||
|
|
||||||
import audioread
|
|
||||||
import os
|
|
||||||
import soundfile as sf
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -44,9 +40,6 @@ __all__ = [
|
|||||||
"SpectrogramParameters",
|
"SpectrogramParameters",
|
||||||
]
|
]
|
||||||
|
|
||||||
AudioPath = Union[
|
|
||||||
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
|
|
||||||
]
|
|
||||||
|
|
||||||
class SpectrogramParameters(TypedDict):
|
class SpectrogramParameters(TypedDict):
|
||||||
"""Parameters for generating spectrograms."""
|
"""Parameters for generating spectrograms."""
|
||||||
|
@ -1,24 +1,17 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple, Union, Any, BinaryIO
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from ..types import AudioPath
|
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import librosa.core.spectrum
|
import librosa.core.spectrum
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import audioread
|
|
||||||
import os
|
|
||||||
import soundfile as sf
|
|
||||||
|
|
||||||
from batdetect2.detector import parameters
|
from batdetect2.detector import parameters
|
||||||
|
|
||||||
from . import wavfile
|
from . import wavfile
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_audio",
|
"load_audio",
|
||||||
"load_audio_and_samplerate",
|
|
||||||
"generate_spectrogram",
|
"generate_spectrogram",
|
||||||
"pad_audio",
|
"pad_audio",
|
||||||
]
|
]
|
||||||
@ -147,20 +140,21 @@ def generate_spectrogram(
|
|||||||
|
|
||||||
return spec, spec_for_viz
|
return spec, spec_for_viz
|
||||||
|
|
||||||
|
|
||||||
def load_audio(
|
def load_audio(
|
||||||
path: AudioPath,
|
audio_file: str,
|
||||||
time_exp_fact: float,
|
time_exp_fact: float,
|
||||||
target_samp_rate: int,
|
target_samp_rate: int,
|
||||||
scale: bool = False,
|
scale: bool = False,
|
||||||
max_duration: Optional[float] = None,
|
max_duration: Optional[float] = None,
|
||||||
) -> Tuple[int, np.ndarray ]:
|
) -> Tuple[int, np.ndarray]:
|
||||||
"""Load an audio file and resample it to the target sampling rate.
|
"""Load an audio file and resample it to the target sampling rate.
|
||||||
|
|
||||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||||
Only mono files are supported.
|
Only mono files are supported.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (string, int, pathlib.Path, soundfile.SoundFile, audioread object, or file-like object): path to the input file.
|
audio_file (str): Path to the audio file.
|
||||||
target_samp_rate (int): Target sampling rate.
|
target_samp_rate (int): Target sampling rate.
|
||||||
scale (bool): Whether to scale the audio to [-1, 1].
|
scale (bool): Whether to scale the audio to [-1, 1].
|
||||||
max_duration (float): Maximum duration of the audio in seconds.
|
max_duration (float): Maximum duration of the audio in seconds.
|
||||||
@ -172,42 +166,12 @@ def load_audio(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the audio file is stereo.
|
ValueError: If the audio file is stereo.
|
||||||
|
|
||||||
"""
|
|
||||||
sample_rate, audio_data, _ = load_audio_and_samplerate(path, time_exp_fact, target_samp_rate, scale, max_duration)
|
|
||||||
return sample_rate, audio_data
|
|
||||||
|
|
||||||
def load_audio_and_samplerate(
|
|
||||||
path: AudioPath,
|
|
||||||
time_exp_fact: float,
|
|
||||||
target_samp_rate: int,
|
|
||||||
scale: bool = False,
|
|
||||||
max_duration: Optional[float] = None,
|
|
||||||
) -> Tuple[int, np.ndarray, Union[float, int]]:
|
|
||||||
"""Load an audio file and resample it to the target sampling rate.
|
|
||||||
|
|
||||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
|
||||||
Only mono files are supported.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (string, int, pathlib.Path, soundfile.SoundFile, audioread object, or file-like object): path to the input file.
|
|
||||||
target_samp_rate (int): Target sampling rate.
|
|
||||||
scale (bool): Whether to scale the audio to [-1, 1].
|
|
||||||
max_duration (float): Maximum duration of the audio in seconds.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
sampling_rate: The sampling rate of the audio.
|
|
||||||
audio_raw: The audio signal in a numpy array.
|
|
||||||
file_sampling_rate: The original sampling rate of the audio
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the audio file is stereo.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
|
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
|
||||||
# sampling_rate, audio_raw = wavfile.read(audio_file)
|
# sampling_rate, audio_raw = wavfile.read(audio_file)
|
||||||
audio_raw, file_sampling_rate = librosa.load(
|
audio_raw, sampling_rate = librosa.load(
|
||||||
path,
|
audio_file,
|
||||||
sr=None,
|
sr=None,
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
@ -215,7 +179,7 @@ def load_audio_and_samplerate(
|
|||||||
if len(audio_raw.shape) > 1:
|
if len(audio_raw.shape) > 1:
|
||||||
raise ValueError("Currently does not handle stereo files")
|
raise ValueError("Currently does not handle stereo files")
|
||||||
|
|
||||||
sampling_rate = file_sampling_rate * time_exp_fact
|
sampling_rate = sampling_rate * time_exp_fact
|
||||||
|
|
||||||
# resample - need to do this after correcting for time expansion
|
# resample - need to do this after correcting for time expansion
|
||||||
sampling_rate_old = sampling_rate
|
sampling_rate_old = sampling_rate
|
||||||
@ -243,7 +207,7 @@ def load_audio_and_samplerate(
|
|||||||
audio_raw = audio_raw - audio_raw.mean()
|
audio_raw = audio_raw - audio_raw.mean()
|
||||||
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
|
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
|
||||||
|
|
||||||
return sampling_rate, audio_raw, file_sampling_rate
|
return sampling_rate, audio_raw
|
||||||
|
|
||||||
|
|
||||||
def compute_spectrogram_width(
|
def compute_spectrogram_width(
|
||||||
|
@ -1,9 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Iterator, List, Optional, Tuple, Union, BinaryIO
|
from typing import Any, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from ..types import AudioPath
|
|
||||||
|
|
||||||
|
import librosa
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@ -32,13 +31,6 @@ from batdetect2.types import (
|
|||||||
SpectrogramParameters,
|
SpectrogramParameters,
|
||||||
)
|
)
|
||||||
|
|
||||||
import audioread
|
|
||||||
import os
|
|
||||||
import io
|
|
||||||
import soundfile as sf
|
|
||||||
import hashlib
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_model",
|
"load_model",
|
||||||
"list_audio_files",
|
"list_audio_files",
|
||||||
@ -737,11 +729,10 @@ def process_audio_array(
|
|||||||
|
|
||||||
|
|
||||||
def process_file(
|
def process_file(
|
||||||
path: AudioPath,
|
audio_file: str,
|
||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
file_id: Optional[str] = None
|
|
||||||
) -> Union[RunResults, Any]:
|
) -> Union[RunResults, Any]:
|
||||||
"""Process a single audio file with detection model.
|
"""Process a single audio file with detection model.
|
||||||
|
|
||||||
@ -750,7 +741,7 @@ def process_file(
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
path : AudioPath
|
audio_file : str
|
||||||
Path to audio file.
|
Path to audio file.
|
||||||
|
|
||||||
model : torch.nn.Module
|
model : torch.nn.Module
|
||||||
@ -759,9 +750,6 @@ def process_file(
|
|||||||
config : ProcessingConfiguration
|
config : ProcessingConfiguration
|
||||||
Configuration for processing.
|
Configuration for processing.
|
||||||
|
|
||||||
file_id: Optional[str],
|
|
||||||
Give the data an id. Defaults to the filename if path is a string. Otherwise an md5 will be calculated from the binary data.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
results : Results or Any
|
results : Results or Any
|
||||||
@ -774,17 +762,19 @@ def process_file(
|
|||||||
cnn_feats = []
|
cnn_feats = []
|
||||||
spec_slices = []
|
spec_slices = []
|
||||||
|
|
||||||
|
# Get original sampling rate
|
||||||
|
file_samp_rate = librosa.get_samplerate(audio_file)
|
||||||
|
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
|
||||||
|
|
||||||
# load audio file
|
# load audio file
|
||||||
sampling_rate, audio_full, file_samp_rate = au.load_audio_and_samplerate(
|
sampling_rate, audio_full = au.load_audio(
|
||||||
path,
|
audio_file,
|
||||||
time_exp_fact=config.get("time_expansion", 1) or 1,
|
time_exp_fact=config.get("time_expansion", 1) or 1,
|
||||||
target_samp_rate=config["target_samp_rate"],
|
target_samp_rate=config["target_samp_rate"],
|
||||||
scale=config["scale_raw_audio"],
|
scale=config["scale_raw_audio"],
|
||||||
max_duration=config.get("max_duration"),
|
max_duration=config.get("max_duration"),
|
||||||
)
|
)
|
||||||
|
|
||||||
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
|
|
||||||
|
|
||||||
# loop through larger file and split into chunks
|
# loop through larger file and split into chunks
|
||||||
# TODO: fix so that it overlaps correctly and takes care of
|
# TODO: fix so that it overlaps correctly and takes care of
|
||||||
# duplicate detections at borders
|
# duplicate detections at borders
|
||||||
@ -833,13 +823,9 @@ def process_file(
|
|||||||
spec_slices,
|
spec_slices,
|
||||||
)
|
)
|
||||||
|
|
||||||
_file_id = file_id
|
|
||||||
if _file_id is None:
|
|
||||||
_file_id = _generate_id(path)
|
|
||||||
|
|
||||||
# convert results to a dictionary in the right format
|
# convert results to a dictionary in the right format
|
||||||
results = convert_results(
|
results = convert_results(
|
||||||
file_id=_file_id,
|
file_id=os.path.basename(audio_file),
|
||||||
time_exp=config.get("time_expansion", 1) or 1,
|
time_exp=config.get("time_expansion", 1) or 1,
|
||||||
duration=audio_full.shape[0] / float(sampling_rate),
|
duration=audio_full.shape[0] / float(sampling_rate),
|
||||||
params=config,
|
params=config,
|
||||||
@ -859,22 +845,6 @@ def process_file(
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
def _generate_id(path: AudioPath) -> str:
|
|
||||||
""" Generate an id based on the path.
|
|
||||||
|
|
||||||
If the path is a str or PathLike it will parsed as the basename.
|
|
||||||
This should ensure backwards compatibility with previous versions.
|
|
||||||
"""
|
|
||||||
if isinstance(path, str) or isinstance(path, os.PathLike):
|
|
||||||
return os.path.basename(path)
|
|
||||||
elif isinstance(path, (BinaryIO, io.BytesIO)):
|
|
||||||
path.seek(0)
|
|
||||||
md5 = hashlib.md5(path.read()).hexdigest()
|
|
||||||
path.seek(0)
|
|
||||||
return md5
|
|
||||||
else:
|
|
||||||
return str(uuid.uuid4())
|
|
||||||
|
|
||||||
|
|
||||||
def summarize_results(results, predictions, config):
|
def summarize_results(results, predictions, config):
|
||||||
"""Print summary of results."""
|
"""Print summary of results."""
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "batdetect2"
|
name = "batdetect2"
|
||||||
version = "1.3.0"
|
version = "1.2.0"
|
||||||
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
|
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
|
||||||
authors = [
|
authors = [
|
||||||
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
||||||
|
@ -10,13 +10,11 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2 import api
|
from batdetect2 import api
|
||||||
import io
|
|
||||||
|
|
||||||
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
||||||
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
|
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
|
||||||
|
|
||||||
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
|
||||||
|
|
||||||
def test_load_model_with_default_params():
|
def test_load_model_with_default_params():
|
||||||
"""Test loading model with default parameters."""
|
"""Test loading model with default parameters."""
|
||||||
@ -282,28 +280,3 @@ def test_process_file_with_empty_predictions_does_not_fail(
|
|||||||
|
|
||||||
assert results is not None
|
assert results is not None
|
||||||
assert len(results["pred_dict"]["annotation"]) == 0
|
assert len(results["pred_dict"]["annotation"]) == 0
|
||||||
|
|
||||||
def test_process_file_file_id_defaults_to_basename():
|
|
||||||
"""Test that process_file assigns basename as an id if no file_id is provided."""
|
|
||||||
# Recording donated by @@kdarras
|
|
||||||
basename = "20230322_172000_selec2.wav"
|
|
||||||
path = os.path.join(DATA_DIR, basename)
|
|
||||||
|
|
||||||
output = api.process_file(path)
|
|
||||||
predictions = output["pred_dict"]
|
|
||||||
id = predictions["id"]
|
|
||||||
assert id == basename
|
|
||||||
|
|
||||||
def test_bytesio_file_id_defaults_to_md5():
|
|
||||||
"""Test that process_file assigns an md5 sum as an id if no file_id is provided when using binary data."""
|
|
||||||
# Recording donated by @@kdarras
|
|
||||||
basename = "20230322_172000_selec2.wav"
|
|
||||||
path = os.path.join(DATA_DIR, basename)
|
|
||||||
|
|
||||||
with open(path, "rb") as f:
|
|
||||||
data = io.BytesIO(f.read())
|
|
||||||
|
|
||||||
output = api.process_file(data)
|
|
||||||
predictions = output["pred_dict"]
|
|
||||||
id = predictions["id"]
|
|
||||||
assert id == "7ade9ebf1a9fe5477ff3a2dc57001929"
|
|
||||||
|
@ -6,10 +6,7 @@ from hypothesis import strategies as st
|
|||||||
|
|
||||||
from batdetect2.detector import parameters
|
from batdetect2.detector import parameters
|
||||||
from batdetect2.utils import audio_utils, detector_utils
|
from batdetect2.utils import audio_utils, detector_utils
|
||||||
import io
|
|
||||||
import os
|
|
||||||
|
|
||||||
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
|
||||||
|
|
||||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||||
def test_can_compute_correct_spectrogram_width(duration: float):
|
def test_can_compute_correct_spectrogram_width(duration: float):
|
||||||
@ -137,20 +134,3 @@ def test_pad_audio_with_fixed_width(duration: float, width: int):
|
|||||||
resize_factor=params["resize_factor"],
|
resize_factor=params["resize_factor"],
|
||||||
)
|
)
|
||||||
assert expected_width == width
|
assert expected_width == width
|
||||||
|
|
||||||
|
|
||||||
def test_load_audio_using_bytesio():
|
|
||||||
basename = "20230322_172000_selec2.wav"
|
|
||||||
path = os.path.join(DATA_DIR, basename)
|
|
||||||
|
|
||||||
with open(path, "rb") as f:
|
|
||||||
data = io.BytesIO(f.read())
|
|
||||||
|
|
||||||
sample_rate, audio_data, file_sample_rate = audio_utils.load_audio_and_samplerate(data, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ)
|
|
||||||
|
|
||||||
expected_sample_rate, expected_audio_data, exp_file_sample_rate = audio_utils.load_audio_and_samplerate(path, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ)
|
|
||||||
|
|
||||||
assert expected_sample_rate == sample_rate
|
|
||||||
assert exp_file_sample_rate == file_sample_rate
|
|
||||||
|
|
||||||
assert np.array_equal(audio_data, expected_audio_data)
|
|
Loading…
Reference in New Issue
Block a user