Merge branch 'main' into train

This commit is contained in:
mbsantiago 2024-11-15 20:10:36 +00:00
commit 6a9e33c729
25 changed files with 1382 additions and 1021 deletions

8
.bumpversion.cfg Normal file
View File

@ -0,0 +1,8 @@
[bumpversion]
current_version = 1.1.1
commit = True
tag = True
[bumpversion:file:batdetect2/__init__.py]
[bumpversion:file:pyproject.toml]

View File

@ -1,34 +1,29 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
name: Python package
on:
push:
branches: [ "main" ]
branches: ["main"]
pull_request:
branches: [ "main" ]
branches: ["main"]
jobs:
build:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest
pip install .
- name: Test with pytest
run: |
pytest
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
enable-cache: true
cache-dependency-glob: "uv.lock"
- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}
- name: Install the project
run: uv sync --all-extras --dev
- name: Test with pytest
run: uv run pytest

View File

@ -1,11 +1,3 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package
on:
@ -17,23 +9,22 @@ permissions:
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: '3.x'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: "3.x"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

4
.gitignore vendored
View File

@ -103,13 +103,11 @@ experiments/*
.ipynb_checkpoints
*.ipynb
# Bump2version
.bumpversion.cfg
# DO Include
!batdetect2_notebook.ipynb
!batdetect2/models/checkpoints/*.pth.tar
!tests/data/*.wav
!notebooks/*.ipynb
!tests/data/**/*.wav
notebooks/lightning_logs
example_data/preprocessed

View File

@ -1 +1,6 @@
__version__ = '1.0.8'
import logging
numba_logger = logging.getLogger("numba")
numba_logger.setLevel(logging.WARNING)
__version__ = "1.1.1"

View File

@ -1,5 +1,3 @@
"""BatDetect2 command line interface."""
import click
from batdetect2 import api
@ -114,10 +112,9 @@ def detect(
):
results_path = audio_file.replace(audio_dir, ann_dir)
save_results_to_file(results, results_path)
except (RuntimeError, ValueError, LookupError) as err:
except (RuntimeError, ValueError, LookupError, EOFError) as err:
error_files.append(audio_file)
click.secho(f"Error processing file!: {err}", fg="red")
raise err
click.secho(f"Error processing file {audio_file}: {err}", fg="red")
click.echo(f"\nResults saved to: {ann_dir}")

View File

@ -1,13 +1,11 @@
"""Types used in the code base."""
from typing import Any, List, NamedTuple, Optional
import numpy as np
import torch
try:
from typing import TypedDict
except ImportError:
from typing_extensions import TypedDict
from typing import TypedDict
try:
@ -15,9 +13,8 @@ try:
except ImportError:
from typing_extensions import Protocol
try:
from typing import NotRequired
from typing import NotRequired # type: ignore
except ImportError:
from typing_extensions import NotRequired

View File

