mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Merge branch 'main' into train
This commit is contained in:
commit
6a9e33c729
8
.bumpversion.cfg
Normal file
8
.bumpversion.cfg
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
[bumpversion]
|
||||||
|
current_version = 1.1.1
|
||||||
|
commit = True
|
||||||
|
tag = True
|
||||||
|
|
||||||
|
[bumpversion:file:batdetect2/__init__.py]
|
||||||
|
|
||||||
|
[bumpversion:file:pyproject.toml]
|
31
.github/workflows/python-package.yml
vendored
31
.github/workflows/python-package.yml
vendored
@ -1,34 +1,29 @@
|
|||||||
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
|
|
||||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
|
||||||
|
|
||||||
name: Python package
|
name: Python package
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [ "main" ]
|
branches: ["main"]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ "main" ]
|
branches: ["main"]
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build:
|
build:
|
||||||
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
- name: Install uv
|
||||||
uses: actions/setup-python@v3
|
uses: astral-sh/setup-uv@v3
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
enable-cache: true
|
||||||
- name: Install dependencies
|
cache-dependency-glob: "uv.lock"
|
||||||
run: |
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
python -m pip install --upgrade pip
|
run: uv python install ${{ matrix.python-version }}
|
||||||
python -m pip install pytest
|
- name: Install the project
|
||||||
pip install .
|
run: uv sync --all-extras --dev
|
||||||
- name: Test with pytest
|
- name: Test with pytest
|
||||||
run: |
|
run: uv run pytest
|
||||||
pytest
|
|
||||||
|
13
.github/workflows/python-publish.yml
vendored
13
.github/workflows/python-publish.yml
vendored
@ -1,11 +1,3 @@
|
|||||||
# This workflow will upload a Python Package using Twine when a release is created
|
|
||||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
|
||||||
|
|
||||||
# This workflow uses actions that are not certified by GitHub.
|
|
||||||
# They are provided by a third-party and are governed by
|
|
||||||
# separate terms of service, privacy policy, and support
|
|
||||||
# documentation.
|
|
||||||
|
|
||||||
name: Upload Python Package
|
name: Upload Python Package
|
||||||
|
|
||||||
on:
|
on:
|
||||||
@ -17,15 +9,14 @@ permissions:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
deploy:
|
deploy:
|
||||||
|
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v3
|
uses: actions/setup-python@v3
|
||||||
with:
|
with:
|
||||||
python-version: '3.x'
|
python-version: "3.x"
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -103,13 +103,11 @@ experiments/*
|
|||||||
.ipynb_checkpoints
|
.ipynb_checkpoints
|
||||||
*.ipynb
|
*.ipynb
|
||||||
|
|
||||||
# Bump2version
|
|
||||||
.bumpversion.cfg
|
|
||||||
|
|
||||||
# DO Include
|
# DO Include
|
||||||
!batdetect2_notebook.ipynb
|
!batdetect2_notebook.ipynb
|
||||||
!batdetect2/models/checkpoints/*.pth.tar
|
!batdetect2/models/checkpoints/*.pth.tar
|
||||||
!tests/data/*.wav
|
!tests/data/*.wav
|
||||||
!notebooks/*.ipynb
|
!notebooks/*.ipynb
|
||||||
|
!tests/data/**/*.wav
|
||||||
notebooks/lightning_logs
|
notebooks/lightning_logs
|
||||||
example_data/preprocessed
|
example_data/preprocessed
|
||||||
|
@ -1 +1,6 @@
|
|||||||
__version__ = '1.0.8'
|
import logging
|
||||||
|
|
||||||
|
numba_logger = logging.getLogger("numba")
|
||||||
|
numba_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
__version__ = "1.1.1"
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
"""BatDetect2 command line interface."""
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from batdetect2 import api
|
from batdetect2 import api
|
||||||
@ -114,10 +112,9 @@ def detect(
|
|||||||
):
|
):
|
||||||
results_path = audio_file.replace(audio_dir, ann_dir)
|
results_path = audio_file.replace(audio_dir, ann_dir)
|
||||||
save_results_to_file(results, results_path)
|
save_results_to_file(results, results_path)
|
||||||
except (RuntimeError, ValueError, LookupError) as err:
|
except (RuntimeError, ValueError, LookupError, EOFError) as err:
|
||||||
error_files.append(audio_file)
|
error_files.append(audio_file)
|
||||||
click.secho(f"Error processing file!: {err}", fg="red")
|
click.secho(f"Error processing file {audio_file}: {err}", fg="red")
|
||||||
raise err
|
|
||||||
|
|
||||||
click.echo(f"\nResults saved to: {ann_dir}")
|
click.echo(f"\nResults saved to: {ann_dir}")
|
||||||
|
|
||||||
|
@ -1,13 +1,11 @@
|
|||||||
"""Types used in the code base."""
|
"""Types used in the code base."""
|
||||||
from typing import Any, List, NamedTuple, Optional
|
from typing import Any, List, NamedTuple, Optional
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
try:
|
from typing import TypedDict
|
||||||
from typing import TypedDict
|
|
||||||
except ImportError:
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -15,9 +13,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import NotRequired
|
from typing import NotRequired # type: ignore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
|
@ -6,6 +6,8 @@ import librosa.core.spectrum
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from batdetect2.detector import parameters
|
||||||
|
|
||||||
from . import wavfile
|
from . import wavfile
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -15,20 +17,44 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
def time_to_x_coords(
|
||||||
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
time_in_file: float,
|
||||||
noverlap = np.floor(fft_overlap * nfft)
|
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||||
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
|
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||||
|
window_overlap: float = parameters.FFT_OVERLAP,
|
||||||
|
) -> float:
|
||||||
|
nfft = np.floor(window_duration * samplerate) # int() uses floor
|
||||||
|
noverlap = np.floor(window_overlap * nfft)
|
||||||
|
return (time_in_file * samplerate - noverlap) / (nfft - noverlap)
|
||||||
|
|
||||||
|
|
||||||
# NOTE this is also defined in post_process
|
def x_coords_to_time(
|
||||||
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
|
x_pos: int,
|
||||||
nfft = np.floor(fft_win_length * sampling_rate)
|
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||||
noverlap = np.floor(fft_overlap * nfft)
|
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||||
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
|
window_overlap: float = parameters.FFT_OVERLAP,
|
||||||
|
) -> float:
|
||||||
|
n_fft = np.floor(window_duration * samplerate)
|
||||||
|
n_overlap = np.floor(window_overlap * n_fft)
|
||||||
|
n_step = n_fft - n_overlap
|
||||||
|
return ((x_pos * n_step) + n_overlap) / samplerate
|
||||||
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
||||||
|
|
||||||
|
|
||||||
|
def x_coord_to_sample(
|
||||||
|
x_pos: int,
|
||||||
|
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
|
||||||
|
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||||
|
window_overlap: float = parameters.FFT_OVERLAP,
|
||||||
|
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||||
|
) -> int:
|
||||||
|
n_fft = np.floor(window_duration * samplerate)
|
||||||
|
n_overlap = np.floor(window_overlap * n_fft)
|
||||||
|
n_step = n_fft - n_overlap
|
||||||
|
x_pos = int(x_pos / resize_factor)
|
||||||
|
return int((x_pos * n_step) + n_overlap)
|
||||||
|
|
||||||
|
|
||||||
def generate_spectrogram(
|
def generate_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
sampling_rate,
|
sampling_rate,
|
||||||
@ -194,55 +220,118 @@ def load_audio(
|
|||||||
return sampling_rate, audio_raw
|
return sampling_rate, audio_raw
|
||||||
|
|
||||||
|
|
||||||
|
def compute_spectrogram_width(
|
||||||
|
length: int,
|
||||||
|
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
|
||||||
|
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||||
|
window_overlap: float = parameters.FFT_OVERLAP,
|
||||||
|
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||||
|
) -> int:
|
||||||
|
n_fft = int(window_duration * samplerate)
|
||||||
|
n_overlap = int(window_overlap * n_fft)
|
||||||
|
n_step = n_fft - n_overlap
|
||||||
|
width = (length - n_overlap) // n_step
|
||||||
|
return int(width * resize_factor)
|
||||||
|
|
||||||
|
|
||||||
def pad_audio(
|
def pad_audio(
|
||||||
audio_raw,
|
audio: np.ndarray,
|
||||||
fs,
|
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
|
||||||
ms,
|
window_duration: float = parameters.FFT_WIN_LENGTH_S,
|
||||||
overlap_perc,
|
window_overlap: float = parameters.FFT_OVERLAP,
|
||||||
resize_factor,
|
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||||
divide_factor,
|
divide_factor: int = parameters.SPEC_DIVIDE_FACTOR,
|
||||||
fixed_width=None,
|
fixed_width: Optional[int] = None,
|
||||||
):
|
):
|
||||||
# Adds zeros to the end of the raw data so that the generated sepctrogram
|
"""Pad audio to be evenly divisible by `divide_factor`.
|
||||||
# will be evenly divisible by `divide_factor`
|
|
||||||
# Also deals with very short audio clips and fixed_width during training
|
|
||||||
|
|
||||||
# This code could be clearer, clean up
|
This function pads the audio signal with zeros to ensure that the
|
||||||
nfft = int(ms * fs)
|
generated spectrogram length will be evenly divisible by `divide_factor`.
|
||||||
noverlap = int(overlap_perc * nfft)
|
This is important for the model to work correctly.
|
||||||
step = nfft - noverlap
|
|
||||||
min_size = int(divide_factor * (1.0 / resize_factor))
|
|
||||||
spec_width = (audio_raw.shape[0] - noverlap) // step
|
|
||||||
spec_width_rs = spec_width * resize_factor
|
|
||||||
|
|
||||||
if fixed_width is not None and spec_width < fixed_width:
|
This `divide_factor` comes from the model architecture as it downscales
|
||||||
# too small
|
the spectrogram by this factor, so the input must be divisible by this
|
||||||
# used during training to ensure all the batches are the same size
|
integer number.
|
||||||
diff = fixed_width * step + noverlap - audio_raw.shape[0]
|
|
||||||
audio_raw = np.hstack(
|
Parameters
|
||||||
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
|
----------
|
||||||
|
audio : np.ndarray
|
||||||
|
The audio signal.
|
||||||
|
samplerate : int
|
||||||
|
The sampling rate of the audio signal.
|
||||||
|
window_size : float
|
||||||
|
The window size in seconds used for the spectrogram computation.
|
||||||
|
window_overlap : float
|
||||||
|
The overlap between windows in the spectrogram computation.
|
||||||
|
resize_factor : float
|
||||||
|
This factor is used to resize the spectrogram after the STFT
|
||||||
|
computation. Default is 0.5 which means that the spectrogram will be
|
||||||
|
reduced by half. Important to take into account for the final size of
|
||||||
|
the spectrogram.
|
||||||
|
divide_factor : int
|
||||||
|
The factor by which the spectrogram will be divided.
|
||||||
|
fixed_width : int, optional
|
||||||
|
If provided, the audio will be padded or cut so that the resulting
|
||||||
|
spectrogram width will be equal to this value.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
The padded audio signal.
|
||||||
|
"""
|
||||||
|
spec_width = compute_spectrogram_width(
|
||||||
|
audio.shape[0],
|
||||||
|
samplerate=samplerate,
|
||||||
|
window_duration=window_duration,
|
||||||
|
window_overlap=window_overlap,
|
||||||
|
resize_factor=resize_factor,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif fixed_width is not None and spec_width > fixed_width:
|
if fixed_width:
|
||||||
# too big
|
target_samples = x_coord_to_sample(
|
||||||
# used during training to ensure all the batches are the same size
|
fixed_width,
|
||||||
diff = fixed_width * step + noverlap - audio_raw.shape[0]
|
samplerate=samplerate,
|
||||||
audio_raw = audio_raw[:diff]
|
window_duration=window_duration,
|
||||||
|
window_overlap=window_overlap,
|
||||||
|
resize_factor=resize_factor,
|
||||||
|
)
|
||||||
|
|
||||||
elif (
|
if spec_width < fixed_width:
|
||||||
spec_width_rs < min_size
|
|
||||||
or (np.floor(spec_width_rs) % divide_factor) != 0
|
|
||||||
):
|
|
||||||
# need to be at least min_size
|
# need to be at least min_size
|
||||||
div_amt = np.ceil(spec_width_rs / float(divide_factor))
|
diff = target_samples - audio.shape[0]
|
||||||
div_amt = np.maximum(1, div_amt)
|
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||||
target_size = int(div_amt * divide_factor * (1.0 / resize_factor))
|
|
||||||
diff = target_size * step + noverlap - audio_raw.shape[0]
|
|
||||||
audio_raw = np.hstack(
|
|
||||||
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype))
|
|
||||||
)
|
|
||||||
|
|
||||||
return audio_raw
|
if spec_width > fixed_width:
|
||||||
|
return audio[:target_samples]
|
||||||
|
|
||||||
|
return audio
|
||||||
|
|
||||||
|
min_width = int(divide_factor / resize_factor)
|
||||||
|
|
||||||
|
if spec_width < min_width:
|
||||||
|
target_samples = x_coord_to_sample(
|
||||||
|
min_width,
|
||||||
|
samplerate=samplerate,
|
||||||
|
window_duration=window_duration,
|
||||||
|
window_overlap=window_overlap,
|
||||||
|
resize_factor=resize_factor,
|
||||||
|
)
|
||||||
|
diff = target_samples - audio.shape[0]
|
||||||
|
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||||
|
|
||||||
|
if (spec_width % divide_factor) == 0:
|
||||||
|
return audio
|
||||||
|
|
||||||
|
target_width = int(np.ceil(spec_width / divide_factor)) * divide_factor
|
||||||
|
target_samples = x_coord_to_sample(
|
||||||
|
target_width,
|
||||||
|
samplerate=samplerate,
|
||||||
|
window_duration=window_duration,
|
||||||
|
window_overlap=window_overlap,
|
||||||
|
resize_factor=resize_factor,
|
||||||
|
)
|
||||||
|
diff = target_samples - audio.shape[0]
|
||||||
|
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
|
||||||
|
|
||||||
|
|
||||||
def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
||||||
|
@ -8,6 +8,11 @@ import pandas as pd
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
try:
|
||||||
|
from numpy.exceptions import AxisError
|
||||||
|
except ImportError:
|
||||||
|
from numpy import AxisError # type: ignore
|
||||||
|
|
||||||
import batdetect2.detector.compute_features as feats
|
import batdetect2.detector.compute_features as feats
|
||||||
import batdetect2.detector.post_process as pp
|
import batdetect2.detector.post_process as pp
|
||||||
import batdetect2.utils.audio_utils as au
|
import batdetect2.utils.audio_utils as au
|
||||||
@ -80,6 +85,7 @@ def load_model(
|
|||||||
model_path: str = DEFAULT_MODEL_PATH,
|
model_path: str = DEFAULT_MODEL_PATH,
|
||||||
load_weights: bool = True,
|
load_weights: bool = True,
|
||||||
device: Union[torch.device, str, None] = None,
|
device: Union[torch.device, str, None] = None,
|
||||||
|
weights_only: bool = True,
|
||||||
) -> Tuple[DetectionModel, ModelParameters]:
|
) -> Tuple[DetectionModel, ModelParameters]:
|
||||||
"""Load model from file.
|
"""Load model from file.
|
||||||
|
|
||||||
@ -100,7 +106,11 @@ def load_model(
|
|||||||
if not os.path.isfile(model_path):
|
if not os.path.isfile(model_path):
|
||||||
raise FileNotFoundError("Model file not found.")
|
raise FileNotFoundError("Model file not found.")
|
||||||
|
|
||||||
net_params = torch.load(model_path, map_location=device)
|
net_params = torch.load(
|
||||||
|
model_path,
|
||||||
|
map_location=device,
|
||||||
|
weights_only=weights_only,
|
||||||
|
)
|
||||||
|
|
||||||
params = net_params["params"]
|
params = net_params["params"]
|
||||||
|
|
||||||
@ -242,7 +252,7 @@ def format_single_result(
|
|||||||
)
|
)
|
||||||
class_name = class_names[np.argmax(class_overall)]
|
class_name = class_names[np.argmax(class_overall)]
|
||||||
annotations = get_annotations_from_preds(predictions, class_names)
|
annotations = get_annotations_from_preds(predictions, class_names)
|
||||||
except (np.AxisError, ValueError):
|
except (AxisError, ValueError):
|
||||||
# No detections
|
# No detections
|
||||||
class_overall = np.zeros(len(class_names))
|
class_overall = np.zeros(len(class_names))
|
||||||
class_name = "None"
|
class_name = "None"
|
||||||
@ -738,9 +748,7 @@ def process_file(
|
|||||||
|
|
||||||
# Get original sampling rate
|
# Get original sampling rate
|
||||||
file_samp_rate = librosa.get_samplerate(audio_file)
|
file_samp_rate = librosa.get_samplerate(audio_file)
|
||||||
orig_samp_rate = file_samp_rate * float(
|
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
|
||||||
config.get("time_expansion", 1.0) or 1.0
|
|
||||||
)
|
|
||||||
|
|
||||||
# load audio file
|
# load audio file
|
||||||
sampling_rate, audio_full = au.load_audio(
|
sampling_rate, audio_full = au.load_audio(
|
||||||
|
@ -1,33 +1,22 @@
|
|||||||
[tool]
|
|
||||||
uv = { dev-dependencies = [
|
|
||||||
"ipykernel>=6.29.4",
|
|
||||||
"setuptools>=69.5.1",
|
|
||||||
"pytest>=8.1.1",
|
|
||||||
] }
|
|
||||||
[tool.pdm]
|
|
||||||
[tool.pdm.dev-dependencies]
|
|
||||||
dev = [
|
|
||||||
"pytest>=7.2.2",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "batdetect2"
|
name = "batdetect2"
|
||||||
version = "1.0.8"
|
version = "1.1.1"
|
||||||
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
|
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
|
||||||
authors = [
|
authors = [
|
||||||
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
|
||||||
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" }
|
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" },
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"click>=8.1.7",
|
||||||
"librosa>=0.10.1",
|
"librosa>=0.10.1",
|
||||||
"matplotlib>=3.7.1",
|
"matplotlib>=3.7.1",
|
||||||
"numpy>=1.23.5",
|
"numpy>=1.23.5",
|
||||||
"pandas>=1.5.3",
|
"pandas>=1.5.3",
|
||||||
"scikit-learn>=1.2.2",
|
"scikit-learn>=1.2.2",
|
||||||
"scipy>=1.10.1",
|
"scipy>=1.10.1",
|
||||||
"torch>=1.13.1",
|
"torch>=1.13.1,<2.5.0",
|
||||||
"torchaudio",
|
"torchaudio>=1.13.1,<2.5.0",
|
||||||
"torchvision",
|
"torchvision>=0.14.0",
|
||||||
"soundevent[audio,geometry,plot]>=2.0.1",
|
"soundevent[audio,geometry,plot]>=2.0.1",
|
||||||
"click>=8.1.7",
|
"click>=8.1.7",
|
||||||
"netcdf4>=1.6.5",
|
"netcdf4>=1.6.5",
|
||||||
@ -38,7 +27,7 @@ dependencies = [
|
|||||||
"lightning[extra]>=2.2.2",
|
"lightning[extra]>=2.2.2",
|
||||||
"tensorboard>=2.16.2",
|
"tensorboard>=2.16.2",
|
||||||
]
|
]
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9,<3.13"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { text = "CC-by-nc-4" }
|
license = { text = "CC-by-nc-4" }
|
||||||
classifiers = [
|
classifiers = [
|
||||||
@ -46,8 +35,10 @@ classifiers = [
|
|||||||
"Intended Audience :: Science/Research",
|
"Intended Audience :: Science/Research",
|
||||||
"Natural Language :: English",
|
"Natural Language :: English",
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
"Programming Language :: Python :: 3.8",
|
|
||||||
"Programming Language :: Python :: 3.9",
|
"Programming Language :: Python :: 3.9",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||||
"Topic :: Multimedia :: Sound/Audio :: Analysis",
|
"Topic :: Multimedia :: Sound/Audio :: Analysis",
|
||||||
@ -63,41 +54,40 @@ keywords = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["pdm-pep517>=1.0.0"]
|
requires = ["hatchling"]
|
||||||
build-backend = "pdm.pep517.api"
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
batdetect2 = "batdetect2.cli:cli"
|
batdetect2 = "batdetect2.cli:cli"
|
||||||
|
|
||||||
[tool.black]
|
[tool.uv]
|
||||||
line-length = 79
|
dev-dependencies = [
|
||||||
|
"debugpy>=1.8.8",
|
||||||
[tool.isort]
|
"hypothesis>=6.118.7",
|
||||||
profile = "black"
|
"pyright>=1.1.388",
|
||||||
line_length = 79
|
"pytest>=7.2.2",
|
||||||
|
"ruff>=0.7.3",
|
||||||
|
"ipykernel>=6.29.4",
|
||||||
|
"setuptools>=69.5.1",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 79
|
line-length = 79
|
||||||
|
target-version = "py39"
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
[tool.ruff.format]
|
||||||
module = [
|
docstring-code-format = true
|
||||||
"librosa",
|
docstring-code-line-length = 79
|
||||||
"pandas",
|
|
||||||
]
|
|
||||||
ignore_missing_imports = true
|
|
||||||
|
|
||||||
[tool.pylsp-mypy]
|
[tool.ruff.lint]
|
||||||
enabled = false
|
select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
|
||||||
live_mode = true
|
|
||||||
strict = true
|
|
||||||
|
|
||||||
[tool.pydocstyle]
|
[tool.ruff.lint.pydocstyle]
|
||||||
convention = "numpy"
|
convention = "numpy"
|
||||||
|
|
||||||
[tool.pyright]
|
[tool.pyright]
|
||||||
include = [
|
include = ["batdetect2", "tests"]
|
||||||
"bat_detect",
|
|
||||||
"tests",
|
|
||||||
]
|
|
||||||
venvPath = "."
|
venvPath = "."
|
||||||
venv = ".venv"
|
venv = ".venv"
|
||||||
|
pythonVersion = "3.9"
|
||||||
|
pythonPlatform = "All"
|
||||||
|
40
tests/conftest.py
Normal file
40
tests/conftest.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_data_dir() -> Path:
|
||||||
|
pkg_dir = Path(__file__).parent.parent
|
||||||
|
example_data_dir = pkg_dir / "example_data"
|
||||||
|
assert example_data_dir.exists()
|
||||||
|
return example_data_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_audio_dir(example_data_dir: Path) -> Path:
|
||||||
|
example_audio_dir = example_data_dir / "audio"
|
||||||
|
assert example_audio_dir.exists()
|
||||||
|
return example_audio_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_audio_files(example_audio_dir: Path) -> List[Path]:
|
||||||
|
audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]"))
|
||||||
|
assert len(audio_files) == 3
|
||||||
|
return audio_files
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def data_dir() -> Path:
|
||||||
|
dir = Path(__file__).parent / "data"
|
||||||
|
assert dir.exists()
|
||||||
|
return dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def contrib_dir(data_dir) -> Path:
|
||||||
|
dir = data_dir / "contrib"
|
||||||
|
assert dir.exists()
|
||||||
|
return dir
|
BIN
tests/data/contrib/jeff37/0166_20240531_223911.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240531_223911.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240602_225340.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240602_225340.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240603_033731.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240603_033731.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240603_033937.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240603_033937.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/jeff37/0166_20240604_233500.wav
Executable file
BIN
tests/data/contrib/jeff37/0166_20240604_233500.wav
Executable file
Binary file not shown.
BIN
tests/data/contrib/padpadpadpad/Audiomoth.WAV
Normal file
BIN
tests/data/contrib/padpadpadpad/Audiomoth.WAV
Normal file
Binary file not shown.
BIN
tests/data/contrib/padpadpadpad/AudiomothNoBatCalls.WAV
Normal file
BIN
tests/data/contrib/padpadpadpad/AudiomothNoBatCalls.WAV
Normal file
Binary file not shown.
BIN
tests/data/contrib/padpadpadpad/Echometer.wav
Normal file
BIN
tests/data/contrib/padpadpadpad/Echometer.wav
Normal file
Binary file not shown.
@ -1,14 +1,13 @@
|
|||||||
"""Test bat detect module API."""
|
"""Test bat detect module API."""
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import soundfile as sf
|
|
||||||
|
|
||||||
from batdetect2 import api
|
from batdetect2 import api
|
||||||
|
|
||||||
@ -267,7 +266,6 @@ def test_process_file_with_spec_slices():
|
|||||||
assert len(results["spec_slices"]) == len(detections)
|
assert len(results["spec_slices"]) == len(detections)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_process_file_with_empty_predictions_does_not_fail(
|
def test_process_file_with_empty_predictions_does_not_fail(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
):
|
):
|
||||||
|
136
tests/test_audio_utils.py
Normal file
136
tests/test_audio_utils.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from hypothesis import given
|
||||||
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
|
from batdetect2.detector import parameters
|
||||||
|
from batdetect2.utils import audio_utils, detector_utils
|
||||||
|
|
||||||
|
|
||||||
|
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||||
|
def test_can_compute_correct_spectrogram_width(duration: float):
|
||||||
|
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||||
|
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||||
|
|
||||||
|
length = int(duration * samplerate)
|
||||||
|
audio = np.random.rand(length)
|
||||||
|
|
||||||
|
spectrogram, _ = audio_utils.generate_spectrogram(
|
||||||
|
audio,
|
||||||
|
samplerate,
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert to pytorch
|
||||||
|
spectrogram = torch.from_numpy(spectrogram)
|
||||||
|
|
||||||
|
# add batch and channel dimensions
|
||||||
|
spectrogram = spectrogram.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
# resize the spec
|
||||||
|
resize_factor = params["resize_factor"]
|
||||||
|
spec_op_shape = (
|
||||||
|
int(params["spec_height"] * resize_factor),
|
||||||
|
int(spectrogram.shape[-1] * resize_factor),
|
||||||
|
)
|
||||||
|
spectrogram = F.interpolate(
|
||||||
|
spectrogram,
|
||||||
|
size=spec_op_shape,
|
||||||
|
mode="bilinear",
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_width = audio_utils.compute_spectrogram_width(
|
||||||
|
length,
|
||||||
|
samplerate=parameters.TARGET_SAMPLERATE_HZ,
|
||||||
|
window_duration=params["fft_win_length"],
|
||||||
|
window_overlap=params["fft_overlap"],
|
||||||
|
resize_factor=params["resize_factor"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert spectrogram.shape[-1] == expected_width
|
||||||
|
|
||||||
|
|
||||||
|
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||||
|
def test_pad_audio_without_fixed_size(duration: float):
|
||||||
|
# Test the pad_audio function
|
||||||
|
# This function is used to pad audio with zeros to a specific length
|
||||||
|
# It is used in the generate_spectrogram function
|
||||||
|
# The function is tested with a simplepas
|
||||||
|
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||||
|
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||||
|
|
||||||
|
length = int(duration * samplerate)
|
||||||
|
audio = np.random.rand(length)
|
||||||
|
|
||||||
|
# pad the audio to be divisible by divide factor
|
||||||
|
padded_audio = audio_utils.pad_audio(
|
||||||
|
audio,
|
||||||
|
samplerate=samplerate,
|
||||||
|
window_duration=params["fft_win_length"],
|
||||||
|
window_overlap=params["fft_overlap"],
|
||||||
|
resize_factor=params["resize_factor"],
|
||||||
|
divide_factor=params["spec_divide_factor"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that the padded audio is divisible by the divide factor
|
||||||
|
expected_width = audio_utils.compute_spectrogram_width(
|
||||||
|
len(padded_audio),
|
||||||
|
samplerate=parameters.TARGET_SAMPLERATE_HZ,
|
||||||
|
window_duration=params["fft_win_length"],
|
||||||
|
window_overlap=params["fft_overlap"],
|
||||||
|
resize_factor=params["resize_factor"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert expected_width % params["spec_divide_factor"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||||
|
def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
|
||||||
|
duration: float,
|
||||||
|
):
|
||||||
|
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||||
|
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||||
|
length = int(duration * samplerate)
|
||||||
|
audio = np.random.rand(length)
|
||||||
|
_, spectrogram, _ = detector_utils.compute_spectrogram(
|
||||||
|
audio,
|
||||||
|
samplerate,
|
||||||
|
params,
|
||||||
|
torch.device("cpu"),
|
||||||
|
)
|
||||||
|
assert spectrogram.shape[-1] % params["spec_divide_factor"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@given(
|
||||||
|
duration=st.floats(min_value=0.1, max_value=2),
|
||||||
|
width=st.integers(min_value=128, max_value=1024),
|
||||||
|
)
|
||||||
|
def test_pad_audio_with_fixed_width(duration: float, width: int):
|
||||||
|
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||||
|
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||||
|
|
||||||
|
length = int(duration * samplerate)
|
||||||
|
audio = np.random.rand(length)
|
||||||
|
|
||||||
|
# pad the audio to be divisible by divide factor
|
||||||
|
padded_audio = audio_utils.pad_audio(
|
||||||
|
audio,
|
||||||
|
samplerate=samplerate,
|
||||||
|
window_duration=params["fft_win_length"],
|
||||||
|
window_overlap=params["fft_overlap"],
|
||||||
|
resize_factor=params["resize_factor"],
|
||||||
|
divide_factor=params["spec_divide_factor"],
|
||||||
|
fixed_width=width,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that the padded audio is divisible by the divide factor
|
||||||
|
expected_width = audio_utils.compute_spectrogram_width(
|
||||||
|
len(padded_audio),
|
||||||
|
samplerate=parameters.TARGET_SAMPLERATE_HZ,
|
||||||
|
window_duration=params["fft_win_length"],
|
||||||
|
window_overlap=params["fft_overlap"],
|
||||||
|
resize_factor=params["resize_factor"],
|
||||||
|
)
|
||||||
|
assert expected_width == width
|
@ -1,22 +1,26 @@
|
|||||||
"""Test the command line interface."""
|
"""Test the command line interface."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from click.testing import CliRunner
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
from batdetect2.cli import cli
|
from batdetect2.cli import cli
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
def test_cli_base_command():
|
def test_cli_base_command():
|
||||||
"""Test the base command."""
|
"""Test the base command."""
|
||||||
runner = CliRunner()
|
|
||||||
result = runner.invoke(cli, ["--help"])
|
result = runner.invoke(cli, ["--help"])
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert "BatDetect2 - Bat Call Detection and Classification" in result.output
|
assert (
|
||||||
|
"BatDetect2 - Bat Call Detection and Classification" in result.output
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_cli_detect_command_help():
|
def test_cli_detect_command_help():
|
||||||
"""Test the detect command help."""
|
"""Test the detect command help."""
|
||||||
runner = CliRunner()
|
|
||||||
result = runner.invoke(cli, ["detect", "--help"])
|
result = runner.invoke(cli, ["detect", "--help"])
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert "Detect bat calls in files in AUDIO_DIR" in result.output
|
assert "Detect bat calls in files in AUDIO_DIR" in result.output
|
||||||
@ -30,7 +34,6 @@ def test_cli_detect_command_on_test_audio(tmp_path):
|
|||||||
if results_dir.exists():
|
if results_dir.exists():
|
||||||
results_dir.rmdir()
|
results_dir.rmdir()
|
||||||
|
|
||||||
runner = CliRunner()
|
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
@ -54,7 +57,6 @@ def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
|
|||||||
if results_dir.exists():
|
if results_dir.exists():
|
||||||
results_dir.rmdir()
|
results_dir.rmdir()
|
||||||
|
|
||||||
runner = CliRunner()
|
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
@ -68,8 +70,7 @@ def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert 'Time Expansion Factor: 10' in result.stdout
|
assert "Time Expansion Factor: 10" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
||||||
@ -80,7 +81,6 @@ def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
|||||||
if results_dir.exists():
|
if results_dir.exists():
|
||||||
results_dir.rmdir()
|
results_dir.rmdir()
|
||||||
|
|
||||||
runner = CliRunner()
|
|
||||||
result = runner.invoke(
|
result = runner.invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
@ -94,13 +94,12 @@ def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
|||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert results_dir.exists()
|
assert results_dir.exists()
|
||||||
|
|
||||||
|
|
||||||
csv_files = [path.name for path in results_dir.glob("*.csv")]
|
csv_files = [path.name for path in results_dir.glob("*.csv")]
|
||||||
|
|
||||||
expected_files = [
|
expected_files = [
|
||||||
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
|
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
|
||||||
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
|
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
|
||||||
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv"
|
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv",
|
||||||
]
|
]
|
||||||
|
|
||||||
for expected_file in expected_files:
|
for expected_file in expected_files:
|
||||||
@ -108,3 +107,26 @@ def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
|
|||||||
|
|
||||||
df = pd.read_csv(results_dir / expected_file)
|
df = pd.read_csv(results_dir / expected_file)
|
||||||
assert not (df.duration == -1).any()
|
assert not (df.duration == -1).any()
|
||||||
|
|
||||||
|
|
||||||
|
def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path):
|
||||||
|
results_dir = tmp_path / "results"
|
||||||
|
target = tmp_path / "audio"
|
||||||
|
target.mkdir()
|
||||||
|
|
||||||
|
# Create an empty file with the .wav extension
|
||||||
|
empty_file = target / "empty.wav"
|
||||||
|
empty_file.touch()
|
||||||
|
|
||||||
|
result = runner.invoke(
|
||||||
|
cli,
|
||||||
|
args=[
|
||||||
|
"detect",
|
||||||
|
str(target),
|
||||||
|
str(results_dir),
|
||||||
|
"0.3",
|
||||||
|
"--spec_features",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert f"Error processing file {empty_file}" in result.output
|
||||||
|
73
tests/test_contrib.py
Normal file
73
tests/test_contrib.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
"""Test suite to ensure user provided files are correctly processed."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
from batdetect2.cli import cli
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_process_jeff37_files(
|
||||||
|
contrib_dir: Path,
|
||||||
|
tmp_path: Path,
|
||||||
|
):
|
||||||
|
"""This test stems from issue #31.
|
||||||
|
|
||||||
|
A user provided a set of files which which batdetect2 cli failed and
|
||||||
|
generated the following error message:
|
||||||
|
|
||||||
|
[2272] "Error processing file!: negative dimensions are not allowed"
|
||||||
|
|
||||||
|
This test ensures that the error message is not generated when running
|
||||||
|
batdetect2 cli with the same set of files.
|
||||||
|
"""
|
||||||
|
path = contrib_dir / "jeff37"
|
||||||
|
assert path.exists()
|
||||||
|
|
||||||
|
results_dir = tmp_path / "results"
|
||||||
|
result = runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"detect",
|
||||||
|
str(path),
|
||||||
|
str(results_dir),
|
||||||
|
"0.3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert results_dir.exists()
|
||||||
|
assert len(list(results_dir.glob("*.csv"))) == 5
|
||||||
|
assert len(list(results_dir.glob("*.json"))) == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_process_padpadpadpad_files(
|
||||||
|
contrib_dir: Path,
|
||||||
|
tmp_path: Path,
|
||||||
|
):
|
||||||
|
"""This test stems from issue #29.
|
||||||
|
|
||||||
|
Batdetect2 cli failed on the files provided by the user @padpadpadpad
|
||||||
|
with the following error message:
|
||||||
|
|
||||||
|
AttributeError: module 'numpy' has no attribute 'AxisError'
|
||||||
|
|
||||||
|
This test ensures that the files are processed without any error.
|
||||||
|
"""
|
||||||
|
path = contrib_dir / "padpadpadpad"
|
||||||
|
assert path.exists()
|
||||||
|
results_dir = tmp_path / "results"
|
||||||
|
result = runner.invoke(
|
||||||
|
cli,
|
||||||
|
[
|
||||||
|
"detect",
|
||||||
|
str(path),
|
||||||
|
str(results_dir),
|
||||||
|
"0.3",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert results_dir.exists()
|
||||||
|
assert len(list(results_dir.glob("*.csv"))) == 2
|
||||||
|
assert len(list(results_dir.glob("*.json"))) == 2
|
78
tests/test_model.py
Normal file
78
tests/test_model.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
"""Test suite for model functions."""
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from hypothesis import given, settings
|
||||||
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
|
from batdetect2 import api
|
||||||
|
from batdetect2.detector import parameters
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_import_model_without_warnings():
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("error")
|
||||||
|
api.load_model()
|
||||||
|
|
||||||
|
|
||||||
|
@settings(deadline=None, max_examples=5)
|
||||||
|
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||||
|
def test_can_import_model_without_pickle(duration: float):
|
||||||
|
# NOTE: remove this test once no other issues are found This is a temporary
|
||||||
|
# test to check that change in model loading did not impact model behaviour
|
||||||
|
# in any way.
|
||||||
|
|
||||||
|
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||||
|
audio = np.random.rand(int(duration * samplerate))
|
||||||
|
|
||||||
|
model_without_pickle, model_params_without_pickle = api.load_model(
|
||||||
|
weights_only=True
|
||||||
|
)
|
||||||
|
model_with_pickle, model_params_with_pickle = api.load_model(
|
||||||
|
weights_only=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert model_params_without_pickle == model_params_with_pickle
|
||||||
|
|
||||||
|
predictions_without_pickle, _, _ = api.process_audio(
|
||||||
|
audio,
|
||||||
|
model=model_without_pickle,
|
||||||
|
)
|
||||||
|
predictions_with_pickle, _, _ = api.process_audio(
|
||||||
|
audio,
|
||||||
|
model=model_with_pickle,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert predictions_without_pickle == predictions_with_pickle
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_import_model_without_pickle_on_test_data(
|
||||||
|
example_audio_files: List[Path],
|
||||||
|
):
|
||||||
|
# NOTE: remove this test once no other issues are found This is a temporary
|
||||||
|
# test to check that change in model loading did not impact model behaviour
|
||||||
|
# in any way.
|
||||||
|
|
||||||
|
model_without_pickle, model_params_without_pickle = api.load_model(
|
||||||
|
weights_only=True
|
||||||
|
)
|
||||||
|
model_with_pickle, model_params_with_pickle = api.load_model(
|
||||||
|
weights_only=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert model_params_without_pickle == model_params_with_pickle
|
||||||
|
|
||||||
|
for audio_file in example_audio_files:
|
||||||
|
audio = api.load_audio(str(audio_file))
|
||||||
|
predictions_without_pickle, _, _ = api.process_audio(
|
||||||
|
audio,
|
||||||
|
model=model_without_pickle,
|
||||||
|
)
|
||||||
|
predictions_with_pickle, _, _ = api.process_audio(
|
||||||
|
audio,
|
||||||
|
model=model_with_pickle,
|
||||||
|
)
|
||||||
|
assert predictions_without_pickle == predictions_with_pickle
|
Loading…
Reference in New Issue
Block a user