Compare commits

...

77 Commits
v1.0.4 ... main

Author SHA1 Message Date
Santiago Martinez Balvanera
4cd71497e7
Merge pull request #52 from kaviecos/http_documentation
Some checks failed
Python package / build (3.10) (push) Has been cancelled
Python package / build (3.11) (push) Has been cancelled
Python package / build (3.12) (push) Has been cancelled
Python package / build (3.9) (push) Has been cancelled
Http documentation
2025-06-03 23:38:45 +01:00
Kavi Askholm Mellerup
ba670932d5
Update README.md 2025-06-03 14:26:54 +02:00
Kavi Askholm Mellerup
d3747c57f2
Update README.md
Added section for using the API with HTTP.
2025-06-03 14:20:20 +02:00
mbsantiago
42a838e9f2 Bump version: 1.2.0 → 1.3.0
Some checks failed
Python package / build (3.10) (push) Has been cancelled
Python package / build (3.11) (push) Has been cancelled
Python package / build (3.12) (push) Has been cancelled
Python package / build (3.9) (push) Has been cancelled
2025-05-16 15:02:15 +01:00
Santiago Martinez Balvanera
c10903a646
Merge pull request #44 from kaviecos/http_support
Http support
2025-05-16 15:01:08 +01:00
Kavi
4282e2ae70 Added AudioPath as an alias for the path definition 2025-05-16 15:13:08 +02:00
Kavi
52570738f2 Renamed load_audio_data to load_audio_and_samplerate 2025-05-16 14:56:35 +02:00
Kavi
cbd362d6ea Updated docstrings for tests 2025-05-16 14:53:35 +02:00
mbsantiago
4b75e13fa2 Bump version: 1.1.1 → 1.2.0
Some checks failed
Python package / build (3.12) (push) Has been cancelled
Python package / build (3.9) (push) Has been cancelled
Python package / build (3.10) (push) Has been cancelled
Python package / build (3.11) (push) Has been cancelled
2025-03-12 18:03:43 +00:00
Santiago Martinez Balvanera
98bf506634
Merge pull request #45 from macaodha/feat/add-chunk-size-to-cli
Added the chunk_size param to the detect command
2025-03-12 18:01:48 +00:00
mbsantiago
b4c59f7de1 Added the chunk_size param to the detect command 2025-03-12 17:59:18 +00:00
Kavi
54ca555587 Fixed code to support Python3.9 syntax 2025-02-27 13:51:58 +01:00
Kavi
230b6167bc Added load_audio_data() which returns the original sample rate. Changed load_audio() implementation so that it uses load_audio_data but retains its signature. du.process_file() now does not need to call get_samplerate 2025-02-27 08:10:27 +01:00
Kavi
f62bc99ab2 Added api method to process a URL 2025-02-26 14:13:21 +01:00
Kavi
47dbdc79c2 Added tests for api and load_audio 2025-02-26 14:12:42 +01:00
Kavi
e10e270de4 Fix error in get_samplerate when reading io.BytesIO. 2025-02-26 14:12:09 +01:00
Kavi
6af7fef316 Fix 'unknown' id by providing a _generate_id() function. 2025-02-26 14:11:11 +01:00
Kavi
838a1ade0d Updated get_samplerate test to use example data file. 2025-02-25 14:46:40 +01:00
Kavi
66ac7e608f Changed the signature of api.process_file, au.load_audio and du.process_file. This allows users to use the same args for processing data as librosa.load() 2025-02-25 14:24:48 +01:00
mbsantiago
2100a3e483 Bump version: 1.1.0 → 1.1.1 2024-11-11 13:01:57 +00:00
mbsantiago
1d3cd2e305 Update lock 2024-11-11 13:01:55 +00:00
Santiago Martinez Balvanera
d5753b95bb
Merge pull request #39 from macaodha/fix/handle-empty-files-gracefully
fix: Handle Empty Audio Files Gracefully (GH-20)
2024-11-11 12:59:28 +00:00
mbsantiago
69f59ff559 Added the EOFError to the list of expected errors when processing files 2024-11-11 12:44:21 +00:00
mbsantiago
1a11174bc4 Add a test to validate that empty files are handled gracefully 2024-11-11 12:43:46 +00:00
Santiago Martinez Balvanera
c5c9476e52
Merge pull request #38 from macaodha/test/add_audio_files_provided_by_padpadpadpad_to_test_suite
test: Add failing audio files from GH-29 to contrib test suite
2024-11-11 12:23:03 +00:00
mbsantiago
270b3f212d Created test to verify no errors occurred when running on padpadpadpad recordings 2024-11-11 12:13:00 +00:00
mbsantiago
f61d1d8c72 Add audio files provided by @padpadpadpad to the contrib test files 2024-11-11 12:12:38 +00:00
Santiago Martinez Balvanera
4627ddd739
Merge pull request #37 from macaodha/fix/GH-30-torch-deprecation-warning-weights-only
fix: Address PyTorch Model Loading Deprecation Warning (GH-30)
2024-11-11 12:02:26 +00:00
mbsantiago
3477d7b5b4 Run the same test with example data instead of random audio 2024-11-11 11:57:46 +00:00
mbsantiago
394c66a2ee Added test to validate that changing model loading behaviour did not change model predictions 2024-11-11 11:46:27 +00:00
mbsantiago
d085b3212c Added weights_only argument to model loading function 2024-11-11 11:46:06 +00:00
mbsantiago
c393c5c29b Bump version: 1.0.8 → 1.1.0 2024-11-11 11:18:39 +00:00
mbsantiago
3c22ff28a7 add bump2version config 2024-11-11 11:18:37 +00:00
Santiago Martinez Balvanera
7dc28695b2
Merge pull request #36 from macaodha/fix/GH-31-negative-dimension-are-not-allowed
fix: Resolve detect Command Failure with Specific Audio Files (GH-31)
2024-11-10 22:53:45 +00:00
mbsantiago
505cca2dea Original test now passing. Issue seems to be fixed 2024-11-10 22:41:03 +00:00
mbsantiago
7906842a16 Added test to ensure pad_audio function, and utils, are working as expected 2024-11-10 22:39:30 +00:00
mbsantiago
a4b22d6590 Improve the pad_audio function
This function was the culprit of the error. Broke the function into
other helper functions to make the flow easier to follow
2024-11-10 22:39:10 +00:00
mbsantiago
25e0a53ad1 Add hypothesis to dev dependencies for easier testing 2024-11-10 22:38:13 +00:00
mbsantiago
039c002796 Remove unnecessary imports 2024-11-10 22:37:57 +00:00
mbsantiago
c97a87b2a4 Remove numba debug logging for easier debugging 2024-11-10 22:37:45 +00:00
mbsantiago
d93d8284d0 Added a test that replicates the error 2024-11-10 20:06:58 +00:00
Santiago Martinez Balvanera
697b5dbddb
Merge pull request #35 from macaodha/fix/update-python-version-support
feat: Drop Python 3.8 Support, Add Python 3.12 Support
2024-11-10 19:52:42 +00:00
mbsantiago
d5bf8f5ad8 Drop 3.9 and add 3.12 to python-version matrix in test github workflow 2024-11-10 19:46:12 +00:00
mbsantiago
fcbccbe012 Update uv lock 2024-11-10 19:45:01 +00:00
mbsantiago
4917641e2c Drop support for python 3.8 and add for 3.12 2024-11-10 19:44:57 +00:00
Santiago Martinez Balvanera
39c3918103
Merge pull request #34 from macaodha/feat/migrate-to-numpy-2
Feat/migrate to numpy 2
2024-11-10 19:17:33 +00:00
mbsantiago
1ac3808fee Remove the numpy<2 requirement from the dependencies specification 2024-11-10 19:13:44 +00:00
mbsantiago
9e0ad7fd78 address all linting errors from rule NPY201 2024-11-10 19:13:30 +00:00
mbsantiago
95bb0985e7 Add ruff rule to help migrating to numpy 2.0 2024-11-10 19:13:11 +00:00
Santiago Martinez Balvanera
cb088359ae
Merge pull request #33 from macaodha/feat/migrate-to-uv
Feat/migrate to uv
2024-11-10 19:04:47 +00:00
mbsantiago
c5030123aa Restrict pytorch version for python 3.8 compatibility 2024-11-10 18:59:35 +00:00
mbsantiago
1c1fbd8019 Added dev dependencies and updated github actions to use uv 2024-11-10 18:32:50 +00:00
mbsantiago
c65fe1c9f9 change pyproject metadata to use uv and hatch instead of pdm 2024-11-10 18:20:18 +00:00
Santiago Martinez Balvanera
d05bec880a
Merge pull request #32 from ccarrizosa/fix/np_exception
Fix numpy exception handling
2024-11-10 18:15:31 +00:00
ccarrizosa
8597ef0a1c Limit numpy versions to <2 2024-11-10 15:54:16 +01:00
ccarrizosa
2d8a7b67f8 Revert support for newest numpy versions. 2024-11-10 15:54:01 +01:00
ccarrizosa
68351d2224 Fix numpy exception handling 2024-11-09 22:35:11 +01:00
mbsantiago
3f34164028 update gitignore 2024-10-15 22:39:27 +01:00
Santiago Martinez
3744709c97 Bump version: 1.0.7 → 1.0.8 2024-01-30 12:41:32 +00:00
Santiago Martinez Balvanera
875a581044
Merge pull request #25 from macaodha/fix/update-dependencies
Add support for python 3.11. Update dependencies requirements
2024-01-30 12:40:54 +00:00
Santiago Martinez
c40197bf1c Lower scipy version to support python 3.8 2024-01-30 12:36:25 +00:00
Santiago Martinez
bbba4625be Add support for python 3.11. Update dependencies requirements 2024-01-30 12:34:47 +00:00
Santiago Martinez
963cc53fd3 Bump version: 1.0.6 → 1.0.7 2023-11-24 15:41:50 +00:00
Santiago Martinez
860e63dddf fix: implemented a cleaning step to remove detections above the nyquist limit 2023-11-24 15:40:58 +00:00
Santiago Martinez
986cfc463c tests: Added a test to check that detections above the nyquist freq are excluded 2023-11-24 15:40:37 +00:00
Oisin Mac Aodha
14aefafe14
Update README.md 2023-08-25 10:36:59 +01:00
Oisin Mac Aodha
1cef6e37e0
Update README.md 2023-08-25 10:35:23 +01:00
Santiago Martinez
9baf60ff2d Bump version: 1.0.5 → 1.0.6 2023-08-03 12:42:52 +01:00
Oisin Mac Aodha
70877495d4
Merge pull request #16 from macaodha/fix/GH-15-spectrogram-features-computation
Fix/gh 15 spectrogram features computation
2023-08-03 12:36:19 +01:00
Santiago Martinez
3288f52bbd tests: added tests for feature computation 2023-08-03 11:46:06 +01:00
Santiago Martinez
8e8779a72e fix: call interval kwargs name error 2023-08-03 11:46:06 +01:00
Santiago Martinez
36d616530a tests: Added test for using the spec_feature flag 2023-08-03 11:46:06 +01:00
Oisin Mac Aodha
14d001a71c
Merge pull request #12 from arky/minor-fixes
Minor typos fixed
2023-06-28 13:48:10 +01:00
Arky
a54d197fdf Minor typos fixed 2023-06-16 20:02:05 +07:00
Santiago Martinez
cca883a730 Bump version: 1.0.4 → 1.0.5 2023-05-11 14:06:58 +01:00
Santiago Martinez Balvanera
e5370e98db
Merge pull request #11 from macaodha/fix/GH-10-merge-results-index-error
fix: the case of no detections is now handled better
2023-05-11 14:05:57 +01:00
Santiago Martinez
04af74228b fix: the case of no detections is now handled better 2023-05-11 13:59:20 +01:00
33 changed files with 3165 additions and 1642 deletions