@ -6,6 +6,8 @@ import librosa.core.spectrum
import numpy as np
import torch
from batdetect2.detector import parameters
from . import wavfile
__all__ = [
@ -15,20 +17,44 @@ __all__ = [
]
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
noverlap = np.floor(fft_overlap * nfft)
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
def time_to_x_coords(
time_in_file: float,
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
window_duration: float = parameters.FFT_WIN_LENGTH_S,
window_overlap: float = parameters.FFT_OVERLAP,
) -> float:
nfft = np.floor(window_duration * samplerate) # int() uses floor
noverlap = np.floor(window_overlap * nfft)
return (time_in_file * samplerate - noverlap) / (nfft - noverlap)
# NOTE this is also defined in post_process
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
nfft = np.floor(fft_win_length * sampling_rate)
noverlap = np.floor(fft_overlap * nfft)
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
def x_coords_to_time(
x_pos: int,
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
window_duration: float = parameters.FFT_WIN_LENGTH_S,
window_overlap: float = parameters.FFT_OVERLAP,
) -> float:
n_fft = np.floor(window_duration * samplerate)
n_overlap = np.floor(window_overlap * n_fft)
n_step = n_fft - n_overlap
return ((x_pos * n_step) + n_overlap) / samplerate
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
def x_coord_to_sample(
x_pos: int,
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
window_duration: float = parameters.FFT_WIN_LENGTH_S,
window_overlap: float = parameters.FFT_OVERLAP,
resize_factor: float = parameters.RESIZE_FACTOR,
) -> int:
n_fft = np.floor(window_duration * samplerate)
n_overlap = np.floor(window_overlap * n_fft)
n_step = n_fft - n_overlap
x_pos = int(x_pos / resize_factor)
return int((x_pos * n_step) + n_overlap)
def generate_spectrogram(
audio,
sampling_rate,
@ -194,55 +220,118 @@ def load_audio(
return sampling_rate, audio_raw
def compute_spectrogram_width(
length: int,
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
window_duration: float = parameters.FFT_WIN_LENGTH_S,
window_overlap: float = parameters.FFT_OVERLAP,
resize_factor: float = parameters.RESIZE_FACTOR,
) -> int:
n_fft = int(window_duration * samplerate)
n_overlap = int(window_overlap * n_fft)
n_step = n_fft - n_overlap
width = (length - n_overlap) // n_step
return int(width * resize_factor)
def pad_audio(
audio_raw,
fs,
ms,
overlap_perc,
resize_factor,
divide_factor,
fixed_width=None,
audio: np.ndarray,
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
window_duration: float = parameters.FFT_WIN_LENGTH_S,
window_overlap: float = parameters.FFT_OVERLAP,
resize_factor: float = parameters.RESIZE_FACTOR,
divide_factor: int = parameters.SPEC_DIVIDE_FACTOR,
fixed_width: Optional[int] = None,
):
# Adds zeros to the end of the raw data so that the generated sepctrogram
# will be evenly divisible by `divide_factor`
# Also deals with very short audio clips and fixed_width during training
"""Pad audio to be evenly divisible by `divide_factor`.
# This code could be clearer, clean up
nfft = int(ms * fs)
noverlap = int(overlap_perc * nfft)
step = nfft - noverlap
min_size = int(divide_factor * (1.0 / resize_factor))
spec_width = (audio_raw.shape[0] - noverlap) // step
spec_width_rs = spec_width * resize_factor
This function pads the audio signal with zeros to ensure that the
generated spectrogram length will be evenly divisible by `divide_factor`.
This is important for the model to work correctly.
if fixed_width is not None and spec_width < fixed_width:
# too small
# used during training to ensure all the batches are the same size
diff = fixed_width * step + noverlap - audio_raw.shape[0]
audio_raw = np.hstack(
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
This `divide_factor` comes from the model architecture as it downscales
the spectrogram by this factor, so the input must be divisible by this
integer number.
Parameters
----------
audio : np.ndarray
The audio signal.
samplerate : int
The sampling rate of the audio signal.
window_size : float
The window size in seconds used for the spectrogram computation.
window_overlap : float
The overlap between windows in the spectrogram computation.
resize_factor : float
This factor is used to resize the spectrogram after the STFT
computation. Default is 0.5 which means that the spectrogram will be
reduced by half. Important to take into account for the final size of
the spectrogram.
divide_factor : int
The factor by which the spectrogram will be divided.
fixed_width : int, optional
If provided, the audio will be padded or cut so that the resulting
spectrogram width will be equal to this value.
Returns
-------
np.ndarray
The padded audio signal.
"""
spec_width = compute_spectrogram_width(
audio.shape[0],
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
resize_factor=resize_factor,
)
if fixed_width:
target_samples = x_coord_to_sample(
fixed_width,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
resize_factor=resize_factor,
)
elif fixed_width is not None and spec_width > fixed_width:
# too big
# used during training to ensure all the batches are the same size
diff = fixed_width * step + noverlap - audio_raw.shape[0]
audio_raw = audio_raw[:diff]
if spec_width < fixed_width:
# need to be at least min_size
diff = target_samples - audio.shape[0]
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
elif (
spec_width_rs < min_size
or (np.floor(spec_width_rs) % divide_factor) != 0
):
# need to be at least min_size
div_amt = np.ceil(spec_width_rs / float(divide_factor))
div_amt = np.maximum(1, div_amt)
target_size = int(div_amt * divide_factor * (1.0 / resize_factor))
diff = target_size * step + noverlap - audio_raw.shape[0]
audio_raw = np.hstack(
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
if spec_width > fixed_width:
return audio[:target_samples]
return audio
min_width = int(divide_factor / resize_factor)
if spec_width < min_width:
target_samples = x_coord_to_sample(
min_width,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
resize_factor=resize_factor,
)
diff = target_samples - audio.shape[0]
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
return audio_raw
if (spec_width % divide_factor) == 0:
return audio
target_width = int(np.ceil(spec_width / divide_factor)) * divide_factor
target_samples = x_coord_to_sample(
target_width,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
resize_factor=resize_factor,
)
diff = target_samples - audio.shape[0]
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
def gen_mag_spectrogram(x, fs, ms, overlap_perc):

View File

@ -8,6 +8,11 @@ import pandas as pd
import torch
import torch.nn.functional as F
try:
from numpy.exceptions import AxisError
except ImportError:
from numpy import AxisError # type: ignore
import batdetect2.detector.compute_features as feats
import batdetect2.detector.post_process as pp
import batdetect2.utils.audio_utils as au
@ -80,6 +85,7 @@ def load_model(
model_path: str = DEFAULT_MODEL_PATH,
load_weights: bool = True,
device: Union[torch.device, str, None] = None,
weights_only: bool = True,
) -> Tuple[DetectionModel, ModelParameters]:
"""Load model from file.
@ -100,7 +106,11 @@ def load_model(
if not os.path.isfile(model_path):
raise FileNotFoundError("Model file not found.")
net_params = torch.load(model_path, map_location=device)
net_params = torch.load(
model_path,
map_location=device,
weights_only=weights_only,
)
params = net_params["params"]
@ -242,7 +252,7 @@ def format_single_result(
)
class_name = class_names[np.argmax(class_overall)]
annotations = get_annotations_from_preds(predictions, class_names)
except (np.AxisError, ValueError):
except (AxisError, ValueError):
# No detections
class_overall = np.zeros(len(class_names))
class_name = "None"
@ -738,9 +748,7 @@ def process_file(
# Get original sampling rate
file_samp_rate = librosa.get_samplerate(audio_file)
orig_samp_rate = file_samp_rate * float(
config.get("time_expansion", 1.0) or 1.0
)
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
# load audio file
sampling_rate, audio_full = au.load_audio(

View File

@ -1,103 +1,93 @@
[tool]
uv = { dev-dependencies = [
"ipykernel>=6.29.4",
"setuptools>=69.5.1",
"pytest>=8.1.1",
] }
[tool.pdm]
[tool.pdm.dev-dependencies]
dev = [
"pytest>=7.2.2",
]
[project]
name = "batdetect2"
version = "1.0.8"
version = "1.1.1"
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
authors = [
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" }
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" },
]
dependencies = [
"librosa>=0.10.1",
"matplotlib>=3.7.1",
"numpy>=1.23.5",
"pandas>=1.5.3",
"scikit-learn>=1.2.2",
"scipy>=1.10.1",
"torch>=1.13.1",
"torchaudio",
"torchvision",
"soundevent[audio,geometry,plot]>=2.0.1",
"click>=8.1.7",
"netcdf4>=1.6.5",
"tqdm>=4.66.2",
"pytorch-lightning>=2.2.2",
"cf-xarray>=0.9.0",
"onnx>=1.16.0",
"lightning[extra]>=2.2.2",
"tensorboard>=2.16.2",
"click>=8.1.7",
"librosa>=0.10.1",
"matplotlib>=3.7.1",
"numpy>=1.23.5",
"pandas>=1.5.3",
"scikit-learn>=1.2.2",
"scipy>=1.10.1",
"torch>=1.13.1,<2.5.0",
"torchaudio>=1.13.1,<2.5.0",
"torchvision>=0.14.0",
"soundevent[audio,geometry,plot]>=2.0.1",
"click>=8.1.7",
"netcdf4>=1.6.5",
"tqdm>=4.66.2",
"pytorch-lightning>=2.2.2",
"cf-xarray>=0.9.0",
"onnx>=1.16.0",
"lightning[extra]>=2.2.2",
"tensorboard>=2.16.2",
]
requires-python = ">=3.9"
requires-python = ">=3.9,<3.13"
readme = "README.md"
license = { text = "CC-by-nc-4" }
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Multimedia :: Sound/Audio :: Analysis",
"Development Status :: 4 - Beta",
"Intended Audience :: Science/Research",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Multimedia :: Sound/Audio :: Analysis",
]
keywords = [
"bat",
"echolocation",
"deep learning",
"audio",
"machine learning",
"classification",
"detection",
"bat",
"echolocation",
"deep learning",
"audio",
"machine learning",
"classification",
"detection",
]
[build-system]
requires = ["pdm-pep517>=1.0.0"]
build-backend = "pdm.pep517.api"
requires = ["hatchling"]
build-backend = "hatchling.build"
[project.scripts]
batdetect2 = "batdetect2.cli:cli"
[tool.black]
line-length = 79
[tool.isort]
profile = "black"
line_length = 79
[tool.uv]
dev-dependencies = [
"debugpy>=1.8.8",
"hypothesis>=6.118.7",
"pyright>=1.1.388",
"pytest>=7.2.2",
"ruff>=0.7.3",
"ipykernel>=6.29.4",
"setuptools>=69.5.1",
]
[tool.ruff]
line-length = 79
target-version = "py39"
[[tool.mypy.overrides]]
module = [
"librosa",
"pandas",
]
ignore_missing_imports = true
[tool.ruff.format]
docstring-code-format = true
docstring-code-line-length = 79
[tool.pylsp-mypy]
enabled = false
live_mode = true
strict = true
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
[tool.pydocstyle]
[tool.ruff.lint.pydocstyle]
convention = "numpy"
[tool.pyright]
include = [
"bat_detect",
"tests",
]
include = ["batdetect2", "tests"]
venvPath = "."
venv = ".venv"
pythonVersion = "3.9"
pythonPlatform = "All"

40
tests/conftest.py Normal file
View File

@ -0,0 +1,40 @@
from pathlib import Path
from typing import List
import pytest
@pytest.fixture
def example_data_dir() -> Path:
pkg_dir = Path(__file__).parent.parent
example_data_dir = pkg_dir / "example_data"
assert example_data_dir.exists()
return example_data_dir
@pytest.fixture
def example_audio_dir(example_data_dir: Path) -> Path:
example_audio_dir = example_data_dir / "audio"
assert example_audio_dir.exists()
return example_audio_dir
@pytest.fixture
def example_audio_files(example_audio_dir: Path) -> List[Path]:
audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]"))
assert len(audio_files) == 3
return audio_files
@pytest.fixture
def data_dir() -> Path:
dir = Path(__file__).parent / "data"
assert dir.exists()
return dir
@pytest.fixture
def contrib_dir(data_dir) -> Path:
dir = data_dir / "contrib"
assert dir.exists()
return dir

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -1,14 +1,13 @@
"""Test bat detect module API."""
from pathlib import Path
import os
from glob import glob
from pathlib import Path
import numpy as np
import soundfile as sf
import torch
from torch import nn
import soundfile as sf
from batdetect2 import api
@ -267,7 +266,6 @@ def test_process_file_with_spec_slices():
assert len(results["spec_slices"]) == len(detections)
def test_process_file_with_empty_predictions_does_not_fail(
tmp_path: Path,
):

136
tests/test_audio_utils.py Normal file
View File

@ -0,0 +1,136 @@
import numpy as np
import torch
import torch.nn.functional as F
from hypothesis import given
from hypothesis import strategies as st
from batdetect2.detector import parameters
from batdetect2.utils import audio_utils, detector_utils
@given(duration=st.floats(min_value=0.1, max_value=2))
def test_can_compute_correct_spectrogram_width(duration: float):
samplerate = parameters.TARGET_SAMPLERATE_HZ
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
length = int(duration * samplerate)
audio = np.random.rand(length)
spectrogram, _ = audio_utils.generate_spectrogram(
audio,
samplerate,
params,
)
# convert to pytorch
spectrogram = torch.from_numpy(spectrogram)
# add batch and channel dimensions
spectrogram = spectrogram.unsqueeze(0).unsqueeze(0)
# resize the spec
resize_factor = params["resize_factor"]
spec_op_shape = (
int(params["spec_height"] * resize_factor),
int(spectrogram.shape[-1] * resize_factor),
)
spectrogram = F.interpolate(
spectrogram,
size=spec_op_shape,
mode="bilinear",
align_corners=False,
)
expected_width = audio_utils.compute_spectrogram_width(
length,
samplerate=parameters.TARGET_SAMPLERATE_HZ,
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
resize_factor=params["resize_factor"],
)
assert spectrogram.shape[-1] == expected_width
@given(duration=st.floats(min_value=0.1, max_value=2))
def test_pad_audio_without_fixed_size(duration: float):
# Test the pad_audio function
# This function is used to pad audio with zeros to a specific length
# It is used in the generate_spectrogram function
# The function is tested with a simplepas
samplerate = parameters.TARGET_SAMPLERATE_HZ
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
length = int(duration * samplerate)
audio = np.random.rand(length)
# pad the audio to be divisible by divide factor
padded_audio = audio_utils.pad_audio(
audio,
samplerate=samplerate,
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
resize_factor=params["resize_factor"],
divide_factor=params["spec_divide_factor"],
)
# check that the padded audio is divisible by the divide factor
expected_width = audio_utils.compute_spectrogram_width(
len(padded_audio),
samplerate=parameters.TARGET_SAMPLERATE_HZ,
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
resize_factor=params["resize_factor"],
)
assert expected_width % params["spec_divide_factor"] == 0
@given(duration=st.floats(min_value=0.1, max_value=2))
def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
duration: float,
):
samplerate = parameters.TARGET_SAMPLERATE_HZ
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
length = int(duration * samplerate)
audio = np.random.rand(length)
_, spectrogram, _ = detector_utils.compute_spectrogram(
audio,
samplerate,
params,
torch.device("cpu"),
)
assert spectrogram.shape[-1] % params["spec_divide_factor"] == 0
@given(
duration=st.floats(min_value=0.1, max_value=2),
width=st.integers(min_value=128, max_value=1024),
)
def test_pad_audio_with_fixed_width(duration: float, width: int):
samplerate = parameters.TARGET_SAMPLERATE_HZ
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
length = int(duration * samplerate)
audio = np.random.rand(length)
# pad the audio to be divisible by divide factor
padded_audio = audio_utils.pad_audio(
audio,
samplerate=samplerate,
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
resize_factor=params["resize_factor"],
divide_factor=params["spec_divide_factor"],
fixed_width=width,
)
# check that the padded audio is divisible by the divide factor
expected_width = audio_utils.compute_spectrogram_width(
len(padded_audio),
samplerate=parameters.TARGET_SAMPLERATE_HZ,
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
resize_factor=params["resize_factor"],
)
assert expected_width == width

View File

@ -1,22 +1,26 @@
"""Test the command line interface."""
from pathlib import Path
from click.testing import CliRunner
import pandas as pd
from click.testing import CliRunner
from batdetect2.cli import cli
runner = CliRunner()
def test_cli_base_command():
"""Test the base command."""
runner = CliRunner()
result = runner.invoke(cli, ["--help"])
assert result.exit_code == 0
assert "BatDetect2 - Bat Call Detection and Classification" in result.output
assert (
"BatDetect2 - Bat Call Detection and Classification" in result.output
)
def test_cli_detect_command_help():
"""Test the detect command help."""
runner = CliRunner()
result = runner.invoke(cli, ["detect", "--help"])
assert result.exit_code == 0
assert "Detect bat calls in files in AUDIO_DIR" in result.output
@ -30,7 +34,6 @@ def test_cli_detect_command_on_test_audio(tmp_path):
if results_dir.exists():
results_dir.rmdir()
runner = CliRunner()
result = runner.invoke(
cli,
[
@ -54,7 +57,6 @@ def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
if results_dir.exists():
results_dir.rmdir()
runner = CliRunner()
result = runner.invoke(
cli,
[
@ -68,8 +70,7 @@ def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
)
assert result.exit_code == 0
assert 'Time Expansion Factor: 10' in result.stdout
assert "Time Expansion Factor: 10" in result.stdout
def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
@ -80,7 +81,6 @@ def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
if results_dir.exists():
results_dir.rmdir()
runner = CliRunner()
result = runner.invoke(
cli,
[
@ -94,13 +94,12 @@ def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
assert result.exit_code == 0
assert results_dir.exists()
csv_files = [path.name for path in results_dir.glob("*.csv")]
expected_files = [
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv"
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv",
]
for expected_file in expected_files:
@ -108,3 +107,26 @@ def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
df = pd.read_csv(results_dir / expected_file)
assert not (df.duration == -1).any()
def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path):
results_dir = tmp_path / "results"
target = tmp_path / "audio"
target.mkdir()
# Create an empty file with the .wav extension
empty_file = target / "empty.wav"
empty_file.touch()
result = runner.invoke(
cli,
args=[
"detect",
str(target),
str(results_dir),
"0.3",
"--spec_features",
],
)
assert result.exit_code == 0
assert f"Error processing file {empty_file}" in result.output

73
tests/test_contrib.py Normal file
View File

@ -0,0 +1,73 @@
"""Test suite to ensure user provided files are correctly processed."""
from pathlib import Path
from click.testing import CliRunner
from batdetect2.cli import cli
runner = CliRunner()
def test_can_process_jeff37_files(
contrib_dir: Path,
tmp_path: Path,
):
"""This test stems from issue #31.
A user provided a set of files which which batdetect2 cli failed and
generated the following error message:
[2272] "Error processing file!: negative dimensions are not allowed"
This test ensures that the error message is not generated when running
batdetect2 cli with the same set of files.
"""
path = contrib_dir / "jeff37"
assert path.exists()
results_dir = tmp_path / "results"
result = runner.invoke(
cli,
[
"detect",
str(path),
str(results_dir),
"0.3",
],
)
assert result.exit_code == 0
assert results_dir.exists()
assert len(list(results_dir.glob("*.csv"))) == 5
assert len(list(results_dir.glob("*.json"))) == 5
def test_can_process_padpadpadpad_files(
contrib_dir: Path,
tmp_path: Path,
):
"""This test stems from issue #29.
Batdetect2 cli failed on the files provided by the user @padpadpadpad
with the following error message:
AttributeError: module 'numpy' has no attribute 'AxisError'
This test ensures that the files are processed without any error.
"""
path = contrib_dir / "padpadpadpad"
assert path.exists()
results_dir = tmp_path / "results"
result = runner.invoke(
cli,
[
"detect",
str(path),
str(results_dir),
"0.3",
],
)
assert result.exit_code == 0
assert results_dir.exists()
assert len(list(results_dir.glob("*.csv"))) == 2
assert len(list(results_dir.glob("*.json"))) == 2

78
tests/test_model.py Normal file
View File

@ -0,0 +1,78 @@
"""Test suite for model functions."""
import warnings
from pathlib import Path
from typing import List
import numpy as np
from hypothesis import given, settings
from hypothesis import strategies as st
from batdetect2 import api
from batdetect2.detector import parameters
def test_can_import_model_without_warnings():
with warnings.catch_warnings():
warnings.simplefilter("error")
api.load_model()
@settings(deadline=None, max_examples=5)
@given(duration=st.floats(min_value=0.1, max_value=2))
def test_can_import_model_without_pickle(duration: float):
# NOTE: remove this test once no other issues are found This is a temporary
# test to check that change in model loading did not impact model behaviour
# in any way.
samplerate = parameters.TARGET_SAMPLERATE_HZ
audio = np.random.rand(int(duration * samplerate))
model_without_pickle, model_params_without_pickle = api.load_model(
weights_only=True
)
model_with_pickle, model_params_with_pickle = api.load_model(
weights_only=False
)
assert model_params_without_pickle == model_params_with_pickle
predictions_without_pickle, _, _ = api.process_audio(
audio,
model=model_without_pickle,
)
predictions_with_pickle, _, _ = api.process_audio(
audio,
model=model_with_pickle,
)
assert predictions_without_pickle == predictions_with_pickle
def test_can_import_model_without_pickle_on_test_data(
example_audio_files: List[Path],
):
# NOTE: remove this test once no other issues are found This is a temporary
# test to check that change in model loading did not impact model behaviour
# in any way.
model_without_pickle, model_params_without_pickle = api.load_model(
weights_only=True
)
model_with_pickle, model_params_with_pickle = api.load_model(
weights_only=False
)
assert model_params_without_pickle == model_params_with_pickle
for audio_file in example_audio_files:
audio = api.load_audio(str(audio_file))
predictions_without_pickle, _, _ = api.process_audio(
audio,
model=model_without_pickle,
)
predictions_with_pickle, _, _ = api.process_audio(
audio,
model=model_with_pickle,
)
assert predictions_without_pickle == predictions_with_pickle

1574
uv.lock generated

File diff suppressed because it is too large Load Diff