Compare commits

..

No commits in common. "0adb1bbea7d1f908db65a1b036385ecc286f187e" and "4ae567bc1db0efcda3680fe3e8523805831dbc80" have entirely different histories.

363 changed files with 6981 additions and 43034 deletions

View File

@ -1,10 +1,8 @@
[bumpversion] [bumpversion]
current_version = 2.0.0b1 current_version = 1.3.1
commit = True commit = True
tag = True tag = True
[bumpversion:file:src/batdetect2/__init__.py] [bumpversion:file:batdetect2/__init__.py]
[bumpversion:file:pyproject.toml] [bumpversion:file:pyproject.toml]
[bumpversion:file:docs/source/conf.py]

View File

@ -1,79 +0,0 @@
name: CI
on:
pull_request:
push:
branches:
- main
concurrency:
group: ci-${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
checks:
name: Checks
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install just
uses: taiki-e/install-action@just
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
cache-dependency-glob: |
pyproject.toml
uv.lock
- name: Install dependencies
run: uv sync --all-extras --all-groups
- name: Run formatting, lint, and type checks
run: just check
tests:
name: Tests (Python ${{ matrix.python-version }})
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version:
- "3.10"
- "3.11"
- "3.12"
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install just
uses: taiki-e/install-action@just
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
cache-dependency-glob: |
pyproject.toml
uv.lock
- name: Install dependencies
run: uv sync --all-extras --all-groups
- name: Run test suite
run: just test

View File

@ -1,69 +0,0 @@
name: Docs Pages
on:
push:
branches:
- main
workflow_dispatch:
permissions:
contents: read
concurrency:
group: docs-pages
cancel-in-progress: true
jobs:
build:
name: Build Docs
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install just
uses: taiki-e/install-action@just
- name: Configure GitHub Pages
uses: actions/configure-pages@v5
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
cache-dependency-glob: |
pyproject.toml
uv.lock
- name: Install dependencies
run: uv sync --all-extras --all-groups
- name: Build docs
run: just check-docs
- name: Upload Pages artifact
uses: actions/upload-pages-artifact@v4
with:
path: docs/build
deploy:
name: Deploy Docs
needs: build
runs-on: ubuntu-latest
permissions:
pages: write
id-token: write
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4

View File

@ -1,70 +0,0 @@
name: Publish PyPI
on:
release:
types:
- published
permissions:
contents: read
concurrency:
group: publish-pypi
cancel-in-progress: false
jobs:
build:
name: Build Distributions
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install just
uses: taiki-e/install-action@just
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
cache-dependency-glob: |
pyproject.toml
uv.lock
- name: Install dependencies
run: uv sync --all-extras --all-groups
- name: Build distributions
run: just build-dist
- name: Upload distributions
uses: actions/upload-artifact@v4
with:
name: release-dists
path: dist/
publish:
name: Publish to PyPI
needs: build
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/batdetect2
steps:
- name: Download distributions
uses: actions/download-artifact@v5
with:
name: release-dists
path: dist/
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}

29
.github/workflows/python-package.yml vendored Normal file
View File

@ -0,0 +1,29 @@
name: Python package
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
jobs:
build:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
enable-cache: true
cache-dependency-glob: "uv.lock"
- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}
- name: Install the project
run: uv sync --all-extras --dev
- name: Test with pytest
run: uv run pytest

30
.github/workflows/python-publish.yml vendored Normal file
View File

@ -0,0 +1,30 @@
name: Upload Python Package
on:
release:
types: [published]
permissions:
contents: read
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: "3.x"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

28
.gitignore vendored
View File