8
.bumpversion.cfg Normal file
View File

@ -0,0 +1,8 @@
[bumpversion]
current_version = 1.3.0
commit = True
tag = True
[bumpversion:file:batdetect2/__init__.py]
[bumpversion:file:pyproject.toml]

View File

@ -1,34 +1,29 @@
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
name: Python package name: Python package
on: on:
push: push:
branches: [ "main" ] branches: ["main"]
pull_request: pull_request:
branches: [ "main" ] branches: ["main"]
jobs: jobs:
build: build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
python-version: ["3.8", "3.9", "3.10"] python-version: ["3.9", "3.10", "3.11", "3.12"]
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} - name: Install uv
uses: actions/setup-python@v3 uses: astral-sh/setup-uv@v3
with: with:
python-version: ${{ matrix.python-version }} enable-cache: true
- name: Install dependencies cache-dependency-glob: "uv.lock"
run: | - name: Set up Python ${{ matrix.python-version }}
python -m pip install --upgrade pip run: uv python install ${{ matrix.python-version }}
python -m pip install pytest - name: Install the project
pip install . run: uv sync --all-extras --dev
- name: Test with pytest - name: Test with pytest
run: | run: uv run pytest
pytest

View File

@ -1,11 +1,3 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package name: Upload Python Package
on: on:
@ -17,23 +9,22 @@ permissions:
jobs: jobs:
deploy: deploy:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v3 uses: actions/setup-python@v3
with: with:
python-version: '3.x' python-version: "3.x"
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install build pip install build
- name: Build package - name: Build package
run: python -m build run: python -m build
- name: Publish package - name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with: with:
user: __token__ user: __token__
password: ${{ secrets.PYPI_API_TOKEN }} password: ${{ secrets.PYPI_API_TOKEN }}

13
.gitignore vendored
View File

@ -65,7 +65,7 @@ ipython_config.py
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control. # in version control.
# https://pdm.fming.dev/#use-with-ide # https://pdm.fming.dev/#use-with-ide
.pdm.toml .pdm-python
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/ __pypackages__/
@ -102,10 +102,11 @@ experiments/*
.virtual_documents .virtual_documents
.ipynb_checkpoints .ipynb_checkpoints
*.ipynb *.ipynb
# DO Include
!batdetect2_notebook.ipynb !batdetect2_notebook.ipynb
# Batdetect Models [Include]
!batdetect2/models/*.pth.tar !batdetect2/models/*.pth.tar
!tests/data/*.wav
# Bump2version !tests/data/**/*.wav
.bumpversion.cfg notebooks/lightning_logs
example_data/preprocessed

View File

@ -29,7 +29,7 @@ pip install batdetect2
``` ```
Alternatively, download this code from the repository (by clicking on the green button on top right) and unzip it. Alternatively, download this code from the repository (by clicking on the green button on top right) and unzip it.
Once unziped, run this from extracted folder. Once unzipped, run this from extracted folder.
```bash ```bash
pip install . pip install .
@ -96,9 +96,30 @@ detections, features = api.process_spectrogram(spec)
You can integrate the detections or the extracted features to your custom analysis pipeline. You can integrate the detections or the extracted features to your custom analysis pipeline.
#### Using the Python API with HTTP
```python
from batdetect2 import api
import io
import requests
AUDIO_URL = "<insert your audio url here>"
# Process a whole file from a url
results = api.process_url(AUDIO_URL)
# Or, load audio and compute spectrograms
# 'requests.get(AUDIO_URL).content' fetches the raw bytes. You are free to use other sources to fetch the raw bytes
audio = api.load_audio(io.BytesIO(requests.get(AUDIO_URL).content))
spec = api.generate_spectrogram(audio)
# And process the audio or the spectrogram with the model
detections, features, spec = api.process_audio(audio)
detections, features = api.process_spectrogram(spec)
```
## Training the model on your own data ## Training the model on your own data
Take a look at the steps outlined in fintuning readme [here](bat_detect/finetune/readme.md) for a description of how to train your own model. Take a look at the steps outlined in finetuning readme [here](batdetect2/finetune/readme.md) for a description of how to train your own model.
## Data and annotations ## Data and annotations

View File

@ -1 +1,6 @@
__version__ = '1.0.4' import logging
numba_logger = logging.getLogger("numba")
numba_logger.setLevel(logging.WARNING)
__version__ = "1.3.0"

View File

