mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Compare commits
62 Commits
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4cd71497e7 | ||
![]() |
ba670932d5 | ||
![]() |
d3747c57f2 | ||
![]() |
42a838e9f2 | ||
![]() |
c10903a646 | ||
![]() |
4282e2ae70 | ||
![]() |
52570738f2 | ||
![]() |
cbd362d6ea | ||
![]() |
4b75e13fa2 | ||
![]() |
98bf506634 | ||
![]() |
b4c59f7de1 | ||
![]() |
54ca555587 | ||
![]() |
230b6167bc | ||
![]() |
f62bc99ab2 | ||
![]() |
47dbdc79c2 | ||
![]() |
e10e270de4 | ||
![]() |
6af7fef316 | ||
![]() |
838a1ade0d | ||
![]() |
66ac7e608f | ||
![]() |
2100a3e483 | ||
![]() |
1d3cd2e305 | ||
![]() |
d5753b95bb | ||
![]() |
69f59ff559 | ||
![]() |
1a11174bc4 | ||
![]() |
c5c9476e52 | ||
![]() |
270b3f212d | ||
![]() |
f61d1d8c72 | ||
![]() |
4627ddd739 | ||
![]() |
3477d7b5b4 | ||
![]() |
394c66a2ee | ||
![]() |
d085b3212c | ||
![]() |
c393c5c29b | ||
![]() |
3c22ff28a7 | ||
![]() |
7dc28695b2 | ||
![]() |
505cca2dea | ||
![]() |
7906842a16 | ||
![]() |
a4b22d6590 | ||
![]() |
25e0a53ad1 | ||
![]() |
039c002796 | ||
![]() |
c97a87b2a4 | ||
![]() |
d93d8284d0 | ||
![]() |
697b5dbddb | ||
![]() |
d5bf8f5ad8 | ||
![]() |
fcbccbe012 | ||
![]() |
4917641e2c | ||
![]() |
39c3918103 | ||
![]() |
1ac3808fee | ||
![]() |
9e0ad7fd78 | ||
![]() |
95bb0985e7 | ||
![]() |
cb088359ae | ||
![]() |
c5030123aa | ||
![]() |
1c1fbd8019 | ||
![]() |
c65fe1c9f9 | ||
![]() |
d05bec880a | ||
![]() |
8597ef0a1c | ||
![]() |
2d8a7b67f8 | ||
![]() |
68351d2224 | ||
![]() |
3f34164028 | ||
![]() |
3744709c97 | ||
![]() |
875a581044 | ||
![]() |
c40197bf1c | ||
![]() |
bbba4625be |
8
.bumpversion.cfg
Normal file
8
.bumpversion.cfg
Normal file
@ -0,0 +1,8 @@
|
||||
[bumpversion]
|
||||
current_version = 1.3.0
|
||||
commit = True
|
||||
tag = True
|
||||
|
||||
[bumpversion:file:batdetect2/__init__.py]
|
||||
|
||||
[bumpversion:file:pyproject.toml]
|
35
.github/workflows/python-package.yml
vendored
35
.github/workflows/python-package.yml
vendored
@ -1,34 +1,29 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
||||
|
||||
name: Python package
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
branches: ["main"]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10"]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pytest
|
||||
pip install .
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "uv.lock"
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
run: uv python install ${{ matrix.python-version }}
|
||||
- name: Install the project
|
||||
run: uv sync --all-extras --dev
|
||||
- name: Test with pytest
|
||||
run: uv run pytest
|
||||
|
41
.github/workflows/python-publish.yml
vendored
41
.github/workflows/python-publish.yml
vendored
@ -1,11 +1,3 @@
|
||||
# This workflow will upload a Python Package using Twine when a release is created
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
||||
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
@ -17,23 +9,22 @@ permissions:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: '3.x'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install build
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: "3.x"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install build
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
|
6
.gitignore
vendored
6
.gitignore
vendored
@ -103,10 +103,10 @@ experiments/*
|
||||
.ipynb_checkpoints
|
||||
*.ipynb
|
||||
|
||||
# Bump2version
|
||||
.bumpversion.cfg
|
||||
|
||||
# DO Include
|
||||
!batdetect2_notebook.ipynb
|
||||
!batdetect2/models/*.pth.tar
|
||||
!tests/data/*.wav
|
||||
!tests/data/**/*.wav
|
||||
notebooks/lightning_logs
|
||||
example_data/preprocessed
|
||||
|
21
README.md
21
README.md
@ -96,6 +96,27 @@ detections, features = api.process_spectrogram(spec)
|
||||
|
||||
You can integrate the detections or the extracted features to your custom analysis pipeline.
|
||||
|
||||
#### Using the Python API with HTTP
|
||||
|
||||
```python
|
||||
from batdetect2 import api
|
||||
import io
|
||||
import requests
|
||||
|
||||
AUDIO_URL = "<insert your audio url here>"
|
||||
|
||||
# Process a whole file from a url
|
||||
results = api.process_url(AUDIO_URL)
|
||||
|
||||
# Or, load audio and compute spectrograms
|
||||
# 'requests.get(AUDIO_URL).content' fetches the raw bytes. You are free to use other sources to fetch the raw bytes
|
||||
audio = api.load_audio(io.BytesIO(requests.get(AUDIO_URL).content))
|
||||
spec = api.generate_spectrogram(audio)
|
||||
|
||||
# And process the audio or the spectrogram with the model
|
||||
detections, features, spec = api.process_audio(audio)
|
||||
detections, features = api.process_spectrogram(spec)
|
||||
```
|
||||
|
||||
## Training the model on your own data
|
||||
Take a look at the steps outlined in finetuning readme [here](batdetect2/finetune/readme.md) for a description of how to train your own model.
|
||||
|
@ -1 +1,6 @@
|
||||
__version__ = '1.0.7'
|
||||
import logging
|
||||
|
||||
numba_logger = logging.getLogger("numba")
|
||||
numba_logger.setLevel(logging.WARNING)
|
||||
|
||||
__version__ = "1.3.0"
|
||||
|
@ -97,8 +97,9 @@ consult the API documentation in the code.
|
||||
|
||||
"""
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, BinaryIO, Any, Union
|
||||
|
||||
from .types import AudioPath
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -120,6 +121,12 @@ from batdetect2.types import (
|
||||
)
|
||||
from batdetect2.utils.detector_utils import list_audio_files, load_model
|
||||
|
||||
import audioread
|
||||
import os
|
||||
import soundfile as sf
|
||||
import requests
|
||||
import io
|
||||
|
||||
# Remove warnings from torch
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
|
||||
|
||||
@ -238,34 +245,82 @@ def generate_spectrogram(
|
||||
|
||||
|
||||
def process_file(
|
||||
audio_file: str,
|
||||
path: AudioPath,
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
device: torch.device = DEVICE,
|
||||
file_id: Optional[str] = None
|
||||
) -> du.RunResults:
|
||||
"""Process audio file with model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_file : str
|
||||
Path to audio file.
|
||||
path : AudioPath
|
||||
Path to audio data.
|
||||
model : DetectionModel, optional
|
||||
Detection model. Uses default model if not specified.
|
||||
config : Optional[ProcessingConfiguration], optional
|
||||
Processing configuration, by default None (uses default parameters).
|
||||
device : torch.device, optional
|
||||
Device to use, by default tries to use GPU if available.
|
||||
file_id: Optional[str],
|
||||
Give the data an id. If path is a string path to a file this can be ignored and
|
||||
the file_id will be the basename of the file.
|
||||
"""
|
||||
if config is None:
|
||||
config = CONFIG
|
||||
|
||||
return du.process_file(
|
||||
audio_file,
|
||||
path,
|
||||
model,
|
||||
config,
|
||||
device,
|
||||
file_id
|
||||
)
|
||||
|
||||
def process_url(
|
||||
url: str,
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
device: torch.device = DEVICE,
|
||||
file_id: Optional[str] = None
|
||||
) -> du.RunResults:
|
||||
"""Process audio file with model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
url : str
|
||||
HTTP URL to load the audio data from
|
||||
model : DetectionModel, optional
|
||||
Detection model. Uses default model if not specified.
|
||||
config : Optional[ProcessingConfiguration], optional
|
||||
Processing configuration, by default None (uses default parameters).
|
||||
device : torch.device, optional
|
||||
Device to use, by default tries to use GPU if available.
|
||||
file_id: Optional[str],
|
||||
Give the data an id. Defaults to the URL
|
||||
"""
|
||||
if config is None:
|
||||
config = CONFIG
|
||||
|
||||
if file_id is None:
|
||||
file_id = url
|
||||
|
||||
response = requests.get(url)
|
||||
|
||||
# Raise exception on HTTP error
|
||||
response.raise_for_status()
|
||||
|
||||
# Retrieve body as raw bytes
|
||||
raw_audio_data = response.content
|
||||
|
||||
return du.process_file(
|
||||
io.BytesIO(raw_audio_data),
|
||||
model,
|
||||
config,
|
||||
device,
|
||||
file_id
|
||||
)
|
||||
|
||||
def process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""BatDetect2 command line interface."""
|
||||
|
||||
import os
|
||||
|
||||
import click
|
||||
@ -44,6 +45,12 @@ def cli():
|
||||
default=False,
|
||||
help="Extracts CNN call features",
|
||||
)
|
||||
@click.option(
|
||||
"--chunk_size",
|
||||
type=float,
|
||||
default=2,
|
||||
help="Specifies the duration of chunks in seconds. BatDetect2 will divide longer files into smaller chunks and process them independently. Larger chunks increase computation time and memory usage but may provide more contextual information for inference.",
|
||||
)
|
||||
@click.option(
|
||||
"--spec_features",
|
||||
is_flag=True,
|
||||
@ -79,6 +86,7 @@ def detect(
|
||||
ann_dir: str,
|
||||
detection_threshold: float,
|
||||
time_expansion_factor: int,
|
||||
chunk_size: float,
|
||||
**args,
|
||||
):
|
||||
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
|
||||
@ -107,7 +115,7 @@ def detect(
|
||||
**args,
|
||||
"time_expansion": time_expansion_factor,
|
||||
"spec_slices": False,
|
||||
"chunk_size": 2,
|
||||
"chunk_size": chunk_size,
|
||||
"detection_threshold": detection_threshold,
|
||||
}
|
||||
)
|
||||
@ -129,10 +137,9 @@ def detect(
|
||||
):
|
||||
results_path = audio_file.replace(audio_dir, ann_dir)
|
||||
save_results_to_file(results, results_path)
|
||||
except (RuntimeError, ValueError, LookupError) as err:
|
||||
except (RuntimeError, ValueError, LookupError, EOFError) as err:
|
||||
error_files.append(audio_file)
|
||||
click.secho(f"Error processing file!: {err}", fg="red")
|
||||
raise err
|
||||
click.secho(f"Error processing file {audio_file}: {err}", fg="red")
|
||||
|
||||
click.echo(f"\nResults saved to: {ann_dir}")
|
||||
|
||||
@ -147,6 +154,7 @@ def print_config(config: ProcessingConfiguration):
|
||||
click.echo("\nProcessing Configuration:")
|
||||
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
||||
click.echo(f"Detection Threshold: {config.get('detection_threshold')}")
|
||||
click.echo(f"Chunk Size: {config.get('chunk_size')}s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -1,7 +1,5 @@
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -1,5 +1,10 @@
|
||||
"""Types used in the code base."""
|
||||
from typing import List, NamedTuple, Optional, Union
|
||||
|
||||
from typing import List, NamedTuple, Optional, Union, Any, BinaryIO
|
||||
|
||||
import audioread
|
||||
import os
|
||||
import soundfile as sf
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -17,7 +22,7 @@ except ImportError:
|
||||
|
||||
|
||||
try:
|
||||
from typing import NotRequired
|
||||
from typing import NotRequired # type: ignore
|
||||
except ImportError:
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
@ -39,6 +44,9 @@ __all__ = [
|
||||
"SpectrogramParameters",
|
||||
]
|
||||
|
||||
AudioPath = Union[
|
||||
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
|
||||
]
|
||||
|
||||
class SpectrogramParameters(TypedDict):
|
||||
"""Parameters for generating spectrograms."""
|
||||
|
@ -1,34 +1,67 @@
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union, Any, BinaryIO
|
||||
|
||||
from ..types import AudioPath
|
||||
|
||||
import librosa
|
||||
import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import audioread
|
||||
import os
|
||||
import soundfile as sf
|
||||
|
||||
from batdetect2.detector import parameters
|
||||
|
||||
from . import wavfile
|
||||
|
||||
__all__ = [
|
||||
"load_audio",
|
||||
"load_audio_and_samplerate",
|
||||
"generate_spectrogram",
|
||||
"pad_audio",
|
||||
]
|
||||
|
||||
|
||||
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
|
||||
def time_to_x_coords(
|
||||
time_in_file: float,
|
||||
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
) -> float:
|
||||
nfft = np.floor(window_duration * samplerate) # int() uses floor
|
||||
noverlap = np.floor(window_overlap * nfft)
|
||||
return (time_in_file * samplerate - noverlap) / (nfft - noverlap)
|
||||
|
||||
|
||||
# NOTE this is also defined in post_process
|
||||
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = np.floor(fft_win_length * sampling_rate)
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
|
||||
def x_coords_to_time(
|
||||
x_pos: int,
|
||||
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
) -> float:
|
||||
n_fft = np.floor(window_duration * samplerate)
|
||||
n_overlap = np.floor(window_overlap * n_fft)
|
||||
n_step = n_fft - n_overlap
|
||||
return ((x_pos * n_step) + n_overlap) / samplerate
|
||||
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
||||
|
||||
|
||||
def x_coord_to_sample(
|
||||
x_pos: int,
|
||||
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||
) -> int:
|
||||
n_fft = np.floor(window_duration * samplerate)
|
||||
n_overlap = np.floor(window_overlap * n_fft)
|
||||
n_step = n_fft - n_overlap
|
||||
x_pos = int(x_pos / resize_factor)
|
||||
return int((x_pos * n_step) + n_overlap)
|
||||
|
||||
|
||||
def generate_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
@ -114,21 +147,20 @@ def generate_spectrogram(
|
||||
|
||||
return spec, spec_for_viz
|
||||
|
||||
|
||||
def load_audio(
|
||||
audio_file: str,
|
||||
path: AudioPath,
|
||||
time_exp_fact: float,
|
||||
target_samp_rate: int,
|
||||
scale: bool = False,
|
||||
max_duration: Optional[float] = None,
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
) -> Tuple[int, np.ndarray ]:
|
||||
"""Load an audio file and resample it to the target sampling rate.
|
||||
|
||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||
Only mono files are supported.
|
||||
|
||||
Args:
|
||||
audio_file (str): Path to the audio file.
|
||||
path (string, int, pathlib.Path, soundfile.SoundFile, audioread object, or file-like object): path to the input file.
|
||||
target_samp_rate (int): Target sampling rate.
|
||||
scale (bool): Whether to scale the audio to [-1, 1].
|
||||
max_duration (float): Maximum duration of the audio in seconds.
|
||||
@ -140,20 +172,50 @@ def load_audio(
|
||||
Raises:
|
||||
ValueError: If the audio file is stereo.
|
||||
|
||||
"""
|
||||
sample_rate, audio_data, _ = load_audio_and_samplerate(path, time_exp_fact, target_samp_rate, scale, max_duration)
|
||||
return sample_rate, audio_data
|
||||
|
||||
def load_audio_and_samplerate(
|
||||
path: AudioPath,
|
||||
time_exp_fact: float,
|
||||
target_samp_rate: int,
|
||||
scale: bool = False,
|
||||
max_duration: Optional[float] = None,
|
||||
) -> Tuple[int, np.ndarray, Union[float, int]]:
|
||||
"""Load an audio file and resample it to the target sampling rate.
|
||||
|
||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||
Only mono files are supported.
|
||||
|
||||
Args:
|
||||
path (string, int, pathlib.Path, soundfile.SoundFile, audioread object, or file-like object): path to the input file.
|
||||
target_samp_rate (int): Target sampling rate.
|
||||
scale (bool): Whether to scale the audio to [-1, 1].
|
||||
max_duration (float): Maximum duration of the audio in seconds.
|
||||
|
||||
Returns:
|
||||
sampling_rate: The sampling rate of the audio.
|
||||
audio_raw: The audio signal in a numpy array.
|
||||
file_sampling_rate: The original sampling rate of the audio
|
||||
|
||||
Raises:
|
||||
ValueError: If the audio file is stereo.
|
||||
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
|
||||
# sampling_rate, audio_raw = wavfile.read(audio_file)
|
||||
audio_raw, sampling_rate = librosa.load(
|
||||
audio_file,
|
||||
audio_raw, file_sampling_rate = librosa.load(
|
||||
path,
|
||||
sr=None,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
|
||||
if len(audio_raw.shape) > 1:
|
||||
raise ValueError("Currently does not handle stereo files")
|
||||
|
||||
sampling_rate = sampling_rate * time_exp_fact
|
||||
sampling_rate = file_sampling_rate * time_exp_fact
|
||||
|
||||
# resample - need to do this after correcting for time expansion
|
||||
sampling_rate_old = sampling_rate
|
||||
@ -181,58 +243,121 @@ def load_audio(
|
||||
audio_raw = audio_raw - audio_raw.mean()
|
||||
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
|
||||
|
||||
return sampling_rate, audio_raw
|
||||
return sampling_rate, audio_raw, file_sampling_rate
|
||||
|
||||
|
||||
def compute_spectrogram_width(
|
||||
length: int,
|
||||
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||
) -> int:
|
||||
n_fft = int(window_duration * samplerate)
|
||||
n_overlap = int(window_overlap * n_fft)
|
||||
n_step = n_fft - n_overlap
|
||||
width = (length - n_overlap) // n_step
|
||||
return int(width * resize_factor)
|
||||
|
||||
|
||||
def pad_audio(
|
||||
audio_raw,
|
||||
fs,
|
||||
ms,
|
||||
overlap_perc,
|
||||
resize_factor,
|
||||
divide_factor,
|
||||
fixed_width=None,
|
||||
audio: np.ndarray,
|
||||
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||
divide_factor: int = parameters.SPEC_DIVIDE_FACTOR,
|
||||
fixed_width: Optional[int] = None,
|
||||
):
|
||||
# Adds zeros to the end of the raw data so that the generated sepctrogram
|
||||
# will be evenly divisible by `divide_factor`
|
||||
# Also deals with very short audio clips and fixed_width during training
|
||||
"""Pad audio to be evenly divisible by `divide_factor`.
|
||||
|
||||
# This code could be clearer, clean up
|
||||
nfft = int(ms * fs)
|
||||
noverlap = int(overlap_perc * nfft)
|
||||
step = nfft - noverlap
|
||||
min_size = int(divide_factor * (1.0 / resize_factor))
|
||||
spec_width = (audio_raw.shape[0] - noverlap) // step
|
||||
spec_width_rs = spec_width * resize_factor
|
||||
This function pads the audio signal with zeros to ensure that the
|
||||
generated spectrogram length will be evenly divisible by `divide_factor`.
|
||||
This is important for the model to work correctly.
|
||||
|
||||
if fixed_width is not None and spec_width < fixed_width:
|
||||
# too small
|
||||
# used during training to ensure all the batches are the same size
|
||||
diff = fixed_width * step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = np.hstack(
|
||||
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
|
||||
This `divide_factor` comes from the model architecture as it downscales
|
||||
the spectrogram by this factor, so the input must be divisible by this
|
||||
integer number.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : np.ndarray
|
||||
The audio signal.
|
||||
samplerate : int
|
||||
The sampling rate of the audio signal.
|
||||
window_size : float
|
||||
The window size in seconds used for the spectrogram computation.
|
||||
window_overlap : float
|
||||
The overlap between windows in the spectrogram computation.
|
||||
resize_factor : float
|
||||
This factor is used to resize the spectrogram after the STFT
|
||||
computation. Default is 0.5 which means that the spectrogram will be
|
||||
reduced by half. Important to take into account for the final size of
|
||||
the spectrogram.
|
||||
divide_factor : int
|
||||
The factor by which the spectrogram will be divided.
|
||||
fixed_width : int, optional
|
||||
If provided, the audio will be padded or cut so that the resulting
|
||||
spectrogram width will be equal to this value.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The padded audio signal.
|
||||
"""
|
||||
spec_width = compute_spectrogram_width(
|
||||
audio.shape[0],
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
resize_factor=resize_factor,
|
||||
)
|
||||
|
||||
if fixed_width:
|
||||
target_samples = x_coord_to_sample(
|
||||
fixed_width,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
resize_factor=resize_factor,
|
||||
)
|
||||
|
||||
elif fixed_width is not None and spec_width > fixed_width:
|
||||
# too big
|
||||
# used during training to ensure all the batches are the same size
|
||||
diff = fixed_width * step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = audio_raw[:diff]
|
||||
if spec_width < fixed_width:
|
||||
# need to be at least min_size
|
||||
diff = target_samples - audio.shape[0]
|
||||
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||
|
||||
elif (
|
||||
spec_width_rs < min_size
|
||||
or (np.floor(spec_width_rs) % divide_factor) != 0
|
||||
):
|
||||
# need to be at least min_size
|
||||
div_amt = np.ceil(spec_width_rs / float(divide_factor))
|
||||
div_amt = np.maximum(1, div_amt)
|
||||
target_size = int(div_amt * divide_factor * (1.0 / resize_factor))
|
||||
diff = target_size * step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = np.hstack(
|
||||
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
|
||||
if spec_width > fixed_width:
|
||||
return audio[:target_samples]
|
||||
|
||||
return audio
|
||||
|
||||
min_width = int(divide_factor / resize_factor)
|
||||
|
||||
if spec_width < min_width:
|
||||
target_samples = x_coord_to_sample(
|
||||
min_width,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
resize_factor=resize_factor,
|
||||
)
|
||||
diff = target_samples - audio.shape[0]
|
||||
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||
|
||||
return audio_raw
|
||||
if (spec_width % divide_factor) == 0:
|
||||
return audio
|
||||
|
||||
target_width = int(np.ceil(spec_width / divide_factor)) * divide_factor
|
||||
target_samples = x_coord_to_sample(
|
||||
target_width,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
resize_factor=resize_factor,
|
||||
)
|
||||
diff = target_samples - audio.shape[0]
|
||||
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||
|
||||
|
||||
def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
||||
@ -247,7 +372,11 @@ def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
||||
|
||||
# compute spec
|
||||
spec, _ = librosa.core.spectrum._spectrogram(
|
||||
y=x, power=1, n_fft=nfft, hop_length=step, center=False
|
||||
y=x,
|
||||
power=1,
|
||||
n_fft=nfft,
|
||||
hop_length=step,
|
||||
center=False,
|
||||
)
|
||||
|
||||
# remove DC component and flip vertical orientation
|
||||
|
@ -1,13 +1,19 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Iterator, List, Optional, Tuple, Union, BinaryIO
|
||||
|
||||
from ..types import AudioPath
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from numpy.exceptions import AxisError
|
||||
except ImportError:
|
||||
from numpy import AxisError # type: ignore
|
||||
|
||||
import batdetect2.detector.compute_features as feats
|
||||
import batdetect2.detector.post_process as pp
|
||||
import batdetect2.utils.audio_utils as au
|
||||
@ -26,6 +32,13 @@ from batdetect2.types import (
|
||||
SpectrogramParameters,
|
||||
)
|
||||
|
||||
import audioread
|
||||
import os
|
||||
import io
|
||||
import soundfile as sf
|
||||
import hashlib
|
||||
import uuid
|
||||
|
||||
__all__ = [
|
||||
"load_model",
|
||||
"list_audio_files",
|
||||
@ -80,6 +93,7 @@ def load_model(
|
||||
model_path: str = DEFAULT_MODEL_PATH,
|
||||
load_weights: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
weights_only: bool = True,
|
||||
) -> Tuple[DetectionModel, ModelParameters]:
|
||||
"""Load model from file.
|
||||
|
||||
@ -100,7 +114,11 @@ def load_model(
|
||||
if not os.path.isfile(model_path):
|
||||
raise FileNotFoundError("Model file not found.")
|
||||
|
||||
net_params = torch.load(model_path, map_location=device)
|
||||
net_params = torch.load(
|
||||
model_path,
|
||||
map_location=device,
|
||||
weights_only=weights_only,
|
||||
)
|
||||
|
||||
params = net_params["params"]
|
||||
|
||||
@ -242,7 +260,7 @@ def format_single_result(
|
||||
)
|
||||
class_name = class_names[np.argmax(class_overall)]
|
||||
annotations = get_annotations_from_preds(predictions, class_names)
|
||||
except (np.AxisError, ValueError):
|
||||
except (AxisError, ValueError):
|
||||
# No detections
|
||||
class_overall = np.zeros(len(class_names))
|
||||
class_name = "None"
|
||||
@ -719,10 +737,11 @@ def process_audio_array(
|
||||
|
||||
|
||||
def process_file(
|
||||
audio_file: str,
|
||||
path: AudioPath,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
file_id: Optional[str] = None
|
||||
) -> Union[RunResults, Any]:
|
||||
"""Process a single audio file with detection model.
|
||||
|
||||
@ -731,7 +750,7 @@ def process_file(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_file : str
|
||||
path : AudioPath
|
||||
Path to audio file.
|
||||
|
||||
model : torch.nn.Module
|
||||
@ -739,6 +758,9 @@ def process_file(
|
||||
|
||||
config : ProcessingConfiguration
|
||||
Configuration for processing.
|
||||
|
||||
file_id: Optional[str],
|
||||
Give the data an id. Defaults to the filename if path is a string. Otherwise an md5 will be calculated from the binary data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -752,19 +774,17 @@ def process_file(
|
||||
cnn_feats = []
|
||||
spec_slices = []
|
||||
|
||||
# Get original sampling rate
|
||||
file_samp_rate = librosa.get_samplerate(audio_file)
|
||||
orig_samp_rate = file_samp_rate * config.get("time_expansion", 1) or 1
|
||||
|
||||
# load audio file
|
||||
sampling_rate, audio_full = au.load_audio(
|
||||
audio_file,
|
||||
sampling_rate, audio_full, file_samp_rate = au.load_audio_and_samplerate(
|
||||
path,
|
||||
time_exp_fact=config.get("time_expansion", 1) or 1,
|
||||
target_samp_rate=config["target_samp_rate"],
|
||||
scale=config["scale_raw_audio"],
|
||||
max_duration=config.get("max_duration"),
|
||||
)
|
||||
|
||||
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
|
||||
|
||||
# loop through larger file and split into chunks
|
||||
# TODO: fix so that it overlaps correctly and takes care of
|
||||
# duplicate detections at borders
|
||||
@ -813,9 +833,13 @@ def process_file(
|
||||
spec_slices,
|
||||
)
|
||||
|
||||
_file_id = file_id
|
||||
if _file_id is None:
|
||||
_file_id = _generate_id(path)
|
||||
|
||||
# convert results to a dictionary in the right format
|
||||
results = convert_results(
|
||||
file_id=os.path.basename(audio_file),
|
||||
file_id=_file_id,
|
||||
time_exp=config.get("time_expansion", 1) or 1,
|
||||
duration=audio_full.shape[0] / float(sampling_rate),
|
||||
params=config,
|
||||
@ -835,6 +859,22 @@ def process_file(
|
||||
|
||||
return results
|
||||
|
||||
def _generate_id(path: AudioPath) -> str:
|
||||
""" Generate an id based on the path.
|
||||
|
||||
If the path is a str or PathLike it will parsed as the basename.
|
||||
This should ensure backwards compatibility with previous versions.
|
||||
"""
|
||||
if isinstance(path, str) or isinstance(path, os.PathLike):
|
||||
return os.path.basename(path)
|
||||
elif isinstance(path, (BinaryIO, io.BytesIO)):
|
||||
path.seek(0)
|
||||
md5 = hashlib.md5(path.read()).hexdigest()
|
||||
path.seek(0)
|
||||
return md5
|
||||
else:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def summarize_results(results, predictions, config):
|
||||
"""Print summary of results."""
|
||||
|
110
pyproject.toml
110
pyproject.toml
@ -1,82 +1,82 @@
|
||||
[tool.pdm]
|
||||
[tool.pdm.dev-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.2.2",
|
||||
]
|
||||
|
||||
[project]
|
||||
name = "batdetect2"
|
||||
version = "1.0.7"
|
||||
version = "1.3.0"
|
||||
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
|
||||
authors = [
|
||||
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
||||
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" }
|
||||
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
||||
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" },
|
||||
]
|
||||
dependencies = [
|
||||
"librosa",
|
||||
"matplotlib",
|
||||
"numpy",
|
||||
"pandas",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"torch>=1.13.1,<2",
|
||||
"torchaudio",
|
||||
"torchvision",
|
||||
"click",
|
||||
"click>=8.1.7",
|
||||
"librosa>=0.10.1",
|
||||
"matplotlib>=3.7.1",
|
||||
"numpy>=1.23.5",
|
||||
"pandas>=1.5.3",
|
||||
"scikit-learn>=1.2.2",
|
||||
"scipy>=1.10.1",
|
||||
"torch>=1.13.1,<2.5.0",
|
||||
"torchaudio>=1.13.1,<2.5.0",
|
||||
"torchvision>=0.14.0",
|
||||
]
|
||||
requires-python = ">=3.8,<3.11"
|
||||
requires-python = ">=3.9,<3.13"
|
||||
readme = "README.md"
|
||||
license = { text = "CC-by-nc-4" }
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Natural Language :: English",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Multimedia :: Sound/Audio :: Analysis",
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Natural Language :: English",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Multimedia :: Sound/Audio :: Analysis",
|
||||
]
|
||||
keywords = [
|
||||
"bat",
|
||||
"echolocation",
|
||||
"deep learning",
|
||||
"audio",
|
||||
"machine learning",
|
||||
"classification",
|
||||
"detection",
|
||||
"bat",
|
||||
"echolocation",
|
||||
"deep learning",
|
||||
"audio",
|
||||
"machine learning",
|
||||
"classification",
|
||||
"detection",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["pdm-pep517>=1.0.0"]
|
||||
build-backend = "pdm.pep517.api"
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project.scripts]
|
||||
batdetect2 = "batdetect2.cli:cli"
|
||||
|
||||
[tool.black]
|
||||
line-length = 79
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"librosa",
|
||||
"pandas",
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"debugpy>=1.8.8",
|
||||
"hypothesis>=6.118.7",
|
||||
"pyright>=1.1.388",
|
||||
"pytest>=7.2.2",
|
||||
"ruff>=0.7.3",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.pylsp-mypy]
|
||||
enabled = false
|
||||
live_mode = true
|
||||
strict = true
|
||||
[tool.ruff]
|
||||
line-length = 79
|
||||
target-version = "py39"
|
||||
|
||||
[tool.pydocstyle]
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
docstring-code-line-length = 79
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "numpy"
|
||||
|
||||
[tool.pyright]
|
||||
include = [
|
||||
"bat_detect",
|
||||
"tests",
|
||||
]
|
||||
include = ["batdetect2", "tests"]
|
||||
venvPath = "."
|
||||
venv = ".venv"
|
||||
pythonVersion = "3.9"
|
||||
pythonPlatform = "All"
|
||||
|
40
tests/conftest.py
Normal file
40
tests/conftest.py
Normal file
@ -0,0 +1,40 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_data_dir() -> Path:
|
||||
pkg_dir = Path(__file__).parent.parent
|
||||
example_data_dir = pkg_dir / "example_data"
|
||||
assert example_data_dir.exists()
|
||||
return example_data_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_audio_dir(example_data_dir: Path) -> Path:
|
||||
example_audio_dir = example_data_dir / "audio"
|
||||
assert example_audio_dir.exists()
|
||||
return example_audio_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_audio_files(example_audio_dir: Path) -> List[Path]:
|
||||
audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]"))
|
||||
assert len(audio_files) == 3
|
||||
return audio_files
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_dir() -> Path:
|
||||
dir = Path(__file__).parent / "data"
|
||||
assert dir.exists()
|
||||
return dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def contrib_dir(data_dir) -> Path:
|
||||
dir = data_dir / "contrib"
|
||||
assert dir.exists()
|
||||
return dir
|
BIN
tests/data/contrib/jeff37/0166_20240531_223911.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240531_223911.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240602_225340.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240602_225340.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240603_033731.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240603_033731.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240603_033937.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240603_033937.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240604_233500.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240604_233500.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/padpadpadpad/Audiomoth.WAV
Normal file
BIN
tests/data/contrib/padpadpadpad/Audiomoth.WAV
Normal file
Binary file not shown.
BIN
tests/data/contrib/padpadpadpad/AudiomothNoBatCalls.WAV
Normal file
BIN
tests/data/contrib/padpadpadpad/AudiomothNoBatCalls.WAV
Normal file
Binary file not shown.
BIN
tests/data/contrib/padpadpadpad/Echometer.wav
Normal file
BIN
tests/data/contrib/padpadpadpad/Echometer.wav
Normal file
Binary file not shown.
@ -1,21 +1,22 @@
|
||||
"""Test bat detect module API."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from torch import nn
|
||||
import soundfile as sf
|
||||
|
||||
from batdetect2 import api
|
||||
import io
|
||||
|
||||
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
|
||||
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
|
||||
|
||||
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
||||
|
||||
def test_load_model_with_default_params():
|
||||
"""Test loading model with default parameters."""
|
||||
@ -267,7 +268,6 @@ def test_process_file_with_spec_slices():
|
||||
assert len(results["spec_slices"]) == len(detections)
|
||||
|
||||
|
||||
|
||||
def test_process_file_with_empty_predictions_does_not_fail(
|
||||
tmp_path: Path,
|
||||
):
|
||||
@ -282,3 +282,28 @@ def test_process_file_with_empty_predictions_does_not_fail(
|
||||
|
||||
assert results is not None
|
||||
assert len(results["pred_dict"]["annotation"]) == 0
|
||||
|
||||
def test_process_file_file_id_defaults_to_basename():
|
||||
"""Test that process_file assigns basename as an id if no file_id is provided."""
|
||||
# Recording donated by @@kdarras
|
||||
basename = "20230322_172000_selec2.wav"
|
||||
path = os.path.join(DATA_DIR, basename)
|
||||
|
||||
output = api.process_file(path)
|
||||
predictions = output["pred_dict"]
|
||||
id = predictions["id"]
|
||||
assert id == basename
|
||||
|
||||
def test_bytesio_file_id_defaults_to_md5():
|
||||
"""Test that process_file assigns an md5 sum as an id if no file_id is provided when using binary data."""
|
||||
# Recording donated by @@kdarras
|
||||
basename = "20230322_172000_selec2.wav"
|
||||
path = os.path.join(DATA_DIR, basename)
|
||||
|
||||
with open(path, "rb") as f:
|
||||
data = io.BytesIO(f.read())
|
||||
|
||||
output = api.process_file(data)
|
||||
predictions = output["pred_dict"]
|
||||
id = predictions["id"]
|
||||
assert id == "7ade9ebf1a9fe5477ff3a2dc57001929"
|
||||
|
156
tests/test_audio_utils.py
Normal file
156
tests/test_audio_utils.py
Normal file
@ -0,0 +1,156 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from hypothesis import given
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from batdetect2.detector import parameters
|
||||
from batdetect2.utils import audio_utils, detector_utils
|
||||
import io
|
||||
import os
|
||||
|
||||
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
||||
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
def test_can_compute_correct_spectrogram_width(duration: float):
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
|
||||
length = int(duration * samplerate)
|
||||
audio = np.random.rand(length)
|
||||
|
||||
spectrogram, _ = audio_utils.generate_spectrogram(
|
||||
audio,
|
||||
samplerate,
|
||||
params,
|
||||
)
|
||||
|
||||
# convert to pytorch
|
||||
spectrogram = torch.from_numpy(spectrogram)
|
||||
|
||||
# add batch and channel dimensions
|
||||
spectrogram = spectrogram.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# resize the spec
|
||||
resize_factor = params["resize_factor"]
|
||||
spec_op_shape = (
|
||||
int(params["spec_height"] * resize_factor),
|
||||
int(spectrogram.shape[-1] * resize_factor),
|
||||
)
|
||||
spectrogram = F.interpolate(
|
||||
spectrogram,
|
||||
size=spec_op_shape,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
expected_width = audio_utils.compute_spectrogram_width(
|
||||
length,
|
||||
samplerate=parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
|
||||
assert spectrogram.shape[-1] == expected_width
|
||||
|
||||
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
def test_pad_audio_without_fixed_size(duration: float):
|
||||
# Test the pad_audio function
|
||||
# This function is used to pad audio with zeros to a specific length
|
||||
# It is used in the generate_spectrogram function
|
||||
# The function is tested with a simplepas
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
|
||||
length = int(duration * samplerate)
|
||||
audio = np.random.rand(length)
|
||||
|
||||
# pad the audio to be divisible by divide factor
|
||||
padded_audio = audio_utils.pad_audio(
|
||||
audio,
|
||||
samplerate=samplerate,
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
resize_factor=params["resize_factor"],
|
||||
divide_factor=params["spec_divide_factor"],
|
||||
)
|
||||
|
||||
# check that the padded audio is divisible by the divide factor
|
||||
expected_width = audio_utils.compute_spectrogram_width(
|
||||
len(padded_audio),
|
||||
samplerate=parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
|
||||
assert expected_width % params["spec_divide_factor"] == 0
|
||||
|
||||
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
|
||||
duration: float,
|
||||
):
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
length = int(duration * samplerate)
|
||||
audio = np.random.rand(length)
|
||||
_, spectrogram, _ = detector_utils.compute_spectrogram(
|
||||
audio,
|
||||
samplerate,
|
||||
params,
|
||||
torch.device("cpu"),
|
||||
)
|
||||
assert spectrogram.shape[-1] % params["spec_divide_factor"] == 0
|
||||
|
||||
|
||||
@given(
|
||||
duration=st.floats(min_value=0.1, max_value=2),
|
||||
width=st.integers(min_value=128, max_value=1024),
|
||||
)
|
||||
def test_pad_audio_with_fixed_width(duration: float, width: int):
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
|
||||
length = int(duration * samplerate)
|
||||
audio = np.random.rand(length)
|
||||
|
||||
# pad the audio to be divisible by divide factor
|
||||
padded_audio = audio_utils.pad_audio(
|
||||
audio,
|
||||
samplerate=samplerate,
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
resize_factor=params["resize_factor"],
|
||||
divide_factor=params["spec_divide_factor"],
|
||||
fixed_width=width,
|
||||
)
|
||||
|
||||
# check that the padded audio is divisible by the divide factor
|
||||
expected_width = audio_utils.compute_spectrogram_width(
|
||||
len(padded_audio),
|
||||
samplerate=parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
assert expected_width == width
|
||||
|
||||
|
||||
def test_load_audio_using_bytesio():
|
||||
basename = "20230322_172000_selec2.wav"
|
||||
path = os.path.join(DATA_DIR, basename)
|
||||
|
||||
with open(path, "rb") as f:
|
||||
data = io.BytesIO(f.read())
|
||||
|
||||
sample_rate, audio_data, file_sample_rate = audio_utils.load_audio_and_samplerate(data, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ)
|
||||
|
||||
expected_sample_rate, expected_audio_data, exp_file_sample_rate = audio_utils.load_audio_and_samplerate(path, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ)
|
||||
|
||||
assert expected_sample_rate == sample_rate
|
||||
assert exp_file_sample_rate == file_sample_rate
|
||||
|
||||
assert np.array_equal(audio_data, expected_audio_data)
|
@ -1,22 +1,26 @@
|
||||
"""Test the command line interface."""
|
||||
|
||||
from pathlib import Path
|
||||
from click.testing import CliRunner
|
||||
|
||||
import pandas as pd
|
||||
from click.testing import CliRunner
|
||||
|
||||
from batdetect2.cli import cli
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
def test_cli_base_command():
|
||||
"""Test the base command."""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "BatDetect2 - Bat Call Detection and Classification" in result.output
|
||||
assert (
|
||||
"BatDetect2 - Bat Call Detection and Classification" in result.output
|
||||
)
|
||||
|
||||
|
||||
def test_cli_detect_command_help():
|
||||
"""Test the detect command help."""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["detect", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "Detect bat calls in files in AUDIO_DIR" in result.output
|
||||
@ -30,7 +34,6 @@ def test_cli_detect_command_on_test_audio(tmp_path):
|
||||
if results_dir.exists():
|
||||
results_dir.rmdir()
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
@ -54,7 +57,6 @@ def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
|
||||
if results_dir.exists():
|
||||
results_dir.rmdir()
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
@ -68,8 +70,7 @@ def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert 'Time Expansion Factor: 10' in result.stdout
|
||||
|
||||
assert "Time Expansion Factor: 10" in result.stdout
|
||||
|
||||
|
||||
def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
||||
@ -80,7 +81,6 @@ def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
||||
if results_dir.exists():
|
||||
results_dir.rmdir()
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
@ -94,13 +94,12 @@ def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
||||
assert result.exit_code == 0
|
||||
assert results_dir.exists()
|
||||
|
||||
|
||||
csv_files = [path.name for path in results_dir.glob("*.csv")]
|
||||
|
||||
expected_files = [
|
||||
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
|
||||
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
|
||||
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv"
|
||||
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv",
|
||||
]
|
||||
|
||||
for expected_file in expected_files:
|
||||
@ -108,3 +107,52 @@ def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
||||
|
||||
df = pd.read_csv(results_dir / expected_file)
|
||||
assert not (df.duration == -1).any()
|
||||
|
||||
|
||||
def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path):
|
||||
results_dir = tmp_path / "results"
|
||||
target = tmp_path / "audio"
|
||||
target.mkdir()
|
||||
|
||||
# Create an empty file with the .wav extension
|
||||
empty_file = target / "empty.wav"
|
||||
empty_file.touch()
|
||||
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
args=[
|
||||
"detect",
|
||||
str(target),
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
"--spec_features",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert f"Error processing file {empty_file}" in result.output
|
||||
|
||||
|
||||
def test_can_set_chunk_size(tmp_path: Path):
|
||||
results_dir = tmp_path / "results"
|
||||
|
||||
# Remove results dir if it exists
|
||||
if results_dir.exists():
|
||||
results_dir.rmdir()
|
||||
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"detect",
|
||||
"example_data/audio",
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
"--chunk_size",
|
||||
"1",
|
||||
],
|
||||
)
|
||||
|
||||
assert "Chunk Size: 1.0s" in result.output
|
||||
assert result.exit_code == 0
|
||||
assert results_dir.exists()
|
||||
assert len(list(results_dir.glob("*.csv"))) == 3
|
||||
assert len(list(results_dir.glob("*.json"))) == 3
|
||||
|
73
tests/test_contrib.py
Normal file
73
tests/test_contrib.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""Test suite to ensure user provided files are correctly processed."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from click.testing import CliRunner
|
||||
|
||||
from batdetect2.cli import cli
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
def test_can_process_jeff37_files(
|
||||
contrib_dir: Path,
|
||||
tmp_path: Path,
|
||||
):
|
||||
"""This test stems from issue #31.
|
||||
|
||||
A user provided a set of files which which batdetect2 cli failed and
|
||||
generated the following error message:
|
||||
|
||||
[2272] "Error processing file!: negative dimensions are not allowed"
|
||||
|
||||
This test ensures that the error message is not generated when running
|
||||
batdetect2 cli with the same set of files.
|
||||
"""
|
||||
path = contrib_dir / "jeff37"
|
||||
assert path.exists()
|
||||
|
||||
results_dir = tmp_path / "results"
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"detect",
|
||||
str(path),
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert results_dir.exists()
|
||||
assert len(list(results_dir.glob("*.csv"))) == 5
|
||||
assert len(list(results_dir.glob("*.json"))) == 5
|
||||
|
||||
|
||||
def test_can_process_padpadpadpad_files(
|
||||
contrib_dir: Path,
|
||||
tmp_path: Path,
|
||||
):
|
||||
"""This test stems from issue #29.
|
||||
|
||||
Batdetect2 cli failed on the files provided by the user @padpadpadpad
|
||||
with the following error message:
|
||||
|
||||
AttributeError: module 'numpy' has no attribute 'AxisError'
|
||||
|
||||
This test ensures that the files are processed without any error.
|
||||
"""
|
||||
path = contrib_dir / "padpadpadpad"
|
||||
assert path.exists()
|
||||
results_dir = tmp_path / "results"
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"detect",
|
||||
str(path),
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert results_dir.exists()
|
||||
assert len(list(results_dir.glob("*.csv"))) == 2
|
||||
assert len(list(results_dir.glob("*.json"))) == 2
|
78
tests/test_model.py
Normal file
78
tests/test_model.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""Test suite for model functions."""
|
||||
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from batdetect2 import api
|
||||
from batdetect2.detector import parameters
|
||||
|
||||
|
||||
def test_can_import_model_without_warnings():
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
api.load_model()
|
||||
|
||||
|
||||
@settings(deadline=None, max_examples=5)
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
def test_can_import_model_without_pickle(duration: float):
|
||||
# NOTE: remove this test once no other issues are found This is a temporary
|
||||
# test to check that change in model loading did not impact model behaviour
|
||||
# in any way.
|
||||
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
audio = np.random.rand(int(duration * samplerate))
|
||||
|
||||
model_without_pickle, model_params_without_pickle = api.load_model(
|
||||
weights_only=True
|
||||
)
|
||||
model_with_pickle, model_params_with_pickle = api.load_model(
|
||||
weights_only=False
|
||||
)
|
||||
|
||||
assert model_params_without_pickle == model_params_with_pickle
|
||||
|
||||
predictions_without_pickle, _, _ = api.process_audio(
|
||||
audio,
|
||||
model=model_without_pickle,
|
||||
)
|
||||
predictions_with_pickle, _, _ = api.process_audio(
|
||||
audio,
|
||||
model=model_with_pickle,
|
||||
)
|
||||
|
||||
assert predictions_without_pickle == predictions_with_pickle
|
||||
|
||||
|
||||
def test_can_import_model_without_pickle_on_test_data(
|
||||
example_audio_files: List[Path],
|
||||
):
|
||||
# NOTE: remove this test once no other issues are found This is a temporary
|
||||
# test to check that change in model loading did not impact model behaviour
|
||||
# in any way.
|
||||
|
||||
model_without_pickle, model_params_without_pickle = api.load_model(
|
||||
weights_only=True
|
||||
)
|
||||
model_with_pickle, model_params_with_pickle = api.load_model(
|
||||
weights_only=False
|
||||
)
|
||||
|
||||
assert model_params_without_pickle == model_params_with_pickle
|
||||
|
||||
for audio_file in example_audio_files:
|
||||
audio = api.load_audio(str(audio_file))
|
||||
predictions_without_pickle, _, _ = api.process_audio(
|
||||
audio,
|
||||
model=model_without_pickle,
|
||||
)
|
||||
predictions_with_pickle, _, _ = api.process_audio(
|
||||
audio,
|
||||
model=model_with_pickle,
|
||||
)
|
||||
assert predictions_without_pickle == predictions_with_pickle
|
Loading…
Reference in New Issue
Block a user