mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-30 07:02:01 +02:00
Compare commits
19 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4cd71497e7 | ||
![]() |
ba670932d5 | ||
![]() |
d3747c57f2 | ||
![]() |
42a838e9f2 | ||
![]() |
c10903a646 | ||
![]() |
4282e2ae70 | ||
![]() |
52570738f2 | ||
![]() |
cbd362d6ea | ||
![]() |
4b75e13fa2 | ||
![]() |
98bf506634 | ||
![]() |
b4c59f7de1 | ||
![]() |
54ca555587 | ||
![]() |
230b6167bc | ||
![]() |
f62bc99ab2 | ||
![]() |
47dbdc79c2 | ||
![]() |
e10e270de4 | ||
![]() |
6af7fef316 | ||
![]() |
838a1ade0d | ||
![]() |
66ac7e608f |
@ -1,5 +1,5 @@
|
|||||||
[bumpversion]
|
[bumpversion]
|
||||||
current_version = 1.1.1
|
current_version = 1.3.0
|
||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
|
|
||||||
|
21
README.md
21
README.md
@ -96,6 +96,27 @@ detections, features = api.process_spectrogram(spec)
|
|||||||
|
|
||||||
You can integrate the detections or the extracted features to your custom analysis pipeline.
|
You can integrate the detections or the extracted features to your custom analysis pipeline.
|
||||||
|
|
||||||
|
#### Using the Python API with HTTP
|
||||||
|
|
||||||
|
```python
|
||||||
|
from batdetect2 import api
|
||||||
|
import io
|
||||||
|
import requests
|
||||||
|
|
||||||
|
AUDIO_URL = "<insert your audio url here>"
|
||||||
|
|
||||||
|
# Process a whole file from a url
|
||||||
|
results = api.process_url(AUDIO_URL)
|
||||||
|
|
||||||
|
# Or, load audio and compute spectrograms
|
||||||
|
# 'requests.get(AUDIO_URL).content' fetches the raw bytes. You are free to use other sources to fetch the raw bytes
|
||||||
|
audio = api.load_audio(io.BytesIO(requests.get(AUDIO_URL).content))
|
||||||
|
spec = api.generate_spectrogram(audio)
|
||||||
|
|
||||||
|
# And process the audio or the spectrogram with the model
|
||||||
|
detections, features, spec = api.process_audio(audio)
|
||||||
|
detections, features = api.process_spectrogram(spec)
|
||||||
|
```
|
||||||
|
|
||||||
## Training the model on your own data
|
## Training the model on your own data
|
||||||
Take a look at the steps outlined in finetuning readme [here](batdetect2/finetune/readme.md) for a description of how to train your own model.
|
Take a look at the steps outlined in finetuning readme [here](batdetect2/finetune/readme.md) for a description of how to train your own model.
|
||||||
|
@ -3,4 +3,4 @@ import logging
|
|||||||
numba_logger = logging.getLogger("numba")
|
numba_logger = logging.getLogger("numba")
|
||||||
numba_logger.setLevel(logging.WARNING)
|
numba_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
__version__ = "1.1.1"
|
__version__ = "1.3.0"
|
||||||
|
@ -97,8 +97,9 @@ consult the API documentation in the code.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple, BinaryIO, Any, Union
|
||||||
|
|
||||||
|
from .types import AudioPath
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -120,6 +121,12 @@ from batdetect2.types import (
|
|||||||
)
|
)
|
||||||
from batdetect2.utils.detector_utils import list_audio_files, load_model
|
from batdetect2.utils.detector_utils import list_audio_files, load_model
|
||||||
|
|
||||||
|
import audioread
|
||||||
|
import os
|
||||||
|
import soundfile as sf
|
||||||
|
import requests
|
||||||
|
import io
|
||||||
|
|
||||||
# Remove warnings from torch
|
# Remove warnings from torch
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
|
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
|
||||||
|
|
||||||
@ -238,34 +245,82 @@ def generate_spectrogram(
|
|||||||
|
|
||||||
|
|
||||||
def process_file(
|
def process_file(
|
||||||
audio_file: str,
|
path: AudioPath,
|
||||||
model: DetectionModel = MODEL,
|
model: DetectionModel = MODEL,
|
||||||
config: Optional[ProcessingConfiguration] = None,
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
|
file_id: Optional[str] = None
|
||||||
) -> du.RunResults:
|
) -> du.RunResults:
|
||||||
"""Process audio file with model.
|
"""Process audio file with model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
audio_file : str
|
path : AudioPath
|
||||||
Path to audio file.
|
Path to audio data.
|
||||||
model : DetectionModel, optional
|
model : DetectionModel, optional
|
||||||
Detection model. Uses default model if not specified.
|
Detection model. Uses default model if not specified.
|
||||||
config : Optional[ProcessingConfiguration], optional
|
config : Optional[ProcessingConfiguration], optional
|
||||||
Processing configuration, by default None (uses default parameters).
|
Processing configuration, by default None (uses default parameters).
|
||||||
device : torch.device, optional
|
device : torch.device, optional
|
||||||
Device to use, by default tries to use GPU if available.
|
Device to use, by default tries to use GPU if available.
|
||||||
|
file_id: Optional[str],
|
||||||
|
Give the data an id. If path is a string path to a file this can be ignored and
|
||||||
|
the file_id will be the basename of the file.
|
||||||
"""
|
"""
|
||||||
if config is None:
|
if config is None:
|
||||||
config = CONFIG
|
config = CONFIG
|
||||||
|
|
||||||
return du.process_file(
|
return du.process_file(
|
||||||
audio_file,
|
path,
|
||||||
model,
|
model,
|
||||||
config,
|
config,
|
||||||
device,
|
device,
|
||||||
|
file_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def process_url(
|
||||||
|
url: str,
|
||||||
|
model: DetectionModel = MODEL,
|
||||||
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
|
device: torch.device = DEVICE,
|
||||||
|
file_id: Optional[str] = None
|
||||||
|
) -> du.RunResults:
|
||||||
|
"""Process audio file with model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
url : str
|
||||||
|
HTTP URL to load the audio data from
|
||||||
|
model : DetectionModel, optional
|
||||||
|
Detection model. Uses default model if not specified.
|
||||||
|
config : Optional[ProcessingConfiguration], optional
|
||||||
|
Processing configuration, by default None (uses default parameters).
|
||||||
|
device : torch.device, optional
|
||||||
|
Device to use, by default tries to use GPU if available.
|
||||||
|
file_id: Optional[str],
|
||||||
|
Give the data an id. Defaults to the URL
|
||||||
|
"""
|
||||||
|
if config is None:
|
||||||
|
config = CONFIG
|
||||||
|
|
||||||
|
if file_id is None:
|
||||||
|
file_id = url
|
||||||
|
|
||||||
|
response = requests.get(url)
|
||||||
|
|
||||||
|
# Raise exception on HTTP error
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# Retrieve body as raw bytes
|
||||||
|
raw_audio_data = response.content
|
||||||
|
|
||||||
|
return du.process_file(
|
||||||
|
io.BytesIO(raw_audio_data),
|
||||||
|
model,
|
||||||
|
config,
|
||||||
|
device,
|
||||||
|
file_id
|
||||||
|
)
|
||||||
|
|
||||||
def process_spectrogram(
|
def process_spectrogram(
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
|
@ -45,6 +45,12 @@ def cli():
|
|||||||
default=False,
|
default=False,
|
||||||
help="Extracts CNN call features",
|
help="Extracts CNN call features",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--chunk_size",
|
||||||
|
type=float,
|
||||||
|
default=2,
|
||||||
|
help="Specifies the duration of chunks in seconds. BatDetect2 will divide longer files into smaller chunks and process them independently. Larger chunks increase computation time and memory usage but may provide more contextual information for inference.",
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--spec_features",
|
"--spec_features",
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
@ -80,6 +86,7 @@ def detect(
|
|||||||
ann_dir: str,
|
ann_dir: str,
|
||||||
detection_threshold: float,
|
detection_threshold: float,
|
||||||
time_expansion_factor: int,
|
time_expansion_factor: int,
|
||||||
|
chunk_size: float,
|
||||||
**args,
|
**args,
|
||||||
):
|
):
|
||||||
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
|
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
|
||||||
@ -108,7 +115,7 @@ def detect(
|
|||||||
**args,
|
**args,
|
||||||
"time_expansion": time_expansion_factor,
|
"time_expansion": time_expansion_factor,
|
||||||
"spec_slices": False,
|
"spec_slices": False,
|
||||||
"chunk_size": 2,
|
"chunk_size": chunk_size,
|
||||||
"detection_threshold": detection_threshold,
|
"detection_threshold": detection_threshold,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -147,6 +154,7 @@ def print_config(config: ProcessingConfiguration):
|
|||||||
click.echo("\nProcessing Configuration:")
|
click.echo("\nProcessing Configuration:")
|
||||||
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
||||||
click.echo(f"Detection Threshold: {config.get('detection_threshold')}")
|
click.echo(f"Detection Threshold: {config.get('detection_threshold')}")
|
||||||
|
click.echo(f"Chunk Size: {config.get('chunk_size')}s")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
"""Types used in the code base."""
|
"""Types used in the code base."""
|
||||||
|
|
||||||
from typing import List, NamedTuple, Optional, Union
|
from typing import List, NamedTuple, Optional, Union, Any, BinaryIO
|
||||||
|
|
||||||
|
import audioread
|
||||||
|
import os
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -40,6 +44,9 @@ __all__ = [
|
|||||||
"SpectrogramParameters",
|
"SpectrogramParameters",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
AudioPath = Union[
|
||||||
|
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
|
||||||
|
]
|
||||||
|
|
||||||
class SpectrogramParameters(TypedDict):
|
class SpectrogramParameters(TypedDict):
|
||||||
"""Parameters for generating spectrograms."""
|
"""Parameters for generating spectrograms."""
|
||||||
|
@ -1,17 +1,24 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union, Any, BinaryIO
|
||||||
|
|
||||||
|
from ..types import AudioPath
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import librosa.core.spectrum
|
import librosa.core.spectrum
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
import audioread
|
||||||
|
import os
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
from batdetect2.detector import parameters
|
from batdetect2.detector import parameters
|
||||||
|
|
||||||
from . import wavfile
|
from . import wavfile
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_audio",
|
"load_audio",
|
||||||
|
"load_audio_and_samplerate",
|
||||||
"generate_spectrogram",
|
"generate_spectrogram",
|
||||||
"pad_audio",
|
"pad_audio",
|
||||||
]
|
]
|
||||||
@ -140,21 +147,20 @@ def generate_spectrogram(
|
|||||||
|
|
||||||
return spec, spec_for_viz
|
return spec, spec_for_viz
|
||||||
|
|
||||||
|
|
||||||
def load_audio(
|
def load_audio(
|
||||||
audio_file: str,
|
path: AudioPath,
|
||||||
time_exp_fact: float,
|
time_exp_fact: float,
|
||||||
target_samp_rate: int,
|
target_samp_rate: int,
|
||||||
scale: bool = False,
|
scale: bool = False,
|
||||||
max_duration: Optional[float] = None,
|
max_duration: Optional[float] = None,
|
||||||
) -> Tuple[int, np.ndarray]:
|
) -> Tuple[int, np.ndarray ]:
|
||||||
"""Load an audio file and resample it to the target sampling rate.
|
"""Load an audio file and resample it to the target sampling rate.
|
||||||
|
|
||||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||||
Only mono files are supported.
|
Only mono files are supported.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
audio_file (str): Path to the audio file.
|
path (string, int, pathlib.Path, soundfile.SoundFile, audioread object, or file-like object): path to the input file.
|
||||||
target_samp_rate (int): Target sampling rate.
|
target_samp_rate (int): Target sampling rate.
|
||||||
scale (bool): Whether to scale the audio to [-1, 1].
|
scale (bool): Whether to scale the audio to [-1, 1].
|
||||||
max_duration (float): Maximum duration of the audio in seconds.
|
max_duration (float): Maximum duration of the audio in seconds.
|
||||||
@ -166,20 +172,50 @@ def load_audio(
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the audio file is stereo.
|
ValueError: If the audio file is stereo.
|
||||||
|
|
||||||
|
"""
|
||||||
|
sample_rate, audio_data, _ = load_audio_and_samplerate(path, time_exp_fact, target_samp_rate, scale, max_duration)
|
||||||
|
return sample_rate, audio_data
|
||||||
|
|
||||||
|
def load_audio_and_samplerate(
|
||||||
|
path: AudioPath,
|
||||||
|
time_exp_fact: float,
|
||||||
|
target_samp_rate: int,
|
||||||
|
scale: bool = False,
|
||||||
|
max_duration: Optional[float] = None,
|
||||||
|
) -> Tuple[int, np.ndarray, Union[float, int]]:
|
||||||
|
"""Load an audio file and resample it to the target sampling rate.
|
||||||
|
|
||||||
|
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||||
|
Only mono files are supported.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path (string, int, pathlib.Path, soundfile.SoundFile, audioread object, or file-like object): path to the input file.
|
||||||
|
target_samp_rate (int): Target sampling rate.
|
||||||
|
scale (bool): Whether to scale the audio to [-1, 1].
|
||||||
|
max_duration (float): Maximum duration of the audio in seconds.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
sampling_rate: The sampling rate of the audio.
|
||||||
|
audio_raw: The audio signal in a numpy array.
|
||||||
|
file_sampling_rate: The original sampling rate of the audio
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the audio file is stereo.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
|
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
|
||||||
# sampling_rate, audio_raw = wavfile.read(audio_file)
|
# sampling_rate, audio_raw = wavfile.read(audio_file)
|
||||||
audio_raw, sampling_rate = librosa.load(
|
audio_raw, file_sampling_rate = librosa.load(
|
||||||
audio_file,
|
path,
|
||||||
sr=None,
|
sr=None,
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(audio_raw.shape) > 1:
|
if len(audio_raw.shape) > 1:
|
||||||
raise ValueError("Currently does not handle stereo files")
|
raise ValueError("Currently does not handle stereo files")
|
||||||
|
|
||||||
sampling_rate = sampling_rate * time_exp_fact
|
sampling_rate = file_sampling_rate * time_exp_fact
|
||||||
|
|
||||||
# resample - need to do this after correcting for time expansion
|
# resample - need to do this after correcting for time expansion
|
||||||
sampling_rate_old = sampling_rate
|
sampling_rate_old = sampling_rate
|
||||||
@ -207,7 +243,7 @@ def load_audio(
|
|||||||
audio_raw = audio_raw - audio_raw.mean()
|
audio_raw = audio_raw - audio_raw.mean()
|
||||||
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
|
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
|
||||||
|
|
||||||
return sampling_rate, audio_raw
|
return sampling_rate, audio_raw, file_sampling_rate
|
||||||
|
|
||||||
|
|
||||||
def compute_spectrogram_width(
|
def compute_spectrogram_width(
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Iterator, List, Optional, Tuple, Union, BinaryIO
|
||||||
|
|
||||||
|
from ..types import AudioPath
|
||||||
|
|
||||||
import librosa
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
@ -31,6 +32,13 @@ from batdetect2.types import (
|
|||||||
SpectrogramParameters,
|
SpectrogramParameters,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import audioread
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import soundfile as sf
|
||||||
|
import hashlib
|
||||||
|
import uuid
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_model",
|
"load_model",
|
||||||
"list_audio_files",
|
"list_audio_files",
|
||||||
@ -729,10 +737,11 @@ def process_audio_array(
|
|||||||
|
|
||||||
|
|
||||||
def process_file(
|
def process_file(
|
||||||
audio_file: str,
|
path: AudioPath,
|
||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
file_id: Optional[str] = None
|
||||||
) -> Union[RunResults, Any]:
|
) -> Union[RunResults, Any]:
|
||||||
"""Process a single audio file with detection model.
|
"""Process a single audio file with detection model.
|
||||||
|
|
||||||
@ -741,7 +750,7 @@ def process_file(
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
audio_file : str
|
path : AudioPath
|
||||||
Path to audio file.
|
Path to audio file.
|
||||||
|
|
||||||
model : torch.nn.Module
|
model : torch.nn.Module
|
||||||
@ -749,6 +758,9 @@ def process_file(
|
|||||||
|
|
||||||
config : ProcessingConfiguration
|
config : ProcessingConfiguration
|
||||||
Configuration for processing.
|
Configuration for processing.
|
||||||
|
|
||||||
|
file_id: Optional[str],
|
||||||
|
Give the data an id. Defaults to the filename if path is a string. Otherwise an md5 will be calculated from the binary data.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -762,19 +774,17 @@ def process_file(
|
|||||||
cnn_feats = []
|
cnn_feats = []
|
||||||
spec_slices = []
|
spec_slices = []
|
||||||
|
|
||||||
# Get original sampling rate
|
|
||||||
file_samp_rate = librosa.get_samplerate(audio_file)
|
|
||||||
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
|
|
||||||
|
|
||||||
# load audio file
|
# load audio file
|
||||||
sampling_rate, audio_full = au.load_audio(
|
sampling_rate, audio_full, file_samp_rate = au.load_audio_and_samplerate(
|
||||||
audio_file,
|
path,
|
||||||
time_exp_fact=config.get("time_expansion", 1) or 1,
|
time_exp_fact=config.get("time_expansion", 1) or 1,
|
||||||
target_samp_rate=config["target_samp_rate"],
|
target_samp_rate=config["target_samp_rate"],
|
||||||
scale=config["scale_raw_audio"],
|
scale=config["scale_raw_audio"],
|
||||||
max_duration=config.get("max_duration"),
|
max_duration=config.get("max_duration"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
|
||||||
|
|
||||||
# loop through larger file and split into chunks
|
# loop through larger file and split into chunks
|
||||||
# TODO: fix so that it overlaps correctly and takes care of
|
# TODO: fix so that it overlaps correctly and takes care of
|
||||||
# duplicate detections at borders
|
# duplicate detections at borders
|
||||||
@ -823,9 +833,13 @@ def process_file(
|
|||||||
spec_slices,
|
spec_slices,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_file_id = file_id
|
||||||
|
if _file_id is None:
|
||||||
|
_file_id = _generate_id(path)
|
||||||
|
|
||||||
# convert results to a dictionary in the right format
|
# convert results to a dictionary in the right format
|
||||||
results = convert_results(
|
results = convert_results(
|
||||||
file_id=os.path.basename(audio_file),
|
file_id=_file_id,
|
||||||
time_exp=config.get("time_expansion", 1) or 1,
|
time_exp=config.get("time_expansion", 1) or 1,
|
||||||
duration=audio_full.shape[0] / float(sampling_rate),
|
duration=audio_full.shape[0] / float(sampling_rate),
|
||||||
params=config,
|
params=config,
|
||||||
@ -845,6 +859,22 @@ def process_file(
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
def _generate_id(path: AudioPath) -> str:
|
||||||
|
""" Generate an id based on the path.
|
||||||
|
|
||||||
|
If the path is a str or PathLike it will parsed as the basename.
|
||||||
|
This should ensure backwards compatibility with previous versions.
|
||||||
|
"""
|
||||||
|
if isinstance(path, str) or isinstance(path, os.PathLike):
|
||||||
|
return os.path.basename(path)
|
||||||
|
elif isinstance(path, (BinaryIO, io.BytesIO)):
|
||||||
|
path.seek(0)
|
||||||
|
md5 = hashlib.md5(path.read()).hexdigest()
|
||||||
|
path.seek(0)
|
||||||
|
return md5
|
||||||
|
else:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
def summarize_results(results, predictions, config):
|
def summarize_results(results, predictions, config):
|
||||||
"""Print summary of results."""
|
"""Print summary of results."""
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "batdetect2"
|
name = "batdetect2"
|
||||||
version = "1.1.1"
|
version = "1.3.0"
|
||||||
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
|
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
|
||||||
authors = [
|
authors = [
|
||||||
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
||||||
|
@ -10,11 +10,13 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2 import api
|
from batdetect2 import api
|
||||||
|
import io
|
||||||
|
|
||||||
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
||||||
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
|
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
|
||||||
|
|
||||||
|
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
||||||
|
|
||||||
def test_load_model_with_default_params():
|
def test_load_model_with_default_params():
|
||||||
"""Test loading model with default parameters."""
|
"""Test loading model with default parameters."""
|
||||||
@ -280,3 +282,28 @@ def test_process_file_with_empty_predictions_does_not_fail(
|
|||||||
|
|
||||||
assert results is not None
|
assert results is not None
|
||||||
assert len(results["pred_dict"]["annotation"]) == 0
|
assert len(results["pred_dict"]["annotation"]) == 0
|
||||||
|
|
||||||
|
def test_process_file_file_id_defaults_to_basename():
|
||||||
|
"""Test that process_file assigns basename as an id if no file_id is provided."""
|
||||||
|
# Recording donated by @@kdarras
|
||||||
|
basename = "20230322_172000_selec2.wav"
|
||||||
|
path = os.path.join(DATA_DIR, basename)
|
||||||
|
|
||||||
|
output = api.process_file(path)
|
||||||
|
predictions = output["pred_dict"]
|
||||||
|
id = predictions["id"]
|
||||||
|
assert id == basename
|
||||||
|
|
||||||
|
def test_bytesio_file_id_defaults_to_md5():
|
||||||
|
"""Test that process_file assigns an md5 sum as an id if no file_id is provided when using binary data."""
|
||||||
|
# Recording donated by @@kdarras
|
||||||
|
basename = "20230322_172000_selec2.wav"
|
||||||
|
path = os.path.join(DATA_DIR, basename)
|
||||||
|
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
data = io.BytesIO(f.read())
|
||||||
|
|
||||||
|
output = api.process_file(data)
|
||||||
|
predictions = output["pred_dict"]
|
||||||
|
id = predictions["id"]
|
||||||
|
assert id == "7ade9ebf1a9fe5477ff3a2dc57001929"
|
||||||
|
@ -6,7 +6,10 @@ from hypothesis import strategies as st
|
|||||||
|
|
||||||
from batdetect2.detector import parameters
|
from batdetect2.detector import parameters
|
||||||
from batdetect2.utils import audio_utils, detector_utils
|
from batdetect2.utils import audio_utils, detector_utils
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
|
||||||
|
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
||||||
|
|
||||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||||
def test_can_compute_correct_spectrogram_width(duration: float):
|
def test_can_compute_correct_spectrogram_width(duration: float):
|
||||||
@ -134,3 +137,20 @@ def test_pad_audio_with_fixed_width(duration: float, width: int):
|
|||||||
resize_factor=params["resize_factor"],
|
resize_factor=params["resize_factor"],
|
||||||
)
|
)
|
||||||
assert expected_width == width
|
assert expected_width == width
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_audio_using_bytesio():
|
||||||
|
basename = "20230322_172000_selec2.wav"
|
||||||
|
path = os.path.join(DATA_DIR, basename)
|
||||||
|
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
data = io.BytesIO(f.read())
|
||||||
|
|
||||||
|
sample_rate, audio_data, file_sample_rate = audio_utils.load_audio_and_samplerate(data, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ)
|
||||||
|
|
||||||
|
expected_sample_rate, expected_audio_data, exp_file_sample_rate = audio_utils.load_audio_and_samplerate(path, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ)
|
||||||
|
|
||||||
|
assert expected_sample_rate == sample_rate
|
||||||
|
assert exp_file_sample_rate == file_sample_rate
|
||||||
|
|
||||||
|
assert np.array_equal(audio_data, expected_audio_data)
|
@ -130,3 +130,29 @@ def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path):
|
|||||||
)
|
)
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert f"Error processing file {empty_file}" in result.output
|
assert f"Error processing file {empty_file}" in result.output
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_set_chunk_size(tmp_path: Path):
|
||||||
|
results_dir = tmp_path / "results"
|
||||||
|
|
||||||
|
# Remove results dir if it exists
|
||||||
|
if results_dir.exists():
|
||||||
|
results_dir.rmdir()
|
||||||
|
|
||||||
|
result = runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"detect",
|
||||||
|
"example_data/audio",
|
||||||
|
str(results_dir),
|
||||||
|
"0.3",
|
||||||
|
"--chunk_size",
|
||||||
|
"1",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "Chunk Size: 1.0s" in result.output
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert results_dir.exists()
|
||||||
|
assert len(list(results_dir.glob("*.csv"))) == 3
|
||||||
|
assert len(list(results_dir.glob("*.json"))) == 3
|
||||||
|
Loading…
Reference in New Issue
Block a user