@ -97,8 +97,9 @@ consult the API documentation in the code.
""" """
import warnings import warnings
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, BinaryIO, Any, Union
from .types import AudioPath
import numpy as np import numpy as np
import torch import torch
@ -120,6 +121,12 @@ from batdetect2.types import (
) )
from batdetect2.utils.detector_utils import list_audio_files, load_model from batdetect2.utils.detector_utils import list_audio_files, load_model
import audioread
import os
import soundfile as sf
import requests
import io
# Remove warnings from torch # Remove warnings from torch
warnings.filterwarnings("ignore", category=UserWarning, module="torch") warnings.filterwarnings("ignore", category=UserWarning, module="torch")
@ -238,34 +245,82 @@ def generate_spectrogram(
def process_file( def process_file(
audio_file: str, path: AudioPath,
model: DetectionModel = MODEL, model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None, config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
file_id: Optional[str] = None
) -> du.RunResults: ) -> du.RunResults:
"""Process audio file with model. """Process audio file with model.
Parameters Parameters
---------- ----------
audio_file : str path : AudioPath
Path to audio file. Path to audio data.
model : DetectionModel, optional model : DetectionModel, optional
Detection model. Uses default model if not specified. Detection model. Uses default model if not specified.
config : Optional[ProcessingConfiguration], optional config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters). Processing configuration, by default None (uses default parameters).
device : torch.device, optional device : torch.device, optional
Device to use, by default tries to use GPU if available. Device to use, by default tries to use GPU if available.
file_id: Optional[str],
Give the data an id. If path is a string path to a file this can be ignored and
the file_id will be the basename of the file.
""" """
if config is None: if config is None:
config = CONFIG config = CONFIG
return du.process_file( return du.process_file(
audio_file, path,
model, model,
config, config,
device, device,
file_id
) )
def process_url(
url: str,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE,
file_id: Optional[str] = None
) -> du.RunResults:
"""Process audio file with model.
Parameters
----------
url : str
HTTP URL to load the audio data from
model : DetectionModel, optional
Detection model. Uses default model if not specified.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
device : torch.device, optional
Device to use, by default tries to use GPU if available.
file_id: Optional[str],
Give the data an id. Defaults to the URL
"""
if config is None:
config = CONFIG
if file_id is None:
file_id = url
response = requests.get(url)
# Raise exception on HTTP error
response.raise_for_status()
# Retrieve body as raw bytes
raw_audio_data = response.content
return du.process_file(
io.BytesIO(raw_audio_data),
model,
config,
device,
file_id
)
def process_spectrogram( def process_spectrogram(
spec: torch.Tensor, spec: torch.Tensor,

View File

@ -1,4 +1,5 @@
"""BatDetect2 command line interface.""" """BatDetect2 command line interface."""
import os import os
import click import click
@ -44,6 +45,12 @@ def cli():
default=False, default=False,
help="Extracts CNN call features", help="Extracts CNN call features",
) )
@click.option(
"--chunk_size",
type=float,
default=2,
help="Specifies the duration of chunks in seconds. BatDetect2 will divide longer files into smaller chunks and process them independently. Larger chunks increase computation time and memory usage but may provide more contextual information for inference.",
)
@click.option( @click.option(
"--spec_features", "--spec_features",
is_flag=True, is_flag=True,
@ -79,6 +86,7 @@ def detect(
ann_dir: str, ann_dir: str,
detection_threshold: float, detection_threshold: float,
time_expansion_factor: int, time_expansion_factor: int,
chunk_size: float,
**args, **args,
): ):
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR. """Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
@ -107,7 +115,7 @@ def detect(
**args, **args,
"time_expansion": time_expansion_factor, "time_expansion": time_expansion_factor,
"spec_slices": False, "spec_slices": False,
"chunk_size": 2, "chunk_size": chunk_size,
"detection_threshold": detection_threshold, "detection_threshold": detection_threshold,
} }
) )
@ -129,10 +137,9 @@ def detect(
): ):
results_path = audio_file.replace(audio_dir, ann_dir) results_path = audio_file.replace(audio_dir, ann_dir)
save_results_to_file(results, results_path) save_results_to_file(results, results_path)
except (RuntimeError, ValueError, LookupError) as err: except (RuntimeError, ValueError, LookupError, EOFError) as err:
error_files.append(audio_file) error_files.append(audio_file)
click.secho(f"Error processing file!: {err}", fg="red") click.secho(f"Error processing file {audio_file}: {err}", fg="red")
raise err
click.echo(f"\nResults saved to: {ann_dir}") click.echo(f"\nResults saved to: {ann_dir}")
@ -147,6 +154,7 @@ def print_config(config: ProcessingConfiguration):
click.echo("\nProcessing Configuration:") click.echo("\nProcessing Configuration:")
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}") click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
click.echo(f"Detection Threshold: {config.get('detection_threshold')}") click.echo(f"Detection Threshold: {config.get('detection_threshold')}")
click.echo(f"Chunk Size: {config.get('chunk_size')}s")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,22 +1,27 @@
"""Functions to compute features from predictions."""
from typing import Dict, Optional
import numpy as np import numpy as np
from batdetect2 import types
from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ
def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq): def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq):
"""Convert spectrogram index to frequency in Hz.""" ""
spec_ind = spec_height - spec_ind spec_ind = spec_height - spec_ind
return round( return round(
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2 (spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
) )
def extract_spec_slices(spec, pred_nms, params): def extract_spec_slices(spec, pred_nms):
""" """Extract spectrogram slices from spectrogram.
Extracts spectrogram slices from spectrogram based on detected call locations.
"""
The slices are extracted based on detected call locations.
"""
x_pos = pred_nms["x_pos"] x_pos = pred_nms["x_pos"]
y_pos = pred_nms["y_pos"]
bb_width = pred_nms["bb_width"] bb_width = pred_nms["bb_width"]
bb_height = pred_nms["bb_height"]
slices = [] slices = []
# add 20% padding either side of call # add 20% padding either side of call
@ -35,100 +40,273 @@ def extract_spec_slices(spec, pred_nms, params):
return slices return slices
def get_feature_names(): def compute_duration(
feature_names = [ prediction: types.Prediction,
"duration", **_,
"low_freq_bb", ) -> float:
"high_freq_bb", """Compute duration of call in seconds."""
"bandwidth", return round(prediction["end_time"] - prediction["start_time"], 5)
"max_power_bb",
"max_power",
"max_power_first",
"max_power_second",
"call_interval",
]
return feature_names
def get_feats(spec, pred_nms, params): def compute_low_freq(
prediction: types.Prediction,
**_,
) -> float:
"""Compute lowest frequency in call in Hz."""
return int(prediction["low_freq"])
def compute_high_freq(
prediction: types.Prediction,
**_,
) -> float:
"""Compute highest frequency in call in Hz."""
return int(prediction["high_freq"])
def compute_bandwidth(
prediction: types.Prediction,
**_,
) -> float:
"""Compute bandwidth of call in Hz."""
return int(prediction["high_freq"] - prediction["low_freq"])
def compute_max_power_bb(
prediction: types.Prediction,
spec: Optional[np.ndarray] = None,
min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ,
**_,
) -> float:
"""Compute frequency with maximum power in call in Hz.
This is the frequency with the maximum power in the bounding box of the
call.
""" """
Extracts features from spectrogram based on detected call locations. if spec is None:
Condsider re-extracting spectrogram for this to get better temporal resolution. return np.nan
x_start = max(0, prediction["x_pos"])
x_end = min(
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
)
# y low is the lowest freq but it will have a higher value due to array
# starting at 0 at top
y_low = min(spec.shape[0] - 1, prediction["y_pos"])
y_high = max(0, prediction["y_pos"] - prediction["bb_height"])
spec_bb = spec[y_high:y_low, x_start:x_end]
power_per_freq_band = np.sum(spec_bb, axis=1)
try:
max_power_ind = np.argmax(power_per_freq_band)
except ValueError:
# If the call is too short, the bounding box might be empty.
# In this case, return NaN.
return np.nan
return int(
convert_int_to_freq(
y_high + max_power_ind,
spec.shape[0],
min_freq,
max_freq,
)
)
def compute_max_power(
prediction: types.Prediction,
spec: Optional[np.ndarray] = None,
min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ,
**_,
) -> float:
"""Compute frequency with maximum power in during the call in Hz."""
if spec is None:
return np.nan
x_start = max(0, prediction["x_pos"])
x_end = min(
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
)
spec_call = spec[:, x_start:x_end]
power_per_freq_band = np.sum(spec_call, axis=1)
max_power_ind = np.argmax(power_per_freq_band)
return int(
convert_int_to_freq(
max_power_ind,
spec.shape[0],
min_freq,
max_freq,
)
)
def compute_max_power_first(
prediction: types.Prediction,
spec: Optional[np.ndarray] = None,
min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ,
**_,
) -> float:
"""Compute frequency with maximum power in first half of call in Hz."""
if spec is None:
return np.nan
x_start = max(0, prediction["x_pos"])
x_end = min(
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
)
spec_call = spec[:, x_start:x_end]
first_half = spec_call[:, : int(spec_call.shape[1] / 2)]
power_per_freq_band = np.sum(first_half, axis=1)
max_power_ind = np.argmax(power_per_freq_band)
return int(
convert_int_to_freq(
max_power_ind,
spec.shape[0],
min_freq,
max_freq,
)
)
def compute_max_power_second(
prediction: types.Prediction,
spec: Optional[np.ndarray] = None,
min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ,
**_,
) -> float:
"""Compute frequency with maximum power in second half of call in Hz."""
if spec is None:
return np.nan
x_start = max(0, prediction["x_pos"])
x_end = min(
spec.shape[1] - 1, prediction["x_pos"] + prediction["bb_width"]
)
spec_call = spec[:, x_start:x_end]
second_half = spec_call[:, int(spec_call.shape[1] / 2) :]
power_per_freq_band = np.sum(second_half, axis=1)
max_power_ind = np.argmax(power_per_freq_band)
return int(
convert_int_to_freq(
max_power_ind,
spec.shape[0],
min_freq,
max_freq,
)
)
def compute_call_interval(
prediction: types.Prediction,
previous: Optional[types.Prediction] = None,
**_,
) -> float:
"""Compute time between this call and the previous call in seconds."""
if previous is None:
return np.nan
return round(prediction["start_time"] - previous["end_time"], 5)
# NOTE: The order of the features in this dictionary is important. The
# features are extracted in this order and the order of the columns in the
# output csv file is determined by this order. In order to avoid breaking
# changes in the output csv file, new features should be added to the end of
# this dictionary.
FEATURES: Dict[str, types.FeatureExtractor] = {
"duration": compute_duration,
"low_freq_bb": compute_low_freq,
"high_freq_bb": compute_high_freq,
"bandwidth": compute_bandwidth,
"max_power_bb": compute_max_power_bb,
"max_power": compute_max_power,
"max_power_first": compute_max_power_first,
"max_power_second": compute_max_power_second,
"call_interval": compute_call_interval,
}
def get_feats(
spec: np.ndarray,
pred_nms: types.PredictionResults,
params: types.FeatureExtractionParameters,
):
"""Extract features from spectrogram based on detected call locations.
The features extracted are:
- duration: duration of call in seconds
- low_freq: lowest frequency in call in kHz
- high_freq: highest frequency in call in kHz
- bandwidth: high_freq - low_freq
- max_power_bb: frequency with maximum power in call in kHz
- max_power: frequency with maximum power in spectrogram in kHz
- max_power_first: frequency with maximum power in first half of call in
kHz.
- max_power_second: frequency with maximum power in second half of call in
kHz.
- call_interval: time between this call and the previous call in seconds
Consider re-extracting spectrogram for this to get better temporal
resolution.
For more possible features check out: For more possible features check out:
https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt https://github.com/YvesBas/Tadarida-D/blob/master/Manual_Tadarida-D.odt
Parameters
----------
spec : np.ndarray
Spectrogram from which to extract features.
pred_nms : types.PredictionResults
Information about detected calls from which to extract features.
params : types.FeatureExtractionParameters
Parameters for feature extraction.
Returns
-------
features : np.ndarray
Extracted features for each detected call. Shape is
(num_detections, num_features).
""" """
x_pos = pred_nms["x_pos"]
y_pos = pred_nms["y_pos"]
bb_width = pred_nms["bb_width"]
bb_height = pred_nms["bb_height"]
feature_names = get_feature_names()
num_detections = len(pred_nms["det_probs"]) num_detections = len(pred_nms["det_probs"])
features = ( features = np.empty((num_detections, len(FEATURES)), dtype=np.float32)
np.ones((num_detections, len(feature_names)), dtype=np.float32) * -1 previous = None
)
for ff in range(num_detections): for row in range(num_detections):
x_start = int(np.maximum(0, x_pos[ff])) prediction: types.Prediction = {
x_end = int( "det_prob": float(pred_nms["det_probs"][row]),
np.minimum(spec.shape[1] - 1, np.round(x_pos[ff] + bb_width[ff])) "class_prob": pred_nms["class_probs"][:, row],
) "start_time": float(pred_nms["start_times"][row]),
# y low is the lowest freq but it will have a higher value due to array starting at 0 at top "end_time": float(pred_nms["end_times"][row]),
y_low = int(np.minimum(spec.shape[0] - 1, y_pos[ff])) "low_freq": float(pred_nms["low_freqs"][row]),
y_high = int(np.maximum(0, np.round(y_pos[ff] - bb_height[ff]))) "high_freq": float(pred_nms["high_freqs"][row]),
spec_slice = spec[:, x_start:x_end] "x_pos": int(pred_nms["x_pos"][row]),
"y_pos": int(pred_nms["y_pos"][row]),
"bb_width": int(pred_nms["bb_width"][row]),
"bb_height": int(pred_nms["bb_height"][row]),
}
if spec_slice.shape[1] > 1: for col, feature in enumerate(FEATURES.values()):
features[ff, 0] = round( features[row, col] = feature(
pred_nms["end_times"][ff] - pred_nms["start_times"][ff], 5 prediction,
) previous=previous,
features[ff, 1] = int(pred_nms["low_freqs"][ff]) spec=spec,
features[ff, 2] = int(pred_nms["high_freqs"][ff]) **params,
features[ff, 3] = int(
pred_nms["high_freqs"][ff] - pred_nms["low_freqs"][ff]
)
features[ff, 4] = int(
convert_int_to_freq(
y_high + spec_slice[y_high:y_low, :].sum(1).argmax(),
spec.shape[0],
params["min_freq"],
params["max_freq"],
)
)
features[ff, 5] = int(
convert_int_to_freq(
spec_slice.sum(1).argmax(),
spec.shape[0],
params["min_freq"],
params["max_freq"],
)
)
hlf_val = spec_slice.shape[1] // 2
features[ff, 6] = int(
convert_int_to_freq(
spec_slice[:, :hlf_val].sum(1).argmax(),
spec.shape[0],
params["min_freq"],
params["max_freq"],
)
)
features[ff, 7] = int(
convert_int_to_freq(
spec_slice[:, hlf_val:].sum(1).argmax(),
spec.shape[0],
params["min_freq"],
params["max_freq"],
)
) )
if ff > 0: previous = prediction
features[ff, 8] = round(
pred_nms["start_times"][ff]
- pred_nms["start_times"][ff - 1],
5,
)
return features return features
def get_feature_names():
"""Get names of features in the order they are extracted."""
return list(FEATURES.keys())

View File

@ -1,7 +1,5 @@
import glob import glob
import json import json
import os
import random
import numpy as np import numpy as np

View File

@ -1,5 +1,10 @@
"""Types used in the code base.""" """Types used in the code base."""
from typing import List, NamedTuple, Optional
from typing import List, NamedTuple, Optional, Union, Any, BinaryIO
import audioread
import os
import soundfile as sf
import numpy as np import numpy as np
import torch import torch
@ -17,7 +22,7 @@ except ImportError:
try: try:
from typing import NotRequired from typing import NotRequired # type: ignore
except ImportError: except ImportError:
from typing_extensions import NotRequired from typing_extensions import NotRequired
@ -25,10 +30,13 @@ except ImportError:
__all__ = [ __all__ = [
"Annotation", "Annotation",
"DetectionModel", "DetectionModel",
"FeatureExtractionParameters",
"FeatureExtractor",
"FileAnnotations", "FileAnnotations",
"ModelOutput", "ModelOutput",
"ModelParameters", "ModelParameters",
"NonMaximumSuppressionConfig", "NonMaximumSuppressionConfig",
"Prediction",
"PredictionResults", "PredictionResults",
"ProcessingConfiguration", "ProcessingConfiguration",
"ResultParams", "ResultParams",
@ -36,6 +44,9 @@ __all__ = [
"SpectrogramParameters", "SpectrogramParameters",
] ]
AudioPath = Union[
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
]
class SpectrogramParameters(TypedDict): class SpectrogramParameters(TypedDict):
"""Parameters for generating spectrograms.""" """Parameters for generating spectrograms."""
@ -312,6 +323,40 @@ class ModelOutput(NamedTuple):
"""Tensor with intermediate features.""" """Tensor with intermediate features."""
class Prediction(TypedDict):
"""Singe prediction."""
det_prob: float
"""Detection probability."""
x_pos: int
"""X position of the detection in pixels."""
y_pos: int
"""Y position of the detection in pixels."""
bb_width: int
"""Width of the detection in pixels."""
bb_height: int
"""Height of the detection in pixels."""
start_time: float
"""Start time of the detection in seconds."""
end_time: float
"""End time of the detection in seconds."""
low_freq: float
"""Low frequency of the detection in Hz."""
high_freq: float
"""High frequency of the detection in Hz."""
class_prob: np.ndarray
"""Vector holding the probability of each class."""
class PredictionResults(TypedDict): class PredictionResults(TypedDict):
"""Results of the prediction. """Results of the prediction.
@ -418,6 +463,16 @@ class NonMaximumSuppressionConfig(TypedDict):
"""Threshold for detection probability.""" """Threshold for detection probability."""
class FeatureExtractionParameters(TypedDict):
"""Parameters that control the feature extraction function."""
min_freq: int
"""Minimum frequency to consider in Hz."""
max_freq: int
"""Maximum frequency to consider in Hz."""
class HeatmapParameters(TypedDict): class HeatmapParameters(TypedDict):
"""Parameters that control the heatmap generation function.""" """Parameters that control the heatmap generation function."""
@ -473,3 +528,11 @@ class AnnotationGroup(TypedDict):
y_inds: NotRequired[np.ndarray] y_inds: NotRequired[np.ndarray]
"""Y coordinate of the annotations in the spectrogram.""" """Y coordinate of the annotations in the spectrogram."""
class FeatureExtractor(Protocol):
"""Protocol for feature extractors."""
def __call__(self, prediction: Prediction, **kwargs) -> Union[float, int]:
"""Extract features from a prediction."""
...

View File

@ -1,34 +1,67 @@
import warnings import warnings
from typing import Optional, Tuple from typing import Optional, Tuple, Union, Any, BinaryIO
from ..types import AudioPath
import librosa import librosa
import librosa.core.spectrum import librosa.core.spectrum
import numpy as np import numpy as np
import torch import torch
import audioread
import os
import soundfile as sf
from batdetect2.detector import parameters
from . import wavfile from . import wavfile
__all__ = [ __all__ = [
"load_audio", "load_audio",
"load_audio_and_samplerate",
"generate_spectrogram", "generate_spectrogram",
"pad_audio", "pad_audio",
] ]
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap): def time_to_x_coords(
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor time_in_file: float,
noverlap = np.floor(fft_overlap * nfft) samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap) window_duration: float = parameters.FFT_WIN_LENGTH_S,
window_overlap: float = parameters.FFT_OVERLAP,
) -> float:
nfft = np.floor(window_duration * samplerate) # int() uses floor
noverlap = np.floor(window_overlap * nfft)
return (time_in_file * samplerate - noverlap) / (nfft - noverlap)
# NOTE this is also defined in post_process def x_coords_to_time(
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap): x_pos: int,
nfft = np.floor(fft_win_length * sampling_rate) samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
noverlap = np.floor(fft_overlap * nfft) window_duration: float = parameters.FFT_WIN_LENGTH_S,
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate window_overlap: float = parameters.FFT_OVERLAP,
) -> float:
n_fft = np.floor(window_duration * samplerate)
n_overlap = np.floor(window_overlap * n_fft)
n_step = n_fft - n_overlap
return ((x_pos * n_step) + n_overlap) / samplerate
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window # return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
def x_coord_to_sample(
x_pos: int,
samplerate: float = parameters.TARGET_SAMPLERATE_HZ,
window_duration: float = parameters.FFT_WIN_LENGTH_S,
window_overlap: float = parameters.FFT_OVERLAP,
resize_factor: float = parameters.RESIZE_FACTOR,
) -> int:
n_fft = np.floor(window_duration * samplerate)
n_overlap = np.floor(window_overlap * n_fft)
n_step = n_fft - n_overlap
x_pos = int(x_pos / resize_factor)
return int((x_pos * n_step) + n_overlap)
def generate_spectrogram( def generate_spectrogram(
audio, audio,
sampling_rate, sampling_rate,
@ -114,21 +147,20 @@ def generate_spectrogram(
return spec, spec_for_viz return spec, spec_for_viz
def load_audio( def load_audio(
audio_file: str, path: AudioPath,
time_exp_fact: float, time_exp_fact: float,
target_samp_rate: int, target_samp_rate: int,
scale: bool = False, scale: bool = False,
max_duration: Optional[float] = None, max_duration: Optional[float] = None,
) -> Tuple[int, np.ndarray]: ) -> Tuple[int, np.ndarray ]:
"""Load an audio file and resample it to the target sampling rate. """Load an audio file and resample it to the target sampling rate.
The audio is also scaled to [-1, 1] and clipped to the maximum duration. The audio is also scaled to [-1, 1] and clipped to the maximum duration.
Only mono files are supported. Only mono files are supported.
Args: Args:
audio_file (str): Path to the audio file. path (string, int, pathlib.Path, soundfile.SoundFile, audioread object, or file-like object): path to the input file.
target_samp_rate (int): Target sampling rate. target_samp_rate (int): Target sampling rate.
scale (bool): Whether to scale the audio to [-1, 1]. scale (bool): Whether to scale the audio to [-1, 1].
max_duration (float): Maximum duration of the audio in seconds. max_duration (float): Maximum duration of the audio in seconds.
@ -140,20 +172,50 @@ def load_audio(
Raises: Raises:
ValueError: If the audio file is stereo. ValueError: If the audio file is stereo.
"""
sample_rate, audio_data, _ = load_audio_and_samplerate(path, time_exp_fact, target_samp_rate, scale, max_duration)
return sample_rate, audio_data
def load_audio_and_samplerate(
path: AudioPath,
time_exp_fact: float,
target_samp_rate: int,
scale: bool = False,
max_duration: Optional[float] = None,
) -> Tuple[int, np.ndarray, Union[float, int]]:
"""Load an audio file and resample it to the target sampling rate.
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
Only mono files are supported.
Args:
path (string, int, pathlib.Path, soundfile.SoundFile, audioread object, or file-like object): path to the input file.
target_samp_rate (int): Target sampling rate.
scale (bool): Whether to scale the audio to [-1, 1].
max_duration (float): Maximum duration of the audio in seconds.
Returns:
sampling_rate: The sampling rate of the audio.
audio_raw: The audio signal in a numpy array.
file_sampling_rate: The original sampling rate of the audio
Raises:
ValueError: If the audio file is stereo.
""" """
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning) warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
# sampling_rate, audio_raw = wavfile.read(audio_file) # sampling_rate, audio_raw = wavfile.read(audio_file)
audio_raw, sampling_rate = librosa.load( audio_raw, file_sampling_rate = librosa.load(
audio_file, path,
sr=None, sr=None,
dtype=np.float32, dtype=np.float32,
) )
if len(audio_raw.shape) > 1: if len(audio_raw.shape) > 1:
raise ValueError("Currently does not handle stereo files") raise ValueError("Currently does not handle stereo files")
sampling_rate = sampling_rate * time_exp_fact sampling_rate = file_sampling_rate * time_exp_fact
# resample - need to do this after correcting for time expansion # resample - need to do this after correcting for time expansion
sampling_rate_old = sampling_rate sampling_rate_old = sampling_rate
@ -181,58 +243,121 @@ def load_audio(
audio_raw = audio_raw - audio_raw.mean() audio_raw = audio_raw - audio_raw.mean()
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6) audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
return sampling_rate, audio_raw return sampling_rate, audio_raw, file_sampling_rate
def compute_spectrogram_width(
length: int,
samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
window_duration: float = parameters.FFT_WIN_LENGTH_S,
window_overlap: float = parameters.FFT_OVERLAP,
resize_factor: float = parameters.RESIZE_FACTOR,
) -> int:
n_fft = int(window_duration * samplerate)
n_overlap = int(window_overlap * n_fft)
n_step = n_fft - n_overlap
width = (length - n_overlap) // n_step
return int(width * resize_factor)
def pad_audio( def pad_audio(
audio_raw, audio: np.ndarray,
fs, samplerate: int = parameters.TARGET_SAMPLERATE_HZ,
ms, window_duration: float = parameters.FFT_WIN_LENGTH_S,
overlap_perc, window_overlap: float = parameters.FFT_OVERLAP,
resize_factor, resize_factor: float = parameters.RESIZE_FACTOR,
divide_factor, divide_factor: int = parameters.SPEC_DIVIDE_FACTOR,
fixed_width=None, fixed_width: Optional[int] = None,
): ):
# Adds zeros to the end of the raw data so that the generated sepctrogram """Pad audio to be evenly divisible by `divide_factor`.
# will be evenly divisible by `divide_factor`
# Also deals with very short audio clips and fixed_width during training
# This code could be clearer, clean up This function pads the audio signal with zeros to ensure that the
nfft = int(ms * fs) generated spectrogram length will be evenly divisible by `divide_factor`.
noverlap = int(overlap_perc * nfft) This is important for the model to work correctly.
step = nfft - noverlap
min_size = int(divide_factor * (1.0 / resize_factor))
spec_width = (audio_raw.shape[0] - noverlap) // step
spec_width_rs = spec_width * resize_factor
if fixed_width is not None and spec_width < fixed_width: This `divide_factor` comes from the model architecture as it downscales
# too small the spectrogram by this factor, so the input must be divisible by this
# used during training to ensure all the batches are the same size integer number.
diff = fixed_width * step + noverlap - audio_raw.shape[0]
audio_raw = np.hstack( Parameters
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype)) ----------
audio : np.ndarray
The audio signal.
samplerate : int
The sampling rate of the audio signal.
window_size : float
The window size in seconds used for the spectrogram computation.
window_overlap : float
The overlap between windows in the spectrogram computation.
resize_factor : float
This factor is used to resize the spectrogram after the STFT
computation. Default is 0.5 which means that the spectrogram will be
reduced by half. Important to take into account for the final size of
the spectrogram.
divide_factor : int
The factor by which the spectrogram will be divided.
fixed_width : int, optional
If provided, the audio will be padded or cut so that the resulting
spectrogram width will be equal to this value.
Returns
-------
np.ndarray
The padded audio signal.
"""
spec_width = compute_spectrogram_width(
audio.shape[0],
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
resize_factor=resize_factor,
)
if fixed_width:
target_samples = x_coord_to_sample(
fixed_width,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
resize_factor=resize_factor,
) )
elif fixed_width is not None and spec_width > fixed_width: if spec_width < fixed_width:
# too big # need to be at least min_size
# used during training to ensure all the batches are the same size diff = target_samples - audio.shape[0]
diff = fixed_width * step + noverlap - audio_raw.shape[0] return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
audio_raw = audio_raw[:diff]
elif ( if spec_width > fixed_width:
spec_width_rs < min_size return audio[:target_samples]
or (np.floor(spec_width_rs) % divide_factor) != 0
): return audio
# need to be at least min_size
div_amt = np.ceil(spec_width_rs / float(divide_factor)) min_width = int(divide_factor / resize_factor)
div_amt = np.maximum(1, div_amt)
target_size = int(div_amt * divide_factor * (1.0 / resize_factor)) if spec_width < min_width:
diff = target_size * step + noverlap - audio_raw.shape[0] target_samples = x_coord_to_sample(
audio_raw = np.hstack( min_width,
(audio_raw, np.zeros(diff, dtype=audio_raw.dtype)) samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
resize_factor=resize_factor,
) )
diff = target_samples - audio.shape[0]
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
return audio_raw if (spec_width % divide_factor) == 0:
return audio
target_width = int(np.ceil(spec_width / divide_factor)) * divide_factor
target_samples = x_coord_to_sample(
target_width,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
resize_factor=resize_factor,
)
diff = target_samples - audio.shape[0]
return np.hstack((audio, np.zeros(diff, dtype=audio.dtype)))
def gen_mag_spectrogram(x, fs, ms, overlap_perc): def gen_mag_spectrogram(x, fs, ms, overlap_perc):
@ -247,7 +372,11 @@ def gen_mag_spectrogram(x, fs, ms, overlap_perc):
# compute spec # compute spec
spec, _ = librosa.core.spectrum._spectrogram( spec, _ = librosa.core.spectrum._spectrogram(
y=x, power=1, n_fft=nfft, hop_length=step, center=False y=x,
power=1,
n_fft=nfft,
hop_length=step,
center=False,
) )
# remove DC component and flip vertical orientation # remove DC component and flip vertical orientation

View File

@ -1,12 +1,19 @@
import json import json
import os import os
from typing import Any, Iterator, List, Optional, Tuple, Union from typing import Any, Iterator, List, Optional, Tuple, Union, BinaryIO
from ..types import AudioPath
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
try:
from numpy.exceptions import AxisError
except ImportError:
from numpy import AxisError # type: ignore
import batdetect2.detector.compute_features as feats import batdetect2.detector.compute_features as feats
import batdetect2.detector.post_process as pp import batdetect2.detector.post_process as pp
import batdetect2.utils.audio_utils as au import batdetect2.utils.audio_utils as au
@ -25,6 +32,13 @@ from batdetect2.types import (
SpectrogramParameters, SpectrogramParameters,
) )
import audioread
import os
import io
import soundfile as sf
import hashlib
import uuid
__all__ = [ __all__ = [
"load_model", "load_model",
"list_audio_files", "list_audio_files",
@ -66,7 +80,6 @@ def list_audio_files(ip_dir: str) -> List[str]:
Raises: Raises:
FileNotFoundError: Input directory not found. FileNotFoundError: Input directory not found.
""" """
matches = [] matches = []
for root, _, filenames in os.walk(ip_dir): for root, _, filenames in os.walk(ip_dir):
@ -80,6 +93,7 @@ def load_model(
model_path: str = DEFAULT_MODEL_PATH, model_path: str = DEFAULT_MODEL_PATH,
load_weights: bool = True, load_weights: bool = True,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
weights_only: bool = True,
) -> Tuple[DetectionModel, ModelParameters]: ) -> Tuple[DetectionModel, ModelParameters]:
"""Load model from file. """Load model from file.
@ -100,7 +114,11 @@ def load_model(
if not os.path.isfile(model_path): if not os.path.isfile(model_path):
raise FileNotFoundError("Model file not found.") raise FileNotFoundError("Model file not found.")
net_params = torch.load(model_path, map_location=device) net_params = torch.load(
model_path,
map_location=device,
weights_only=weights_only,
)
params = net_params["params"] params = net_params["params"]
@ -143,7 +161,19 @@ def load_model(
def _merge_results(predictions, spec_feats, cnn_feats, spec_slices): def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
predictions_m = {} predictions_m = {
"det_probs": np.array([]),
"x_pos": np.array([]),
"y_pos": np.array([]),
"bb_widths": np.array([]),
"bb_heights": np.array([]),
"start_times": np.array([]),
"end_times": np.array([]),
"low_freqs": np.array([]),
"high_freqs": np.array([]),
"class_probs": np.array([]),
}
num_preds = np.sum([len(pp["det_probs"]) for pp in predictions]) num_preds = np.sum([len(pp["det_probs"]) for pp in predictions])
if num_preds > 0: if num_preds > 0:
@ -151,10 +181,6 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
predictions_m[key] = np.hstack( predictions_m[key] = np.hstack(
[pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0] [pp[key] for pp in predictions if pp["det_probs"].shape[0] > 0]
) )
else:
# hack in case where no detected calls as we need some of the key
# names in dict
predictions_m = predictions[0]
if len(spec_feats) > 0: if len(spec_feats) > 0:
spec_feats = np.vstack(spec_feats) spec_feats = np.vstack(spec_feats)
@ -226,11 +252,19 @@ def format_single_result(
Returns: Returns:
dict: Results in the format expected by the annotation tool. dict: Results in the format expected by the annotation tool.
""" """
# Get a single class prediction for the file try:
class_overall = pp.overall_class_pred( # Get a single class prediction for the file
predictions["det_probs"], class_overall = pp.overall_class_pred(
predictions["class_probs"], predictions["det_probs"],
) predictions["class_probs"],
)
class_name = class_names[np.argmax(class_overall)]
annotations = get_annotations_from_preds(predictions, class_names)
except (AxisError, ValueError):
# No detections
class_overall = np.zeros(len(class_names))
class_name = "None"
annotations = []
return { return {
"id": file_id, "id": file_id,
@ -239,8 +273,8 @@ def format_single_result(
"notes": "Automatically generated.", "notes": "Automatically generated.",
"time_exp": time_exp, "time_exp": time_exp,
"duration": round(float(duration), 4), "duration": round(float(duration), 4),
"annotation": get_annotations_from_preds(predictions, class_names), "annotation": annotations,
"class_name": class_names[np.argmax(class_overall)], "class_name": class_name,
} }
@ -253,6 +287,7 @@ def convert_results(
spec_feats, spec_feats,
cnn_feats, cnn_feats,
spec_slices, spec_slices,
nyquist_freq: Optional[float] = None,
) -> RunResults: ) -> RunResults:
"""Convert results to dictionary as expected by the annotation tool. """Convert results to dictionary as expected by the annotation tool.
@ -268,8 +303,8 @@ def convert_results(
Returns: Returns:
dict: Dictionary with results. dict: Dictionary with results.
""" """
pred_dict = format_single_result( pred_dict = format_single_result(
file_id, file_id,
time_exp, time_exp,
@ -278,6 +313,14 @@ def convert_results(
params["class_names"], params["class_names"],
) )
# Remove high frequency detections
if nyquist_freq is not None:
pred_dict["annotation"] = [
pred
for pred in pred_dict["annotation"]
if pred["high_freq"] <= nyquist_freq
]
# combine into final results dictionary # combine into final results dictionary
results: RunResults = { results: RunResults = {
"pred_dict": pred_dict, "pred_dict": pred_dict,
@ -310,7 +353,6 @@ def save_results_to_file(results, op_path: str) -> None:
Args: Args:
results (dict): Results. results (dict): Results.
op_path (str): Output path. op_path (str): Output path.
""" """
# make directory if it does not exist # make directory if it does not exist
if not os.path.isdir(os.path.dirname(op_path)): if not os.path.isdir(os.path.dirname(op_path)):
@ -472,7 +514,6 @@ def iterate_over_chunks(
chunk_start : float chunk_start : float
Start time of chunk in seconds. Start time of chunk in seconds.
chunk : np.ndarray chunk : np.ndarray
""" """
nsamples = audio.shape[0] nsamples = audio.shape[0]
duration_full = nsamples / samplerate duration_full = nsamples / samplerate
@ -678,7 +719,6 @@ def process_audio_array(
The array is of shape (num_detections, num_features). The array is of shape (num_detections, num_features).
spec : torch.Tensor spec : torch.Tensor
Spectrogram of the audio used as input. Spectrogram of the audio used as input.
""" """
pred_nms, features, spec = _process_audio_array( pred_nms, features, spec = _process_audio_array(
audio, audio,
@ -697,10 +737,11 @@ def process_audio_array(
def process_file( def process_file(
audio_file: str, path: AudioPath,
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
device: torch.device, device: torch.device,
file_id: Optional[str] = None
) -> Union[RunResults, Any]: ) -> Union[RunResults, Any]:
"""Process a single audio file with detection model. """Process a single audio file with detection model.
@ -709,7 +750,7 @@ def process_file(
Parameters Parameters
---------- ----------
audio_file : str path : AudioPath
Path to audio file. Path to audio file.
model : torch.nn.Module model : torch.nn.Module
@ -717,6 +758,9 @@ def process_file(
config : ProcessingConfiguration config : ProcessingConfiguration
Configuration for processing. Configuration for processing.
file_id: Optional[str],
Give the data an id. Defaults to the filename if path is a string. Otherwise an md5 will be calculated from the binary data.
Returns Returns
------- -------
@ -731,14 +775,16 @@ def process_file(
spec_slices = [] spec_slices = []
# load audio file # load audio file
sampling_rate, audio_full = au.load_audio( sampling_rate, audio_full, file_samp_rate = au.load_audio_and_samplerate(
audio_file, path,
time_exp_fact=config.get("time_expansion", 1) or 1, time_exp_fact=config.get("time_expansion", 1) or 1,
target_samp_rate=config["target_samp_rate"], target_samp_rate=config["target_samp_rate"],
scale=config["scale_raw_audio"], scale=config["scale_raw_audio"],
max_duration=config.get("max_duration"), max_duration=config.get("max_duration"),
) )
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
# loop through larger file and split into chunks # loop through larger file and split into chunks
# TODO: fix so that it overlaps correctly and takes care of # TODO: fix so that it overlaps correctly and takes care of
# duplicate detections at borders # duplicate detections at borders
@ -757,7 +803,7 @@ def process_file(
) )
# convert to numpy # convert to numpy
spec_np = spec.detach().cpu().numpy() spec_np = spec.detach().cpu().numpy().squeeze()
# add chunk time to start and end times # add chunk time to start and end times
pred_nms["start_times"] += chunk_time pred_nms["start_times"] += chunk_time
@ -777,9 +823,7 @@ def process_file(
if config["spec_slices"]: if config["spec_slices"]:
# FIX: This is not currently working. Returns empty slices # FIX: This is not currently working. Returns empty slices
spec_slices.extend( spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms))
feats.extract_spec_slices(spec_np, pred_nms, config)
)
# Merge results from chunks # Merge results from chunks
predictions, spec_feats, cnn_feats, spec_slices = _merge_results( predictions, spec_feats, cnn_feats, spec_slices = _merge_results(
@ -789,9 +833,13 @@ def process_file(
spec_slices, spec_slices,
) )
_file_id = file_id
if _file_id is None:
_file_id = _generate_id(path)
# convert results to a dictionary in the right format # convert results to a dictionary in the right format
results = convert_results( results = convert_results(
file_id=os.path.basename(audio_file), file_id=_file_id,
time_exp=config.get("time_expansion", 1) or 1, time_exp=config.get("time_expansion", 1) or 1,
duration=audio_full.shape[0] / float(sampling_rate), duration=audio_full.shape[0] / float(sampling_rate),
params=config, params=config,
@ -799,6 +847,7 @@ def process_file(
spec_feats=spec_feats, spec_feats=spec_feats,
cnn_feats=cnn_feats, cnn_feats=cnn_feats,
spec_slices=spec_slices, spec_slices=spec_slices,
nyquist_freq=orig_samp_rate / 2,
) )
# summarize results # summarize results
@ -810,6 +859,22 @@ def process_file(
return results return results
def _generate_id(path: AudioPath) -> str:
""" Generate an id based on the path.
If the path is a str or PathLike it will parsed as the basename.
This should ensure backwards compatibility with previous versions.
"""
if isinstance(path, str) or isinstance(path, os.PathLike):
return os.path.basename(path)
elif isinstance(path, (BinaryIO, io.BytesIO)):
path.seek(0)
md5 = hashlib.md5(path.read()).hexdigest()
path.seek(0)
return md5
else:
return str(uuid.uuid4())
def summarize_results(results, predictions, config): def summarize_results(results, predictions, config):
"""Print summary of results.""" """Print summary of results."""

1337
pdm.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,82 +1,82 @@
[tool.pdm]
[tool.pdm.dev-dependencies]
dev = [
"pytest>=7.2.2",
]
[project] [project]
name = "batdetect2" name = "batdetect2"
version = "1.0.4" version = "1.3.0"
description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings." description = "Deep learning model for detecting and classifying bat echolocation calls in high frequency audio recordings."
authors = [ authors = [
{ "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" }, { "name" = "Oisin Mac Aodha", "email" = "oisin.macaodha@ed.ac.uk" },
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" } { "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" },
] ]
dependencies = [ dependencies = [
"librosa", "click>=8.1.7",
"matplotlib", "librosa>=0.10.1",
"numpy", "matplotlib>=3.7.1",
"pandas", "numpy>=1.23.5",
"scikit-learn", "pandas>=1.5.3",
"scipy", "scikit-learn>=1.2.2",
"torch>=1.13.1,<2", "scipy>=1.10.1",
"torchaudio", "torch>=1.13.1,<2.5.0",
"torchvision", "torchaudio>=1.13.1,<2.5.0",
"click", "torchvision>=0.14.0",
] ]
requires-python = ">=3.8,<3.11" requires-python = ">=3.9,<3.13"
readme = "README.md" readme = "README.md"
license = { text = "CC-by-nc-4" } license = { text = "CC-by-nc-4" }
classifiers = [ classifiers = [
"Development Status :: 4 - Beta", "Development Status :: 4 - Beta",
"Intended Audience :: Science/Research", "Intended Audience :: Science/Research",
"Natural Language :: English", "Natural Language :: English",
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering :: Artificial Intelligence", "Programming Language :: Python :: 3.11",
"Topic :: Software Development :: Libraries :: Python Modules", "Programming Language :: Python :: 3.12",
"Topic :: Multimedia :: Sound/Audio :: Analysis", "Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Multimedia :: Sound/Audio :: Analysis",
] ]
keywords = [ keywords = [
"bat", "bat",
"echolocation", "echolocation",
"deep learning", "deep learning",
"audio", "audio",
"machine learning", "machine learning",
"classification", "classification",
"detection", "detection",
] ]
[build-system] [build-system]
requires = ["pdm-pep517>=1.0.0"] requires = ["hatchling"]
build-backend = "pdm.pep517.api" build-backend = "hatchling.build"
[project.scripts] [project.scripts]
batdetect2 = "batdetect2.cli:cli" batdetect2 = "batdetect2.cli:cli"
[tool.black] [tool.uv]
line-length = 80 dev-dependencies = [
"debugpy>=1.8.8",
[[tool.mypy.overrides]] "hypothesis>=6.118.7",
module = [ "pyright>=1.1.388",
"librosa", "pytest>=7.2.2",
"pandas", "ruff>=0.7.3",
] ]
ignore_missing_imports = true
[tool.pylsp-mypy] [tool.ruff]
enabled = false line-length = 79
live_mode = true target-version = "py39"
strict = true
[tool.pydocstyle] [tool.ruff.format]
docstring-code-format = true
docstring-code-line-length = 79
[tool.ruff.lint]
select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
[tool.ruff.lint.pydocstyle]
convention = "numpy" convention = "numpy"
[tool.pyright] [tool.pyright]
include = [ include = ["batdetect2", "tests"]
"bat_detect",
"tests",
]
venvPath = "." venvPath = "."
venv = ".venv" venv = ".venv"
pythonVersion = "3.9"
pythonPlatform = "All"

40
tests/conftest.py Normal file
View File

@ -0,0 +1,40 @@
from pathlib import Path
from typing import List
import pytest
@pytest.fixture
def example_data_dir() -> Path:
pkg_dir = Path(__file__).parent.parent
example_data_dir = pkg_dir / "example_data"
assert example_data_dir.exists()
return example_data_dir
@pytest.fixture
def example_audio_dir(example_data_dir: Path) -> Path:
example_audio_dir = example_data_dir / "audio"
assert example_audio_dir.exists()
return example_audio_dir
@pytest.fixture
def example_audio_files(example_audio_dir: Path) -> List[Path]:
audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]"))
assert len(audio_files) == 3
return audio_files
@pytest.fixture
def data_dir() -> Path:
dir = Path(__file__).parent / "data"
assert dir.exists()
return dir
@pytest.fixture
def contrib_dir(data_dir) -> Path:
dir = data_dir / "contrib"
assert dir.exists()
return dir

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -2,17 +2,21 @@
import os import os
from glob import glob from glob import glob
from pathlib import Path
import numpy as np import numpy as np
import soundfile as sf
import torch import torch
from torch import nn from torch import nn
from batdetect2 import api from batdetect2 import api
import io
PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) PKG_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio") TEST_DATA_DIR = os.path.join(PKG_DIR, "example_data", "audio")
TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav")) TEST_DATA = glob(os.path.join(TEST_DATA_DIR, "*.wav"))
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
def test_load_model_with_default_params(): def test_load_model_with_default_params():
"""Test loading model with default parameters.""" """Test loading model with default parameters."""
@ -262,3 +266,44 @@ def test_process_file_with_spec_slices():
assert "spec_slices" in results assert "spec_slices" in results
assert isinstance(results["spec_slices"], list) assert isinstance(results["spec_slices"], list)
assert len(results["spec_slices"]) == len(detections) assert len(results["spec_slices"]) == len(detections)
def test_process_file_with_empty_predictions_does_not_fail(
tmp_path: Path,
):
"""Test process file with empty predictions does not fail."""
# Create empty file
empty_file = tmp_path / "empty.wav"
empty_wav = np.zeros((0, 1), dtype=np.float32)
sf.write(empty_file, empty_wav, 256000)
# Process file
results = api.process_file(str(empty_file))
assert results is not None
assert len(results["pred_dict"]["annotation"]) == 0
def test_process_file_file_id_defaults_to_basename():
"""Test that process_file assigns basename as an id if no file_id is provided."""
# Recording donated by @@kdarras
basename = "20230322_172000_selec2.wav"
path = os.path.join(DATA_DIR, basename)
output = api.process_file(path)
predictions = output["pred_dict"]
id = predictions["id"]
assert id == basename
def test_bytesio_file_id_defaults_to_md5():
"""Test that process_file assigns an md5 sum as an id if no file_id is provided when using binary data."""
# Recording donated by @@kdarras
basename = "20230322_172000_selec2.wav"
path = os.path.join(DATA_DIR, basename)
with open(path, "rb") as f:
data = io.BytesIO(f.read())
output = api.process_file(data)
predictions = output["pred_dict"]
id = predictions["id"]
assert id == "7ade9ebf1a9fe5477ff3a2dc57001929"

156
tests/test_audio_utils.py Normal file
View File

@ -0,0 +1,156 @@
import numpy as np
import torch
import torch.nn.functional as F
from hypothesis import given
from hypothesis import strategies as st
from batdetect2.detector import parameters
from batdetect2.utils import audio_utils, detector_utils
import io
import os
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
@given(duration=st.floats(min_value=0.1, max_value=2))
def test_can_compute_correct_spectrogram_width(duration: float):
samplerate = parameters.TARGET_SAMPLERATE_HZ
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
length = int(duration * samplerate)
audio = np.random.rand(length)
spectrogram, _ = audio_utils.generate_spectrogram(
audio,
samplerate,
params,
)
# convert to pytorch
spectrogram = torch.from_numpy(spectrogram)
# add batch and channel dimensions
spectrogram = spectrogram.unsqueeze(0).unsqueeze(0)
# resize the spec
resize_factor = params["resize_factor"]
spec_op_shape = (
int(params["spec_height"] * resize_factor),
int(spectrogram.shape[-1] * resize_factor),
)
spectrogram = F.interpolate(
spectrogram,
size=spec_op_shape,
mode="bilinear",
align_corners=False,
)
expected_width = audio_utils.compute_spectrogram_width(
length,
samplerate=parameters.TARGET_SAMPLERATE_HZ,
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
resize_factor=params["resize_factor"],
)
assert spectrogram.shape[-1] == expected_width
@given(duration=st.floats(min_value=0.1, max_value=2))
def test_pad_audio_without_fixed_size(duration: float):
# Test the pad_audio function
# This function is used to pad audio with zeros to a specific length
# It is used in the generate_spectrogram function
# The function is tested with a simplepas
samplerate = parameters.TARGET_SAMPLERATE_HZ
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
length = int(duration * samplerate)
audio = np.random.rand(length)
# pad the audio to be divisible by divide factor
padded_audio = audio_utils.pad_audio(
audio,
samplerate=samplerate,
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
resize_factor=params["resize_factor"],
divide_factor=params["spec_divide_factor"],
)
# check that the padded audio is divisible by the divide factor
expected_width = audio_utils.compute_spectrogram_width(
len(padded_audio),
samplerate=parameters.TARGET_SAMPLERATE_HZ,
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
resize_factor=params["resize_factor"],
)
assert expected_width % params["spec_divide_factor"] == 0
@given(duration=st.floats(min_value=0.1, max_value=2))
def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
duration: float,
):
samplerate = parameters.TARGET_SAMPLERATE_HZ
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
length = int(duration * samplerate)
audio = np.random.rand(length)
_, spectrogram, _ = detector_utils.compute_spectrogram(
audio,
samplerate,
params,
torch.device("cpu"),
)
assert spectrogram.shape[-1] % params["spec_divide_factor"] == 0
@given(
duration=st.floats(min_value=0.1, max_value=2),
width=st.integers(min_value=128, max_value=1024),
)
def test_pad_audio_with_fixed_width(duration: float, width: int):
samplerate = parameters.TARGET_SAMPLERATE_HZ
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
length = int(duration * samplerate)
audio = np.random.rand(length)
# pad the audio to be divisible by divide factor
padded_audio = audio_utils.pad_audio(
audio,
samplerate=samplerate,
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
resize_factor=params["resize_factor"],
divide_factor=params["spec_divide_factor"],
fixed_width=width,
)
# check that the padded audio is divisible by the divide factor
expected_width = audio_utils.compute_spectrogram_width(
len(padded_audio),
samplerate=parameters.TARGET_SAMPLERATE_HZ,
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
resize_factor=params["resize_factor"],
)
assert expected_width == width
def test_load_audio_using_bytesio():
basename = "20230322_172000_selec2.wav"
path = os.path.join(DATA_DIR, basename)
with open(path, "rb") as f:
data = io.BytesIO(f.read())
sample_rate, audio_data, file_sample_rate = audio_utils.load_audio_and_samplerate(data, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ)
expected_sample_rate, expected_audio_data, exp_file_sample_rate = audio_utils.load_audio_and_samplerate(path, time_exp_fact=1, target_samp_rate=parameters.TARGET_SAMPLERATE_HZ)
assert expected_sample_rate == sample_rate
assert exp_file_sample_rate == file_sample_rate
assert np.array_equal(audio_data, expected_audio_data)

View File

@ -1,20 +1,26 @@
"""Test the command line interface.""" """Test the command line interface."""
from pathlib import Path
import pandas as pd
from click.testing import CliRunner from click.testing import CliRunner
from batdetect2.cli import cli from batdetect2.cli import cli
runner = CliRunner()
def test_cli_base_command(): def test_cli_base_command():
"""Test the base command.""" """Test the base command."""
runner = CliRunner()
result = runner.invoke(cli, ["--help"]) result = runner.invoke(cli, ["--help"])
assert result.exit_code == 0 assert result.exit_code == 0
assert "BatDetect2 - Bat Call Detection and Classification" in result.output assert (
"BatDetect2 - Bat Call Detection and Classification" in result.output
)
def test_cli_detect_command_help(): def test_cli_detect_command_help():
"""Test the detect command help.""" """Test the detect command help."""
runner = CliRunner()
result = runner.invoke(cli, ["detect", "--help"]) result = runner.invoke(cli, ["detect", "--help"])
assert result.exit_code == 0 assert result.exit_code == 0
assert "Detect bat calls in files in AUDIO_DIR" in result.output assert "Detect bat calls in files in AUDIO_DIR" in result.output
@ -28,7 +34,6 @@ def test_cli_detect_command_on_test_audio(tmp_path):
if results_dir.exists(): if results_dir.exists():
results_dir.rmdir() results_dir.rmdir()
runner = CliRunner()
result = runner.invoke( result = runner.invoke(
cli, cli,
[ [
@ -52,7 +57,6 @@ def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
if results_dir.exists(): if results_dir.exists():
results_dir.rmdir() results_dir.rmdir()
runner = CliRunner()
result = runner.invoke( result = runner.invoke(
cli, cli,
[ [
@ -66,4 +70,89 @@ def test_cli_detect_command_with_non_trivial_time_expansion(tmp_path):
) )
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Time Expansion Factor: 10' in result.stdout assert "Time Expansion Factor: 10" in result.stdout
def test_cli_detect_command_with_the_spec_feature_flag(tmp_path: Path):
"""Test the detect command with the spec feature flag."""
results_dir = tmp_path / "results"
# Remove results dir if it exists
if results_dir.exists():
results_dir.rmdir()
result = runner.invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
"--spec_features",
],
)
assert result.exit_code == 0
assert results_dir.exists()
csv_files = [path.name for path in results_dir.glob("*.csv")]
expected_files = [
"20170701_213954-MYOMYS-LR_0_0.5.wav_spec_features.csv",
"20180530_213516-EPTSER-LR_0_0.5.wav_spec_features.csv",
"20180627_215323-RHIFER-LR_0_0.5.wav_spec_features.csv",
]
for expected_file in expected_files:
assert expected_file in csv_files
df = pd.read_csv(results_dir / expected_file)
assert not (df.duration == -1).any()
def test_cli_detect_fails_gracefully_on_empty_file(tmp_path: Path):
results_dir = tmp_path / "results"
target = tmp_path / "audio"
target.mkdir()
# Create an empty file with the .wav extension
empty_file = target / "empty.wav"
empty_file.touch()
result = runner.invoke(
cli,
args=[
"detect",
str(target),
str(results_dir),
"0.3",
"--spec_features",
],
)
assert result.exit_code == 0
assert f"Error processing file {empty_file}" in result.output
def test_can_set_chunk_size(tmp_path: Path):
results_dir = tmp_path / "results"
# Remove results dir if it exists
if results_dir.exists():
results_dir.rmdir()
result = runner.invoke(
cli,
[
"detect",
"example_data/audio",
str(results_dir),
"0.3",
"--chunk_size",
"1",
],
)
assert "Chunk Size: 1.0s" in result.output
assert result.exit_code == 0
assert results_dir.exists()
assert len(list(results_dir.glob("*.csv"))) == 3
assert len(list(results_dir.glob("*.json"))) == 3

73
tests/test_contrib.py Normal file
View File

@ -0,0 +1,73 @@
"""Test suite to ensure user provided files are correctly processed."""
from pathlib import Path
from click.testing import CliRunner
from batdetect2.cli import cli
runner = CliRunner()
def test_can_process_jeff37_files(
contrib_dir: Path,
tmp_path: Path,
):
"""This test stems from issue #31.
A user provided a set of files which which batdetect2 cli failed and
generated the following error message:
[2272] "Error processing file!: negative dimensions are not allowed"
This test ensures that the error message is not generated when running
batdetect2 cli with the same set of files.
"""
path = contrib_dir / "jeff37"
assert path.exists()
results_dir = tmp_path / "results"
result = runner.invoke(
cli,
[
"detect",
str(path),
str(results_dir),
"0.3",
],
)
assert result.exit_code == 0
assert results_dir.exists()
assert len(list(results_dir.glob("*.csv"))) == 5
assert len(list(results_dir.glob("*.json"))) == 5
def test_can_process_padpadpadpad_files(
contrib_dir: Path,
tmp_path: Path,
):
"""This test stems from issue #29.
Batdetect2 cli failed on the files provided by the user @padpadpadpad
with the following error message:
AttributeError: module 'numpy' has no attribute 'AxisError'
This test ensures that the files are processed without any error.
"""
path = contrib_dir / "padpadpadpad"
assert path.exists()
results_dir = tmp_path / "results"
result = runner.invoke(
cli,
[
"detect",
str(path),
str(results_dir),
"0.3",
],
)
assert result.exit_code == 0
assert results_dir.exists()
assert len(list(results_dir.glob("*.csv"))) == 2
assert len(list(results_dir.glob("*.json"))) == 2

23
tests/test_detections.py Normal file
View File

@ -0,0 +1,23 @@
"""Test suite to ensure that model detections are not incorrect."""
import os
from batdetect2 import api
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
def test_no_detections_above_nyquist():
"""Test that no detections are made above the nyquist frequency."""
# Recording donated by @@kdarras
path = os.path.join(DATA_DIR, "20230322_172000_selec2.wav")
# This recording has a sampling rate of 192 kHz
nyquist = 192_000 / 2
output = api.process_file(path)
predictions = output["pred_dict"]
assert len(predictions["annotation"]) != 0
assert all(
pred["high_freq"] < nyquist for pred in predictions["annotation"]
)

291
tests/test_features.py Normal file
View File

@ -0,0 +1,291 @@
"""Test suite for feature extraction functions."""
import logging
import librosa
import numpy as np
import pytest
import batdetect2.detector.compute_features as feats
from batdetect2 import api, types
from batdetect2.utils import audio_utils as au
numba_logger = logging.getLogger("numba")
numba_logger.setLevel(logging.WARNING)
def index_to_freq(
index: int,
spec_height: int,
min_freq: int,
max_freq: int,
) -> float:
"""Convert spectrogram index to frequency in Hz."""
index = spec_height - index
return round(
(index / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
)
def index_to_time(
index: int,
spec_width: int,
spec_duration: float,
) -> float:
"""Convert spectrogram index to time in seconds."""
return round((index / float(spec_width)) * spec_duration, 2)
def test_get_feats_function_with_empty_spectrogram():
"""Test get_feats function with empty spectrogram.
This tests that the overall flow of the function works, even if the
spectrogram is empty.
"""
spec_duration = 3
spec_width = 100
spec_height = 100
min_freq = 10_000
max_freq = 120_000
spectrogram = np.zeros((spec_height, spec_width))
x_pos = 20
y_pos = 80
bb_width = 20
bb_height = 20
start_time = index_to_time(x_pos, spec_width, spec_duration)
end_time = index_to_time(x_pos + bb_width, spec_width, spec_duration)
low_freq = index_to_freq(y_pos, spec_height, min_freq, max_freq)
high_freq = index_to_freq(
y_pos - bb_height, spec_height, min_freq, max_freq
)
pred_nms: types.PredictionResults = {
"det_probs": np.array([1]),
"class_probs": np.array([[1]]),
"x_pos": np.array([x_pos]),
"y_pos": np.array([y_pos]),
"bb_width": np.array([bb_width]),
"bb_height": np.array([bb_height]),
"start_times": np.array([start_time]),
"end_times": np.array([end_time]),
"low_freqs": np.array([low_freq]),
"high_freqs": np.array([high_freq]),
}
params: types.FeatureExtractionParameters = {
"min_freq": min_freq,
"max_freq": max_freq,
}
features = feats.get_feats(spectrogram, pred_nms, params)
assert low_freq < high_freq
assert isinstance(features, np.ndarray)
assert features.shape == (len(pred_nms["det_probs"]), 9)
assert np.isclose(
features[0],
np.array(
[
end_time - start_time,
low_freq,
high_freq,
high_freq - low_freq,
high_freq,
max_freq,
max_freq,
max_freq,
np.nan,
]
),
equal_nan=True,
).all()
@pytest.mark.parametrize(
"max_power",
[
30_000,
31_000,
32_000,
33_000,
34_000,
35_000,
36_000,
37_000,
38_000,
39_000,
40_000,
],
)
def test_compute_max_power_bb(max_power: int):
"""Test compute_max_power_bb function."""
duration = 1
samplerate = 256_000
min_freq = 0
max_freq = 128_000
start_time = 0.3
end_time = 0.6
low_freq = 30_000
high_freq = 40_000
audio = np.zeros((int(duration * samplerate),))
# Add a signal during the time and frequency range of interest
audio[
int(start_time * samplerate) : int(end_time * samplerate)
] = 0.5 * librosa.tone(
max_power, sr=samplerate, duration=end_time - start_time
)
# Add a more powerful signal outside frequency range of interest
audio[
int(start_time * samplerate) : int(end_time * samplerate)
] += 2 * librosa.tone(
80_000, sr=samplerate, duration=end_time - start_time
)
params = api.get_config(
min_freq=min_freq,
max_freq=max_freq,
target_samp_rate=samplerate,
)
spec, _ = au.generate_spectrogram(
audio,
samplerate,
params,
)
x_start = int(
au.time_to_x_coords(
start_time,
samplerate,
params["fft_win_length"],
params["fft_overlap"],
)
)
x_end = int(
au.time_to_x_coords(
end_time,
samplerate,
params["fft_win_length"],
params["fft_overlap"],
)
)
num_freq_bins = spec.shape[0]
y_low = num_freq_bins - int(num_freq_bins * low_freq / max_freq)
y_high = num_freq_bins - int(num_freq_bins * high_freq / max_freq)
prediction: types.Prediction = {
"det_prob": 1,
"class_prob": np.ones((1,)),
"x_pos": x_start,
"y_pos": int(y_low),
"bb_width": int(x_end - x_start),
"bb_height": int(y_low - y_high),
"start_time": start_time,
"end_time": end_time,
"low_freq": low_freq,
"high_freq": high_freq,
}
print(prediction)
max_power_bb = feats.compute_max_power_bb(
prediction,
spec,
min_freq=min_freq,
max_freq=max_freq,
)
assert abs(max_power_bb - max_power) <= 500
def test_compute_max_power():
"""Test compute_max_power_bb function."""
duration = 3
samplerate = 16_000
min_freq = 0
max_freq = 8_000
start_time = 1
end_time = 2
low_freq = 3_000
high_freq = 4_000
max_power = 5_000
audio = np.zeros((int(duration * samplerate),))
# Add a signal during the time and frequency range of interest
audio[
int(start_time * samplerate) : int(end_time * samplerate)
] = 0.5 * librosa.tone(
3_500, sr=samplerate, duration=end_time - start_time
)
# Add a more powerful signal outside frequency range of interest
audio[
int(start_time * samplerate) : int(end_time * samplerate)
] += 2 * librosa.tone(
max_power, sr=samplerate, duration=end_time - start_time
)
params = api.get_config(
min_freq=min_freq,
max_freq=max_freq,
target_samp_rate=samplerate,
)
spec, _ = au.generate_spectrogram(
audio,
samplerate,
params,
)
x_start = int(
au.time_to_x_coords(
start_time,
samplerate,
params["fft_win_length"],
params["fft_overlap"],
)
)
x_end = int(
au.time_to_x_coords(
end_time,
samplerate,
params["fft_win_length"],
params["fft_overlap"],
)
)
num_freq_bins = spec.shape[0]
y_low = int(num_freq_bins * low_freq / max_freq)
y_high = int(num_freq_bins * high_freq / max_freq)
prediction: types.Prediction = {
"det_prob": 1,
"class_prob": np.ones((1,)),
"x_pos": x_start,
"y_pos": int(y_high),
"bb_width": int(x_end - x_start),
"bb_height": int(y_high - y_low),
"start_time": start_time,
"end_time": end_time,
"low_freq": low_freq,
"high_freq": high_freq,
}
computed_max_power = feats.compute_max_power(
prediction,
spec,
min_freq=min_freq,
max_freq=max_freq,
)
assert abs(computed_max_power - max_power) < 100

78
tests/test_model.py Normal file
View File

@ -0,0 +1,78 @@
"""Test suite for model functions."""
import warnings
from pathlib import Path
from typing import List
import numpy as np
from hypothesis import given, settings
from hypothesis import strategies as st
from batdetect2 import api
from batdetect2.detector import parameters
def test_can_import_model_without_warnings():
with warnings.catch_warnings():
warnings.simplefilter("error")
api.load_model()
@settings(deadline=None, max_examples=5)
@given(duration=st.floats(min_value=0.1, max_value=2))
def test_can_import_model_without_pickle(duration: float):
# NOTE: remove this test once no other issues are found This is a temporary
# test to check that change in model loading did not impact model behaviour
# in any way.
samplerate = parameters.TARGET_SAMPLERATE_HZ
audio = np.random.rand(int(duration * samplerate))
model_without_pickle, model_params_without_pickle = api.load_model(
weights_only=True
)
model_with_pickle, model_params_with_pickle = api.load_model(
weights_only=False
)
assert model_params_without_pickle == model_params_with_pickle
predictions_without_pickle, _, _ = api.process_audio(
audio,
model=model_without_pickle,
)
predictions_with_pickle, _, _ = api.process_audio(
audio,
model=model_with_pickle,
)
assert predictions_without_pickle == predictions_with_pickle
def test_can_import_model_without_pickle_on_test_data(
example_audio_files: List[Path],
):
# NOTE: remove this test once no other issues are found This is a temporary
# test to check that change in model loading did not impact model behaviour
# in any way.
model_without_pickle, model_params_without_pickle = api.load_model(
weights_only=True
)
model_with_pickle, model_params_with_pickle = api.load_model(
weights_only=False
)
assert model_params_without_pickle == model_params_with_pickle
for audio_file in example_audio_files:
audio = api.load_audio(str(audio_file))
predictions_without_pickle, _, _ = api.process_audio(
audio,
model=model_without_pickle,
)
predictions_with_pickle, _, _ = api.process_audio(
audio,
model=model_with_pickle,
)
assert predictions_without_pickle == predictions_with_pickle

1548
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff