mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Compare commits
No commits in common. "0adb1bbea7d1f908db65a1b036385ecc286f187e" and "4ae567bc1db0efcda3680fe3e8523805831dbc80" have entirely different histories.
0adb1bbea7
...
4ae567bc1d
@ -1,10 +1,8 @@
|
||||
[bumpversion]
|
||||
current_version = 2.0.0b1
|
||||
current_version = 1.3.1
|
||||
commit = True
|
||||
tag = True
|
||||
|
||||
[bumpversion:file:src/batdetect2/__init__.py]
|
||||
[bumpversion:file:batdetect2/__init__.py]
|
||||
|
||||
[bumpversion:file:pyproject.toml]
|
||||
|
||||
[bumpversion:file:docs/source/conf.py]
|
||||
|
||||
79
.github/workflows/ci.yml
vendored
79
.github/workflows/ci.yml
vendored
@ -1,79 +0,0 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
concurrency:
|
||||
group: ci-${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
checks:
|
||||
name: Checks
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install just
|
||||
uses: taiki-e/install-action@just
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
pyproject.toml
|
||||
uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --all-extras --all-groups
|
||||
|
||||
- name: Run formatting, lint, and type checks
|
||||
run: just check
|
||||
|
||||
tests:
|
||||
name: Tests (Python ${{ matrix.python-version }})
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
- "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install just
|
||||
uses: taiki-e/install-action@just
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
pyproject.toml
|
||||
uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --all-extras --all-groups
|
||||
|
||||
- name: Run test suite
|
||||
run: just test
|
||||
69
.github/workflows/docs-pages.yml
vendored
69
.github/workflows/docs-pages.yml
vendored
@ -1,69 +0,0 @@
|
||||
name: Docs Pages
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: docs-pages
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build Docs
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install just
|
||||
uses: taiki-e/install-action@just
|
||||
|
||||
- name: Configure GitHub Pages
|
||||
uses: actions/configure-pages@v5
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
pyproject.toml
|
||||
uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --all-extras --all-groups
|
||||
|
||||
- name: Build docs
|
||||
run: just check-docs
|
||||
|
||||
- name: Upload Pages artifact
|
||||
uses: actions/upload-pages-artifact@v4
|
||||
with:
|
||||
path: docs/build
|
||||
|
||||
deploy:
|
||||
name: Deploy Docs
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pages: write
|
||||
id-token: write
|
||||
environment:
|
||||
name: github-pages
|
||||
url: ${{ steps.deployment.outputs.page_url }}
|
||||
|
||||
steps:
|
||||
- name: Deploy to GitHub Pages
|
||||
id: deployment
|
||||
uses: actions/deploy-pages@v4
|
||||
70
.github/workflows/publish-pypi.yml
vendored
70
.github/workflows/publish-pypi.yml
vendored
@ -1,70 +0,0 @@
|
||||
name: Publish PyPI
|
||||
|
||||
on:
|
||||
release:
|
||||
types:
|
||||
- published
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: publish-pypi
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build Distributions
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install just
|
||||
uses: taiki-e/install-action@just
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
pyproject.toml
|
||||
uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: uv sync --all-extras --all-groups
|
||||
|
||||
- name: Build distributions
|
||||
run: just build-dist
|
||||
|
||||
- name: Upload distributions
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-dists
|
||||
path: dist/
|
||||
|
||||
publish:
|
||||
name: Publish to PyPI
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/batdetect2
|
||||
|
||||
steps:
|
||||
- name: Download distributions
|
||||
uses: actions/download-artifact@v5
|
||||
with:
|
||||
name: release-dists
|
||||
path: dist/
|
||||
|
||||
- name: Publish to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
with:
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
29
.github/workflows/python-package.yml
vendored
Normal file
29
.github/workflows/python-package.yml
vendored
Normal file
@ -0,0 +1,29 @@
|
||||
name: Python package
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "uv.lock"
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
run: uv python install ${{ matrix.python-version }}
|
||||
- name: Install the project
|
||||
run: uv sync --all-extras --dev
|
||||
- name: Test with pytest
|
||||
run: uv run pytest
|
||||
30
.github/workflows/python-publish.yml
vendored
Normal file
30
.github/workflows/python-publish.yml
vendored
Normal file
@ -0,0 +1,30 @@
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: "3.x"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install build
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
28
.gitignore
vendored
28
.gitignore
vendored
@ -50,7 +50,6 @@ cover/
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
docs/build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
@ -96,15 +95,8 @@ dmypy.json
|
||||
*.json
|
||||
plots/*
|
||||
|
||||
!example_data/anns/*.json
|
||||
|
||||
# Model experiments
|
||||
experiments/*
|
||||
DvcLiveLogger/checkpoints
|
||||
logs/
|
||||
mlruns/
|
||||
/outputs/
|
||||
notebooks/lightning_logs
|
||||
|
||||
# Jupiter notebooks
|
||||
.virtual_documents
|
||||
@ -113,24 +105,8 @@ notebooks/lightning_logs
|
||||
|
||||
# DO Include
|
||||
!batdetect2_notebook.ipynb
|
||||
!src/batdetect2/models/checkpoints/*.pth.tar
|
||||
!batdetect2/models/*.pth.tar
|
||||
!tests/data/*.wav
|
||||
!notebooks/*.ipynb
|
||||
!tests/data/**/*.wav
|
||||
.aider*
|
||||
|
||||
# Intermediate artifacts
|
||||
notebooks/lightning_logs
|
||||
example_data/preprocessed
|
||||
|
||||
# Dev notebooks
|
||||
notebooks/tmp
|
||||
/tmp
|
||||
/.agents/skills
|
||||
/notebooks
|
||||
/AGENTS.md
|
||||
/scripts
|
||||
/todo.md
|
||||
|
||||
# Assets
|
||||
!assets/*
|
||||
/models
|
||||
|
||||
5
.pylintrc
Normal file
5
.pylintrc
Normal file
@ -0,0 +1,5 @@
|
||||
[TYPECHECK]
|
||||
|
||||
# List of members which are set dynamically and missed by Pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed.
|
||||
generated-members=torch.*
|
||||
237
README.md
237
README.md
@ -1,166 +1,161 @@
|
||||
# BatDetect2
|
||||
|
||||
<img style="display:block-inline;" width="64" height="64" src="assets/bat_icon.png">
|
||||
|
||||
Code for detecting and classifying bat echolocation calls in high-frequency
|
||||
audio recordings.
|
||||
|
||||
> [!WARNING]
|
||||
> `batdetect2` 2.0.0b1 is out.
|
||||
> This is a beta release and we are gathering user feedback.
|
||||
> If you run into issues or have feedback on the new workflows, please use the
|
||||
> GitHub issues page to let us know.
|
||||
>
|
||||
> There are many changes and new recommended workflows.
|
||||
> We have left the previous `batdetect2.api` module intact, but if you run
|
||||
> into issues or want to upgrade, see the
|
||||
> [migration guide](docs/source/legacy/migration-guide.md) in the docs site.
|
||||
>
|
||||
> This update also ships with a refreshed default model.
|
||||
> It was trained in the same way and on the same data as before, but you should
|
||||
> still expect small output differences in some cases.
|
||||
|
||||
## What is BatDetect2
|
||||
|
||||
BatDetect2 is a deep learning model for detecting and classifying bat
|
||||
echolocation calls.
|
||||
The model generates multiple predictions for each input recording by providing a
|
||||
bounding box and predicted class for each individual call within it.
|
||||
|
||||
This repository also holds `batdetect2`, a Python-based tool to run, train,
|
||||
finetune and evaluate BatDetect2-type models, including the built-in model for
|
||||
detecting UK bat species.
|
||||
You can use the tool from the command line (terminal) or from Python as needed.
|
||||
|
||||
## Getting Started
|
||||
|
||||
We have [extensive documentation](docs/source/index.md) on how to use
|
||||
`batdetect2`.
|
||||
|
||||
The docs site is still being built and will be live soon.
|
||||
If you want a quick peek for now, see the `docs/` folder in this repository.
|
||||
|
||||
See our [getting started](docs/source/getting_started.md) guide and then jump
|
||||
into any of our tutorials:
|
||||
|
||||
- Run the model on a folder of recordings:
|
||||
`docs/source/tutorials/run-inference-on-folder.md`
|
||||
- Train your own model:
|
||||
`docs/source/tutorials/train-a-custom-model.md`
|
||||
- Evaluate your model:
|
||||
`docs/source/tutorials/evaluate-on-a-test-set.md`
|
||||
- Fine-tune a model:
|
||||
`docs/source/tutorials/integrate-with-a-python-pipeline.md`
|
||||
|
||||
### Try the model
|
||||
|
||||
If you want to try the model for UK bat species without installing anything, you
|
||||
can try the following:
|
||||
|
||||
1. Demo of the model (for UK species) on
|
||||
[huggingface](https://huggingface.co/spaces/macaodha/batdetect2).
|
||||
|
||||
2. Alternatively, click
|
||||
[here](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb)
|
||||
to run the model using Google Colab.
|
||||
You can also run this notebook locally.
|
||||
|
||||
### Installing BatDetect2
|
||||
<img style="display: block-inline;" width="64" height="64" src="ims/bat_icon.png"> Code for detecting and classifying bat echolocation calls in high frequency audio recordings.
|
||||
|
||||
> [!NOTE]
|
||||
> `2.0.0b1` is a pre-release on PyPI.
|
||||
> You may need to request it explicitly by version, for example:
|
||||
>
|
||||
> ```bash
|
||||
> uvx --from batdetect2==2.0.0b1 batdetect2
|
||||
> uv tool install batdetect2==2.0.0b1
|
||||
> pip install batdetect2==2.0.0b1
|
||||
> ```
|
||||
> We’re actively working to make it easier to train and fine-tune BatDetect2 models using custom data. A major update is coming soon to the main branch—stay tuned! In the meantime, you can follow our progress in the train branch.
|
||||
|
||||
If you have `uv` installed (if not, we recommend it; follow the instructions
|
||||
[here](https://docs.astral.sh/uv/getting-started/installation/)), then you can
|
||||
run `batdetect2` one-off with
|
||||
## Getting started
|
||||
### Python Environment
|
||||
|
||||
We recommend using an isolated Python environment to avoid dependency issues. Choose one
|
||||
of the following options:
|
||||
|
||||
* Install the Anaconda Python 3.10 distribution for your operating system from [here](https://www.continuum.io/downloads). Create a new environment and activate it:
|
||||
|
||||
```bash
|
||||
uvx batdetect2
|
||||
conda create -y --name batdetect2 python==3.10
|
||||
conda activate batdetect2
|
||||
```
|
||||
|
||||
or if you want to install it permanently:
|
||||
* If you already have Python installed (version >= 3.8,< 3.11) and prefer using virtual environments then:
|
||||
|
||||
```bash
|
||||
uv tool install batdetect2
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
and test it with
|
||||
### Installing BatDetect2
|
||||
You can use pip to install `batdetect2`:
|
||||
|
||||
```bash
|
||||
batdetect2
|
||||
pip install batdetect2
|
||||
```
|
||||
|
||||
### Run BatDetect2 on a folder of recordings
|
||||
|
||||
Once installed, you can run BatDetect2 on a folder of `.wav` files.
|
||||
By default it will use the model trained on UK data.
|
||||
|
||||
Example command:
|
||||
Alternatively, download this code from the repository (by clicking on the green button on top right) and unzip it.
|
||||
Once unzipped, run this from extracted folder.
|
||||
|
||||
```bash
|
||||
batdetect2 process directory example_data/audio outputs
|
||||
pip install .
|
||||
```
|
||||
|
||||
This will scan the audio files in `example_data/audio` and save model outputs to
|
||||
`outputs`.
|
||||
If you have your own model checkpoint, you can use it:
|
||||
Make sure you have the environment activated before installing `batdetect2`.
|
||||
|
||||
|
||||
## Try the model
|
||||
1) You can try a demo of the model (for UK species) on [huggingface](https://huggingface.co/spaces/macaodha/batdetect2).
|
||||
|
||||
2) Alternatively, click [here](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb) to run the model using Google Colab. You can also run this notebook locally.
|
||||
|
||||
|
||||
## Running the model on your own data
|
||||
|
||||
After following the above steps to install the code you can run the model on your own data.
|
||||
|
||||
|
||||
### Using the command line
|
||||
|
||||
You can run the model by opening the command line and typing:
|
||||
```bash
|
||||
batdetect2 process directory --model path/to/checkpoint.ckpt example_data/audio outputs
|
||||
batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD
|
||||
```
|
||||
e.g.
|
||||
```bash
|
||||
batdetect2 detect example_data/audio/ example_data/anns/ 0.3
|
||||
```
|
||||
|
||||
For the full walkthrough, use
|
||||
`docs/source/tutorials/run-inference-on-folder.md`.
|
||||
`AUDIO_DIR` is the path on your computer to the audio wav files of interest.
|
||||
`ANN_DIR` is the path on your computer where the model predictions will be saved. The model will output both `.csv` and `.json` results for each audio file.
|
||||
`DETECTION_THRESHOLD` is a number between 0 and 1 specifying the cut-off threshold applied to the calls. A smaller number will result in more calls detected, but with the chance of introducing more mistakes.
|
||||
|
||||
There are also optional arguments, e.g. you can request that the model outputs features (i.e. estimated call parameters) such as duration, max_frequency, etc. by setting the flag `--spec_features`. These will be saved as `*_spec_features.csv` files:
|
||||
`batdetect2 detect example_data/audio/ example_data/anns/ 0.3 --spec_features`
|
||||
|
||||
You can also specify which model to use by setting the `--model_path` argument. If not specified, it will default to using a model trained on UK data e.g.
|
||||
`batdetect2 detect example_data/audio/ example_data/anns/ 0.3 --model_path models/Net2DFast_UK_same.pth.tar`
|
||||
|
||||
|
||||
### Using the Python API
|
||||
|
||||
If you prefer to process your data within a Python script then you can use the `batdetect2` Python API.
|
||||
|
||||
```python
|
||||
from batdetect2 import api
|
||||
|
||||
AUDIO_FILE = "example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav"
|
||||
|
||||
# Process a whole file
|
||||
results = api.process_file(AUDIO_FILE)
|
||||
|
||||
# Or, load audio and compute spectrograms
|
||||
audio = api.load_audio(AUDIO_FILE)
|
||||
spec = api.generate_spectrogram(audio)
|
||||
|
||||
# And process the audio or the spectrogram with the model
|
||||
detections, features, spec = api.process_audio(audio)
|
||||
detections, features = api.process_spectrogram(spec)
|
||||
|
||||
# Do something else ...
|
||||
```
|
||||
|
||||
You can integrate the detections or the extracted features to your custom analysis pipeline.
|
||||
|
||||
#### Using the Python API with HTTP
|
||||
|
||||
```python
|
||||
from batdetect2 import api
|
||||
import io
|
||||
import requests
|
||||
|
||||
AUDIO_URL = "<insert your audio url here>"
|
||||
|
||||
# Process a whole file from a url
|
||||
results = api.process_url(AUDIO_URL)
|
||||
|
||||
# Or, load audio and compute spectrograms
|
||||
# 'requests.get(AUDIO_URL).content' fetches the raw bytes. You are free to use other sources to fetch the raw bytes
|
||||
audio = api.load_audio(io.BytesIO(requests.get(AUDIO_URL).content))
|
||||
spec = api.generate_spectrogram(audio)
|
||||
|
||||
# And process the audio or the spectrogram with the model
|
||||
detections, features, spec = api.process_audio(audio)
|
||||
detections, features = api.process_spectrogram(spec)
|
||||
```
|
||||
|
||||
## Training the model on your own data
|
||||
Take a look at the steps outlined in finetuning readme [here](batdetect2/finetune/readme.md) for a description of how to train your own model.
|
||||
|
||||
|
||||
## Data and annotations
|
||||
The raw audio data and annotations used to train the models in the paper will be added soon.
|
||||
The audio interface used to annotate audio data for training and evaluation is available [here](https://github.com/macaodha/batdetect2_GUI).
|
||||
|
||||
The raw audio data and annotations used to train the models in the paper will be
|
||||
added soon.
|
||||
`batdetect2` supports annotations in various formats and is compatible with the
|
||||
outputs of [`whombat`](https://github.com/mbsantiago/whombat/) and this
|
||||
[earlier version](https://github.com/macaodha/batdetect2_GUI).
|
||||
If you're interested in supporting another format, please reach out or submit a
|
||||
PR.
|
||||
|
||||
## Warning
|
||||
The models developed and shared as part of this repository should be used with caution.
|
||||
While they have been evaluated on held out audio data, great care should be taken when using the model outputs for any form of biodiversity assessment.
|
||||
Your data may differ, and as a result it is very strongly recommended that you validate the model first using data with known species to ensure that the outputs can be trusted.
|
||||
|
||||
The models developed and shared as part of this repository should be used with
|
||||
caution.
|
||||
While they have been evaluated on held-out audio data, great care should be
|
||||
taken when using the model outputs for any form of biodiversity assessment.
|
||||
Your data may differ, and as a result it is very strongly recommended that you
|
||||
validate the model first using data with known species to ensure that the
|
||||
outputs can be trusted.
|
||||
If you train a model, make the best effort to be transparent about its training
|
||||
and evaluation data, and inform downstream users about its limitations.
|
||||
|
||||
## FAQ
|
||||
For more information please consult our [FAQ](faq.md).
|
||||
|
||||
For more information please consult our [FAQ](docs/source/faq.md).
|
||||
|
||||
## Reference
|
||||
|
||||
If you find our work useful in your research, please consider citing our paper,
|
||||
which you can find
|
||||
[here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1):
|
||||
|
||||
If you find our work useful in your research please consider citing our paper which you can find [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1):
|
||||
```
|
||||
@article{batdetect2_2022,
|
||||
title = {Towards a General Approach for Bat Echolocation Detection and Classification},
|
||||
author = {Mac Aodha, Oisin and Mart\'{i}nez Balvanera, Santiago and Damstra, Elise and Cooke, Martyn and Eichinski, Philip and Browning, Ella and Barataud, Michel and Boughey, Katherine and Coles, Roger and Giacomini, Giada and MacSwiney G., M. Cristina and K. Obrist, Martin and Parsons, Stuart and Sattler, Thomas and Jones, Kate E.},
|
||||
author = {Mac Aodha, Oisin and Mart\'{i}nez Balvanera, Santiago and Damstra, Elise and Cooke, Martyn and Eichinski, Philip and Browning, Ella and Barataudm, Michel and Boughey, Katherine and Coles, Roger and Giacomini, Giada and MacSwiney G., M. Cristina and K. Obrist, Martin and Parsons, Stuart and Sattler, Thomas and Jones, Kate E.},
|
||||
journal = {bioRxiv},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
## Acknowledgements
|
||||
Thanks to all the contributors who spent time collecting and annotating audio data.
|
||||
|
||||
Thanks to all the contributors who spent time collecting and annotating audio
|
||||
data.
|
||||
|
||||
### TODOs
|
||||
- [x] Release the code and pretrained model
|
||||
- [ ] Release the datasets and annotations used the experiments in the paper
|
||||
- [ ] Add the scripts used to generate the tables and figures from the paper
|
||||
|
||||
6
batdetect2/__init__.py
Normal file
6
batdetect2/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
import logging
|
||||
|
||||
numba_logger = logging.getLogger("numba")
|
||||
numba_logger.setLevel(logging.WARNING)
|
||||
|
||||
__version__ = "1.3.1"
|
||||
@ -96,9 +96,10 @@ If you wish to use a custom model or change the default parameters, please
|
||||
consult the API documentation in the code.
|
||||
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple, BinaryIO, Any, Union
|
||||
|
||||
from .types import AudioPath
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
@ -120,6 +121,12 @@ from batdetect2.types import (
|
||||
)
|
||||
from batdetect2.utils.detector_utils import list_audio_files, load_model
|
||||
|
||||
import audioread
|
||||
import os
|
||||
import soundfile as sf
|
||||
import requests
|
||||
import io
|
||||
|
||||
# Remove warnings from torch
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="torch")
|
||||
|
||||
@ -164,7 +171,7 @@ def load_audio(
|
||||
time_exp_fact: float = 1,
|
||||
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
scale: bool = False,
|
||||
max_duration: float | None = None,
|
||||
max_duration: Optional[float] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load audio from file.
|
||||
|
||||
@ -202,7 +209,7 @@ def load_audio(
|
||||
def generate_spectrogram(
|
||||
audio: np.ndarray,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: SpectrogramParameters | None = None,
|
||||
config: Optional[SpectrogramParameters] = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> torch.Tensor:
|
||||
"""Generate spectrogram from audio array.
|
||||
@ -226,10 +233,11 @@ def generate_spectrogram(
|
||||
if config is None:
|
||||
config = DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
|
||||
_, spec = du.compute_spectrogram(
|
||||
_, spec, _ = du.compute_spectrogram(
|
||||
audio,
|
||||
samp_rate,
|
||||
config,
|
||||
return_np=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@ -237,41 +245,89 @@ def generate_spectrogram(
|
||||
|
||||
|
||||
def process_file(
|
||||
audio_file: str,
|
||||
path: AudioPath,
|
||||
model: DetectionModel = MODEL,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
device: torch.device = DEVICE,
|
||||
file_id: Optional[str] = None
|
||||
) -> du.RunResults:
|
||||
"""Process audio file with model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_file : str
|
||||
Path to audio file.
|
||||
path : AudioPath
|
||||
Path to audio data.
|
||||
model : DetectionModel, optional
|
||||
Detection model. Uses default model if not specified.
|
||||
config : Optional[ProcessingConfiguration], optional
|
||||
Processing configuration, by default None (uses default parameters).
|
||||
device : torch.device, optional
|
||||
Device to use, by default tries to use GPU if available.
|
||||
file_id: Optional[str],
|
||||
Give the data an id. If path is a string path to a file this can be ignored and
|
||||
the file_id will be the basename of the file.
|
||||
"""
|
||||
if config is None:
|
||||
config = CONFIG
|
||||
|
||||
return du.process_file(
|
||||
audio_file,
|
||||
path,
|
||||
model,
|
||||
config,
|
||||
device,
|
||||
file_id
|
||||
)
|
||||
|
||||
def process_url(
|
||||
url: str,
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
device: torch.device = DEVICE,
|
||||
file_id: Optional[str] = None
|
||||
) -> du.RunResults:
|
||||
"""Process audio file with model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
url : str
|
||||
HTTP URL to load the audio data from
|
||||
model : DetectionModel, optional
|
||||
Detection model. Uses default model if not specified.
|
||||
config : Optional[ProcessingConfiguration], optional
|
||||
Processing configuration, by default None (uses default parameters).
|
||||
device : torch.device, optional
|
||||
Device to use, by default tries to use GPU if available.
|
||||
file_id: Optional[str],
|
||||
Give the data an id. Defaults to the URL
|
||||
"""
|
||||
if config is None:
|
||||
config = CONFIG
|
||||
|
||||
if file_id is None:
|
||||
file_id = url
|
||||
|
||||
response = requests.get(url)
|
||||
|
||||
# Raise exception on HTTP error
|
||||
response.raise_for_status()
|
||||
|
||||
# Retrieve body as raw bytes
|
||||
raw_audio_data = response.content
|
||||
|
||||
return du.process_file(
|
||||
io.BytesIO(raw_audio_data),
|
||||
model,
|
||||
config,
|
||||
device,
|
||||
file_id
|
||||
)
|
||||
|
||||
def process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
model: DetectionModel = MODEL,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
) -> tuple[list[Annotation], np.ndarray]:
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
"""Process spectrogram with model.
|
||||
|
||||
Parameters
|
||||
@ -311,9 +367,9 @@ def process_audio(
|
||||
audio: np.ndarray,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
model: DetectionModel = MODEL,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> tuple[list[Annotation], np.ndarray, torch.Tensor]:
|
||||
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
|
||||
"""Process audio array with model.
|
||||
|
||||
Parameters
|
||||
@ -355,8 +411,8 @@ def process_audio(
|
||||
def postprocess(
|
||||
outputs: ModelOutput,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
) -> tuple[list[Annotation], np.ndarray]:
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
"""Postprocess model outputs.
|
||||
|
||||
Convert model tensor outputs to predicted bounding boxes and
|
||||
@ -410,9 +466,7 @@ def print_summary(results: RunResults) -> None:
|
||||
Detection result.
|
||||
"""
|
||||
print("Results for " + results["pred_dict"]["id"])
|
||||
print(
|
||||
"{} calls detected\n".format(len(results["pred_dict"]["annotation"]))
|
||||
)
|
||||
print("{} calls detected\n".format(len(results["pred_dict"]["annotation"])))
|
||||
|
||||
print("time\tprob\tlfreq\tspecies_name")
|
||||
for ann in results["pred_dict"]["annotation"]:
|
||||
@ -1,24 +1,32 @@
|
||||
"""BatDetect2 command line interface."""
|
||||
|
||||
import os
|
||||
|
||||
import click
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2 import api
|
||||
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
||||
from batdetect2.types import ProcessingConfiguration
|
||||
from batdetect2.utils.detector_utils import save_results_to_file
|
||||
|
||||
DEFAULT_MODEL_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"models",
|
||||
"checkpoints",
|
||||
"Net2DFast_UK_same.pth.tar",
|
||||
)
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
@cli.command(
|
||||
short_help="Legacy detection command.",
|
||||
epilog=(
|
||||
"Deprecated workflow. Prefer `batdetect2 process directory` for "
|
||||
"new analyses."
|
||||
),
|
||||
)
|
||||
INFO_STR = """
|
||||
BatDetect2 - Detection and Classification
|
||||
Assumes audio files are mono, not stereo.
|
||||
Spaces in the input paths will throw an error. Wrap in quotes.
|
||||
Input files should be short in duration e.g. < 30 seconds.
|
||||
"""
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
"""BatDetect2 - Bat Call Detection and Classification."""
|
||||
click.echo(INFO_STR)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument(
|
||||
"audio_dir",
|
||||
type=click.Path(exists=True),
|
||||
@ -37,6 +45,12 @@ DEFAULT_MODEL_PATH = os.path.join(
|
||||
default=False,
|
||||
help="Extracts CNN call features",
|
||||
)
|
||||
@click.option(
|
||||
"--chunk_size",
|
||||
type=float,
|
||||
default=2,
|
||||
help="Specifies the duration of chunks in seconds. BatDetect2 will divide longer files into smaller chunks and process them independently. Larger chunks increase computation time and memory usage but may provide more contextual information for inference.",
|
||||
)
|
||||
@click.option(
|
||||
"--spec_features",
|
||||
is_flag=True,
|
||||
@ -72,12 +86,10 @@ def detect(
|
||||
ann_dir: str,
|
||||
detection_threshold: float,
|
||||
time_expansion_factor: int,
|
||||
chunk_size: float,
|
||||
**args,
|
||||
):
|
||||
"""Legacy detection command for directory-based inference.
|
||||
|
||||
Detect bat calls in files in `AUDIO_DIR` and save predictions to
|
||||
`ANN_DIR`.
|
||||
"""Detect bat calls in files in AUDIO_DIR and save predictions to ANN_DIR.
|
||||
|
||||
DETECTION_THRESHOLD is the detection threshold. All predictions with a
|
||||
score below this threshold will be discarded. Values between 0 and 1.
|
||||
@ -87,21 +99,7 @@ def detect(
|
||||
Spaces in the input paths will throw an error. Wrap in quotes.
|
||||
|
||||
Input files should be short in duration e.g. < 30 seconds.
|
||||
|
||||
Note
|
||||
----
|
||||
This command is kept for backwards compatibility. Prefer
|
||||
`batdetect2 process directory` for new workflows.
|
||||
"""
|
||||
from batdetect2 import api
|
||||
from batdetect2.utils.detector_utils import save_results_to_file
|
||||
|
||||
message = (
|
||||
"The `batdetect2 detect` command is deprecated. Prefer "
|
||||
"`batdetect2 process directory` for new analyses."
|
||||
)
|
||||
click.secho(f"WARNING: {message}", fg="yellow", err=True)
|
||||
|
||||
click.echo(f"Loading model: {args['model_path']}")
|
||||
model, params = api.load_model(args["model_path"])
|
||||
|
||||
@ -117,7 +115,7 @@ def detect(
|
||||
**args,
|
||||
"time_expansion": time_expansion_factor,
|
||||
"spec_slices": False,
|
||||
"chunk_size": 2,
|
||||
"chunk_size": chunk_size,
|
||||
"detection_threshold": detection_threshold,
|
||||
}
|
||||
)
|
||||
@ -151,8 +149,13 @@ def detect(
|
||||
click.echo(f" {err}")
|
||||
|
||||
|
||||
def print_config(config):
|
||||
"""Print the processing configuration values."""
|
||||
def print_config(config: ProcessingConfiguration):
|
||||
"""Print the processing configuration."""
|
||||
click.echo("\nProcessing Configuration:")
|
||||
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
||||
click.echo(f"Detection Threshold: {config.get('detection_threshold')}")
|
||||
click.echo(f"Chunk Size: {config.get('chunk_size')}s")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@ -1,6 +1,5 @@
|
||||
"""Functions to compute features from predictions."""
|
||||
|
||||
from typing import Dict, List
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -8,26 +7,15 @@ from batdetect2 import types
|
||||
from batdetect2.detector.parameters import MAX_FREQ_HZ, MIN_FREQ_HZ
|
||||
|
||||
|
||||
def convert_int_to_freq(
|
||||
spec_ind: int,
|
||||
spec_height: int,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
) -> int:
|
||||
def convert_int_to_freq(spec_ind, spec_height, min_freq, max_freq):
|
||||
"""Convert spectrogram index to frequency in Hz.""" ""
|
||||
spec_ind = spec_height - spec_ind
|
||||
return int(
|
||||
round(
|
||||
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq,
|
||||
2,
|
||||
)
|
||||
return round(
|
||||
(spec_ind / float(spec_height)) * (max_freq - min_freq) + min_freq, 2
|
||||
)
|
||||
|
||||
|
||||
def extract_spec_slices(
|
||||
spec: np.ndarray,
|
||||
pred_nms: types.PredictionResults,
|
||||
) -> List[np.ndarray]:
|
||||
def extract_spec_slices(spec, pred_nms):
|
||||
"""Extract spectrogram slices from spectrogram.
|
||||
|
||||
The slices are extracted based on detected call locations.
|
||||
@ -86,7 +74,7 @@ def compute_bandwidth(
|
||||
|
||||
def compute_max_power_bb(
|
||||
prediction: types.Prediction,
|
||||
spec: np.ndarray | None = None,
|
||||
spec: Optional[np.ndarray] = None,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
**_,
|
||||
@ -121,7 +109,7 @@ def compute_max_power_bb(
|
||||
|
||||
return int(
|
||||
convert_int_to_freq(
|
||||
int(y_high + max_power_ind),
|
||||
y_high + max_power_ind,
|
||||
spec.shape[0],
|
||||
min_freq,
|
||||
max_freq,
|
||||
@ -131,7 +119,7 @@ def compute_max_power_bb(
|
||||
|
||||
def compute_max_power(
|
||||
prediction: types.Prediction,
|
||||
spec: np.ndarray | None = None,
|
||||
spec: Optional[np.ndarray] = None,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
**_,
|
||||
@ -147,17 +135,19 @@ def compute_max_power(
|
||||
spec_call = spec[:, x_start:x_end]
|
||||
power_per_freq_band = np.sum(spec_call, axis=1)
|
||||
max_power_ind = np.argmax(power_per_freq_band)
|
||||
return convert_int_to_freq(
|
||||
int(max_power_ind),
|
||||
return int(
|
||||
convert_int_to_freq(
|
||||
max_power_ind,
|
||||
spec.shape[0],
|
||||
min_freq,
|
||||
max_freq,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def compute_max_power_first(
|
||||
prediction: types.Prediction,
|
||||
spec: np.ndarray | None = None,
|
||||
spec: Optional[np.ndarray] = None,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
**_,
|
||||
@ -174,17 +164,19 @@ def compute_max_power_first(
|
||||
first_half = spec_call[:, : int(spec_call.shape[1] / 2)]
|
||||
power_per_freq_band = np.sum(first_half, axis=1)
|
||||
max_power_ind = np.argmax(power_per_freq_band)
|
||||
return convert_int_to_freq(
|
||||
int(max_power_ind),
|
||||
return int(
|
||||
convert_int_to_freq(
|
||||
max_power_ind,
|
||||
spec.shape[0],
|
||||
min_freq,
|
||||
max_freq,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def compute_max_power_second(
|
||||
prediction: types.Prediction,
|
||||
spec: np.ndarray | None = None,
|
||||
spec: Optional[np.ndarray] = None,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
**_,
|
||||
@ -201,17 +193,19 @@ def compute_max_power_second(
|
||||
second_half = spec_call[:, int(spec_call.shape[1] / 2) :]
|
||||
power_per_freq_band = np.sum(second_half, axis=1)
|
||||
max_power_ind = np.argmax(power_per_freq_band)
|
||||
return convert_int_to_freq(
|
||||
int(max_power_ind),
|
||||
return int(
|
||||
convert_int_to_freq(
|
||||
max_power_ind,
|
||||
spec.shape[0],
|
||||
min_freq,
|
||||
max_freq,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def compute_call_interval(
|
||||
prediction: types.Prediction,
|
||||
previous: types.Prediction | None = None,
|
||||
previous: Optional[types.Prediction] = None,
|
||||
**_,
|
||||
) -> float:
|
||||
"""Compute time between this call and the previous call in seconds."""
|
||||
@ -242,7 +236,7 @@ def get_feats(
|
||||
spec: np.ndarray,
|
||||
pred_nms: types.PredictionResults,
|
||||
params: types.FeatureExtractionParameters,
|
||||
) -> np.ndarray:
|
||||
):
|
||||
"""Extract features from spectrogram based on detected call locations.
|
||||
|
||||
The features extracted are:
|
||||
@ -53,13 +53,7 @@ class SelfAttention(nn.Module):
|
||||
|
||||
class ConvBlockDownCoordF(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chn,
|
||||
out_chn,
|
||||
ip_height,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, stride=1
|
||||
):
|
||||
super(ConvBlockDownCoordF, self).__init__()
|
||||
self.coords = nn.Parameter(
|
||||
@ -85,13 +79,7 @@ class ConvBlockDownCoordF(nn.Module):
|
||||
|
||||
class ConvBlockDownStandard(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_chn,
|
||||
out_chn,
|
||||
ip_height=None,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
self, in_chn, out_chn, ip_height=None, k_size=3, pad_size=1, stride=1
|
||||
):
|
||||
super(ConvBlockDownStandard, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import torch.fft
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
@ -94,10 +95,7 @@ class Net2DFast(nn.Module):
|
||||
num_filts // 4, 2, kernel_size=1, padding=0
|
||||
)
|
||||
self.conv_classes_op = nn.Conv2d(
|
||||
num_filts // 4,
|
||||
self.num_classes + 1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
if self.emb_dim > 0:
|
||||
@ -105,15 +103,15 @@ class Net2DFast(nn.Module):
|
||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
def forward(self, ip, return_feats=False) -> ModelOutput:
|
||||
# encoder
|
||||
x1 = self.conv_dn_0(spec)
|
||||
x1 = self.conv_dn_0(ip)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
|
||||
|
||||
# bottleneck
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
|
||||
x = self.att(x)
|
||||
x = x.repeat([1, 1, self.bneck_height * 4, 1])
|
||||
|
||||
@ -123,13 +121,13 @@ class Net2DFast(nn.Module):
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
# output
|
||||
x = F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
|
||||
cls = self.conv_classes_op(x)
|
||||
comb = torch.softmax(cls, 1)
|
||||
|
||||
return ModelOutput(
|
||||
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||
pred_size=F.relu(self.conv_size_op(x)),
|
||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||
pred_class=comb,
|
||||
pred_class_un_norm=cls,
|
||||
features=x,
|
||||
@ -209,10 +207,7 @@ class Net2DFastNoAttn(nn.Module):
|
||||
num_filts // 4, 2, kernel_size=1, padding=0
|
||||
)
|
||||
self.conv_classes_op = nn.Conv2d(
|
||||
num_filts // 4,
|
||||
self.num_classes + 1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
if self.emb_dim > 0:
|
||||
@ -220,26 +215,26 @@ class Net2DFastNoAttn(nn.Module):
|
||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
x1 = self.conv_dn_0(spec)
|
||||
def forward(self, ip, return_feats=False) -> ModelOutput:
|
||||
x1 = self.conv_dn_0(ip)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
|
||||
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
|
||||
x = x.repeat([1, 1, self.bneck_height * 4, 1])
|
||||
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
x = F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
|
||||
cls = self.conv_classes_op(x)
|
||||
comb = torch.softmax(cls, 1)
|
||||
|
||||
return ModelOutput(
|
||||
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||
pred_size=F.relu_(self.conv_size_op(x)),
|
||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||
pred_class=comb,
|
||||
pred_class_un_norm=cls,
|
||||
features=x,
|
||||
@ -329,13 +324,13 @@ class Net2DFastNoCoordConv(nn.Module):
|
||||
num_filts, self.emb_dim, kernel_size=1, padding=0
|
||||
)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
x1 = self.conv_dn_0(spec)
|
||||
def forward(self, ip, return_feats=False) -> ModelOutput:
|
||||
x1 = self.conv_dn_0(ip)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
x3 = F.relu(self.conv_dn_3_bn(self.conv_dn_3(x3)), inplace=True)
|
||||
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = F.relu(self.conv_1d_bn(self.conv_1d(x3)), inplace=True)
|
||||
x = self.att(x)
|
||||
x = x.repeat([1, 1, self.bneck_height * 4, 1])
|
||||
|
||||
@ -343,13 +338,15 @@ class Net2DFastNoCoordConv(nn.Module):
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
x = F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||
x = F.relu(self.conv_op_bn(self.conv_op(x)), inplace=True)
|
||||
cls = self.conv_classes_op(x)
|
||||
comb = torch.softmax(cls, 1)
|
||||
|
||||
pred_emb = (self.conv_emb(x) if self.emb_dim > 0 else None,)
|
||||
|
||||
return ModelOutput(
|
||||
pred_det=comb[:, :-1, :, :].sum(1).unsqueeze(1),
|
||||
pred_size=F.relu_(self.conv_size_op(x)),
|
||||
pred_size=F.relu(self.conv_size_op(x), inplace=True),
|
||||
pred_class=comb,
|
||||
pred_class_un_norm=cls,
|
||||
features=x,
|
||||
232
batdetect2/detector/parameters.py
Normal file
232
batdetect2/detector/parameters.py
Normal file
@ -0,0 +1,232 @@
|
||||
import datetime
|
||||
import os
|
||||
|
||||
from batdetect2.types import ProcessingConfiguration, SpectrogramParameters
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256000
|
||||
FFT_WIN_LENGTH_S = 512 / 256000.0
|
||||
FFT_OVERLAP = 0.75
|
||||
MAX_FREQ_HZ = 120000
|
||||
MIN_FREQ_HZ = 10000
|
||||
RESIZE_FACTOR = 0.5
|
||||
SPEC_DIVIDE_FACTOR = 32
|
||||
SPEC_HEIGHT = 256
|
||||
SCALE_RAW_AUDIO = False
|
||||
DETECTION_THRESHOLD = 0.01
|
||||
NMS_KERNEL_SIZE = 9
|
||||
NMS_TOP_K_PER_SEC = 200
|
||||
SPEC_SCALE = "pcen"
|
||||
DENOISE_SPEC_AVG = True
|
||||
MAX_SCALE_SPEC = False
|
||||
|
||||
|
||||
DEFAULT_MODEL_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"models",
|
||||
"Net2DFast_UK_same.pth.tar",
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_SPECTROGRAM_PARAMETERS: SpectrogramParameters = {
|
||||
"fft_win_length": FFT_WIN_LENGTH_S,
|
||||
"fft_overlap": FFT_OVERLAP,
|
||||
"spec_height": SPEC_HEIGHT,
|
||||
"resize_factor": RESIZE_FACTOR,
|
||||
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
||||
"max_freq": MAX_FREQ_HZ,
|
||||
"min_freq": MIN_FREQ_HZ,
|
||||
"spec_scale": SPEC_SCALE,
|
||||
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||
"max_scale_spec": MAX_SCALE_SPEC,
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_PROCESSING_CONFIGURATIONS: ProcessingConfiguration = {
|
||||
"detection_threshold": DETECTION_THRESHOLD,
|
||||
"spec_slices": False,
|
||||
"chunk_size": 3,
|
||||
"spec_features": False,
|
||||
"cnn_features": False,
|
||||
"quiet": True,
|
||||
"target_samp_rate": TARGET_SAMPLERATE_HZ,
|
||||
"fft_win_length": FFT_WIN_LENGTH_S,
|
||||
"fft_overlap": FFT_OVERLAP,
|
||||
"resize_factor": RESIZE_FACTOR,
|
||||
"spec_divide_factor": SPEC_DIVIDE_FACTOR,
|
||||
"spec_height": SPEC_HEIGHT,
|
||||
"scale_raw_audio": SCALE_RAW_AUDIO,
|
||||
"class_names": [],
|
||||
"time_expansion": 1,
|
||||
"top_n": 3,
|
||||
"return_raw_preds": False,
|
||||
"max_duration": None,
|
||||
"nms_kernel_size": NMS_KERNEL_SIZE,
|
||||
"max_freq": MAX_FREQ_HZ,
|
||||
"min_freq": MIN_FREQ_HZ,
|
||||
"nms_top_k_per_sec": NMS_TOP_K_PER_SEC,
|
||||
"spec_scale": SPEC_SCALE,
|
||||
"denoise_spec_avg": DENOISE_SPEC_AVG,
|
||||
"max_scale_spec": MAX_SCALE_SPEC,
|
||||
}
|
||||
|
||||
|
||||
def mk_dir(path):
|
||||
if not os.path.isdir(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def get_params(make_dirs=False, exps_dir="../../experiments/"):
|
||||
params = {}
|
||||
|
||||
params[
|
||||
"model_name"
|
||||
] = "Net2DFast" # Net2DFast, Net2DSkip, Net2DSimple, Net2DSkipDS, Net2DRN
|
||||
params["num_filters"] = 128
|
||||
|
||||
now_str = datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")
|
||||
model_name = now_str + ".pth.tar"
|
||||
params["experiment"] = os.path.join(exps_dir, now_str, "")
|
||||
params["model_file_name"] = os.path.join(params["experiment"], model_name)
|
||||
params["op_im_dir"] = os.path.join(params["experiment"], "op_ims", "")
|
||||
params["op_im_dir_test"] = os.path.join(
|
||||
params["experiment"], "op_ims_test", ""
|
||||
)
|
||||
# params['notes'] = '' # can save notes about an experiment here
|
||||
|
||||
# spec parameters
|
||||
params[
|
||||
"target_samp_rate"
|
||||
] = TARGET_SAMPLERATE_HZ # resamples all audio so that it is at this rate
|
||||
params[
|
||||
"fft_win_length"
|
||||
] = FFT_WIN_LENGTH_S # in milliseconds, amount of time per stft time step
|
||||
params["fft_overlap"] = FFT_OVERLAP # stft window overlap
|
||||
|
||||
params[
|
||||
"max_freq"
|
||||
] = MAX_FREQ_HZ # in Hz, everything above this will be discarded
|
||||
params[
|
||||
"min_freq"
|
||||
] = MIN_FREQ_HZ # in Hz, everything below this will be discarded
|
||||
|
||||
params[
|
||||
"resize_factor"
|
||||
] = RESIZE_FACTOR # resize so the spectrogram at the input of the network
|
||||
params[
|
||||
"spec_height"
|
||||
] = SPEC_HEIGHT # units are number of frequency bins (before resizing is performed)
|
||||
params[
|
||||
"spec_train_width"
|
||||
] = 512 # units are number of time steps (before resizing is performed)
|
||||
params[
|
||||
"spec_divide_factor"
|
||||
] = SPEC_DIVIDE_FACTOR # spectrogram should be divisible by this amount in width and height
|
||||
|
||||
# spec processing params
|
||||
params[
|
||||
"denoise_spec_avg"
|
||||
] = DENOISE_SPEC_AVG # removes the mean for each frequency band
|
||||
params[
|
||||
"scale_raw_audio"
|
||||
] = SCALE_RAW_AUDIO # scales the raw audio to [-1, 1]
|
||||
params[
|
||||
"max_scale_spec"
|
||||
] = MAX_SCALE_SPEC # scales the spectrogram so that it is max 1
|
||||
params["spec_scale"] = SPEC_SCALE # 'log', 'pcen', 'none'
|
||||
|
||||
# detection params
|
||||
params[
|
||||
"detection_overlap"
|
||||
] = 0.01 # has to be within this number of ms to count as detection
|
||||
params[
|
||||
"ignore_start_end"
|
||||
] = 0.01 # if start of GT calls are within this time from the start/end of file ignore
|
||||
params[
|
||||
"detection_threshold"
|
||||
] = DETECTION_THRESHOLD # the smaller this is the better the recall will be
|
||||
params[
|
||||
"nms_kernel_size"
|
||||
] = NMS_KERNEL_SIZE # size of the kernel for non-max suppression
|
||||
params[
|
||||
"nms_top_k_per_sec"
|
||||
] = NMS_TOP_K_PER_SEC # keep top K highest predictions per second of audio
|
||||
params["target_sigma"] = 2.0
|
||||
|
||||
# augmentation params
|
||||
params[
|
||||
"aug_prob"
|
||||
] = 0.20 # augmentations will be performed with this probability
|
||||
params["augment_at_train"] = True
|
||||
params["augment_at_train_combine"] = True
|
||||
params[
|
||||
"echo_max_delay"
|
||||
] = 0.005 # simulate echo by adding copy of raw audio
|
||||
params["stretch_squeeze_delta"] = 0.04 # stretch or squeeze spec
|
||||
params[
|
||||
"mask_max_time_perc"
|
||||
] = 0.05 # max mask size - here percentage, not ideal
|
||||
params[
|
||||
"mask_max_freq_perc"
|
||||
] = 0.10 # max mask size - here percentage, not ideal
|
||||
params[
|
||||
"spec_amp_scaling"
|
||||
] = 2.0 # multiply the "volume" by 0:X times current amount
|
||||
params["aug_sampling_rates"] = [
|
||||
220500,
|
||||
256000,
|
||||
300000,
|
||||
312500,
|
||||
384000,
|
||||
441000,
|
||||
500000,
|
||||
]
|
||||
|
||||
# loss params
|
||||
params["train_loss"] = "focal" # mse or focal
|
||||
params["det_loss_weight"] = 1.0 # weight for the detection part of the loss
|
||||
params["size_loss_weight"] = 0.1 # weight for the bbox size loss
|
||||
params["class_loss_weight"] = 2.0 # weight for the classification loss
|
||||
params["individual_loss_weight"] = 0.0 # not used
|
||||
if params["individual_loss_weight"] == 0.0:
|
||||
params[
|
||||
"emb_dim"
|
||||
] = 0 # number of dimensions used for individual id embedding
|
||||
else:
|
||||
params["emb_dim"] = 3
|
||||
|
||||
# train params
|
||||
params["lr"] = 0.001
|
||||
params["batch_size"] = 8
|
||||
params["num_workers"] = 4
|
||||
params["num_epochs"] = 200
|
||||
params["num_eval_epochs"] = 5 # run evaluation every X epochs
|
||||
params["device"] = "cuda"
|
||||
params["save_test_image_during_train"] = False
|
||||
params["save_test_image_after_train"] = True
|
||||
|
||||
params["convert_to_genus"] = False
|
||||
params["genus_mapping"] = []
|
||||
params["class_names"] = []
|
||||
params["classes_to_ignore"] = ["", " ", "Unknown", "Not Bat"]
|
||||
params["generic_class"] = ["Bat"]
|
||||
params["events_of_interest"] = [
|
||||
"Echolocation"
|
||||
] # will ignore all other types of events e.g. social calls
|
||||
|
||||
# the classes in this list are standardized during training so that the same low and high freq are used
|
||||
params["standardize_classs_names"] = []
|
||||
|
||||
# create directories
|
||||
if make_dirs:
|
||||
print("Model name : " + params["model_name"])
|
||||
print("Model file : " + params["model_file_name"])
|
||||
print("Experiment : " + params["experiment"])
|
||||
|
||||
mk_dir(params["experiment"])
|
||||
if params["save_test_image_during_train"]:
|
||||
mk_dir(params["op_im_dir"])
|
||||
if params["save_test_image_after_train"]:
|
||||
mk_dir(params["op_im_dir_test"])
|
||||
mk_dir(os.path.dirname(params["model_file_name"]))
|
||||
|
||||
return params
|
||||
@ -1,4 +1,5 @@
|
||||
"""Post-processing of the output of the model."""
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -43,7 +44,7 @@ def run_nms(
|
||||
outputs: ModelOutput,
|
||||
params: NonMaximumSuppressionConfig,
|
||||
sampling_rate: np.ndarray,
|
||||
) -> tuple[list[PredictionResults], list[np.ndarray]]:
|
||||
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
|
||||
"""Run non-maximum suppression on the output of the model.
|
||||
|
||||
Model outputs processed are expected to have a batch dimension.
|
||||
@ -71,8 +72,8 @@ def run_nms(
|
||||
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
|
||||
|
||||
# loop over batch to save outputs
|
||||
preds: list[PredictionResults] = []
|
||||
feats: list[np.ndarray] = []
|
||||
preds: List[PredictionResults] = []
|
||||
feats: List[np.ndarray] = []
|
||||
for num_detection in range(pred_det_nms.shape[0]):
|
||||
# get valid indices
|
||||
inds_ord = torch.argsort(x_pos[num_detection, :])
|
||||
@ -149,7 +150,7 @@ def run_nms(
|
||||
|
||||
def non_max_suppression(
|
||||
heat: torch.Tensor,
|
||||
kernel_size: int | tuple[int, int],
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
):
|
||||
# kernel can be an int or list/tuple
|
||||
if isinstance(kernel_size, int):
|
||||
@ -7,19 +7,20 @@ import copy
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
import batdetect2.evaluate.legacy.evaluate_models as evl
|
||||
import batdetect2.train.legacy.train_utils as tu
|
||||
from batdetect2.detector import parameters
|
||||
import batdetect2.train.evaluate as evl
|
||||
import batdetect2.train.train_utils as tu
|
||||
import batdetect2.utils.detector_utils as du
|
||||
import batdetect2.utils.plot_utils as pu
|
||||
from batdetect2.detector import parameters
|
||||
|
||||
|
||||
def get_blank_annotation(ip_str):
|
||||
|
||||
res = {}
|
||||
res["class_name"] = ""
|
||||
res["duration"] = -1
|
||||
@ -76,6 +77,7 @@ def create_genus_mapping(gt_test, preds, class_names):
|
||||
|
||||
|
||||
def load_tadarida_pred(ip_dir, dataset, file_of_interest):
|
||||
|
||||
res, ann = get_blank_annotation("Generated by Tadarida")
|
||||
|
||||
# create the annotations in the correct format
|
||||
@ -118,6 +120,7 @@ def load_sonobat_meta(
|
||||
class_names,
|
||||
only_accepted_species=True,
|
||||
):
|
||||
|
||||
sp_dict = {}
|
||||
for ss in class_names:
|
||||
sp_key = ss.split(" ")[0][:3] + ss.split(" ")[1][:3]
|
||||
@ -179,6 +182,7 @@ def load_sonobat_meta(
|
||||
|
||||
|
||||
def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
||||
|
||||
# create the annotations in the correct format
|
||||
res, ann = get_blank_annotation("Generated by Sonobat")
|
||||
res_c = copy.deepcopy(res)
|
||||
@ -217,6 +221,7 @@ def load_sonobat_preds(dataset, id, sb_meta, set_class_name=None):
|
||||
|
||||
|
||||
def bb_overlap(bb_g_in, bb_p_in):
|
||||
|
||||
freq_scale = 10000000.0 # ensure that both axis are roughly the same range
|
||||
bb_g = [
|
||||
bb_g_in["start_time"],
|
||||
@ -325,8 +330,7 @@ def load_gt_data(datasets, events_of_interest, class_names, classes_to_ignore):
|
||||
for dd in datasets:
|
||||
print("\n" + dd["dataset_name"])
|
||||
gt_dataset = tu.load_set_of_anns(
|
||||
[dd],
|
||||
events_of_interest=events_of_interest,
|
||||
[dd], events_of_interest=events_of_interest, verbose=True
|
||||
)
|
||||
gt_dataset = [
|
||||
parse_data(gg, class_names, classes_to_ignore, False)
|
||||
@ -357,7 +361,7 @@ def train_rf_model(x_train, y_train, num_classes, seed=2001):
|
||||
clf = RandomForestClassifier(random_state=seed, n_jobs=-1)
|
||||
clf.fit(x_train, y_train)
|
||||
y_pred = clf.predict(x_train)
|
||||
(y_pred == y_train).mean()
|
||||
tr_acc = (y_pred == y_train).mean()
|
||||
# print('Train acc', round(tr_acc*100, 2))
|
||||
return clf, un_train_class
|
||||
|
||||
@ -450,7 +454,7 @@ def add_root_path_back(data_sets, ann_path, wav_path):
|
||||
|
||||
|
||||
def check_classes_in_train(gt_list, class_names):
|
||||
np.sum([gg["start_times"].shape[0] for gg in gt_list])
|
||||
num_gt_total = np.sum([gg["start_times"].shape[0] for gg in gt_list])
|
||||
num_with_no_class = 0
|
||||
for gt in gt_list:
|
||||
for cc in gt["class_names"]:
|
||||
@ -460,6 +464,7 @@ def check_classes_in_train(gt_list, class_names):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"op_dir",
|
||||
@ -548,9 +553,7 @@ if __name__ == "__main__":
|
||||
test_dict["dataset_name"] = args["test_file"].replace(".json", "")
|
||||
test_dict["is_test"] = True
|
||||
test_dict["is_binary"] = True
|
||||
test_dict["ann_path"] = os.path.join(
|
||||
args["ann_dir"], args["test_file"]
|
||||
)
|
||||
test_dict["ann_path"] = os.path.join(args["ann_dir"], args["test_file"])
|
||||
test_dict["wav_path"] = args["data_dir"]
|
||||
test_sets = [test_dict]
|
||||
|
||||
@ -569,7 +572,7 @@ if __name__ == "__main__":
|
||||
num_with_no_class = check_classes_in_train(gt_test, class_names)
|
||||
if total_num_calls == num_with_no_class:
|
||||
print("Classes from the test set are not in the train set.")
|
||||
raise AssertionError()
|
||||
assert False
|
||||
|
||||
# only need the train data if evaluating Sonobat or Tadarida
|
||||
if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "":
|
||||
@ -743,7 +746,7 @@ if __name__ == "__main__":
|
||||
# check if the class names are the same
|
||||
if params_bd["class_names"] != class_names:
|
||||
print("Warning: Class names are not the same as the trained model")
|
||||
raise AssertionError()
|
||||
assert False
|
||||
|
||||
run_config = {
|
||||
**bd_args,
|
||||
@ -753,7 +756,7 @@ if __name__ == "__main__":
|
||||
|
||||
preds_bd = []
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
for gg in gt_test:
|
||||
for ii, gg in enumerate(gt_test):
|
||||
pred = du.process_file(
|
||||
gg["file_path"],
|
||||
model,
|
||||
@ -1,31 +1,33 @@
|
||||
import argparse
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import List
|
||||
import sys
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import torch.nn.functional as F
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
import batdetect2.detector.models as models
|
||||
import batdetect2.detector.parameters as parameters
|
||||
import batdetect2.train.legacy.audio_dataloader as adl
|
||||
import batdetect2.train.legacy.train_model as tm
|
||||
import batdetect2.train.legacy.train_utils as tu
|
||||
import batdetect2.detector.post_process as pp
|
||||
import batdetect2.train.audio_dataloader as adl
|
||||
import batdetect2.train.evaluate as evl
|
||||
import batdetect2.train.losses as losses
|
||||
import batdetect2.train.train_model as tm
|
||||
import batdetect2.train.train_utils as tu
|
||||
import batdetect2.utils.detector_utils as du
|
||||
import batdetect2.utils.plot_utils as pu
|
||||
from batdetect2 import types
|
||||
from batdetect2.detector.models import Net2DFast
|
||||
|
||||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
if __name__ == "__main__":
|
||||
info_str = "\nBatDetect - Finetune Model\n"
|
||||
|
||||
|
||||
def parse_arugments():
|
||||
print(info_str)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"audio_path",
|
||||
type=str,
|
||||
help="Input directory for audio",
|
||||
"audio_path", type=str, help="Input directory for audio"
|
||||
)
|
||||
parser.add_argument(
|
||||
"train_ann_path",
|
||||
@ -37,15 +39,7 @@ def parse_arugments():
|
||||
type=str,
|
||||
help="Path to where test annotation file is stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"model_path", type=str, help="Path to pretrained model"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--experiment_dir",
|
||||
type=str,
|
||||
default=os.path.join(BASE_DIR, "experiments"),
|
||||
help="Path to where experiment files are stored",
|
||||
)
|
||||
parser.add_argument("model_path", type=str, help="Path to pretrained model")
|
||||
parser.add_argument(
|
||||
"--op_model_name",
|
||||
type=str,
|
||||
@ -77,64 +71,107 @@ def parse_arugments():
|
||||
parser.add_argument(
|
||||
"--notes", type=str, default="", help="Notes to save in text file"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
|
||||
def select_device(warn=True) -> str:
|
||||
params = parameters.get_params(True, "../../experiments/")
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
|
||||
if warn:
|
||||
warnings.warn(
|
||||
"No GPU available, using the CPU instead. Please consider using a GPU "
|
||||
"to speed up training.",
|
||||
stacklevel=2,
|
||||
params["device"] = "cuda"
|
||||
else:
|
||||
params["device"] = "cpu"
|
||||
print(
|
||||
"\nNote, this will be a lot faster if you use computer with a GPU.\n"
|
||||
)
|
||||
|
||||
return "cpu"
|
||||
print("\nAudio directory: " + args["audio_path"])
|
||||
print("Train file: " + args["train_ann_path"])
|
||||
print("Test file: " + args["test_ann_path"])
|
||||
print("Loading model: " + args["model_path"])
|
||||
|
||||
dataset_name = (
|
||||
os.path.basename(args["train_ann_path"])
|
||||
.replace(".json", "")
|
||||
.replace("_TRAIN", "")
|
||||
)
|
||||
|
||||
def load_annotations(
|
||||
dataset_name: str,
|
||||
ann_path: str,
|
||||
audio_path: str,
|
||||
classes_to_ignore: List[str] | None = None,
|
||||
events_of_interest: List[str] | None = None,
|
||||
) -> List[types.FileAnnotation]:
|
||||
train_sets: List[types.DatasetDict] = []
|
||||
if args["train_from_scratch"]:
|
||||
print("\nTraining model from scratch i.e. not using pretrained weights")
|
||||
model, params_train = du.load_model(args["model_path"], False)
|
||||
else:
|
||||
model, params_train = du.load_model(args["model_path"], True)
|
||||
model.to(params["device"])
|
||||
|
||||
params["num_epochs"] = args["num_epochs"]
|
||||
if args["op_model_name"] != "":
|
||||
params["model_file_name"] = args["op_model_name"]
|
||||
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
|
||||
|
||||
# save notes file
|
||||
params["notes"] = args["notes"]
|
||||
if args["notes"] != "":
|
||||
tu.write_notes_file(params["experiment"] + "notes.txt", args["notes"])
|
||||
|
||||
# load train annotations
|
||||
train_sets = []
|
||||
train_sets.append(
|
||||
tu.get_blank_dataset_dict(
|
||||
dataset_name, False, args["train_ann_path"], args["audio_path"]
|
||||
)
|
||||
)
|
||||
params["train_sets"] = [
|
||||
tu.get_blank_dataset_dict(
|
||||
dataset_name,
|
||||
is_test=False,
|
||||
ann_path=ann_path,
|
||||
wav_path=audio_path,
|
||||
False,
|
||||
os.path.basename(args["train_ann_path"]),
|
||||
args["audio_path"],
|
||||
)
|
||||
]
|
||||
|
||||
print("\nTrain set:")
|
||||
(
|
||||
data_train,
|
||||
params["class_names"],
|
||||
params["class_inv_freq"],
|
||||
) = tu.load_set_of_anns(
|
||||
train_sets, classes_to_ignore, params["events_of_interest"]
|
||||
)
|
||||
print("Number of files", len(data_train))
|
||||
|
||||
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
|
||||
params["class_names"]
|
||||
)
|
||||
params["class_names_short"] = tu.get_short_class_names(
|
||||
params["class_names"]
|
||||
)
|
||||
|
||||
return tu.load_set_of_anns(
|
||||
train_sets,
|
||||
events_of_interest=events_of_interest,
|
||||
classes_to_ignore=classes_to_ignore,
|
||||
# load test annotations
|
||||
test_sets = []
|
||||
test_sets.append(
|
||||
tu.get_blank_dataset_dict(
|
||||
dataset_name, True, args["test_ann_path"], args["audio_path"]
|
||||
)
|
||||
)
|
||||
params["test_sets"] = [
|
||||
tu.get_blank_dataset_dict(
|
||||
dataset_name,
|
||||
True,
|
||||
os.path.basename(args["test_ann_path"]),
|
||||
args["audio_path"],
|
||||
)
|
||||
]
|
||||
|
||||
print("\nTest set:")
|
||||
data_test, _, _ = tu.load_set_of_anns(
|
||||
test_sets, classes_to_ignore, params["events_of_interest"]
|
||||
)
|
||||
print("Number of files", len(data_test))
|
||||
|
||||
def finetune_model(
|
||||
model: types.DetectionModel,
|
||||
data_train: List[types.FileAnnotation],
|
||||
data_test: List[types.FileAnnotation],
|
||||
params: parameters.TrainingParameters,
|
||||
model_params: types.ModelParameters,
|
||||
finetune_only_last_layer: bool = False,
|
||||
save_images: bool = True,
|
||||
):
|
||||
# train loader
|
||||
train_dataset = adl.AudioLoader(data_train, params, is_train=True)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=params.batch_size,
|
||||
batch_size=params["batch_size"],
|
||||
shuffle=True,
|
||||
num_workers=params.num_workers,
|
||||
num_workers=params["num_workers"],
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
@ -144,36 +181,32 @@ def finetune_model(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=params.num_workers,
|
||||
num_workers=params["num_workers"],
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
inputs_train = next(iter(train_loader))
|
||||
params.ip_height = inputs_train["spec"].shape[2]
|
||||
params["ip_height"] = inputs_train["spec"].shape[2]
|
||||
print("\ntrain batch size :", inputs_train["spec"].shape)
|
||||
|
||||
# Check that the model is the same as the one used to train the pretrained
|
||||
# weights
|
||||
assert model_params["model_name"] == "Net2DFast"
|
||||
assert isinstance(model, Net2DFast)
|
||||
assert params_train["model_name"] == "Net2DFast"
|
||||
print(
|
||||
"\n\nSOME hyperparams need to be the same as the loaded model "
|
||||
"(e.g. FFT) - currently they are getting overwritten.\n\n"
|
||||
"\n\nSOME hyperparams need to be the same as the loaded model (e.g. FFT) - currently they are getting overwritten.\n\n"
|
||||
)
|
||||
|
||||
# set the number of output classes
|
||||
num_filts = model.conv_classes_op.in_channels
|
||||
(k_size,) = model.conv_classes_op.kernel_size
|
||||
(pad,) = model.conv_classes_op.padding
|
||||
k_size = model.conv_classes_op.kernel_size
|
||||
pad = model.conv_classes_op.padding
|
||||
model.conv_classes_op = torch.nn.Conv2d(
|
||||
num_filts,
|
||||
len(params.class_names) + 1,
|
||||
len(params["class_names"]) + 1,
|
||||
kernel_size=k_size,
|
||||
padding=pad,
|
||||
)
|
||||
model.conv_classes_op.to(params.device)
|
||||
model.conv_classes_op.to(params["device"])
|
||||
|
||||
if finetune_only_last_layer:
|
||||
if args["finetune_only_last_layer"]:
|
||||
print("\nOnly finetuning the final layers.\n")
|
||||
train_layers_i = [
|
||||
"conv_classes",
|
||||
@ -190,26 +223,19 @@ def finetune_model(
|
||||
else:
|
||||
param.requires_grad = False
|
||||
|
||||
optimizer = torch.optim.Adam(
|
||||
model.parameters(),
|
||||
lr=params.lr,
|
||||
)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
|
||||
scheduler = CosineAnnealingLR(
|
||||
optimizer,
|
||||
params.num_epochs * len(train_loader),
|
||||
optimizer, params["num_epochs"] * len(train_loader)
|
||||
)
|
||||
|
||||
if params.train_loss == "mse":
|
||||
if params["train_loss"] == "mse":
|
||||
det_criterion = losses.mse_loss
|
||||
elif params.train_loss == "focal":
|
||||
elif params["train_loss"] == "focal":
|
||||
det_criterion = losses.focal_loss
|
||||
else:
|
||||
raise ValueError("Unknown loss function")
|
||||
|
||||
# plotting
|
||||
train_plt_ls = pu.LossPlotter(
|
||||
params.experiment / "train_loss.png",
|
||||
params.num_epochs + 1,
|
||||
params["experiment"] + "train_loss.png",
|
||||
params["num_epochs"] + 1,
|
||||
["train_loss"],
|
||||
None,
|
||||
None,
|
||||
@ -217,8 +243,8 @@ def finetune_model(
|
||||
logy=True,
|
||||
)
|
||||
test_plt_ls = pu.LossPlotter(
|
||||
params.experiment / "test_loss.png",
|
||||
params.num_epochs + 1,
|
||||
params["experiment"] + "test_loss.png",
|
||||
params["num_epochs"] + 1,
|
||||
["test_loss"],
|
||||
None,
|
||||
None,
|
||||
@ -226,24 +252,24 @@ def finetune_model(
|
||||
logy=True,
|
||||
)
|
||||
test_plt = pu.LossPlotter(
|
||||
params.experiment / "test.png",
|
||||
params.num_epochs + 1,
|
||||
params["experiment"] + "test.png",
|
||||
params["num_epochs"] + 1,
|
||||
["avg_prec", "rec_at_x", "avg_prec_class", "file_acc", "top_class"],
|
||||
[0, 1],
|
||||
None,
|
||||
["epoch", ""],
|
||||
)
|
||||
test_plt_class = pu.LossPlotter(
|
||||
params.experiment / "test_avg_prec.png",
|
||||
params.num_epochs + 1,
|
||||
params.class_names_short,
|
||||
params["experiment"] + "test_avg_prec.png",
|
||||
params["num_epochs"] + 1,
|
||||
params["class_names_short"],
|
||||
[0, 1],
|
||||
params.class_names_short,
|
||||
params["class_names_short"],
|
||||
["epoch", "avg_prec"],
|
||||
)
|
||||
|
||||
# main train loop
|
||||
for epoch in range(0, params.num_epochs + 1):
|
||||
for epoch in range(0, params["num_epochs"] + 1):
|
||||
train_loss = tm.train(
|
||||
model,
|
||||
epoch,
|
||||
@ -255,14 +281,10 @@ def finetune_model(
|
||||
)
|
||||
train_plt_ls.update_and_save(epoch, [train_loss["train_loss"]])
|
||||
|
||||
if epoch % params.num_eval_epochs == 0:
|
||||
if epoch % params["num_eval_epochs"] == 0:
|
||||
# detection accuracy on test set
|
||||
test_res, test_loss = tm.test(
|
||||
model,
|
||||
epoch,
|
||||
test_loader,
|
||||
det_criterion,
|
||||
params,
|
||||
model, epoch, test_loader, det_criterion, params
|
||||
)
|
||||
test_plt_ls.update_and_save(epoch, [test_loss["test_loss"]])
|
||||
test_plt.update_and_save(
|
||||
@ -279,106 +301,18 @@ def finetune_model(
|
||||
epoch, [rs["avg_prec"] for rs in test_res["class_pr"]]
|
||||
)
|
||||
pu.plot_pr_curve_class(
|
||||
params.experiment, "test_pr", "test_pr", test_res
|
||||
params["experiment"], "test_pr", "test_pr", test_res
|
||||
)
|
||||
|
||||
# save finetuned model
|
||||
print(f"saving model to: {params.model_file_name}")
|
||||
print("saving model to: " + params["model_file_name"])
|
||||
op_state = {
|
||||
"epoch": epoch + 1,
|
||||
"state_dict": model.state_dict(),
|
||||
"params": params,
|
||||
}
|
||||
torch.save(op_state, params.model_file_name)
|
||||
torch.save(op_state, params["model_file_name"])
|
||||
|
||||
# save an image with associated prediction for each batch in the test set
|
||||
if save_images:
|
||||
if not args["do_not_save_images"]:
|
||||
tm.save_images_batch(model, test_loader, params)
|
||||
|
||||
|
||||
def main():
|
||||
info_str = "\nBatDetect - Finetune Model\n"
|
||||
print(info_str)
|
||||
|
||||
args = parse_arugments()
|
||||
|
||||
# Load experiment parameters
|
||||
params = parameters.get_params(
|
||||
make_dirs=True,
|
||||
exps_dir=args.experiment_dir,
|
||||
device=select_device(),
|
||||
num_epochs=args.num_epochs,
|
||||
notes=args.notes,
|
||||
)
|
||||
|
||||
print("\nAudio directory: " + args.audio_path)
|
||||
print("Train file: " + args.train_ann_path)
|
||||
print("Test file: " + args.test_ann_path)
|
||||
print("Loading model: " + args.model_path)
|
||||
|
||||
if args.train_from_scratch:
|
||||
print(
|
||||
"\nTraining model from scratch i.e. not using pretrained weights"
|
||||
)
|
||||
|
||||
model, model_params = du.load_model(
|
||||
args.model_path,
|
||||
load_weights=not args.train_from_scratch,
|
||||
device=params.device,
|
||||
)
|
||||
|
||||
if args.op_model_name != "":
|
||||
params.model_file_name = args.op_model_name
|
||||
|
||||
classes_to_ignore = params.classes_to_ignore + params.generic_class
|
||||
|
||||
# save notes file
|
||||
if params.notes:
|
||||
tu.write_notes_file(
|
||||
params.experiment / "notes.txt",
|
||||
args.notes,
|
||||
)
|
||||
|
||||
# NOTE:??
|
||||
dataset_name = (
|
||||
os.path.basename(args.train_ann_path)
|
||||
.replace(".json", "")
|
||||
.replace("_TRAIN", "")
|
||||
)
|
||||
|
||||
# ==== LOAD DATA ====
|
||||
|
||||
# load train annotations
|
||||
data_train = load_annotations(
|
||||
dataset_name,
|
||||
args.train_ann_path,
|
||||
args.audio_path,
|
||||
params.events_of_interest,
|
||||
)
|
||||
print("\nTrain set:")
|
||||
print("Number of files", len(data_train))
|
||||
|
||||
# load test annotations
|
||||
data_test = load_annotations(
|
||||
dataset_name,
|
||||
args.test_ann_path,
|
||||
args.audio_path,
|
||||
classes_to_ignore,
|
||||
params.events_of_interest,
|
||||
)
|
||||
print("\nTrain set:")
|
||||
print("Number of files", len(data_train))
|
||||
|
||||
finetune_model(
|
||||
model,
|
||||
data_train,
|
||||
data_test,
|
||||
params,
|
||||
model_params,
|
||||
finetune_only_last_layer=args.finetune_only_last_layer,
|
||||
save_images=args.do_not_save_images,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
201
batdetect2/finetune/prep_data_finetune.py
Normal file
201
batdetect2/finetune/prep_data_finetune.py
Normal file
@ -0,0 +1,201 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
import batdetect2.train.train_utils as tu
|
||||
|
||||
|
||||
def print_dataset_stats(data, split_name, classes_to_ignore):
|
||||
print("\nSplit:", split_name)
|
||||
print("Num files:", len(data))
|
||||
|
||||
class_cnts = {}
|
||||
for dd in data:
|
||||
for aa in dd["annotation"]:
|
||||
if aa["class"] not in classes_to_ignore:
|
||||
if aa["class"] in class_cnts:
|
||||
class_cnts[aa["class"]] += 1
|
||||
else:
|
||||
class_cnts[aa["class"]] = 1
|
||||
|
||||
if len(class_cnts) == 0:
|
||||
class_names = []
|
||||
else:
|
||||
class_names = np.sort([*class_cnts]).tolist()
|
||||
print("Class count:")
|
||||
str_len = np.max([len(cc) for cc in class_names]) + 5
|
||||
|
||||
for ii, cc in enumerate(class_names):
|
||||
print(str(ii).ljust(5) + cc.ljust(str_len) + str(class_cnts[cc]))
|
||||
|
||||
return class_names
|
||||
|
||||
|
||||
def load_file_names(file_name):
|
||||
if os.path.isfile(file_name):
|
||||
with open(file_name) as da:
|
||||
files = [line.rstrip() for line in da.readlines()]
|
||||
for ff in files:
|
||||
if ff.lower()[-3:] != "wav":
|
||||
print("Error: Filenames need to end in .wav - ", ff)
|
||||
assert False
|
||||
else:
|
||||
print("Error: Input file not found - ", file_name)
|
||||
assert False
|
||||
|
||||
return files
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
info_str = "\nBatDetect - Prepare Data for Finetuning\n"
|
||||
|
||||
print(info_str)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"dataset_name", type=str, help="Name to call your dataset"
|
||||
)
|
||||
parser.add_argument("audio_dir", type=str, help="Input directory for audio")
|
||||
parser.add_argument(
|
||||
"ann_dir",
|
||||
type=str,
|
||||
help="Input directory for where the audio annotations are stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"op_dir",
|
||||
type=str,
|
||||
help="Path where the train and test splits will be stored",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--percent_val",
|
||||
type=float,
|
||||
default=0.20,
|
||||
help="Hold out this much data for validation. Should be number between 0 and 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rand_seed",
|
||||
type=int,
|
||||
default=2001,
|
||||
help="Random seed used for creating the validation split",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_file",
|
||||
type=str,
|
||||
default="",
|
||||
help="Text file where each line is a wav file in train split",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_file",
|
||||
type=str,
|
||||
default="",
|
||||
help="Text file where each line is a wav file in test split",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_class_names",
|
||||
type=str,
|
||||
default="",
|
||||
help='Specify names of classes that you want to change. Separate with ";"',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_class_names",
|
||||
type=str,
|
||||
default="",
|
||||
help='New class names to use instead. One to one mapping with "--input_class_names". \
|
||||
Separate with ";"',
|
||||
)
|
||||
args = vars(parser.parse_args())
|
||||
|
||||
np.random.seed(args["rand_seed"])
|
||||
|
||||
classes_to_ignore = ["", " ", "Unknown", "Not Bat"]
|
||||
generic_class = ["Bat"]
|
||||
events_of_interest = ["Echolocation"]
|
||||
|
||||
if args["input_class_names"] != "" and args["output_class_names"] != "":
|
||||
# change the names of the classes
|
||||
ip_names = args["input_class_names"].split(";")
|
||||
op_names = args["output_class_names"].split(";")
|
||||
name_dict = dict(zip(ip_names, op_names))
|
||||
else:
|
||||
name_dict = False
|
||||
|
||||
# load annotations
|
||||
data_all, _, _ = tu.load_set_of_anns(
|
||||
{"ann_path": args["ann_dir"], "wav_path": args["audio_dir"]},
|
||||
classes_to_ignore,
|
||||
events_of_interest,
|
||||
False,
|
||||
False,
|
||||
list_of_anns=True,
|
||||
filter_issues=True,
|
||||
name_replace=name_dict,
|
||||
)
|
||||
|
||||
print("Dataset name: " + args["dataset_name"])
|
||||
print("Audio directory: " + args["audio_dir"])
|
||||
print("Annotation directory: " + args["ann_dir"])
|
||||
print("Ouput directory: " + args["op_dir"])
|
||||
print("Num annotated files: " + str(len(data_all)))
|
||||
|
||||
if args["train_file"] != "" and args["test_file"] != "":
|
||||
# user has specifed the train / test split
|
||||
train_files = load_file_names(args["train_file"])
|
||||
test_files = load_file_names(args["test_file"])
|
||||
file_names_all = [dd["id"] for dd in data_all]
|
||||
train_inds = [
|
||||
file_names_all.index(ff)
|
||||
for ff in train_files
|
||||
if ff in file_names_all
|
||||
]
|
||||
test_inds = [
|
||||
file_names_all.index(ff)
|
||||
for ff in test_files
|
||||
if ff in file_names_all
|
||||
]
|
||||
|
||||
else:
|
||||
# split the data into train and test at the file level
|
||||
num_exs = len(data_all)
|
||||
test_inds = np.random.choice(
|
||||
np.arange(num_exs),
|
||||
int(num_exs * args["percent_val"]),
|
||||
replace=False,
|
||||
)
|
||||
test_inds = np.sort(test_inds)
|
||||
train_inds = np.setdiff1d(np.arange(num_exs), test_inds)
|
||||
|
||||
data_train = [data_all[ii] for ii in train_inds]
|
||||
data_test = [data_all[ii] for ii in test_inds]
|
||||
|
||||
if not os.path.isdir(args["op_dir"]):
|
||||
os.makedirs(args["op_dir"])
|
||||
op_name = os.path.join(args["op_dir"], args["dataset_name"])
|
||||
op_name_train = op_name + "_TRAIN.json"
|
||||
op_name_test = op_name + "_TEST.json"
|
||||
|
||||
class_un_train = print_dataset_stats(data_train, "Train", classes_to_ignore)
|
||||
class_un_test = print_dataset_stats(data_test, "Test", classes_to_ignore)
|
||||
|
||||
if len(data_train) > 0 and len(data_test) > 0:
|
||||
if class_un_train != class_un_test:
|
||||
print(
|
||||
'\nError: some classes are not in both the training and test sets.\
|
||||
\nTry a different random seed "--rand_seed".'
|
||||
)
|
||||
assert False
|
||||
|
||||
print("\n")
|
||||
if len(data_train) == 0:
|
||||
print("No train annotations to save")
|
||||
else:
|
||||
print("Saving: ", op_name_train)
|
||||
with open(op_name_train, "w") as da:
|
||||
json.dump(data_train, da, indent=2)
|
||||
|
||||
if len(data_test) == 0:
|
||||
print("No test annotations to save")
|
||||
else:
|
||||
print("Saving: ", op_name_test)
|
||||
with open(op_name_test, "w") as da:
|
||||
json.dump(data_test, da, indent=2)
|
||||
@ -1,11 +1,11 @@
|
||||
"""Plot functions to visualize detections and spectrograms."""
|
||||
|
||||
from typing import cast
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
import matplotlib.ticker as tick
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import axes, patches
|
||||
import matplotlib.ticker as tick
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
@ -24,10 +24,10 @@ __all__ = [
|
||||
|
||||
|
||||
def spectrogram(
|
||||
spec: torch.Tensor | np.ndarray,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
spec: Union[torch.Tensor, np.ndarray],
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
cmap: str = "plasma",
|
||||
start_time: float = 0,
|
||||
) -> axes.Axes:
|
||||
@ -35,18 +35,18 @@ def spectrogram(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: Spectrogram to plot.
|
||||
config: Configuration
|
||||
spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot.
|
||||
config (Optional[ProcessingConfiguration], optional): Configuration
|
||||
used to compute the spectrogram. Defaults to None. If None,
|
||||
the default configuration will be used.
|
||||
ax: Matplotlib axes object.
|
||||
ax (Optional[axes.Axes], optional): Matplotlib axes object.
|
||||
Defaults to None. if provided, the spectrogram will be plotted
|
||||
on this axes.
|
||||
figsize: Figure size.
|
||||
figsize (Optional[Tuple[int, int]], optional): Figure size.
|
||||
Defaults to None. If `ax` is None, this will be used to create
|
||||
a new figure of the given size.
|
||||
cmap: Colormap to use. Defaults to "plasma".
|
||||
start_time: Start time of the spectrogram.
|
||||
cmap (str, optional): Colormap to use. Defaults to "plasma".
|
||||
start_time (float, optional): Start time of the spectrogram.
|
||||
Defaults to 0. This is useful if plotting a spectrogram
|
||||
of a segment of a longer audio file.
|
||||
|
||||
@ -103,11 +103,11 @@ def spectrogram(
|
||||
|
||||
|
||||
def spectrogram_with_detections(
|
||||
spec: torch.Tensor | np.ndarray,
|
||||
dets: list[Annotation],
|
||||
config: ProcessingConfiguration | None = None,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
spec: Union[torch.Tensor, np.ndarray],
|
||||
dets: List[Annotation],
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
cmap: str = "plasma",
|
||||
with_names: bool = True,
|
||||
start_time: float = 0,
|
||||
@ -117,21 +117,21 @@ def spectrogram_with_detections(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec: Spectrogram to plot.
|
||||
detections: List of detections.
|
||||
config: Configuration
|
||||
spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot.
|
||||
detections (List[Annotation]): List of detections.
|
||||
config (Optional[ProcessingConfiguration], optional): Configuration
|
||||
used to compute the spectrogram. Defaults to None. If None,
|
||||
the default configuration will be used.
|
||||
ax: Matplotlib axes object.
|
||||
ax (Optional[axes.Axes], optional): Matplotlib axes object.
|
||||
Defaults to None. if provided, the spectrogram will be plotted
|
||||
on this axes.
|
||||
figsize: Figure size.
|
||||
figsize (Optional[Tuple[int, int]], optional): Figure size.
|
||||
Defaults to None. If `ax` is None, this will be used to create
|
||||
a new figure of the given size.
|
||||
cmap: Colormap to use. Defaults to "plasma".
|
||||
with_names: Whether to plot the name of the
|
||||
cmap (str, optional): Colormap to use. Defaults to "plasma".
|
||||
with_names (bool, optional): Whether to plot the name of the
|
||||
predicted class next to the detection. Defaults to True.
|
||||
start_time: Start time of the spectrogram.
|
||||
start_time (float, optional): Start time of the spectrogram.
|
||||
Defaults to 0. This is useful if plotting a spectrogram
|
||||
of a segment of a longer audio file.
|
||||
**kwargs: Additional keyword arguments to pass to the
|
||||
@ -167,9 +167,9 @@ def spectrogram_with_detections(
|
||||
|
||||
|
||||
def detections(
|
||||
dets: list[Annotation],
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
dets: List[Annotation],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
with_names: bool = True,
|
||||
**kwargs,
|
||||
) -> axes.Axes:
|
||||
@ -177,14 +177,14 @@ def detections(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dets: List of detections.
|
||||
ax: Matplotlib axes object.
|
||||
dets (List[Annotation]): List of detections.
|
||||
ax (Optional[axes.Axes], optional): Matplotlib axes object.
|
||||
Defaults to None. if provided, the spectrogram will be plotted
|
||||
on this axes.
|
||||
figsize: Figure size.
|
||||
figsize (Optional[Tuple[int, int]], optional): Figure size.
|
||||
Defaults to None. If `ax` is None, this will be used to create
|
||||
a new figure of the given size.
|
||||
with_names: Whether to plot the name of the
|
||||
with_names (bool, optional): Whether to plot the name of the
|
||||
predicted class next to the detection. Defaults to True.
|
||||
**kwargs: Additional keyword arguments to pass to the
|
||||
`plot.detection` function.
|
||||
@ -213,8 +213,8 @@ def detections(
|
||||
|
||||
def detection(
|
||||
det: Annotation,
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
linewidth: float = 1,
|
||||
edgecolor: str = "w",
|
||||
facecolor: str = "none",
|
||||
@ -224,19 +224,19 @@ def detection(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
det: Detection to plot.
|
||||
ax: Matplotlib axes object. Defaults
|
||||
det (Annotation): Detection to plot.
|
||||
ax (Optional[axes.Axes], optional): Matplotlib axes object. Defaults
|
||||
to None. If provided, the spectrogram will be plotted on this axes.
|
||||
figsize: Figure size. Defaults
|
||||
figsize (Optional[Tuple[int, int]], optional): Figure size. Defaults
|
||||
to None. If `ax` is None, this will be used to create a new figure
|
||||
of the given size.
|
||||
linewidth: Line width of the detection.
|
||||
linewidth (float, optional): Line width of the detection.
|
||||
Defaults to 1.
|
||||
edgecolor: Edge color of the detection.
|
||||
edgecolor (str, optional): Edge color of the detection.
|
||||
Defaults to "w", i.e. white.
|
||||
facecolor: Face color of the detection.
|
||||
facecolor (str, optional): Face color of the detection.
|
||||
Defaults to "none", i.e. transparent.
|
||||
with_name: Whether to plot the name of the
|
||||
with_name (bool, optional): Whether to plot the name of the
|
||||
predicted class next to the detection. Defaults to True.
|
||||
|
||||
Returns
|
||||
@ -277,22 +277,22 @@ def detection(
|
||||
|
||||
|
||||
def _compute_spec_extent(
|
||||
shape: tuple[int, int],
|
||||
shape: Tuple[int, int],
|
||||
params: SpectrogramParameters,
|
||||
) -> tuple[float, float, float, float]:
|
||||
) -> Tuple[float, float, float, float]:
|
||||
"""Compute the extent of a spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shape: Shape of the spectrogram.
|
||||
shape (Tuple[int, int]): Shape of the spectrogram.
|
||||
The first dimension is the frequency axis and the second
|
||||
dimension is the time axis.
|
||||
params: Spectrogram parameters.
|
||||
params (SpectrogramParameters): Spectrogram parameters.
|
||||
Should be the same as the ones used to compute the spectrogram.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[float, float, float, float]: Extent of the spectrogram.
|
||||
Tuple[float, float, float, float]: Extent of the spectrogram.
|
||||
The first two values are the minimum and maximum time values,
|
||||
the last two values are the minimum and maximum frequency values.
|
||||
"""
|
||||
@ -306,9 +306,6 @@ def _compute_spec_extent(
|
||||
|
||||
# If the spectrogram is not resized, the duration is correct
|
||||
# but if it is resized, the duration needs to be adjusted
|
||||
# NOTE: For now we can only detect if the spectrogram is resized
|
||||
# by checking if the height is equal to the specified height,
|
||||
# but this could fail.
|
||||
resize_factor = params["resize_factor"]
|
||||
spec_height = params["spec_height"]
|
||||
if spec_height * resize_factor == shape[0]:
|
||||
603
batdetect2/train/audio_dataloader.py
Normal file
603
batdetect2/train/audio_dataloader.py
Normal file
@ -0,0 +1,603 @@
|
||||
import copy
|
||||
from typing import Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
|
||||
import batdetect2.utils.audio_utils as au
|
||||
from batdetect2.types import AnnotationGroup, HeatmapParameters
|
||||
|
||||
|
||||
def generate_gt_heatmaps(
|
||||
spec_op_shape: Tuple[int, int],
|
||||
sampling_rate: int,
|
||||
ann: AnnotationGroup,
|
||||
params: HeatmapParameters,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, AnnotationGroup]:
|
||||
"""Generate ground truth heatmaps from annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec_op_shape : Tuple[int, int]
|
||||
Shape of the input spectrogram.
|
||||
sampling_rate : int
|
||||
Sampling rate of the input audio in Hz.
|
||||
ann : AnnotationGroup
|
||||
Dictionary containing the annotation information.
|
||||
params : HeatmapParameters
|
||||
Parameters controlling the generation of the heatmaps.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
y_2d_det : np.ndarray
|
||||
2D heatmap of the presence of an event.
|
||||
|
||||
y_2d_size : np.ndarray
|
||||
2D heatmap of the size of the bounding box associated to event.
|
||||
|
||||
y_2d_classes : np.ndarray
|
||||
3D array containing the ground-truth class probabilities for each
|
||||
pixel.
|
||||
|
||||
ann_aug : AnnotationGroup
|
||||
A dictionary containing the annotation information of the
|
||||
annotations that are within the input spectrogram, augmented with
|
||||
the x and y indices of their pixel location in the input spectrogram.
|
||||
|
||||
"""
|
||||
# spec may be resized on input into the network
|
||||
num_classes = len(params["class_names"])
|
||||
op_height = spec_op_shape[0]
|
||||
op_width = spec_op_shape[1]
|
||||
freq_per_bin = (params["max_freq"] - params["min_freq"]) / op_height
|
||||
|
||||
# start and end times
|
||||
x_pos_start = au.time_to_x_coords(
|
||||
ann["start_times"],
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
x_pos_start = (params["resize_factor"] * x_pos_start).astype(np.int)
|
||||
x_pos_end = au.time_to_x_coords(
|
||||
ann["end_times"],
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
x_pos_end = (params["resize_factor"] * x_pos_end).astype(np.int)
|
||||
|
||||
# location on y axis i.e. frequency
|
||||
y_pos_low = (ann["low_freqs"] - params["min_freq"]) / freq_per_bin
|
||||
y_pos_low = (op_height - y_pos_low).astype(np.int)
|
||||
y_pos_high = (ann["high_freqs"] - params["min_freq"]) / freq_per_bin
|
||||
y_pos_high = (op_height - y_pos_high).astype(np.int)
|
||||
bb_widths = x_pos_end - x_pos_start
|
||||
bb_heights = y_pos_low - y_pos_high
|
||||
|
||||
# Only include annotations that are within the input spectrogram
|
||||
valid_inds = np.where(
|
||||
(x_pos_start >= 0)
|
||||
& (x_pos_start < op_width)
|
||||
& (y_pos_low >= 0)
|
||||
& (y_pos_low < (op_height - 1))
|
||||
)[0]
|
||||
|
||||
ann_aug: AnnotationGroup = {
|
||||
"start_times": ann["start_times"][valid_inds],
|
||||
"end_times": ann["end_times"][valid_inds],
|
||||
"high_freqs": ann["high_freqs"][valid_inds],
|
||||
"low_freqs": ann["low_freqs"][valid_inds],
|
||||
"class_ids": ann["class_ids"][valid_inds],
|
||||
"individual_ids": ann["individual_ids"][valid_inds],
|
||||
}
|
||||
ann_aug["x_inds"] = x_pos_start[valid_inds]
|
||||
ann_aug["y_inds"] = y_pos_low[valid_inds]
|
||||
# keys = [
|
||||
# "start_times",
|
||||
# "end_times",
|
||||
# "high_freqs",
|
||||
# "low_freqs",
|
||||
# "class_ids",
|
||||
# "individual_ids",
|
||||
# ]
|
||||
# for kk in keys:
|
||||
# ann_aug[kk] = ann[kk][valid_inds]
|
||||
|
||||
# if the number of calls is only 1, then it is unique
|
||||
# TODO would be better if we found these unique calls at the merging stage
|
||||
if len(ann_aug["individual_ids"]) == 1:
|
||||
ann_aug["individual_ids"][0] = 0
|
||||
|
||||
y_2d_det = np.zeros((1, op_height, op_width), dtype=np.float32)
|
||||
y_2d_size = np.zeros((2, op_height, op_width), dtype=np.float32)
|
||||
# num classes and "background" class
|
||||
y_2d_classes: np.ndarray = np.zeros(
|
||||
(num_classes + 1, op_height, op_width), dtype=np.float32
|
||||
)
|
||||
|
||||
# create 2D ground truth heatmaps
|
||||
for ii in valid_inds:
|
||||
draw_gaussian(
|
||||
y_2d_det[0, :],
|
||||
(x_pos_start[ii], y_pos_low[ii]),
|
||||
params["target_sigma"],
|
||||
)
|
||||
# draw_gaussian(y_2d_det[0,:], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
|
||||
y_2d_size[0, y_pos_low[ii], x_pos_start[ii]] = bb_widths[ii]
|
||||
y_2d_size[1, y_pos_low[ii], x_pos_start[ii]] = bb_heights[ii]
|
||||
|
||||
cls_id = ann["class_ids"][ii]
|
||||
if cls_id > -1:
|
||||
draw_gaussian(
|
||||
y_2d_classes[cls_id, :],
|
||||
(x_pos_start[ii], y_pos_low[ii]),
|
||||
params["target_sigma"],
|
||||
)
|
||||
# draw_gaussian(y_2d_classes[cls_id, :], (x_pos_start[ii], y_pos_low[ii]), params['target_sigma'], params['target_sigma']*2)
|
||||
|
||||
# be careful as this will have a 1.0 places where we have event but dont know gt class
|
||||
# this will be masked in training anyway
|
||||
y_2d_classes[num_classes, :] = 1.0 - y_2d_classes.sum(0)
|
||||
y_2d_classes = y_2d_classes / y_2d_classes.sum(0)[np.newaxis, ...]
|
||||
y_2d_classes[np.isnan(y_2d_classes)] = 0.0
|
||||
|
||||
return y_2d_det, y_2d_size, y_2d_classes, ann_aug
|
||||
|
||||
|
||||
def draw_gaussian(heatmap, center, sigmax, sigmay=None):
|
||||
# center is (x, y)
|
||||
# this edits the heatmap inplace
|
||||
|
||||
if sigmay is None:
|
||||
sigmay = sigmax
|
||||
tmp_size = np.maximum(sigmax, sigmay) * 3
|
||||
mu_x = int(center[0] + 0.5)
|
||||
mu_y = int(center[1] + 0.5)
|
||||
w, h = heatmap.shape[0], heatmap.shape[1]
|
||||
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
||||
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
||||
|
||||
if ul[0] >= h or ul[1] >= w or br[0] < 0 or br[1] < 0:
|
||||
return False
|
||||
|
||||
size = 2 * tmp_size + 1
|
||||
x = np.arange(0, size, 1, np.float32)
|
||||
y = x[:, np.newaxis]
|
||||
x0 = y0 = size // 2
|
||||
# g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
||||
g = np.exp(
|
||||
-((x - x0) ** 2) / (2 * sigmax**2)
|
||||
- ((y - y0) ** 2) / (2 * sigmay**2)
|
||||
)
|
||||
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
|
||||
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
|
||||
img_x = max(0, ul[0]), min(br[0], h)
|
||||
img_y = max(0, ul[1]), min(br[1], w)
|
||||
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]] = np.maximum(
|
||||
heatmap[img_y[0] : img_y[1], img_x[0] : img_x[1]],
|
||||
g[g_y[0] : g_y[1], g_x[0] : g_x[1]],
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def pad_aray(ip_array, pad_size):
|
||||
return np.hstack((ip_array, np.ones(pad_size, dtype=np.int) * -1))
|
||||
|
||||
|
||||
def warp_spec_aug(spec, ann, return_spec_for_viz, params):
|
||||
# This is messy
|
||||
# Augment spectrogram by randomly stretch and squeezing
|
||||
# NOTE this also changes the start and stop time in place
|
||||
|
||||
# not taking care of spec for viz
|
||||
if return_spec_for_viz:
|
||||
assert False
|
||||
|
||||
delta = params["stretch_squeeze_delta"]
|
||||
op_size = (spec.shape[1], spec.shape[2])
|
||||
resize_fract_r = np.random.rand() * delta * 2 - delta + 1.0
|
||||
resize_amt = int(spec.shape[2] * resize_fract_r)
|
||||
if resize_amt >= spec.shape[2]:
|
||||
spec_r = torch.cat(
|
||||
(
|
||||
spec,
|
||||
torch.zeros(
|
||||
(1, spec.shape[1], resize_amt - spec.shape[2]),
|
||||
dtype=spec.dtype,
|
||||
),
|
||||
),
|
||||
2,
|
||||
)
|
||||
else:
|
||||
spec_r = spec[:, :, :resize_amt]
|
||||
spec = F.interpolate(
|
||||
spec_r.unsqueeze(0), size=op_size, mode="bilinear", align_corners=False
|
||||
).squeeze(0)
|
||||
ann["start_times"] *= 1.0 / resize_fract_r
|
||||
ann["end_times"] *= 1.0 / resize_fract_r
|
||||
return spec
|
||||
|
||||
|
||||
def mask_time_aug(spec, params):
|
||||
# Mask out a random block of time - repeat up to 3 times
|
||||
# SpecAugment: A Simple Data Augmentation Methodfor Automatic Speech Recognition
|
||||
fm = torchaudio.transforms.TimeMasking(
|
||||
int(spec.shape[1] * params["mask_max_time_perc"])
|
||||
)
|
||||
for ii in range(np.random.randint(1, 4)):
|
||||
spec = fm(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def mask_freq_aug(spec, params):
|
||||
# Mask out a random frequncy range - repeat up to 3 times
|
||||
# SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
|
||||
fm = torchaudio.transforms.FrequencyMasking(
|
||||
int(spec.shape[1] * params["mask_max_freq_perc"])
|
||||
)
|
||||
for ii in range(np.random.randint(1, 4)):
|
||||
spec = fm(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def scale_vol_aug(spec, params):
|
||||
return spec * np.random.random() * params["spec_amp_scaling"]
|
||||
|
||||
|
||||
def echo_aug(audio, sampling_rate, params):
|
||||
sample_offset = (
|
||||
int(params["echo_max_delay"] * np.random.random() * sampling_rate) + 1
|
||||
)
|
||||
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
|
||||
return audio
|
||||
|
||||
|
||||
def resample_aug(audio, sampling_rate, params):
|
||||
sampling_rate_old = sampling_rate
|
||||
sampling_rate = np.random.choice(params["aug_sampling_rates"])
|
||||
audio = librosa.resample(
|
||||
audio,
|
||||
orig_sr=sampling_rate_old,
|
||||
target_sr=sampling_rate,
|
||||
res_type="polyphase",
|
||||
)
|
||||
|
||||
audio = au.pad_audio(
|
||||
audio,
|
||||
sampling_rate,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
params["resize_factor"],
|
||||
params["spec_divide_factor"],
|
||||
params["spec_train_width"],
|
||||
)
|
||||
duration = audio.shape[0] / float(sampling_rate)
|
||||
return audio, sampling_rate, duration
|
||||
|
||||
|
||||
def resample_audio(num_samples, sampling_rate, audio2, sampling_rate2):
|
||||
if sampling_rate != sampling_rate2:
|
||||
audio2 = librosa.resample(
|
||||
audio2,
|
||||
orig_sr=sampling_rate2,
|
||||
target_sr=sampling_rate,
|
||||
res_type="polyphase",
|
||||
)
|
||||
sampling_rate2 = sampling_rate
|
||||
if audio2.shape[0] < num_samples:
|
||||
audio2 = np.hstack(
|
||||
(
|
||||
audio2,
|
||||
np.zeros((num_samples - audio2.shape[0]), dtype=audio2.dtype),
|
||||
)
|
||||
)
|
||||
elif audio2.shape[0] > num_samples:
|
||||
audio2 = audio2[:num_samples]
|
||||
return audio2, sampling_rate2
|
||||
|
||||
|
||||
def combine_audio_aug(audio, sampling_rate, ann, audio2, sampling_rate2, ann2):
|
||||
|
||||
# resample so they are the same
|
||||
audio2, sampling_rate2 = resample_audio(
|
||||
audio.shape[0], sampling_rate, audio2, sampling_rate2
|
||||
)
|
||||
|
||||
# # set mean and std to be the same
|
||||
# audio2 = (audio2 - audio2.mean())
|
||||
# audio2 = (audio2/audio2.std())*audio.std()
|
||||
# audio2 = audio2 + audio.mean()
|
||||
|
||||
if (
|
||||
ann["annotated"]
|
||||
and (ann2["annotated"])
|
||||
and (sampling_rate2 == sampling_rate)
|
||||
and (audio.shape[0] == audio2.shape[0])
|
||||
):
|
||||
comb_weight = 0.3 + np.random.random() * 0.4
|
||||
audio = comb_weight * audio + (1 - comb_weight) * audio2
|
||||
inds = np.argsort(np.hstack((ann["start_times"], ann2["start_times"])))
|
||||
for kk in ann.keys():
|
||||
|
||||
# when combining calls from different files, assume they come from different individuals
|
||||
if kk == "individual_ids":
|
||||
if (ann[kk] > -1).sum() > 0:
|
||||
ann2[kk][ann2[kk] > -1] += np.max(ann[kk][ann[kk] > -1]) + 1
|
||||
|
||||
if (kk != "class_id_file") and (kk != "annotated"):
|
||||
ann[kk] = np.hstack((ann[kk], ann2[kk]))[inds]
|
||||
|
||||
return audio, ann
|
||||
|
||||
|
||||
class AudioLoader(torch.utils.data.Dataset):
|
||||
def __init__(self, data_anns_ip, params, dataset_name=None, is_train=False):
|
||||
|
||||
self.data_anns = []
|
||||
self.is_train = is_train
|
||||
self.params = params
|
||||
self.return_spec_for_viz = False
|
||||
|
||||
for ii in range(len(data_anns_ip)):
|
||||
dd = copy.deepcopy(data_anns_ip[ii])
|
||||
|
||||
# filter out unused annotation here
|
||||
filtered_annotations = []
|
||||
for ii, aa in enumerate(dd["annotation"]):
|
||||
|
||||
if "individual" in aa.keys():
|
||||
aa["individual"] = int(aa["individual"])
|
||||
|
||||
# if only one call labeled it has to be from the same individual
|
||||
if len(dd["annotation"]) == 1:
|
||||
aa["individual"] = 0
|
||||
|
||||
# convert class name into class label
|
||||
if aa["class"] in self.params["class_names"]:
|
||||
aa["class_id"] = self.params["class_names"].index(
|
||||
aa["class"]
|
||||
)
|
||||
else:
|
||||
aa["class_id"] = -1
|
||||
|
||||
if aa["class"] not in self.params["classes_to_ignore"]:
|
||||
filtered_annotations.append(aa)
|
||||
|
||||
dd["annotation"] = filtered_annotations
|
||||
dd["start_times"] = np.array(
|
||||
[aa["start_time"] for aa in dd["annotation"]]
|
||||
)
|
||||
dd["end_times"] = np.array(
|
||||
[aa["end_time"] for aa in dd["annotation"]]
|
||||
)
|
||||
dd["high_freqs"] = np.array(
|
||||
[float(aa["high_freq"]) for aa in dd["annotation"]]
|
||||
)
|
||||
dd["low_freqs"] = np.array(
|
||||
[float(aa["low_freq"]) for aa in dd["annotation"]]
|
||||
)
|
||||
dd["class_ids"] = np.array(
|
||||
[aa["class_id"] for aa in dd["annotation"]]
|
||||
).astype(np.int)
|
||||
dd["individual_ids"] = np.array(
|
||||
[aa["individual"] for aa in dd["annotation"]]
|
||||
).astype(np.int)
|
||||
|
||||
# file level class name
|
||||
dd["class_id_file"] = -1
|
||||
if "class_name" in dd.keys():
|
||||
if dd["class_name"] in self.params["class_names"]:
|
||||
dd["class_id_file"] = self.params["class_names"].index(
|
||||
dd["class_name"]
|
||||
)
|
||||
|
||||
self.data_anns.append(dd)
|
||||
|
||||
ann_cnt = [len(aa["annotation"]) for aa in self.data_anns]
|
||||
self.max_num_anns = 2 * np.max(
|
||||
ann_cnt
|
||||
) # x2 because we may be combining files during training
|
||||
|
||||
print("\n")
|
||||
if dataset_name is not None:
|
||||
print("Dataset : " + dataset_name)
|
||||
if self.is_train:
|
||||
print("Split type : train")
|
||||
else:
|
||||
print("Split type : test")
|
||||
print("Num files : " + str(len(self.data_anns)))
|
||||
print("Num calls : " + str(np.sum(ann_cnt)))
|
||||
|
||||
def get_file_and_anns(self, index=None):
|
||||
|
||||
# if no file specified, choose random one
|
||||
if index == None:
|
||||
index = np.random.randint(0, len(self.data_anns))
|
||||
|
||||
audio_file = self.data_anns[index]["file_path"]
|
||||
sampling_rate, audio_raw = au.load_audio(
|
||||
audio_file,
|
||||
self.data_anns[index]["time_exp"],
|
||||
self.params["target_samp_rate"],
|
||||
self.params["scale_raw_audio"],
|
||||
)
|
||||
|
||||
# copy annotation
|
||||
ann = {}
|
||||
ann["annotated"] = self.data_anns[index]["annotated"]
|
||||
ann["class_id_file"] = self.data_anns[index]["class_id_file"]
|
||||
keys = [
|
||||
"start_times",
|
||||
"end_times",
|
||||
"high_freqs",
|
||||
"low_freqs",
|
||||
"class_ids",
|
||||
"individual_ids",
|
||||
]
|
||||
for kk in keys:
|
||||
ann[kk] = self.data_anns[index][kk].copy()
|
||||
|
||||
# if train then grab a random crop
|
||||
if self.is_train:
|
||||
nfft = int(self.params["fft_win_length"] * sampling_rate)
|
||||
noverlap = int(self.params["fft_overlap"] * nfft)
|
||||
length_samples = (
|
||||
self.params["spec_train_width"] * (nfft - noverlap) + noverlap
|
||||
)
|
||||
|
||||
if audio_raw.shape[0] - length_samples > 0:
|
||||
sample_crop = np.random.randint(
|
||||
audio_raw.shape[0] - length_samples
|
||||
)
|
||||
else:
|
||||
sample_crop = 0
|
||||
audio_raw = audio_raw[sample_crop : sample_crop + length_samples]
|
||||
ann["start_times"] = ann["start_times"] - sample_crop / float(
|
||||
sampling_rate
|
||||
)
|
||||
ann["end_times"] = ann["end_times"] - sample_crop / float(
|
||||
sampling_rate
|
||||
)
|
||||
|
||||
# pad audio
|
||||
if self.is_train:
|
||||
op_spec_target_size = self.params["spec_train_width"]
|
||||
else:
|
||||
op_spec_target_size = None
|
||||
audio_raw = au.pad_audio(
|
||||
audio_raw,
|
||||
sampling_rate,
|
||||
self.params["fft_win_length"],
|
||||
self.params["fft_overlap"],
|
||||
self.params["resize_factor"],
|
||||
self.params["spec_divide_factor"],
|
||||
op_spec_target_size,
|
||||
)
|
||||
duration = audio_raw.shape[0] / float(sampling_rate)
|
||||
|
||||
# sort based on time
|
||||
inds = np.argsort(ann["start_times"])
|
||||
for kk in ann.keys():
|
||||
if (kk != "class_id_file") and (kk != "annotated"):
|
||||
ann[kk] = ann[kk][inds]
|
||||
|
||||
return audio_raw, sampling_rate, duration, ann
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
# load audio file
|
||||
audio, sampling_rate, duration, ann = self.get_file_and_anns(index)
|
||||
|
||||
# augment on raw audio
|
||||
if self.is_train and self.params["augment_at_train"]:
|
||||
# augment - combine with random audio file
|
||||
if (
|
||||
self.params["augment_at_train_combine"]
|
||||
and np.random.random() < self.params["aug_prob"]
|
||||
):
|
||||
(
|
||||
audio2,
|
||||
sampling_rate2,
|
||||
duration2,
|
||||
ann2,
|
||||
) = self.get_file_and_anns()
|
||||
audio, ann = combine_audio_aug(
|
||||
audio, sampling_rate, ann, audio2, sampling_rate2, ann2
|
||||
)
|
||||
|
||||
# simulate echo by adding delayed copy of the file
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
audio = echo_aug(audio, sampling_rate, self.params)
|
||||
|
||||
# resample the audio
|
||||
# if np.random.random() < self.params['aug_prob']:
|
||||
# audio, sampling_rate, duration = resample_aug(audio, sampling_rate, self.params)
|
||||
|
||||
# create spectrogram
|
||||
spec, spec_for_viz = au.generate_spectrogram(
|
||||
audio, sampling_rate, self.params, self.return_spec_for_viz
|
||||
)
|
||||
rsf = self.params["resize_factor"]
|
||||
spec_op_shape = (
|
||||
int(self.params["spec_height"] * rsf),
|
||||
int(spec.shape[1] * rsf),
|
||||
)
|
||||
|
||||
# resize the spec
|
||||
spec = torch.from_numpy(spec).unsqueeze(0).unsqueeze(0)
|
||||
spec = F.interpolate(
|
||||
spec, size=spec_op_shape, mode="bilinear", align_corners=False
|
||||
).squeeze(0)
|
||||
|
||||
# augment spectrogram
|
||||
if self.is_train and self.params["augment_at_train"]:
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = scale_vol_aug(spec, self.params)
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = warp_spec_aug(
|
||||
spec, ann, self.return_spec_for_viz, self.params
|
||||
)
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = mask_time_aug(spec, self.params)
|
||||
|
||||
if np.random.random() < self.params["aug_prob"]:
|
||||
spec = mask_freq_aug(spec, self.params)
|
||||
|
||||
outputs = {}
|
||||
outputs["spec"] = spec
|
||||
if self.return_spec_for_viz:
|
||||
outputs["spec_for_viz"] = torch.from_numpy(spec_for_viz).unsqueeze(
|
||||
0
|
||||
)
|
||||
|
||||
# create ground truth heatmaps
|
||||
(
|
||||
outputs["y_2d_det"],
|
||||
outputs["y_2d_size"],
|
||||
outputs["y_2d_classes"],
|
||||
ann_aug,
|
||||
) = generate_gt_heatmaps(spec_op_shape, sampling_rate, ann, self.params)
|
||||
|
||||
# hack to get around requirement that all vectors are the same length in
|
||||
# the output batch
|
||||
pad_size = self.max_num_anns - len(ann_aug["individual_ids"])
|
||||
outputs["is_valid"] = pad_aray(
|
||||
np.ones(len(ann_aug["individual_ids"])), pad_size
|
||||
)
|
||||
keys = [
|
||||
"class_ids",
|
||||
"individual_ids",
|
||||
"x_inds",
|
||||
"y_inds",
|
||||
"start_times",
|
||||
"end_times",
|
||||
"low_freqs",
|
||||
"high_freqs",
|
||||
]
|
||||
for kk in keys:
|
||||
outputs[kk] = pad_aray(ann_aug[kk], pad_size)
|
||||
|
||||
# convert to pytorch
|
||||
for kk in outputs.keys():
|
||||
if type(outputs[kk]) != torch.Tensor:
|
||||
outputs[kk] = torch.from_numpy(outputs[kk])
|
||||
|
||||
# scalars
|
||||
outputs["class_id_file"] = ann["class_id_file"]
|
||||
outputs["annotated"] = ann["annotated"]
|
||||
outputs["duration"] = duration
|
||||
outputs["sampling_rate"] = sampling_rate
|
||||
outputs["file_id"] = index
|
||||
|
||||
return outputs
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_anns)
|
||||
@ -1,14 +1,20 @@
|
||||
import numpy as np
|
||||
from sklearn.metrics import auc, roc_curve
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
auc,
|
||||
balanced_accuracy_score,
|
||||
roc_curve,
|
||||
)
|
||||
|
||||
|
||||
def compute_error_auc(op_str, gt, pred, prob):
|
||||
|
||||
# classification error
|
||||
pred_int = (pred > prob).astype(np.int32)
|
||||
pred_int = (pred > prob).astype(np.int)
|
||||
class_acc = (pred_int == gt).mean() * 100.0
|
||||
|
||||
# ROC - area under curve
|
||||
fpr, tpr, _ = roc_curve(gt, pred)
|
||||
fpr, tpr, thresholds = roc_curve(gt, pred)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
|
||||
print(
|
||||
@ -19,6 +25,7 @@ def compute_error_auc(op_str, gt, pred, prob):
|
||||
|
||||
|
||||
def calc_average_precision(recall, precision):
|
||||
|
||||
precision[np.isnan(precision)] = 0
|
||||
recall[np.isnan(recall)] = 0
|
||||
|
||||
@ -84,6 +91,7 @@ def compute_pre_rec(
|
||||
pred_class = []
|
||||
file_ids = []
|
||||
for pid, pp in enumerate(preds):
|
||||
|
||||
# filter predicted calls that are too near the start or end of the file
|
||||
file_dur = gts[pid]["duration"]
|
||||
valid_inds = (pp["start_times"] >= ignore_start_end) & (
|
||||
@ -121,7 +129,7 @@ def compute_pre_rec(
|
||||
file_ids.append([pid] * valid_inds.sum())
|
||||
|
||||
confidence = np.hstack(confidence)
|
||||
file_ids = np.hstack(file_ids).astype(int)
|
||||
file_ids = np.hstack(file_ids).astype(np.int)
|
||||
pred_boxes = np.vstack(pred_boxes)
|
||||
if len(pred_class) > 0:
|
||||
pred_class = np.hstack(pred_class)
|
||||
@ -133,6 +141,7 @@ def compute_pre_rec(
|
||||
gt_generic_class = []
|
||||
num_positives = 0
|
||||
for gg in gts:
|
||||
|
||||
# filter ground truth calls that are too near the start or end of the file
|
||||
file_dur = gg["duration"]
|
||||
valid_inds = (gg["start_times"] >= ignore_start_end) & (
|
||||
@ -141,7 +150,8 @@ def compute_pre_rec(
|
||||
|
||||
# note, files with the incorrect duration will cause a problem
|
||||
if (gg["start_times"] > file_dur).sum() > 0:
|
||||
raise ValueError(f"Error: file duration incorrect for {gg['id']}")
|
||||
print("Error: file duration incorrect for", gg["id"])
|
||||
assert False
|
||||
|
||||
boxes = np.vstack(
|
||||
(
|
||||
@ -187,8 +197,6 @@ def compute_pre_rec(
|
||||
gt_id = file_ids[ind]
|
||||
|
||||
valid_det = False
|
||||
det_ind = 0
|
||||
|
||||
if gt_boxes[gt_id].shape[0] > 0:
|
||||
# compute overlap
|
||||
valid_det, det_ind = compute_affinity_1d(
|
||||
@ -197,6 +205,7 @@ def compute_pre_rec(
|
||||
|
||||
# valid detection that has not already been assigned
|
||||
if valid_det and (gt_assigned[gt_id][det_ind] == 0):
|
||||
|
||||
count_as_true_pos = True
|
||||
if eval_mode == "top_class" and (
|
||||
gt_class[gt_id][det_ind] != pred_class[ind]
|
||||
@ -218,7 +227,7 @@ def compute_pre_rec(
|
||||
# store threshold values - used for plotting
|
||||
conf_sorted = np.sort(confidence)[::-1][valid_inds]
|
||||
thresholds = np.linspace(0.1, 0.9, 9)
|
||||
thresholds_inds = np.zeros(len(thresholds), dtype=int)
|
||||
thresholds_inds = np.zeros(len(thresholds), dtype=np.int)
|
||||
for ii, tt in enumerate(thresholds):
|
||||
thresholds_inds[ii] = np.argmin(conf_sorted > tt)
|
||||
thresholds_inds[thresholds_inds == 0] = -1
|
||||
@ -330,7 +339,7 @@ def compute_file_accuracy(gts, preds, num_classes):
|
||||
).mean(0)
|
||||
best_thresh = np.argmax(acc_per_thresh)
|
||||
best_acc = acc_per_thresh[best_thresh]
|
||||
pred_valid = pred_valid_all[:, best_thresh].astype(int).tolist()
|
||||
pred_valid = pred_valid_all[:, best_thresh].astype(np.int).tolist()
|
||||
|
||||
res = {}
|
||||
res["num_valid_files"] = len(gt_valid)
|
||||
63
batdetect2/train/losses.py
Normal file
63
batdetect2/train/losses.py
Normal file
@ -0,0 +1,63 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def bbox_size_loss(pred_size, gt_size):
|
||||
"""
|
||||
Bounding box size loss. Only compute loss where there is a bounding box.
|
||||
"""
|
||||
gt_size_mask = (gt_size > 0).float()
|
||||
return F.l1_loss(pred_size * gt_size_mask, gt_size, reduction="sum") / (
|
||||
gt_size_mask.sum() + 1e-5
|
||||
)
|
||||
|
||||
|
||||
def focal_loss(pred, gt, weights=None, valid_mask=None):
|
||||
"""
|
||||
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
|
||||
pred (batch x c x h x w)
|
||||
gt (batch x c x h x w)
|
||||
"""
|
||||
eps = 1e-5
|
||||
beta = 4
|
||||
alpha = 2
|
||||
|
||||
pos_inds = gt.eq(1).float()
|
||||
neg_inds = gt.lt(1).float()
|
||||
|
||||
pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds
|
||||
neg_loss = (
|
||||
torch.log(1 - pred + eps)
|
||||
* torch.pow(pred, alpha)
|
||||
* torch.pow(1 - gt, beta)
|
||||
* neg_inds
|
||||
)
|
||||
|
||||
if weights is not None:
|
||||
pos_loss = pos_loss * weights
|
||||
# neg_loss = neg_loss*weights
|
||||
|
||||
if valid_mask is not None:
|
||||
pos_loss = pos_loss * valid_mask
|
||||
neg_loss = neg_loss * valid_mask
|
||||
|
||||
pos_loss = pos_loss.sum()
|
||||
neg_loss = neg_loss.sum()
|
||||
|
||||
num_pos = pos_inds.float().sum()
|
||||
if num_pos == 0:
|
||||
loss = -neg_loss
|
||||
else:
|
||||
loss = -(pos_loss + neg_loss) / num_pos
|
||||
return loss
|
||||
|
||||
|
||||
def mse_loss(pred, gt, weights=None, valid_mask=None):
|
||||
"""
|
||||
Mean squared error loss.
|
||||
"""
|
||||
if valid_mask is None:
|
||||
op = ((gt - pred) ** 2).mean()
|
||||
else:
|
||||
op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
|
||||
return op
|
||||
@ -1,13 +1,10 @@
|
||||
## How to train a model from scratch
|
||||
|
||||
> **Warning**
|
||||
> This code in currently broken. Will fix soon, stay tuned.
|
||||
|
||||
`python train_model.py data_dir annotation_dir` e.g.
|
||||
`python train_model.py /data1/bat_data/data/ /data1/bat_data/annotations/anns/`
|
||||
|
||||
More comprehensive instructions are provided in the finetune directory.
|
||||
|
||||
|
||||
## Training on your own data
|
||||
You can either use the finetuning scripts to finetune from an existing training dataset. Follow the instructions in the `../finetune/` directory.
|
||||
|
||||
@ -2,17 +2,16 @@ import argparse
|
||||
import json
|
||||
import warnings
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
import batdetect2.detector.post_process as pp
|
||||
import batdetect2.train.audio_dataloader as adl
|
||||
import batdetect2.train.evaluate as evl
|
||||
import batdetect2.train.train_split as ts
|
||||
import batdetect2.train.train_utils as tu
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
import batdetect2.detector.post_process as pp
|
||||
import batdetect2.utils.plot_utils as pu
|
||||
from batdetect2.detector import models, parameters
|
||||
from batdetect2.train import losses
|
||||
@ -30,7 +29,7 @@ def save_images_batch(model, data_loader, params):
|
||||
|
||||
ind = 0 # first image in each batch
|
||||
with torch.no_grad():
|
||||
for inputs in data_loader:
|
||||
for batch_idx, inputs in enumerate(data_loader):
|
||||
data = inputs["spec"].to(params["device"])
|
||||
outputs = model(data)
|
||||
|
||||
@ -82,12 +81,7 @@ def save_image(
|
||||
|
||||
|
||||
def loss_fun(
|
||||
outputs,
|
||||
gt_det,
|
||||
gt_size,
|
||||
gt_class,
|
||||
det_criterion,
|
||||
params,
|
||||
outputs, gt_det, gt_size, gt_class, det_criterion, params, class_inv_freq
|
||||
):
|
||||
# detection loss
|
||||
loss = params["det_loss_weight"] * det_criterion(
|
||||
@ -110,13 +104,7 @@ def loss_fun(
|
||||
|
||||
|
||||
def train(
|
||||
model,
|
||||
epoch,
|
||||
data_loader,
|
||||
det_criterion,
|
||||
optimizer,
|
||||
scheduler,
|
||||
params,
|
||||
model, epoch, data_loader, det_criterion, optimizer, scheduler, params
|
||||
):
|
||||
model.train()
|
||||
|
||||
@ -321,7 +309,7 @@ def select_model(params):
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
else:
|
||||
raise ValueError("No valid network specified")
|
||||
print("No valid network specified")
|
||||
return model
|
||||
|
||||
|
||||
@ -331,9 +319,9 @@ def main():
|
||||
params = parameters.get_params(True)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
params.device = "cuda"
|
||||
params["device"] = "cuda"
|
||||
else:
|
||||
params.device = "cpu"
|
||||
params["device"] = "cpu"
|
||||
|
||||
# setup arg parser and populate it with exiting parameters - will not work with lists
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -361,16 +349,13 @@ def main():
|
||||
default="Rhinolophus ferrumequinum;Rhinolophus hipposideros",
|
||||
help='Will set low and high frequency the same for these classes. Separate names with ";"',
|
||||
)
|
||||
|
||||
for key, val in params.items():
|
||||
parser.add_argument("--" + key, type=type(val), default=val)
|
||||
params = vars(parser.parse_args())
|
||||
|
||||
# save notes file
|
||||
if params["notes"] != "":
|
||||
tu.write_notes_file(
|
||||
params["experiment"] + "notes.txt", params["notes"]
|
||||
)
|
||||
tu.write_notes_file(params["experiment"] + "notes.txt", params["notes"])
|
||||
|
||||
# load the training and test meta data - there are different splits defined
|
||||
train_sets, test_sets = ts.get_train_test_data(
|
||||
@ -389,11 +374,15 @@ def main():
|
||||
for tt in train_sets:
|
||||
print(tt["ann_path"])
|
||||
classes_to_ignore = params["classes_to_ignore"] + params["generic_class"]
|
||||
data_train = tu.load_set_of_anns(
|
||||
(
|
||||
data_train,
|
||||
params["class_names"],
|
||||
params["class_inv_freq"],
|
||||
) = tu.load_set_of_anns(
|
||||
train_sets,
|
||||
classes_to_ignore=classes_to_ignore,
|
||||
events_of_interest=params["events_of_interest"],
|
||||
convert_to_genus=params["convert_to_genus"],
|
||||
classes_to_ignore,
|
||||
params["events_of_interest"],
|
||||
params["convert_to_genus"],
|
||||
)
|
||||
params["genus_names"], params["genus_mapping"] = tu.get_genus_mapping(
|
||||
params["class_names"]
|
||||
@ -426,12 +415,11 @@ def main():
|
||||
print("\nTesting on:")
|
||||
for tt in test_sets:
|
||||
print(tt["ann_path"])
|
||||
|
||||
data_test = tu.load_set_of_anns(
|
||||
data_test, _, _ = tu.load_set_of_anns(
|
||||
test_sets,
|
||||
classes_to_ignore=classes_to_ignore,
|
||||
events_of_interest=params["events_of_interest"],
|
||||
convert_to_genus=params["convert_to_genus"],
|
||||
classes_to_ignore,
|
||||
params["events_of_interest"],
|
||||
params["convert_to_genus"],
|
||||
)
|
||||
data_train = tu.remove_dupes(data_train, data_test)
|
||||
test_dataset = adl.AudioLoader(data_test, params, is_train=False)
|
||||
@ -459,13 +447,10 @@ def main():
|
||||
scheduler = CosineAnnealingLR(
|
||||
optimizer, params["num_epochs"] * len(train_loader)
|
||||
)
|
||||
|
||||
if params["train_loss"] == "mse":
|
||||
det_criterion = losses.mse_loss
|
||||
elif params["train_loss"] == "focal":
|
||||
det_criterion = losses.focal_loss
|
||||
else:
|
||||
raise ValueError("No valid loss specified")
|
||||
|
||||
# save parameters to file
|
||||
with open(params["experiment"] + "params.json", "w") as da:
|
||||
@ -10,12 +10,13 @@ def get_train_test_data(ann_dir, wav_dir, split_name, load_extra=True):
|
||||
train_sets, test_sets = split_same(ann_dir, wav_dir, load_extra)
|
||||
else:
|
||||
print("Split not defined")
|
||||
raise AssertionError()
|
||||
assert False
|
||||
|
||||
return train_sets, test_sets
|
||||
|
||||
|
||||
def split_diff(ann_dir, wav_dir, load_extra=True):
|
||||
|
||||
train_sets = []
|
||||
if load_extra:
|
||||
train_sets.append(
|
||||
@ -143,6 +144,7 @@ def split_diff(ann_dir, wav_dir, load_extra=True):
|
||||
|
||||
|
||||
def split_same(ann_dir, wav_dir, load_extra=True):
|
||||
|
||||
train_sets = []
|
||||
if load_extra:
|
||||
train_sets.append(
|
||||
207
batdetect2/train/train_utils.py
Normal file
207
batdetect2/train/train_utils.py
Normal file
@ -0,0 +1,207 @@
|
||||
import glob
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def write_notes_file(file_name, text):
|
||||
with open(file_name, "a") as da:
|
||||
da.write(text + "\n")
|
||||
|
||||
|
||||
def get_blank_dataset_dict(dataset_name, is_test, ann_path, wav_path):
|
||||
ddict = {
|
||||
"dataset_name": dataset_name,
|
||||
"is_test": is_test,
|
||||
"is_binary": False,
|
||||
"ann_path": ann_path,
|
||||
"wav_path": wav_path,
|
||||
}
|
||||
return ddict
|
||||
|
||||
|
||||
def get_short_class_names(class_names, str_len=3):
|
||||
class_names_short = []
|
||||
for cc in class_names:
|
||||
class_names_short.append(
|
||||
" ".join([sp[:str_len] for sp in cc.split(" ")])
|
||||
)
|
||||
return class_names_short
|
||||
|
||||
|
||||
def remove_dupes(data_train, data_test):
|
||||
test_ids = [dd["id"] for dd in data_test]
|
||||
data_train_prune = []
|
||||
for aa in data_train:
|
||||
if aa["id"] not in test_ids:
|
||||
data_train_prune.append(aa)
|
||||
diff = len(data_train) - len(data_train_prune)
|
||||
if diff != 0:
|
||||
print(diff, "items removed from train set")
|
||||
return data_train_prune
|
||||
|
||||
|
||||
def get_genus_mapping(class_names):
|
||||
genus_names, genus_mapping = np.unique(
|
||||
[cc.split(" ")[0] for cc in class_names], return_inverse=True
|
||||
)
|
||||
return genus_names.tolist(), genus_mapping.tolist()
|
||||
|
||||
|
||||
def standardize_low_freq(data, class_of_interest):
|
||||
# address the issue of highly variable low frequency annotations
|
||||
# this often happens for contstant frequency calls
|
||||
# for the class of interest sets the low and high freq to be the dataset mean
|
||||
low_freqs = []
|
||||
high_freqs = []
|
||||
for dd in data:
|
||||
for aa in dd["annotation"]:
|
||||
if aa["class"] == class_of_interest:
|
||||
low_freqs.append(aa["low_freq"])
|
||||
high_freqs.append(aa["high_freq"])
|
||||
|
||||
low_mean = np.mean(low_freqs)
|
||||
high_mean = np.mean(high_freqs)
|
||||
assert low_mean < high_mean
|
||||
|
||||
print("\nStandardizing low and high frequency for:")
|
||||
print(class_of_interest)
|
||||
print("low: ", round(low_mean, 2))
|
||||
print("high: ", round(high_mean, 2))
|
||||
|
||||
# only set the low freq, high stays the same
|
||||
# assumes that low_mean < high_mean
|
||||
for dd in data:
|
||||
for aa in dd["annotation"]:
|
||||
if aa["class"] == class_of_interest:
|
||||
aa["low_freq"] = low_mean
|
||||
if aa["high_freq"] < low_mean:
|
||||
aa["high_freq"] = high_mean
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def load_set_of_anns(
|
||||
data,
|
||||
classes_to_ignore=[],
|
||||
events_of_interest=None,
|
||||
convert_to_genus=False,
|
||||
verbose=True,
|
||||
list_of_anns=False,
|
||||
filter_issues=False,
|
||||
name_replace=False,
|
||||
):
|
||||
|
||||
# load the annotations
|
||||
anns = []
|
||||
if list_of_anns:
|
||||
# path to list of individual json files
|
||||
anns.extend(load_anns_from_path(data["ann_path"], data["wav_path"]))
|
||||
else:
|
||||
# dictionary of datasets
|
||||
for dd in data:
|
||||
anns.extend(load_anns(dd["ann_path"], dd["wav_path"]))
|
||||
|
||||
# discarding unannoated files
|
||||
anns = [aa for aa in anns if aa["annotated"] is True]
|
||||
|
||||
# filter files that have annotation issues - is the input is a dictionary of
|
||||
# datasets, this will lilely have already been done
|
||||
if filter_issues:
|
||||
anns = [aa for aa in anns if aa["issues"] is False]
|
||||
|
||||
# check for some basic formatting errors with class names
|
||||
for ann in anns:
|
||||
for aa in ann["annotation"]:
|
||||
aa["class"] = aa["class"].strip()
|
||||
|
||||
# only load specified events - i.e. types of calls
|
||||
if events_of_interest is not None:
|
||||
for ann in anns:
|
||||
filtered_events = []
|
||||
for aa in ann["annotation"]:
|
||||
if aa["event"] in events_of_interest:
|
||||
filtered_events.append(aa)
|
||||
ann["annotation"] = filtered_events
|
||||
|
||||
# change class names
|
||||
# replace_names will be a dictionary mapping input name to output
|
||||
if type(name_replace) is dict:
|
||||
for ann in anns:
|
||||
for aa in ann["annotation"]:
|
||||
if aa["class"] in name_replace:
|
||||
aa["class"] = name_replace[aa["class"]]
|
||||
|
||||
# convert everything to genus name
|
||||
if convert_to_genus:
|
||||
for ann in anns:
|
||||
for aa in ann["annotation"]:
|
||||
aa["class"] = aa["class"].split(" ")[0]
|
||||
|
||||
# get unique class names
|
||||
class_names_all = []
|
||||
for ann in anns:
|
||||
for aa in ann["annotation"]:
|
||||
if aa["class"] not in classes_to_ignore:
|
||||
class_names_all.append(aa["class"])
|
||||
|
||||
class_names, class_cnts = np.unique(class_names_all, return_counts=True)
|
||||
class_inv_freq = class_cnts.sum() / (
|
||||
len(class_names) * class_cnts.astype(np.float32)
|
||||
)
|
||||
|
||||
if verbose:
|
||||
print("Class count:")
|
||||
str_len = np.max([len(cc) for cc in class_names]) + 5
|
||||
for cc in range(len(class_names)):
|
||||
print(
|
||||
str(cc).ljust(5)
|
||||
+ class_names[cc].ljust(str_len)
|
||||
+ str(class_cnts[cc])
|
||||
)
|
||||
|
||||
if len(classes_to_ignore) == 0:
|
||||
return anns
|
||||
else:
|
||||
return anns, class_names.tolist(), class_inv_freq.tolist()
|
||||
|
||||
|
||||
def load_anns(ann_file_name, raw_audio_dir):
|
||||
with open(ann_file_name) as da:
|
||||
anns = json.load(da)
|
||||
|
||||
for aa in anns:
|
||||
aa["file_path"] = raw_audio_dir + aa["id"]
|
||||
|
||||
return anns
|
||||
|
||||
|
||||
def load_anns_from_path(ann_file_dir, raw_audio_dir):
|
||||
files = glob.glob(ann_file_dir + "*.json")
|
||||
anns = []
|
||||
for ff in files:
|
||||
with open(ff) as da:
|
||||
ann = json.load(da)
|
||||
ann["file_path"] = raw_audio_dir + ann["id"]
|
||||
anns.append(ann)
|
||||
|
||||
return anns
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
@ -1,14 +1,29 @@
|
||||
"""Types used in the code base."""
|
||||
|
||||
import sys
|
||||
from typing import Any, NamedTuple, Protocol, TypedDict
|
||||
from typing import List, NamedTuple, Optional, Union, Any, BinaryIO
|
||||
|
||||
import audioread
|
||||
import os
|
||||
import soundfile as sf
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import NotRequired
|
||||
else:
|
||||
try:
|
||||
from typing import TypedDict
|
||||
except ImportError:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
try:
|
||||
from typing import Protocol
|
||||
except ImportError:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
|
||||
try:
|
||||
from typing import NotRequired # type: ignore
|
||||
except ImportError:
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
|
||||
@ -16,7 +31,8 @@ __all__ = [
|
||||
"Annotation",
|
||||
"DetectionModel",
|
||||
"FeatureExtractionParameters",
|
||||
"FileAnnotation",
|
||||
"FeatureExtractor",
|
||||
"FileAnnotations",
|
||||
"ModelOutput",
|
||||
"ModelParameters",
|
||||
"NonMaximumSuppressionConfig",
|
||||
@ -26,9 +42,11 @@ __all__ = [
|
||||
"ResultParams",
|
||||
"RunResults",
|
||||
"SpectrogramParameters",
|
||||
"AudioLoaderAnnotationGroup",
|
||||
]
|
||||
|
||||
AudioPath = Union[
|
||||
str, int, os.PathLike[Any], sf.SoundFile, audioread.AudioFile, BinaryIO
|
||||
]
|
||||
|
||||
class SpectrogramParameters(TypedDict):
|
||||
"""Parameters for generating spectrograms."""
|
||||
@ -82,11 +100,8 @@ class ModelParameters(TypedDict):
|
||||
resize_factor: float
|
||||
"""Resize factor."""
|
||||
|
||||
class_names: list[str]
|
||||
"""Class names.
|
||||
|
||||
The model is trained to detect these classes.
|
||||
"""
|
||||
class_names: List[str]
|
||||
"""Class names. The model is trained to detect these classes."""
|
||||
|
||||
|
||||
DictWithClass = TypedDict("DictWithClass", {"class": str})
|
||||
@ -95,8 +110,8 @@ DictWithClass = TypedDict("DictWithClass", {"class": str})
|
||||
class Annotation(DictWithClass):
|
||||
"""Format of annotations.
|
||||
|
||||
This is the format of a single annotation as expected by the
|
||||
annotation tool.
|
||||
This is the format of a single annotation as expected by the annotation
|
||||
tool.
|
||||
"""
|
||||
|
||||
start_time: float
|
||||
@ -105,10 +120,10 @@ class Annotation(DictWithClass):
|
||||
end_time: float
|
||||
"""End time in seconds."""
|
||||
|
||||
low_freq: float
|
||||
low_freq: int
|
||||
"""Low frequency in Hz."""
|
||||
|
||||
high_freq: float
|
||||
high_freq: int
|
||||
"""High frequency in Hz."""
|
||||
|
||||
class_prob: float
|
||||
@ -123,11 +138,8 @@ class Annotation(DictWithClass):
|
||||
event: str
|
||||
"""Type of detected event."""
|
||||
|
||||
class_id: NotRequired[int]
|
||||
"""Numeric ID for the class of the annotation."""
|
||||
|
||||
|
||||
class FileAnnotation(TypedDict):
|
||||
class FileAnnotations(TypedDict):
|
||||
"""Format of results.
|
||||
|
||||
This is the format of the results expected by the annotation tool.
|
||||
@ -149,41 +161,41 @@ class FileAnnotation(TypedDict):
|
||||
"""Time expansion factor."""
|
||||
|
||||
class_name: str
|
||||
"""Class predicted at file level."""
|
||||
"""Class predicted at file level"""
|
||||
|
||||
notes: str
|
||||
"""Notes of file."""
|
||||
|
||||
annotation: list[Annotation]
|
||||
annotation: List[Annotation]
|
||||
"""List of annotations."""
|
||||
|
||||
|
||||
class RunResults(TypedDict):
|
||||
"""Run results."""
|
||||
|
||||
pred_dict: FileAnnotation
|
||||
pred_dict: FileAnnotations
|
||||
"""Predictions in the format expected by the annotation tool."""
|
||||
|
||||
spec_feats: NotRequired[list[np.ndarray]]
|
||||
spec_feats: NotRequired[List[np.ndarray]]
|
||||
"""Spectrogram features."""
|
||||
|
||||
spec_feat_names: NotRequired[list[str]]
|
||||
spec_feat_names: NotRequired[List[str]]
|
||||
"""Spectrogram feature names."""
|
||||
|
||||
cnn_feats: NotRequired[list[np.ndarray]]
|
||||
cnn_feats: NotRequired[List[np.ndarray]]
|
||||
"""CNN features."""
|
||||
|
||||
cnn_feat_names: NotRequired[list[str]]
|
||||
cnn_feat_names: NotRequired[List[str]]
|
||||
"""CNN feature names."""
|
||||
|
||||
spec_slices: NotRequired[list[np.ndarray]]
|
||||
spec_slices: NotRequired[List[np.ndarray]]
|
||||
"""Spectrogram slices."""
|
||||
|
||||
|
||||
class ResultParams(TypedDict):
|
||||
"""Result parameters."""
|
||||
|
||||
class_names: list[str]
|
||||
class_names: List[str]
|
||||
"""Class names."""
|
||||
|
||||
spec_features: bool
|
||||
@ -230,13 +242,13 @@ class ProcessingConfiguration(TypedDict):
|
||||
scale_raw_audio: bool
|
||||
"""Whether to scale the raw audio to be between -1 and 1."""
|
||||
|
||||
class_names: list[str]
|
||||
class_names: List[str]
|
||||
"""Names of the classes the model can detect."""
|
||||
|
||||
detection_threshold: float
|
||||
"""Threshold for detection probability."""
|
||||
|
||||
time_expansion: float | None
|
||||
time_expansion: Optional[float]
|
||||
"""Time expansion factor of the processed recordings."""
|
||||
|
||||
top_n: int
|
||||
@ -245,7 +257,7 @@ class ProcessingConfiguration(TypedDict):
|
||||
return_raw_preds: bool
|
||||
"""Whether to return raw predictions."""
|
||||
|
||||
max_duration: float | None
|
||||
max_duration: Optional[float]
|
||||
"""Maximum duration of audio file to process in seconds."""
|
||||
|
||||
nms_kernel_size: int
|
||||
@ -386,9 +398,9 @@ class PredictionResults(TypedDict):
|
||||
class DetectionModel(Protocol):
|
||||
"""Protocol for detection models.
|
||||
|
||||
This protocol is used to define the interface for the detection
|
||||
models. This allows us to use the same code for training and
|
||||
inference, even though the models are different.
|
||||
This protocol is used to define the interface for the detection models.
|
||||
This allows us to use the same code for training and inference, even
|
||||
though the models are different.
|
||||
"""
|
||||
|
||||
num_classes: int
|
||||
@ -408,14 +420,16 @@ class DetectionModel(Protocol):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
spec: torch.Tensor,
|
||||
ip: torch.Tensor,
|
||||
return_feats: bool = False,
|
||||
) -> ModelOutput:
|
||||
"""Forward pass of the model."""
|
||||
...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
spec: torch.Tensor,
|
||||
ip: torch.Tensor,
|
||||
return_feats: bool = False,
|
||||
) -> ModelOutput:
|
||||
"""Forward pass of the model."""
|
||||
...
|
||||
@ -462,7 +476,7 @@ class FeatureExtractionParameters(TypedDict):
|
||||
class HeatmapParameters(TypedDict):
|
||||
"""Parameters that control the heatmap generation function."""
|
||||
|
||||
class_names: list[str]
|
||||
class_names: List[str]
|
||||
|
||||
fft_win_length: float
|
||||
"""Length of the FFT window in seconds."""
|
||||
@ -480,10 +494,8 @@ class HeatmapParameters(TypedDict):
|
||||
"""Maximum frequency to consider in Hz."""
|
||||
|
||||
target_sigma: float
|
||||
"""Sigma for the Gaussian kernel.
|
||||
|
||||
Controls the width of the points in the heatmap.
|
||||
"""
|
||||
"""Sigma for the Gaussian kernel. Controls the width of the points in
|
||||
the heatmap."""
|
||||
|
||||
|
||||
class AnnotationGroup(TypedDict):
|
||||
@ -511,15 +523,6 @@ class AnnotationGroup(TypedDict):
|
||||
individual_ids: np.ndarray
|
||||
"""Individual IDs of the annotations."""
|
||||
|
||||
annotated: NotRequired[bool]
|
||||
"""Wether the annotation group is complete or not.
|
||||
|
||||
Usually annotation groups are associated to a single audio clip. If
|
||||
the annotation group is complete, it means that all relevant sound
|
||||
events have been annotated. If it is not complete, it means that
|
||||
some sound events might not have been annotated.
|
||||
"""
|
||||
|
||||
x_inds: NotRequired[np.ndarray]
|
||||
"""X coordinate of the annotations in the spectrogram."""
|
||||
|
||||
@ -527,87 +530,9 @@ class AnnotationGroup(TypedDict):
|
||||
"""Y coordinate of the annotations in the spectrogram."""
|
||||
|
||||
|
||||
class AudioLoaderAnnotationGroup(TypedDict):
|
||||
"""Group of annotation items for the training audio loader.
|
||||
|
||||
This class is used to store the annotations for the training audio
|
||||
loader. It inherits from `AnnotationGroup` and `FileAnnotations`.
|
||||
"""
|
||||
|
||||
id: str
|
||||
duration: float
|
||||
issues: bool
|
||||
file_path: str
|
||||
time_exp: float
|
||||
class_name: str
|
||||
notes: str
|
||||
start_times: np.ndarray
|
||||
end_times: np.ndarray
|
||||
low_freqs: np.ndarray
|
||||
high_freqs: np.ndarray
|
||||
class_ids: np.ndarray
|
||||
individual_ids: np.ndarray
|
||||
x_inds: np.ndarray
|
||||
y_inds: np.ndarray
|
||||
annotation: list[Annotation]
|
||||
annotated: bool
|
||||
class_id_file: int
|
||||
"""ID of the class of the file."""
|
||||
|
||||
|
||||
class AudioLoaderParameters(TypedDict):
|
||||
class_names: list[str]
|
||||
classes_to_ignore: list[str]
|
||||
target_samp_rate: int
|
||||
scale_raw_audio: bool
|
||||
fft_win_length: float
|
||||
fft_overlap: float
|
||||
spec_train_width: int
|
||||
resize_factor: float
|
||||
spec_divide_factor: int
|
||||
augment_at_train: bool
|
||||
augment_at_train_combine: bool
|
||||
aug_prob: float
|
||||
spec_height: int
|
||||
echo_max_delay: float
|
||||
spec_amp_scaling: float
|
||||
stretch_squeeze_delta: float
|
||||
mask_max_time_perc: float
|
||||
mask_max_freq_perc: float
|
||||
max_freq: float
|
||||
min_freq: float
|
||||
spec_scale: str
|
||||
denoise_spec_avg: bool
|
||||
max_scale_spec: bool
|
||||
target_sigma: float
|
||||
|
||||
|
||||
class FeatureExtractor(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
prediction: Prediction,
|
||||
**kwargs: Any,
|
||||
) -> float: ...
|
||||
"""Protocol for feature extractors."""
|
||||
|
||||
|
||||
class DatasetDict(TypedDict):
|
||||
"""Dataset dictionary.
|
||||
|
||||
This is the format of the dictionary that contains the dataset
|
||||
information.
|
||||
"""
|
||||
|
||||
dataset_name: str
|
||||
"""Name of the dataset."""
|
||||
|
||||
is_test: bool
|
||||
"""Whether the dataset is a test set."""
|
||||
|
||||
is_binary: bool
|
||||
"""Whether the dataset is binary."""
|
||||
|
||||
ann_path: str
|
||||
"""Path to the annotations."""
|
||||
|
||||
wav_path: str
|
||||
"""Path to the audio files."""
|
||||
def __call__(self, prediction: Prediction, **kwargs) -> Union[float, int]:
|
||||
"""Extract features from a prediction."""
|
||||
...
|
||||
@ -1,16 +1,24 @@
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union, Any, BinaryIO
|
||||
|
||||
from ..types import AudioPath
|
||||
|
||||
import librosa
|
||||
import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import audioread
|
||||
import os
|
||||
import soundfile as sf
|
||||
|
||||
from batdetect2.detector import parameters
|
||||
|
||||
from . import wavfile
|
||||
|
||||
__all__ = [
|
||||
"load_audio",
|
||||
"load_audio_and_samplerate",
|
||||
"generate_spectrogram",
|
||||
"pad_audio",
|
||||
]
|
||||
@ -77,7 +85,7 @@ def generate_spectrogram(
|
||||
spec = np.vstack(
|
||||
(np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec)
|
||||
)
|
||||
spec = spec[-max_freq : spec.shape[0] - min_freq, :]
|
||||
spec_cropped = spec[-max_freq : spec.shape[0] - min_freq, :]
|
||||
|
||||
if params["spec_scale"] == "log":
|
||||
log_scaling = (
|
||||
@ -89,7 +97,7 @@ def generate_spectrogram(
|
||||
np.abs(
|
||||
np.hanning(
|
||||
int(params["fft_win_length"] * sampling_rate)
|
||||
).astype(np.float32)
|
||||
)
|
||||
)
|
||||
** 2
|
||||
).sum()
|
||||
@ -97,9 +105,9 @@ def generate_spectrogram(
|
||||
)
|
||||
# log_scaling = (1.0 / sampling_rate)*0.1
|
||||
# log_scaling = (1.0 / sampling_rate)*10e4
|
||||
spec = np.log1p(log_scaling * spec)
|
||||
spec = np.log1p(log_scaling * spec_cropped)
|
||||
elif params["spec_scale"] == "pcen":
|
||||
spec = pcen(spec, sampling_rate)
|
||||
spec = pcen(spec_cropped, sampling_rate)
|
||||
|
||||
elif params["spec_scale"] == "none":
|
||||
pass
|
||||
@ -133,54 +141,73 @@ def generate_spectrogram(
|
||||
).sum()
|
||||
)
|
||||
)
|
||||
spec_for_viz = np.log1p(log_scaling * spec).astype(np.float32)
|
||||
spec_for_viz = np.log1p(log_scaling * spec_cropped).astype(np.float32)
|
||||
else:
|
||||
spec_for_viz = None
|
||||
|
||||
return spec, spec_for_viz
|
||||
|
||||
|
||||
def load_audio(
|
||||
audio_file: str,
|
||||
path: AudioPath,
|
||||
time_exp_fact: float,
|
||||
target_samp_rate: int,
|
||||
scale: bool = False,
|
||||
max_duration: float | None = None,
|
||||
) -> tuple[int, np.ndarray]:
|
||||
max_duration: Optional[float] = None,
|
||||
) -> Tuple[int, np.ndarray ]:
|
||||
"""Load an audio file and resample it to the target sampling rate.
|
||||
|
||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||
Only mono files are supported.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_file: str
|
||||
Path to the audio file.
|
||||
target_samp_rate: int
|
||||
Target sampling rate.
|
||||
scale: bool, optional
|
||||
Whether to scale the audio to [-1, 1]. Default: False.
|
||||
max_duration: float, optional
|
||||
Maximum duration of the audio in seconds. Defaults to None.
|
||||
If provided, the audio is clipped to this duration.
|
||||
Args:
|
||||
path (string, int, pathlib.Path, soundfile.SoundFile, audioread object, or file-like object): path to the input file.
|
||||
target_samp_rate (int): Target sampling rate.
|
||||
scale (bool): Whether to scale the audio to [-1, 1].
|
||||
max_duration (float): Maximum duration of the audio in seconds.
|
||||
|
||||
Returns
|
||||
-------
|
||||
sampling_rate: int
|
||||
The sampling rate of the audio.
|
||||
audio_raw: np.ndarray
|
||||
The audio signal in a numpy array.
|
||||
Returns:
|
||||
sampling_rate: The sampling rate of the audio.
|
||||
audio_raw: The audio signal in a numpy array.
|
||||
|
||||
Raises
|
||||
------
|
||||
Raises:
|
||||
ValueError: If the audio file is stereo.
|
||||
|
||||
"""
|
||||
sample_rate, audio_data, _ = load_audio_and_samplerate(path, time_exp_fact, target_samp_rate, scale, max_duration)
|
||||
return sample_rate, audio_data
|
||||
|
||||
def load_audio_and_samplerate(
|
||||
path: AudioPath,
|
||||
time_exp_fact: float,
|
||||
target_samp_rate: int,
|
||||
scale: bool = False,
|
||||
max_duration: Optional[float] = None,
|
||||
) -> Tuple[int, np.ndarray, Union[float, int]]:
|
||||
"""Load an audio file and resample it to the target sampling rate.
|
||||
|
||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||
Only mono files are supported.
|
||||
|
||||
Args:
|
||||
path (string, int, pathlib.Path, soundfile.SoundFile, audioread object, or file-like object): path to the input file.
|
||||
target_samp_rate (int): Target sampling rate.
|
||||
scale (bool): Whether to scale the audio to [-1, 1].
|
||||
max_duration (float): Maximum duration of the audio in seconds.
|
||||
|
||||
Returns:
|
||||
sampling_rate: The sampling rate of the audio.
|
||||
audio_raw: The audio signal in a numpy array.
|
||||
file_sampling_rate: The original sampling rate of the audio
|
||||
|
||||
Raises:
|
||||
ValueError: If the audio file is stereo.
|
||||
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
|
||||
# sampling_rate, audio_raw = wavfile.read(audio_file)
|
||||
audio_raw, sampling_rate = librosa.load(
|
||||
audio_file,
|
||||
audio_raw, file_sampling_rate = librosa.load(
|
||||
path,
|
||||
sr=None,
|
||||
dtype=np.float32,
|
||||
)
|
||||
@ -188,7 +215,7 @@ def load_audio(
|
||||
if len(audio_raw.shape) > 1:
|
||||
raise ValueError("Currently does not handle stereo files")
|
||||
|
||||
sampling_rate = sampling_rate * time_exp_fact
|
||||
sampling_rate = file_sampling_rate * time_exp_fact
|
||||
|
||||
# resample - need to do this after correcting for time expansion
|
||||
sampling_rate_old = sampling_rate
|
||||
@ -216,7 +243,7 @@ def load_audio(
|
||||
audio_raw = audio_raw - audio_raw.mean()
|
||||
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
|
||||
|
||||
return sampling_rate, audio_raw
|
||||
return sampling_rate, audio_raw, file_sampling_rate
|
||||
|
||||
|
||||
def compute_spectrogram_width(
|
||||
@ -240,7 +267,7 @@ def pad_audio(
|
||||
window_overlap: float = parameters.FFT_OVERLAP,
|
||||
resize_factor: float = parameters.RESIZE_FACTOR,
|
||||
divide_factor: int = parameters.SPEC_DIVIDE_FACTOR,
|
||||
fixed_width: int | None = None,
|
||||
fixed_width: Optional[int] = None,
|
||||
):
|
||||
"""Pad audio to be evenly divisible by `divide_factor`.
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Iterator
|
||||
from typing import Any, Iterator, List, Optional, Tuple, Union, BinaryIO
|
||||
|
||||
from ..types import AudioPath
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
@ -21,7 +22,7 @@ from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
||||
from batdetect2.types import (
|
||||
Annotation,
|
||||
DetectionModel,
|
||||
FileAnnotation,
|
||||
FileAnnotations,
|
||||
ModelOutput,
|
||||
ModelParameters,
|
||||
PredictionResults,
|
||||
@ -31,6 +32,13 @@ from batdetect2.types import (
|
||||
SpectrogramParameters,
|
||||
)
|
||||
|
||||
import audioread
|
||||
import os
|
||||
import io
|
||||
import soundfile as sf
|
||||
import hashlib
|
||||
import uuid
|
||||
|
||||
__all__ = [
|
||||
"load_model",
|
||||
"list_audio_files",
|
||||
@ -60,7 +68,7 @@ def get_default_bd_args():
|
||||
return args
|
||||
|
||||
|
||||
def list_audio_files(ip_dir: str) -> list[str]:
|
||||
def list_audio_files(ip_dir: str) -> List[str]:
|
||||
"""Get all audio files in directory.
|
||||
|
||||
Args:
|
||||
@ -84,9 +92,9 @@ def list_audio_files(ip_dir: str) -> list[str]:
|
||||
def load_model(
|
||||
model_path: str = DEFAULT_MODEL_PATH,
|
||||
load_weights: bool = True,
|
||||
device: torch.device | str | None = None,
|
||||
device: Optional[torch.device] = None,
|
||||
weights_only: bool = True,
|
||||
) -> tuple[DetectionModel, ModelParameters]:
|
||||
) -> Tuple[DetectionModel, ModelParameters]:
|
||||
"""Load model from file.
|
||||
|
||||
Args:
|
||||
@ -185,16 +193,15 @@ def _merge_results(predictions, spec_feats, cnn_feats, spec_slices):
|
||||
|
||||
def get_annotations_from_preds(
|
||||
predictions: PredictionResults,
|
||||
class_names: list[str],
|
||||
) -> list[Annotation]:
|
||||
class_names: List[str],
|
||||
) -> List[Annotation]:
|
||||
"""Get list of annotations from predictions."""
|
||||
# Get the best class prediction probability and index for each detection
|
||||
class_prob_best = predictions["class_probs"].max(0)
|
||||
class_ind_best = predictions["class_probs"].argmax(0)
|
||||
|
||||
# Pack the results into a list of dictionaries
|
||||
annotations: list[Annotation] = [
|
||||
Annotation(
|
||||
annotations: List[Annotation] = [
|
||||
{
|
||||
"start_time": round(float(start_time), 4),
|
||||
"end_time": round(float(end_time), 4),
|
||||
@ -206,7 +213,6 @@ def get_annotations_from_preds(
|
||||
"individual": "-1",
|
||||
"event": "Echolocation",
|
||||
}
|
||||
)
|
||||
for (
|
||||
start_time,
|
||||
end_time,
|
||||
@ -223,7 +229,6 @@ def get_annotations_from_preds(
|
||||
class_ind_best,
|
||||
class_prob_best,
|
||||
predictions["det_probs"],
|
||||
strict=False,
|
||||
)
|
||||
]
|
||||
return annotations
|
||||
@ -234,8 +239,8 @@ def format_single_result(
|
||||
time_exp: float,
|
||||
duration: float,
|
||||
predictions: PredictionResults,
|
||||
class_names: list[str],
|
||||
) -> FileAnnotation:
|
||||
class_names: List[str],
|
||||
) -> FileAnnotations:
|
||||
"""Format results into the format expected by the annotation tool.
|
||||
|
||||
Args:
|
||||
@ -282,7 +287,7 @@ def convert_results(
|
||||
spec_feats,
|
||||
cnn_feats,
|
||||
spec_slices,
|
||||
nyquist_freq: float | None = None,
|
||||
nyquist_freq: Optional[float] = None,
|
||||
) -> RunResults:
|
||||
"""Convert results to dictionary as expected by the annotation tool.
|
||||
|
||||
@ -317,11 +322,9 @@ def convert_results(
|
||||
]
|
||||
|
||||
# combine into final results dictionary
|
||||
results: RunResults = RunResults( # type: ignore[missing-argument]
|
||||
{
|
||||
results: RunResults = {
|
||||
"pred_dict": pred_dict,
|
||||
}
|
||||
)
|
||||
|
||||
# add spectrogram features if they exist
|
||||
if len(spec_feats) > 0 and params["spec_features"]:
|
||||
@ -417,7 +420,8 @@ def compute_spectrogram(
|
||||
sampling_rate: int,
|
||||
params: SpectrogramParameters,
|
||||
device: torch.device,
|
||||
) -> tuple[float, torch.Tensor]:
|
||||
return_np: bool = False,
|
||||
) -> Tuple[float, torch.Tensor, Optional[np.ndarray]]:
|
||||
"""Compute a spectrogram from an audio array.
|
||||
|
||||
Will pad the audio array so that it is evenly divisible by the
|
||||
@ -426,16 +430,24 @@ def compute_spectrogram(
|
||||
Parameters
|
||||
----------
|
||||
audio : np.ndarray
|
||||
|
||||
sampling_rate : int
|
||||
|
||||
params : SpectrogramParameters
|
||||
The parameters to use for generating the spectrogram.
|
||||
|
||||
return_np : bool, optional
|
||||
Whether to return the spectrogram as a numpy array as well as a
|
||||
torch tensor. The default is False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
duration : float
|
||||
The duration of the spectrgram in seconds.
|
||||
|
||||
spec : torch.Tensor
|
||||
The spectrogram as a torch tensor.
|
||||
|
||||
spec_np : np.ndarray, optional
|
||||
The spectrogram as a numpy array. Only returned if `return_np` is
|
||||
True, otherwise None.
|
||||
@ -472,14 +484,20 @@ def compute_spectrogram(
|
||||
mode="bilinear",
|
||||
align_corners=False,
|
||||
)
|
||||
return duration, spec
|
||||
|
||||
if return_np:
|
||||
spec_np = spec[0, 0, :].cpu().data.numpy()
|
||||
else:
|
||||
spec_np = None
|
||||
|
||||
return duration, spec, spec_np
|
||||
|
||||
|
||||
def iterate_over_chunks(
|
||||
audio: np.ndarray,
|
||||
samplerate: float,
|
||||
samplerate: int,
|
||||
chunk_size: float,
|
||||
) -> Iterator[tuple[float, np.ndarray]]:
|
||||
) -> Iterator[Tuple[float, np.ndarray]]:
|
||||
"""Iterate over audio in chunks of size chunk_size.
|
||||
|
||||
Parameters
|
||||
@ -510,10 +528,10 @@ def iterate_over_chunks(
|
||||
|
||||
def _process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
samplerate: float,
|
||||
samplerate: int,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
) -> tuple[PredictionResults, np.ndarray]:
|
||||
) -> Tuple[PredictionResults, np.ndarray]:
|
||||
# evaluate model
|
||||
with torch.no_grad():
|
||||
outputs = model(spec)
|
||||
@ -550,7 +568,7 @@ def postprocess_model_outputs(
|
||||
outputs: ModelOutput,
|
||||
samp_rate: int,
|
||||
config: ProcessingConfiguration,
|
||||
) -> tuple[list[Annotation], np.ndarray]:
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
# run non-max suppression
|
||||
pred_nms_list, features = pp.run_nms(
|
||||
outputs,
|
||||
@ -589,7 +607,7 @@ def process_spectrogram(
|
||||
samplerate: int,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
) -> tuple[list[Annotation], np.ndarray]:
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
"""Process a spectrogram with detection model.
|
||||
|
||||
Will run non-maximum suppression on the output of the model.
|
||||
@ -608,9 +626,9 @@ def process_spectrogram(
|
||||
|
||||
Returns
|
||||
-------
|
||||
detections
|
||||
detections: List[Annotation]
|
||||
List of detections predicted by the model.
|
||||
features
|
||||
features : np.ndarray
|
||||
An array of CNN features associated with each annotation.
|
||||
The array is of shape (num_detections, num_features).
|
||||
Is empty if `config["cnn_features"]` is False.
|
||||
@ -636,9 +654,9 @@ def _process_audio_array(
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
) -> tuple[PredictionResults, np.ndarray, torch.Tensor]:
|
||||
) -> Tuple[PredictionResults, np.ndarray, torch.Tensor]:
|
||||
# load audio file and compute spectrogram
|
||||
_, spec = compute_spectrogram(
|
||||
_, spec, _ = compute_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
{
|
||||
@ -654,6 +672,7 @@ def _process_audio_array(
|
||||
"max_scale_spec": config["max_scale_spec"],
|
||||
},
|
||||
device,
|
||||
return_np=False,
|
||||
)
|
||||
|
||||
# process spectrogram with model
|
||||
@ -673,7 +692,7 @@ def process_audio_array(
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
) -> tuple[list[Annotation], np.ndarray, torch.Tensor]:
|
||||
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
|
||||
"""Process a single audio array with detection model.
|
||||
|
||||
Parameters
|
||||
@ -693,7 +712,7 @@ def process_audio_array(
|
||||
|
||||
Returns
|
||||
-------
|
||||
annotations : list[Annotation]
|
||||
annotations : List[Annotation]
|
||||
List of annotations predicted by the model.
|
||||
features : np.ndarray
|
||||
Array of CNN features associated with each annotation.
|
||||
@ -718,11 +737,12 @@ def process_audio_array(
|
||||
|
||||
|
||||
def process_file(
|
||||
audio_file: str,
|
||||
path: AudioPath,
|
||||
model: DetectionModel,
|
||||
config: ProcessingConfiguration,
|
||||
device: torch.device,
|
||||
) -> RunResults | Any:
|
||||
file_id: Optional[str] = None
|
||||
) -> Union[RunResults, Any]:
|
||||
"""Process a single audio file with detection model.
|
||||
|
||||
Will split the audio file into chunks if it is too long and
|
||||
@ -730,7 +750,7 @@ def process_file(
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_file : str
|
||||
path : AudioPath
|
||||
Path to audio file.
|
||||
|
||||
model : torch.nn.Module
|
||||
@ -739,6 +759,9 @@ def process_file(
|
||||
config : ProcessingConfiguration
|
||||
Configuration for processing.
|
||||
|
||||
file_id: Optional[str],
|
||||
Give the data an id. Defaults to the filename if path is a string. Otherwise an md5 will be calculated from the binary data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
results : Results or Any
|
||||
@ -751,19 +774,17 @@ def process_file(
|
||||
cnn_feats = []
|
||||
spec_slices = []
|
||||
|
||||
# Get original sampling rate
|
||||
file_samp_rate = librosa.get_samplerate(audio_file)
|
||||
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
|
||||
|
||||
# load audio file
|
||||
sampling_rate, audio_full = au.load_audio(
|
||||
audio_file,
|
||||
sampling_rate, audio_full, file_samp_rate = au.load_audio_and_samplerate(
|
||||
path,
|
||||
time_exp_fact=config.get("time_expansion", 1) or 1,
|
||||
target_samp_rate=config["target_samp_rate"],
|
||||
scale=config["scale_raw_audio"],
|
||||
max_duration=config.get("max_duration"),
|
||||
)
|
||||
|
||||
orig_samp_rate = file_samp_rate * (config.get("time_expansion") or 1)
|
||||
|
||||
# loop through larger file and split into chunks
|
||||
# TODO: fix so that it overlaps correctly and takes care of
|
||||
# duplicate detections at borders
|
||||
@ -801,6 +822,7 @@ def process_file(
|
||||
cnn_feats.append(features[0])
|
||||
|
||||
if config["spec_slices"]:
|
||||
# FIX: This is not currently working. Returns empty slices
|
||||
spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms))
|
||||
|
||||
# Merge results from chunks
|
||||
@ -811,9 +833,13 @@ def process_file(
|
||||
spec_slices,
|
||||
)
|
||||
|
||||
_file_id = file_id
|
||||
if _file_id is None:
|
||||
_file_id = _generate_id(path)
|
||||
|
||||
# convert results to a dictionary in the right format
|
||||
results = convert_results(
|
||||
file_id=os.path.basename(audio_file),
|
||||
file_id=_file_id,
|
||||
time_exp=config.get("time_expansion", 1) or 1,
|
||||
duration=audio_full.shape[0] / float(sampling_rate),
|
||||
params=config,
|
||||
@ -833,6 +859,22 @@ def process_file(
|
||||
|
||||
return results
|
||||
|
||||
def _generate_id(path: AudioPath) -> str:
|
||||
""" Generate an id based on the path.
|
||||
|
||||
If the path is a str or PathLike it will parsed as the basename.
|
||||
This should ensure backwards compatibility with previous versions.
|
||||
"""
|
||||
if isinstance(path, str) or isinstance(path, os.PathLike):
|
||||
return os.path.basename(path)
|
||||
elif isinstance(path, (BinaryIO, io.BytesIO)):
|
||||
path.seek(0)
|
||||
md5 = hashlib.md5(path.read()).hexdigest()
|
||||
path.seek(0)
|
||||
return md5
|
||||
else:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def summarize_results(results, predictions, config):
|
||||
"""Print summary of results."""
|
||||
@ -87,7 +87,9 @@ def save_ann_spec(
|
||||
y_extent = [0, duration, min_freq, max_freq]
|
||||
|
||||
plt.close("all")
|
||||
plt.figure(0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100)
|
||||
fig = plt.figure(
|
||||
0, figsize=(spec.shape[1] / 100, spec.shape[0] / 100), dpi=100
|
||||
)
|
||||
plt.imshow(
|
||||
spec,
|
||||
aspect="auto",
|
||||
@ -367,7 +369,7 @@ def plot_pr_curve_class(
|
||||
|
||||
# print(class_name)
|
||||
# plot the location of the confidence threshold values
|
||||
for jj, _tt in enumerate(rr["thresholds"]):
|
||||
for jj, tt in enumerate(rr["thresholds"]):
|
||||
ind = rr["thresholds_inds"][jj]
|
||||
if ind > -1:
|
||||
plt.plot(
|
||||
@ -415,9 +417,7 @@ def plot_confusion_matrix(
|
||||
cm_norm = cm.sum(1)
|
||||
|
||||
valid_inds = np.where(cm_norm > 0)[0]
|
||||
cm[valid_inds, :] = (
|
||||
cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
||||
)
|
||||
cm[valid_inds, :] = cm[valid_inds, :] / cm_norm[valid_inds][..., np.newaxis]
|
||||
cm[np.where(cm_norm == -0)[0], :] = np.nan
|
||||
|
||||
if verbose:
|
||||
@ -133,7 +133,7 @@ class InteractivePlotter:
|
||||
self.fig.canvas.mpl_connect("key_press_event", self.key_press)
|
||||
|
||||
def mouse_hover(self, event):
|
||||
self.annot.get_visible()
|
||||
vis = self.annot.get_visible()
|
||||
if event.inaxes == self.ax[0]:
|
||||
cont, ind = self.low_dim_plt.contains(event)
|
||||
if cont:
|
||||
@ -155,9 +155,9 @@ class InteractivePlotter:
|
||||
|
||||
# draw bounding box around call
|
||||
self.ax[1].patches[0].remove()
|
||||
spec_width_orig = self.spec_slices[self.current_id].shape[
|
||||
1
|
||||
] / (1.0 + 2.0 * self.spec_pad)
|
||||
spec_width_orig = self.spec_slices[self.current_id].shape[1] / (
|
||||
1.0 + 2.0 * self.spec_pad
|
||||
)
|
||||
xx = w_diff + self.spec_pad * spec_width_orig
|
||||
ww = spec_width_orig
|
||||
yy = self.call_info[self.current_id]["low_freq"] / 1000
|
||||
@ -183,9 +183,7 @@ class InteractivePlotter:
|
||||
round(self.call_info[self.current_id]["start_time"], 3)
|
||||
)
|
||||
+ ", prob="
|
||||
+ str(
|
||||
round(self.call_info[self.current_id]["det_prob"], 3)
|
||||
)
|
||||
+ str(round(self.call_info[self.current_id]["det_prob"], 3))
|
||||
)
|
||||
self.ax[0].set_xlabel(info_str)
|
||||
|
||||
@ -8,7 +8,6 @@ Functions
|
||||
`write`: Write a numpy array as a WAV file.
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
@ -43,7 +42,7 @@ def _read_fmt_chunk(fid):
|
||||
size, comp, noc, rate, sbytes, ba, bits = res
|
||||
if comp not in KNOWN_WAVE_FORMATS or size > 16:
|
||||
comp = WAVE_FORMAT_PCM
|
||||
warnings.warn("Unknown wave file format", WavFileWarning, stacklevel=2)
|
||||
warnings.warn("Unknown wave file format", WavFileWarning)
|
||||
if size > 16:
|
||||
fid.read(size - 16)
|
||||
|
||||
@ -157,6 +156,7 @@ def read(filename, mmap=False):
|
||||
fid = open(filename, "rb")
|
||||
|
||||
try:
|
||||
|
||||
# some files seem to have the size recorded in the header greater than
|
||||
# the actual file size.
|
||||
fid.seek(0, os.SEEK_END)
|
||||
@ -1,20 +0,0 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= sphinx-build
|
||||
SOURCEDIR = source
|
||||
BUILDDIR = build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
@ -1,35 +0,0 @@
|
||||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=source
|
||||
set BUILDDIR=build
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.https://www.sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:end
|
||||
popd
|
||||
@ -1,78 +0,0 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# For the full list of built-in configuration values, see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
|
||||
project = "batdetect2"
|
||||
copyright = "2025, Oisin Mac Aodha, Santiago Martinez Balvanera"
|
||||
author = "Oisin Mac Aodha, Santiago Martinez Balvanera"
|
||||
release = "2.0.0b1"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
extensions = [
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinxcontrib.autodoc_pydantic",
|
||||
"sphinx_click",
|
||||
"numpydoc",
|
||||
"myst_parser",
|
||||
"sphinx_autodoc_typehints",
|
||||
]
|
||||
|
||||
templates_path = ["_templates"]
|
||||
exclude_patterns = []
|
||||
|
||||
source_suffix = {
|
||||
".rst": "restructuredtext",
|
||||
".txt": "markdown",
|
||||
".md": "markdown",
|
||||
}
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
html_theme = "sphinx_book_theme"
|
||||
html_static_path = ["_static"]
|
||||
html_theme_options = {
|
||||
"home_page_in_toc": True,
|
||||
"show_navbar_depth": 2,
|
||||
"show_toc_level": 2,
|
||||
}
|
||||
|
||||
intersphinx_mapping = {
|
||||
"python": ("https://docs.python.org/3", None),
|
||||
"click": ("https://click.palletsprojects.com/en/stable/", None),
|
||||
"librosa": ("https://librosa.org/doc/latest/", None),
|
||||
"lightning": ("https://lightning.ai/docs/pytorch/stable/", None),
|
||||
"loguru": ("https://loguru.readthedocs.io/en/stable/", None),
|
||||
"numpy": ("https://numpy.org/doc/stable/", None),
|
||||
"omegaconf": ("https://omegaconf.readthedocs.io/en/latest/", None),
|
||||
"pytorch": ("https://pytorch.org/docs/stable/", None),
|
||||
"soundevent": ("https://mbsantiago.github.io/soundevent/", None),
|
||||
"pydantic": ("https://docs.pydantic.dev/latest/", None),
|
||||
"xarray": ("https://docs.xarray.dev/en/stable/", None),
|
||||
}
|
||||
|
||||
# -- Options for autodoc ------------------------------------------------------
|
||||
autosummary_generate = False
|
||||
autosummary_imported_members = True
|
||||
|
||||
autodoc_default_options = {
|
||||
"members": True,
|
||||
"undoc-members": False,
|
||||
"private-members": False,
|
||||
"special-members": False,
|
||||
"inherited-members": False,
|
||||
"show-inheritance": True,
|
||||
"module-first": True,
|
||||
}
|
||||
|
||||
numpydoc_show_class_members = False
|
||||
numpydoc_show_inherited_class_members = False
|
||||
numpydoc_class_members_toctree = False
|
||||
@ -1,34 +0,0 @@
|
||||
# Development and contribution
|
||||
|
||||
Thanks for your interest in improving batdetect2.
|
||||
|
||||
## Ways to contribute
|
||||
|
||||
- Report bugs and request features on
|
||||
[GitHub Issues](https://github.com/macaodha/batdetect2/issues)
|
||||
- Improve docs by opening pull requests with clearer examples, fixes, or
|
||||
missing workflows
|
||||
- Contribute code for models, data handling, evaluation, or CLI workflows
|
||||
|
||||
## Basic contribution workflow
|
||||
|
||||
1. Open an issue (or comment on an existing one) so work is visible.
|
||||
2. Create a branch for your change.
|
||||
3. Run checks locally before opening a PR:
|
||||
|
||||
```bash
|
||||
just check
|
||||
just docs
|
||||
```
|
||||
|
||||
4. Open a pull request with a clear summary of what changed and why.
|
||||
|
||||
## Development environment
|
||||
|
||||
Use `uv` for dependency and environment management.
|
||||
|
||||
```bash
|
||||
uv sync
|
||||
```
|
||||
|
||||
For more setup details, see {doc}`../getting_started`.
|
||||
@ -1,48 +0,0 @@
|
||||
# Evaluation concepts and matching
|
||||
|
||||
Evaluation is not just "run predictions and compute one number".
|
||||
|
||||
The reported metric depends on the evaluation task, the matching rule, and the treatment of clip boundaries and generic labels.
|
||||
|
||||
## Task families answer different questions
|
||||
|
||||
Built-in task families include:
|
||||
|
||||
- sound event detection,
|
||||
- sound event classification,
|
||||
- top-class detection,
|
||||
- clip detection,
|
||||
- clip classification.
|
||||
|
||||
Choose the task that matches the scientific or engineering question.
|
||||
|
||||
## Matching matters
|
||||
|
||||
For sound-event-style tasks, predictions and annotations are matched using an affinity function.
|
||||
|
||||
Important controls include:
|
||||
|
||||
- `affinity`,
|
||||
- `affinity_threshold`,
|
||||
- `strict_match`,
|
||||
- `ignore_start_end`.
|
||||
|
||||
Small changes here can change the reported metric without changing the underlying predictions.
|
||||
|
||||
## Boundary handling matters
|
||||
|
||||
The evaluation base task can exclude events near clip boundaries through `ignore_start_end`.
|
||||
|
||||
This is useful when clip boundaries make matches ambiguous.
|
||||
|
||||
## Generic labels can matter in classification
|
||||
|
||||
Classification tasks can include or exclude generic targets depending on configuration.
|
||||
|
||||
That affects what counts as a valid class-level comparison.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Evaluate on a test set: {doc}`../tutorials/evaluate-on-a-test-set`
|
||||
- Evaluation config reference: {doc}`../reference/evaluation-config`
|
||||
- Model output and validation: {doc}`model-output-and-validation`
|
||||
@ -1,43 +0,0 @@
|
||||
# Extracted features and embeddings
|
||||
|
||||
The current API exposes a per-detection `features` vector.
|
||||
|
||||
Older BatDetect2 workflows also exposed concepts such as `cnn_feats`,
|
||||
`spec_features`, and `spec_slices`.
|
||||
|
||||
## What the current feature vector is
|
||||
|
||||
In the current stack, each retained detection can carry an internal feature
|
||||
representation produced by the model output pipeline.
|
||||
|
||||
This is useful for downstream exploration, comparison, and custom analysis.
|
||||
|
||||
## What these features are not
|
||||
|
||||
They are not automatically human-interpretable ecological variables.
|
||||
|
||||
They are also not a substitute for careful validation.
|
||||
|
||||
## Why people refer to them as embeddings
|
||||
|
||||
In practice, users often treat these feature vectors as embeddings because they
|
||||
can be used as dense learned representations of detections.
|
||||
|
||||
That usage is reasonable, but you should still treat them as model-derived
|
||||
internal representations whose meaning depends on the training setup.
|
||||
|
||||
## Legacy terminology versus current terminology
|
||||
|
||||
- legacy `cnn_feats` referred to CNN feature outputs in the older workflow,
|
||||
- legacy `spec_features` referred to lower-level extracted call features,
|
||||
- current `features` are the per-detection vectors attached to `Detection`
|
||||
objects.
|
||||
|
||||
These are related ideas, but not necessarily one-to-one replacements.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Inspect detection features in Python:
|
||||
{doc}`../how_to/inspect-detection-features-in-python`
|
||||
- Legacy migration guide:
|
||||
{doc}`../legacy/migration-guide`
|
||||
@ -1,19 +0,0 @@
|
||||
# Understanding
|
||||
|
||||
Understanding pages explain how BatDetect2 works, what its outputs mean, and how to reason about trade-offs.
|
||||
|
||||
Use this section when you want help interpreting the tool, not just running it.
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
what-batdetect2-predicts
|
||||
interpreting-formatted-outputs
|
||||
extracted-features-and-embeddings
|
||||
model-output-and-validation
|
||||
postprocessing-and-thresholds
|
||||
pipeline-overview
|
||||
preprocessing-consistency
|
||||
target-encoding-and-decoding
|
||||
evaluation-concepts-and-matching
|
||||
```
|
||||
@ -1,36 +0,0 @@
|
||||
# Interpreting formatted outputs
|
||||
|
||||
BatDetect2 can write predictions in several output formats.
|
||||
|
||||
Those formats are different views of the same underlying detections, not different model behaviors.
|
||||
|
||||
## Separate the underlying detection from the serialized file
|
||||
|
||||
Internally, the current stack works with clip-level detections containing geometry, detection score, class scores, and features.
|
||||
|
||||
Output formatters then serialize those detections in different ways.
|
||||
|
||||
## Raw outputs are richest
|
||||
|
||||
The `raw` format preserves the broadest structured view of detections and is a good default when you want to inspect or reload predictions later.
|
||||
|
||||
## Tabular outputs are for analysis convenience
|
||||
|
||||
The `parquet` format is convenient for data analysis workflows, but the tabular representation is only one projection of the underlying detection object.
|
||||
|
||||
## Legacy-shaped outputs are mainly for compatibility
|
||||
|
||||
The `batdetect2` formatter writes the older BatDetect2-style JSON shape.
|
||||
|
||||
Use it when you need compatibility with older downstream tools or workflows.
|
||||
|
||||
## The meaning does not come from the file extension
|
||||
|
||||
Do not assume that a `.json`, `.parquet`, or `.nc` file changes what the model predicted.
|
||||
|
||||
It changes how the prediction is packaged and how much detail is retained.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Output formats reference: {doc}`../reference/output-formats`
|
||||
- Outputs config reference: {doc}`../reference/outputs-config`
|
||||
@ -1,29 +0,0 @@
|
||||
# Model output and validation
|
||||
|
||||
BatDetect2 outputs model predictions, not ground truth. The same configuration
|
||||
can behave differently across recording conditions, species compositions, and
|
||||
acoustic environments.
|
||||
|
||||
## Why threshold choice matters
|
||||
|
||||
- Lower detection thresholds increase sensitivity but can increase false
|
||||
positives.
|
||||
- Higher thresholds reduce false positives but can miss faint calls.
|
||||
|
||||
No threshold is universally correct. The right setting depends on your survey
|
||||
objectives and tolerance for false positives versus missed detections.
|
||||
|
||||
## Why local validation is required
|
||||
|
||||
Model performance depends on how similar your data are to training data.
|
||||
Before ecological interpretation, validate predictions on a representative,
|
||||
locally reviewed subset.
|
||||
|
||||
Recommended validation checks:
|
||||
|
||||
1. Compare detection counts against expert-reviewed clips.
|
||||
2. Inspect species-level predictions for plausible confusion patterns.
|
||||
3. Repeat checks across sites, seasons, and recorder setups.
|
||||
|
||||
For practical threshold workflows, see
|
||||
{doc}`../how_to/tune-detection-threshold`.
|
||||
@ -1,34 +0,0 @@
|
||||
# Pipeline overview
|
||||
|
||||
batdetect2 processes recordings as a sequence of modules. Each stage has a
|
||||
clear role and configuration surface.
|
||||
|
||||
## End-to-end flow
|
||||
|
||||
1. Audio loading
|
||||
2. Preprocessing (waveform -> spectrogram)
|
||||
3. Detector forward pass
|
||||
4. Postprocessing (peaks, decoding, thresholds)
|
||||
5. Output formatting and export
|
||||
|
||||
## Why the modular design matters
|
||||
|
||||
The model, preprocessing, postprocessing, targets, and output formatting are
|
||||
configured separately. That makes it easier to:
|
||||
|
||||
- swap components without rewriting the whole pipeline,
|
||||
- keep experiments reproducible,
|
||||
- adapt workflows to new datasets.
|
||||
|
||||
## Core objects in the stack
|
||||
|
||||
- `BatDetect2API` orchestrates training, inference, and evaluation workflows.
|
||||
- `ModelConfig` defines architecture, preprocessing, postprocessing, and
|
||||
targets.
|
||||
- `Targets` controls event filtering, class encoding/decoding, and ROI mapping.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Preprocessing rationale: {doc}`preprocessing-consistency`
|
||||
- Postprocessing rationale: {doc}`postprocessing-and-thresholds`
|
||||
- Target rationale: {doc}`target-encoding-and-decoding`
|
||||
@ -1,43 +0,0 @@
|
||||
# Postprocessing and thresholds
|
||||
|
||||
After the detector runs on a spectrogram, the model output is still a set of
|
||||
dense prediction tensors. Postprocessing turns that into a final list of call
|
||||
detections with positions, sizes, and class scores.
|
||||
|
||||
## What postprocessing does
|
||||
|
||||
In broad terms, the pipeline:
|
||||
|
||||
1. suppresses nearby duplicate peaks,
|
||||
2. extracts candidate detections,
|
||||
3. reads size and class values at each detected location,
|
||||
4. decodes outputs into call-level predictions.
|
||||
|
||||
This is where score thresholds and output density limits are applied.
|
||||
|
||||
## Why thresholds matter
|
||||
|
||||
Thresholds control the balance between sensitivity and precision.
|
||||
|
||||
- Lower thresholds keep more detections, including weaker calls, but may add
|
||||
false positives.
|
||||
- Higher thresholds remove low-confidence detections, but may miss faint calls.
|
||||
|
||||
You can tune this behavior per run without retraining the model.
|
||||
|
||||
## Two common threshold controls
|
||||
|
||||
- `detection_threshold`: minimum score required to keep a detection.
|
||||
- `classification_threshold`: minimum class score used when assigning class
|
||||
labels.
|
||||
|
||||
Both settings shape the final output and should be validated on reviewed local
|
||||
data.
|
||||
|
||||
## Practical workflow
|
||||
|
||||
Tune thresholds on a representative subset first, then lock settings for the
|
||||
full analysis run.
|
||||
|
||||
- How-to: {doc}`../how_to/tune-detection-threshold`
|
||||
- CLI reference: {doc}`../reference/cli/predict`
|
||||
@ -1,36 +0,0 @@
|
||||
# Preprocessing consistency
|
||||
|
||||
Preprocessing consistency is one of the biggest factors behind stable model
|
||||
performance.
|
||||
|
||||
## Why consistency matters
|
||||
|
||||
The detector is trained on spectrograms produced by a specific preprocessing
|
||||
pipeline. If inference uses different settings, the model can see a shifted
|
||||
input distribution and performance may drop.
|
||||
|
||||
Typical mismatch sources:
|
||||
|
||||
- sample-rate differences,
|
||||
- changed frequency crop,
|
||||
- changed STFT window/hop,
|
||||
- changed spectrogram transforms.
|
||||
|
||||
## Practical implication
|
||||
|
||||
When possible, keep preprocessing settings aligned between:
|
||||
|
||||
- training,
|
||||
- evaluation,
|
||||
- deployment inference.
|
||||
|
||||
If you intentionally change preprocessing, treat this as a new experiment and
|
||||
re-validate on reviewed local data.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Configure audio preprocessing:
|
||||
{doc}`../how_to/configure-audio-preprocessing`
|
||||
- Configure spectrogram preprocessing:
|
||||
{doc}`../how_to/configure-spectrogram-preprocessing`
|
||||
- Preprocessing config reference: {doc}`../reference/preprocessing-config`
|
||||
@ -1,40 +0,0 @@
|
||||
# Target encoding and decoding
|
||||
|
||||
batdetect2 turns annotated sound events into training targets, then maps model
|
||||
outputs back into interpretable predictions.
|
||||
|
||||
## Encoding path (annotations -> model targets)
|
||||
|
||||
At training time, the target system:
|
||||
|
||||
1. checks whether an event belongs to the configured detection target,
|
||||
2. assigns a classification label (or none for non-specific class matches),
|
||||
3. maps event geometry into position and size targets.
|
||||
|
||||
This behaviour is configured through `TargetConfig`,
|
||||
`TargetClassConfig`, and ROI mapper settings.
|
||||
|
||||
## Decoding path (model outputs -> tags and geometry)
|
||||
|
||||
At inference time, class labels and ROI parameters are decoded back into
|
||||
annotation tags and geometry.
|
||||
|
||||
This makes outputs interpretable in the same conceptual space as your original
|
||||
annotations.
|
||||
|
||||
## Why this matters
|
||||
|
||||
Target definitions are not just metadata. They directly shape:
|
||||
|
||||
- what events are treated as positive examples,
|
||||
- which class names the model learns,
|
||||
- how geometry is represented and reconstructed.
|
||||
|
||||
Small changes here can alter both training outcomes and prediction semantics.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Configure detection target logic: {doc}`../how_to/configure-target-definitions`
|
||||
- Configure class mapping: {doc}`../how_to/define-target-classes`
|
||||
- Configure ROI mapping: {doc}`../how_to/configure-roi-mapping`
|
||||
- Target config reference: {doc}`../reference/targets-config-workflow`
|
||||
@ -1,45 +0,0 @@
|
||||
# What BatDetect2 predicts
|
||||
|
||||
BatDetect2 predicts call-level events, not recording-level truth.
|
||||
|
||||
For each retained detection, the current stack can expose:
|
||||
|
||||
- a geometry describing where the event sits in time-frequency space,
|
||||
- a detection score,
|
||||
- a class-score vector,
|
||||
- an internal feature vector.
|
||||
|
||||
## Detection score versus class scores
|
||||
|
||||
These are different outputs and should not be interpreted as the same thing.
|
||||
|
||||
- The detection score is about whether the event is kept as a detection.
|
||||
- The class-score vector ranks classes for that detected event.
|
||||
|
||||
A detection can be kept while still having uncertain class identity.
|
||||
|
||||
## Predictions are conditional on the workflow
|
||||
|
||||
The final output also depends on:
|
||||
|
||||
- preprocessing,
|
||||
- postprocessing,
|
||||
- thresholds,
|
||||
- target definitions,
|
||||
- output transforms.
|
||||
|
||||
That is why two runs can differ even when they use the same checkpoint.
|
||||
|
||||
## What BatDetect2 does not predict
|
||||
|
||||
BatDetect2 does not directly output ecological truth.
|
||||
|
||||
It also does not eliminate the need for local validation.
|
||||
|
||||
Use reviewed local data before making ecological claims.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Model output and validation: {doc}`model-output-and-validation`
|
||||
- Postprocessing and thresholds: {doc}`postprocessing-and-thresholds`
|
||||
- Interpreting formatted outputs: {doc}`interpreting-formatted-outputs`
|
||||
@ -1,81 +0,0 @@
|
||||
# FAQ
|
||||
|
||||
## Installation and setup
|
||||
|
||||
### Do I need Python knowledge to use batdetect2?
|
||||
|
||||
Not much.
|
||||
If you only want to run the model on your own recordings, you can use the CLI and follow the steps in {doc}`getting_started`.
|
||||
|
||||
Some command-line familiarity helps, but you do not need to write Python code for standard inference workflows.
|
||||
|
||||
### Are there plans for an R version?
|
||||
|
||||
Not currently.
|
||||
Output files are plain formats (for example CSV/JSON), so you can read and analyze them in R or other environments.
|
||||
|
||||
### I cannot get installation working. What should I do?
|
||||
|
||||
First, re-check {doc}`getting_started` and confirm your environment is active.
|
||||
If it still fails, open an issue with your OS, install method, and full error output: [GitHub Issues](https://github.com/macaodha/batdetect2/issues).
|
||||
|
||||
## Model behavior and performance
|
||||
|
||||
### The model does not perform well on my data
|
||||
|
||||
This usually means your data distribution differs from training data.
|
||||
The best next step is to validate on reviewed local data and then fine-tune/train on your own annotations if needed.
|
||||
|
||||
### The model confuses insects/noise with bats
|
||||
|
||||
This can happen, especially when recording conditions differ from training conditions.
|
||||
Threshold tuning and training with local annotations can improve results.
|
||||
|
||||
See {doc}`how_to/tune-detection-threshold`.
|
||||
|
||||
### The model struggles with feeding buzzes or social calls
|
||||
|
||||
This is a known limitation of available training data in some settings.
|
||||
If you have high-quality annotated examples, they are valuable for improving models.
|
||||
|
||||
### Calls in the same sequence are predicted as different species
|
||||
|
||||
Currently we do not do any sophisticated post processing on the results output by the model.
|
||||
We return a probability associated with each species for each call.
|
||||
You can use these predictions to clean up the noisy predictions for sequences of calls.
|
||||
|
||||
### Can I trust model outputs for biodiversity conclusions?
|
||||
|
||||
The models developed and shared as part of this repository should be used with caution.
|
||||
While they have been evaluated on held out audio data, great care should be taken when using the model outputs for any form of biodiversity assessment.
|
||||
Your data may differ, and as a result it is very strongly recommended that you validate the model first using data with known species to ensure that the outputs can be trusted.
|
||||
|
||||
### The pipeline is slow
|
||||
|
||||
Runtime depends on hardware and recording duration.
|
||||
GPU inference is often much faster than CPU.
|
||||
|
||||
## Training and scope
|
||||
|
||||
### Can I train on my own species set?
|
||||
|
||||
Yes.
|
||||
You can train/fine-tune with your own annotated data and species labels.
|
||||
|
||||
### Does this work on frequency-division or zero-crossing recordings?
|
||||
|
||||
Not directly.
|
||||
The workflow assumes audio can be converted to spectrograms from the raw waveform.
|
||||
|
||||
### Can this be used for non-bat bioacoustics (for example insects or birds)?
|
||||
|
||||
Potentially yes, but expect retraining and configuration changes.
|
||||
Open an issue if you want guidance for a specific use case.
|
||||
|
||||
## Usage and licensing
|
||||
|
||||
### Can I use this for commercial purposes?
|
||||
|
||||
No.
|
||||
This project is currently for non-commercial use.
|
||||
See the repository license for details.
|
||||
@ -1,91 +0,0 @@
|
||||
# Getting started
|
||||
|
||||
BatDetect2 can be used in two ways: through the `batdetect2` command line interface (CLI), or as the `batdetect2` Python package.
|
||||
The CLI route does not require coding.
|
||||
You run commands in the terminal and, in some cases, write configuration files.
|
||||
The Python route gives you more flexibility and lets you integrate the model into your own workflows or experiments.
|
||||
For most common use cases, both routes give you the same results.
|
||||
|
||||
## Try it out
|
||||
|
||||
If you want to try BatDetect2 before installing anything locally:
|
||||
|
||||
- [Hugging Face demo (UK species)](https://huggingface.co/spaces/macaodha/batdetect2)
|
||||
- [Google Colab notebook](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb)
|
||||
|
||||
## Installation
|
||||
|
||||
To use `batdetect2` on your machine, you need to install it first.
|
||||
We recommend using `uv` for that.
|
||||
`uv` is a tool that helps manage Python software cleanly, without mixing it into the rest of your machine.
|
||||
Install `uv` first by following the [installation instructions](https://docs.astral.sh/uv/getting-started/installation/).
|
||||
|
||||
### One-off usage
|
||||
|
||||
If you are not ready to install `batdetect2` permanently, you can try it with:
|
||||
|
||||
```bash
|
||||
uvx batdetect2
|
||||
```
|
||||
|
||||
This still downloads the code and dependencies and runs them on your machine, but the environment is temporary.
|
||||
|
||||
### Install the CLI
|
||||
|
||||
If you want the `batdetect2` CLI to always be available in your terminal, run:
|
||||
|
||||
```bash
|
||||
uv tool install batdetect2
|
||||
```
|
||||
|
||||
If you need to upgrade later:
|
||||
|
||||
```bash
|
||||
uv tool upgrade batdetect2
|
||||
```
|
||||
|
||||
Verify the CLI is available:
|
||||
|
||||
```bash
|
||||
batdetect2
|
||||
```
|
||||
|
||||
You can then run your first workflow.
|
||||
See {doc}`tutorials/run-inference-on-folder` for more details.
|
||||
|
||||
### Add it to your Python project
|
||||
|
||||
If you are using BatDetect2 from Python code and already manage your projects with `uv`, you can add it with:
|
||||
|
||||
```bash
|
||||
uv add batdetect2
|
||||
```
|
||||
|
||||
If you want to upgrade it later:
|
||||
|
||||
```bash
|
||||
uv add -U batdetect2
|
||||
```
|
||||
|
||||
#### Alternative with `pip`
|
||||
|
||||
If you prefer `pip`, you can use:
|
||||
|
||||
```bash
|
||||
pip install batdetect2
|
||||
```
|
||||
|
||||
It is a good idea to create a separate virtual environment first so this does not interfere with other Python environments.
|
||||
|
||||
```bash
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
## What's next
|
||||
|
||||
- Run your first workflow on a folder of recordings: {doc}`tutorials/run-inference-on-folder`
|
||||
- If you write code and want the Python route: {doc}`tutorials/integrate-with-a-python-pipeline`
|
||||
- For common practical tasks, go to {doc}`how_to/index`
|
||||
- For detailed command help, go to {doc}`reference/cli/index`
|
||||
- To understand the model and its outputs, go to {doc}`explanation/index`
|
||||
@ -1,112 +0,0 @@
|
||||
# How to choose a model
|
||||
|
||||
Use this guide when you want to choose which model checkpoint BatDetect2 loads.
|
||||
|
||||
You can choose a model in both the CLI and the Python API.
|
||||
|
||||
## Where you can choose the model
|
||||
|
||||
In the CLI, use `--model` with commands that load a checkpoint, including:
|
||||
|
||||
- `batdetect2 process`
|
||||
- `batdetect2 evaluate`
|
||||
- `batdetect2 train`
|
||||
- `batdetect2 finetune`
|
||||
|
||||
In Python, pass the model source to `BatDetect2API.from_checkpoint(...)`.
|
||||
|
||||
If you do not choose a model, BatDetect2 uses the built-in default UK model.
|
||||
|
||||
## Use a local checkpoint path
|
||||
|
||||
Use a local path when you already have a checkpoint file on disk.
|
||||
|
||||
CLI example:
|
||||
|
||||
```bash
|
||||
batdetect2 process directory \
|
||||
path/to/audio \
|
||||
path/to/outputs \
|
||||
--model path/to/model.ckpt
|
||||
```
|
||||
|
||||
Python example:
|
||||
|
||||
```python
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
|
||||
api = BatDetect2API.from_checkpoint("path/to/model.ckpt")
|
||||
```
|
||||
|
||||
## Use a bundled checkpoint alias
|
||||
|
||||
BatDetect2 also supports bundled checkpoint aliases.
|
||||
|
||||
The built-in UK model is available as `uk_same`.
|
||||
The alias `batdetect2_uk_same` also works.
|
||||
|
||||
CLI example:
|
||||
|
||||
```bash
|
||||
batdetect2 process directory \
|
||||
path/to/audio \
|
||||
path/to/outputs \
|
||||
--model uk_same
|
||||
```
|
||||
|
||||
Python example:
|
||||
|
||||
```python
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
|
||||
api = BatDetect2API.from_checkpoint("uk_same")
|
||||
```
|
||||
|
||||
## Use a Hugging Face URI
|
||||
|
||||
You can also load a checkpoint from Hugging Face with a URI like:
|
||||
|
||||
```text
|
||||
hf://owner/repo/path/to/model.ckpt
|
||||
```
|
||||
|
||||
This needs the optional Hugging Face dependency to be installed.
|
||||
For example, install it with `pip install batdetect2[huggingface]`.
|
||||
|
||||
CLI example:
|
||||
|
||||
```bash
|
||||
batdetect2 process directory \
|
||||
path/to/audio \
|
||||
path/to/outputs \
|
||||
--model hf://owner/repo/path/to/model.ckpt
|
||||
```
|
||||
|
||||
Python example:
|
||||
|
||||
```python
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
"hf://owner/repo/path/to/model.ckpt"
|
||||
)
|
||||
```
|
||||
|
||||
## Choose the right source
|
||||
|
||||
- Use a local path when you already have a checkpoint file.
|
||||
- Use an alias when you want one of the bundled models.
|
||||
- Use a Hugging Face URI when the checkpoint lives in a Hugging Face repo.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Run inference on a folder:
|
||||
{doc}`../tutorials/run-inference-on-folder`
|
||||
- `BatDetect2API` reference:
|
||||
{doc}`../reference/api`
|
||||
- Process command reference:
|
||||
{doc}`../reference/cli/predict`
|
||||
- Train a custom model:
|
||||
{doc}`../tutorials/train-a-custom-model`
|
||||
- Fine-tune from a checkpoint:
|
||||
{doc}`fine-tune-from-a-checkpoint`
|
||||
@ -1,71 +0,0 @@
|
||||
# How to choose an inference input mode
|
||||
|
||||
Use this guide to decide whether `process directory`, `process file_list`, or
|
||||
`process dataset` is the right entry point for your run.
|
||||
|
||||
## Use `process directory` when the recordings already live together
|
||||
|
||||
This is the simplest choice.
|
||||
|
||||
Use it when:
|
||||
|
||||
- your recordings are already organized in one directory tree,
|
||||
- you want BatDetect2 to discover audio files for you,
|
||||
- you are doing a first pass over a folder of recordings.
|
||||
|
||||
```bash
|
||||
batdetect2 process directory \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_dir \
|
||||
path/to/outputs
|
||||
```
|
||||
|
||||
## Use `process file_list` when you need explicit control over the file set
|
||||
|
||||
Use it when:
|
||||
|
||||
- you want to run only a selected subset,
|
||||
- your files are spread across directories,
|
||||
- another tool has already produced the exact list of recordings to process.
|
||||
|
||||
The list file should contain one path per line.
|
||||
|
||||
```bash
|
||||
batdetect2 process file_list \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_files.txt \
|
||||
path/to/outputs
|
||||
```
|
||||
|
||||
## Use `process dataset` when your workflow is already annotation-set driven
|
||||
|
||||
Use it when:
|
||||
|
||||
- your project already has a `soundevent` annotation set,
|
||||
- you want prediction runs aligned with that annotation metadata,
|
||||
- you want BatDetect2 to resolve recording paths from the annotation set.
|
||||
|
||||
```bash
|
||||
batdetect2 process dataset \
|
||||
path/to/model.ckpt \
|
||||
path/to/annotation_set.json \
|
||||
path/to/outputs
|
||||
```
|
||||
|
||||
The dataset command reads a `soundevent` annotation set and extracts unique
|
||||
recording paths before inference.
|
||||
|
||||
## Rule of thumb
|
||||
|
||||
- Start with `directory` for the easiest first run.
|
||||
- Use `file_list` when selection matters.
|
||||
- Use `dataset` when the rest of your workflow is already dataset-based.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Run batch predictions:
|
||||
{doc}`run-batch-predictions`
|
||||
- Tune inference clipping:
|
||||
{doc}`tune-inference-clipping`
|
||||
- Process command reference:
|
||||
{doc}`../reference/cli/predict`
|
||||
@ -1,74 +0,0 @@
|
||||
# How to choose and configure evaluation tasks
|
||||
|
||||
Use this guide when the default evaluation tasks do not match the question you
|
||||
want to answer.
|
||||
|
||||
## Know the default first
|
||||
|
||||
By default, BatDetect2 evaluation starts with:
|
||||
|
||||
- sound event detection,
|
||||
- sound event classification.
|
||||
|
||||
Those are good defaults for many projects, but not for all of them.
|
||||
|
||||
## Choose the task that matches the question
|
||||
|
||||
Common built-in task families include:
|
||||
|
||||
- `sound_event_detection`
|
||||
- `sound_event_classification`
|
||||
- `top_class_detection`
|
||||
- `clip_detection`
|
||||
- `clip_classification`
|
||||
|
||||
Choose based on the question you care about.
|
||||
|
||||
- Use sound-event tasks when you care about individual call events.
|
||||
- Use clip tasks when you care about clip-level presence or clip-level class
|
||||
evidence.
|
||||
- Use top-class detection when you want matching based on the highest-scoring
|
||||
class per detection.
|
||||
|
||||
## Configure tasks in `EvaluationConfig`
|
||||
|
||||
Example:
|
||||
|
||||
```yaml
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
prefix: detection
|
||||
affinity_threshold: 0.0
|
||||
strict_match: true
|
||||
- name: clip_classification
|
||||
prefix: clip_classification
|
||||
```
|
||||
|
||||
Pass the config with:
|
||||
|
||||
```bash
|
||||
batdetect2 evaluate \
|
||||
path/to/test_dataset.yaml \
|
||||
--model path/to/model.ckpt \
|
||||
--base-dir path/to/project_root \
|
||||
--evaluation-config path/to/evaluation.yaml
|
||||
```
|
||||
|
||||
Include `--base-dir` when the dataset config resolves recordings through
|
||||
relative paths.
|
||||
|
||||
## Change one thing at a time
|
||||
|
||||
When comparing models or settings, avoid changing task definitions, thresholds,
|
||||
matching behavior, and datasets all at once.
|
||||
|
||||
Otherwise it becomes hard to explain why the metric changed.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Evaluation tutorial:
|
||||
{doc}`../tutorials/evaluate-on-a-test-set`
|
||||
- Evaluation config reference:
|
||||
{doc}`../reference/evaluation-config`
|
||||
- Evaluation concepts:
|
||||
{doc}`../explanation/evaluation-concepts-and-matching`
|
||||
@ -1,53 +0,0 @@
|
||||
# How to configure an AOEF dataset source
|
||||
|
||||
Use this guide when your annotations are stored in AOEF/soundevent JSON files,
|
||||
including exports from Whombat.
|
||||
|
||||
## 1) Add an AOEF source entry
|
||||
|
||||
In your dataset config, add a source with `format: aoef`.
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
- name: my_aoef_source
|
||||
format: aoef
|
||||
audio_dir: /path/to/audio
|
||||
annotations_path: /path/to/annotations.soundevent.json
|
||||
```
|
||||
|
||||
## 2) Choose filtering behavior for annotation projects
|
||||
|
||||
If `annotations_path` is an `AnnotationProject`, you can filter by task state.
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
- name: whombat_verified
|
||||
format: aoef
|
||||
audio_dir: /path/to/audio
|
||||
annotations_path: /path/to/project_export.aoef
|
||||
filter:
|
||||
only_completed: true
|
||||
only_verified: true
|
||||
exclude_issues: true
|
||||
```
|
||||
|
||||
If you omit `filter`, default project filtering is applied.
|
||||
|
||||
To disable filtering for project files:
|
||||
|
||||
```yaml
|
||||
filter: null
|
||||
```
|
||||
|
||||
## 3) Check that the source loads
|
||||
|
||||
Run a summary on your dataset config:
|
||||
|
||||
```bash
|
||||
batdetect2 data summary path/to/dataset.yaml
|
||||
```
|
||||
|
||||
## 4) Continue to training or evaluation
|
||||
|
||||
- For training: {doc}`../tutorials/train-a-custom-model`
|
||||
- For field-level reference: {doc}`../reference/data-sources`
|
||||
@ -1,66 +0,0 @@
|
||||
# How to configure audio preprocessing
|
||||
|
||||
Use this guide to set sample-rate and waveform-level preprocessing behaviour.
|
||||
|
||||
## 1) Set audio loader settings
|
||||
|
||||
The audio loader config controls resampling.
|
||||
|
||||
```yaml
|
||||
samplerate: 256000
|
||||
resample:
|
||||
enabled: true
|
||||
method: poly
|
||||
```
|
||||
|
||||
If your recordings are already at the expected sample rate, you can disable
|
||||
resampling.
|
||||
|
||||
```yaml
|
||||
samplerate: 256000
|
||||
resample:
|
||||
enabled: false
|
||||
```
|
||||
|
||||
## 2) Set waveform transforms in preprocessing config
|
||||
|
||||
Waveform transforms are configured in `preprocess.audio_transforms`.
|
||||
|
||||
```yaml
|
||||
preprocess:
|
||||
audio_transforms:
|
||||
- name: center_audio
|
||||
- name: scale_audio
|
||||
- name: fix_duration
|
||||
duration: 0.5
|
||||
```
|
||||
|
||||
Available built-ins:
|
||||
|
||||
- `center_audio`
|
||||
- `scale_audio`
|
||||
- `fix_duration`
|
||||
|
||||
## 3) Use the config in your workflow
|
||||
|
||||
For CLI inference/evaluation, use `--audio-config`.
|
||||
|
||||
```bash
|
||||
batdetect2 process directory \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_dir \
|
||||
path/to/outputs \
|
||||
--audio-config path/to/audio.yaml
|
||||
```
|
||||
|
||||
## 4) Verify quickly on a small subset
|
||||
|
||||
Run on a small folder first and confirm that outputs and runtime are as expected
|
||||
before full-batch runs.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Spectrogram settings:
|
||||
{doc}`configure-spectrogram-preprocessing`
|
||||
- Preprocessing config reference:
|
||||
{doc}`../reference/preprocessing-config`
|
||||
@ -1,57 +0,0 @@
|
||||
# How to configure ROI mapping
|
||||
|
||||
Use this guide to control how annotation geometry is encoded into training
|
||||
targets and decoded back into boxes.
|
||||
|
||||
## 1) Set the default ROI mapper
|
||||
|
||||
The default mapper is `anchor_bbox`.
|
||||
|
||||
```yaml
|
||||
roi:
|
||||
default:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
time_scale: 1000.0
|
||||
frequency_scale: 0.001163
|
||||
```
|
||||
|
||||
## 2) Choose an anchor strategy
|
||||
|
||||
Typical options include `bottom-left` and `center`.
|
||||
|
||||
- `bottom-left` is the current default.
|
||||
- `center` can be easier to reason about in some workflows.
|
||||
|
||||
## 3) Set scale factors intentionally
|
||||
|
||||
- `time_scale` controls width scaling.
|
||||
- `frequency_scale` controls height scaling.
|
||||
|
||||
Use values that are consistent with your model setup and keep them fixed when
|
||||
comparing experiments.
|
||||
|
||||
## 4) (Optional) override ROI mapping for specific classes
|
||||
|
||||
Add class-specific mappers under `roi.overrides`.
|
||||
|
||||
```yaml
|
||||
roi:
|
||||
default:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
time_scale: 1000.0
|
||||
frequency_scale: 0.001163
|
||||
overrides:
|
||||
species_x:
|
||||
name: anchor_bbox
|
||||
anchor: center
|
||||
time_scale: 1000.0
|
||||
frequency_scale: 0.001163
|
||||
```
|
||||
|
||||
## Related pages
|
||||
|
||||
- Target definitions: {doc}`configure-target-definitions`
|
||||
- Class definitions: {doc}`define-target-classes`
|
||||
- Target encoding overview: {doc}`../explanation/target-encoding-and-decoding`
|
||||
@ -1,59 +0,0 @@
|
||||
# How to configure spectrogram preprocessing
|
||||
|
||||
Use this guide to set STFT, frequency range, and spectrogram transforms.
|
||||
|
||||
## 1) Configure STFT and frequency range
|
||||
|
||||
```yaml
|
||||
preprocess:
|
||||
stft:
|
||||
window_duration: 0.002
|
||||
window_overlap: 0.75
|
||||
window_fn: hann
|
||||
frequencies:
|
||||
min_freq: 10000
|
||||
max_freq: 120000
|
||||
```
|
||||
|
||||
## 2) Configure spectrogram transforms
|
||||
|
||||
`spectrogram_transforms` are applied in order.
|
||||
|
||||
```yaml
|
||||
preprocess:
|
||||
spectrogram_transforms:
|
||||
- name: pcen
|
||||
time_constant: 0.4
|
||||
gain: 0.98
|
||||
bias: 2.0
|
||||
power: 0.5
|
||||
- name: spectral_mean_subtraction
|
||||
- name: scale_amplitude
|
||||
scale: db
|
||||
```
|
||||
|
||||
Common built-ins:
|
||||
|
||||
- `pcen`
|
||||
- `spectral_mean_subtraction`
|
||||
- `scale_amplitude` (`db` or `power`)
|
||||
- `peak_normalize`
|
||||
|
||||
## 3) Configure output size
|
||||
|
||||
```yaml
|
||||
preprocess:
|
||||
size:
|
||||
height: 128
|
||||
resize_factor: 0.5
|
||||
```
|
||||
|
||||
## 4) Keep train and inference settings aligned
|
||||
|
||||
Use the same preprocessing setup for training and prediction whenever possible.
|
||||
Large mismatches can degrade model performance.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Why consistency matters: {doc}`../explanation/preprocessing-consistency`
|
||||
- Preprocessing config reference: {doc}`../reference/preprocessing-config`
|
||||
@ -1,58 +0,0 @@
|
||||
# How to configure target definitions
|
||||
|
||||
Use this guide to define which annotated sound events are considered valid
|
||||
detection targets.
|
||||
|
||||
## 1) Start from a targets config file
|
||||
|
||||
```yaml
|
||||
detection_target:
|
||||
name: bat
|
||||
match_if:
|
||||
name: has_tag
|
||||
tag:
|
||||
key: call_type
|
||||
value: Echolocation
|
||||
assign_tags:
|
||||
- key: call_type
|
||||
value: Echolocation
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
```
|
||||
|
||||
`match_if` decides whether an annotation is included in the detection target.
|
||||
|
||||
## 2) Use condition combinators when needed
|
||||
|
||||
You can combine conditions with `all_of`, `any_of`, and `not`.
|
||||
|
||||
```yaml
|
||||
detection_target:
|
||||
name: bat
|
||||
match_if:
|
||||
name: all_of
|
||||
conditions:
|
||||
- name: has_tag
|
||||
tag:
|
||||
key: call_type
|
||||
value: Echolocation
|
||||
- name: not
|
||||
condition:
|
||||
name: has_any_tag
|
||||
tags:
|
||||
- key: call_type
|
||||
value: Social
|
||||
- key: class
|
||||
value: Not Bat
|
||||
```
|
||||
|
||||
## 3) Verify with a small sample first
|
||||
|
||||
Before full training, inspect a small annotation subset and confirm that the
|
||||
selection logic keeps the events you expect.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Class mapping: {doc}`define-target-classes`
|
||||
- ROI mapping: {doc}`configure-roi-mapping`
|
||||
- Targets reference: {doc}`../reference/targets-config-workflow`
|
||||
@ -1,59 +0,0 @@
|
||||
# How to define target classes
|
||||
|
||||
Use this guide to map annotations to classification labels used during
|
||||
training.
|
||||
|
||||
## 1) Add classification target entries
|
||||
|
||||
Each entry defines a class name and matching tags.
|
||||
|
||||
```yaml
|
||||
classification_targets:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: class
|
||||
value: Pipistrellus pipistrellus
|
||||
- name: pippyg
|
||||
tags:
|
||||
- key: class
|
||||
value: Pipistrellus pygmaeus
|
||||
```
|
||||
|
||||
## 2) Use `assign_tags` to control decoded output tags
|
||||
|
||||
If you want prediction output tags to differ from matching tags, set
|
||||
`assign_tags` explicitly.
|
||||
|
||||
```yaml
|
||||
classification_targets:
|
||||
- name: pipistrelle_group
|
||||
tags:
|
||||
- key: class
|
||||
value: Pipistrellus pipistrellus
|
||||
assign_tags:
|
||||
- key: genus
|
||||
value: Pipistrellus
|
||||
```
|
||||
|
||||
## 3) Use `match_if` for complex class rules
|
||||
|
||||
For advanced conditions, use `match_if` instead of `tags`.
|
||||
|
||||
```yaml
|
||||
classification_targets:
|
||||
- name: long_call
|
||||
match_if:
|
||||
name: duration
|
||||
operator: gt
|
||||
seconds: 0.02
|
||||
```
|
||||
|
||||
## 4) Confirm class names are unique
|
||||
|
||||
`classification_targets.name` values must be unique.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Detection-target filtering: {doc}`configure-target-definitions`
|
||||
- ROI mapping: {doc}`configure-roi-mapping`
|
||||
- Targets config reference: {doc}`../reference/targets-config-workflow`
|
||||
@ -1,45 +0,0 @@
|
||||
# How to fine-tune from a checkpoint
|
||||
|
||||
Use this guide when you want to continue from an existing checkpoint instead of training a fresh model config.
|
||||
|
||||
## Use `--model` for checkpoint-based training
|
||||
|
||||
Pass a checkpoint with `--model`.
|
||||
|
||||
Do not combine `--model` with `--model-config`.
|
||||
|
||||
```bash
|
||||
batdetect2 train \
|
||||
path/to/train_dataset.yaml \
|
||||
--val-dataset path/to/val_dataset.yaml \
|
||||
--model path/to/model.ckpt \
|
||||
--training-config path/to/training.yaml
|
||||
```
|
||||
|
||||
## Keep targets and preprocessing aligned
|
||||
|
||||
If you override targets or audio-related settings while fine-tuning, validate that they still match the checkpoint and your dataset.
|
||||
|
||||
Mismatches here can produce confusing failures or invalid comparisons.
|
||||
|
||||
## Decide what question the fine-tune should answer
|
||||
|
||||
Common fine-tuning goals are:
|
||||
|
||||
- adapting to local recording conditions,
|
||||
- adapting to a new label set,
|
||||
- improving performance on a narrower deployment context.
|
||||
|
||||
Make that goal explicit before comparing results.
|
||||
|
||||
## Evaluate after fine-tuning
|
||||
|
||||
Always compare the fine-tuned checkpoint against a held-out dataset.
|
||||
|
||||
Use the same evaluation setup when comparing before and after.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Training tutorial: {doc}`../tutorials/train-a-custom-model`
|
||||
- Evaluate a test set: {doc}`../tutorials/evaluate-on-a-test-set`
|
||||
- Train command reference: {doc}`../reference/cli/train`
|
||||
@ -1,66 +0,0 @@
|
||||
# How to import legacy batdetect2 annotations
|
||||
|
||||
Use this guide if your annotations are in older batdetect2 JSON formats.
|
||||
|
||||
Two legacy formats are supported:
|
||||
|
||||
- `batdetect2`: one annotation JSON file per recording
|
||||
- `batdetect2_file`: one merged JSON file for many recordings
|
||||
|
||||
## 1) Choose the correct source format
|
||||
|
||||
Directory-based annotations (`format: batdetect2`):
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
- name: legacy_per_file
|
||||
format: batdetect2
|
||||
audio_dir: /path/to/audio
|
||||
annotations_dir: /path/to/annotation_json_dir
|
||||
```
|
||||
|
||||
Merged annotation file (`format: batdetect2_file`):
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
- name: legacy_merged
|
||||
format: batdetect2_file
|
||||
audio_dir: /path/to/audio
|
||||
annotations_path: /path/to/merged_annotations.json
|
||||
```
|
||||
|
||||
## 2) Set optional legacy filters
|
||||
|
||||
Legacy filters are based on `annotated` and `issues` flags.
|
||||
|
||||
```yaml
|
||||
filter:
|
||||
only_annotated: true
|
||||
exclude_issues: true
|
||||
```
|
||||
|
||||
To load all entries regardless of flags:
|
||||
|
||||
```yaml
|
||||
filter: null
|
||||
```
|
||||
|
||||
## 3) Validate and convert if needed
|
||||
|
||||
Check loaded records:
|
||||
|
||||
```bash
|
||||
batdetect2 data summary path/to/dataset.yaml
|
||||
```
|
||||
|
||||
Convert to annotation-set output for downstream tooling:
|
||||
|
||||
```bash
|
||||
batdetect2 data convert path/to/dataset.yaml --output path/to/output.json
|
||||
```
|
||||
|
||||
## 4) Continue with current workflows
|
||||
|
||||
- Run predictions: {doc}`run-batch-predictions`
|
||||
- Train on imported data: {doc}`../tutorials/train-a-custom-model`
|
||||
- Field-level reference: {doc}`../reference/data-sources`
|
||||
@ -1,30 +0,0 @@
|
||||
# How-to Guides
|
||||
|
||||
How-to guides help you answer practical questions once you are past the first
|
||||
tutorial.
|
||||
|
||||
Use this section when you already know the basic workflow and want help with one
|
||||
specific task.
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
choose-a-model
|
||||
choose-an-inference-input-mode
|
||||
run-batch-predictions
|
||||
tune-inference-clipping
|
||||
tune-detection-threshold
|
||||
inspect-class-scores-in-python
|
||||
inspect-detection-features-in-python
|
||||
save-predictions-in-different-output-formats
|
||||
fine-tune-from-a-checkpoint
|
||||
choose-and-configure-evaluation-tasks
|
||||
interpret-evaluation-outputs
|
||||
configure-aoef-dataset
|
||||
import-legacy-batdetect2-annotations
|
||||
configure-audio-preprocessing
|
||||
configure-spectrogram-preprocessing
|
||||
configure-target-definitions
|
||||
define-target-classes
|
||||
configure-roi-mapping
|
||||
```
|
||||
@ -1,44 +0,0 @@
|
||||
# How to inspect class scores in Python
|
||||
|
||||
Use this guide when you need more than the top class label for each detection.
|
||||
|
||||
## Get the ranked class scores
|
||||
|
||||
`BatDetect2API.get_class_scores` returns `(class_name, score)` pairs for one detection.
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
|
||||
api = BatDetect2API.from_checkpoint(Path("path/to/model.ckpt"))
|
||||
prediction = api.process_file(Path("path/to/audio.wav"))
|
||||
|
||||
for detection in prediction.detections:
|
||||
print("detection score:", detection.detection_score)
|
||||
for class_name, score in api.get_class_scores(detection):
|
||||
print(class_name, score)
|
||||
```
|
||||
|
||||
## Separate detection confidence from class ranking
|
||||
|
||||
Keep these two ideas separate:
|
||||
|
||||
- `detection_score` tells you how strongly the model kept the event as a detection,
|
||||
- `class_scores` tell you how the model ranked classes for that detected event.
|
||||
|
||||
A detection can have a reasonable detection score while still having uncertain class ranking.
|
||||
|
||||
## Hide the top class if needed
|
||||
|
||||
If you want to inspect only the alternatives, pass `include_top_class=False`.
|
||||
|
||||
```python
|
||||
api.get_class_scores(detection, include_top_class=False)
|
||||
```
|
||||
|
||||
## Related pages
|
||||
|
||||
- Python tutorial: {doc}`../tutorials/integrate-with-a-python-pipeline`
|
||||
- API reference: {doc}`../reference/api`
|
||||
- Understanding scores: {doc}`../explanation/what-batdetect2-predicts`
|
||||
@ -1,49 +0,0 @@
|
||||
# How to inspect detection features in Python
|
||||
|
||||
Use this guide when you want the per-detection feature vectors exposed by the current API.
|
||||
|
||||
## Get the feature vector for one detection
|
||||
|
||||
Each detection carries a `features` vector.
|
||||
|
||||
The API exposes it through `get_detection_features`.
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
|
||||
api = BatDetect2API.from_checkpoint(Path("path/to/model.ckpt"))
|
||||
prediction = api.process_file(Path("path/to/audio.wav"))
|
||||
|
||||
for detection in prediction.detections:
|
||||
features = api.get_detection_features(detection)
|
||||
print(features.shape)
|
||||
```
|
||||
|
||||
## Use features for exploration, not as ground truth labels
|
||||
|
||||
These features are internal model representations attached to detections.
|
||||
|
||||
They can be useful for:
|
||||
|
||||
- exploratory visualization,
|
||||
- downstream clustering,
|
||||
- comparison across detections,
|
||||
- building extra analysis pipelines.
|
||||
|
||||
They do not replace validation.
|
||||
|
||||
They also do not automatically have a one-to-one interpretation as ecological variables.
|
||||
|
||||
## Save predictions with features included
|
||||
|
||||
If you need features on disk, use an output format that supports them, such as `raw` or `parquet`, and keep feature inclusion enabled.
|
||||
|
||||
See {doc}`save-predictions-in-different-output-formats`.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Understanding features and embeddings: {doc}`../explanation/extracted-features-and-embeddings`
|
||||
- Output formats reference: {doc}`../reference/output-formats`
|
||||
- API reference: {doc}`../reference/api`
|
||||
@ -1,41 +0,0 @@
|
||||
# How to interpret evaluation outputs
|
||||
|
||||
Use this guide after `batdetect2 evaluate` has written metrics and plots to disk.
|
||||
|
||||
## Start by identifying the task
|
||||
|
||||
Do not interpret a metric until you know which evaluation task produced it.
|
||||
|
||||
For example, a detection score and a clip-classification score answer different questions.
|
||||
|
||||
## Read the output directory as a bundle
|
||||
|
||||
Treat the evaluation output directory as one package:
|
||||
|
||||
- metrics,
|
||||
- plots,
|
||||
- saved predictions,
|
||||
- config context.
|
||||
|
||||
Do not lift a single number out of context and treat it as the whole story.
|
||||
|
||||
## Look for failure patterns, not just overall averages
|
||||
|
||||
Check:
|
||||
|
||||
- whether errors concentrate in certain taxa,
|
||||
- whether specific sites or recorder setups behave differently,
|
||||
- whether threshold choices are driving the result,
|
||||
- whether predictions are near clip boundaries or matching thresholds.
|
||||
|
||||
## Keep validation and deployment questions separate
|
||||
|
||||
A model can look good on one task and still be a poor fit for your deployment question.
|
||||
|
||||
Interpret the outputs in relation to the real use case, not only the easiest metric to report.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Evaluation tutorial: {doc}`../tutorials/evaluate-on-a-test-set`
|
||||
- Evaluation concepts: {doc}`../explanation/evaluation-concepts-and-matching`
|
||||
- Model output and validation: {doc}`../explanation/model-output-and-validation`
|
||||
@ -1,62 +0,0 @@
|
||||
# How to run batch processing
|
||||
|
||||
This guide shows practical command patterns for directory-based and file-list
|
||||
processing runs.
|
||||
|
||||
Use it after you already know which input mode you want and need concrete
|
||||
command templates for a repeatable batch run.
|
||||
|
||||
## Process a directory
|
||||
|
||||
```bash
|
||||
batdetect2 process directory \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_dir \
|
||||
path/to/outputs
|
||||
```
|
||||
|
||||
Use this when BatDetect2 should discover the audio files for you.
|
||||
|
||||
## Process a file list
|
||||
|
||||
```bash
|
||||
batdetect2 process file_list \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_files.txt \
|
||||
path/to/outputs
|
||||
```
|
||||
|
||||
Use this when another part of your workflow already produced the exact recording
|
||||
list to process.
|
||||
|
||||
## Process a dataset config
|
||||
|
||||
```bash
|
||||
batdetect2 process dataset \
|
||||
path/to/model.ckpt \
|
||||
path/to/annotation_set.json \
|
||||
path/to/outputs
|
||||
```
|
||||
|
||||
Use this when your project already has a `soundevent` annotation set and you
|
||||
want to extract unique recording paths from it.
|
||||
|
||||
## Useful options
|
||||
|
||||
- `--batch-size` to control throughput.
|
||||
- `--workers` to set data-loading parallelism.
|
||||
- `--format` to select output format.
|
||||
- `--inference-config` to control clipping and loader behavior.
|
||||
- `--outputs-config` to control serialization and output transforms.
|
||||
- `--detection-threshold` to override the detection threshold for a run.
|
||||
|
||||
## Practical workflow
|
||||
|
||||
For large runs:
|
||||
|
||||
1. test the command on a small reviewed subset,
|
||||
2. lock the config files and command shape,
|
||||
3. write outputs to a dedicated directory per run,
|
||||
4. record the checkpoint, config paths, and thresholds used.
|
||||
|
||||
For complete option details, see {doc}`../reference/cli/predict`.
|
||||
@ -1,95 +0,0 @@
|
||||
# How to save predictions in different output formats
|
||||
|
||||
Use this guide when you need BatDetect2 outputs in a specific representation for
|
||||
downstream tools.
|
||||
|
||||
## Choose the format that matches the job
|
||||
|
||||
Current built-in output formats include:
|
||||
|
||||
- `raw`:
|
||||
one NetCDF file per clip, best for rich structured outputs,
|
||||
- `parquet`:
|
||||
tabular storage for data analysis workflows,
|
||||
- `soundevent`:
|
||||
prediction-set JSON for soundevent-style tooling,
|
||||
- `batdetect2`:
|
||||
legacy-compatible per-recording JSON and CSV outputs.
|
||||
|
||||
## Select a format from the CLI
|
||||
|
||||
Use `--format` for quick experiments.
|
||||
|
||||
```bash
|
||||
batdetect2 process directory \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_dir \
|
||||
path/to/outputs \
|
||||
--format parquet
|
||||
```
|
||||
|
||||
## Use an outputs config for repeatable runs
|
||||
|
||||
Use an outputs config when you want reproducible control over format and
|
||||
transforms.
|
||||
|
||||
Example:
|
||||
|
||||
```yaml
|
||||
format:
|
||||
name: raw
|
||||
include_class_scores: true
|
||||
include_features: true
|
||||
include_geometry: true
|
||||
transform:
|
||||
detection_transforms: []
|
||||
clip_transforms: []
|
||||
```
|
||||
|
||||
Run with:
|
||||
|
||||
```bash
|
||||
batdetect2 process directory \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_dir \
|
||||
path/to/outputs \
|
||||
--outputs-config path/to/outputs.yaml
|
||||
```
|
||||
|
||||
## Pick the simplest useful format
|
||||
|
||||
- Use `raw` if you want the richest output surface and easy round-tripping.
|
||||
- Use `parquet` if you want tabular analysis in Python or data-lake workflows.
|
||||
- Use `soundevent` if you want prediction-set JSON.
|
||||
- Use `batdetect2` when you need legacy BatDetect2-style outputs.
|
||||
|
||||
## Enable legacy CNN feature CSVs
|
||||
|
||||
The `batdetect2` formatter can also write the legacy CNN feature sidecar CSVs.
|
||||
This is controlled through the outputs config.
|
||||
|
||||
Example:
|
||||
|
||||
```yaml
|
||||
format:
|
||||
name: batdetect2
|
||||
write_cnn_features_csv: true
|
||||
transform:
|
||||
detection_transforms: []
|
||||
clip_transforms: []
|
||||
```
|
||||
|
||||
When enabled, BatDetect2 writes:
|
||||
|
||||
- one `.json` file per recording,
|
||||
- one detection `.csv` file per recording,
|
||||
- one `_cnn_features.csv` file per recording when detections are present.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Outputs config reference:
|
||||
{doc}`../reference/outputs-config`
|
||||
- Output formats reference:
|
||||
{doc}`../reference/output-formats`
|
||||
- Output transforms reference:
|
||||
{doc}`../reference/output-transforms`
|
||||
@ -1,51 +0,0 @@
|
||||
# How to tune detection threshold
|
||||
|
||||
Use this guide to compare detection outputs at different threshold values.
|
||||
|
||||
The goal is not to find a universal threshold.
|
||||
|
||||
The goal is to choose a threshold that fits your reviewed local data and the
|
||||
project trade-off between missed calls and false positives.
|
||||
|
||||
## 1) Start with a baseline run
|
||||
|
||||
Run an initial prediction workflow and keep outputs in a dedicated folder.
|
||||
|
||||
## 2) Sweep threshold values
|
||||
|
||||
Run `process` multiple times with different thresholds (for example `0.1`,
|
||||
`0.3`, `0.5`) and compare output counts and quality on the same validation
|
||||
subset.
|
||||
|
||||
```bash
|
||||
batdetect2 process directory \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_dir \
|
||||
path/to/outputs_thr_03 \
|
||||
--detection-threshold 0.3
|
||||
```
|
||||
|
||||
Keep each threshold run in a separate output directory.
|
||||
|
||||
That makes it easier to compare counts and inspect example files without mixing
|
||||
results.
|
||||
|
||||
## 3) Validate against known calls
|
||||
|
||||
Use files with trusted annotations or expert review to select a threshold that
|
||||
fits your project goals.
|
||||
|
||||
Check both:
|
||||
|
||||
- obvious false positives,
|
||||
- obvious missed calls.
|
||||
|
||||
If class interpretation matters downstream, inspect class ranking behavior as
|
||||
well, not just detection counts.
|
||||
|
||||
## 4) Record your chosen setting
|
||||
|
||||
Write down the chosen threshold and rationale so analyses are reproducible.
|
||||
|
||||
For conceptual trade-offs, see
|
||||
{doc}`../explanation/model-output-and-validation`.
|
||||
@ -1,73 +0,0 @@
|
||||
# How to tune inference clipping
|
||||
|
||||
Use this guide when long recordings need to be split into smaller clips during
|
||||
inference.
|
||||
|
||||
## What clipping controls
|
||||
|
||||
`InferenceConfig.clipping` controls how recordings are split before batching.
|
||||
|
||||
Key fields are:
|
||||
|
||||
- `duration`:
|
||||
clip duration in seconds,
|
||||
- `overlap`:
|
||||
overlap between adjacent clips,
|
||||
- `max_empty`:
|
||||
how much empty padding is allowed,
|
||||
- `discard_empty`:
|
||||
whether empty clips are dropped.
|
||||
|
||||
## Start from the defaults
|
||||
|
||||
Use the built-in clipping behavior first unless you already know you need
|
||||
something else.
|
||||
|
||||
Only tune clipping when:
|
||||
|
||||
- recordings are much longer than your normal working set,
|
||||
- you are seeing edge effects around calls,
|
||||
- you need tighter control over throughput or padding behavior.
|
||||
|
||||
## Override clipping with an inference config
|
||||
|
||||
Create an inference config file and pass it to `process` or `evaluate`.
|
||||
|
||||
Example:
|
||||
|
||||
```yaml
|
||||
clipping:
|
||||
enabled: true
|
||||
duration: 0.5
|
||||
overlap: 0.1
|
||||
max_empty: 0.0
|
||||
discard_empty: true
|
||||
loader:
|
||||
batch_size: 8
|
||||
```
|
||||
|
||||
Run with:
|
||||
|
||||
```bash
|
||||
batdetect2 process directory \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_dir \
|
||||
path/to/outputs \
|
||||
--inference-config path/to/inference.yaml
|
||||
```
|
||||
|
||||
## Validate clipping changes on a small reviewed subset
|
||||
|
||||
Changing clipping changes what the model sees per batch and can change how
|
||||
events near clip boundaries behave.
|
||||
|
||||
Check a reviewed subset before applying clipping changes to a full project.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Inference config reference:
|
||||
{doc}`../reference/inference-config`
|
||||
- Run batch predictions:
|
||||
{doc}`run-batch-predictions`
|
||||
- Understanding the pipeline:
|
||||
{doc}`../explanation/pipeline-overview`
|
||||
@ -1,114 +0,0 @@
|
||||
# Home
|
||||
|
||||
Welcome to the BatDetect2 documentation.
|
||||
|
||||
## What is BatDetect2?
|
||||
|
||||
`batdetect2` is a deep learning model and software package for detecting and
|
||||
classifying bat echolocation calls in high-frequency audio recordings.
|
||||
|
||||
You can use it from the command line or from Python, depending on how much
|
||||
control you need.
|
||||
|
||||
In practice, BatDetect2 scans a recording, finds sounds that look like bat
|
||||
calls, and returns one result for each detected call.
|
||||
Each result can include where the call appears in the recording, shown as a box
|
||||
with start and end time and the lowest and highest frequency, how confident the
|
||||
model is that it found a call, and how strongly it matches the available
|
||||
classes.
|
||||
|
||||
The built-in default model is trained for 17 UK species.
|
||||
The package also supports custom training, fine-tuning, evaluation, and more
|
||||
advanced workflows from Python.
|
||||
|
||||
For more detail on the underlying approach, see the pre-print:
|
||||
[Towards a General Approach for Bat Echolocation Detection and Classification](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)
|
||||
|
||||
```{warning}
|
||||
Treat outputs as model predictions, not ground truth.
|
||||
Always validate on reviewed local data before using results for ecological inference.
|
||||
```
|
||||
|
||||
## What can I do with it?
|
||||
|
||||
- I want to run the model on my recordings:
|
||||
{doc}`tutorials/run-inference-on-folder`
|
||||
- I write code and want to use it from Python:
|
||||
{doc}`tutorials/integrate-with-a-python-pipeline`
|
||||
- I want to train or fine-tune a custom model:
|
||||
{doc}`tutorials/train-a-custom-model`
|
||||
- I want to evaluate a trained model on held-out data:
|
||||
{doc}`tutorials/evaluate-on-a-test-set`
|
||||
|
||||
```{note}
|
||||
Looking for the previous BatDetect2 workflow?
|
||||
See {doc}`legacy/index`.
|
||||
The legacy docs are still available, but new workflows should use `batdetect2 process` and `BatDetect2API`.
|
||||
```
|
||||
|
||||
## How to use this site
|
||||
|
||||
Start with {doc}`getting_started` if you are new.
|
||||
|
||||
Then choose the section that matches what you need.
|
||||
|
||||
If you are here mainly to run the model on recordings, start with Tutorials.
|
||||
|
||||
| Section | Best for | Start here |
|
||||
| ------------- | --------------------------------------------- | ------------------------ |
|
||||
| Tutorials | Step-by-step routes for the most common tasks | {doc}`tutorials/index` |
|
||||
| How-to guides | Answers to specific practical questions | {doc}`how_to/index` |
|
||||
| Reference | Detailed command and settings help | {doc}`reference/index` |
|
||||
| Understanding | Concepts, interpretation, and trade-offs | {doc}`explanation/index` |
|
||||
| Legacy | Previous workflow and migration guidance | {doc}`legacy/index` |
|
||||
|
||||
## Get in touch
|
||||
|
||||
- GitHub repository:
|
||||
[macaodha/batdetect2](https://github.com/macaodha/batdetect2)
|
||||
- Questions, bug reports, and feature requests:
|
||||
[GitHub Issues](https://github.com/macaodha/batdetect2/issues)
|
||||
- Common questions:
|
||||
{doc}`faq`
|
||||
- Want to contribute?
|
||||
See {doc}`development/index`
|
||||
|
||||
## Cite this work
|
||||
|
||||
If you use BatDetect2 in research, please cite:
|
||||
|
||||
Mac Aodha, O., Martinez Balvanera, S., Damstra, E., et al.
|
||||
(2022).
|
||||
_Towards a General Approach for Bat Echolocation Detection and Classification_.
|
||||
bioRxiv.
|
||||
|
||||
or the bibtex entry
|
||||
|
||||
```bibtex
|
||||
@article{batdetect2_2022,
|
||||
title = {Towards a General Approach for Bat Echolocation Detection and Classification},
|
||||
author = {Mac Aodha, Oisin and Mart\'{i}nez Balvanera, Santiago and Damstra, Elise and Cooke, Martyn and Eichinski, Philip and Browning, Ella and Barataudm, Michel and Boughey, Katherine and Coles, Roger and Giacomini, Giada and MacSwiney G., M. Cristina and K. Obrist, Martin and Parsons, Stuart and Sattler, Thomas and Jones, Kate E.},
|
||||
journal = {bioRxiv},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
:caption: Get Started
|
||||
|
||||
getting_started
|
||||
faq
|
||||
tutorials/index
|
||||
how_to/index
|
||||
reference/index
|
||||
explanation/index
|
||||
legacy/index
|
||||
```
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
:caption: Contributing
|
||||
|
||||
development/index
|
||||
```
|
||||
@ -1,53 +0,0 @@
|
||||
# CLI workflow: `batdetect2 detect`
|
||||
|
||||
This page documents the previous CLI workflow based on `batdetect2 detect`.
|
||||
|
||||
```{warning}
|
||||
This is documentation for a previous version of batdetect2.
|
||||
For new workflows, use `batdetect2 process directory` instead.
|
||||
If you are migrating, start with {doc}`migration-guide`.
|
||||
```
|
||||
|
||||
## Processing a folder of audio files
|
||||
|
||||
```bash
|
||||
batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD
|
||||
```
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
batdetect2 detect example_data/audio/ example_data/anns/ 0.3
|
||||
```
|
||||
|
||||
This command scans a directory of audio files, runs the BatDetect2 detector on
|
||||
each file, and writes BatDetect2-style outputs into `ANN_DIR`.
|
||||
Those outputs usually include one JSON file and one CSV file per recording, and
|
||||
can optionally include extra feature CSVs.
|
||||
|
||||
`AUDIO_DIR` is the folder containing the input `.wav` files.
|
||||
`ANN_DIR` is the folder where model outputs are written.
|
||||
|
||||
`DETECTION_THRESHOLD` controls which detections are kept.
|
||||
Predictions below this score are discarded.
|
||||
Smaller values keep more detections, but usually also increase mistakes.
|
||||
|
||||
Common options:
|
||||
|
||||
- `--cnn_features` Write extra CNN feature CSV files for each recording.
|
||||
- `--spec_features` Extract and write traditional acoustic spectrogram feature
|
||||
CSV files.
|
||||
These are saved as `*_spec_features.csv` files.
|
||||
- `--time_expansion_factor` Set the time expansion factor used for all files in
|
||||
the run.
|
||||
- `--save_preds_if_empty` Save output files even when no detections are found.
|
||||
- `--model_path` Use a specific checkpoint instead of the included default
|
||||
model.
|
||||
If omitted, the command uses the default model trained on UK data.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Migration guide:
|
||||
{doc}`migration-guide`
|
||||
- Current process docs:
|
||||
{doc}`../reference/cli/predict`
|
||||
@ -1,28 +0,0 @@
|
||||
# BatDetect2 v1.0 documentation
|
||||
|
||||
This section documents the BatDetect2 workflow for version 1.
|
||||
|
||||
Use these pages if you need to keep working with the older `batdetect2 detect` command or the older `batdetect2.api` interface.
|
||||
|
||||
For new projects, we recommend the current workflow:
|
||||
|
||||
- CLI:
|
||||
`batdetect2 process`
|
||||
- Python:
|
||||
`batdetect2.api_v2.BatDetect2API`
|
||||
|
||||
If you are moving from the older workflow, start with {doc}`migration-guide`.
|
||||
|
||||
```{warning}
|
||||
These pages describe the previous workflow.
|
||||
They are kept for continuity and migration support.
|
||||
New users should start with {doc}`../getting_started` and {doc}`../tutorials/index`.
|
||||
```
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
cli-detect
|
||||
python-api
|
||||
migration-guide
|
||||
```
|
||||
@ -1,123 +0,0 @@
|
||||
# BatDetect2 2.0 migration guide
|
||||
|
||||
Use this guide when moving from BatDetect2 1.x workflows to the CLI and API in
|
||||
2.x.
|
||||
|
||||
## Why migrate
|
||||
|
||||
You get access to newer features.
|
||||
The codebase changed quite a bit and now gives you much more control over the
|
||||
workflow through config files, improved training and fine-tuning code, and a
|
||||
more flexible sound target definition system.
|
||||
|
||||
You can also run newer or improved models.
|
||||
That includes updated versions of the UK model, plus other models trained with
|
||||
the newer codebase.
|
||||
|
||||
We are no longer actively supporting version 1.
|
||||
No new enhancements are planned there, and only major bug fixes may still be
|
||||
considered.
|
||||
Future work is focused on version 2, including compatibility with newer Python
|
||||
versions.
|
||||
|
||||
## Deprecation plan
|
||||
|
||||
We have kept the `batdetect2.api` module and the `batdetect2 detect` CLI command
|
||||
in place for now.
|
||||
You can keep using them without changing your current workflow.
|
||||
However, many of the internal functions were relocated, removed or modified.
|
||||
If your code relied on anything outside of the `api` module, it may break.
|
||||
It is worth checking the new docs first, since there may already be a newer
|
||||
feature that covers your use case.
|
||||
If not, please open an issue.
|
||||
|
||||
Because the old `api` and CLI command are now redundant with the newer stack, we
|
||||
plan to remove them in about a year.
|
||||
If you want to keep pipelines up to date and long-running, it is a good idea to
|
||||
migrate to version 2.
|
||||
|
||||
## How to migrate
|
||||
|
||||
If you are only using the `batdetect2 detect` CLI command or the
|
||||
`batdetect2.api` module, the migration should be fairly simple.
|
||||
This guide only covers these two entry points.
|
||||
|
||||
### CLI mapping
|
||||
|
||||
- `batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD` -> `batdetect2
|
||||
process directory AUDIO_DIR OUTPUT_PATH --detection-threshold
|
||||
DETECTION_THRESHOLD ...`
|
||||
|
||||
Main changes:
|
||||
|
||||
- outputs can be written in different formats.
|
||||
See the output format reference for the available options.
|
||||
- the detection threshold is now an option instead of a required positional
|
||||
argument.
|
||||
- options like saving CNN features are now controlled through config rather than
|
||||
command flags.
|
||||
- there are separate subcommands for processing a directory, file list, or
|
||||
dataset.
|
||||
|
||||
### Python API mapping
|
||||
|
||||
- old:
|
||||
`import batdetect2.api as api`
|
||||
- current:
|
||||
`from batdetect2 import BatDetect2API`
|
||||
|
||||
Typical migration shape:
|
||||
|
||||
```python
|
||||
from pathlib import Path
|
||||
|
||||
from batdetect2 import BatDetect2API
|
||||
|
||||
# If no checkpoint is provided, the default UK model is loaded
|
||||
api = BatDetect2API.from_checkpoint()
|
||||
prediction = api.process_file(Path("path/to/audio.wav"))
|
||||
```
|
||||
|
||||
Useful replacements:
|
||||
|
||||
- `batdetect2.api.process_file` -> current `BatDetect2API.process_file`
|
||||
- `batdetect2.api.process_audio` -> current `BatDetect2API.process_audio`
|
||||
- `batdetect2.api.process_spectrogram` -> current
|
||||
`BatDetect2API.process_spectrogram`
|
||||
- one-off batch loops -> `BatDetect2API.process_files` or CLI `process`
|
||||
|
||||
### Model changes
|
||||
|
||||
The default checkpoint used by the new CLI `process` commands and by
|
||||
`BatDetect2API` is a newer model trained from scratch using the updated training
|
||||
code, but the same model architecture, training procedure, and data.
|
||||
Performance did not change substantially, but some differences are still
|
||||
expected.
|
||||
|
||||
### Species names
|
||||
|
||||
For the default UK model there are two naming changes:
|
||||
|
||||
1. The original model had a typo and instead of `Barbastella barbastellus` it
|
||||
used `Barbastellus barbastellus`.
|
||||
This has now been corrected.
|
||||
2. There has been a recent change in name for `Eptesicus serotinus` to
|
||||
`Cnephaeus serotinus`.
|
||||
|
||||
## Stay on version 1
|
||||
|
||||
If you prefer not to migrate to version 2 yet, you can keep using version 1.
|
||||
In that case, it is a good idea to pin your dependency:
|
||||
|
||||
```bash
|
||||
pip install "batdetect2>=1.3.1,<2"
|
||||
```
|
||||
|
||||
## Related pages
|
||||
|
||||
- Getting started:
|
||||
{doc}`../getting_started`
|
||||
- Tutorials:
|
||||
{doc}`../tutorials/index`
|
||||
- API reference:
|
||||
{doc}`../reference/api`
|
||||
@ -1,55 +0,0 @@
|
||||
# Legacy Python API: `batdetect2.api`
|
||||
|
||||
This page documents the previous Python API workflow based on `batdetect2.api`.
|
||||
|
||||
```{warning}
|
||||
This is documentation for a previous version of batdetect2.
|
||||
For new workflows, use `batdetect2.BatDetect2API`.
|
||||
If you are migrating, start with {doc}`migration-guide`.
|
||||
```
|
||||
|
||||
## Using BatDetect2 in Python
|
||||
|
||||
If you prefer to process data inside a Python script, you can use the `batdetect2.api` module.
|
||||
|
||||
This interface gives you a simple entry point for running the built-in BatDetect2 model and also exposes the default model and default configuration more directly than the current API.
|
||||
|
||||
You can process a whole file in one step, or load audio, generate a spectrogram, and work with lower-level functions yourself.
|
||||
|
||||
Common functions:
|
||||
|
||||
- `process_file` Load an audio file, run the model, and return BatDetect2-style results for that recording.
|
||||
- `process_audio` Run inference on an audio array that is already loaded in memory.
|
||||
- `process_spectrogram` Run inference starting from a spectrogram tensor instead of raw audio.
|
||||
- `load_audio` Load and resample audio using the legacy preprocessing path.
|
||||
- `generate_spectrogram` Convert audio into the spectrogram representation expected by the model.
|
||||
- `postprocess` Convert raw model outputs into detections and extracted features.
|
||||
|
||||
Typical usage:
|
||||
|
||||
```python
|
||||
import batdetect2.api as api
|
||||
|
||||
AUDIO_FILE = "example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav"
|
||||
|
||||
# Process a whole file
|
||||
results = api.process_file(AUDIO_FILE)
|
||||
annotations = results["pred_dict"]["annotation"]
|
||||
|
||||
# Or, load audio and compute spectrograms
|
||||
audio = api.load_audio(AUDIO_FILE)
|
||||
spec = api.generate_spectrogram(audio)
|
||||
|
||||
# And process the audio or the spectrogram with the model
|
||||
detections, features, spec = api.process_audio(audio)
|
||||
detections, features = api.process_spectrogram(spec)
|
||||
|
||||
# Integrate the detections or extracted features into your own analysis
|
||||
```
|
||||
|
||||
This interface is most useful when you want to work directly with detections, features, spectrograms, or intermediate arrays inside your own code.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Migration guide: {doc}`migration-guide`
|
||||
- Current API reference: {doc}`../reference/api`
|
||||
@ -1,39 +0,0 @@
|
||||
# `BatDetect2API` reference
|
||||
|
||||
`BatDetect2API` is the main Python entry point for BatDetect2.
|
||||
|
||||
Use it when you want to load a model, run prediction, inspect detections,
|
||||
evaluate results, or train from Python.
|
||||
|
||||
Defined in `batdetect2.api_v2`.
|
||||
|
||||
## Main ways to create it
|
||||
|
||||
- `BatDetect2API.from_checkpoint(path, ...)`
|
||||
- load a trained checkpoint, a bundled checkpoint alias, or a Hugging Face
|
||||
checkpoint.
|
||||
- `BatDetect2API.from_config(model_config=..., targets_config=..., ...)`
|
||||
- build a full model stack from config objects.
|
||||
|
||||
## Common tasks
|
||||
|
||||
- Load a checkpoint and run prediction on one file.
|
||||
- Run prediction on many files or clips.
|
||||
- Save predictions in one of the supported output formats.
|
||||
- Evaluate a model on labelled data.
|
||||
- Fine-tune an existing checkpoint on new targets.
|
||||
|
||||
## Generated reference
|
||||
|
||||
```{eval-rst}
|
||||
.. autoclass:: batdetect2.api_v2.BatDetect2API
|
||||
```
|
||||
|
||||
## Related pages
|
||||
|
||||
- Python tutorial:
|
||||
{doc}`../tutorials/integrate-with-a-python-pipeline`
|
||||
- Outputs config reference:
|
||||
{doc}`outputs-config`
|
||||
- Output formats reference:
|
||||
{doc}`output-formats`
|
||||
@ -1,8 +0,0 @@
|
||||
Base command
|
||||
============
|
||||
|
||||
The options on this page apply to all subcommands.
|
||||
|
||||
.. click:: batdetect2.cli:cli
|
||||
:prog: batdetect2
|
||||
:nested: none
|
||||
@ -1,8 +0,0 @@
|
||||
Data command
|
||||
============
|
||||
|
||||
Inspect and convert dataset config files.
|
||||
|
||||
.. click:: batdetect2.cli.data:data
|
||||
:prog: batdetect2 data
|
||||
:nested: full
|
||||
@ -1,18 +0,0 @@
|
||||
Legacy detect command
|
||||
=====================
|
||||
|
||||
.. warning::
|
||||
|
||||
``batdetect2 detect`` is a legacy compatibility command.
|
||||
Prefer ``batdetect2 process directory`` for new workflows.
|
||||
|
||||
Migration at a glance
|
||||
---------------------
|
||||
|
||||
- Legacy: ``batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD``
|
||||
- Current: ``batdetect2 process directory MODEL_PATH AUDIO_DIR OUTPUT_PATH``
|
||||
with optional ``--detection-threshold``
|
||||
|
||||
.. click:: batdetect2.cli.compat:detect
|
||||
:prog: batdetect2 detect
|
||||
:nested: none
|
||||
@ -1,11 +0,0 @@
|
||||
Evaluate command
|
||||
================
|
||||
|
||||
Use ``batdetect2 evaluate`` to compare a checkpoint against labelled test data.
|
||||
|
||||
This command writes metrics and any configured artifacts to the output
|
||||
directory.
|
||||
|
||||
.. click:: batdetect2.cli.evaluate:evaluate_command
|
||||
:prog: batdetect2 evaluate
|
||||
:nested: none
|
||||
@ -1,11 +0,0 @@
|
||||
Finetune command
|
||||
================
|
||||
|
||||
Use ``batdetect2 finetune`` to adapt an existing checkpoint to a new target
|
||||
definition.
|
||||
|
||||
If you do not pass ``--model``, the bundled ``uk_same`` checkpoint is used.
|
||||
|
||||
.. click:: batdetect2.cli.finetune:finetune_command
|
||||
:prog: batdetect2 finetune
|
||||
:nested: none
|
||||
@ -1,50 +0,0 @@
|
||||
# CLI reference
|
||||
|
||||
Use this section to find the right command quickly, then open the command page
|
||||
for the full option list.
|
||||
|
||||
## Command map
|
||||
|
||||
| Command | Use it for | Required positional args |
|
||||
| --- | --- | --- |
|
||||
| `batdetect2 process` | Run inference on audio | Depends on subcommand (`directory`, `file_list`, `dataset`) |
|
||||
| `batdetect2 data` | Inspect and convert dataset configs | Depends on subcommand (`summary`, `convert`) |
|
||||
| `batdetect2 train` | Train or fine-tune models | `TRAIN_DATASET` |
|
||||
| `batdetect2 finetune` | Fine-tune a checkpoint on new targets | `TRAIN_DATASET` plus `--targets` |
|
||||
| `batdetect2 evaluate` | Evaluate a checkpoint on a test dataset | `TEST_DATASET` |
|
||||
| `batdetect2 detect` | Legacy compatibility workflow | `AUDIO_DIR`, `ANN_DIR`, `DETECTION_THRESHOLD` |
|
||||
|
||||
## Notes
|
||||
|
||||
- Global CLI options are documented in {doc}`base`.
|
||||
- Paths with spaces should be wrapped in quotes.
|
||||
- Input audio is expected to be mono.
|
||||
- `process` uses the optional `--detection-threshold` override.
|
||||
- `evaluate` takes `TEST_DATASET` as a positional argument and uses `--model`
|
||||
for the checkpoint override.
|
||||
- `finetune` defaults to the bundled `uk_same` checkpoint if `--model` is not
|
||||
provided.
|
||||
|
||||
```{warning}
|
||||
`batdetect2 detect` is a legacy command.
|
||||
Prefer `batdetect2 process directory` for new workflows.
|
||||
```
|
||||
|
||||
## Related pages
|
||||
|
||||
- {doc}`../../tutorials/run-inference-on-folder`
|
||||
- {doc}`../../how_to/run-batch-predictions`
|
||||
- {doc}`../../how_to/tune-detection-threshold`
|
||||
- {doc}`../configs`
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
Base command and global options <base>
|
||||
Process command group <predict>
|
||||
Data command group <data>
|
||||
Train command <train>
|
||||
Finetune command <finetune>
|
||||
Evaluate command <evaluate>
|
||||
Legacy detect command <detect_legacy>
|
||||
```
|
||||
@ -1,17 +0,0 @@
|
||||
Process command
|
||||
===============
|
||||
|
||||
Use ``batdetect2 process`` to run inference on audio.
|
||||
|
||||
Choose a subcommand based on how you want to provide the input:
|
||||
|
||||
- ``directory`` for all supported audio files in one folder
|
||||
- ``file_list`` for a text file with one audio path per line
|
||||
- ``dataset`` for recordings referenced by a dataset file
|
||||
|
||||
Use ``--detection-threshold`` when you want to override the configured
|
||||
threshold for one run.
|
||||
|
||||
.. click:: batdetect2.cli.inference:process
|
||||
:prog: batdetect2 process
|
||||
:nested: full
|
||||
@ -1,12 +0,0 @@
|
||||
Train command
|
||||
=============
|
||||
|
||||
Use ``batdetect2 train`` to start from a fresh model config or continue from an
|
||||
existing checkpoint.
|
||||
|
||||
If you want to adapt an existing checkpoint to a new target definition, use
|
||||
``batdetect2 finetune`` instead.
|
||||
|
||||
.. click:: batdetect2.cli.train:train_command
|
||||
:prog: batdetect2 train
|
||||
:nested: none
|
||||
@ -1,18 +0,0 @@
|
||||
Config reference
|
||||
================
|
||||
|
||||
BatDetect2 uses separate config objects for different workflow surfaces.
|
||||
|
||||
Use the dedicated reference pages for each config family:
|
||||
|
||||
- model config
|
||||
- training config
|
||||
- logging config
|
||||
- inference config
|
||||
- evaluation config
|
||||
- outputs config
|
||||
- preprocessing config
|
||||
- postprocess config
|
||||
- targets config workflow
|
||||
|
||||
Example config files live under `example_data/configs/`.
|
||||
@ -1,76 +0,0 @@
|
||||
# Data source reference
|
||||
|
||||
This page summarizes dataset source formats and their config fields.
|
||||
|
||||
## Supported source formats
|
||||
|
||||
| Format | Description |
|
||||
| --- | --- |
|
||||
| `aoef` | AOEF/soundevent annotation files (`AnnotationSet` or `AnnotationProject`) |
|
||||
| `batdetect2` | Legacy format with one JSON annotation file per recording |
|
||||
| `batdetect2_file` | Legacy format with one merged JSON annotation file |
|
||||
|
||||
## AOEF (`format: aoef`)
|
||||
|
||||
Required fields:
|
||||
|
||||
- `name`
|
||||
- `format`
|
||||
- `audio_dir`
|
||||
- `annotations_path`
|
||||
|
||||
Optional fields:
|
||||
|
||||
- `description`
|
||||
- `filter`
|
||||
|
||||
`filter` is only used when `annotations_path` points to an
|
||||
`AnnotationProject`.
|
||||
|
||||
AOEF filter options:
|
||||
|
||||
- `only_completed` (default: `true`)
|
||||
- `only_verified` (default: `false`)
|
||||
- `exclude_issues` (default: `true`)
|
||||
|
||||
Use `filter: null` to disable project filtering.
|
||||
|
||||
## Legacy per-file (`format: batdetect2`)
|
||||
|
||||
Required fields:
|
||||
|
||||
- `name`
|
||||
- `format`
|
||||
- `audio_dir`
|
||||
- `annotations_dir`
|
||||
|
||||
Optional fields:
|
||||
|
||||
- `description`
|
||||
- `filter`
|
||||
|
||||
## Legacy merged file (`format: batdetect2_file`)
|
||||
|
||||
Required fields:
|
||||
|
||||
- `name`
|
||||
- `format`
|
||||
- `audio_dir`
|
||||
- `annotations_path`
|
||||
|
||||
Optional fields:
|
||||
|
||||
- `description`
|
||||
- `filter`
|
||||
|
||||
Legacy filter options:
|
||||
|
||||
- `only_annotated` (default: `true`)
|
||||
- `exclude_issues` (default: `true`)
|
||||
|
||||
Use `filter: null` to disable filtering.
|
||||
|
||||
## Related guides
|
||||
|
||||
- {doc}`../how_to/configure-aoef-dataset`
|
||||
- {doc}`../how_to/import-legacy-batdetect2-annotations`
|
||||
@ -1,42 +0,0 @@
|
||||
# Detections reference
|
||||
|
||||
These are the main prediction objects returned by BatDetect2 inference methods.
|
||||
|
||||
Defined in `batdetect2.postprocess.types`.
|
||||
|
||||
## `ClipDetections`
|
||||
|
||||
`ClipDetections` represents the predictions for one clip or one full recording.
|
||||
|
||||
Fields:
|
||||
|
||||
- `clip`
|
||||
- the `soundevent` clip metadata for the processed audio.
|
||||
- `detections`
|
||||
- list of `Detection` objects for that clip.
|
||||
|
||||
## `Detection`
|
||||
|
||||
`Detection` represents one detected event.
|
||||
|
||||
Fields:
|
||||
|
||||
- `geometry`
|
||||
- time-frequency geometry for the detected event.
|
||||
- `detection_score`
|
||||
- confidence that there is an event at this location.
|
||||
- `class_scores`
|
||||
- class ranking scores for the detected event.
|
||||
- `features`
|
||||
- per-detection feature vector from the model.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Python tutorial:
|
||||
{doc}`../tutorials/integrate-with-a-python-pipeline`
|
||||
- API reference:
|
||||
{doc}`api`
|
||||
- What BatDetect2 predicts:
|
||||
{doc}`../explanation/what-batdetect2-predicts`
|
||||
- Features and embeddings:
|
||||
{doc}`../explanation/extracted-features-and-embeddings`
|
||||
@ -1,46 +0,0 @@
|
||||
# Evaluation config reference
|
||||
|
||||
`EvaluationConfig` defines which evaluation tasks run and which plots they generate.
|
||||
|
||||
Defined in `batdetect2.evaluate.config`.
|
||||
|
||||
## Top-level fields
|
||||
|
||||
- `tasks`
|
||||
- list of task configs.
|
||||
|
||||
## Built-in task families
|
||||
|
||||
Current built-in tasks include:
|
||||
|
||||
- `sound_event_detection`
|
||||
- `sound_event_classification`
|
||||
- `top_class_detection`
|
||||
- `clip_detection`
|
||||
- `clip_classification`
|
||||
|
||||
## Shared task controls
|
||||
|
||||
Common task-level controls include:
|
||||
|
||||
- `prefix`
|
||||
- `ignore_start_end`
|
||||
|
||||
Sound-event-style tasks also support:
|
||||
|
||||
- `affinity`
|
||||
- `affinity_threshold`
|
||||
- `strict_match`
|
||||
|
||||
## Default behavior
|
||||
|
||||
The default evaluation config starts with:
|
||||
|
||||
- sound event detection,
|
||||
- sound event classification.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Choose and configure evaluation tasks: {doc}`../how_to/choose-and-configure-evaluation-tasks`
|
||||
- Evaluation concepts: {doc}`../explanation/evaluation-concepts-and-matching`
|
||||
- Evaluate CLI reference: {doc}`cli/evaluate`
|
||||
@ -1,28 +0,0 @@
|
||||
# Reference documentation
|
||||
|
||||
Reference pages are the detailed lookup pages.
|
||||
|
||||
Use this section when you need exact command options, setting names, output
|
||||
details, or Python API entries.
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
cli/index
|
||||
api
|
||||
detections
|
||||
model-config
|
||||
training-config
|
||||
logging-config
|
||||
inference-config
|
||||
evaluation-config
|
||||
outputs-config
|
||||
output-formats
|
||||
output-transforms
|
||||
data-sources
|
||||
preprocessing-config
|
||||
postprocess-config
|
||||
targets-config-workflow
|
||||
configs
|
||||
targets
|
||||
```
|
||||
@ -1,41 +0,0 @@
|
||||
# Inference config reference
|
||||
|
||||
`InferenceConfig` controls how files are clipped and batched during prediction-time workflows.
|
||||
|
||||
Defined in `batdetect2.inference.config`.
|
||||
|
||||
## Top-level fields
|
||||
|
||||
- `loader`
|
||||
- data-loader settings for inference.
|
||||
- `clipping`
|
||||
- controls how recordings are split into clips before batching.
|
||||
|
||||
## `loader`
|
||||
|
||||
Current built-in loader field:
|
||||
|
||||
- `batch_size` (int, default `8`)
|
||||
|
||||
## `clipping`
|
||||
|
||||
Fields:
|
||||
|
||||
- `enabled` (bool)
|
||||
- `duration` (float, seconds)
|
||||
- `overlap` (float, seconds)
|
||||
- `max_empty` (float)
|
||||
- `discard_empty` (bool)
|
||||
|
||||
## When to override this config
|
||||
|
||||
Override `InferenceConfig` when:
|
||||
|
||||
- long recordings need different clipping behavior,
|
||||
- you want to tune batch size for your hardware,
|
||||
- you need reproducible prediction settings across runs.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Tune inference clipping: {doc}`../how_to/tune-inference-clipping`
|
||||
- Predict CLI reference: {doc}`cli/predict`
|
||||
@ -1,46 +0,0 @@
|
||||
# Logging config reference
|
||||
|
||||
`AppLoggingConfig` controls which logger backend BatDetect2 uses for training,
|
||||
evaluation, and inference.
|
||||
|
||||
Defined in `batdetect2.logging`.
|
||||
|
||||
## Top-level fields
|
||||
|
||||
- `train`
|
||||
- logger config for training runs.
|
||||
- `evaluation`
|
||||
- logger config for evaluation runs.
|
||||
- `inference`
|
||||
- logger config for inference runs.
|
||||
|
||||
## Built-in logger backends
|
||||
|
||||
Current built-in logger backends are:
|
||||
|
||||
- `csv`
|
||||
- `tensorboard`
|
||||
- `mlflow`
|
||||
- `dvclive`
|
||||
|
||||
## Default behaviour
|
||||
|
||||
By default:
|
||||
|
||||
- training uses `csv`,
|
||||
- evaluation uses `csv`,
|
||||
- inference uses `csv`.
|
||||
|
||||
With the CSV logger, training writes a `metrics.csv` file in the log folder.
|
||||
|
||||
Example files live under `example_data/configs/`, including
|
||||
`example_data/configs/logging.yaml`.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Train command reference:
|
||||
{doc}`cli/train`
|
||||
- Evaluate command reference:
|
||||
{doc}`cli/evaluate`
|
||||
- Run inference on a folder:
|
||||
{doc}`../tutorials/run-inference-on-folder`
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user