mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Merge branch 'main' into train
This commit is contained in:
commit
6a9e33c729
8
.bumpversion.cfg
Normal file
8
.bumpversion.cfg
Normal file
@ -0,0 +1,8 @@
|
||||
[bumpversion]
|
||||
current_version = 1.1.1
|
||||
commit = True
|
||||
tag = True
|
||||
|
||||
[bumpversion:file:batdetect2/__init__.py]
|
||||
|
||||
[bumpversion:file:pyproject.toml]
|
27
.github/workflows/python-package.yml
vendored
27
.github/workflows/python-package.yml
vendored
@ -1,6 +1,3 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
||||
|
||||
name: Python package
|
||||
|
||||
on:
|
||||
@ -11,24 +8,22 @@ on:
|
||||
|
||||
jobs:
|
||||
build:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install pytest
|
||||
pip install .
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "uv.lock"
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
run: uv python install ${{ matrix.python-version }}
|
||||
- name: Install the project
|
||||
run: uv sync --all-extras --dev
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
pytest
|
||||
run: uv run pytest
|
||||
|
13
.github/workflows/python-publish.yml
vendored
13
.github/workflows/python-publish.yml
vendored
@ -1,11 +1,3 @@
|
||||
# This workflow will upload a Python Package using Twine when a release is created
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
||||
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
@ -17,15 +9,14 @@ permissions:
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: '3.x'
|
||||
python-version: "3.x"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -103,13 +103,11 @@ experiments/*
|
||||
.ipynb_checkpoints
|
||||
*.ipynb
|
||||
|
||||
# Bump2version
|
||||
.bumpversion.cfg
|
||||
|
||||
# DO Include
|
||||
!batdetect2_notebook.ipynb
|
||||
!batdetect2/models/checkpoints/*.pth.tar
|
||||
!tests/data/*.wav
|
||||
!notebooks/*.ipynb
|
||||
!tests/data/**/*.wav
|
||||
notebooks/lightning_logs
|
||||
example_data/preprocessed
|
||||
|
@ -1 +1,6 @@
|
||||
__version__ = '1.0.8'
|
||||
import logging
|
||||
|
||||
numba_logger = logging.getLogger("numba")
|
||||
numba_logger.setLevel(logging.WARNING)
|
||||
|
||||
__version__ = "1.1.1"
|
||||
|
@ -1,5 +1,3 @@
|
||||
"""BatDetect2 command line interface."""
|
||||
|
||||
import click
|
||||
|
||||
from batdetect2 import api
|
||||
@ -114,10 +112,9 @@ def detect(
|
||||
):
|
||||
results_path = audio_file.replace(audio_dir, ann_dir)
|
||||
save_results_to_file(results, results_path)
|
||||
except (RuntimeError, ValueError, LookupError) as err:
|
||||
except (RuntimeError, ValueError, LookupError, EOFError) as err:
|
||||
error_files.append(audio_file)
|
||||
click.secho(f"Error processing file!: {err}", fg="red")
|
||||
raise err
|
||||
click.secho(f"Error processing file {audio_file}: {err}", fg="red")
|
||||
|
||||
click.echo(f"\nResults saved to: {ann_dir}")
|
||||
|
||||
|
@ -1,13 +1,11 @@
|
||||
"""Types used in the code base."""
|
||||
from typing import Any, List, NamedTuple, Optional
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
from typing import TypedDict
|
||||
except ImportError:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
try:
|
||||
@ -15,9 +13,8 @@ try:
|
||||
except ImportError:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
|
||||
try:
|
||||
from typing import NotRequired
|
||||
from typing import NotRequired # type: ignore
|
||||
except ImportError:
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
|
@ -6,6 +6,8 @@ import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from batdetect2.detector import parameters
|
||||
|
||||
from . import wavfile
|
||||
|
||||
__all__ = [
|
||||
@ -15,20 +17,44 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
|
||||
def time_to_x_coords(
|
||||
time_in_file: float,
|
||||
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
) -> float:
|
||||
nfft = np.floor(window_duration * samplerate) # int() uses floor
|
||||
noverlap = np.floor(window_overlap * nfft)
|
||||
return (time_in_file * samplerate - noverlap) / (nfft - noverlap)
|
||||
|
||||
|
||||
# NOTE this is also defined in post_process
|
||||
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = np.floor(fft_win_length * sampling_rate)
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
|
||||
def x_coords_to_time(
|
||||
x_pos: int,
|
||||
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
) -> float:
|
||||
n_fft = np.floor(window_duration * samplerate)
|
||||
n_overlap = np.floor(window_overlap * n_fft)
|
||||
n_step = n_fft - n_overlap
|
||||
return ((x_pos * n_step) + n_overlap) / samplerate
|
||||
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
||||
|
||||
|
||||
def x_coord_to_sample(
|
||||
x_pos: int,
|
||||
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||
) -> int:
|
||||
n_fft = np.floor(window_duration * samplerate)
|
||||
n_overlap = np.floor(window_overlap * n_fft)
|
||||
n_step = n_fft - n_overlap
|
||||
x_pos = int(x_pos / resize_factor)
|
||||
return int((x_pos * n_step) + n_overlap)
|
||||
|
||||
|
||||
def generate_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
@ -194,55 +220,118 @@ def load_audio(
|
||||
return sampling_rate, audio_raw
|
||||
|
||||
|
||||
def compute_spectrogram_width(
|
||||
length: int,
|
||||
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||
) -> int:
|
||||
n_fft = int(window_duration * samplerate)
|
||||
n_overlap = int(window_overlap * n_fft)
|
||||
n_step = n_fft - n_overlap
|
||||
width = (length - n_overlap) // n_step
|
||||
return int(width * resize_factor)
|
||||
|
||||
|
||||
def pad_audio(
|
||||
audio_raw,
|
||||
fs,
|
||||
ms,
|
||||
overlap_perc,
|
||||
resize_factor,
|
||||
divide_factor,
|
||||
fixed_width=None,
|
||||
audio: np.ndarray,
|
||||
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||
divide_factor: int = parameters.SPEC_DIVIDE_FACTOR,
|
||||
fixed_width: Optional[int] = None,
|
||||
):
|
||||
# Adds zeros to the end of the raw data so that the generated sepctrogram
|
||||
# will be evenly divisible by `divide_factor`
|
||||
# Also deals with very short audio clips and fixed_width during training
|
||||
"""Pad audio to be evenly divisible by `divide_factor`.
|
||||
|
||||
# This code could be clearer, clean up
|
||||
nfft = int(ms * fs)
|
||||
noverlap = int(overlap_perc * nfft)
|
||||
step = nfft - noverlap
|
||||
min_size = int(divide_factor * (1.0 / resize_factor))
|
||||
spec_width = (audio_raw.shape[0] - noverlap) // step
|
||||
spec_width_rs = spec_width * resize_factor
|
||||
This function pads the audio signal with zeros to ensure that the
|
||||
generated spectrogram length will be evenly divisible by `divide_factor`.
|
||||
This is important for the model to work correctly.
|
||||
|
||||
if fixed_width is not None and spec_width < fixed_width:
|
||||
# too small
|
||||
# used during training to ensure all the batches are the same size
|
||||
diff = fixed_width * step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = np.hstack(
|
||||
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
|
||||
This `divide_factor` comes from the model architecture as it downscales
|
||||
the spectrogram by this factor, so the input must be divisible by this
|
||||
integer number.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : np.ndarray
|
||||
The audio signal.
|
||||
samplerate : int
|
||||
The sampling rate of the audio signal.
|
||||
window_size : float
|
||||
The window size in seconds used for the spectrogram computation.
|
||||
window_overlap : float
|
||||
The overlap between windows in the spectrogram computation.
|
||||
resize_factor : float
|
||||
This factor is used to resize the spectrogram after the STFT
|
||||
computation. Default is 0.5 which means that the spectrogram will be
|
||||
reduced by half. Important to take into account for the final size of
|
||||
the spectrogram.
|
||||
divide_factor : int
|
||||
The factor by which the spectrogram will be divided.
|
||||
fixed_width : int, optional
|
||||
If provided, the audio will be padded or cut so that the resulting
|
||||
spectrogram width will be equal to this value.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The padded audio signal.
|
||||
"""
|
||||
spec_width = compute_spectrogram_width(
|
||||
audio.shape[0],
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
resize_factor=resize_factor,
|
||||
)
|
||||
|
||||
elif fixed_width is not None and spec_width > fixed_width:
|
||||
# too big
|
||||
# used during training to ensure all the batches are the same size
|
||||
diff = fixed_width * step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = audio_raw[:diff]
|
||||
if fixed_width:
|
||||
target_samples = x_coord_to_sample(
|
||||
fixed_width,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
resize_factor=resize_factor,
|
||||
)
|
||||
|
||||
elif (
|
||||
spec_width_rs < min_size
|
||||
or (np.floor(spec_width_rs) % divide_factor) != 0
|
||||
):
|
||||
if spec_width < fixed_width:
|
||||
# need to be at least min_size
|
||||
div_amt = np.ceil(spec_width_rs / float(divide_factor))
|
||||
div_amt = np.maximum(1, div_amt)
|
||||
target_size = int(div_amt * divide_factor * (1.0 / resize_factor))
|
||||
diff = target_size * step + noverlap - audio_raw.shape[0]
|
||||
audio_raw = np.hstack(
|
||||
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
|
||||
)
|
||||
diff = target_samples - audio.shape[0]
|
||||
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||
|
||||
return audio_raw
|
||||
if spec_width > fixed_width:
|
||||
return audio[:target_samples]
|
||||
|
||||
return audio
|
||||
|
||||
min_width = int(divide_factor / resize_factor)
|
||||
|
||||
if spec_width < min_width:
|
||||
target_samples = x_coord_to_sample(
|
||||
min_width,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
resize_factor=resize_factor,
|
||||
)
|
||||
diff = target_samples - audio.shape[0]
|
||||
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||
|
||||
if (spec_width % divide_factor) == 0:
|
||||
return audio
|
||||
|
||||
target_width = int(np.ceil(spec_width / divide_factor)) * divide_factor
|
||||
target_samples = x_coord_to_sample(
|
||||
target_width,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
resize_factor=resize_factor,
|
||||
)
|
||||
diff = target_samples - audio.shape[0]
|
||||
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||
|
||||
|
||||
def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
||||
|
@ -8,6 +8,11 @@ import pandas as pd
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
try:
|
||||
from numpy.exceptions import AxisError
|
||||
except ImportError:
|
||||
from numpy import AxisError # type: ignore
|
||||
|
||||
import batdetect2.detector.compute_features as feats
|
||||
import batdetect2.detector.post_process as pp
|
||||
import batdetect2.utils.audio_utils as au
|
||||
@ -80,6 +85,7 @@ def load_model(
|
||||
model_path: str = DEFAULT_MODEL_PATH,
|
||||
load_weights: bool = True,
|
||||
device: Union[torch.device, str, None] = None,
|
||||
weights_only: bool = True,
|
||||
) -> Tuple[DetectionModel, ModelParameters]:
|
||||
"""Load model from file.
|
||||
|
||||
@ -100,7 +106,11 @@ def load_model(
|
||||
if not os.path.isfile(model_path):
|
||||
raise FileNotFoundError("Model file not found.")
|
||||
|
||||
net_params = torch.load(model_path, map_location=device)
|
||||
net_params = torch.load(
|
||||
model_path,
|
||||
map_location=device,
|
||||
weights_only=weights_only,
|
||||
)
|
||||
|
||||
params = net_params["params"]
|
||||
|
||||
@ -242,7 +252,7 @@ def format_single_result(
|
||||
)
|
||||
class_name = class_names[np.argmax(class_overall)]
|
||||
annotations = get_annotations_from_preds(predictions, class_names)
|
||||
except (np.AxisError, ValueError):
|
||||
except (AxisError, ValueError):
|
||||
# No detections
|
||||
class_overall = np.zeros(len(class_names))
|
||||
class_name = "None"
|
||||
@ -738,9 +748,7 @@ def process_file(
|
||||
|
||||
# Get original sampling rate
|
||||
file_samp_rate = librosa.get_samplerate(audio_file)
|
||||
orig_samp_rate = file_samp_rate * float(
|
||||
config.get("time_expansion", 1.0) or 1.0
|
||||
)
|
||||
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
|
||||
|
||||
# load audio file
|
||||
sampling_rate, audio_full = au.load_audio(
|
||||
|
@ -1,33 +1,22 @@
|
||||
[tool]
|
||||
uv = { dev-dependencies = [
|
||||
"ipykernel>=6.29.4",
|
||||
"setuptools>=69.5.1",
|
||||
"pytest>=8.1.1",
|
||||
] }
|
||||
[tool.pdm]
|
||||
[tool.pdm.dev-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.2.2",
|
||||
]
|
||||
|
||||
[project]
|
||||
name = "batdetect2"
|
||||
version = "1.0.8"
|
||||
version = "1.1.1"
|
||||
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
|
||||
authors = [
|
||||
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
||||
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" }
|
||||
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" },
|
||||
]
|
||||
dependencies = [
|
||||
"click>=8.1.7",
|
||||
"librosa>=0.10.1",
|
||||
"matplotlib>=3.7.1",
|
||||
"numpy>=1.23.5",
|
||||
"pandas>=1.5.3",
|
||||
"scikit-learn>=1.2.2",
|
||||
"scipy>=1.10.1",
|
||||
"torch>=1.13.1",
|
||||
"torchaudio",
|
||||
"torchvision",
|
||||
"torch>=1.13.1,<2.5.0",
|
||||
"torchaudio>=1.13.1,<2.5.0",
|
||||
"torchvision>=0.14.0",
|
||||
"soundevent[audio,geometry,plot]>=2.0.1",
|
||||
"click>=8.1.7",
|
||||
"netcdf4>=1.6.5",
|
||||
@ -38,7 +27,7 @@ dependencies = [
|
||||
"lightning[extra]>=2.2.2",
|
||||
"tensorboard>=2.16.2",
|
||||
]
|
||||
requires-python = ">=3.9"
|
||||
requires-python = ">=3.9,<3.13"
|
||||
readme = "README.md"
|
||||
license = { text = "CC-by-nc-4" }
|
||||
classifiers = [
|
||||
@ -46,8 +35,10 @@ classifiers = [
|
||||
"Intended Audience :: Science/Research",
|
||||
"Natural Language :: English",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Multimedia :: Sound/Audio :: Analysis",
|
||||
@ -63,41 +54,40 @@ keywords = [
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["pdm-pep517>=1.0.0"]
|
||||
build-backend = "pdm.pep517.api"
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project.scripts]
|
||||
batdetect2 = "batdetect2.cli:cli"
|
||||
|
||||
[tool.black]
|
||||
line-length = 79
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
line_length = 79
|
||||
[tool.uv]
|
||||
dev-dependencies = [
|
||||
"debugpy>=1.8.8",
|
||||
"hypothesis>=6.118.7",
|
||||
"pyright>=1.1.388",
|
||||
"pytest>=7.2.2",
|
||||
"ruff>=0.7.3",
|
||||
"ipykernel>=6.29.4",
|
||||
"setuptools>=69.5.1",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 79
|
||||
target-version = "py39"
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"librosa",
|
||||
"pandas",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
docstring-code-line-length = 79
|
||||
|
||||
[tool.pylsp-mypy]
|
||||
enabled = false
|
||||
live_mode = true
|
||||
strict = true
|
||||
[tool.ruff.lint]
|
||||
select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
|
||||
|
||||
[tool.pydocstyle]
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "numpy"
|
||||
|
||||
[tool.pyright]
|
||||
include = [
|
||||
"bat_detect",
|
||||
"tests",
|
||||
]
|
||||
include = ["batdetect2", "tests"]
|
||||
venvPath = "."
|
||||
venv = ".venv"
|
||||
pythonVersion = "3.9"
|
||||
pythonPlatform = "All"
|
||||
|
40
tests/conftest.py
Normal file
40
tests/conftest.py
Normal file
@ -0,0 +1,40 @@
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_data_dir() -> Path:
|
||||
pkg_dir = Path(__file__).parent.parent
|
||||
example_data_dir = pkg_dir / "example_data"
|
||||
assert example_data_dir.exists()
|
||||
return example_data_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_audio_dir(example_data_dir: Path) -> Path:
|
||||
example_audio_dir = example_data_dir / "audio"
|
||||
assert example_audio_dir.exists()
|
||||
return example_audio_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_audio_files(example_audio_dir: Path) -> List[Path]:
|
||||
audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]"))
|
||||
assert len(audio_files) == 3
|
||||
return audio_files
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_dir() -> Path:
|
||||
dir = Path(__file__).parent / "data"
|
||||
assert dir.exists()
|
||||
return dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def contrib_dir(data_dir) -> Path:
|
||||
dir = data_dir / "contrib"
|
||||
assert dir.exists()
|
||||
return dir
|
BIN
tests/data/contrib/jeff37/0166_20240531_223911.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240531_223911.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240602_225340.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240602_225340.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240603_033731.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240603_033731.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240603_033937.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240603_033937.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240604_233500.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240604_233500.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/padpadpadpad/Audiomoth.WAV
Normal file
BIN
tests/data/contrib/padpadpadpad/Audiomoth.WAV
Normal file
Binary file not shown.
BIN
tests/data/contrib/padpadpadpad/AudiomothNoBatCalls.WAV
Normal file
BIN
tests/data/contrib/padpadpadpad/AudiomothNoBatCalls.WAV
Normal file
Binary file not shown.
BIN
tests/data/contrib/padpadpadpad/Echometer.wav
Normal file
BIN
tests/data/contrib/padpadpadpad/Echometer.wav
Normal file
Binary file not shown.
@ -1,14 +1,13 @@
|
||||
"""Test bat detect module API."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import os
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from torch import nn
|
||||
import soundfile as sf
|
||||
|
||||
from batdetect2 import api
|
||||
|
||||
@ -267,7 +266,6 @@ def test_process_file_with_spec_slices():
|
||||
assert len(results["spec_slices"]) == len(detections)
|
||||
|
||||
|
||||
|
||||
def test_process_file_with_empty_predictions_does_not_fail(
|
||||
tmp_path: Path,
|
||||
):
|
||||
|
136
tests/test_audio_utils.py
Normal file
136
tests/test_audio_utils.py
Normal file
@ -0,0 +1,136 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from hypothesis import given
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from batdetect2.detector import parameters
|
||||
from batdetect2.utils import audio_utils, detector_utils
|
||||
|
||||
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
def test_can_compute_correct_spectrogram_width(duration: float):
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
|
||||
length = int(duration * samplerate)
|
||||
audio = np.random.rand(length)
|
||||
|
||||
spectrogram, _ = audio_utils.generate_spectrogram(
|
||||
audio,
|
||||
samplerate,
|
||||
params,
|
||||
)
|
||||
|
||||
# convert to pytorch
|
||||
spectrogram = torch.from_numpy(spectrogram)
|
||||
|
||||
# add batch and channel dimensions
|
||||
spectrogram = spectrogram.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# resize the spec
|
||||
resize_factor = params["resize_factor"]
|
||||
spec_op_shape = (
|
||||
int(params["spec_height"] * resize_factor),
|
||||
int(spectrogram.shape[-1] * resize_factor),
|
||||
)
|
||||
spectrogram = F.interpolate(
|
||||
spectrogram,
|
||||
size=spec_op_shape,
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
expected_width = audio_utils.compute_spectrogram_width(
|
||||
length,
|
||||
samplerate=parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
|
||||
assert spectrogram.shape[-1] == expected_width
|
||||
|
||||
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
def test_pad_audio_without_fixed_size(duration: float):
|
||||
# Test the pad_audio function
|
||||
# This function is used to pad audio with zeros to a specific length
|
||||
# It is used in the generate_spectrogram function
|
||||
# The function is tested with a simplepas
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
|
||||
length = int(duration * samplerate)
|
||||
audio = np.random.rand(length)
|
||||
|
||||
# pad the audio to be divisible by divide factor
|
||||
padded_audio = audio_utils.pad_audio(
|
||||
audio,
|
||||
samplerate=samplerate,
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
resize_factor=params["resize_factor"],
|
||||
divide_factor=params["spec_divide_factor"],
|
||||
)
|
||||
|
||||
# check that the padded audio is divisible by the divide factor
|
||||
expected_width = audio_utils.compute_spectrogram_width(
|
||||
len(padded_audio),
|
||||
samplerate=parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
|
||||
assert expected_width % params["spec_divide_factor"] == 0
|
||||
|
||||
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
|
||||
duration: float,
|
||||
):
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
length = int(duration * samplerate)
|
||||
audio = np.random.rand(length)
|
||||
_, spectrogram, _ = detector_utils.compute_spectrogram(
|
||||
audio,
|
||||
samplerate,
|
||||
params,
|
||||
torch.device("cpu"),
|
||||
)
|
||||
assert spectrogram.shape[-1] % params["spec_divide_factor"] == 0
|
||||
|
||||
|
||||
@given(
|
||||
duration=st.floats(min_value=0.1, max_value=2),
|
||||
width=st.integers(min_value=128, max_value=1024),
|
||||
)
|
||||
def test_pad_audio_with_fixed_width(duration: float, width: int):
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
|
||||
length = int(duration * samplerate)
|
||||
audio = np.random.rand(length)
|
||||
|
||||
# pad the audio to be divisible by divide factor
|
||||
padded_audio = audio_utils.pad_audio(
|
||||
audio,
|
||||
samplerate=samplerate,
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
resize_factor=params["resize_factor"],
|
||||
divide_factor=params["spec_divide_factor"],
|
||||
fixed_width=width,
|
||||
)
|
||||
|
||||
# check that the padded audio is divisible by the divide factor
|
||||
expected_width = audio_utils.compute_spectrogram_width(
|
||||
len(padded_audio),
|
||||
samplerate=parameters.TARGET_SAMPLERATE_HZ,
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
assert expected_width == width
|
@ -1,22 +1,26 @@
|
||||
"""Test the command line interface."""
|
||||
|
||||
from pathlib import Path
|
||||
from click.testing import CliRunner
|
||||
|
||||
import pandas as pd
|
||||
from click.testing import CliRunner
|
||||
|
||||
from batdetect2.cli import cli
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
def test_cli_base_command():
|
||||
"""Test the base command."""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "BatDetect2 - Bat Call Detection and Classification" in result.output
|
||||
assert (
|
||||
"BatDetect2 - Bat Call Detection and Classification" in result.output
|
||||
)
|
||||
|
||||
|
||||
def test_cli_detect_command_help():
|
||||
"""Test the detect command help."""
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(cli, ["detect", "--help"])
|
||||
assert result.exit_code == 0
|
||||
assert "Detect bat calls in files in AUDIO_DIR" in result.output
|
||||
@ -30,7 +34,6 @@ def test_cli_detect_command_on_test_audio(tmp_path):
|
||||
if results_dir.exists():
|
||||
results_dir.rmdir()
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
@ -54,7 +57,6 @@ def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
|
||||
if results_dir.exists():
|
||||
results_dir.rmdir()
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
@ -68,8 +70,7 @@ def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert 'Time Expansion Factor: 10' in result.stdout
|
||||
|
||||
assert "Time Expansion Factor: 10" in result.stdout
|
||||
|
||||
|
||||
def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
||||
@ -80,7 +81,6 @@ def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
||||
if results_dir.exists():
|
||||
results_dir.rmdir()
|
||||
|
||||
runner = CliRunner()
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
@ -94,13 +94,12 @@ def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
||||
assert result.exit_code == 0
|
||||
assert results_dir.exists()
|
||||
|
||||
|
||||
csv_files = [path.name for path in results_dir.glob("*.csv")]
|
||||
|
||||
expected_files = [
|
||||
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
|
||||
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
|
||||
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv"
|
||||
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv",
|
||||
]
|
||||
|
||||
for expected_file in expected_files:
|
||||
@ -108,3 +107,26 @@ def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
||||
|
||||
df = pd.read_csv(results_dir / expected_file)
|
||||
assert not (df.duration == -1).any()
|
||||
|
||||
|
||||
def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path):
|
||||
results_dir = tmp_path / "results"
|
||||
target = tmp_path / "audio"
|
||||
target.mkdir()
|
||||
|
||||
# Create an empty file with the .wav extension
|
||||
empty_file = target / "empty.wav"
|
||||
empty_file.touch()
|
||||
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
args=[
|
||||
"detect",
|
||||
str(target),
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
"--spec_features",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert f"Error processing file {empty_file}" in result.output
|
||||
|
73
tests/test_contrib.py
Normal file
73
tests/test_contrib.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""Test suite to ensure user provided files are correctly processed."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from click.testing import CliRunner
|
||||
|
||||
from batdetect2.cli import cli
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
def test_can_process_jeff37_files(
|
||||
contrib_dir: Path,
|
||||
tmp_path: Path,
|
||||
):
|
||||
"""This test stems from issue #31.
|
||||
|
||||
A user provided a set of files which which batdetect2 cli failed and
|
||||
generated the following error message:
|
||||
|
||||
[2272] "Error processing file!: negative dimensions are not allowed"
|
||||
|
||||
This test ensures that the error message is not generated when running
|
||||
batdetect2 cli with the same set of files.
|
||||
"""
|
||||
path = contrib_dir / "jeff37"
|
||||
assert path.exists()
|
||||
|
||||
results_dir = tmp_path / "results"
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"detect",
|
||||
str(path),
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert results_dir.exists()
|
||||
assert len(list(results_dir.glob("*.csv"))) == 5
|
||||
assert len(list(results_dir.glob("*.json"))) == 5
|
||||
|
||||
|
||||
def test_can_process_padpadpadpad_files(
|
||||
contrib_dir: Path,
|
||||
tmp_path: Path,
|
||||
):
|
||||
"""This test stems from issue #29.
|
||||
|
||||
Batdetect2 cli failed on the files provided by the user @padpadpadpad
|
||||
with the following error message:
|
||||
|
||||
AttributeError: module 'numpy' has no attribute 'AxisError'
|
||||
|
||||
This test ensures that the files are processed without any error.
|
||||
"""
|
||||
path = contrib_dir / "padpadpadpad"
|
||||
assert path.exists()
|
||||
results_dir = tmp_path / "results"
|
||||
result = runner.invoke(
|
||||
cli,
|
||||
[
|
||||
"detect",
|
||||
str(path),
|
||||
str(results_dir),
|
||||
"0.3",
|
||||
],
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert results_dir.exists()
|
||||
assert len(list(results_dir.glob("*.csv"))) == 2
|
||||
assert len(list(results_dir.glob("*.json"))) == 2
|
78
tests/test_model.py
Normal file
78
tests/test_model.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""Test suite for model functions."""
|
||||
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from batdetect2 import api
|
||||
from batdetect2.detector import parameters
|
||||
|
||||
|
||||
def test_can_import_model_without_warnings():
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
api.load_model()
|
||||
|
||||
|
||||
@settings(deadline=None, max_examples=5)
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
def test_can_import_model_without_pickle(duration: float):
|
||||
# NOTE: remove this test once no other issues are found This is a temporary
|
||||
# test to check that change in model loading did not impact model behaviour
|
||||
# in any way.
|
||||
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
audio = np.random.rand(int(duration * samplerate))
|
||||
|
||||
model_without_pickle, model_params_without_pickle = api.load_model(
|
||||
weights_only=True
|
||||
)
|
||||
model_with_pickle, model_params_with_pickle = api.load_model(
|
||||
weights_only=False
|
||||
)
|
||||
|
||||
assert model_params_without_pickle == model_params_with_pickle
|
||||
|
||||
predictions_without_pickle, _, _ = api.process_audio(
|
||||
audio,
|
||||
model=model_without_pickle,
|
||||
)
|
||||
predictions_with_pickle, _, _ = api.process_audio(
|
||||
audio,
|
||||
model=model_with_pickle,
|
||||
)
|
||||
|
||||
assert predictions_without_pickle == predictions_with_pickle
|
||||
|
||||
|
||||
def test_can_import_model_without_pickle_on_test_data(
|
||||
example_audio_files: List[Path],
|
||||
):
|
||||
# NOTE: remove this test once no other issues are found This is a temporary
|
||||
# test to check that change in model loading did not impact model behaviour
|
||||
# in any way.
|
||||
|
||||
model_without_pickle, model_params_without_pickle = api.load_model(
|
||||
weights_only=True
|
||||
)
|
||||
model_with_pickle, model_params_with_pickle = api.load_model(
|
||||
weights_only=False
|
||||
)
|
||||
|
||||
assert model_params_without_pickle == model_params_with_pickle
|
||||
|
||||
for audio_file in example_audio_files:
|
||||
audio = api.load_audio(str(audio_file))
|
||||
predictions_without_pickle, _, _ = api.process_audio(
|
||||
audio,
|
||||
model=model_without_pickle,
|
||||
)
|
||||
predictions_with_pickle, _, _ = api.process_audio(
|
||||
audio,
|
||||
model=model_with_pickle,
|
||||
)
|
||||
assert predictions_without_pickle == predictions_with_pickle
|
Loading…
Reference in New Issue
Block a user