@ -50,7 +50,6 @@ cover/
# Sphinx documentation # Sphinx documentation
docs/_build/ docs/_build/
docs/build/
# PyBuilder # PyBuilder
.pybuilder/ .pybuilder/
@ -96,15 +95,8 @@ dmypy.json
*.json *.json
plots/* plots/*
!example_data/anns/*.json
# Model experiments # Model experiments
experiments/* experiments/*
DvcLiveLogger/checkpoints
logs/
mlruns/
/outputs/
notebooks/lightning_logs
# Jupiter notebooks # Jupiter notebooks
.virtual_documents .virtual_documents
@ -113,24 +105,8 @@ notebooks/lightning_logs
# DO Include # DO Include
!batdetect2_notebook.ipynb !batdetect2_notebook.ipynb
!src/batdetect2/models/checkpoints/*.pth.tar !batdetect2/models/*.pth.tar
!tests/data/*.wav !tests/data/*.wav
!notebooks/*.ipynb
!tests/data/**/*.wav !tests/data/**/*.wav
.aider* notebooks/lightning_logs
# Intermediate artifacts
example_data/preprocessed example_data/preprocessed
# Dev notebooks
notebooks/tmp
/tmp
/.agents/skills
/notebooks
/AGENTS.md
/scripts
/todo.md
# Assets
!assets/*
/models

5
.pylintrc Normal file
View File

@ -0,0 +1,5 @@
[TYPECHECK]
# List of members which are set dynamically and missed by Pylint inference
# system, and so shouldn't trigger E1101 when accessed.
generated-members=torch.*

237
README.md
View File

@ -1,166 +1,161 @@
# BatDetect2 # BatDetect2
<img style="display: block-inline;" width="64" height="64" src="ims/bat_icon.png"> Code for detecting and classifying bat echolocation calls in high frequency audio recordings.
<img style="display:block-inline;" width="64" height="64" src="assets/bat_icon.png">
Code for detecting and classifying bat echolocation calls in high-frequency
audio recordings.
> [!WARNING]
> `batdetect2` 2.0.0b1 is out.
> This is a beta release and we are gathering user feedback.
> If you run into issues or have feedback on the new workflows, please use the
> GitHub issues page to let us know.
>
> There are many changes and new recommended workflows.
> We have left the previous `batdetect2.api` module intact, but if you run
> into issues or want to upgrade, see the
> [migration guide](docs/source/legacy/migration-guide.md) in the docs site.
>
> This update also ships with a refreshed default model.
> It was trained in the same way and on the same data as before, but you should
> still expect small output differences in some cases.
## What is BatDetect2
BatDetect2 is a deep learning model for detecting and classifying bat
echolocation calls.
The model generates multiple predictions for each input recording by providing a
bounding box and predicted class for each individual call within it.
This repository also holds `batdetect2`, a Python-based tool to run, train,
finetune and evaluate BatDetect2-type models, including the built-in model for
detecting UK bat species.
You can use the tool from the command line (terminal) or from Python as needed.
## Getting Started
We have [extensive documentation](docs/source/index.md) on how to use
`batdetect2`.
The docs site is still being built and will be live soon.
If you want a quick peek for now, see the `docs/` folder in this repository.
See our [getting started](docs/source/getting_started.md) guide and then jump
into any of our tutorials:
- Run the model on a folder of recordings:
`docs/source/tutorials/run-inference-on-folder.md`
- Train your own model:
`docs/source/tutorials/train-a-custom-model.md`
- Evaluate your model:
`docs/source/tutorials/evaluate-on-a-test-set.md`
- Fine-tune a model:
`docs/source/tutorials/integrate-with-a-python-pipeline.md`
### Try the model
If you want to try the model for UK bat species without installing anything, you
can try the following:
1. Demo of the model (for UK species) on
[huggingface](https://huggingface.co/spaces/macaodha/batdetect2).
2. Alternatively, click
[here](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb)
to run the model using Google Colab.
You can also run this notebook locally.
### Installing BatDetect2
> [!NOTE] > [!NOTE]
> `2.0.0b1` is a pre-release on PyPI. > Were actively working to make it easier to train and fine-tune BatDetect2 models using custom data. A major update is coming soon to the main branch—stay tuned! In the meantime, you can follow our progress in the train branch.
> You may need to request it explicitly by version, for example:
>
> ```bash
> uvx --from batdetect2==2.0.0b1 batdetect2
> uv tool install batdetect2==2.0.0b1
> pip install batdetect2==2.0.0b1
> ```
If you have `uv` installed (if not, we recommend it; follow the instructions ## Getting started
[here](https://docs.astral.sh/uv/getting-started/installation/)), then you can ### Python Environment
run `batdetect2` one-off with
We recommend using an isolated Python environment to avoid dependency issues. Choose one
of the following options:
* Install the Anaconda Python 3.10 distribution for your operating system from [here](https://www.continuum.io/downloads). Create a new environment and activate it:
```bash ```bash
uvx batdetect2 conda create -y --name batdetect2 python==3.10
conda activate batdetect2
``` ```
or if you want to install it permanently: * If you already have Python installed (version >= 3.8,< 3.11) and prefer using virtual environments then:
```bash ```bash
uv tool install batdetect2 python -m venv .venv
source .venv/bin/activate
``` ```
and test it with ### Installing BatDetect2
You can use pip to install `batdetect2`:
```bash ```bash
batdetect2 pip install batdetect2
``` ```
### Run BatDetect2 on a folder of recordings Alternatively, download this code from the repository (by clicking on the green button on top right) and unzip it.
Once unzipped, run this from extracted folder.
Once installed, you can run BatDetect2 on a folder of `.wav` files.
By default it will use the model trained on UK data.
Example command:
```bash ```bash
batdetect2 process directory example_data/audio outputs pip install .
``` ```
This will scan the audio files in `example_data/audio` and save model outputs to Make sure you have the environment activated before installing `batdetect2`.
`outputs`.
If you have your own model checkpoint, you can use it:
## Try the model
1) You can try a demo of the model (for UK species) on [huggingface](https://huggingface.co/spaces/macaodha/batdetect2).
2) Alternatively, click [here](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb) to run the model using Google Colab. You can also run this notebook locally.
## Running the model on your own data
After following the above steps to install the code you can run the model on your own data.
### Using the command line
You can run the model by opening the command line and typing:
```bash ```bash
batdetect2 process directory --model path/to/checkpoint.ckpt example_data/audio outputs batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD
```
e.g.
```bash
batdetect2 detect example_data/audio/ example_data/anns/ 0.3
``` ```
For the full walkthrough, use `AUDIO_DIR` is the path on your computer to the audio wav files of interest.
`docs/source/tutorials/run-inference-on-folder.md`. `ANN_DIR` is the path on your computer where the model predictions will be saved. The model will output both `.csv` and `.json` results for each audio file.
`DETECTION_THRESHOLD` is a number between 0 and 1 specifying the cut-off threshold applied to the calls. A smaller number will result in more calls detected, but with the chance of introducing more mistakes.
There are also optional arguments, e.g. you can request that the model outputs features (i.e. estimated call parameters) such as duration, max_frequency, etc. by setting the flag `--spec_features`. These will be saved as `*_spec_features.csv` files:
`batdetect2 detect example_data/audio/ example_data/anns/ 0.3 --spec_features`
You can also specify which model to use by setting the `--model_path` argument. If not specified, it will default to using a model trained on UK data e.g.
`batdetect2 detect example_data/audio/ example_data/anns/ 0.3 --model_path models/Net2DFast_UK_same.pth.tar`
### Using the Python API
If you prefer to process your data within a Python script then you can use the `batdetect2` Python API.
```python
from batdetect2 import api
AUDIO_FILE = "example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav"
# Process a whole file
results = api.process_file(AUDIO_FILE)
# Or, load audio and compute spectrograms
audio = api.load_audio(AUDIO_FILE)
spec = api.generate_spectrogram(audio)
# And process the audio or the spectrogram with the model
detections, features, spec = api.process_audio(audio)
detections, features = api.process_spectrogram(spec)
# Do something else ...
```
You can integrate the detections or the extracted features to your custom analysis pipeline.
#### Using the Python API with HTTP
```python
from batdetect2 import api
import io
import requests
AUDIO_URL = "<insert your audio url here>"
# Process a whole file from a url
results = api.process_url(AUDIO_URL)
# Or, load audio and compute spectrograms
# 'requests.get(AUDIO_URL).content' fetches the raw bytes. You are free to use other sources to fetch the raw bytes
audio = api.load_audio(io.BytesIO(requests.get(AUDIO_URL).content))
spec = api.generate_spectrogram(audio)
# And process the audio or the spectrogram with the model
detections, features, spec = api.process_audio(audio)
detections, features = api.process_spectrogram(spec)
```
## Training the model on your own data
Take a look at the steps outlined in finetuning readme [here](batdetect2/finetune/readme.md) for a description of how to train your own model.
## Data and annotations ## Data and annotations
The raw audio data and annotations used to train the models in the paper will be added soon.
The audio interface used to annotate audio data for training and evaluation is available [here](https://github.com/macaodha/batdetect2_GUI).
The raw audio data and annotations used to train the models in the paper will be
added soon.
`batdetect2` supports annotations in various formats and is compatible with the
outputs of [`whombat`](https://github.com/mbsantiago/whombat/) and this
[earlier version](https://github.com/macaodha/batdetect2_GUI).
If you're interested in supporting another format, please reach out or submit a
PR.
## Warning ## Warning
The models developed and shared as part of this repository should be used with caution.
While they have been evaluated on held out audio data, great care should be taken when using the model outputs for any form of biodiversity assessment.
Your data may differ, and as a result it is very strongly recommended that you validate the model first using data with known species to ensure that the outputs can be trusted.
The models developed and shared as part of this repository should be used with
caution.
While they have been evaluated on held-out audio data, great care should be
taken when using the model outputs for any form of biodiversity assessment.
Your data may differ, and as a result it is very strongly recommended that you
validate the model first using data with known species to ensure that the
outputs can be trusted.
If you train a model, make the best effort to be transparent about its training
and evaluation data, and inform downstream users about its limitations.
## FAQ ## FAQ
For more information please consult our [FAQ](faq.md).
For more information please consult our [FAQ](docs/source/faq.md).
## Reference ## Reference
If you find our work useful in your research please consider citing our paper which you can find [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1):
If you find our work useful in your research, please consider citing our paper,
which you can find
[here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1):
``` ```
@article{batdetect2_2022, @article{batdetect2_2022,
title = {Towards a General Approach for Bat Echolocation Detection and Classification}, title = {Towards a General Approach for Bat Echolocation Detection and Classification},
author = {Mac Aodha, Oisin and Mart\'{i}nez Balvanera, Santiago and Damstra, Elise and Cooke, Martyn and Eichinski, Philip and Browning, Ella and Barataud, Michel and Boughey, Katherine and Coles, Roger and Giacomini, Giada and MacSwiney G., M. Cristina and K. Obrist, Martin and Parsons, Stuart and Sattler, Thomas and Jones, Kate E.}, author = {Mac Aodha, Oisin and Mart\'{i}nez Balvanera, Santiago and Damstra, Elise and Cooke, Martyn and Eichinski, Philip and Browning, Ella and Barataudm, Michel and Boughey, Katherine and Coles, Roger and Giacomini, Giada and MacSwiney G., M. Cristina and K. Obrist, Martin and Parsons, Stuart and Sattler, Thomas and Jones, Kate E.},
journal = {bioRxiv}, journal = {bioRxiv},
year = {2022} year = {2022}
} }
``` ```
## Acknowledgements ## Acknowledgements
Thanks to all the contributors who spent time collecting and annotating audio data.
Thanks to all the contributors who spent time collecting and annotating audio
data. ### TODOs
- [x] Release the code and pretrained model
- [ ] Release the datasets and annotations used the experiments in the paper
- [ ] Add the scripts used to generate the tables and figures from the paper

6
batdetect2/__init__.py Normal file
View File

@ -0,0 +1,6 @@
import logging
numba_logger = logging.getLogger("numba")
numba_logger.setLevel(logging.WARNING)
__version__ = "1.3.1"

View File

@ -32,7 +32,7 @@ results will be combined into a dictionary with the following keys:
for each detection. The CNN features are the output of the CNN before for each detection. The CNN features are the output of the CNN before
the final classification layer. You can use these features to train the final classification layer. You can use these features to train
your own classifier, or to do other processing on the detections. your own classifier, or to do other processing on the detections.
They are in the same order as the detections in They are in the same order as the detections in
`results['pred_dict']['annotation']`. Will only be returned if the `results['pred_dict']['annotation']`. Will only be returned if the
`cnn_feats` parameter in the config is set to `True`. `cnn_feats` parameter in the config is set to `True`.
- `spec_slices`: Optional. A list of `numpy` arrays containing the spectrogram - `spec_slices`: Optional. A list of `numpy` arrays containing the spectrogram
@ -96,9 +96,10 @@ If you wish to use a custom model or change the default parameters, please
consult the API documentation in the code. consult the API documentation in the code.
""" """
import warnings import warnings
from typing import List, Optional, Tuple, BinaryIO, Any, Union
from .types import AudioPath
import numpy as np import numpy as np
import torch import torch
@ -120,6 +121,12 @@ from batdetect2.types import (
) )
from batdetect2.utils.detector_utils import list_audio_files, load_model from batdetect2.utils.detector_utils import list_audio_files, load_model
import audioread
import os
import soundfile as sf
import requests
import io
# Remove warnings from torch # Remove warnings from torch
warnings.filterwarnings("ignore", category=UserWarning, module="torch") warnings.filterwarnings("ignore", category=UserWarning, module="torch")
@ -164,7 +171,7 @@ def load_audio(
time_exp_fact: float = 1, time_exp_fact: float = 1,
target_samp_rate: int = TARGET_SAMPLERATE_HZ, target_samp_rate: int = TARGET_SAMPLERATE_HZ,
scale: bool = False, scale: bool = False,
max_duration: float | None = None, max_duration: Optional[float] = None,
) -> np.ndarray: ) -> np.ndarray:
"""Load audio from file. """Load audio from file.
@ -202,7 +209,7 @@ def load_audio(
def generate_spectrogram( def generate_spectrogram(
audio: np.ndarray, audio: np.ndarray,
samp_rate: int = TARGET_SAMPLERATE_HZ, samp_rate: int = TARGET_SAMPLERATE_HZ,
config: SpectrogramParameters | None = None, config: Optional[SpectrogramParameters] = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
) -> torch.Tensor: ) -> torch.Tensor:
"""Generate spectrogram from audio array. """Generate spectrogram from audio array.
@ -226,10 +233,11 @@ def generate_spectrogram(
if config is None: if config is None:
config = DEFAULT_SPECTROGRAM_PARAMETERS config = DEFAULT_SPECTROGRAM_PARAMETERS
_, spec = du.compute_spectrogram( _, spec, _ = du.compute_spectrogram(
audio, audio,
samp_rate, samp_rate,
config, config,
return_np=False,
device=device, device=device,
) )
@ -237,41 +245,89 @@ def generate_spectrogram(
def process_file( def process_file(
audio_file: str, path: AudioPath,
model: DetectionModel = MODEL, model: DetectionModel = MODEL,
config: ProcessingConfiguration | None = None, config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
file_id: Optional[str] = None
) -> du.RunResults: ) -> du.RunResults:
"""Process audio file with model. """Process audio file with model.
Parameters Parameters
---------- ----------
audio_file : str path : AudioPath
Path to audio file. Path to audio data.
model : DetectionModel, optional model : DetectionModel, optional
Detection model. Uses default model if not specified. Detection model. Uses default model if not specified.
config : Optional[ProcessingConfiguration], optional config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters). Processing configuration, by default None (uses default parameters).
device : torch.device, optional device : torch.device, optional
Device to use, by default tries to use GPU if available. Device to use, by default tries to use GPU if available.
file_id: Optional[str],
Give the data an id. If path is a string path to a file this can be ignored and
the file_id will be the basename of the file.
""" """
if config is None: if config is None:
config = CONFIG config = CONFIG
return du.process_file( return du.process_file(
audio_file, path,
model, model,
config, config,
device, device,
file_id
) )
def process_url(
url: str,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE,
file_id: Optional[str] = None
) -> du.RunResults:
"""Process audio file with model.
Parameters
----------
url : str
HTTP URL to load the audio data from
model : DetectionModel, optional
Detection model. Uses default model if not specified.
config : Optional[ProcessingConfiguration], optional
Processing configuration, by default None (uses default parameters).
device : torch.device, optional
Device to use, by default tries to use GPU if available.
file_id: Optional[str],
Give the data an id. Defaults to the URL
"""
if config is None:
config = CONFIG
if file_id is None:
file_id = url
response = requests.get(url)
# Raise exception on HTTP error
response.raise_for_status()
# Retrieve body as raw bytes
raw_audio_data = response.content
return du.process_file(
io.BytesIO(raw_audio_data),
model,
config,
device,
file_id
)
def process_spectrogram( def process_spectrogram(
spec: torch.Tensor, spec: torch.Tensor,
samp_rate: int = TARGET_SAMPLERATE_HZ, samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL, model: DetectionModel = MODEL,
config: ProcessingConfiguration | None = None, config: Optional[ProcessingConfiguration] = None,
) -> tuple[list[Annotation], np.ndarray]: ) -> Tuple[List[Annotation], np.ndarray]:
"""Process spectrogram with model. """Process spectrogram with model.
Parameters Parameters
@ -311,9 +367,9 @@ def process_audio(
audio: np.ndarray, audio: np.ndarray,
samp_rate: int = TARGET_SAMPLERATE_HZ, samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL, model: DetectionModel = MODEL,
config: ProcessingConfiguration | None = None, config: Optional[ProcessingConfiguration] = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
) -> tuple[list[Annotation], np.ndarray, torch.Tensor]: ) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
"""Process audio array with model. """Process audio array with model.
Parameters Parameters
@ -355,8 +411,8 @@ def process_audio(
def postprocess( def postprocess(
outputs: ModelOutput, outputs: ModelOutput,
samp_rate: int = TARGET_SAMPLERATE_HZ, samp_rate: int = TARGET_SAMPLERATE_HZ,
config: ProcessingConfiguration | None = None, config: Optional[ProcessingConfiguration] = None,
) -> tuple[list[Annotation], np.ndarray]: ) -> Tuple[List[Annotation], np.ndarray]:
"""Postprocess model outputs. """Postprocess model outputs.
Convert model tensor outputs to predicted bounding boxes and Convert model tensor outputs to predicted bounding boxes and
@ -410,9 +466,7 @@ def print_summary(results: RunResults) -> None:
Detection result. Detection result.
""" """
print("Results for " + results["pred_dict"]["id"]) print("Results for " + results["pred_dict"]["id"])
print( print("{} calls detected\n".format(len(results["pred_dict"]["annotation"])))
"{} calls detected\n".format(len(results["pred_dict"]["annotation"]))
)
print("time\tprob\tlfreq\tspecies_name") print("time\tprob\tlfreq\tspecies_name")
for ann in results["pred_dict"]["annotation"]: for ann in results["pred_dict"]["annotation"]:

View File

@ -1,24 +1,32 @@
"""BatDetect2 command line interface."""
import os import os
import click import click
from batdetect2.cli.base import cli from batdetect2 import api
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
from batdetect2.types import ProcessingConfiguration
from batdetect2.utils.detector_utils import save_results_to_file
DEFAULT_MODEL_PATH = os.path.join( CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
os.path.dirname(os.path.dirname(__file__)),
"models",
"checkpoints",
"Net2DFast_UK_same.pth.tar",
)
@cli.command( INFO_STR = """
short_help="Legacy detection command.", BatDetect2 - Detection and Classification
epilog=( Assumes audio files are mono, not stereo.
"Deprecated workflow. Prefer `batdetect2 process directory` for " Spaces in the input paths will throw an error. Wrap in quotes.
"new analyses." Input files should be short in duration e.g. < 30 seconds.
), """
)
@click.group()
def cli():
"""BatDetect2 - Bat Call Detection and Classification."""
click.echo(INFO_STR)
@cli.command()
@click.argument( @click.argument(
"audio_dir", "audio_dir",
type=click.Path(exists=True), type=click.Path(exists=True),
@ -37,6 +45,12 @@ DEFAULT_MODEL_PATH = os.path.join(
default=False, default=False,
help="Extracts CNN call features", help="Extracts CNN call features",
) )
@click.option(
"--chunk_size",
type=float,
default=2,
help="Specifies the duration of chunks in seconds. BatDetect2 will divide longer files into smaller chunks and process them independently. Larger chunks increase computation time and memory usage but may provide more contextual information for inference.",
)
@click.option( @click.option(
"--spec_features", "--spec_features",
is_flag=True, is_flag=True,
@ -72,12 +86,10 @@ def detect(
ann_dir: str, ann_dir: str,
detection_threshold: float, detection_threshold: float,
time_expansion_factor: int, time_expansion_factor: int,
chunk_size: float,
**args, **args,
): ):
"""Legacy detection command for directory-based inference. """Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
Detect bat calls in files in `AUDIO_DIR` and save predictions to
`ANN_DIR`.
DETECTION_THRESHOLD is the detection threshold. All predictions with a DETECTION_THRESHOLD is the detection threshold. All predictions with a
score below this threshold will be discarded. Values between 0 and 1. score below this threshold will be discarded. Values between 0 and 1.
@ -87,21 +99,7 @@ def detect(
Spaces in the input paths will throw an error. Wrap in quotes. Spaces in the input paths will throw an error. Wrap in quotes.
Input files should be short in duration e.g. < 30 seconds. Input files should be short in duration e.g. < 30 seconds.
Note
----
This command is kept for backwards compatibility. Prefer
`batdetect2 process directory` for new workflows.
""" """
from batdetect2 import api
from batdetect2.utils.detector_utils import save_results_to_file
message = (
"The `batdetect2 detect` command is deprecated. Prefer "
"`batdetect2 process directory` for new analyses."
)
click.secho(f"WARNING: {message}", fg="yellow", err=True)
click.echo(f"Loading model: {args['model_path']}") click.echo(f"Loading model: {args['model_path']}")
model, params = api.load_model(args["model_path"]) model, params = api.load_model(args["model_path"])
@ -117,7 +115,7 @@ def detect(
**args, **args,
"time_expansion": time_expansion_factor, "time_expansion": time_expansion_factor,
"spec_slices": False, "spec_slices": False,
"chunk_size": 2, "chunk_size": chunk_size,
"detection_threshold": detection_threshold, "detection_threshold": detection_threshold,
} }
) )
@ -151,8 +149,13 @@ def detect(
click.echo(f" {err}") click.echo(f" {err}")
def print_config(config): def print_config(config: ProcessingConfiguration):
"""Print the processing configuration values.""" """Print the processing configuration."""
click.echo("\nProcessing Configuration:") click.echo("\nProcessing Configuration:")
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}") click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
click.echo(f"Detection Threshold: {config.get('detection_threshold')}") click.echo(f"Detection Threshold: {config.get('detection_threshold')}")
click.echo(f"Chunk Size: {config.get('chunk_size')}s")
if __name__ == "__main__":
cli()

View File

@ -1,6 +1,5 @@
"""Functions to compute features from predictions.""" """Functions to compute features from predictions."""
from typing import Dict, Optional
from typing import Dict, List
import numpy as np import numpy as np
@ -8,26 +7,15 @@ from batdetect2 import types
from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ
def convert_int_to_freq( def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq):
spec_ind: int,
spec_height: int,
min_freq: float,
max_freq: float,
) -> int:
"""Convert spectrogram index to frequency in Hz.""" "" """Convert spectrogram index to frequency in Hz.""" ""
spec_ind = spec_height - spec_ind spec_ind = spec_height - spec_ind
return int( return round(
round( (spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq,
2,
)
) )
def extract_spec_slices( def extract_spec_slices(spec, pred_nms):
spec: np.ndarray,
pred_nms: types.PredictionResults,
) -> List[np.ndarray]:
"""Extract spectrogram slices from spectrogram. """Extract spectrogram slices from spectrogram.
The slices are extracted based on detected call locations. The slices are extracted based on detected call locations.
@ -86,7 +74,7 @@ def compute_bandwidth(
def compute_max_power_bb( def compute_max_power_bb(
prediction: types.Prediction, prediction: types.Prediction,
spec: np.ndarray | None = None, spec: Optional[np.ndarray] = None,
min_freq: int = MIN_FREQ_HZ, min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ, max_freq: int = MAX_FREQ_HZ,
**_, **_,
@ -121,7 +109,7 @@ def compute_max_power_bb(
return int( return int(
convert_int_to_freq( convert_int_to_freq(
int(y_high + max_power_ind), y_high + max_power_ind,
spec.shape[0], spec.shape[0],
min_freq, min_freq,
max_freq, max_freq,
@ -131,7 +119,7 @@ def compute_max_power_bb(
def compute_max_power( def compute_max_power(
prediction: types.Prediction, prediction: types.Prediction,
spec: np.ndarray | None = None, spec: Optional[np.ndarray] = None,
min_freq: int = MIN_FREQ_HZ, min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ, max_freq: int = MAX_FREQ_HZ,
**_, **_,
@ -147,17 +135,19 @@ def compute_max_power(
spec_call = spec[:, x_start:x_end] spec_call = spec[:, x_start:x_end]
power_per_freq_band = np.sum(spec_call, axis=1) power_per_freq_band = np.sum(spec_call, axis=1)
max_power_ind = np.argmax(power_per_freq_band) max_power_ind = np.argmax(power_per_freq_band)
return convert_int_to_freq( return int(
int(max_power_ind), convert_int_to_freq(
spec.shape[0], max_power_ind,
min_freq, spec.shape[0],
max_freq, min_freq,
max_freq,
)
) )
def compute_max_power_first( def compute_max_power_first(
prediction: types.Prediction, prediction: types.Prediction,
spec: np.ndarray | None = None, spec: Optional[np.ndarray] = None,
min_freq: int = MIN_FREQ_HZ, min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ, max_freq: int = MAX_FREQ_HZ,
**_, **_,
@ -174,17 +164,19 @@ def compute_max_power_first(
first_half = spec_call[:, : int(spec_call.shape[1] / 2)] first_half = spec_call[:, : int(spec_call.shape[1] / 2)]
power_per_freq_band = np.sum(first_half, axis=1) power_per_freq_band = np.sum(first_half, axis=1)
max_power_ind = np.argmax(power_per_freq_band) max_power_ind = np.argmax(power_per_freq_band)
return convert_int_to_freq( return int(
int(max_power_ind), convert_int_to_freq(
spec.shape[0], max_power_ind,
min_freq, spec.shape[0],
max_freq, min_freq,
max_freq,
)
) )
def compute_max_power_second( def compute_max_power_second(
prediction: types.Prediction, prediction: types.Prediction,
spec: np.ndarray | None = None, spec: Optional[np.ndarray] = None,
min_freq: int = MIN_FREQ_HZ, min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ, max_freq: int = MAX_FREQ_HZ,
**_, **_,
@ -201,17 +193,19 @@ def compute_max_power_second(
second_half = spec_call[:, int(spec_call.shape[1] / 2) :] second_half = spec_call[:, int(spec_call.shape[1] / 2) :]
power_per_freq_band = np.sum(second_half, axis=1) power_per_freq_band = np.sum(second_half, axis=1)
max_power_ind = np.argmax(power_per_freq_band) max_power_ind = np.argmax(power_per_freq_band)
return convert_int_to_freq( return int(
int(max_power_ind), convert_int_to_freq(
spec.shape[0], max_power_ind,
min_freq, spec.shape[0],
max_freq, min_freq,
max_freq,
)
) )
def compute_call_interval( def compute_call_interval(
prediction: types.Prediction, prediction: types.Prediction,
previous: types.Prediction | None = None, previous: Optional[types.Prediction] = None,
**_, **_,
) -> float: ) -> float:
"""Compute time between this call and the previous call in seconds.""" """Compute time between this call and the previous call in seconds."""
@ -242,7 +236,7 @@ def get_feats(
spec: np.ndarray, spec: np.ndarray,
pred_nms: types.PredictionResults, pred_nms: types.PredictionResults,
params: types.FeatureExtractionParameters, params: types.FeatureExtractionParameters,
) -> np.ndarray: ):
"""Extract features from spectrogram based on detected call locations. """Extract features from spectrogram based on detected call locations.
The features extracted are: The features extracted are:

View File

@ -53,13 +53,7 @@ class SelfAttention(nn.Module):
class ConvBlockDownCoordF(nn.Module): class ConvBlockDownCoordF(nn.Module):
def __init__( def __init__(
self, self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, stride=1
in_chn,
out_chn,
ip_height,
k_size=3,
pad_size=1,
stride=1,
): ):
super(ConvBlockDownCoordF, self).__init__() super(ConvBlockDownCoordF, self).__init__()
self.coords = nn.Parameter( self.coords = nn.Parameter(
@ -85,13 +79,7 @@ class ConvBlockDownCoordF(nn.Module):
class ConvBlockDownStandard(nn.Module): class ConvBlockDownStandard(nn.Module):
def __init__( def __init__(
self, self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1
in_chn,
out_chn,
ip_height=None,
k_size=3,
pad_size=1,
stride=1,
): ):
super(ConvBlockDownStandard, self).__init__() super(ConvBlockDownStandard, self).__init__()
self.conv = nn.Conv2d( self.conv = nn.Conv2d(

View File

@ -1,4 +1,5 @@
import torch import torch
import torch.fft
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
@ -94,10 +95,7 @@ class Net2DFast(nn.Module):
num_filts // 4, 2, kernel_size=1, padding=0 num_filts // 4, 2, kernel_size=1, padding=0
) )
self.conv_classes_op = nn.Conv2d( self.conv_classes_op = nn.Conv2d(
num_filts // 4, num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
self.num_classes + 1,
kernel_size=1,
padding=0,
) )
if self.emb_dim > 0: if self.emb_dim > 0:
@ -105,15 +103,15 @@ class Net2DFast(nn.Module):
num_filts, self.emb_dim, kernel_size=1, padding=0 num_filts, self.emb_dim, kernel_size=1, padding=0
) )
def forward(self, spec: torch.Tensor) -> ModelOutput: def forward(self, ip, return_feats=False) -> ModelOutput:
# encoder # encoder
x1 = self.conv_dn_0(spec) x1 = self.conv_dn_0(ip)
x2 = self.conv_dn_1(x1) x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2) x3 = self.conv_dn_2(x2)
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3))) x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
# bottleneck # bottleneck
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3))) x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
x = self.att(x) x = self.att(x)
x = x.repeat([1, 1, self.bneck_height * 4, 1]) x = x.repeat([1, 1, self.bneck_height * 4, 1])
@ -123,13 +121,13 @@ class Net2DFast(nn.Module):
x = self.conv_up_4(x + x1) x = self.conv_up_4(x + x1)
# output # output
x = F.relu_(self.conv_op_bn(self.conv_op(x))) x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
cls = self.conv_classes_op(x) cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1) comb = torch.softmax(cls, 1)
return ModelOutput( return ModelOutput(
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
pred_size=F.relu(self.conv_size_op(x)), pred_size=F.relu(self.conv_size_op(x), inplace=True),
pred_class=comb, pred_class=comb,
pred_class_un_norm=cls, pred_class_un_norm=cls,
features=x, features=x,
@ -209,10 +207,7 @@ class Net2DFastNoAttn(nn.Module):
num_filts // 4, 2, kernel_size=1, padding=0 num_filts // 4, 2, kernel_size=1, padding=0
) )
self.conv_classes_op = nn.Conv2d( self.conv_classes_op = nn.Conv2d(
num_filts // 4, num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
self.num_classes + 1,
kernel_size=1,
padding=0,
) )
if self.emb_dim > 0: if self.emb_dim > 0:
@ -220,26 +215,26 @@ class Net2DFastNoAttn(nn.Module):
num_filts, self.emb_dim, kernel_size=1, padding=0 num_filts, self.emb_dim, kernel_size=1, padding=0
) )
def forward(self, spec: torch.Tensor) -> ModelOutput: def forward(self, ip, return_feats=False) -> ModelOutput:
x1 = self.conv_dn_0(spec) x1 = self.conv_dn_0(ip)
x2 = self.conv_dn_1(x1) x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2) x3 = self.conv_dn_2(x2)
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3))) x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3))) x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
x = x.repeat([1, 1, self.bneck_height * 4, 1]) x = x.repeat([1, 1, self.bneck_height * 4, 1])
x = self.conv_up_2(x + x3) x = self.conv_up_2(x + x3)
x = self.conv_up_3(x + x2) x = self.conv_up_3(x + x2)
x = self.conv_up_4(x + x1) x = self.conv_up_4(x + x1)
x = F.relu_(self.conv_op_bn(self.conv_op(x))) x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
cls = self.conv_classes_op(x) cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1) comb = torch.softmax(cls, 1)
return ModelOutput( return ModelOutput(
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
pred_size=F.relu_(self.conv_size_op(x)), pred_size=F.relu(self.conv_size_op(x), inplace=True),
pred_class=comb, pred_class=comb,
pred_class_un_norm=cls, pred_class_un_norm=cls,
features=x, features=x,
@ -329,13 +324,13 @@ class Net2DFastNoCoordConv(nn.Module):
num_filts, self.emb_dim, kernel_size=1, padding=0 num_filts, self.emb_dim, kernel_size=1, padding=0
) )
def forward(self, spec: torch.Tensor) -> ModelOutput: def forward(self, ip, return_feats=False) -> ModelOutput:
x1 = self.conv_dn_0(spec) x1 = self.conv_dn_0(ip)
x2 = self.conv_dn_1(x1) x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2) x3 = self.conv_dn_2(x2)
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3))) x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3))) x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
x = self.att(x) x = self.att(x)
x = x.repeat([1, 1, self.bneck_height * 4, 1]) x = x.repeat([1, 1, self.bneck_height * 4, 1])
@ -343,13 +338,15 @@ class Net2DFastNoCoordConv(nn.Module):
x = self.conv_up_3(x + x2) x = self.conv_up_3(x + x2)
x = self.conv_up_4(x + x1) x = self.conv_up_4(x + x1)
x = F.relu_(self.conv_op_bn(self.conv_op(x))) x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
cls = self.conv_classes_op(x) cls = self.conv_classes_op(x)
comb = torch.softmax(cls, 1) comb = torch.softmax(cls, 1)
pred_emb = (self.conv_emb(x) if self.emb_dim > 0 else None,)
return ModelOutput( return ModelOutput(
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1), pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
pred_size=F.relu_(self.conv_size_op(x)), pred_size=F.relu(self.conv_size_op(x), inplace=True),
pred_class=comb, pred_class=comb,
pred_class_un_norm=cls, pred_class_un_norm=cls,
features=x, features=x,

View File

@ -0,0 +1,232 @@
import datetime
import os
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
TARGET_SAMPLERATE_HZ = 256000
FFT_WIN_LENGTH_S = 512 / 256000.0
FFT_OVERLAP = 0.75
MAX_FREQ_HZ = 120000
MIN_FREQ_HZ = 10000
RESIZE_FACTOR = 0.5
SPEC_DIVIDE_FACTOR = 32
SPEC_HEIGHT = 256
SCALE_RAW_AUDIO = False
DETECTION_THRESHOLD = 0.01
NMS_KERNEL_SIZE = 9
NMS_TOP_K_PER_SEC = 200
SPEC_SCALE = "pcen"
DENOISE_SPEC_AVG = True
MAX_SCALE_SPEC = False
DEFAULT_MODEL_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"models",
"Net2DFast_UK_same.pth.tar",
)
DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = {
"fft_win_length": FFT_WIN_LENGTH_S,
"fft_overlap": FFT_OVERLAP,
"spec_height": SPEC_HEIGHT,
"resize_factor": RESIZE_FACTOR,
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
"max_freq": MAX_FREQ_HZ,
"min_freq": MIN_FREQ_HZ,
"spec_scale": SPEC_SCALE,
"denoise_spec_avg": DENOISE_SPEC_AVG,
"max_scale_spec": MAX_SCALE_SPEC,
}
DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = {
"detection_threshold": DETECTION_THRESHOLD,
"spec_slices": False,
"chunk_size": 3,
"spec_features": False,
"cnn_features": False,
"quiet": True,
"target_samp_rate": TARGET_SAMPLERATE_HZ,
"fft_win_length": FFT_WIN_LENGTH_S,
"fft_overlap": FFT_OVERLAP,
"resize_factor": RESIZE_FACTOR,
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
"spec_height": SPEC_HEIGHT,
"scale_raw_audio": SCALE_RAW_AUDIO,
"class_names": [],
"time_expansion": 1,
"top_n": 3,
"return_raw_preds": False,
"max_duration": None,
"nms_kernel_size": NMS_KERNEL_SIZE,
"max_freq": MAX_FREQ_HZ,
"min_freq": MIN_FREQ_HZ,
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
"spec_scale": SPEC_SCALE,
"denoise_spec_avg": DENOISE_SPEC_AVG,
"max_scale_spec": MAX_SCALE_SPEC,
}
def mk_dir(path):
if not os.path.isdir(path):
os.makedirs(path)
def get_params(make_dirs=False, exps_dir="../../experiments/"):
params = {}
params[
"model_name"
] = "Net2DFast" # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN
params["num_filters"] = 128
now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
model_name = now_str + ".pth.tar"
params["experiment"] = os.path.join(exps_dir, now_str, "")
params["model_file_name"] = os.path.join(params["experiment"], model_name)
params["op_im_dir"] = os.path.join(params["experiment"], "op_ims", "")
params["op_im_dir_test"] = os.path.join(
params["experiment"], "op_ims_test", ""
)
# params['notes'] = '' # can save notes about an experiment here
# spec parameters
params[
"target_samp_rate"
] = TARGET_SAMPLERATE_HZ # resamples all audio so that it is at this rate
params[
"fft_win_length"
] = FFT_WIN_LENGTH_S # in milliseconds, amount of time per stft time step
params["fft_overlap"] = FFT_OVERLAP # stft window overlap
params[
"max_freq"
] = MAX_FREQ_HZ # in Hz, everything above this will be discarded
params[
"min_freq"
] = MIN_FREQ_HZ # in Hz, everything below this will be discarded
params[
"resize_factor"
] = RESIZE_FACTOR # resize so the spectrogram at the input of the network
params[
"spec_height"
] = SPEC_HEIGHT # units are number of frequency bins (before resizing is performed)
params[
"spec_train_width"
] = 512 # units are number of time steps (before resizing is performed)
params[
"spec_divide_factor"
] = SPEC_DIVIDE_FACTOR # spectrogram should be divisible by this amount in width and height
# spec processing params
params[
"denoise_spec_avg"
] = DENOISE_SPEC_AVG # removes the mean for each frequency band
params[
"scale_raw_audio"
] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1]
params[
"max_scale_spec"
] = MAX_SCALE_SPEC # scales the spectrogram so that it is max 1
params["spec_scale"] = SPEC_SCALE # 'log', 'pcen', 'none'
# detection params
params[
"detection_overlap"
] = 0.01 # has to be within this number of ms to count as detection
params[
"ignore_start_end"
] = 0.01 # if start of GT calls are within this time from the start/end of file ignore
params[
"detection_threshold"
] = DETECTION_THRESHOLD # the smaller this is the better the recall will be
params[
"nms_kernel_size"
] = NMS_KERNEL_SIZE # size of the kernel for non-max suppression
params[
"nms_top_k_per_sec"
] = NMS_TOP_K_PER_SEC # keep top K highest predictions per second of audio
params["target_sigma"] = 2.0
# augmentation params
params[
"aug_prob"
] = 0.20 # augmentations will be performed with this probability
params["augment_at_train"] = True
params["augment_at_train_combine"] = True
params[
"echo_max_delay"
] = 0.005 # simulate echo by adding copy of raw audio
params["stretch_squeeze_delta"] = 0.04 # stretch or squeeze spec
params[
"mask_max_time_perc"
] = 0.05 # max mask size - here percentage, not ideal
params[
"mask_max_freq_perc"
] = 0.10 # max mask size - here percentage, not ideal
params[
"spec_amp_scaling"
] = 2.0 # multiply the "volume" by 0:X times current amount
params["aug_sampling_rates"] = [
220500,
256000,
300000,
312500,
384000,
441000,
500000,
]
# loss params
params["train_loss"] = "focal" # mse or focal
params["det_loss_weight"] = 1.0 # weight for the detection part of the loss
params["size_loss_weight"] = 0.1 # weight for the bbox size loss
params["class_loss_weight"] = 2.0 # weight for the classification loss
params["individual_loss_weight"] = 0.0 # not used
if params["individual_loss_weight"] == 0.0:
params[
"emb_dim"
] = 0 # number of dimensions used for individual id embedding
else:
params["emb_dim"] = 3
# train params
params["lr"] = 0.001
params["batch_size"] = 8
params["num_workers"] = 4
params["num_epochs"] = 200
params["num_eval_epochs"] = 5 # run evaluation every X epochs
params["device"] = "cuda"
params["save_test_image_during_train"] = False
params["save_test_image_after_train"] = True
params["convert_to_genus"] = False
params["genus_mapping"] = []
params["class_names"] = []
params["classes_to_ignore"] = ["", " ", "Unknown", "Not Bat"]
params["generic_class"] = ["Bat"]
params["events_of_interest"] = [
"Echolocation"
] # will ignore all other types of events e.g. social calls
# the classes in this list are standardized during training so that the same low and high freq are used
params["standardize_classs_names"] = []
# create directories
if make_dirs:
print("Model name : " + params["model_name"])
print("Model file : " + params["model_file_name"])
print("Experiment : " + params["experiment"])
mk_dir(params["experiment"])
if params["save_test_image_during_train"]:
mk_dir(params["op_im_dir"])
if params["save_test_image_after_train"]:
mk_dir(params["op_im_dir_test"])
mk_dir(os.path.dirname(params["model_file_name"]))
return params

View File

@ -1,4 +1,5 @@
"""Post-processing of the output of the model.""" """Post-processing of the output of the model."""
from typing import List, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -43,7 +44,7 @@ def run_nms(
outputs: ModelOutput, outputs: ModelOutput,
params: NonMaximumSuppressionConfig, params: NonMaximumSuppressionConfig,
sampling_rate: np.ndarray, sampling_rate: np.ndarray,
) -> tuple[list[PredictionResults], list[np.ndarray]]: ) -> Tuple[List[PredictionResults], List[np.ndarray]]:
"""Run non-maximum suppression on the output of the model. """Run non-maximum suppression on the output of the model.
Model outputs processed are expected to have a batch dimension. Model outputs processed are expected to have a batch dimension.
@ -71,8 +72,8 @@ def run_nms(
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k) scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
# loop over batch to save outputs # loop over batch to save outputs
preds: list[PredictionResults] = [] preds: List[PredictionResults] = []
feats: list[np.ndarray] = [] feats: List[np.ndarray] = []
for num_detection in range(pred_det_nms.shape[0]): for num_detection in range(pred_det_nms.shape[0]):
# get valid indices # get valid indices
inds_ord = torch.argsort(x_pos[num_detection, :]) inds_ord = torch.argsort(x_pos[num_detection, :])
@ -149,7 +150,7 @@ def run_nms(
def non_max_suppression( def non_max_suppression(
heat: torch.Tensor, heat: torch.Tensor,
kernel_size: int | tuple[int, int], kernel_size: Union[int, Tuple[int, int]],
): ):
# kernel can be an int or list/tuple # kernel can be an int or list/tuple
if isinstance(kernel_size, int): if isinstance(kernel_size, int):

View File

@ -7,19 +7,20 @@ import copy
import json import json
import os import os
import torch
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch
from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestClassifier
import batdetect2.evaluate.legacy.evaluate_models as evl from batdetect2.detector import parameters
import batdetect2.train.legacy.train_utils as tu import batdetect2.train.evaluate as evl
import batdetect2.train.train_utils as tu
import batdetect2.utils.detector_utils as du import batdetect2.utils.detector_utils as du
import batdetect2.utils.plot_utils as pu import batdetect2.utils.plot_utils as pu
from batdetect2.detector import parameters
def get_blank_annotation(ip_str): def get_blank_annotation(ip_str):
res = {} res = {}
res["class_name"] = "" res["class_name"] = ""
res["duration"] = -1 res["duration"] = -1
@ -76,6 +77,7 @@ def create_genus_mapping(gt_test, preds, class_names):
def load_tadarida_pred(ip_dir, dataset, file_of_interest): def load_tadarida_pred(ip_dir, dataset, file_of_interest):
res, ann = get_blank_annotation("Generated by Tadarida") res, ann = get_blank_annotation("Generated by Tadarida")
# create the annotations in the correct format # create the annotations in the correct format
@ -118,6 +120,7 @@ def load_sonobat_meta(
class_names, class_names,
only_accepted_species=True, only_accepted_species=True,
): ):
sp_dict = {} sp_dict = {}
for ss in class_names: for ss in class_names:
sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3] sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3]
@ -179,6 +182,7 @@ def load_sonobat_meta(
def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None): def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
# create the annotations in the correct format # create the annotations in the correct format
res, ann = get_blank_annotation("Generated by Sonobat") res, ann = get_blank_annotation("Generated by Sonobat")
res_c = copy.deepcopy(res) res_c = copy.deepcopy(res)
@ -217,6 +221,7 @@ def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
def bb_overlap(bb_g_in, bb_p_in): def bb_overlap(bb_g_in, bb_p_in):
freq_scale = 10000000.0 # ensure that both axis are roughly the same range freq_scale = 10000000.0 # ensure that both axis are roughly the same range
bb_g = [ bb_g = [
bb_g_in["start_time"], bb_g_in["start_time"],
@ -325,8 +330,7 @@ def load_gt_data(datasets, events_of_interest, class_names, classes_to_ignore):
for dd in datasets: for dd in datasets:
print("\n" + dd["dataset_name"]) print("\n" + dd["dataset_name"])
gt_dataset = tu.load_set_of_anns( gt_dataset = tu.load_set_of_anns(
[dd], [dd], events_of_interest=events_of_interest, verbose=True
events_of_interest=events_of_interest,
) )
gt_dataset = [ gt_dataset = [
parse_data(gg, class_names, classes_to_ignore, False) parse_data(gg, class_names, classes_to_ignore, False)
@ -357,7 +361,7 @@ def train_rf_model(x_train, y_train, num_classes, seed=2001):
clf = RandomForestClassifier(random_state=seed, n_jobs=-1) clf = RandomForestClassifier(random_state=seed, n_jobs=-1)
clf.fit(x_train, y_train) clf.fit(x_train, y_train)
y_pred = clf.predict(x_train) y_pred = clf.predict(x_train)
(y_pred == y_train).mean() tr_acc = (y_pred == y_train).mean()
# print('Train acc', round(tr_acc*100, 2)) # print('Train acc', round(tr_acc*100, 2))
return clf, un_train_class return clf, un_train_class
@ -450,7 +454,7 @@ def add_root_path_back(data_sets, ann_path, wav_path):
def check_classes_in_train(gt_list, class_names): def check_classes_in_train(gt_list, class_names):
np.sum([gg["start_times"].shape[0] for gg in gt_list]) num_gt_total = np.sum([gg["start_times"].shape[0] for gg in gt_list])
num_with_no_class = 0 num_with_no_class = 0
for gt in gt_list: for gt in gt_list:
for cc in gt["class_names"]: for cc in gt["class_names"]:
@ -460,6 +464,7 @@ def check_classes_in_train(gt_list, class_names):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"op_dir", "op_dir",
@ -548,9 +553,7 @@ if __name__ == "__main__":
test_dict["dataset_name"] = args["test_file"].replace(".json", "") test_dict["dataset_name"] = args["test_file"].replace(".json", "")
test_dict["is_test"] = True test_dict["is_test"] = True
test_dict["is_binary"] = True test_dict["is_binary"] = True
test_dict["ann_path"] = os.path.join( test_dict["ann_path"] = os.path.join(args["ann_dir"], args["test_file"])
args["ann_dir"], args["test_file"]
)
test_dict["wav_path"] = args["data_dir"] test_dict["wav_path"] = args["data_dir"]
test_sets = [test_dict] test_sets = [test_dict]
@ -569,7 +572,7 @@ if __name__ == "__main__":
num_with_no_class = check_classes_in_train(gt_test, class_names) num_with_no_class = check_classes_in_train(gt_test, class_names)
if total_num_calls == num_with_no_class: if total_num_calls == num_with_no_class:
print("Classes from the test set are not in the train set.") print("Classes from the test set are not in the train set.")
raise AssertionError() assert False
# only need the train data if evaluating Sonobat or Tadarida # only need the train data if evaluating Sonobat or Tadarida
if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "": if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "":
@ -743,7 +746,7 @@ if __name__ == "__main__":
# check if the class names are the same # check if the class names are the same
if params_bd["class_names"] != class_names: if params_bd["class_names"] != class_names:
print("Warning: Class names are not the same as the trained model") print("Warning: Class names are not the same as the trained model")
raise AssertionError() assert False
run_config = { run_config = {
**bd_args, **bd_args,
@ -753,7 +756,7 @@ if __name__ == "__main__":
preds_bd = [] preds_bd = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for gg in gt_test: for ii, gg in enumerate(gt_test):
pred = du.process_file( pred = du.process_file(
gg["file_path"], gg["file_path"],
model, model,

View File

@ -1,31 +1,33 @@
import argparse import argparse
import glob
import json
import os import os
import warnings import sys
from typing import List
import matplotlib.pyplot as plt
import numpy as np
import torch import torch
import torch.utils.data import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
import batdetect2.detector.models as models
import batdetect2.detector.parameters as parameters import batdetect2.detector.parameters as parameters
import batdetect2.train.legacy.audio_dataloader as adl import batdetect2.detector.post_process as pp
import batdetect2.train.legacy.train_model as tm import batdetect2.train.audio_dataloader as adl
import batdetect2.train.legacy.train_utils as tu import batdetect2.train.evaluate as evl
import batdetect2.train.losses as losses import batdetect2.train.losses as losses
import batdetect2.train.train_model as tm
import batdetect2.train.train_utils as tu
import batdetect2.utils.detector_utils as du import batdetect2.utils.detector_utils as du
import batdetect2.utils.plot_utils as pu import batdetect2.utils.plot_utils as pu
from batdetect2 import types
from batdetect2.detector.models import Net2DFast
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) if __name__ == "__main__":
info_str = "\nBatDetect - Finetune Model\n"
print(info_str)
def parse_arugments():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"audio_path", "audio_path", type=str, help="Input directory for audio"
type=str,
help="Input directory for audio",
) )
parser.add_argument( parser.add_argument(
"train_ann_path", "train_ann_path",
@ -37,15 +39,7 @@ def parse_arugments():
type=str, type=str,
help="Path to where test annotation file is stored", help="Path to where test annotation file is stored",
) )
parser.add_argument( parser.add_argument("model_path", type=str, help="Path to pretrained model")
"model_path", type=str, help="Path to pretrained model"
)
parser.add_argument(
"--experiment_dir",
type=str,
default=os.path.join(BASE_DIR, "experiments"),
help="Path to where experiment files are stored",
)
parser.add_argument( parser.add_argument(
"--op_model_name", "--op_model_name",
type=str, type=str,
@ -77,64 +71,107 @@ def parse_arugments():
parser.add_argument( parser.add_argument(
"--notes", type=str, default="", help="Notes to save in text file" "--notes", type=str, default="", help="Notes to save in text file"
) )
args = parser.parse_args() args = vars(parser.parse_args())
return args
params = parameters.get_params(True, "../../experiments/")
def select_device(warn=True) -> str:
if torch.cuda.is_available(): if torch.cuda.is_available():
return "cuda" params["device"] = "cuda"
else:
if warn: params["device"] = "cpu"
warnings.warn( print(
"No GPU available, using the CPU instead. Please consider using a GPU " "\nNote, this will be a lot faster if you use computer with a GPU.\n"
"to speed up training.",
stacklevel=2,
) )
return "cpu" print("\nAudio directory: " + args["audio_path"])
print("Train file: " + args["train_ann_path"])
print("Test file: " + args["test_ann_path"])
print("Loading model: " + args["model_path"])
dataset_name = (
os.path.basename(args["train_ann_path"])
.replace(".json", "")
.replace("_TRAIN", "")
)
def load_annotations( if args["train_from_scratch"]:
dataset_name: str, print("\nTraining model from scratch i.e. not using pretrained weights")
ann_path: str, model, params_train = du.load_model(args["model_path"], False)
audio_path: str, else:
classes_to_ignore: List[str] | None = None, model, params_train = du.load_model(args["model_path"], True)
events_of_interest: List[str] | None = None, model.to(params["device"])
) -> List[types.FileAnnotation]:
train_sets: List[types.DatasetDict] = [] params["num_epochs"] = args["num_epochs"]
if args["op_model_name"] != "":
params["model_file_name"] = args["op_model_name"]
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
# save notes file
params["notes"] = args["notes"]
if args["notes"] != "":
tu.write_notes_file(params["experiment"] + "notes.txt", args["notes"])
# load train annotations
train_sets = []
train_sets.append( train_sets.append(
tu.get_blank_dataset_dict( tu.get_blank_dataset_dict(
dataset_name, dataset_name, False, args["train_ann_path"], args["audio_path"]
is_test=False,
ann_path=ann_path,
wav_path=audio_path,
) )
) )
params["train_sets"] = [
tu.get_blank_dataset_dict(
dataset_name,
False,
os.path.basename(args["train_ann_path"]),
args["audio_path"],
)
]
return tu.load_set_of_anns( print("\nTrain set:")
train_sets, (
events_of_interest=events_of_interest, data_train,
classes_to_ignore=classes_to_ignore, params["class_names"],
params["class_inv_freq"],
) = tu.load_set_of_anns(
train_sets, classes_to_ignore, params["events_of_interest"]
)
print("Number of files", len(data_train))
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
params["class_names"]
)
params["class_names_short"] = tu.get_short_class_names(
params["class_names"]
) )
# load test annotations
test_sets = []
test_sets.append(
tu.get_blank_dataset_dict(
dataset_name, True, args["test_ann_path"], args["audio_path"]
)
)
params["test_sets"] = [
tu.get_blank_dataset_dict(
dataset_name,
True,
os.path.basename(args["test_ann_path"]),
args["audio_path"],
)
]
print("\nTest set:")
data_test, _, _ = tu.load_set_of_anns(
test_sets, classes_to_ignore, params["events_of_interest"]
)
print("Number of files", len(data_test))
def finetune_model(
model: types.DetectionModel,
data_train: List[types.FileAnnotation],
data_test: List[types.FileAnnotation],
params: parameters.TrainingParameters,
model_params: types.ModelParameters,
finetune_only_last_layer: bool = False,
save_images: bool = True,
):
# train loader # train loader
train_dataset = adl.AudioLoader(data_train, params, is_train=True) train_dataset = adl.AudioLoader(data_train, params, is_train=True)
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, train_dataset,
batch_size=params.batch_size, batch_size=params["batch_size"],
shuffle=True, shuffle=True,
num_workers=params.num_workers, num_workers=params["num_workers"],
pin_memory=True, pin_memory=True,
) )
@ -144,36 +181,32 @@ def finetune_model(
test_dataset, test_dataset,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
num_workers=params.num_workers, num_workers=params["num_workers"],
pin_memory=True, pin_memory=True,
) )
inputs_train = next(iter(train_loader)) inputs_train = next(iter(train_loader))
params.ip_height = inputs_train["spec"].shape[2] params["ip_height"] = inputs_train["spec"].shape[2]
print("\ntrain batch size :", inputs_train["spec"].shape) print("\ntrain batch size :", inputs_train["spec"].shape)
# Check that the model is the same as the one used to train the pretrained assert params_train["model_name"] == "Net2DFast"
# weights
assert model_params["model_name"] == "Net2DFast"
assert isinstance(model, Net2DFast)
print( print(
"\n\nSOME hyperparams need to be the same as the loaded model " "\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n"
"(e.g. FFT) - currently they are getting overwritten.\n\n"
) )
# set the number of output classes # set the number of output classes
num_filts = model.conv_classes_op.in_channels num_filts = model.conv_classes_op.in_channels
(k_size,) = model.conv_classes_op.kernel_size k_size = model.conv_classes_op.kernel_size
(pad,) = model.conv_classes_op.padding pad = model.conv_classes_op.padding
model.conv_classes_op = torch.nn.Conv2d( model.conv_classes_op = torch.nn.Conv2d(
num_filts, num_filts,
len(params.class_names) + 1, len(params["class_names"]) + 1,
kernel_size=k_size, kernel_size=k_size,
padding=pad, padding=pad,
) )
model.conv_classes_op.to(params.device) model.conv_classes_op.to(params["device"])
if finetune_only_last_layer: if args["finetune_only_last_layer"]:
print("\nOnly finetuning the final layers.\n") print("\nOnly finetuning the final layers.\n")
train_layers_i = [ train_layers_i = [
"conv_classes", "conv_classes",
@ -190,26 +223,19 @@ def finetune_model(
else: else:
param.requires_grad = False param.requires_grad = False
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
model.parameters(),
lr=params.lr,
)
scheduler = CosineAnnealingLR( scheduler = CosineAnnealingLR(
optimizer, optimizer, params["num_epochs"] * len(train_loader)
params.num_epochs * len(train_loader),
) )
if params["train_loss"] == "mse":
if params.train_loss == "mse":
det_criterion = losses.mse_loss det_criterion = losses.mse_loss
elif params.train_loss == "focal": elif params["train_loss"] == "focal":
det_criterion = losses.focal_loss det_criterion = losses.focal_loss
else:
raise ValueError("Unknown loss function")
# plotting # plotting
train_plt_ls = pu.LossPlotter( train_plt_ls = pu.LossPlotter(
params.experiment / "train_loss.png", params["experiment"] + "train_loss.png",
params.num_epochs + 1, params["num_epochs"] + 1,
["train_loss"], ["train_loss"],
None, None,
None, None,
@ -217,8 +243,8 @@ def finetune_model(
logy=True, logy=True,
) )
test_plt_ls = pu.LossPlotter( test_plt_ls = pu.LossPlotter(
params.experiment / "test_loss.png", params["experiment"] + "test_loss.png",
params.num_epochs + 1, params["num_epochs"] + 1,
["test_loss"], ["test_loss"],
None, None,
None, None,
@ -226,24 +252,24 @@ def finetune_model(
logy=True, logy=True,
) )
test_plt = pu.LossPlotter( test_plt = pu.LossPlotter(
params.experiment / "test.png", params["experiment"] + "test.png",
params.num_epochs + 1, params["num_epochs"] + 1,
["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"], ["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"],
[0, 1], [0, 1],
None, None,
["epoch", ""], ["epoch", ""],
) )
test_plt_class = pu.LossPlotter( test_plt_class = pu.LossPlotter(
params.experiment / "test_avg_prec.png", params["experiment"] + "test_avg_prec.png",
params.num_epochs + 1, params["num_epochs"] + 1,
params.class_names_short, params["class_names_short"],
[0, 1], [0, 1],
params.class_names_short, params["class_names_short"],
["epoch", "avg_prec"], ["epoch", "avg_prec"],
) )
# main train loop # main train loop
for epoch in range(0, params.num_epochs + 1): for epoch in range(0, params["num_epochs"] + 1):
train_loss = tm.train( train_loss = tm.train(
model, model,
epoch, epoch,
@ -255,14 +281,10 @@ def finetune_model(
) )
train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]]) train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]])
if epoch % params.num_eval_epochs == 0: if epoch % params["num_eval_epochs"] == 0:
# detection accuracy on test set # detection accuracy on test set
test_res, test_loss = tm.test( test_res, test_loss = tm.test(
model, model, epoch, test_loader, det_criterion, params
epoch,
test_loader,
det_criterion,
params,
) )
test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]]) test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]])
test_plt.update_and_save( test_plt.update_and_save(
@ -279,106 +301,18 @@ def finetune_model(
epoch, [rs["avg_prec"] for rs in test_res["class_pr"]] epoch, [rs["avg_prec"] for rs in test_res["class_pr"]]
) )
pu.plot_pr_curve_class( pu.plot_pr_curve_class(
params.experiment, "test_pr", "test_pr", test_res params["experiment"], "test_pr", "test_pr", test_res
) )
# save finetuned model # save finetuned model
print(f"saving model to: {params.model_file_name}") print("saving model to: " + params["model_file_name"])
op_state = { op_state = {
"epoch": epoch + 1, "epoch": epoch + 1,
"state_dict": model.state_dict(), "state_dict": model.state_dict(),
"params": params, "params": params,
} }
torch.save(op_state, params.model_file_name) torch.save(op_state, params["model_file_name"])
# save an image with associated prediction for each batch in the test set # save an image with associated prediction for each batch in the test set
if save_images: if not args["do_not_save_images"]:
tm.save_images_batch(model, test_loader, params) tm.save_images_batch(model, test_loader, params)
def main():
info_str = "\nBatDetect - Finetune Model\n"
print(info_str)
args = parse_arugments()
# Load experiment parameters
params = parameters.get_params(
make_dirs=True,
exps_dir=args.experiment_dir,
device=select_device(),
num_epochs=args.num_epochs,
notes=args.notes,
)
print("\nAudio directory: " + args.audio_path)
print("Train file: " + args.train_ann_path)
print("Test file: " + args.test_ann_path)
print("Loading model: " + args.model_path)
if args.train_from_scratch:
print(
"\nTraining model from scratch i.e. not using pretrained weights"
)
model, model_params = du.load_model(
args.model_path,
load_weights=not args.train_from_scratch,
device=params.device,
)
if args.op_model_name != "":
params.model_file_name = args.op_model_name
classes_to_ignore = params.classes_to_ignore + params.generic_class
# save notes file
if params.notes:
tu.write_notes_file(
params.experiment / "notes.txt",
args.notes,
)
# NOTE:??
dataset_name = (
os.path.basename(args.train_ann_path)
.replace(".json", "")
.replace("_TRAIN", "")
)
# ==== LOAD DATA ====
# load train annotations
data_train = load_annotations(
dataset_name,
args.train_ann_path,
args.audio_path,
params.events_of_interest,
)
print("\nTrain set:")
print("Number of files", len(data_train))
# load test annotations
data_test = load_annotations(
dataset_name,
args.test_ann_path,
args.audio_path,
classes_to_ignore,
params.events_of_interest,
)
print("\nTrain set:")
print("Number of files", len(data_train))
finetune_model(
model,
data_train,
data_test,
params,
model_params,
finetune_only_last_layer=args.finetune_only_last_layer,
save_images=args.do_not_save_images,
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,201 @@
import argparse
import json
import os
import numpy as np
import batdetect2.train.train_utils as tu
def print_dataset_stats(data, split_name, classes_to_ignore):
print("\nSplit:", split_name)
print("Num files:", len(data))
class_cnts = {}
for dd in data:
for aa in dd["annotation"]:
if aa["class"] not in classes_to_ignore:
if aa["class"] in class_cnts:
class_cnts[aa["class"]] += 1
else:
class_cnts[aa["class"]] = 1
if len(class_cnts) == 0:
class_names = []
else:
class_names = np.sort([*class_cnts]).tolist()
print("Class count:")
str_len = np.max([len(cc) for cc in class_names]) + 5
for ii, cc in enumerate(class_names):
print(str(ii).ljust(5) + cc.ljust(str_len) + str(class_cnts[cc]))
return class_names
def load_file_names(file_name):
if os.path.isfile(file_name):
with open(file_name) as da:
files = [line.rstrip() for line in da.readlines()]
for ff in files:
if ff.lower()[-3:] != "wav":
print("Error: Filenames need to end in .wav - ", ff)
assert False
else:
print("Error: Input file not found - ", file_name)
assert False
return files
if __name__ == "__main__":
info_str = "\nBatDetect - Prepare Data for Finetuning\n"
print(info_str)
parser = argparse.ArgumentParser()
parser.add_argument(
"dataset_name", type=str, help="Name to call your dataset"
)
parser.add_argument("audio_dir", type=str, help="Input directory for audio")
parser.add_argument(
"ann_dir",
type=str,
help="Input directory for where the audio annotations are stored",
)
parser.add_argument(
"op_dir",
type=str,
help="Path where the train and test splits will be stored",
)
parser.add_argument(
"--percent_val",
type=float,
default=0.20,
help="Hold out this much data for validation. Should be number between 0 and 1",
)
parser.add_argument(
"--rand_seed",
type=int,
default=2001,
help="Random seed used for creating the validation split",
)
parser.add_argument(
"--train_file",
type=str,
default="",
help="Text file where each line is a wav file in train split",
)
parser.add_argument(
"--test_file",
type=str,
default="",
help="Text file where each line is a wav file in test split",
)
parser.add_argument(
"--input_class_names",
type=str,
default="",
help='Specify names of classes that you want to change. Separate with ";"',
)
parser.add_argument(
"--output_class_names",
type=str,
default="",
help='New class names to use instead. One to one mapping with "--input_class_names". \
Separate with ";"',
)
args = vars(parser.parse_args())
np.random.seed(args["rand_seed"])
classes_to_ignore = ["", " ", "Unknown", "Not Bat"]
generic_class = ["Bat"]
events_of_interest = ["Echolocation"]
if args["input_class_names"] != "" and args["output_class_names"] != "":
# change the names of the classes
ip_names = args["input_class_names"].split(";")
op_names = args["output_class_names"].split(";")
name_dict = dict(zip(ip_names, op_names))
else:
name_dict = False
# load annotations
data_all, _, _ = tu.load_set_of_anns(
{"ann_path": args["ann_dir"], "wav_path": args["audio_dir"]},
classes_to_ignore,
events_of_interest,
False,
False,
list_of_anns=True,
filter_issues=True,
name_replace=name_dict,
)
print("Dataset name: " + args["dataset_name"])
print("Audio directory: " + args["audio_dir"])
print("Annotation directory: " + args["ann_dir"])
print("Ouput directory: " + args["op_dir"])
print("Num annotated files: " + str(len(data_all)))
if args["train_file"] != "" and args["test_file"] != "":
# user has specifed the train / test split
train_files = load_file_names(args["train_file"])
test_files = load_file_names(args["test_file"])
file_names_all = [dd["id"] for dd in data_all]
train_inds = [
file_names_all.index(ff)
for ff in train_files
if ff in file_names_all
]
test_inds = [
file_names_all.index(ff)
for ff in test_files
if ff in file_names_all
]
else:
# split the data into train and test at the file level
num_exs = len(data_all)
test_inds = np.random.choice(
np.arange(num_exs),
int(num_exs * args["percent_val"]),
replace=False,
)
test_inds = np.sort(test_inds)
train_inds = np.setdiff1d(np.arange(num_exs), test_inds)
data_train = [data_all[ii] for ii in train_inds]
data_test = [data_all[ii] for ii in test_inds]
if not os.path.isdir(args["op_dir"]):
os.makedirs(args["op_dir"])
op_name = os.path.join(args["op_dir"], args["dataset_name"])
op_name_train = op_name + "_TRAIN.json"
op_name_test = op_name + "_TEST.json"
class_un_train = print_dataset_stats(data_train, "Train", classes_to_ignore)
class_un_test = print_dataset_stats(data_test, "Test", classes_to_ignore)
if len(data_train) > 0 and len(data_test) > 0:
if class_un_train != class_un_test:
print(
'\nError: some classes are not in both the training and test sets.\
\nTry a different random seed "--rand_seed".'
)
assert False
print("\n")
if len(data_train) == 0:
print("No train annotations to save")
else:
print("Saving: ", op_name_train)
with open(op_name_train, "w") as da:
json.dump(data_train, da, indent=2)
if len(data_test) == 0:
print("No test annotations to save")
else:
print("Saving: ", op_name_test)
with open(op_name_test, "w") as da:
json.dump(data_test, da, indent=2)

View File

@ -1,11 +1,11 @@
"""Plot functions to visualize detections and spectrograms.""" """Plot functions to visualize detections and spectrograms."""
from typing import cast from typing import List, Optional, Tuple, Union, cast
import matplotlib.ticker as tick
import numpy as np import numpy as np
import torch import torch
from matplotlib import axes, patches from matplotlib import axes, patches
import matplotlib.ticker as tick
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS
@ -24,10 +24,10 @@ __all__ = [
def spectrogram( def spectrogram(
spec: torch.Tensor | np.ndarray, spec: Union[torch.Tensor, np.ndarray],
config: ProcessingConfiguration | None = None, config: Optional[ProcessingConfiguration] = None,
ax: axes.Axes | None = None, ax: Optional[axes.Axes] = None,
figsize: tuple[int, int] | None = None, figsize: Optional[Tuple[int, int]] = None,
cmap: str = "plasma", cmap: str = "plasma",
start_time: float = 0, start_time: float = 0,
) -> axes.Axes: ) -> axes.Axes:
@ -35,18 +35,18 @@ def spectrogram(
Parameters Parameters
---------- ----------
spec: Spectrogram to plot. spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot.
config: Configuration config (Optional[ProcessingConfiguration], optional): Configuration
used to compute the spectrogram. Defaults to None. If None, used to compute the spectrogram. Defaults to None. If None,
the default configuration will be used. the default configuration will be used.
ax: Matplotlib axes object. ax (Optional[axes.Axes], optional): Matplotlib axes object.
Defaults to None. if provided, the spectrogram will be plotted Defaults to None. if provided, the spectrogram will be plotted
on this axes. on this axes.
figsize: Figure size. figsize (Optional[Tuple[int, int]], optional): Figure size.
Defaults to None. If `ax` is None, this will be used to create Defaults to None. If `ax` is None, this will be used to create
a new figure of the given size. a new figure of the given size.
cmap: Colormap to use. Defaults to "plasma". cmap (str, optional): Colormap to use. Defaults to "plasma".
start_time: Start time of the spectrogram. start_time (float, optional): Start time of the spectrogram.
Defaults to 0. This is useful if plotting a spectrogram Defaults to 0. This is useful if plotting a spectrogram
of a segment of a longer audio file. of a segment of a longer audio file.
@ -103,11 +103,11 @@ def spectrogram(
def spectrogram_with_detections( def spectrogram_with_detections(
spec: torch.Tensor | np.ndarray, spec: Union[torch.Tensor, np.ndarray],
dets: list[Annotation], dets: List[Annotation],
config: ProcessingConfiguration | None = None, config: Optional[ProcessingConfiguration] = None,
ax: axes.Axes | None = None, ax: Optional[axes.Axes] = None,
figsize: tuple[int, int] | None = None, figsize: Optional[Tuple[int, int]] = None,
cmap: str = "plasma", cmap: str = "plasma",
with_names: bool = True, with_names: bool = True,
start_time: float = 0, start_time: float = 0,
@ -117,21 +117,21 @@ def spectrogram_with_detections(
Parameters Parameters
---------- ----------
spec: Spectrogram to plot. spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot.
detections: List of detections. detections (List[Annotation]): List of detections.
config: Configuration config (Optional[ProcessingConfiguration], optional): Configuration
used to compute the spectrogram. Defaults to None. If None, used to compute the spectrogram. Defaults to None. If None,
the default configuration will be used. the default configuration will be used.
ax: Matplotlib axes object. ax (Optional[axes.Axes], optional): Matplotlib axes object.
Defaults to None. if provided, the spectrogram will be plotted Defaults to None. if provided, the spectrogram will be plotted
on this axes. on this axes.
figsize: Figure size. figsize (Optional[Tuple[int, int]], optional): Figure size.
Defaults to None. If `ax` is None, this will be used to create Defaults to None. If `ax` is None, this will be used to create
a new figure of the given size. a new figure of the given size.
cmap: Colormap to use. Defaults to "plasma". cmap (str, optional): Colormap to use. Defaults to "plasma".
with_names: Whether to plot the name of the with_names (bool, optional): Whether to plot the name of the
predicted class next to the detection. Defaults to True. predicted class next to the detection. Defaults to True.
start_time: Start time of the spectrogram. start_time (float, optional): Start time of the spectrogram.
Defaults to 0. This is useful if plotting a spectrogram Defaults to 0. This is useful if plotting a spectrogram
of a segment of a longer audio file. of a segment of a longer audio file.
**kwargs: Additional keyword arguments to pass to the **kwargs: Additional keyword arguments to pass to the
@ -167,9 +167,9 @@ def spectrogram_with_detections(
def detections( def detections(
dets: list[Annotation], dets: List[Annotation],
ax: axes.Axes | None = None, ax: Optional[axes.Axes] = None,
figsize: tuple[int, int] | None = None, figsize: Optional[Tuple[int, int]] = None,
with_names: bool = True, with_names: bool = True,
**kwargs, **kwargs,
) -> axes.Axes: ) -> axes.Axes:
@ -177,14 +177,14 @@ def detections(
Parameters Parameters
---------- ----------
dets: List of detections. dets (List[Annotation]): List of detections.
ax: Matplotlib axes object. ax (Optional[axes.Axes], optional): Matplotlib axes object.
Defaults to None. if provided, the spectrogram will be plotted Defaults to None. if provided, the spectrogram will be plotted
on this axes. on this axes.
figsize: Figure size. figsize (Optional[Tuple[int, int]], optional): Figure size.
Defaults to None. If `ax` is None, this will be used to create Defaults to None. If `ax` is None, this will be used to create
a new figure of the given size. a new figure of the given size.
with_names: Whether to plot the name of the with_names (bool, optional): Whether to plot the name of the
predicted class next to the detection. Defaults to True. predicted class next to the detection. Defaults to True.
**kwargs: Additional keyword arguments to pass to the **kwargs: Additional keyword arguments to pass to the
`plot.detection` function. `plot.detection` function.
@ -213,8 +213,8 @@ def detections(
def detection( def detection(
det: Annotation, det: Annotation,
ax: axes.Axes | None = None, ax: Optional[axes.Axes] = None,
figsize: tuple[int, int] | None = None, figsize: Optional[Tuple[int, int]] = None,
linewidth: float = 1, linewidth: float = 1,
edgecolor: str = "w", edgecolor: str = "w",
facecolor: str = "none", facecolor: str = "none",
@ -224,19 +224,19 @@ def detection(
Parameters Parameters
---------- ----------
det: Detection to plot. det (Annotation): Detection to plot.
ax: Matplotlib axes object. Defaults ax (Optional[axes.Axes], optional): Matplotlib axes object. Defaults
to None. If provided, the spectrogram will be plotted on this axes. to None. If provided, the spectrogram will be plotted on this axes.
figsize: Figure size. Defaults figsize (Optional[Tuple[int, int]], optional): Figure size. Defaults
to None. If `ax` is None, this will be used to create a new figure to None. If `ax` is None, this will be used to create a new figure
of the given size. of the given size.
linewidth: Line width of the detection. linewidth (float, optional): Line width of the detection.
Defaults to 1. Defaults to 1.
edgecolor: Edge color of the detection. edgecolor (str, optional): Edge color of the detection.
Defaults to "w", i.e. white. Defaults to "w", i.e. white.
facecolor: Face color of the detection. facecolor (str, optional): Face color of the detection.
Defaults to "none", i.e. transparent. Defaults to "none", i.e. transparent.
with_name: Whether to plot the name of the with_name (bool, optional): Whether to plot the name of the
predicted class next to the detection. Defaults to True. predicted class next to the detection. Defaults to True.
Returns Returns
@ -277,22 +277,22 @@ def detection(
def _compute_spec_extent( def _compute_spec_extent(
shape: tuple[int, int], shape: Tuple[int, int],
params: SpectrogramParameters, params: SpectrogramParameters,
) -> tuple[float, float, float, float]: ) -> Tuple[float, float, float, float]:
"""Compute the extent of a spectrogram. """Compute the extent of a spectrogram.
Parameters Parameters
---------- ----------
shape: Shape of the spectrogram. shape (Tuple[int, int]): Shape of the spectrogram.
The first dimension is the frequency axis and the second The first dimension is the frequency axis and the second
dimension is the time axis. dimension is the time axis.
params: Spectrogram parameters. params (SpectrogramParameters): Spectrogram parameters.
Should be the same as the ones used to compute the spectrogram. Should be the same as the ones used to compute the spectrogram.
Returns Returns
------- -------
tuple[float, float, float, float]: Extent of the spectrogram. Tuple[float, float, float, float]: Extent of the spectrogram.
The first two values are the minimum and maximum time values, The first two values are the minimum and maximum time values,
the last two values are the minimum and maximum frequency values. the last two values are the minimum and maximum frequency values.
""" """
@ -306,9 +306,6 @@ def _compute_spec_extent(
# If the spectrogram is not resized, the duration is correct # If the spectrogram is not resized, the duration is correct
# but if it is resized, the duration needs to be adjusted # but if it is resized, the duration needs to be adjusted
# NOTE: For now we can only detect if the spectrogram is resized
# by checking if the height is equal to the specified height,
# but this could fail.
resize_factor = params["resize_factor"] resize_factor = params["resize_factor"]
spec_height = params["spec_height"] spec_height = params["spec_height"]
if spec_height * resize_factor == shape[0]: if spec_height * resize_factor == shape[0]:

View File

@ -0,0 +1,603 @@
import copy
from typing import Tuple
import librosa
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio
import batdetect2.utils.audio_utils as au
from batdetect2.types import AnnotationGroup, HeatmapParameters
def generate_gt_heatmaps(
spec_op_shape: Tuple[int, int],
sampling_rate: int,
ann: AnnotationGroup,
params: HeatmapParameters,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AnnotationGroup]:
"""Generate ground truth heatmaps from annotations.
Parameters
----------
spec_op_shape : Tuple[int, int]
Shape of the input spectrogram.
sampling_rate : int
Sampling rate of the input audio in Hz.
ann : AnnotationGroup
Dictionary containing the annotation information.
params : HeatmapParameters
Parameters controlling the generation of the heatmaps.
Returns
-------
y_2d_det : np.ndarray
2D heatmap of the presence of an event.
y_2d_size : np.ndarray
2D heatmap of the size of the bounding box associated to event.
y_2d_classes : np.ndarray
3D array containing the ground-truth class probabilities for each
pixel.
ann_aug : AnnotationGroup
A dictionary containing the annotation information of the
annotations that are within the input spectrogram, augmented with
the x and y indices of their pixel location in the input spectrogram.
"""
# spec may be resized on input into the network
num_classes = len(params["class_names"])
op_height = spec_op_shape[0]
op_width = spec_op_shape[1]
freq_per_bin = (params["max_freq"] - params["min_freq"]) / op_height
# start and end times
x_pos_start = au.time_to_x_coords(
ann["start_times"],
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
)
x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int)
x_pos_end = au.time_to_x_coords(
ann["end_times"],
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
)
x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int)
# location on y axis i.e. frequency
y_pos_low = (ann["low_freqs"] - params["min_freq"]) / freq_per_bin
y_pos_low = (op_height - y_pos_low).astype(np.int)
y_pos_high = (ann["high_freqs"] - params["min_freq"]) / freq_per_bin
y_pos_high = (op_height - y_pos_high).astype(np.int)
bb_widths = x_pos_end - x_pos_start
bb_heights = y_pos_low - y_pos_high
# Only include annotations that are within the input spectrogram
valid_inds = np.where(
(x_pos_start >= 0)
& (x_pos_start < op_width)
& (y_pos_low >= 0)
& (y_pos_low < (op_height - 1))
)[0]
ann_aug: AnnotationGroup = {
"start_times": ann["start_times"][valid_inds],
"end_times": ann["end_times"][valid_inds],
"high_freqs": ann["high_freqs"][valid_inds],
"low_freqs": ann["low_freqs"][valid_inds],
"class_ids": ann["class_ids"][valid_inds],
"individual_ids": ann["individual_ids"][valid_inds],
}
ann_aug["x_inds"] = x_pos_start[valid_inds]
ann_aug["y_inds"] = y_pos_low[valid_inds]
# keys = [
# "start_times",
# "end_times",
# "high_freqs",
# "low_freqs",
# "class_ids",
# "individual_ids",
# ]
# for kk in keys:
# ann_aug[kk] = ann[kk][valid_inds]
# if the number of calls is only 1, then it is unique
# TODO would be better if we found these unique calls at the merging stage
if len(ann_aug["individual_ids"]) == 1:
ann_aug["individual_ids"][0] = 0
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
# num classes and "background" class
y_2d_classes: np.ndarray = np.zeros(
(num_classes + 1, op_height, op_width), dtype=np.float32
)
# create 2D ground truth heatmaps
for ii in valid_inds:
draw_gaussian(
y_2d_det[0, :],
(x_pos_start[ii], y_pos_low[ii]),
params["target_sigma"],
)
# draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
cls_id = ann["class_ids"][ii]
if cls_id > -1:
draw_gaussian(
y_2d_classes[cls_id, :],
(x_pos_start[ii], y_pos_low[ii]),
params["target_sigma"],
)
# draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
# be careful as this will have a 1.0 places where we have event but dont know gt class
# this will be masked in training anyway
y_2d_classes[num_classes, :] = 1.0 - y_2d_classes.sum(0)
y_2d_classes = y_2d_classes / y_2d_classes.sum(0)[np.newaxis, ...]
y_2d_classes[np.isnan(y_2d_classes)] = 0.0
return y_2d_det, y_2d_size, y_2d_classes, ann_aug
def draw_gaussian(heatmap, center, sigmax, sigmay=None):
# center is (x, y)
# this edits the heatmap inplace
if sigmay is None:
sigmay = sigmax
tmp_size = np.maximum(sigmax, sigmay) * 3
mu_x = int(center[0] + 0.5)
mu_y = int(center[1] + 0.5)
w, h = heatmap.shape[0], heatmap.shape[1]
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0:
return False
size = 2 * tmp_size + 1
x = np.arange(0, size, 1, np.float32)
y = x[:, np.newaxis]
x0 = y0 = size // 2
# g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
g = np.exp(
-((x - x0) ** 2) / (2 * sigmax**2)
- ((y - y0) ** 2) / (2 * sigmay**2)
)
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
img_x = max(0, ul[0]), min(br[0], h)
img_y = max(0, ul[1]), min(br[1], w)
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]] = np.maximum(
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]],
g[g_y[0] : g_y[1], g_x[0] : g_x[1]],
)
return True
def pad_aray(ip_array, pad_size):
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int) * -1))
def warp_spec_aug(spec, ann, return_spec_for_viz, params):
# This is messy
# Augment spectrogram by randomly stretch and squeezing
# NOTE this also changes the start and stop time in place
# not taking care of spec for viz
if return_spec_for_viz:
assert False
delta = params["stretch_squeeze_delta"]
op_size = (spec.shape[1], spec.shape[2])
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
resize_amt = int(spec.shape[2] * resize_fract_r)
if resize_amt >= spec.shape[2]:
spec_r = torch.cat(
(
spec,
torch.zeros(
(1, spec.shape[1], resize_amt - spec.shape[2]),
dtype=spec.dtype,
),
),
2,
)
else:
spec_r = spec[:, :, :resize_amt]
spec = F.interpolate(
spec_r.unsqueeze(0), size=op_size, mode="bilinear", align_corners=False
).squeeze(0)
ann["start_times"] *= 1.0 / resize_fract_r
ann["end_times"] *= 1.0 / resize_fract_r
return spec
def mask_time_aug(spec, params):
# Mask out a random block of time - repeat up to 3 times
# SpecAugment: A Simple Data Augmentation Methodfor Automatic Speech Recognition
fm = torchaudio.transforms.TimeMasking(
int(spec.shape[1] * params["mask_max_time_perc"])
)
for ii in range(np.random.randint(1, 4)):
spec = fm(spec)
return spec
def mask_freq_aug(spec, params):
# Mask out a random frequncy range - repeat up to 3 times
# SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
fm = torchaudio.transforms.FrequencyMasking(
int(spec.shape[1] * params["mask_max_freq_perc"])
)
for ii in range(np.random.randint(1, 4)):
spec = fm(spec)
return spec
def scale_vol_aug(spec, params):
return spec * np.random.random() * params["spec_amp_scaling"]
def echo_aug(audio, sampling_rate, params):
sample_offset = (
int(params["echo_max_delay"] * np.random.random() * sampling_rate) + 1
)
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
return audio
def resample_aug(audio, sampling_rate, params):
sampling_rate_old = sampling_rate
sampling_rate = np.random.choice(params["aug_sampling_rates"])
audio = librosa.resample(
audio,
orig_sr=sampling_rate_old,
target_sr=sampling_rate,
res_type="polyphase",
)
audio = au.pad_audio(
audio,
sampling_rate,
params["fft_win_length"],
params["fft_overlap"],
params["resize_factor"],
params["spec_divide_factor"],
params["spec_train_width"],
)
duration = audio.shape[0] / float(sampling_rate)
return audio, sampling_rate, duration
def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
if sampling_rate != sampling_rate2:
audio2 = librosa.resample(
audio2,
orig_sr=sampling_rate2,
target_sr=sampling_rate,
res_type="polyphase",
)
sampling_rate2 = sampling_rate
if audio2.shape[0] < num_samples:
audio2 = np.hstack(
(
audio2,
np.zeros((num_samples - audio2.shape[0]), dtype=audio2.dtype),
)
)
elif audio2.shape[0] > num_samples:
audio2 = audio2[:num_samples]
return audio2, sampling_rate2
def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
# resample so they are the same
audio2, sampling_rate2 = resample_audio(
audio.shape[0], sampling_rate, audio2, sampling_rate2
)
# # set mean and std to be the same
# audio2 = (audio2 - audio2.mean())
# audio2 = (audio2/audio2.std())*audio.std()
# audio2 = audio2 + audio.mean()
if (
ann["annotated"]
and (ann2["annotated"])
and (sampling_rate2 == sampling_rate)
and (audio.shape[0] == audio2.shape[0])
):
comb_weight = 0.3 + np.random.random() * 0.4
audio = comb_weight * audio + (1 - comb_weight) * audio2
inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"])))
for kk in ann.keys():
# when combining calls from different files, assume they come from different individuals
if kk == "individual_ids":
if (ann[kk] > -1).sum() > 0:
ann2[kk][ann2[kk] > -1] += np.max(ann[kk][ann[kk] > -1]) + 1
if (kk != "class_id_file") and (kk != "annotated"):
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
return audio, ann
class AudioLoader(torch.utils.data.Dataset):
def __init__(self, data_anns_ip, params, dataset_name=None, is_train=False):
self.data_anns = []
self.is_train = is_train
self.params = params
self.return_spec_for_viz = False
for ii in range(len(data_anns_ip)):
dd = copy.deepcopy(data_anns_ip[ii])
# filter out unused annotation here
filtered_annotations = []
for ii, aa in enumerate(dd["annotation"]):
if "individual" in aa.keys():
aa["individual"] = int(aa["individual"])
# if only one call labeled it has to be from the same individual
if len(dd["annotation"]) == 1:
aa["individual"] = 0
# convert class name into class label
if aa["class"] in self.params["class_names"]:
aa["class_id"] = self.params["class_names"].index(
aa["class"]
)
else:
aa["class_id"] = -1
if aa["class"] not in self.params["classes_to_ignore"]:
filtered_annotations.append(aa)
dd["annotation"] = filtered_annotations
dd["start_times"] = np.array(
[aa["start_time"] for aa in dd["annotation"]]
)
dd["end_times"] = np.array(
[aa["end_time"] for aa in dd["annotation"]]
)
dd["high_freqs"] = np.array(
[float(aa["high_freq"]) for aa in dd["annotation"]]
)
dd["low_freqs"] = np.array(
[float(aa["low_freq"]) for aa in dd["annotation"]]
)
dd["class_ids"] = np.array(
[aa["class_id"] for aa in dd["annotation"]]
).astype(np.int)
dd["individual_ids"] = np.array(
[aa["individual"] for aa in dd["annotation"]]
).astype(np.int)
# file level class name
dd["class_id_file"] = -1
if "class_name" in dd.keys():
if dd["class_name"] in self.params["class_names"]:
dd["class_id_file"] = self.params["class_names"].index(
dd["class_name"]
)
self.data_anns.append(dd)
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
self.max_num_anns = 2 * np.max(
ann_cnt
) # x2 because we may be combining files during training
print("\n")
if dataset_name is not None:
print("Dataset : " + dataset_name)
if self.is_train:
print("Split type : train")
else:
print("Split type : test")
print("Num files : " + str(len(self.data_anns)))
print("Num calls : " + str(np.sum(ann_cnt)))
def get_file_and_anns(self, index=None):
# if no file specified, choose random one
if index == None:
index = np.random.randint(0, len(self.data_anns))
audio_file = self.data_anns[index]["file_path"]
sampling_rate, audio_raw = au.load_audio(
audio_file,
self.data_anns[index]["time_exp"],
self.params["target_samp_rate"],
self.params["scale_raw_audio"],
)
# copy annotation
ann = {}
ann["annotated"] = self.data_anns[index]["annotated"]
ann["class_id_file"] = self.data_anns[index]["class_id_file"]
keys = [
"start_times",
"end_times",
"high_freqs",
"low_freqs",
"class_ids",
"individual_ids",
]
for kk in keys:
ann[kk] = self.data_anns[index][kk].copy()
# if train then grab a random crop
if self.is_train:
nfft = int(self.params["fft_win_length"] * sampling_rate)
noverlap = int(self.params["fft_overlap"] * nfft)
length_samples = (
self.params["spec_train_width"] * (nfft - noverlap) + noverlap
)
if audio_raw.shape[0] - length_samples > 0:
sample_crop = np.random.randint(
audio_raw.shape[0] - length_samples
)
else:
sample_crop = 0
audio_raw = audio_raw[sample_crop : sample_crop + length_samples]
ann["start_times"] = ann["start_times"] - sample_crop / float(
sampling_rate
)
ann["end_times"] = ann["end_times"] - sample_crop / float(
sampling_rate
)
# pad audio
if self.is_train:
op_spec_target_size = self.params["spec_train_width"]
else:
op_spec_target_size = None
audio_raw = au.pad_audio(
audio_raw,
sampling_rate,
self.params["fft_win_length"],
self.params["fft_overlap"],
self.params["resize_factor"],
self.params["spec_divide_factor"],
op_spec_target_size,
)
duration = audio_raw.shape[0] / float(sampling_rate)
# sort based on time
inds = np.argsort(ann["start_times"])
for kk in ann.keys():
if (kk != "class_id_file") and (kk != "annotated"):
ann[kk] = ann[kk][inds]
return audio_raw, sampling_rate, duration, ann
def __getitem__(self, index):
# load audio file
audio, sampling_rate, duration, ann = self.get_file_and_anns(index)
# augment on raw audio
if self.is_train and self.params["augment_at_train"]:
# augment - combine with random audio file
if (
self.params["augment_at_train_combine"]
and np.random.random() < self.params["aug_prob"]
):
(
audio2,
sampling_rate2,
duration2,
ann2,
) = self.get_file_and_anns()
audio, ann = combine_audio_aug(
audio, sampling_rate, ann, audio2, sampling_rate2, ann2
)
# simulate echo by adding delayed copy of the file
if np.random.random() < self.params["aug_prob"]:
audio = echo_aug(audio, sampling_rate, self.params)
# resample the audio
# if np.random.random() < self.params['aug_prob']:
# audio, sampling_rate, duration = resample_aug(audio, sampling_rate, self.params)
# create spectrogram
spec, spec_for_viz = au.generate_spectrogram(
audio, sampling_rate, self.params, self.return_spec_for_viz
)
rsf = self.params["resize_factor"]
spec_op_shape = (
int(self.params["spec_height"] * rsf),
int(spec.shape[1] * rsf),
)
# resize the spec
spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
spec = F.interpolate(
spec, size=spec_op_shape, mode="bilinear", align_corners=False
).squeeze(0)
# augment spectrogram
if self.is_train and self.params["augment_at_train"]:
if np.random.random() < self.params["aug_prob"]:
spec = scale_vol_aug(spec, self.params)
if np.random.random() < self.params["aug_prob"]:
spec = warp_spec_aug(
spec, ann, self.return_spec_for_viz, self.params
)
if np.random.random() < self.params["aug_prob"]:
spec = mask_time_aug(spec, self.params)
if np.random.random() < self.params["aug_prob"]:
spec = mask_freq_aug(spec, self.params)
outputs = {}
outputs["spec"] = spec
if self.return_spec_for_viz:
outputs["spec_for_viz"] = torch.from_numpy(spec_for_viz).unsqueeze(
0
)
# create ground truth heatmaps
(
outputs["y_2d_det"],
outputs["y_2d_size"],
outputs["y_2d_classes"],
ann_aug,
) = generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, self.params)
# hack to get around requirement that all vectors are the same length in
# the output batch
pad_size = self.max_num_anns - len(ann_aug["individual_ids"])
outputs["is_valid"] = pad_aray(
np.ones(len(ann_aug["individual_ids"])), pad_size
)
keys = [
"class_ids",
"individual_ids",
"x_inds",
"y_inds",
"start_times",
"end_times",
"low_freqs",
"high_freqs",
]
for kk in keys:
outputs[kk] = pad_aray(ann_aug[kk], pad_size)
# convert to pytorch
for kk in outputs.keys():
if type(outputs[kk]) != torch.Tensor:
outputs[kk] = torch.from_numpy(outputs[kk])
# scalars
outputs["class_id_file"] = ann["class_id_file"]
outputs["annotated"] = ann["annotated"]
outputs["duration"] = duration
outputs["sampling_rate"] = sampling_rate
outputs["file_id"] = index
return outputs
def __len__(self):
return len(self.data_anns)

View File

@ -1,14 +1,20 @@
import numpy as np import numpy as np
from sklearn.metrics import auc, roc_curve from sklearn.metrics import (
accuracy_score,
auc,
balanced_accuracy_score,
roc_curve,
)
def compute_error_auc(op_str, gt, pred, prob): def compute_error_auc(op_str, gt, pred, prob):
# classification error # classification error
pred_int = (pred > prob).astype(np.int32) pred_int = (pred > prob).astype(np.int)
class_acc = (pred_int == gt).mean() * 100.0 class_acc = (pred_int == gt).mean() * 100.0
# ROC - area under curve # ROC - area under curve
fpr, tpr, _ = roc_curve(gt, pred) fpr, tpr, thresholds = roc_curve(gt, pred)
roc_auc = auc(fpr, tpr) roc_auc = auc(fpr, tpr)
print( print(
@ -19,6 +25,7 @@ def compute_error_auc(op_str, gt, pred, prob):
def calc_average_precision(recall, precision): def calc_average_precision(recall, precision):
precision[np.isnan(precision)] = 0 precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0 recall[np.isnan(recall)] = 0
@ -84,6 +91,7 @@ def compute_pre_rec(
pred_class = [] pred_class = []
file_ids = [] file_ids = []
for pid, pp in enumerate(preds): for pid, pp in enumerate(preds):
# filter predicted calls that are too near the start or end of the file # filter predicted calls that are too near the start or end of the file
file_dur = gts[pid]["duration"] file_dur = gts[pid]["duration"]
valid_inds = (pp["start_times"] >= ignore_start_end) & ( valid_inds = (pp["start_times"] >= ignore_start_end) & (
@ -121,7 +129,7 @@ def compute_pre_rec(
file_ids.append([pid] * valid_inds.sum()) file_ids.append([pid] * valid_inds.sum())
confidence = np.hstack(confidence) confidence = np.hstack(confidence)
file_ids = np.hstack(file_ids).astype(int) file_ids = np.hstack(file_ids).astype(np.int)
pred_boxes = np.vstack(pred_boxes) pred_boxes = np.vstack(pred_boxes)
if len(pred_class) > 0: if len(pred_class) > 0:
pred_class = np.hstack(pred_class) pred_class = np.hstack(pred_class)
@ -133,6 +141,7 @@ def compute_pre_rec(
gt_generic_class = [] gt_generic_class = []
num_positives = 0 num_positives = 0
for gg in gts: for gg in gts:
# filter ground truth calls that are too near the start or end of the file # filter ground truth calls that are too near the start or end of the file
file_dur = gg["duration"] file_dur = gg["duration"]
valid_inds = (gg["start_times"] >= ignore_start_end) & ( valid_inds = (gg["start_times"] >= ignore_start_end) & (
@ -141,7 +150,8 @@ def compute_pre_rec(
# note, files with the incorrect duration will cause a problem # note, files with the incorrect duration will cause a problem
if (gg["start_times"] > file_dur).sum() > 0: if (gg["start_times"] > file_dur).sum() > 0:
raise ValueError(f"Error: file duration incorrect for {gg['id']}") print("Error: file duration incorrect for", gg["id"])
assert False
boxes = np.vstack( boxes = np.vstack(
( (
@ -187,8 +197,6 @@ def compute_pre_rec(
gt_id = file_ids[ind] gt_id = file_ids[ind]
valid_det = False valid_det = False
det_ind = 0
if gt_boxes[gt_id].shape[0] > 0: if gt_boxes[gt_id].shape[0] > 0:
# compute overlap # compute overlap
valid_det, det_ind = compute_affinity_1d( valid_det, det_ind = compute_affinity_1d(
@ -197,6 +205,7 @@ def compute_pre_rec(
# valid detection that has not already been assigned # valid detection that has not already been assigned
if valid_det and (gt_assigned[gt_id][det_ind] == 0): if valid_det and (gt_assigned[gt_id][det_ind] == 0):
count_as_true_pos = True count_as_true_pos = True
if eval_mode == "top_class" and ( if eval_mode == "top_class" and (
gt_class[gt_id][det_ind] != pred_class[ind] gt_class[gt_id][det_ind] != pred_class[ind]
@ -218,7 +227,7 @@ def compute_pre_rec(
# store threshold values - used for plotting # store threshold values - used for plotting
conf_sorted = np.sort(confidence)[::-1][valid_inds] conf_sorted = np.sort(confidence)[::-1][valid_inds]
thresholds = np.linspace(0.1, 0.9, 9) thresholds = np.linspace(0.1, 0.9, 9)
thresholds_inds = np.zeros(len(thresholds), dtype=int) thresholds_inds = np.zeros(len(thresholds), dtype=np.int)
for ii, tt in enumerate(thresholds): for ii, tt in enumerate(thresholds):
thresholds_inds[ii] = np.argmin(conf_sorted > tt) thresholds_inds[ii] = np.argmin(conf_sorted > tt)
thresholds_inds[thresholds_inds == 0] = -1 thresholds_inds[thresholds_inds == 0] = -1
@ -330,7 +339,7 @@ def compute_file_accuracy(gts, preds, num_classes):
).mean(0) ).mean(0)
best_thresh = np.argmax(acc_per_thresh) best_thresh = np.argmax(acc_per_thresh)
best_acc = acc_per_thresh[best_thresh] best_acc = acc_per_thresh[best_thresh]
pred_valid = pred_valid_all[:, best_thresh].astype(int).tolist() pred_valid = pred_valid_all[:, best_thresh].astype(np.int).tolist()
res = {} res = {}
res["num_valid_files"] = len(gt_valid) res["num_valid_files"] = len(gt_valid)

View File

@ -0,0 +1,63 @@
import torch
import torch.nn.functional as F
def bbox_size_loss(pred_size, gt_size):
"""
Bounding box size loss. Only compute loss where there is a bounding box.
"""
gt_size_mask = (gt_size > 0).float()
return F.l1_loss(pred_size * gt_size_mask, gt_size, reduction="sum") / (
gt_size_mask.sum() + 1e-5
)
def focal_loss(pred, gt, weights=None, valid_mask=None):
"""
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
pred (batch x c x h x w)
gt (batch x c x h x w)
"""
eps = 1e-5
beta = 4
alpha = 2
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds
neg_loss = (
torch.log(1 - pred + eps)
* torch.pow(pred, alpha)
* torch.pow(1 - gt, beta)
* neg_inds
)
if weights is not None:
pos_loss = pos_loss * weights
# neg_loss = neg_loss*weights
if valid_mask is not None:
pos_loss = pos_loss * valid_mask
neg_loss = neg_loss * valid_mask
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
num_pos = pos_inds.float().sum()
if num_pos == 0:
loss = -neg_loss
else:
loss = -(pos_loss + neg_loss) / num_pos
return loss
def mse_loss(pred, gt, weights=None, valid_mask=None):
"""
Mean squared error loss.
"""
if valid_mask is None:
op = ((gt - pred) ** 2).mean()
else:
op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
return op

View File

@ -1,21 +1,18 @@
## How to train a model from scratch ## How to train a model from scratch
`python train_model.py data_dir annotation_dir` e.g.
> **Warning**
> This code in currently broken. Will fix soon, stay tuned.
`python train_model.py data_dir annotation_dir` e.g.
`python train_model.py /data1/bat_data/data/ /data1/bat_data/annotations/anns/` `python train_model.py /data1/bat_data/data/ /data1/bat_data/annotations/anns/`
More comprehensive instructions are provided in the finetune directory. More comprehensive instructions are provided in the finetune directory.
## Training on your own data ## Training on your own data
You can either use the finetuning scripts to finetune from an existing training dataset. Follow the instructions in the `../finetune/` directory. You can either use the finetuning scripts to finetune from an existing training dataset. Follow the instructions in the `../finetune/` directory.
Alternatively, you can train from scratch. First, you will need to create your own annotation file (like in the finetune example), and then you will need to edit `train_split.py` to add your new dataset and specify which combination of files you want to train on. Alternatively, you can train from scratch. First, you will need to create your own annotation file (like in the finetune example), and then you will need to edit `train_split.py` to add your new dataset and specify which combination of files you want to train on.
Note, if training from scratch and you want to include the existing data, you may need to set all the class names to the generic class name ('Bat') so that the existing species are not added to your model, but instead just used to help perform the bat/not bat task. Note, if training from scratch and you want to include the existing data, you may need to set all the class names to the generic class name ('Bat') so that the existing species are not added to your model, but instead just used to help perform the bat/not bat task.
## Additional notes ## Additional notes
Having blank files with no bats in them is also useful, just make sure that the annotation files lists them as not being annotated (i.e. `is_annotated=True`). Having blank files with no bats in them is also useful, just make sure that the annotation files lists them as not being annotated (i.e. `is_annotated=True`).
Training will be slow without a GPU. Training will be slow without a GPU.

View File

@ -2,17 +2,16 @@ import argparse
import json import json
import warnings import warnings
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
import batdetect2.detector.post_process as pp
import batdetect2.train.audio_dataloader as adl import batdetect2.train.audio_dataloader as adl
import batdetect2.train.evaluate as evl import batdetect2.train.evaluate as evl
import batdetect2.train.train_split as ts import batdetect2.train.train_split as ts
import batdetect2.train.train_utils as tu import batdetect2.train.train_utils as tu
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.utils.data
from torch.optim.lr_scheduler import CosineAnnealingLR
import batdetect2.detector.post_process as pp
import batdetect2.utils.plot_utils as pu import batdetect2.utils.plot_utils as pu
from batdetect2.detector import models, parameters from batdetect2.detector import models, parameters
from batdetect2.train import losses from batdetect2.train import losses
@ -30,7 +29,7 @@ def save_images_batch(model, data_loader, params):
ind = 0 # first image in each batch ind = 0 # first image in each batch
with torch.no_grad(): with torch.no_grad():
for inputs in data_loader: for batch_idx, inputs in enumerate(data_loader):
data = inputs["spec"].to(params["device"]) data = inputs["spec"].to(params["device"])
outputs = model(data) outputs = model(data)
@ -82,12 +81,7 @@ def save_image(
def loss_fun( def loss_fun(
outputs, outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq
gt_det,
gt_size,
gt_class,
det_criterion,
params,
): ):
# detection loss # detection loss
loss = params["det_loss_weight"] * det_criterion( loss = params["det_loss_weight"] * det_criterion(
@ -110,13 +104,7 @@ def loss_fun(
def train( def train(
model, model, epoch, data_loader, det_criterion, optimizer, scheduler, params
epoch,
data_loader,
det_criterion,
optimizer,
scheduler,
params,
): ):
model.train() model.train()
@ -321,7 +309,7 @@ def select_model(params):
resize_factor=params["resize_factor"], resize_factor=params["resize_factor"],
) )
else: else:
raise ValueError("No valid network specified") print("No valid network specified")
return model return model
@ -331,9 +319,9 @@ def main():
params = parameters.get_params(True) params = parameters.get_params(True)
if torch.cuda.is_available(): if torch.cuda.is_available():
params.device = "cuda" params["device"] = "cuda"
else: else:
params.device = "cpu" params["device"] = "cpu"
# setup arg parser and populate it with exiting parameters - will not work with lists # setup arg parser and populate it with exiting parameters - will not work with lists
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -361,16 +349,13 @@ def main():
default="Rhinolophus ferrumequinum;Rhinolophus hipposideros", default="Rhinolophus ferrumequinum;Rhinolophus hipposideros",
help='Will set low and high frequency the same for these classes. Separate names with ";"', help='Will set low and high frequency the same for these classes. Separate names with ";"',
) )
for key, val in params.items(): for key, val in params.items():
parser.add_argument("--" + key, type=type(val), default=val) parser.add_argument("--" + key, type=type(val), default=val)
params = vars(parser.parse_args()) params = vars(parser.parse_args())
# save notes file # save notes file
if params["notes"] != "": if params["notes"] != "":
tu.write_notes_file( tu.write_notes_file(params["experiment"] + "notes.txt", params["notes"])
params["experiment"] + "notes.txt", params["notes"]
)
# load the training and test meta data - there are different splits defined # load the training and test meta data - there are different splits defined
train_sets, test_sets = ts.get_train_test_data( train_sets, test_sets = ts.get_train_test_data(
@ -389,11 +374,15 @@ def main():
for tt in train_sets: for tt in train_sets:
print(tt["ann_path"]) print(tt["ann_path"])
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"] classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
data_train = tu.load_set_of_anns( (
data_train,
params["class_names"],
params["class_inv_freq"],
) = tu.load_set_of_anns(
train_sets, train_sets,
classes_to_ignore=classes_to_ignore, classes_to_ignore,
events_of_interest=params["events_of_interest"], params["events_of_interest"],
convert_to_genus=params["convert_to_genus"], params["convert_to_genus"],
) )
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping( params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
params["class_names"] params["class_names"]
@ -426,12 +415,11 @@ def main():
print("\nTesting on:") print("\nTesting on:")
for tt in test_sets: for tt in test_sets:
print(tt["ann_path"]) print(tt["ann_path"])
data_test, _, _ = tu.load_set_of_anns(
data_test = tu.load_set_of_anns(
test_sets, test_sets,
classes_to_ignore=classes_to_ignore, classes_to_ignore,
events_of_interest=params["events_of_interest"], params["events_of_interest"],
convert_to_genus=params["convert_to_genus"], params["convert_to_genus"],
) )
data_train = tu.remove_dupes(data_train, data_test) data_train = tu.remove_dupes(data_train, data_test)
test_dataset = adl.AudioLoader(data_test, params, is_train=False) test_dataset = adl.AudioLoader(data_test, params, is_train=False)
@ -459,13 +447,10 @@ def main():
scheduler = CosineAnnealingLR( scheduler = CosineAnnealingLR(
optimizer, params["num_epochs"] * len(train_loader) optimizer, params["num_epochs"] * len(train_loader)
) )
if params["train_loss"] == "mse": if params["train_loss"] == "mse":
det_criterion = losses.mse_loss det_criterion = losses.mse_loss
elif params["train_loss"] == "focal": elif params["train_loss"] == "focal":
det_criterion = losses.focal_loss det_criterion = losses.focal_loss
else:
raise ValueError("No valid loss specified")
# save parameters to file # save parameters to file
with open(params["experiment"] + "params.json", "w") as da: with open(params["experiment"] + "params.json", "w") as da:

View File

@ -10,12 +10,13 @@ def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
train_sets, test_sets = split_same(ann_dir, wav_dir, load_extra) train_sets, test_sets = split_same(ann_dir, wav_dir, load_extra)
else: else:
print("Split not defined") print("Split not defined")
raise AssertionError() assert False
return train_sets, test_sets return train_sets, test_sets
def split_diff(ann_dir, wav_dir, load_extra=True): def split_diff(ann_dir, wav_dir, load_extra=True):
train_sets = [] train_sets = []
if load_extra: if load_extra:
train_sets.append( train_sets.append(
@ -143,6 +144,7 @@ def split_diff(ann_dir, wav_dir, load_extra=True):
def split_same(ann_dir, wav_dir, load_extra=True): def split_same(ann_dir, wav_dir, load_extra=True):
train_sets = [] train_sets = []
if load_extra: if load_extra:
train_sets.append( train_sets.append(

View File

@ -0,0 +1,207 @@
import glob
import json
import numpy as np
def write_notes_file(file_name, text):
with open(file_name, "a") as da:
da.write(text + "\n")
def get_blank_dataset_dict(dataset_name, is_test, ann_path, wav_path):
ddict = {
"dataset_name": dataset_name,
"is_test": is_test,
"is_binary": False,
"ann_path": ann_path,
"wav_path": wav_path,
}
return ddict
def get_short_class_names(class_names, str_len=3):
class_names_short = []
for cc in class_names:
class_names_short.append(
" ".join([sp[:str_len] for sp in cc.split(" ")])
)
return class_names_short
def remove_dupes(data_train, data_test):
test_ids = [dd["id"] for dd in data_test]
data_train_prune = []
for aa in data_train:
if aa["id"] not in test_ids:
data_train_prune.append(aa)
diff = len(data_train) - len(data_train_prune)
if diff != 0:
print(diff, "items removed from train set")
return data_train_prune
def get_genus_mapping(class_names):
genus_names, genus_mapping = np.unique(
[cc.split(" ")[0] for cc in class_names], return_inverse=True
)
return genus_names.tolist(), genus_mapping.tolist()
def standardize_low_freq(data, class_of_interest):
# address the issue of highly variable low frequency annotations
# this often happens for contstant frequency calls
# for the class of interest sets the low and high freq to be the dataset mean
low_freqs = []
high_freqs = []
for dd in data:
for aa in dd["annotation"]:
if aa["class"] == class_of_interest:
low_freqs.append(aa["low_freq"])
high_freqs.append(aa["high_freq"])
low_mean = np.mean(low_freqs)
high_mean = np.mean(high_freqs)
assert low_mean < high_mean
print("\nStandardizing low and high frequency for:")
print(class_of_interest)
print("low: ", round(low_mean, 2))
print("high: ", round(high_mean, 2))
# only set the low freq, high stays the same
# assumes that low_mean < high_mean
for dd in data:
for aa in dd["annotation"]:
if aa["class"] == class_of_interest:
aa["low_freq"] = low_mean
if aa["high_freq"] < low_mean:
aa["high_freq"] = high_mean
return data
def load_set_of_anns(
data,
classes_to_ignore=[],
events_of_interest=None,
convert_to_genus=False,
verbose=True,
list_of_anns=False,
filter_issues=False,
name_replace=False,
):
# load the annotations
anns = []
if list_of_anns:
# path to list of individual json files
anns.extend(load_anns_from_path(data["ann_path"], data["wav_path"]))
else:
# dictionary of datasets
for dd in data:
anns.extend(load_anns(dd["ann_path"], dd["wav_path"]))
# discarding unannoated files
anns = [aa for aa in anns if aa["annotated"] is True]
# filter files that have annotation issues - is the input is a dictionary of
# datasets, this will lilely have already been done
if filter_issues:
anns = [aa for aa in anns if aa["issues"] is False]
# check for some basic formatting errors with class names
for ann in anns:
for aa in ann["annotation"]:
aa["class"] = aa["class"].strip()
# only load specified events - i.e. types of calls
if events_of_interest is not None:
for ann in anns:
filtered_events = []
for aa in ann["annotation"]:
if aa["event"] in events_of_interest:
filtered_events.append(aa)
ann["annotation"] = filtered_events
# change class names
# replace_names will be a dictionary mapping input name to output
if type(name_replace) is dict:
for ann in anns:
for aa in ann["annotation"]:
if aa["class"] in name_replace:
aa["class"] = name_replace[aa["class"]]
# convert everything to genus name
if convert_to_genus:
for ann in anns:
for aa in ann["annotation"]:
aa["class"] = aa["class"].split(" ")[0]
# get unique class names
class_names_all = []
for ann in anns:
for aa in ann["annotation"]:
if aa["class"] not in classes_to_ignore:
class_names_all.append(aa["class"])
class_names, class_cnts = np.unique(class_names_all, return_counts=True)
class_inv_freq = class_cnts.sum() / (
len(class_names) * class_cnts.astype(np.float32)
)
if verbose:
print("Class count:")
str_len = np.max([len(cc) for cc in class_names]) + 5
for cc in range(len(class_names)):
print(
str(cc).ljust(5)
+ class_names[cc].ljust(str_len)
+ str(class_cnts[cc])
)
if len(classes_to_ignore) == 0:
return anns
else:
return anns, class_names.tolist(), class_inv_freq.tolist()
def load_anns(ann_file_name, raw_audio_dir):
with open(ann_file_name) as da:
anns = json.load(da)
for aa in anns:
aa["file_path"] = raw_audio_dir + aa["id"]
return anns
def load_anns_from_path(ann_file_dir, raw_audio_dir):
files = glob.glob(ann_file_dir + "*.json")
anns = []
for ff in files:
with open(ff) as da:
ann = json.load(da)
ann["file_path"] = raw_audio_dir + ann["id"]
anns.append(ann)
return anns
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count

View File

@ -1,14 +1,29 @@
"""Types used in the code base.""" """Types used in the code base."""
import sys from typing import List, NamedTuple, Optional, Union, Any, BinaryIO
from typing import Any, NamedTuple, Protocol, TypedDict
import audioread
import os
import soundfile as sf
import numpy as np import numpy as np
import torch import torch
if sys.version_info >= (3, 11): try:
from typing import NotRequired from typing import TypedDict
else: except ImportError:
from typing_extensions import TypedDict
try:
from typing import Protocol
except ImportError:
from typing_extensions import Protocol
try:
from typing import NotRequired # type: ignore
except ImportError:
from typing_extensions import NotRequired from typing_extensions import NotRequired
@ -16,7 +31,8 @@ __all__ = [
"Annotation", "Annotation",
"DetectionModel", "DetectionModel",
"FeatureExtractionParameters", "FeatureExtractionParameters",
"FileAnnotation", "FeatureExtractor",
"FileAnnotations",
"ModelOutput", "ModelOutput",
"ModelParameters", "ModelParameters",
"NonMaximumSuppressionConfig", "NonMaximumSuppressionConfig",
@ -26,9 +42,11 @@ __all__ = [
"ResultParams", "ResultParams",
"RunResults", "RunResults",
"SpectrogramParameters", "SpectrogramParameters",
"AudioLoaderAnnotationGroup",
] ]
AudioPath = Union[
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
]
class SpectrogramParameters(TypedDict): class SpectrogramParameters(TypedDict):
"""Parameters for generating spectrograms.""" """Parameters for generating spectrograms."""
@ -82,11 +100,8 @@ class ModelParameters(TypedDict):
resize_factor: float resize_factor: float
"""Resize factor.""" """Resize factor."""
class_names: list[str] class_names: List[str]
"""Class names. """Class names. The model is trained to detect these classes."""
The model is trained to detect these classes.
"""
DictWithClass = TypedDict("DictWithClass", {"class": str}) DictWithClass = TypedDict("DictWithClass", {"class": str})
@ -95,8 +110,8 @@ DictWithClass = TypedDict("DictWithClass", {"class": str})
class Annotation(DictWithClass): class Annotation(DictWithClass):
"""Format of annotations. """Format of annotations.
This is the format of a single annotation as expected by the This is the format of a single annotation as expected by the annotation
annotation tool. tool.
""" """
start_time: float start_time: float
@ -105,10 +120,10 @@ class Annotation(DictWithClass):
end_time: float end_time: float
"""End time in seconds.""" """End time in seconds."""
low_freq: float low_freq: int
"""Low frequency in Hz.""" """Low frequency in Hz."""
high_freq: float high_freq: int
"""High frequency in Hz.""" """High frequency in Hz."""
class_prob: float class_prob: float
@ -123,11 +138,8 @@ class Annotation(DictWithClass):
event: str event: str
"""Type of detected event.""" """Type of detected event."""
class_id: NotRequired[int]
"""Numeric ID for the class of the annotation."""
class FileAnnotations(TypedDict):
class FileAnnotation(TypedDict):
"""Format of results. """Format of results.
This is the format of the results expected by the annotation tool. This is the format of the results expected by the annotation tool.
@ -149,41 +161,41 @@ class FileAnnotation(TypedDict):
"""Time expansion factor.""" """Time expansion factor."""
class_name: str class_name: str
"""Class predicted at file level.""" """Class predicted at file level"""
notes: str notes: str
"""Notes of file.""" """Notes of file."""
annotation: list[Annotation] annotation: List[Annotation]
"""List of annotations.""" """List of annotations."""
class RunResults(TypedDict): class RunResults(TypedDict):
"""Run results.""" """Run results."""
pred_dict: FileAnnotation pred_dict: FileAnnotations
"""Predictions in the format expected by the annotation tool.""" """Predictions in the format expected by the annotation tool."""
spec_feats: NotRequired[list[np.ndarray]] spec_feats: NotRequired[List[np.ndarray]]
"""Spectrogram features.""" """Spectrogram features."""
spec_feat_names: NotRequired[list[str]] spec_feat_names: NotRequired[List[str]]
"""Spectrogram feature names.""" """Spectrogram feature names."""
cnn_feats: NotRequired[list[np.ndarray]] cnn_feats: NotRequired[List[np.ndarray]]
"""CNN features.""" """CNN features."""
cnn_feat_names: NotRequired[list[str]] cnn_feat_names: NotRequired[List[str]]
"""CNN feature names.""" """CNN feature names."""
spec_slices: NotRequired[list[np.ndarray]] spec_slices: NotRequired[List[np.ndarray]]
"""Spectrogram slices.""" """Spectrogram slices."""
class ResultParams(TypedDict): class ResultParams(TypedDict):
"""Result parameters.""" """Result parameters."""
class_names: list[str] class_names: List[str]
"""Class names.""" """Class names."""
spec_features: bool spec_features: bool
@ -230,13 +242,13 @@ class ProcessingConfiguration(TypedDict):
scale_raw_audio: bool scale_raw_audio: bool
"""Whether to scale the raw audio to be between -1 and 1.""" """Whether to scale the raw audio to be between -1 and 1."""
class_names: list[str] class_names: List[str]
"""Names of the classes the model can detect.""" """Names of the classes the model can detect."""
detection_threshold: float detection_threshold: float
"""Threshold for detection probability.""" """Threshold for detection probability."""
time_expansion: float | None time_expansion: Optional[float]
"""Time expansion factor of the processed recordings.""" """Time expansion factor of the processed recordings."""
top_n: int top_n: int
@ -245,7 +257,7 @@ class ProcessingConfiguration(TypedDict):
return_raw_preds: bool return_raw_preds: bool
"""Whether to return raw predictions.""" """Whether to return raw predictions."""
max_duration: float | None max_duration: Optional[float]
"""Maximum duration of audio file to process in seconds.""" """Maximum duration of audio file to process in seconds."""
nms_kernel_size: int nms_kernel_size: int
@ -386,9 +398,9 @@ class PredictionResults(TypedDict):
class DetectionModel(Protocol): class DetectionModel(Protocol):
"""Protocol for detection models. """Protocol for detection models.
This protocol is used to define the interface for the detection This protocol is used to define the interface for the detection models.
models. This allows us to use the same code for training and This allows us to use the same code for training and inference, even
inference, even though the models are different. though the models are different.
""" """
num_classes: int num_classes: int
@ -408,14 +420,16 @@ class DetectionModel(Protocol):
def forward( def forward(
self, self,
spec: torch.Tensor, ip: torch.Tensor,
return_feats: bool = False,
) -> ModelOutput: ) -> ModelOutput:
"""Forward pass of the model.""" """Forward pass of the model."""
... ...
def __call__( def __call__(
self, self,
spec: torch.Tensor, ip: torch.Tensor,
return_feats: bool = False,
) -> ModelOutput: ) -> ModelOutput:
"""Forward pass of the model.""" """Forward pass of the model."""
... ...
@ -462,7 +476,7 @@ class FeatureExtractionParameters(TypedDict):
class HeatmapParameters(TypedDict): class HeatmapParameters(TypedDict):
"""Parameters that control the heatmap generation function.""" """Parameters that control the heatmap generation function."""
class_names: list[str] class_names: List[str]
fft_win_length: float fft_win_length: float
"""Length of the FFT window in seconds.""" """Length of the FFT window in seconds."""
@ -480,10 +494,8 @@ class HeatmapParameters(TypedDict):
"""Maximum frequency to consider in Hz.""" """Maximum frequency to consider in Hz."""
target_sigma: float target_sigma: float
"""Sigma for the Gaussian kernel. """Sigma for the Gaussian kernel. Controls the width of the points in
the heatmap."""
Controls the width of the points in the heatmap.
"""
class AnnotationGroup(TypedDict): class AnnotationGroup(TypedDict):
@ -511,15 +523,6 @@ class AnnotationGroup(TypedDict):
individual_ids: np.ndarray individual_ids: np.ndarray
"""Individual IDs of the annotations.""" """Individual IDs of the annotations."""
annotated: NotRequired[bool]
"""Wether the annotation group is complete or not.
Usually annotation groups are associated to a single audio clip. If
the annotation group is complete, it means that all relevant sound
events have been annotated. If it is not complete, it means that
some sound events might not have been annotated.
"""
x_inds: NotRequired[np.ndarray] x_inds: NotRequired[np.ndarray]
"""X coordinate of the annotations in the spectrogram.""" """X coordinate of the annotations in the spectrogram."""
@ -527,87 +530,9 @@ class AnnotationGroup(TypedDict):
"""Y coordinate of the annotations in the spectrogram.""" """Y coordinate of the annotations in the spectrogram."""
class AudioLoaderAnnotationGroup(TypedDict):
"""Group of annotation items for the training audio loader.
This class is used to store the annotations for the training audio
loader. It inherits from `AnnotationGroup` and `FileAnnotations`.
"""
id: str
duration: float
issues: bool
file_path: str
time_exp: float
class_name: str
notes: str
start_times: np.ndarray
end_times: np.ndarray
low_freqs: np.ndarray
high_freqs: np.ndarray
class_ids: np.ndarray
individual_ids: np.ndarray
x_inds: np.ndarray
y_inds: np.ndarray
annotation: list[Annotation]
annotated: bool
class_id_file: int
"""ID of the class of the file."""
class AudioLoaderParameters(TypedDict):
class_names: list[str]
classes_to_ignore: list[str]
target_samp_rate: int
scale_raw_audio: bool
fft_win_length: float
fft_overlap: float
spec_train_width: int
resize_factor: float
spec_divide_factor: int
augment_at_train: bool
augment_at_train_combine: bool
aug_prob: float
spec_height: int
echo_max_delay: float
spec_amp_scaling: float
stretch_squeeze_delta: float
mask_max_time_perc: float
mask_max_freq_perc: float
max_freq: float
min_freq: float
spec_scale: str
denoise_spec_avg: bool
max_scale_spec: bool
target_sigma: float
class FeatureExtractor(Protocol): class FeatureExtractor(Protocol):
def __call__( """Protocol for feature extractors."""
self,
prediction: Prediction,
**kwargs: Any,
) -> float: ...
def __call__(self, prediction: Prediction, **kwargs) -> Union[float, int]:
class DatasetDict(TypedDict): """Extract features from a prediction."""
"""Dataset dictionary. ...
This is the format of the dictionary that contains the dataset
information.
"""
dataset_name: str
"""Name of the dataset."""
is_test: bool
"""Whether the dataset is a test set."""
is_binary: bool
"""Whether the dataset is binary."""
ann_path: str
"""Path to the annotations."""
wav_path: str
"""Path to the audio files."""

View File

@ -1,16 +1,24 @@
import warnings import warnings
from typing import Optional, Tuple, Union, Any, BinaryIO
from ..types import AudioPath
import librosa import librosa
import librosa.core.spectrum import librosa.core.spectrum
import numpy as np import numpy as np
import torch import torch
import audioread
import os
import soundfile as sf
from batdetect2.detector import parameters from batdetect2.detector import parameters
from . import wavfile from . import wavfile
__all__ = [ __all__ = [
"load_audio", "load_audio",
"load_audio_and_samplerate",
"generate_spectrogram", "generate_spectrogram",
"pad_audio", "pad_audio",
] ]
@ -77,7 +85,7 @@ def generate_spectrogram(
spec = np.vstack( spec = np.vstack(
(np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec) (np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec)
) )
spec = spec[-max_freq : spec.shape[0] - min_freq, :] spec_cropped = spec[-max_freq : spec.shape[0] - min_freq, :]
if params["spec_scale"] == "log": if params["spec_scale"] == "log":
log_scaling = ( log_scaling = (
@ -89,7 +97,7 @@ def generate_spectrogram(
np.abs( np.abs(
np.hanning( np.hanning(
int(params["fft_win_length"] * sampling_rate) int(params["fft_win_length"] * sampling_rate)
).astype(np.float32) )
) )
** 2 ** 2
).sum() ).sum()
@ -97,9 +105,9 @@ def generate_spectrogram(
) )
# log_scaling = (1.0 / sampling_rate)*0.1 # log_scaling = (1.0 / sampling_rate)*0.1
# log_scaling = (1.0 / sampling_rate)*10e4 # log_scaling = (1.0 / sampling_rate)*10e4
spec = np.log1p(log_scaling * spec) spec = np.log1p(log_scaling * spec_cropped)
elif params["spec_scale"] == "pcen": elif params["spec_scale"] == "pcen":
spec = pcen(spec, sampling_rate) spec = pcen(spec_cropped, sampling_rate)
elif params["spec_scale"] == "none": elif params["spec_scale"] == "none":
pass pass
@ -133,62 +141,81 @@ def generate_spectrogram(
).sum() ).sum()
) )
) )
spec_for_viz = np.log1p(log_scaling * spec).astype(np.float32) spec_for_viz = np.log1p(log_scaling * spec_cropped).astype(np.float32)
else: else:
spec_for_viz = None spec_for_viz = None
return spec, spec_for_viz return spec, spec_for_viz
def load_audio( def load_audio(
audio_file: str, path: AudioPath,
time_exp_fact: float, time_exp_fact: float,
target_samp_rate: int, target_samp_rate: int,
scale: bool = False, scale: bool = False,
max_duration: float | None = None, max_duration: Optional[float] = None,
) -> tuple[int, np.ndarray]: ) -> Tuple[int, np.ndarray ]:
"""Load an audio file and resample it to the target sampling rate. """Load an audio file and resample it to the target sampling rate.
The audio is also scaled to [-1, 1] and clipped to the maximum duration. The audio is also scaled to [-1, 1] and clipped to the maximum duration.
Only mono files are supported. Only mono files are supported.
Parameters Args:
---------- path (string, int, pathlib.Path, soundfile.SoundFile, audioread object, or file-like object): path to the input file.
audio_file: str target_samp_rate (int): Target sampling rate.
Path to the audio file. scale (bool): Whether to scale the audio to [-1, 1].
target_samp_rate: int max_duration (float): Maximum duration of the audio in seconds.
Target sampling rate.
scale: bool, optional
Whether to scale the audio to [-1, 1]. Default: False.
max_duration: float, optional
Maximum duration of the audio in seconds. Defaults to None.
If provided, the audio is clipped to this duration.
Returns Returns:
------- sampling_rate: The sampling rate of the audio.
sampling_rate: int audio_raw: The audio signal in a numpy array.
The sampling rate of the audio.
audio_raw: np.ndarray
The audio signal in a numpy array.
Raises Raises:
------ ValueError: If the audio file is stereo.
ValueError: If the audio file is stereo.
"""
sample_rate, audio_data, _ = load_audio_and_samplerate(path, time_exp_fact, target_samp_rate, scale, max_duration)
return sample_rate, audio_data
def load_audio_and_samplerate(
path: AudioPath,
time_exp_fact: float,
target_samp_rate: int,
scale: bool = False,
max_duration: Optional[float] = None,
) -> Tuple[int, np.ndarray, Union[float, int]]:
"""Load an audio file and resample it to the target sampling rate.
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
Only mono files are supported.
Args:
path (string, int, pathlib.Path, soundfile.SoundFile, audioread object, or file-like object): path to the input file.
target_samp_rate (int): Target sampling rate.
scale (bool): Whether to scale the audio to [-1, 1].
max_duration (float): Maximum duration of the audio in seconds.
Returns:
sampling_rate: The sampling rate of the audio.
audio_raw: The audio signal in a numpy array.
file_sampling_rate: The original sampling rate of the audio
Raises:
ValueError: If the audio file is stereo.
""" """
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning) warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
# sampling_rate, audio_raw = wavfile.read(audio_file) # sampling_rate, audio_raw = wavfile.read(audio_file)
audio_raw, sampling_rate = librosa.load( audio_raw, file_sampling_rate = librosa.load(
audio_file, path,
sr=None, sr=None,
dtype=np.float32, dtype=np.float32,
) )
if len(audio_raw.shape) > 1: if len(audio_raw.shape) > 1:
raise ValueError("Currently does not handle stereo files") raise ValueError("Currently does not handle stereo files")
sampling_rate = sampling_rate * time_exp_fact sampling_rate = file_sampling_rate * time_exp_fact
# resample - need to do this after correcting for time expansion # resample - need to do this after correcting for time expansion
sampling_rate_old = sampling_rate sampling_rate_old = sampling_rate
@ -216,7 +243,7 @@ def load_audio(
audio_raw = audio_raw - audio_raw.mean() audio_raw = audio_raw - audio_raw.mean()
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6) audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
return sampling_rate, audio_raw return sampling_rate, audio_raw, file_sampling_rate
def compute_spectrogram_width( def compute_spectrogram_width(
@ -240,7 +267,7 @@ def pad_audio(
window_overlap: float = parameters.FFT_OVERLAP, window_overlap: float = parameters.FFT_OVERLAP,
resize_factor: float = parameters.RESIZE_FACTOR, resize_factor: float = parameters.RESIZE_FACTOR,
divide_factor: int = parameters.SPEC_DIVIDE_FACTOR, divide_factor: int = parameters.SPEC_DIVIDE_FACTOR,
fixed_width: int | None = None, fixed_width: Optional[int] = None,
): ):
"""Pad audio to be evenly divisible by `divide_factor`. """Pad audio to be evenly divisible by `divide_factor`.

View File

@ -1,8 +1,9 @@
import json import json
import os import os
from typing import Any, Iterator from typing import Any, Iterator, List, Optional, Tuple, Union, BinaryIO
from ..types import AudioPath
import librosa
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
@ -21,7 +22,7 @@ from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
from batdetect2.types import ( from batdetect2.types import (
Annotation, Annotation,
DetectionModel, DetectionModel,
FileAnnotation, FileAnnotations,
ModelOutput, ModelOutput,
ModelParameters, ModelParameters,
PredictionResults, PredictionResults,
@ -31,6 +32,13 @@ from batdetect2.types import (
SpectrogramParameters, SpectrogramParameters,
) )
import audioread
import os
import io
import soundfile as sf
import hashlib
import uuid
__all__ = [ __all__ = [
"load_model", "load_model",
"list_audio_files", "list_audio_files",
@ -60,7 +68,7 @@ def get_default_bd_args():
return args return args
def list_audio_files(ip_dir: str) -> list[str]: def list_audio_files(ip_dir: str) -> List[str]:
"""Get all audio files in directory. """Get all audio files in directory.
Args: Args:
@ -84,9 +92,9 @@ def list_audio_files(ip_dir: str) -> list[str]:
def load_model( def load_model(
model_path: str = DEFAULT_MODEL_PATH, model_path: str = DEFAULT_MODEL_PATH,
load_weights: bool = True, load_weights: bool = True,
device: torch.device | str | None = None, device: Optional[torch.device] = None,
weights_only: bool = True, weights_only: bool = True,
) -> tuple[DetectionModel, ModelParameters]: ) -> Tuple[DetectionModel, ModelParameters]:
"""Load model from file. """Load model from file.
Args: Args:
@ -185,28 +193,26 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
def get_annotations_from_preds( def get_annotations_from_preds(
predictions: PredictionResults, predictions: PredictionResults,
class_names: list[str], class_names: List[str],
) -> list[Annotation]: ) -> List[Annotation]:
"""Get list of annotations from predictions.""" """Get list of annotations from predictions."""
# Get the best class prediction probability and index for each detection # Get the best class prediction probability and index for each detection
class_prob_best = predictions["class_probs"].max(0) class_prob_best = predictions["class_probs"].max(0)
class_ind_best = predictions["class_probs"].argmax(0) class_ind_best = predictions["class_probs"].argmax(0)
# Pack the results into a list of dictionaries # Pack the results into a list of dictionaries
annotations: list[Annotation] = [ annotations: List[Annotation] = [
Annotation( {
{ "start_time": round(float(start_time), 4),
"start_time": round(float(start_time), 4), "end_time": round(float(end_time), 4),
"end_time": round(float(end_time), 4), "low_freq": int(low_freq),
"low_freq": int(low_freq), "high_freq": int(high_freq),
"high_freq": int(high_freq), "class": str(class_names[class_index]),
"class": str(class_names[class_index]), "class_prob": round(float(class_prob), 3),
"class_prob": round(float(class_prob), 3), "det_prob": round(float(det_prob), 3),
"det_prob": round(float(det_prob), 3), "individual": "-1",
"individual": "-1", "event": "Echolocation",
"event": "Echolocation", }
}
)
for ( for (
start_time, start_time,
end_time, end_time,
@ -223,7 +229,6 @@ def get_annotations_from_preds(
class_ind_best, class_ind_best,
class_prob_best, class_prob_best,
predictions["det_probs"], predictions["det_probs"],
strict=False,
) )
] ]
return annotations return annotations
@ -234,8 +239,8 @@ def format_single_result(
time_exp: float, time_exp: float,
duration: float, duration: float,
predictions: PredictionResults, predictions: PredictionResults,
class_names: list[str], class_names: List[str],
) -> FileAnnotation: ) -> FileAnnotations:
"""Format results into the format expected by the annotation tool. """Format results into the format expected by the annotation tool.
Args: Args:
@ -282,7 +287,7 @@ def convert_results(
spec_feats, spec_feats,
cnn_feats, cnn_feats,
spec_slices, spec_slices,
nyquist_freq: float | None = None, nyquist_freq: Optional[float] = None,
) -> RunResults: ) -> RunResults:
"""Convert results to dictionary as expected by the annotation tool. """Convert results to dictionary as expected by the annotation tool.
@ -317,11 +322,9 @@ def convert_results(
] ]
# combine into final results dictionary # combine into final results dictionary
results: RunResults = RunResults( # type: ignore[missing-argument] results: RunResults = {
{ "pred_dict": pred_dict,
"pred_dict": pred_dict, }
}
)
# add spectrogram features if they exist # add spectrogram features if they exist
if len(spec_feats) > 0 and params["spec_features"]: if len(spec_feats) > 0 and params["spec_features"]:
@ -417,7 +420,8 @@ def compute_spectrogram(
sampling_rate: int, sampling_rate: int,
params: SpectrogramParameters, params: SpectrogramParameters,
device: torch.device, device: torch.device,
) -> tuple[float, torch.Tensor]: return_np: bool = False,
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
"""Compute a spectrogram from an audio array. """Compute a spectrogram from an audio array.
Will pad the audio array so that it is evenly divisible by the Will pad the audio array so that it is evenly divisible by the
@ -426,16 +430,24 @@ def compute_spectrogram(
Parameters Parameters
---------- ----------
audio : np.ndarray audio : np.ndarray
sampling_rate : int sampling_rate : int
params : SpectrogramParameters params : SpectrogramParameters
The parameters to use for generating the spectrogram. The parameters to use for generating the spectrogram.
return_np : bool, optional
Whether to return the spectrogram as a numpy array as well as a
torch tensor. The default is False.
Returns Returns
------- -------
duration : float duration : float
The duration of the spectrgram in seconds. The duration of the spectrgram in seconds.
spec : torch.Tensor spec : torch.Tensor
The spectrogram as a torch tensor. The spectrogram as a torch tensor.
spec_np : np.ndarray, optional spec_np : np.ndarray, optional
The spectrogram as a numpy array. Only returned if `return_np` is The spectrogram as a numpy array. Only returned if `return_np` is
True, otherwise None. True, otherwise None.
@ -472,14 +484,20 @@ def compute_spectrogram(
mode="bilinear", mode="bilinear",
align_corners=False, align_corners=False,
) )
return duration, spec
if return_np:
spec_np = spec[0, 0, :].cpu().data.numpy()
else:
spec_np = None
return duration, spec, spec_np
def iterate_over_chunks( def iterate_over_chunks(
audio: np.ndarray, audio: np.ndarray,
samplerate: float, samplerate: int,
chunk_size: float, chunk_size: float,
) -> Iterator[tuple[float, np.ndarray]]: ) -> Iterator[Tuple[float, np.ndarray]]:
"""Iterate over audio in chunks of size chunk_size. """Iterate over audio in chunks of size chunk_size.
Parameters Parameters
@ -510,10 +528,10 @@ def iterate_over_chunks(
def _process_spectrogram( def _process_spectrogram(
spec: torch.Tensor, spec: torch.Tensor,
samplerate: float, samplerate: int,
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
) -> tuple[PredictionResults, np.ndarray]: ) -> Tuple[PredictionResults, np.ndarray]:
# evaluate model # evaluate model
with torch.no_grad(): with torch.no_grad():
outputs = model(spec) outputs = model(spec)
@ -550,7 +568,7 @@ def postprocess_model_outputs(
outputs: ModelOutput, outputs: ModelOutput,
samp_rate: int, samp_rate: int,
config: ProcessingConfiguration, config: ProcessingConfiguration,
) -> tuple[list[Annotation], np.ndarray]: ) -> Tuple[List[Annotation], np.ndarray]:
# run non-max suppression # run non-max suppression
pred_nms_list, features = pp.run_nms( pred_nms_list, features = pp.run_nms(
outputs, outputs,
@ -589,7 +607,7 @@ def process_spectrogram(
samplerate: int, samplerate: int,
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
) -> tuple[list[Annotation], np.ndarray]: ) -> Tuple[List[Annotation], np.ndarray]:
"""Process a spectrogram with detection model. """Process a spectrogram with detection model.
Will run non-maximum suppression on the output of the model. Will run non-maximum suppression on the output of the model.
@ -608,9 +626,9 @@ def process_spectrogram(
Returns Returns
------- -------
detections detections: List[Annotation]
List of detections predicted by the model. List of detections predicted by the model.
features features : np.ndarray
An array of CNN features associated with each annotation. An array of CNN features associated with each annotation.
The array is of shape (num_detections, num_features). The array is of shape (num_detections, num_features).
Is empty if `config["cnn_features"]` is False. Is empty if `config["cnn_features"]` is False.
@ -636,9 +654,9 @@ def _process_audio_array(
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
device: torch.device, device: torch.device,
) -> tuple[PredictionResults, np.ndarray, torch.Tensor]: ) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]:
# load audio file and compute spectrogram # load audio file and compute spectrogram
_, spec = compute_spectrogram( _, spec, _ = compute_spectrogram(
audio, audio,
sampling_rate, sampling_rate,
{ {
@ -654,6 +672,7 @@ def _process_audio_array(
"max_scale_spec": config["max_scale_spec"], "max_scale_spec": config["max_scale_spec"],
}, },
device, device,
return_np=False,
) )
# process spectrogram with model # process spectrogram with model
@ -673,7 +692,7 @@ def process_audio_array(
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
device: torch.device, device: torch.device,
) -> tuple[list[Annotation], np.ndarray, torch.Tensor]: ) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
"""Process a single audio array with detection model. """Process a single audio array with detection model.
Parameters Parameters
@ -693,7 +712,7 @@ def process_audio_array(
Returns Returns
------- -------
annotations : list[Annotation] annotations : List[Annotation]
List of annotations predicted by the model. List of annotations predicted by the model.
features : np.ndarray features : np.ndarray
Array of CNN features associated with each annotation. Array of CNN features associated with each annotation.
@ -718,11 +737,12 @@ def process_audio_array(
def process_file( def process_file(
audio_file: str, path: AudioPath,
model: DetectionModel, model: DetectionModel,
config: ProcessingConfiguration, config: ProcessingConfiguration,
device: torch.device, device: torch.device,
) -> RunResults | Any: file_id: Optional[str] = None
) -> Union[RunResults, Any]:
"""Process a single audio file with detection model. """Process a single audio file with detection model.
Will split the audio file into chunks if it is too long and Will split the audio file into chunks if it is too long and
@ -730,7 +750,7 @@ def process_file(
Parameters Parameters
---------- ----------
audio_file : str path : AudioPath
Path to audio file. Path to audio file.
model : torch.nn.Module model : torch.nn.Module
@ -738,6 +758,9 @@ def process_file(
config : ProcessingConfiguration config : ProcessingConfiguration
Configuration for processing. Configuration for processing.
file_id: Optional[str],
Give the data an id. Defaults to the filename if path is a string. Otherwise an md5 will be calculated from the binary data.
Returns Returns
------- -------
@ -751,19 +774,17 @@ def process_file(
cnn_feats = [] cnn_feats = []
spec_slices = [] spec_slices = []
# Get original sampling rate
file_samp_rate = librosa.get_samplerate(audio_file)
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
# load audio file # load audio file
sampling_rate, audio_full = au.load_audio( sampling_rate, audio_full, file_samp_rate = au.load_audio_and_samplerate(
audio_file, path,
time_exp_fact=config.get("time_expansion", 1) or 1, time_exp_fact=config.get("time_expansion", 1) or 1,
target_samp_rate=config["target_samp_rate"], target_samp_rate=config["target_samp_rate"],
scale=config["scale_raw_audio"], scale=config["scale_raw_audio"],
max_duration=config.get("max_duration"), max_duration=config.get("max_duration"),
) )
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
# loop through larger file and split into chunks # loop through larger file and split into chunks
# TODO: fix so that it overlaps correctly and takes care of # TODO: fix so that it overlaps correctly and takes care of
# duplicate detections at borders # duplicate detections at borders
@ -801,6 +822,7 @@ def process_file(
cnn_feats.append(features[0]) cnn_feats.append(features[0])
if config["spec_slices"]: if config["spec_slices"]:
# FIX: This is not currently working. Returns empty slices
spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms)) spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms))
# Merge results from chunks # Merge results from chunks
@ -811,9 +833,13 @@ def process_file(
spec_slices, spec_slices,
) )
_file_id = file_id
if _file_id is None:
_file_id = _generate_id(path)
# convert results to a dictionary in the right format # convert results to a dictionary in the right format
results = convert_results( results = convert_results(
file_id=os.path.basename(audio_file), file_id=_file_id,
time_exp=config.get("time_expansion", 1) or 1, time_exp=config.get("time_expansion", 1) or 1,
duration=audio_full.shape[0] / float(sampling_rate), duration=audio_full.shape[0] / float(sampling_rate),
params=config, params=config,
@ -833,6 +859,22 @@ def process_file(
return results return results
def _generate_id(path: AudioPath) -> str:
""" Generate an id based on the path.
If the path is a str or PathLike it will parsed as the basename.
This should ensure backwards compatibility with previous versions.
"""
if isinstance(path, str) or isinstance(path, os.PathLike):
return os.path.basename(path)
elif isinstance(path, (BinaryIO, io.BytesIO)):
path.seek(0)
md5 = hashlib.md5(path.read()).hexdigest()
path.seek(0)
return md5
else:
return str(uuid.uuid4())
def summarize_results(results, predictions, config): def summarize_results(results, predictions, config):
"""Print summary of results.""" """Print summary of results."""

View File

@ -87,7 +87,9 @@ def save_ann_spec(
y_extent = [0, duration, min_freq, max_freq] y_extent = [0, duration, min_freq, max_freq]
plt.close("all") plt.close("all")
plt.figure(0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100) fig = plt.figure(
0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100
)
plt.imshow( plt.imshow(
spec, spec,
aspect="auto", aspect="auto",
@ -367,7 +369,7 @@ def plot_pr_curve_class(
# print(class_name) # print(class_name)
# plot the location of the confidence threshold values # plot the location of the confidence threshold values
for jj, _tt in enumerate(rr["thresholds"]): for jj, tt in enumerate(rr["thresholds"]):
ind = rr["thresholds_inds"][jj] ind = rr["thresholds_inds"][jj]
if ind > -1: if ind > -1:
plt.plot( plt.plot(
@ -415,9 +417,7 @@ def plot_confusion_matrix(
cm_norm = cm.sum(1) cm_norm = cm.sum(1)
valid_inds = np.where(cm_norm > 0)[0] valid_inds = np.where(cm_norm > 0)[0]
cm[valid_inds, :] = ( cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
)
cm[np.where(cm_norm == -0)[0], :] = np.nan cm[np.where(cm_norm == -0)[0], :] = np.nan
if verbose: if verbose:

View File

@ -133,7 +133,7 @@ class InteractivePlotter:
self.fig.canvas.mpl_connect("key_press_event", self.key_press) self.fig.canvas.mpl_connect("key_press_event", self.key_press)
def mouse_hover(self, event): def mouse_hover(self, event):
self.annot.get_visible() vis = self.annot.get_visible()
if event.inaxes == self.ax[0]: if event.inaxes == self.ax[0]:
cont, ind = self.low_dim_plt.contains(event) cont, ind = self.low_dim_plt.contains(event)
if cont: if cont:
@ -155,9 +155,9 @@ class InteractivePlotter:
# draw bounding box around call # draw bounding box around call
self.ax[1].patches[0].remove() self.ax[1].patches[0].remove()
spec_width_orig = self.spec_slices[self.current_id].shape[ spec_width_orig = self.spec_slices[self.current_id].shape[1] / (
1 1.0 + 2.0 * self.spec_pad
] / (1.0 + 2.0 * self.spec_pad) )
xx = w_diff + self.spec_pad * spec_width_orig xx = w_diff + self.spec_pad * spec_width_orig
ww = spec_width_orig ww = spec_width_orig
yy = self.call_info[self.current_id]["low_freq"] / 1000 yy = self.call_info[self.current_id]["low_freq"] / 1000
@ -183,9 +183,7 @@ class InteractivePlotter:
round(self.call_info[self.current_id]["start_time"], 3) round(self.call_info[self.current_id]["start_time"], 3)
) )
+ ", prob=" + ", prob="
+ str( + str(round(self.call_info[self.current_id]["det_prob"], 3))
round(self.call_info[self.current_id]["det_prob"], 3)
)
) )
self.ax[0].set_xlabel(info_str) self.ax[0].set_xlabel(info_str)

View File

@ -8,7 +8,6 @@ Functions
`write`: Write a numpy array as a WAV file. `write`: Write a numpy array as a WAV file.
""" """
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import os import os
@ -43,7 +42,7 @@ def _read_fmt_chunk(fid):
size, comp, noc, rate, sbytes, ba, bits = res size, comp, noc, rate, sbytes, ba, bits = res
if comp not in KNOWN_WAVE_FORMATS or size > 16: if comp not in KNOWN_WAVE_FORMATS or size > 16:
comp = WAVE_FORMAT_PCM comp = WAVE_FORMAT_PCM
warnings.warn("Unknown wave file format", WavFileWarning, stacklevel=2) warnings.warn("Unknown wave file format", WavFileWarning)
if size > 16: if size > 16:
fid.read(size - 16) fid.read(size - 16)
@ -157,6 +156,7 @@ def read(filename, mmap=False):
fid = open(filename, "rb") fid = open(filename, "rb")
try: try:
# some files seem to have the size recorded in the header greater than # some files seem to have the size recorded in the header greater than
# the actual file size. # the actual file size.
fid.seek(0, os.SEEK_END) fid.seek(0, os.SEEK_END)

View File

@ -1,20 +0,0 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

View File

@ -1,35 +0,0 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

View File

@ -1,78 +0,0 @@
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
project = "batdetect2"
copyright = "2025, Oisin Mac Aodha, Santiago Martinez Balvanera"
author = "Oisin Mac Aodha, Santiago Martinez Balvanera"
release = "2.0.0b1"
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinxcontrib.autodoc_pydantic",
"sphinx_click",
"numpydoc",
"myst_parser",
"sphinx_autodoc_typehints",
]
templates_path = ["_templates"]
exclude_patterns = []
source_suffix = {
".rst": "restructuredtext",
".txt": "markdown",
".md": "markdown",
}
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = "sphinx_book_theme"
html_static_path = ["_static"]
html_theme_options = {
"home_page_in_toc": True,
"show_navbar_depth": 2,
"show_toc_level": 2,
}
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"click": ("https://click.palletsprojects.com/en/stable/", None),
"librosa": ("https://librosa.org/doc/latest/", None),
"lightning": ("https://lightning.ai/docs/pytorch/stable/", None),
"loguru": ("https://loguru.readthedocs.io/en/stable/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"omegaconf": ("https://omegaconf.readthedocs.io/en/latest/", None),
"pytorch": ("https://pytorch.org/docs/stable/", None),
"soundevent": ("https://mbsantiago.github.io/soundevent/", None),
"pydantic": ("https://docs.pydantic.dev/latest/", None),
"xarray": ("https://docs.xarray.dev/en/stable/", None),
}
# -- Options for autodoc ------------------------------------------------------
autosummary_generate = False
autosummary_imported_members = True
autodoc_default_options = {
"members": True,
"undoc-members": False,
"private-members": False,
"special-members": False,
"inherited-members": False,
"show-inheritance": True,
"module-first": True,
}
numpydoc_show_class_members = False
numpydoc_show_inherited_class_members = False
numpydoc_class_members_toctree = False

View File

@ -1,34 +0,0 @@
# Development and contribution
Thanks for your interest in improving batdetect2.
## Ways to contribute
- Report bugs and request features on
[GitHub Issues](https://github.com/macaodha/batdetect2/issues)
- Improve docs by opening pull requests with clearer examples, fixes, or
missing workflows
- Contribute code for models, data handling, evaluation, or CLI workflows
## Basic contribution workflow
1. Open an issue (or comment on an existing one) so work is visible.
2. Create a branch for your change.
3. Run checks locally before opening a PR:
```bash
just check
just docs
```
4. Open a pull request with a clear summary of what changed and why.
## Development environment
Use `uv` for dependency and environment management.
```bash
uv sync
```
For more setup details, see {doc}`../getting_started`.

View File

@ -1,48 +0,0 @@
# Evaluation concepts and matching
Evaluation is not just "run predictions and compute one number".
The reported metric depends on the evaluation task, the matching rule, and the treatment of clip boundaries and generic labels.
## Task families answer different questions
Built-in task families include:
- sound event detection,
- sound event classification,
- top-class detection,
- clip detection,
- clip classification.
Choose the task that matches the scientific or engineering question.
## Matching matters
For sound-event-style tasks, predictions and annotations are matched using an affinity function.
Important controls include:
- `affinity`,
- `affinity_threshold`,
- `strict_match`,
- `ignore_start_end`.
Small changes here can change the reported metric without changing the underlying predictions.
## Boundary handling matters
The evaluation base task can exclude events near clip boundaries through `ignore_start_end`.
This is useful when clip boundaries make matches ambiguous.
## Generic labels can matter in classification
Classification tasks can include or exclude generic targets depending on configuration.
That affects what counts as a valid class-level comparison.
## Related pages
- Evaluate on a test set: {doc}`../tutorials/evaluate-on-a-test-set`
- Evaluation config reference: {doc}`../reference/evaluation-config`
- Model output and validation: {doc}`model-output-and-validation`

View File

@ -1,43 +0,0 @@
# Extracted features and embeddings
The current API exposes a per-detection `features` vector.
Older BatDetect2 workflows also exposed concepts such as `cnn_feats`,
`spec_features`, and `spec_slices`.
## What the current feature vector is
In the current stack, each retained detection can carry an internal feature
representation produced by the model output pipeline.
This is useful for downstream exploration, comparison, and custom analysis.
## What these features are not
They are not automatically human-interpretable ecological variables.
They are also not a substitute for careful validation.
## Why people refer to them as embeddings
In practice, users often treat these feature vectors as embeddings because they
can be used as dense learned representations of detections.
That usage is reasonable, but you should still treat them as model-derived
internal representations whose meaning depends on the training setup.
## Legacy terminology versus current terminology
- legacy `cnn_feats` referred to CNN feature outputs in the older workflow,
- legacy `spec_features` referred to lower-level extracted call features,
- current `features` are the per-detection vectors attached to `Detection`
objects.
These are related ideas, but not necessarily one-to-one replacements.
## Related pages
- Inspect detection features in Python:
{doc}`../how_to/inspect-detection-features-in-python`
- Legacy migration guide:
{doc}`../legacy/migration-guide`

View File

@ -1,19 +0,0 @@
# Understanding
Understanding pages explain how BatDetect2 works, what its outputs mean, and how to reason about trade-offs.
Use this section when you want help interpreting the tool, not just running it.
```{toctree}
:maxdepth: 1
what-batdetect2-predicts
interpreting-formatted-outputs
extracted-features-and-embeddings
model-output-and-validation
postprocessing-and-thresholds
pipeline-overview
preprocessing-consistency
target-encoding-and-decoding
evaluation-concepts-and-matching
```

View File

@ -1,36 +0,0 @@
# Interpreting formatted outputs
BatDetect2 can write predictions in several output formats.
Those formats are different views of the same underlying detections, not different model behaviors.
## Separate the underlying detection from the serialized file
Internally, the current stack works with clip-level detections containing geometry, detection score, class scores, and features.
Output formatters then serialize those detections in different ways.
## Raw outputs are richest
The `raw` format preserves the broadest structured view of detections and is a good default when you want to inspect or reload predictions later.
## Tabular outputs are for analysis convenience
The `parquet` format is convenient for data analysis workflows, but the tabular representation is only one projection of the underlying detection object.
## Legacy-shaped outputs are mainly for compatibility
The `batdetect2` formatter writes the older BatDetect2-style JSON shape.
Use it when you need compatibility with older downstream tools or workflows.
## The meaning does not come from the file extension
Do not assume that a `.json`, `.parquet`, or `.nc` file changes what the model predicted.
It changes how the prediction is packaged and how much detail is retained.
## Related pages
- Output formats reference: {doc}`../reference/output-formats`
- Outputs config reference: {doc}`../reference/outputs-config`

View File

@ -1,29 +0,0 @@
# Model output and validation
BatDetect2 outputs model predictions, not ground truth. The same configuration
can behave differently across recording conditions, species compositions, and
acoustic environments.
## Why threshold choice matters
- Lower detection thresholds increase sensitivity but can increase false
positives.
- Higher thresholds reduce false positives but can miss faint calls.
No threshold is universally correct. The right setting depends on your survey
objectives and tolerance for false positives versus missed detections.
## Why local validation is required
Model performance depends on how similar your data are to training data.
Before ecological interpretation, validate predictions on a representative,
locally reviewed subset.
Recommended validation checks:
1. Compare detection counts against expert-reviewed clips.
2. Inspect species-level predictions for plausible confusion patterns.
3. Repeat checks across sites, seasons, and recorder setups.
For practical threshold workflows, see
{doc}`../how_to/tune-detection-threshold`.

View File

@ -1,34 +0,0 @@
# Pipeline overview
batdetect2 processes recordings as a sequence of modules. Each stage has a
clear role and configuration surface.
## End-to-end flow
1. Audio loading
2. Preprocessing (waveform -> spectrogram)
3. Detector forward pass
4. Postprocessing (peaks, decoding, thresholds)
5. Output formatting and export
## Why the modular design matters
The model, preprocessing, postprocessing, targets, and output formatting are
configured separately. That makes it easier to:
- swap components without rewriting the whole pipeline,
- keep experiments reproducible,
- adapt workflows to new datasets.
## Core objects in the stack
- `BatDetect2API` orchestrates training, inference, and evaluation workflows.
- `ModelConfig` defines architecture, preprocessing, postprocessing, and
targets.
- `Targets` controls event filtering, class encoding/decoding, and ROI mapping.
## Related pages
- Preprocessing rationale: {doc}`preprocessing-consistency`
- Postprocessing rationale: {doc}`postprocessing-and-thresholds`
- Target rationale: {doc}`target-encoding-and-decoding`

View File

@ -1,43 +0,0 @@
# Postprocessing and thresholds
After the detector runs on a spectrogram, the model output is still a set of
dense prediction tensors. Postprocessing turns that into a final list of call
detections with positions, sizes, and class scores.
## What postprocessing does
In broad terms, the pipeline:
1. suppresses nearby duplicate peaks,
2. extracts candidate detections,
3. reads size and class values at each detected location,
4. decodes outputs into call-level predictions.
This is where score thresholds and output density limits are applied.
## Why thresholds matter
Thresholds control the balance between sensitivity and precision.
- Lower thresholds keep more detections, including weaker calls, but may add
false positives.
- Higher thresholds remove low-confidence detections, but may miss faint calls.
You can tune this behavior per run without retraining the model.
## Two common threshold controls
- `detection_threshold`: minimum score required to keep a detection.
- `classification_threshold`: minimum class score used when assigning class
labels.
Both settings shape the final output and should be validated on reviewed local
data.
## Practical workflow
Tune thresholds on a representative subset first, then lock settings for the
full analysis run.
- How-to: {doc}`../how_to/tune-detection-threshold`
- CLI reference: {doc}`../reference/cli/predict`

View File

@ -1,36 +0,0 @@
# Preprocessing consistency
Preprocessing consistency is one of the biggest factors behind stable model
performance.
## Why consistency matters
The detector is trained on spectrograms produced by a specific preprocessing
pipeline. If inference uses different settings, the model can see a shifted
input distribution and performance may drop.
Typical mismatch sources:
- sample-rate differences,
- changed frequency crop,
- changed STFT window/hop,
- changed spectrogram transforms.
## Practical implication
When possible, keep preprocessing settings aligned between:
- training,
- evaluation,
- deployment inference.
If you intentionally change preprocessing, treat this as a new experiment and
re-validate on reviewed local data.
## Related pages
- Configure audio preprocessing:
{doc}`../how_to/configure-audio-preprocessing`
- Configure spectrogram preprocessing:
{doc}`../how_to/configure-spectrogram-preprocessing`
- Preprocessing config reference: {doc}`../reference/preprocessing-config`

View File

@ -1,40 +0,0 @@
# Target encoding and decoding
batdetect2 turns annotated sound events into training targets, then maps model
outputs back into interpretable predictions.
## Encoding path (annotations -> model targets)
At training time, the target system:
1. checks whether an event belongs to the configured detection target,
2. assigns a classification label (or none for non-specific class matches),
3. maps event geometry into position and size targets.
This behaviour is configured through `TargetConfig`,
`TargetClassConfig`, and ROI mapper settings.
## Decoding path (model outputs -> tags and geometry)
At inference time, class labels and ROI parameters are decoded back into
annotation tags and geometry.
This makes outputs interpretable in the same conceptual space as your original
annotations.
## Why this matters
Target definitions are not just metadata. They directly shape:
- what events are treated as positive examples,
- which class names the model learns,
- how geometry is represented and reconstructed.
Small changes here can alter both training outcomes and prediction semantics.
## Related pages
- Configure detection target logic: {doc}`../how_to/configure-target-definitions`
- Configure class mapping: {doc}`../how_to/define-target-classes`
- Configure ROI mapping: {doc}`../how_to/configure-roi-mapping`
- Target config reference: {doc}`../reference/targets-config-workflow`

View File

@ -1,45 +0,0 @@
# What BatDetect2 predicts
BatDetect2 predicts call-level events, not recording-level truth.
For each retained detection, the current stack can expose:
- a geometry describing where the event sits in time-frequency space,
- a detection score,
- a class-score vector,
- an internal feature vector.
## Detection score versus class scores
These are different outputs and should not be interpreted as the same thing.
- The detection score is about whether the event is kept as a detection.
- The class-score vector ranks classes for that detected event.
A detection can be kept while still having uncertain class identity.
## Predictions are conditional on the workflow
The final output also depends on:
- preprocessing,
- postprocessing,
- thresholds,
- target definitions,
- output transforms.
That is why two runs can differ even when they use the same checkpoint.
## What BatDetect2 does not predict
BatDetect2 does not directly output ecological truth.
It also does not eliminate the need for local validation.
Use reviewed local data before making ecological claims.
## Related pages
- Model output and validation: {doc}`model-output-and-validation`
- Postprocessing and thresholds: {doc}`postprocessing-and-thresholds`
- Interpreting formatted outputs: {doc}`interpreting-formatted-outputs`

View File

@ -1,81 +0,0 @@
# FAQ
## Installation and setup
### Do I need Python knowledge to use batdetect2?
Not much.
If you only want to run the model on your own recordings, you can use the CLI and follow the steps in {doc}`getting_started`.
Some command-line familiarity helps, but you do not need to write Python code for standard inference workflows.
### Are there plans for an R version?
Not currently.
Output files are plain formats (for example CSV/JSON), so you can read and analyze them in R or other environments.
### I cannot get installation working. What should I do?
First, re-check {doc}`getting_started` and confirm your environment is active.
If it still fails, open an issue with your OS, install method, and full error output: [GitHub Issues](https://github.com/macaodha/batdetect2/issues).
## Model behavior and performance
### The model does not perform well on my data
This usually means your data distribution differs from training data.
The best next step is to validate on reviewed local data and then fine-tune/train on your own annotations if needed.
### The model confuses insects/noise with bats
This can happen, especially when recording conditions differ from training conditions.
Threshold tuning and training with local annotations can improve results.
See {doc}`how_to/tune-detection-threshold`.
### The model struggles with feeding buzzes or social calls
This is a known limitation of available training data in some settings.
If you have high-quality annotated examples, they are valuable for improving models.
### Calls in the same sequence are predicted as different species
Currently we do not do any sophisticated post processing on the results output by the model.
We return a probability associated with each species for each call.
You can use these predictions to clean up the noisy predictions for sequences of calls.
### Can I trust model outputs for biodiversity conclusions?
The models developed and shared as part of this repository should be used with caution.
While they have been evaluated on held out audio data, great care should be taken when using the model outputs for any form of biodiversity assessment.
Your data may differ, and as a result it is very strongly recommended that you validate the model first using data with known species to ensure that the outputs can be trusted.
### The pipeline is slow
Runtime depends on hardware and recording duration.
GPU inference is often much faster than CPU.
## Training and scope
### Can I train on my own species set?
Yes.
You can train/fine-tune with your own annotated data and species labels.
### Does this work on frequency-division or zero-crossing recordings?
Not directly.
The workflow assumes audio can be converted to spectrograms from the raw waveform.
### Can this be used for non-bat bioacoustics (for example insects or birds)?
Potentially yes, but expect retraining and configuration changes.
Open an issue if you want guidance for a specific use case.
## Usage and licensing
### Can I use this for commercial purposes?
No.
This project is currently for non-commercial use.
See the repository license for details.

View File

@ -1,91 +0,0 @@
# Getting started
BatDetect2 can be used in two ways: through the `batdetect2` command line interface (CLI), or as the `batdetect2` Python package.
The CLI route does not require coding.
You run commands in the terminal and, in some cases, write configuration files.
The Python route gives you more flexibility and lets you integrate the model into your own workflows or experiments.
For most common use cases, both routes give you the same results.
## Try it out
If you want to try BatDetect2 before installing anything locally:
- [Hugging Face demo (UK species)](https://huggingface.co/spaces/macaodha/batdetect2)
- [Google Colab notebook](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb)
## Installation
To use `batdetect2` on your machine, you need to install it first.
We recommend using `uv` for that.
`uv` is a tool that helps manage Python software cleanly, without mixing it into the rest of your machine.
Install `uv` first by following the [installation instructions](https://docs.astral.sh/uv/getting-started/installation/).
### One-off usage
If you are not ready to install `batdetect2` permanently, you can try it with:
```bash
uvx batdetect2
```
This still downloads the code and dependencies and runs them on your machine, but the environment is temporary.
### Install the CLI
If you want the `batdetect2` CLI to always be available in your terminal, run:
```bash
uv tool install batdetect2
```
If you need to upgrade later:
```bash
uv tool upgrade batdetect2
```
Verify the CLI is available:
```bash
batdetect2
```
You can then run your first workflow.
See {doc}`tutorials/run-inference-on-folder` for more details.
### Add it to your Python project
If you are using BatDetect2 from Python code and already manage your projects with `uv`, you can add it with:
```bash
uv add batdetect2
```
If you want to upgrade it later:
```bash
uv add -U batdetect2
```
#### Alternative with `pip`
If you prefer `pip`, you can use:
```bash
pip install batdetect2
```
It is a good idea to create a separate virtual environment first so this does not interfere with other Python environments.
```bash
python -m venv .venv
source .venv/bin/activate
```
## What's next
- Run your first workflow on a folder of recordings: {doc}`tutorials/run-inference-on-folder`
- If you write code and want the Python route: {doc}`tutorials/integrate-with-a-python-pipeline`
- For common practical tasks, go to {doc}`how_to/index`
- For detailed command help, go to {doc}`reference/cli/index`
- To understand the model and its outputs, go to {doc}`explanation/index`

View File

@ -1,112 +0,0 @@
# How to choose a model
Use this guide when you want to choose which model checkpoint BatDetect2 loads.
You can choose a model in both the CLI and the Python API.
## Where you can choose the model
In the CLI, use `--model` with commands that load a checkpoint, including:
- `batdetect2 process`
- `batdetect2 evaluate`
- `batdetect2 train`
- `batdetect2 finetune`
In Python, pass the model source to `BatDetect2API.from_checkpoint(...)`.
If you do not choose a model, BatDetect2 uses the built-in default UK model.
## Use a local checkpoint path
Use a local path when you already have a checkpoint file on disk.
CLI example:
```bash
batdetect2 process directory \
path/to/audio \
path/to/outputs \
--model path/to/model.ckpt
```
Python example:
```python
from batdetect2.api_v2 import BatDetect2API
api = BatDetect2API.from_checkpoint("path/to/model.ckpt")
```
## Use a bundled checkpoint alias
BatDetect2 also supports bundled checkpoint aliases.
The built-in UK model is available as `uk_same`.
The alias `batdetect2_uk_same` also works.
CLI example:
```bash
batdetect2 process directory \
path/to/audio \
path/to/outputs \
--model uk_same
```
Python example:
```python
from batdetect2.api_v2 import BatDetect2API
api = BatDetect2API.from_checkpoint("uk_same")
```
## Use a Hugging Face URI
You can also load a checkpoint from Hugging Face with a URI like:
```text
hf://owner/repo/path/to/model.ckpt
```
This needs the optional Hugging Face dependency to be installed.
For example, install it with `pip install batdetect2[huggingface]`.
CLI example:
```bash
batdetect2 process directory \
path/to/audio \
path/to/outputs \
--model hf://owner/repo/path/to/model.ckpt
```
Python example:
```python
from batdetect2.api_v2 import BatDetect2API
api = BatDetect2API.from_checkpoint(
"hf://owner/repo/path/to/model.ckpt"
)
```
## Choose the right source
- Use a local path when you already have a checkpoint file.
- Use an alias when you want one of the bundled models.
- Use a Hugging Face URI when the checkpoint lives in a Hugging Face repo.
## Related pages
- Run inference on a folder:
{doc}`../tutorials/run-inference-on-folder`
- `BatDetect2API` reference:
{doc}`../reference/api`
- Process command reference:
{doc}`../reference/cli/predict`
- Train a custom model:
{doc}`../tutorials/train-a-custom-model`
- Fine-tune from a checkpoint:
{doc}`fine-tune-from-a-checkpoint`

View File

@ -1,71 +0,0 @@
# How to choose an inference input mode
Use this guide to decide whether `process directory`, `process file_list`, or
`process dataset` is the right entry point for your run.
## Use `process directory` when the recordings already live together
This is the simplest choice.
Use it when:
- your recordings are already organized in one directory tree,
- you want BatDetect2 to discover audio files for you,
- you are doing a first pass over a folder of recordings.
```bash
batdetect2 process directory \
path/to/model.ckpt \
path/to/audio_dir \
path/to/outputs
```
## Use `process file_list` when you need explicit control over the file set
Use it when:
- you want to run only a selected subset,
- your files are spread across directories,
- another tool has already produced the exact list of recordings to process.
The list file should contain one path per line.
```bash
batdetect2 process file_list \
path/to/model.ckpt \
path/to/audio_files.txt \
path/to/outputs
```
## Use `process dataset` when your workflow is already annotation-set driven
Use it when:
- your project already has a `soundevent` annotation set,
- you want prediction runs aligned with that annotation metadata,
- you want BatDetect2 to resolve recording paths from the annotation set.
```bash
batdetect2 process dataset \
path/to/model.ckpt \
path/to/annotation_set.json \
path/to/outputs
```
The dataset command reads a `soundevent` annotation set and extracts unique
recording paths before inference.
## Rule of thumb
- Start with `directory` for the easiest first run.
- Use `file_list` when selection matters.
- Use `dataset` when the rest of your workflow is already dataset-based.
## Related pages
- Run batch predictions:
{doc}`run-batch-predictions`
- Tune inference clipping:
{doc}`tune-inference-clipping`
- Process command reference:
{doc}`../reference/cli/predict`

View File

@ -1,74 +0,0 @@
# How to choose and configure evaluation tasks
Use this guide when the default evaluation tasks do not match the question you
want to answer.
## Know the default first
By default, BatDetect2 evaluation starts with:
- sound event detection,
- sound event classification.
Those are good defaults for many projects, but not for all of them.
## Choose the task that matches the question
Common built-in task families include:
- `sound_event_detection`
- `sound_event_classification`
- `top_class_detection`
- `clip_detection`
- `clip_classification`
Choose based on the question you care about.
- Use sound-event tasks when you care about individual call events.
- Use clip tasks when you care about clip-level presence or clip-level class
evidence.
- Use top-class detection when you want matching based on the highest-scoring
class per detection.
## Configure tasks in `EvaluationConfig`
Example:
```yaml
tasks:
- name: sound_event_detection
prefix: detection
affinity_threshold: 0.0
strict_match: true
- name: clip_classification
prefix: clip_classification
```
Pass the config with:
```bash
batdetect2 evaluate \
path/to/test_dataset.yaml \
--model path/to/model.ckpt \
--base-dir path/to/project_root \
--evaluation-config path/to/evaluation.yaml
```
Include `--base-dir` when the dataset config resolves recordings through
relative paths.
## Change one thing at a time
When comparing models or settings, avoid changing task definitions, thresholds,
matching behavior, and datasets all at once.
Otherwise it becomes hard to explain why the metric changed.
## Related pages
- Evaluation tutorial:
{doc}`../tutorials/evaluate-on-a-test-set`
- Evaluation config reference:
{doc}`../reference/evaluation-config`
- Evaluation concepts:
{doc}`../explanation/evaluation-concepts-and-matching`

View File

@ -1,53 +0,0 @@
# How to configure an AOEF dataset source
Use this guide when your annotations are stored in AOEF/soundevent JSON files,
including exports from Whombat.
## 1) Add an AOEF source entry
In your dataset config, add a source with `format: aoef`.
```yaml
sources:
- name: my_aoef_source
format: aoef
audio_dir: /path/to/audio
annotations_path: /path/to/annotations.soundevent.json
```
## 2) Choose filtering behavior for annotation projects
If `annotations_path` is an `AnnotationProject`, you can filter by task state.
```yaml
sources:
- name: whombat_verified
format: aoef
audio_dir: /path/to/audio
annotations_path: /path/to/project_export.aoef
filter:
only_completed: true
only_verified: true
exclude_issues: true
```
If you omit `filter`, default project filtering is applied.
To disable filtering for project files:
```yaml
filter: null
```
## 3) Check that the source loads
Run a summary on your dataset config:
```bash
batdetect2 data summary path/to/dataset.yaml
```
## 4) Continue to training or evaluation
- For training: {doc}`../tutorials/train-a-custom-model`
- For field-level reference: {doc}`../reference/data-sources`

View File

@ -1,66 +0,0 @@
# How to configure audio preprocessing
Use this guide to set sample-rate and waveform-level preprocessing behaviour.
## 1) Set audio loader settings
The audio loader config controls resampling.
```yaml
samplerate: 256000
resample:
enabled: true
method: poly
```
If your recordings are already at the expected sample rate, you can disable
resampling.
```yaml
samplerate: 256000
resample:
enabled: false
```
## 2) Set waveform transforms in preprocessing config
Waveform transforms are configured in `preprocess.audio_transforms`.
```yaml
preprocess:
audio_transforms:
- name: center_audio
- name: scale_audio
- name: fix_duration
duration: 0.5
```
Available built-ins:
- `center_audio`
- `scale_audio`
- `fix_duration`
## 3) Use the config in your workflow
For CLI inference/evaluation, use `--audio-config`.
```bash
batdetect2 process directory \
path/to/model.ckpt \
path/to/audio_dir \
path/to/outputs \
--audio-config path/to/audio.yaml
```
## 4) Verify quickly on a small subset
Run on a small folder first and confirm that outputs and runtime are as expected
before full-batch runs.
## Related pages
- Spectrogram settings:
{doc}`configure-spectrogram-preprocessing`
- Preprocessing config reference:
{doc}`../reference/preprocessing-config`

View File

@ -1,57 +0,0 @@
# How to configure ROI mapping
Use this guide to control how annotation geometry is encoded into training
targets and decoded back into boxes.
## 1) Set the default ROI mapper
The default mapper is `anchor_bbox`.
```yaml
roi:
default:
name: anchor_bbox
anchor: bottom-left
time_scale: 1000.0
frequency_scale: 0.001163
```
## 2) Choose an anchor strategy
Typical options include `bottom-left` and `center`.
- `bottom-left` is the current default.
- `center` can be easier to reason about in some workflows.
## 3) Set scale factors intentionally
- `time_scale` controls width scaling.
- `frequency_scale` controls height scaling.
Use values that are consistent with your model setup and keep them fixed when
comparing experiments.
## 4) (Optional) override ROI mapping for specific classes
Add class-specific mappers under `roi.overrides`.
```yaml
roi:
default:
name: anchor_bbox
anchor: bottom-left
time_scale: 1000.0
frequency_scale: 0.001163
overrides:
species_x:
name: anchor_bbox
anchor: center
time_scale: 1000.0
frequency_scale: 0.001163
```
## Related pages
- Target definitions: {doc}`configure-target-definitions`
- Class definitions: {doc}`define-target-classes`
- Target encoding overview: {doc}`../explanation/target-encoding-and-decoding`

View File

@ -1,59 +0,0 @@
# How to configure spectrogram preprocessing
Use this guide to set STFT, frequency range, and spectrogram transforms.
## 1) Configure STFT and frequency range
```yaml
preprocess:
stft:
window_duration: 0.002
window_overlap: 0.75
window_fn: hann
frequencies:
min_freq: 10000
max_freq: 120000
```
## 2) Configure spectrogram transforms
`spectrogram_transforms` are applied in order.
```yaml
preprocess:
spectrogram_transforms:
- name: pcen
time_constant: 0.4
gain: 0.98
bias: 2.0
power: 0.5
- name: spectral_mean_subtraction
- name: scale_amplitude
scale: db
```
Common built-ins:
- `pcen`
- `spectral_mean_subtraction`
- `scale_amplitude` (`db` or `power`)
- `peak_normalize`
## 3) Configure output size
```yaml
preprocess:
size:
height: 128
resize_factor: 0.5
```
## 4) Keep train and inference settings aligned
Use the same preprocessing setup for training and prediction whenever possible.
Large mismatches can degrade model performance.
## Related pages
- Why consistency matters: {doc}`../explanation/preprocessing-consistency`
- Preprocessing config reference: {doc}`../reference/preprocessing-config`

View File

@ -1,58 +0,0 @@
# How to configure target definitions
Use this guide to define which annotated sound events are considered valid
detection targets.
## 1) Start from a targets config file
```yaml
detection_target:
name: bat
match_if:
name: has_tag
tag:
key: call_type
value: Echolocation
assign_tags:
- key: call_type
value: Echolocation
- key: order
value: Chiroptera
```
`match_if` decides whether an annotation is included in the detection target.
## 2) Use condition combinators when needed
You can combine conditions with `all_of`, `any_of`, and `not`.
```yaml
detection_target:
name: bat
match_if:
name: all_of
conditions:
- name: has_tag
tag:
key: call_type
value: Echolocation
- name: not
condition:
name: has_any_tag
tags:
- key: call_type
value: Social
- key: class
value: Not Bat
```
## 3) Verify with a small sample first
Before full training, inspect a small annotation subset and confirm that the
selection logic keeps the events you expect.
## Related pages
- Class mapping: {doc}`define-target-classes`
- ROI mapping: {doc}`configure-roi-mapping`
- Targets reference: {doc}`../reference/targets-config-workflow`

View File

@ -1,59 +0,0 @@
# How to define target classes
Use this guide to map annotations to classification labels used during
training.
## 1) Add classification target entries
Each entry defines a class name and matching tags.
```yaml
classification_targets:
- name: pippip
tags:
- key: class
value: Pipistrellus pipistrellus
- name: pippyg
tags:
- key: class
value: Pipistrellus pygmaeus
```
## 2) Use `assign_tags` to control decoded output tags
If you want prediction output tags to differ from matching tags, set
`assign_tags` explicitly.
```yaml
classification_targets:
- name: pipistrelle_group
tags:
- key: class
value: Pipistrellus pipistrellus
assign_tags:
- key: genus
value: Pipistrellus
```
## 3) Use `match_if` for complex class rules
For advanced conditions, use `match_if` instead of `tags`.
```yaml
classification_targets:
- name: long_call
match_if:
name: duration
operator: gt
seconds: 0.02
```
## 4) Confirm class names are unique
`classification_targets.name` values must be unique.
## Related pages
- Detection-target filtering: {doc}`configure-target-definitions`
- ROI mapping: {doc}`configure-roi-mapping`
- Targets config reference: {doc}`../reference/targets-config-workflow`

View File

@ -1,45 +0,0 @@
# How to fine-tune from a checkpoint
Use this guide when you want to continue from an existing checkpoint instead of training a fresh model config.
## Use `--model` for checkpoint-based training
Pass a checkpoint with `--model`.
Do not combine `--model` with `--model-config`.
```bash
batdetect2 train \
path/to/train_dataset.yaml \
--val-dataset path/to/val_dataset.yaml \
--model path/to/model.ckpt \
--training-config path/to/training.yaml
```
## Keep targets and preprocessing aligned
If you override targets or audio-related settings while fine-tuning, validate that they still match the checkpoint and your dataset.
Mismatches here can produce confusing failures or invalid comparisons.
## Decide what question the fine-tune should answer
Common fine-tuning goals are:
- adapting to local recording conditions,
- adapting to a new label set,
- improving performance on a narrower deployment context.
Make that goal explicit before comparing results.
## Evaluate after fine-tuning
Always compare the fine-tuned checkpoint against a held-out dataset.
Use the same evaluation setup when comparing before and after.
## Related pages
- Training tutorial: {doc}`../tutorials/train-a-custom-model`
- Evaluate a test set: {doc}`../tutorials/evaluate-on-a-test-set`
- Train command reference: {doc}`../reference/cli/train`

View File

@ -1,66 +0,0 @@
# How to import legacy batdetect2 annotations
Use this guide if your annotations are in older batdetect2 JSON formats.
Two legacy formats are supported:
- `batdetect2`: one annotation JSON file per recording
- `batdetect2_file`: one merged JSON file for many recordings
## 1) Choose the correct source format
Directory-based annotations (`format: batdetect2`):
```yaml
sources:
- name: legacy_per_file
format: batdetect2
audio_dir: /path/to/audio
annotations_dir: /path/to/annotation_json_dir
```
Merged annotation file (`format: batdetect2_file`):
```yaml
sources:
- name: legacy_merged
format: batdetect2_file
audio_dir: /path/to/audio
annotations_path: /path/to/merged_annotations.json
```
## 2) Set optional legacy filters
Legacy filters are based on `annotated` and `issues` flags.
```yaml
filter:
only_annotated: true
exclude_issues: true
```
To load all entries regardless of flags:
```yaml
filter: null
```
## 3) Validate and convert if needed
Check loaded records:
```bash
batdetect2 data summary path/to/dataset.yaml
```
Convert to annotation-set output for downstream tooling:
```bash
batdetect2 data convert path/to/dataset.yaml --output path/to/output.json
```
## 4) Continue with current workflows
- Run predictions: {doc}`run-batch-predictions`
- Train on imported data: {doc}`../tutorials/train-a-custom-model`
- Field-level reference: {doc}`../reference/data-sources`

View File

@ -1,30 +0,0 @@
# How-to Guides
How-to guides help you answer practical questions once you are past the first
tutorial.
Use this section when you already know the basic workflow and want help with one
specific task.
```{toctree}
:maxdepth: 1
choose-a-model
choose-an-inference-input-mode
run-batch-predictions
tune-inference-clipping
tune-detection-threshold
inspect-class-scores-in-python
inspect-detection-features-in-python
save-predictions-in-different-output-formats
fine-tune-from-a-checkpoint
choose-and-configure-evaluation-tasks
interpret-evaluation-outputs
configure-aoef-dataset
import-legacy-batdetect2-annotations
configure-audio-preprocessing
configure-spectrogram-preprocessing
configure-target-definitions
define-target-classes
configure-roi-mapping
```

View File

@ -1,44 +0,0 @@
# How to inspect class scores in Python
Use this guide when you need more than the top class label for each detection.
## Get the ranked class scores
`BatDetect2API.get_class_scores` returns `(class_name, score)` pairs for one detection.
```python
from pathlib import Path
from batdetect2.api_v2 import BatDetect2API
api = BatDetect2API.from_checkpoint(Path("path/to/model.ckpt"))
prediction = api.process_file(Path("path/to/audio.wav"))
for detection in prediction.detections:
print("detection score:", detection.detection_score)
for class_name, score in api.get_class_scores(detection):
print(class_name, score)
```
## Separate detection confidence from class ranking
Keep these two ideas separate:
- `detection_score` tells you how strongly the model kept the event as a detection,
- `class_scores` tell you how the model ranked classes for that detected event.
A detection can have a reasonable detection score while still having uncertain class ranking.
## Hide the top class if needed
If you want to inspect only the alternatives, pass `include_top_class=False`.
```python
api.get_class_scores(detection, include_top_class=False)
```
## Related pages
- Python tutorial: {doc}`../tutorials/integrate-with-a-python-pipeline`
- API reference: {doc}`../reference/api`
- Understanding scores: {doc}`../explanation/what-batdetect2-predicts`

View File

@ -1,49 +0,0 @@
# How to inspect detection features in Python
Use this guide when you want the per-detection feature vectors exposed by the current API.
## Get the feature vector for one detection
Each detection carries a `features` vector.
The API exposes it through `get_detection_features`.
```python
from pathlib import Path
from batdetect2.api_v2 import BatDetect2API
api = BatDetect2API.from_checkpoint(Path("path/to/model.ckpt"))
prediction = api.process_file(Path("path/to/audio.wav"))
for detection in prediction.detections:
features = api.get_detection_features(detection)
print(features.shape)
```
## Use features for exploration, not as ground truth labels
These features are internal model representations attached to detections.
They can be useful for:
- exploratory visualization,
- downstream clustering,
- comparison across detections,
- building extra analysis pipelines.
They do not replace validation.
They also do not automatically have a one-to-one interpretation as ecological variables.
## Save predictions with features included
If you need features on disk, use an output format that supports them, such as `raw` or `parquet`, and keep feature inclusion enabled.
See {doc}`save-predictions-in-different-output-formats`.
## Related pages
- Understanding features and embeddings: {doc}`../explanation/extracted-features-and-embeddings`
- Output formats reference: {doc}`../reference/output-formats`
- API reference: {doc}`../reference/api`

View File

@ -1,41 +0,0 @@
# How to interpret evaluation outputs
Use this guide after `batdetect2 evaluate` has written metrics and plots to disk.
## Start by identifying the task
Do not interpret a metric until you know which evaluation task produced it.
For example, a detection score and a clip-classification score answer different questions.
## Read the output directory as a bundle
Treat the evaluation output directory as one package:
- metrics,
- plots,
- saved predictions,
- config context.
Do not lift a single number out of context and treat it as the whole story.
## Look for failure patterns, not just overall averages
Check:
- whether errors concentrate in certain taxa,
- whether specific sites or recorder setups behave differently,
- whether threshold choices are driving the result,
- whether predictions are near clip boundaries or matching thresholds.
## Keep validation and deployment questions separate
A model can look good on one task and still be a poor fit for your deployment question.
Interpret the outputs in relation to the real use case, not only the easiest metric to report.
## Related pages
- Evaluation tutorial: {doc}`../tutorials/evaluate-on-a-test-set`
- Evaluation concepts: {doc}`../explanation/evaluation-concepts-and-matching`
- Model output and validation: {doc}`../explanation/model-output-and-validation`

View File

@ -1,62 +0,0 @@
# How to run batch processing
This guide shows practical command patterns for directory-based and file-list
processing runs.
Use it after you already know which input mode you want and need concrete
command templates for a repeatable batch run.
## Process a directory
```bash
batdetect2 process directory \
path/to/model.ckpt \
path/to/audio_dir \
path/to/outputs
```
Use this when BatDetect2 should discover the audio files for you.
## Process a file list
```bash
batdetect2 process file_list \
path/to/model.ckpt \
path/to/audio_files.txt \
path/to/outputs
```
Use this when another part of your workflow already produced the exact recording
list to process.
## Process a dataset config
```bash
batdetect2 process dataset \
path/to/model.ckpt \
path/to/annotation_set.json \
path/to/outputs
```
Use this when your project already has a `soundevent` annotation set and you
want to extract unique recording paths from it.
## Useful options
- `--batch-size` to control throughput.
- `--workers` to set data-loading parallelism.
- `--format` to select output format.
- `--inference-config` to control clipping and loader behavior.
- `--outputs-config` to control serialization and output transforms.
- `--detection-threshold` to override the detection threshold for a run.
## Practical workflow
For large runs:
1. test the command on a small reviewed subset,
2. lock the config files and command shape,
3. write outputs to a dedicated directory per run,
4. record the checkpoint, config paths, and thresholds used.
For complete option details, see {doc}`../reference/cli/predict`.

View File

@ -1,95 +0,0 @@
# How to save predictions in different output formats
Use this guide when you need BatDetect2 outputs in a specific representation for
downstream tools.
## Choose the format that matches the job
Current built-in output formats include:
- `raw`:
one NetCDF file per clip, best for rich structured outputs,
- `parquet`:
tabular storage for data analysis workflows,
- `soundevent`:
prediction-set JSON for soundevent-style tooling,
- `batdetect2`:
legacy-compatible per-recording JSON and CSV outputs.
## Select a format from the CLI
Use `--format` for quick experiments.
```bash
batdetect2 process directory \
path/to/model.ckpt \
path/to/audio_dir \
path/to/outputs \
--format parquet
```
## Use an outputs config for repeatable runs
Use an outputs config when you want reproducible control over format and
transforms.
Example:
```yaml
format:
name: raw
include_class_scores: true
include_features: true
include_geometry: true
transform:
detection_transforms: []
clip_transforms: []
```
Run with:
```bash
batdetect2 process directory \
path/to/model.ckpt \
path/to/audio_dir \
path/to/outputs \
--outputs-config path/to/outputs.yaml
```
## Pick the simplest useful format
- Use `raw` if you want the richest output surface and easy round-tripping.
- Use `parquet` if you want tabular analysis in Python or data-lake workflows.
- Use `soundevent` if you want prediction-set JSON.
- Use `batdetect2` when you need legacy BatDetect2-style outputs.
## Enable legacy CNN feature CSVs
The `batdetect2` formatter can also write the legacy CNN feature sidecar CSVs.
This is controlled through the outputs config.
Example:
```yaml
format:
name: batdetect2
write_cnn_features_csv: true
transform:
detection_transforms: []
clip_transforms: []
```
When enabled, BatDetect2 writes:
- one `.json` file per recording,
- one detection `.csv` file per recording,
- one `_cnn_features.csv` file per recording when detections are present.
## Related pages
- Outputs config reference:
{doc}`../reference/outputs-config`
- Output formats reference:
{doc}`../reference/output-formats`
- Output transforms reference:
{doc}`../reference/output-transforms`

View File

@ -1,51 +0,0 @@
# How to tune detection threshold
Use this guide to compare detection outputs at different threshold values.
The goal is not to find a universal threshold.
The goal is to choose a threshold that fits your reviewed local data and the
project trade-off between missed calls and false positives.
## 1) Start with a baseline run
Run an initial prediction workflow and keep outputs in a dedicated folder.
## 2) Sweep threshold values
Run `process` multiple times with different thresholds (for example `0.1`,
`0.3`, `0.5`) and compare output counts and quality on the same validation
subset.
```bash
batdetect2 process directory \
path/to/model.ckpt \
path/to/audio_dir \
path/to/outputs_thr_03 \
--detection-threshold 0.3
```
Keep each threshold run in a separate output directory.
That makes it easier to compare counts and inspect example files without mixing
results.
## 3) Validate against known calls
Use files with trusted annotations or expert review to select a threshold that
fits your project goals.
Check both:
- obvious false positives,
- obvious missed calls.
If class interpretation matters downstream, inspect class ranking behavior as
well, not just detection counts.
## 4) Record your chosen setting
Write down the chosen threshold and rationale so analyses are reproducible.
For conceptual trade-offs, see
{doc}`../explanation/model-output-and-validation`.

View File

@ -1,73 +0,0 @@
# How to tune inference clipping
Use this guide when long recordings need to be split into smaller clips during
inference.
## What clipping controls
`InferenceConfig.clipping` controls how recordings are split before batching.
Key fields are:
- `duration`:
clip duration in seconds,
- `overlap`:
overlap between adjacent clips,
- `max_empty`:
how much empty padding is allowed,
- `discard_empty`:
whether empty clips are dropped.
## Start from the defaults
Use the built-in clipping behavior first unless you already know you need
something else.
Only tune clipping when:
- recordings are much longer than your normal working set,
- you are seeing edge effects around calls,
- you need tighter control over throughput or padding behavior.
## Override clipping with an inference config
Create an inference config file and pass it to `process` or `evaluate`.
Example:
```yaml
clipping:
enabled: true
duration: 0.5
overlap: 0.1
max_empty: 0.0
discard_empty: true
loader:
batch_size: 8
```
Run with:
```bash
batdetect2 process directory \
path/to/model.ckpt \
path/to/audio_dir \
path/to/outputs \
--inference-config path/to/inference.yaml
```
## Validate clipping changes on a small reviewed subset
Changing clipping changes what the model sees per batch and can change how
events near clip boundaries behave.
Check a reviewed subset before applying clipping changes to a full project.
## Related pages
- Inference config reference:
{doc}`../reference/inference-config`
- Run batch predictions:
{doc}`run-batch-predictions`
- Understanding the pipeline:
{doc}`../explanation/pipeline-overview`

View File

@ -1,114 +0,0 @@
# Home
Welcome to the BatDetect2 documentation.
## What is BatDetect2?
`batdetect2` is a deep learning model and software package for detecting and
classifying bat echolocation calls in high-frequency audio recordings.
You can use it from the command line or from Python, depending on how much
control you need.
In practice, BatDetect2 scans a recording, finds sounds that look like bat
calls, and returns one result for each detected call.
Each result can include where the call appears in the recording, shown as a box
with start and end time and the lowest and highest frequency, how confident the
model is that it found a call, and how strongly it matches the available
classes.
The built-in default model is trained for 17 UK species.
The package also supports custom training, fine-tuning, evaluation, and more
advanced workflows from Python.
For more detail on the underlying approach, see the pre-print:
[Towards a General Approach for Bat Echolocation Detection and Classification](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)
```{warning}
Treat outputs as model predictions, not ground truth.
Always validate on reviewed local data before using results for ecological inference.
```
## What can I do with it?
- I want to run the model on my recordings:
{doc}`tutorials/run-inference-on-folder`
- I write code and want to use it from Python:
{doc}`tutorials/integrate-with-a-python-pipeline`
- I want to train or fine-tune a custom model:
{doc}`tutorials/train-a-custom-model`
- I want to evaluate a trained model on held-out data:
{doc}`tutorials/evaluate-on-a-test-set`
```{note}
Looking for the previous BatDetect2 workflow?
See {doc}`legacy/index`.
The legacy docs are still available, but new workflows should use `batdetect2 process` and `BatDetect2API`.
```
## How to use this site
Start with {doc}`getting_started` if you are new.
Then choose the section that matches what you need.
If you are here mainly to run the model on recordings, start with Tutorials.
| Section | Best for | Start here |
| ------------- | --------------------------------------------- | ------------------------ |
| Tutorials | Step-by-step routes for the most common tasks | {doc}`tutorials/index` |
| How-to guides | Answers to specific practical questions | {doc}`how_to/index` |
| Reference | Detailed command and settings help | {doc}`reference/index` |
| Understanding | Concepts, interpretation, and trade-offs | {doc}`explanation/index` |
| Legacy | Previous workflow and migration guidance | {doc}`legacy/index` |
## Get in touch
- GitHub repository:
[macaodha/batdetect2](https://github.com/macaodha/batdetect2)
- Questions, bug reports, and feature requests:
[GitHub Issues](https://github.com/macaodha/batdetect2/issues)
- Common questions:
{doc}`faq`
- Want to contribute?
See {doc}`development/index`
## Cite this work
If you use BatDetect2 in research, please cite:
Mac Aodha, O., Martinez Balvanera, S., Damstra, E., et al.
(2022).
_Towards a General Approach for Bat Echolocation Detection and Classification_.
bioRxiv.
or the bibtex entry
```bibtex
@article{batdetect2_2022,
title = {Towards a General Approach for Bat Echolocation Detection and Classification},
author = {Mac Aodha, Oisin and Mart\'{i}nez Balvanera, Santiago and Damstra, Elise and Cooke, Martyn and Eichinski, Philip and Browning, Ella and Barataudm, Michel and Boughey, Katherine and Coles, Roger and Giacomini, Giada and MacSwiney G., M. Cristina and K. Obrist, Martin and Parsons, Stuart and Sattler, Thomas and Jones, Kate E.},
journal = {bioRxiv},
year = {2022}
}
```
```{toctree}
:maxdepth: 1
:caption: Get Started
getting_started
faq
tutorials/index
how_to/index
reference/index
explanation/index
legacy/index
```
```{toctree}
:maxdepth: 1
:caption: Contributing
development/index
```

View File

@ -1,53 +0,0 @@
# CLI workflow: `batdetect2 detect`
This page documents the previous CLI workflow based on `batdetect2 detect`.
```{warning}
This is documentation for a previous version of batdetect2.
For new workflows, use `batdetect2 process directory` instead.
If you are migrating, start with {doc}`migration-guide`.
```
## Processing a folder of audio files
```bash
batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD
```
Example:
```bash
batdetect2 detect example_data/audio/ example_data/anns/ 0.3
```
This command scans a directory of audio files, runs the BatDetect2 detector on
each file, and writes BatDetect2-style outputs into `ANN_DIR`.
Those outputs usually include one JSON file and one CSV file per recording, and
can optionally include extra feature CSVs.
`AUDIO_DIR` is the folder containing the input `.wav` files.
`ANN_DIR` is the folder where model outputs are written.
`DETECTION_THRESHOLD` controls which detections are kept.
Predictions below this score are discarded.
Smaller values keep more detections, but usually also increase mistakes.
Common options:
- `--cnn_features` Write extra CNN feature CSV files for each recording.
- `--spec_features` Extract and write traditional acoustic spectrogram feature
CSV files.
These are saved as `*_spec_features.csv` files.
- `--time_expansion_factor` Set the time expansion factor used for all files in
the run.
- `--save_preds_if_empty` Save output files even when no detections are found.
- `--model_path` Use a specific checkpoint instead of the included default
model.
If omitted, the command uses the default model trained on UK data.
## Related pages
- Migration guide:
{doc}`migration-guide`
- Current process docs:
{doc}`../reference/cli/predict`

View File

@ -1,28 +0,0 @@
# BatDetect2 v1.0 documentation
This section documents the BatDetect2 workflow for version 1.
Use these pages if you need to keep working with the older `batdetect2 detect` command or the older `batdetect2.api` interface.
For new projects, we recommend the current workflow:
- CLI:
`batdetect2 process`
- Python:
`batdetect2.api_v2.BatDetect2API`
If you are moving from the older workflow, start with {doc}`migration-guide`.
```{warning}
These pages describe the previous workflow.
They are kept for continuity and migration support.
New users should start with {doc}`../getting_started` and {doc}`../tutorials/index`.
```
```{toctree}
:maxdepth: 1
cli-detect
python-api
migration-guide
```

View File

@ -1,123 +0,0 @@
# BatDetect2 2.0 migration guide
Use this guide when moving from BatDetect2 1.x workflows to the CLI and API in
2.x.
## Why migrate
You get access to newer features.
The codebase changed quite a bit and now gives you much more control over the
workflow through config files, improved training and fine-tuning code, and a
more flexible sound target definition system.
You can also run newer or improved models.
That includes updated versions of the UK model, plus other models trained with
the newer codebase.
We are no longer actively supporting version 1.
No new enhancements are planned there, and only major bug fixes may still be
considered.
Future work is focused on version 2, including compatibility with newer Python
versions.
## Deprecation plan
We have kept the `batdetect2.api` module and the `batdetect2 detect` CLI command
in place for now.
You can keep using them without changing your current workflow.
However, many of the internal functions were relocated, removed or modified.
If your code relied on anything outside of the `api` module, it may break.
It is worth checking the new docs first, since there may already be a newer
feature that covers your use case.
If not, please open an issue.
Because the old `api` and CLI command are now redundant with the newer stack, we
plan to remove them in about a year.
If you want to keep pipelines up to date and long-running, it is a good idea to
migrate to version 2.
## How to migrate
If you are only using the `batdetect2 detect` CLI command or the
`batdetect2.api` module, the migration should be fairly simple.
This guide only covers these two entry points.
### CLI mapping
- `batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD` -> `batdetect2
process directory AUDIO_DIR OUTPUT_PATH --detection-threshold
DETECTION_THRESHOLD ...`
Main changes:
- outputs can be written in different formats.
See the output format reference for the available options.
- the detection threshold is now an option instead of a required positional
argument.
- options like saving CNN features are now controlled through config rather than
command flags.
- there are separate subcommands for processing a directory, file list, or
dataset.
### Python API mapping
- old:
`import batdetect2.api as api`
- current:
`from batdetect2 import BatDetect2API`
Typical migration shape:
```python
from pathlib import Path
from batdetect2 import BatDetect2API
# If no checkpoint is provided, the default UK model is loaded
api = BatDetect2API.from_checkpoint()
prediction = api.process_file(Path("path/to/audio.wav"))
```
Useful replacements:
- `batdetect2.api.process_file` -> current `BatDetect2API.process_file`
- `batdetect2.api.process_audio` -> current `BatDetect2API.process_audio`
- `batdetect2.api.process_spectrogram` -> current
`BatDetect2API.process_spectrogram`
- one-off batch loops -> `BatDetect2API.process_files` or CLI `process`
### Model changes
The default checkpoint used by the new CLI `process` commands and by
`BatDetect2API` is a newer model trained from scratch using the updated training
code, but the same model architecture, training procedure, and data.
Performance did not change substantially, but some differences are still
expected.
### Species names
For the default UK model there are two naming changes:
1. The original model had a typo and instead of `Barbastella barbastellus` it
used `Barbastellus barbastellus`.
This has now been corrected.
2. There has been a recent change in name for `Eptesicus serotinus` to
`Cnephaeus serotinus`.
## Stay on version 1
If you prefer not to migrate to version 2 yet, you can keep using version 1.
In that case, it is a good idea to pin your dependency:
```bash
pip install "batdetect2>=1.3.1,<2"
```
## Related pages
- Getting started:
{doc}`../getting_started`
- Tutorials:
{doc}`../tutorials/index`
- API reference:
{doc}`../reference/api`

View File

@ -1,55 +0,0 @@
# Legacy Python API: `batdetect2.api`
This page documents the previous Python API workflow based on `batdetect2.api`.
```{warning}
This is documentation for a previous version of batdetect2.
For new workflows, use `batdetect2.BatDetect2API`.
If you are migrating, start with {doc}`migration-guide`.
```
## Using BatDetect2 in Python
If you prefer to process data inside a Python script, you can use the `batdetect2.api` module.
This interface gives you a simple entry point for running the built-in BatDetect2 model and also exposes the default model and default configuration more directly than the current API.
You can process a whole file in one step, or load audio, generate a spectrogram, and work with lower-level functions yourself.
Common functions:
- `process_file` Load an audio file, run the model, and return BatDetect2-style results for that recording.
- `process_audio` Run inference on an audio array that is already loaded in memory.
- `process_spectrogram` Run inference starting from a spectrogram tensor instead of raw audio.
- `load_audio` Load and resample audio using the legacy preprocessing path.
- `generate_spectrogram` Convert audio into the spectrogram representation expected by the model.
- `postprocess` Convert raw model outputs into detections and extracted features.
Typical usage:
```python
import batdetect2.api as api
AUDIO_FILE = "example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav"
# Process a whole file
results = api.process_file(AUDIO_FILE)
annotations = results["pred_dict"]["annotation"]
# Or, load audio and compute spectrograms
audio = api.load_audio(AUDIO_FILE)
spec = api.generate_spectrogram(audio)
# And process the audio or the spectrogram with the model
detections, features, spec = api.process_audio(audio)
detections, features = api.process_spectrogram(spec)
# Integrate the detections or extracted features into your own analysis
```
This interface is most useful when you want to work directly with detections, features, spectrograms, or intermediate arrays inside your own code.
## Related pages
- Migration guide: {doc}`migration-guide`
- Current API reference: {doc}`../reference/api`

View File

@ -1,39 +0,0 @@
# `BatDetect2API` reference
`BatDetect2API` is the main Python entry point for BatDetect2.
Use it when you want to load a model, run prediction, inspect detections,
evaluate results, or train from Python.
Defined in `batdetect2.api_v2`.
## Main ways to create it
- `BatDetect2API.from_checkpoint(path, ...)`
- load a trained checkpoint, a bundled checkpoint alias, or a Hugging Face
checkpoint.
- `BatDetect2API.from_config(model_config=..., targets_config=..., ...)`
- build a full model stack from config objects.
## Common tasks
- Load a checkpoint and run prediction on one file.
- Run prediction on many files or clips.
- Save predictions in one of the supported output formats.
- Evaluate a model on labelled data.
- Fine-tune an existing checkpoint on new targets.
## Generated reference
```{eval-rst}
.. autoclass:: batdetect2.api_v2.BatDetect2API
```
## Related pages
- Python tutorial:
{doc}`../tutorials/integrate-with-a-python-pipeline`
- Outputs config reference:
{doc}`outputs-config`
- Output formats reference:
{doc}`output-formats`

View File

@ -1,8 +0,0 @@
Base command
============
The options on this page apply to all subcommands.
.. click:: batdetect2.cli:cli
:prog: batdetect2
:nested: none

View File

@ -1,8 +0,0 @@
Data command
============
Inspect and convert dataset config files.
.. click:: batdetect2.cli.data:data
:prog: batdetect2 data
:nested: full

View File

@ -1,18 +0,0 @@
Legacy detect command
=====================
.. warning::
``batdetect2 detect`` is a legacy compatibility command.
Prefer ``batdetect2 process directory`` for new workflows.
Migration at a glance
---------------------
- Legacy: ``batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD``
- Current: ``batdetect2 process directory MODEL_PATH AUDIO_DIR OUTPUT_PATH``
with optional ``--detection-threshold``
.. click:: batdetect2.cli.compat:detect
:prog: batdetect2 detect
:nested: none

View File

@ -1,11 +0,0 @@
Evaluate command
================
Use ``batdetect2 evaluate`` to compare a checkpoint against labelled test data.
This command writes metrics and any configured artifacts to the output
directory.
.. click:: batdetect2.cli.evaluate:evaluate_command
:prog: batdetect2 evaluate
:nested: none

View File

@ -1,11 +0,0 @@
Finetune command
================
Use ``batdetect2 finetune`` to adapt an existing checkpoint to a new target
definition.
If you do not pass ``--model``, the bundled ``uk_same`` checkpoint is used.
.. click:: batdetect2.cli.finetune:finetune_command
:prog: batdetect2 finetune
:nested: none

View File

@ -1,50 +0,0 @@
# CLI reference
Use this section to find the right command quickly, then open the command page
for the full option list.
## Command map
| Command | Use it for | Required positional args |
| --- | --- | --- |
| `batdetect2 process` | Run inference on audio | Depends on subcommand (`directory`, `file_list`, `dataset`) |
| `batdetect2 data` | Inspect and convert dataset configs | Depends on subcommand (`summary`, `convert`) |
| `batdetect2 train` | Train or fine-tune models | `TRAIN_DATASET` |
| `batdetect2 finetune` | Fine-tune a checkpoint on new targets | `TRAIN_DATASET` plus `--targets` |
| `batdetect2 evaluate` | Evaluate a checkpoint on a test dataset | `TEST_DATASET` |
| `batdetect2 detect` | Legacy compatibility workflow | `AUDIO_DIR`, `ANN_DIR`, `DETECTION_THRESHOLD` |
## Notes
- Global CLI options are documented in {doc}`base`.
- Paths with spaces should be wrapped in quotes.
- Input audio is expected to be mono.
- `process` uses the optional `--detection-threshold` override.
- `evaluate` takes `TEST_DATASET` as a positional argument and uses `--model`
for the checkpoint override.
- `finetune` defaults to the bundled `uk_same` checkpoint if `--model` is not
provided.
```{warning}
`batdetect2 detect` is a legacy command.
Prefer `batdetect2 process directory` for new workflows.
```
## Related pages
- {doc}`../../tutorials/run-inference-on-folder`
- {doc}`../../how_to/run-batch-predictions`
- {doc}`../../how_to/tune-detection-threshold`
- {doc}`../configs`
```{toctree}
:maxdepth: 1
Base command and global options <base>
Process command group <predict>
Data command group <data>
Train command <train>
Finetune command <finetune>
Evaluate command <evaluate>
Legacy detect command <detect_legacy>
```

View File

@ -1,17 +0,0 @@
Process command
===============
Use ``batdetect2 process`` to run inference on audio.
Choose a subcommand based on how you want to provide the input:
- ``directory`` for all supported audio files in one folder
- ``file_list`` for a text file with one audio path per line
- ``dataset`` for recordings referenced by a dataset file
Use ``--detection-threshold`` when you want to override the configured
threshold for one run.
.. click:: batdetect2.cli.inference:process
:prog: batdetect2 process
:nested: full

View File

@ -1,12 +0,0 @@
Train command
=============
Use ``batdetect2 train`` to start from a fresh model config or continue from an
existing checkpoint.
If you want to adapt an existing checkpoint to a new target definition, use
``batdetect2 finetune`` instead.
.. click:: batdetect2.cli.train:train_command
:prog: batdetect2 train
:nested: none

View File

@ -1,18 +0,0 @@
Config reference
================
BatDetect2 uses separate config objects for different workflow surfaces.
Use the dedicated reference pages for each config family:
- model config
- training config
- logging config
- inference config
- evaluation config
- outputs config
- preprocessing config
- postprocess config
- targets config workflow
Example config files live under `example_data/configs/`.

View File

@ -1,76 +0,0 @@
# Data source reference
This page summarizes dataset source formats and their config fields.
## Supported source formats
| Format | Description |
| --- | --- |
| `aoef` | AOEF/soundevent annotation files (`AnnotationSet` or `AnnotationProject`) |
| `batdetect2` | Legacy format with one JSON annotation file per recording |
| `batdetect2_file` | Legacy format with one merged JSON annotation file |
## AOEF (`format: aoef`)
Required fields:
- `name`
- `format`
- `audio_dir`
- `annotations_path`
Optional fields:
- `description`
- `filter`
`filter` is only used when `annotations_path` points to an
`AnnotationProject`.
AOEF filter options:
- `only_completed` (default: `true`)
- `only_verified` (default: `false`)
- `exclude_issues` (default: `true`)
Use `filter: null` to disable project filtering.
## Legacy per-file (`format: batdetect2`)
Required fields:
- `name`
- `format`
- `audio_dir`
- `annotations_dir`
Optional fields:
- `description`
- `filter`
## Legacy merged file (`format: batdetect2_file`)
Required fields:
- `name`
- `format`
- `audio_dir`
- `annotations_path`
Optional fields:
- `description`
- `filter`
Legacy filter options:
- `only_annotated` (default: `true`)
- `exclude_issues` (default: `true`)
Use `filter: null` to disable filtering.
## Related guides
- {doc}`../how_to/configure-aoef-dataset`
- {doc}`../how_to/import-legacy-batdetect2-annotations`

View File

@ -1,42 +0,0 @@
# Detections reference
These are the main prediction objects returned by BatDetect2 inference methods.
Defined in `batdetect2.postprocess.types`.
## `ClipDetections`
`ClipDetections` represents the predictions for one clip or one full recording.
Fields:
- `clip`
- the `soundevent` clip metadata for the processed audio.
- `detections`
- list of `Detection` objects for that clip.
## `Detection`
`Detection` represents one detected event.
Fields:
- `geometry`
- time-frequency geometry for the detected event.
- `detection_score`
- confidence that there is an event at this location.
- `class_scores`
- class ranking scores for the detected event.
- `features`
- per-detection feature vector from the model.
## Related pages
- Python tutorial:
{doc}`../tutorials/integrate-with-a-python-pipeline`
- API reference:
{doc}`api`
- What BatDetect2 predicts:
{doc}`../explanation/what-batdetect2-predicts`
- Features and embeddings:
{doc}`../explanation/extracted-features-and-embeddings`

View File

@ -1,46 +0,0 @@
# Evaluation config reference
`EvaluationConfig` defines which evaluation tasks run and which plots they generate.
Defined in `batdetect2.evaluate.config`.
## Top-level fields
- `tasks`
- list of task configs.
## Built-in task families
Current built-in tasks include:
- `sound_event_detection`
- `sound_event_classification`
- `top_class_detection`
- `clip_detection`
- `clip_classification`
## Shared task controls
Common task-level controls include:
- `prefix`
- `ignore_start_end`
Sound-event-style tasks also support:
- `affinity`
- `affinity_threshold`
- `strict_match`
## Default behavior
The default evaluation config starts with:
- sound event detection,
- sound event classification.
## Related pages
- Choose and configure evaluation tasks: {doc}`../how_to/choose-and-configure-evaluation-tasks`
- Evaluation concepts: {doc}`../explanation/evaluation-concepts-and-matching`
- Evaluate CLI reference: {doc}`cli/evaluate`

View File

@ -1,28 +0,0 @@
# Reference documentation
Reference pages are the detailed lookup pages.
Use this section when you need exact command options, setting names, output
details, or Python API entries.
```{toctree}
:maxdepth: 1
cli/index
api
detections
model-config
training-config
logging-config
inference-config
evaluation-config
outputs-config
output-formats
output-transforms
data-sources
preprocessing-config
postprocess-config
targets-config-workflow
configs
targets
```

View File

@ -1,41 +0,0 @@
# Inference config reference
`InferenceConfig` controls how files are clipped and batched during prediction-time workflows.
Defined in `batdetect2.inference.config`.
## Top-level fields
- `loader`
- data-loader settings for inference.
- `clipping`
- controls how recordings are split into clips before batching.
## `loader`
Current built-in loader field:
- `batch_size` (int, default `8`)
## `clipping`
Fields:
- `enabled` (bool)
- `duration` (float, seconds)
- `overlap` (float, seconds)
- `max_empty` (float)
- `discard_empty` (bool)
## When to override this config
Override `InferenceConfig` when:
- long recordings need different clipping behavior,
- you want to tune batch size for your hardware,
- you need reproducible prediction settings across runs.
## Related pages
- Tune inference clipping: {doc}`../how_to/tune-inference-clipping`
- Predict CLI reference: {doc}`cli/predict`

View File

@ -1,46 +0,0 @@
# Logging config reference
`AppLoggingConfig` controls which logger backend BatDetect2 uses for training,
evaluation, and inference.
Defined in `batdetect2.logging`.
## Top-level fields
- `train`
- logger config for training runs.
- `evaluation`
- logger config for evaluation runs.
- `inference`
- logger config for inference runs.
## Built-in logger backends
Current built-in logger backends are:
- `csv`
- `tensorboard`
- `mlflow`
- `dvclive`
## Default behaviour
By default:
- training uses `csv`,
- evaluation uses `csv`,
- inference uses `csv`.
With the CSV logger, training writes a `metrics.csv` file in the log folder.
Example files live under `example_data/configs/`, including
`example_data/configs/logging.yaml`.
## Related pages
- Train command reference:
{doc}`cli/train`
- Evaluate command reference:
{doc}`cli/evaluate`
- Run inference on a folder:
{doc}`../tutorials/run-inference-on-folder`

Some files were not shown because too many files have changed in this diff Show More