Compare commits

...

31 Commits
v1.1.0 ... main

Author SHA1 Message Date
Santiago Martinez Balvanera
4cd71497e7
Merge pull request #52 from kaviecos/http_documentation
Some checks failed
Python package / build (3.10) (push) Has been cancelled
Python package / build (3.11) (push) Has been cancelled
Python package / build (3.12) (push) Has been cancelled
Python package / build (3.9) (push) Has been cancelled
Http documentation
2025-06-03 23:38:45 +01:00
Kavi Askholm Mellerup
ba670932d5
Update README.md 2025-06-03 14:26:54 +02:00
Kavi Askholm Mellerup
d3747c57f2
Update README.md
Added section for using the API with HTTP.
2025-06-03 14:20:20 +02:00
mbsantiago
42a838e9f2 Bump version: 1.2.0 → 1.3.0
Some checks failed
Python package / build (3.10) (push) Has been cancelled
Python package / build (3.11) (push) Has been cancelled
Python package / build (3.12) (push) Has been cancelled
Python package / build (3.9) (push) Has been cancelled
2025-05-16 15:02:15 +01:00
Santiago Martinez Balvanera
c10903a646
Merge pull request #44 from kaviecos/http_support
Http support
2025-05-16 15:01:08 +01:00
Kavi
4282e2ae70 Added AudioPath as an alias for the path definition 2025-05-16 15:13:08 +02:00
Kavi
52570738f2 Renamed load_audio_data to load_audio_and_samplerate 2025-05-16 14:56:35 +02:00
Kavi
cbd362d6ea Updated docstrings for tests 2025-05-16 14:53:35 +02:00
mbsantiago
4b75e13fa2 Bump version: 1.1.1 → 1.2.0
Some checks failed
Python package / build (3.12) (push) Has been cancelled
Python package / build (3.9) (push) Has been cancelled
Python package / build (3.10) (push) Has been cancelled
Python package / build (3.11) (push) Has been cancelled
2025-03-12 18:03:43 +00:00
Santiago Martinez Balvanera
98bf506634
Merge pull request #45 from macaodha/feat/add-chunk-size-to-cli
Added the chunk_size param to the detect command
2025-03-12 18:01:48 +00:00
mbsantiago
b4c59f7de1 Added the chunk_size param to the detect command 2025-03-12 17:59:18 +00:00
Kavi
54ca555587 Fixed code to support Python3.9 syntax 2025-02-27 13:51:58 +01:00
Kavi
230b6167bc Added load_audio_data() which returns the original sample rate. Changed load_audio() implementation so that it uses load_audio_data but retains its signature. du.process_file() now does not need to call get_samplerate 2025-02-27 08:10:27 +01:00
Kavi
f62bc99ab2 Added api method to process a URL 2025-02-26 14:13:21 +01:00
Kavi
47dbdc79c2 Added tests for api and load_audio 2025-02-26 14:12:42 +01:00
Kavi
e10e270de4 Fix error in get_samplerate when reading io.BytesIO. 2025-02-26 14:12:09 +01:00
Kavi
6af7fef316 Fix 'unknown' id by providing a _generate_id() function. 2025-02-26 14:11:11 +01:00
Kavi
838a1ade0d Updated get_samplerate test to use example data file. 2025-02-25 14:46:40 +01:00
Kavi
66ac7e608f Changed the signature of api.process_file, au.load_audio and du.process_file. This allows users to use the same args for processing data as librosa.load() 2025-02-25 14:24:48 +01:00
mbsantiago
2100a3e483 Bump version: 1.1.0 → 1.1.1 2024-11-11 13:01:57 +00:00
mbsantiago
1d3cd2e305 Update lock 2024-11-11 13:01:55 +00:00
Santiago Martinez Balvanera
d5753b95bb
Merge pull request #39 from macaodha/fix/handle-empty-files-gracefully
fix: Handle Empty Audio Files Gracefully (GH-20)
2024-11-11 12:59:28 +00:00
mbsantiago
69f59ff559 Added the EOFError to the list of expected errors when processing files 2024-11-11 12:44:21 +00:00
mbsantiago
1a11174bc4 Add a test to validate that empty files are handled gracefully 2024-11-11 12:43:46 +00:00
Santiago Martinez Balvanera
c5c9476e52
Merge pull request #38 from macaodha/test/add_audio_files_provided_by_padpadpadpad_to_test_suite
test: Add failing audio files from GH-29 to contrib test suite
2024-11-11 12:23:03 +00:00
mbsantiago
270b3f212d Created test to verify no errors occurred when running on padpadpadpad recordings 2024-11-11 12:13:00 +00:00
mbsantiago
f61d1d8c72 Add audio files provided by @padpadpadpad to the contrib test files 2024-11-11 12:12:38 +00:00
Santiago Martinez Balvanera
4627ddd739
Merge pull request #37 from macaodha/fix/GH-30-torch-deprecation-warning-weights-only
fix: Address PyTorch Model Loading Deprecation Warning (GH-30)
2024-11-11 12:02:26 +00:00
mbsantiago
3477d7b5b4 Run the same test with example data instead of random audio 2024-11-11 11:57:46 +00:00
mbsantiago
394c66a2ee Added test to validate that changing model loading behaviour did not change model predictions 2024-11-11 11:46:27 +00:00
mbsantiago
d085b3212c Added weights_only argument to model loading function 2024-11-11 11:46:06 +00:00
19 changed files with 439 additions and 52 deletions

View File

@ -1,5 +1,5 @@
[bumpversion]
current_version = 1.1.0
current_version = 1.3.0
commit = True
tag = True

View File

@ -96,6 +96,27 @@ detections, features = api.process_spectrogram(spec)
You can integrate the detections or the extracted features to your custom analysis pipeline.
#### Using the Python API with HTTP
```python
from batdetect2 import api
import io
import requests
AUDIO_URL = "<insert your audio url here>"
# Process a whole file from a url
results = api.process_url(AUDIO_URL)
# Or, load audio and compute spectrograms
# 'requests.get(AUDIO_URL).content' fetches the raw bytes. You are free to use other sources to fetch the raw bytes
audio = api.load_audio(io.BytesIO(requests.get(AUDIO_URL).content))
spec = api.generate_spectrogram(audio)
# And process the audio or the spectrogram with the model
detections, features, spec = api.process_audio(audio)
detections, features = api.process_spectrogram(spec)
```
## Training the model on your own data
Take a look at the steps outlined in finetuning readme [here](batdetect2/finetune/readme.md) for a description of how to train your own model.

View File

@ -3,4 +3,4 @@ import logging
numba_logger = logging.getLogger("numba")
numba_logger.setLevel(logging.WARNING)
__version__ = "1.1.0"
__version__ = "1.3.0"

View File

@ -97,8 +97,9 @@ consult the API documentation in the code.
"""
import warnings
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, BinaryIO, Any, Union
from .types import AudioPath
import numpy as np
import torch
@ -120,6 +121,12 @@ from batdetect2.types import (
)
from batdetect2.utils.detector_utils import list_audio_files, load_model
import audioread
import os
import soundfile as sf
import requests
import io
# Remove warnings from torch
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
@ -238,34 +245,82 @@ def generate_spectrogram(
def process_file(
audio_file: str,
path: AudioPath,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE,
file_id: Optional[str] = None
) -> du.RunResults:
"""Process audio file with model.
Parameters
----------
audio_file : str
Path to audio file.
path : AudioPath
Path to audio data.
model : DetectionModel, optional
Detection model. Uses default model if not specified.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
device : torch.device, optional
Device to use, by default tries to use GPU if available.
file_id: Optional[str],
Give the data an id. If path is a string path to a file this can be ignored and
the file_id will be the basename of the file.
"""
if config is None:
config = CONFIG
return du.process_file(
audio_file,
path,
model,
config,
device,
file_id
)
def process_url(
url: str,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE,
file_id: Optional[str] = None
) -> du.RunResults:
"""Process audio file with model.
Parameters
----------
url : str
HTTP URL to load the audio data from
model : DetectionModel, optional
Detection model. Uses default model if not specified.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
device : torch.device, optional
Device to use, by default tries to use GPU if available.
file_id: Optional[str],
Give the data an id. Defaults to the URL
"""
if config is None:
config = CONFIG
if file_id is None:
file_id = url
response = requests.get(url)
# Raise exception on HTTP error
response.raise_for_status()
# Retrieve body as raw bytes
raw_audio_data = response.content
return du.process_file(
io.BytesIO(raw_audio_data),
model,
config,
device,
file_id
)
def process_spectrogram(
spec: torch.Tensor,

View File

@ -1,4 +1,5 @@
"""BatDetect2 command line interface."""
import os
import click
@ -44,6 +45,12 @@ def cli():
default=False,
help="Extracts CNN call features",
)
@click.option(
"--chunk_size",
type=float,
default=2,
help="Specifies the duration of chunks in seconds. BatDetect2 will divide longer files into smaller chunks and process them independently. Larger chunks increase computation time and memory usage but may provide more contextual information for inference.",
)
@click.option(
"--spec_features",
is_flag=True,
@ -79,6 +86,7 @@ def detect(
ann_dir: str,
detection_threshold: float,
time_expansion_factor: int,
chunk_size: float,
**args,
):
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
@ -107,7 +115,7 @@ def detect(
**args,
"time_expansion": time_expansion_factor,
"spec_slices": False,
"chunk_size": 2,
"chunk_size": chunk_size,
"detection_threshold": detection_threshold,
}
)
@ -129,10 +137,9 @@ def detect(
):
results_path = audio_file.replace(audio_dir, ann_dir)
save_results_to_file(results, results_path)
except (RuntimeError, ValueError, LookupError) as err:
except (RuntimeError, ValueError, LookupError, EOFError) as err:
error_files.append(audio_file)
click.secho(f"Error processing file!: {err}", fg="red")
raise err
click.secho(f"Error processing file {audio_file}: {err}", fg="red")
click.echo(f"\nResults saved to: {ann_dir}")
@ -147,6 +154,7 @@ def print_config(config: ProcessingConfiguration):
click.echo("\nProcessing Configuration:")
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
click.echo(f"Detection Threshold: {config.get('detection_threshold')}")
click.echo(f"Chunk Size: {config.get('chunk_size')}s")
if __name__ == "__main__":

View File

@ -1,6 +1,10 @@
"""Types used in the code base."""
from typing import List, NamedTuple, Optional, Union
from typing import List, NamedTuple, Optional, Union, Any, BinaryIO
import audioread
import os
import soundfile as sf
import numpy as np
import torch
@ -40,6 +44,9 @@ __all__ = [
"SpectrogramParameters",
]
AudioPath = Union[
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
]
class SpectrogramParameters(TypedDict):
"""Parameters for generating spectrograms."""

View File

@ -1,17 +1,24 @@
import warnings
from typing import Optional, Tuple
from typing import Optional, Tuple, Union, Any, BinaryIO
from ..types import AudioPath
import librosa
import librosa.core.spectrum
import numpy as np
import torch
import audioread
import os
import soundfile as sf
from batdetect2.detector import parameters
from . import wavfile
__all__ = [
"load_audio",
"load_audio_and_samplerate",
"generate_spectrogram",
"pad_audio",
]
@ -140,9 +147,8 @@ def generate_spectrogram(
return spec, spec_for_viz
def load_audio(
audio_file: str,
path: AudioPath,
time_exp_fact: float,
target_samp_rate: int,
scale: bool = False,
@ -154,7 +160,7 @@ def load_audio(
Only mono files are supported.
Args:
audio_file (str): Path to the audio file.
path (string, int, pathlib.Path, soundfile.SoundFile, audioread object, or file-like object): path to the input file.
target_samp_rate (int): Target sampling rate.
scale (bool): Whether to scale the audio to [-1, 1].
max_duration (float): Maximum duration of the audio in seconds.
@ -166,12 +172,42 @@ def load_audio(
Raises:
ValueError: If the audio file is stereo.
"""
sample_rate, audio_data, _ = load_audio_and_samplerate(path, time_exp_fact, target_samp_rate, scale, max_duration)
return sample_rate, audio_data
def load_audio_and_samplerate(
path: AudioPath,
time_exp_fact: float,
target_samp_rate: int,
scale: bool = False,
max_duration: Optional[float] = None,
) -> Tuple[int, np.ndarray, Union[float, int]]:
"""Load an audio file and resample it to the target sampling rate.
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
Only mono files are supported.
Args:
path (string, int, pathlib.Path, soundfile.SoundFile, audioread object, or file-like object): path to the input file.
target_samp_rate (int): Target sampling rate.
scale (bool): Whether to scale the audio to [-1, 1].
max_duration (float): Maximum duration of the audio in seconds.
Returns:
sampling_rate: The sampling rate of the audio.
audio_raw: The audio signal in a numpy array.
file_sampling_rate: The original sampling rate of the audio
Raises:
ValueError: If the audio file is stereo.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
# sampling_rate, audio_raw = wavfile.read(audio_file)
audio_raw, sampling_rate = librosa.load(
audio_file,
audio_raw, file_sampling_rate = librosa.load(
path,
sr=None,
dtype=np.float32,
)
@ -179,7 +215,7 @@ def load_audio(
if len(audio_raw.shape) > 1:
raise ValueError("Currently does not handle stereo files")
sampling_rate = sampling_rate * time_exp_fact
sampling_rate = file_sampling_rate * time_exp_fact
# resample - need to do this after correcting for time expansion
sampling_rate_old = sampling_rate
@ -207,7 +243,7 @@ def load_audio(
audio_raw = audio_raw - audio_raw.mean()
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
return sampling_rate, audio_raw
return sampling_rate, audio_raw, file_sampling_rate
def compute_spectrogram_width(

View File

@ -1,8 +1,9 @@
import json
import os
from typing import Any, Iterator, List, Optional, Tuple, Union
from typing import Any, Iterator, List, Optional, Tuple, Union, BinaryIO
from ..types import AudioPath
import librosa
import numpy as np
import pandas as pd
import torch
@ -31,6 +32,13 @@ from batdetect2.types import (
SpectrogramParameters,
)
import audioread
import os
import io
import soundfile as sf
import hashlib
import uuid
__all__ = [
"load_model",
"list_audio_files",
@ -85,6 +93,7 @@ def load_model(
model_path: str = DEFAULT_MODEL_PATH,
load_weights: bool = True,
device: Optional[torch.device] = None,
weights_only: bool = True,
) -> Tuple[DetectionModel, ModelParameters]:
"""Load model from file.
@ -105,7 +114,11 @@ def load_model(
if not os.path.isfile(model_path):
raise FileNotFoundError("Model file not found.")
net_params = torch.load(model_path, map_location=device)
net_params = torch.load(
model_path,
map_location=device,
weights_only=weights_only,
)
params = net_params["params"]
@ -724,10 +737,11 @@ def process_audio_array(
def process_file(
audio_file: str,
path: AudioPath,
model: DetectionModel,
config: ProcessingConfiguration,
device: torch.device,
file_id: Optional[str] = None
) -> Union[RunResults, Any]:
"""Process a single audio file with detection model.
@ -736,7 +750,7 @@ def process_file(
Parameters
----------
audio_file : str
path : AudioPath
Path to audio file.
model : torch.nn.Module
@ -745,6 +759,9 @@ def process_file(
config : ProcessingConfiguration
Configuration for processing.
file_id: Optional[str],
Give the data an id. Defaults to the filename if path is a string. Otherwise an md5 will be calculated from the binary data.
Returns
-------
results : Results or Any
@ -757,19 +774,17 @@ def process_file(
cnn_feats = []
spec_slices = []
# Get original sampling rate
file_samp_rate = librosa.get_samplerate(audio_file)
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
# load audio file
sampling_rate, audio_full = au.load_audio(
audio_file,
sampling_rate, audio_full, file_samp_rate = au.load_audio_and_samplerate(
path,
time_exp_fact=config.get("time_expansion", 1) or 1,
target_samp_rate=config["target_samp_rate"],
scale=config["scale_raw_audio"],
max_duration=config.get("max_duration"),
)
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
# loop through larger file and split into chunks
# TODO: fix so that it overlaps correctly and takes care of
# duplicate detections at borders
@ -818,9 +833,13 @@ def process_file(
spec_slices,
)
_file_id = file_id
if _file_id is None:
_file_id = _generate_id(path)
# convert results to a dictionary in the right format
results = convert_results(
file_id=os.path.basename(audio_file),
file_id=_file_id,
time_exp=config.get("time_expansion", 1) or 1,
duration=audio_full.shape[0] / float(sampling_rate),
params=config,
@ -840,6 +859,22 @@ def process_file(
return results
def _generate_id(path: AudioPath) -> str:
""" Generate an id based on the path.
If the path is a str or PathLike it will parsed as the basename.
This should ensure backwards compatibility with previous versions.
"""
if isinstance(path, str) or isinstance(path, os.PathLike):
return os.path.basename(path)
elif isinstance(path, (BinaryIO, io.BytesIO)):
path.seek(0)
md5 = hashlib.md5(path.read()).hexdigest()
path.seek(0)
return md5
else:
return str(uuid.uuid4())
def summarize_results(results, predictions, config):
"""Print summary of results."""

View File

@ -1,6 +1,6 @@
[project]
name = "batdetect2"
version = "1.1.0"
version = "1.3.0"
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
authors = [
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },

View File

@ -1,8 +1,31 @@
from pathlib import Path
from typing import List
import pytest
@pytest.fixture
def example_data_dir() -> Path:
pkg_dir = Path(__file__).parent.parent
example_data_dir = pkg_dir / "example_data"
assert example_data_dir.exists()
return example_data_dir
@pytest.fixture
def example_audio_dir(example_data_dir: Path) -> Path:
example_audio_dir = example_data_dir / "audio"
assert example_audio_dir.exists()
return example_audio_dir
@pytest.fixture
def example_audio_files(example_audio_dir: Path) -> List[Path]:
audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]"))
assert len(audio_files) == 3
return audio_files
@pytest.fixture
def data_dir() -> Path:
dir = Path(__file__).parent / "data"

Binary file not shown.

Binary file not shown.

View File

@ -1,21 +1,22 @@
"""Test bat detect module API."""
from pathlib import Path
import os
from glob import glob
from pathlib import Path
import numpy as np
import soundfile as sf
import torch
from torch import nn
import soundfile as sf
from batdetect2 import api
import io
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
def test_load_model_with_default_params():
"""Test loading model with default parameters."""
@ -267,7 +268,6 @@ def test_process_file_with_spec_slices():
assert len(results["spec_slices"]) == len(detections)
def test_process_file_with_empty_predictions_does_not_fail(
tmp_path: Path,
):
@ -282,3 +282,28 @@ def test_process_file_with_empty_predictions_does_not_fail(
assert results is not None
assert len(results["pred_dict"]["annotation"]) == 0
def test_process_file_file_id_defaults_to_basename():
"""Test that process_file assigns basename as an id if no file_id is provided."""
# Recording donated by @@kdarras
basename = "20230322_172000_selec2.wav"
path = os.path.join(DATA_DIR, basename)
output = api.process_file(path)
predictions = output["pred_dict"]
id = predictions["id"]
assert id == basename
def test_bytesio_file_id_defaults_to_md5():
"""Test that process_file assigns an md5 sum as an id if no file_id is provided when using binary data."""
# Recording donated by @@kdarras
basename = "20230322_172000_selec2.wav"
path = os.path.join(DATA_DIR, basename)
with open(path, "rb") as f:
data = io.BytesIO(f.read())
output = api.process_file(data)
predictions = output["pred_dict"]
id = predictions["id"]
assert id == "7ade9ebf1a9fe5477ff3a2dc57001929"

View File

@ -6,7 +6,10 @@ from hypothesis import strategies as st
from batdetect2.detector import parameters
from batdetect2.utils import audio_utils, detector_utils
import io
import os
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
@given(duration=st.floats(min_value=0.1, max_value=2))
def test_can_compute_correct_spectrogram_width(duration: float):
@ -134,3 +137,20 @@ def test_pad_audio_with_fixed_width(duration: float, width: int):
resize_factor=params["resize_factor"],
)
assert expected_width == width
def test_load_audio_using_bytesio():
basename = "20230322_172000_selec2.wav"
path = os.path.join(DATA_DIR, basename)
with open(path, "rb") as f:
data = io.BytesIO(f.read())
sample_rate, audio_data, file_sample_rate = audio_utils.load_audio_and_samplerate(data, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ)
expected_sample_rate, expected_audio_data, exp_file_sample_rate = audio_utils.load_audio_and_samplerate(path, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ)
assert expected_sample_rate == sample_rate
assert exp_file_sample_rate == file_sample_rate
assert np.array_equal(audio_data, expected_audio_data)

View File

@ -1,22 +1,26 @@
"""Test the command line interface."""
from pathlib import Path
from click.testing import CliRunner
import pandas as pd
from click.testing import CliRunner
from batdetect2.cli import cli
runner = CliRunner()
def test_cli_base_command():
"""Test the base command."""
runner = CliRunner()
result = runner.invoke(cli, ["--help"])
assert result.exit_code == 0
assert "BatDetect2 - Bat Call Detection and Classification" in result.output
assert (
"BatDetect2 - Bat Call Detection and Classification" in result.output
)
def test_cli_detect_command_help():
"""Test the detect command help."""
runner = CliRunner()
result = runner.invoke(cli, ["detect", "--help"])
assert result.exit_code == 0
assert "Detect bat calls in files in AUDIO_DIR" in result.output
@ -30,7 +34,6 @@ def test_cli_detect_command_on_test_audio(tmp_path):
if results_dir.exists():
results_dir.rmdir()
runner = CliRunner()
result = runner.invoke(
cli,
[
@ -54,7 +57,6 @@ def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
if results_dir.exists():
results_dir.rmdir()
runner = CliRunner()
result = runner.invoke(
cli,
[
@ -68,8 +70,7 @@ def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
)
assert result.exit_code == 0
assert 'Time Expansion Factor: 10' in result.stdout
assert "Time Expansion Factor: 10" in result.stdout
def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
@ -80,7 +81,6 @@ def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
if results_dir.exists():
results_dir.rmdir()
runner = CliRunner()
result = runner.invoke(
cli,
[
@ -94,13 +94,12 @@ def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
assert result.exit_code == 0
assert results_dir.exists()
csv_files = [path.name for path in results_dir.glob("*.csv")]
expected_files = [
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv"
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv",
]
for expected_file in expected_files:
@ -108,3 +107,52 @@ def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
df = pd.read_csv(results_dir / expected_file)
assert not (df.duration == -1).any()
def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path):
results_dir = tmp_path / "results"
target = tmp_path / "audio"
target.mkdir()
# Create an empty file with the .wav extension
empty_file = target / "empty.wav"
empty_file.touch()
result = runner.invoke(
cli,
args=[
"detect",
str(target),
str(results_dir),
"0.3",
"--spec_features",
],
)
assert result.exit_code == 0
assert f"Error processing file {empty_file}" in result.output
def test_can_set_chunk_size(tmp_path: Path):
results_dir = tmp_path / "results"
# Remove results dir if it exists
if results_dir.exists():
results_dir.rmdir()
result = runner.invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
"--chunk_size",
"1",
],
)
assert "Chunk Size: 1.0s" in result.output
assert result.exit_code == 0
assert results_dir.exists()
assert len(list(results_dir.glob("*.csv"))) == 3
assert len(list(results_dir.glob("*.json"))) == 3

View File

@ -9,7 +9,7 @@ from batdetect2.cli import cli
runner = CliRunner()
def test_files_negative_dimensions_are_not_allowed(
def test_can_process_jeff37_files(
contrib_dir: Path,
tmp_path: Path,
):
@ -40,3 +40,34 @@ def test_files_negative_dimensions_are_not_allowed(
assert results_dir.exists()
assert len(list(results_dir.glob("*.csv"))) == 5
assert len(list(results_dir.glob("*.json"))) == 5
def test_can_process_padpadpadpad_files(
contrib_dir: Path,
tmp_path: Path,
):
"""This test stems from issue #29.
Batdetect2 cli failed on the files provided by the user @padpadpadpad
with the following error message:
AttributeError: module 'numpy' has no attribute 'AxisError'
This test ensures that the files are processed without any error.
"""
path = contrib_dir / "padpadpadpad"
assert path.exists()
results_dir = tmp_path / "results"
result = runner.invoke(
cli,
[
"detect",
str(path),
str(results_dir),
"0.3",
],
)
assert result.exit_code == 0
assert results_dir.exists()
assert len(list(results_dir.glob("*.csv"))) == 2
assert len(list(results_dir.glob("*.json"))) == 2

78
tests/test_model.py Normal file
View File

@ -0,0 +1,78 @@
"""Test suite for model functions."""
import warnings
from pathlib import Path
from typing import List
import numpy as np
from hypothesis import given, settings
from hypothesis import strategies as st
from batdetect2 import api
from batdetect2.detector import parameters
def test_can_import_model_without_warnings():
with warnings.catch_warnings():
warnings.simplefilter("error")
api.load_model()
@settings(deadline=None, max_examples=5)
@given(duration=st.floats(min_value=0.1, max_value=2))
def test_can_import_model_without_pickle(duration: float):
# NOTE: remove this test once no other issues are found This is a temporary
# test to check that change in model loading did not impact model behaviour
# in any way.
samplerate = parameters.TARGET_SAMPLERATE_HZ
audio = np.random.rand(int(duration * samplerate))
model_without_pickle, model_params_without_pickle = api.load_model(
weights_only=True
)
model_with_pickle, model_params_with_pickle = api.load_model(
weights_only=False
)
assert model_params_without_pickle == model_params_with_pickle
predictions_without_pickle, _, _ = api.process_audio(
audio,
model=model_without_pickle,
)
predictions_with_pickle, _, _ = api.process_audio(
audio,
model=model_with_pickle,
)
assert predictions_without_pickle == predictions_with_pickle
def test_can_import_model_without_pickle_on_test_data(
example_audio_files: List[Path],
):
# NOTE: remove this test once no other issues are found This is a temporary
# test to check that change in model loading did not impact model behaviour
# in any way.
model_without_pickle, model_params_without_pickle = api.load_model(
weights_only=True
)
model_with_pickle, model_params_with_pickle = api.load_model(
weights_only=False
)
assert model_params_without_pickle == model_params_with_pickle
for audio_file in example_audio_files:
audio = api.load_audio(str(audio_file))
predictions_without_pickle, _, _ = api.process_audio(
audio,
model=model_without_pickle,
)
predictions_with_pickle, _, _ = api.process_audio(
audio,
model=model_with_pickle,
)
assert predictions_without_pickle == predictions_with_pickle

2
uv.lock generated
View File

@ -26,7 +26,7 @@ wheels = [
[[package]]
name = "batdetect2"
version = "1.0.8"
version = "1.1.0"
source = { editable = "." }
dependencies = [
{ name = "click" },