Compare commits

...

17 Commits

Author SHA1 Message Date
mbsantiago
b0f85b96e3 fix: resolve remaining type check issues 2026-05-06 17:43:29 +01:00
mbsantiago
ce6975770e ci: add GitHub workflows and release helpers 2026-05-06 17:22:18 +01:00
mbsantiago
69d8e2d228 docs: polish README overview and links 2026-05-06 16:50:45 +01:00
mbsantiago
855a79853b docs: refine CLI command docstrings 2026-05-06 14:47:29 +01:00
mbsantiago
6587c6c4e5 feat: rename CLI inference command to process 2026-05-06 14:32:51 +01:00
mbsantiago
831925bd95 feat: expose BatDetect2API at package root 2026-05-06 14:10:24 +01:00
mbsantiago
b4efcfcf0f docs: refresh api reference guidance 2026-05-06 14:06:04 +01:00
mbsantiago
5cc5767eff fix: rename detector heads and refresh bundled checkpoint 2026-05-06 12:50:53 +01:00
mbsantiago
2008c8000f refactor: replace abstract model types with protocols 2026-05-06 12:50:32 +01:00
mbsantiago
a27d1bbfd3 refactor: derive training config from the model 2026-05-06 12:48:40 +01:00
mbsantiago
999dc93d88 docs: improve API and CLI reference docs
Clarify BatDetect2API usage, add examples and NumPy-style docstrings, and tighten CLI help and reference pages for prediction, training, evaluation, and fine-tuning workflows.
2026-05-06 11:19:38 +01:00
mbsantiago
7c05fb8577 feat: default to bundled checkpoint
Fall back to the bundled uk_same model when no checkpoint is provided in the shared loader and fine-tune CLI. Keep tests aligned with the new default resolution behavior.
2026-05-06 10:33:04 +01:00
mbsantiago
31054f64f6 fix: load checkpoints on cpu
Use CPU map_location when restoring Lightning checkpoints so packaged models load reliably without requiring accelerator-specific device state.
2026-05-05 21:49:09 +01:00
mbsantiago
84918086c8 feat: streamline bundled checkpoint handling
Support packaged model aliases and save weights-only checkpoints by default so distributed models stay small while remaining easy to load.
2026-05-05 21:34:54 +01:00
mbsantiago
d83f801515 perf: defer heavy api imports 2026-05-05 16:39:43 +01:00
mbsantiago
5526ac99fc Remove stale dependencies 2026-05-05 16:20:37 +01:00
mbsantiago
f5afa9881c feat: load checkpoints from Hugging Face 2026-05-05 15:46:39 +01:00
77 changed files with 1831 additions and 778 deletions

View File

@ -3,6 +3,8 @@ current_version = 1.1.1
commit = True
tag = True
[bumpversion:file:batdetect2/__init__.py]
[bumpversion:file:src/batdetect2/__init__.py]
[bumpversion:file:pyproject.toml]
[bumpversion:file:docs/source/conf.py]

79
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,79 @@
name: CI
on:
pull_request:
push:
branches:
- main
concurrency:
group: ci-${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
checks:
name: Checks
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install just
uses: taiki-e/install-action@just
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
cache-dependency-glob: |
pyproject.toml
uv.lock
- name: Install dependencies
run: just install-dev
- name: Run formatting, lint, and type checks
run: just check
tests:
name: Tests (Python ${{ matrix.python-version }})
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version:
- "3.10"
- "3.11"
- "3.12"
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install just
uses: taiki-e/install-action@just
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
cache-dependency-glob: |
pyproject.toml
uv.lock
- name: Install dependencies
run: just install-dev
- name: Run test suite
run: just test

69
.github/workflows/docs-pages.yml vendored Normal file
View File

@ -0,0 +1,69 @@
name: Docs Pages
on:
push:
branches:
- main
workflow_dispatch:
permissions:
contents: read
concurrency:
group: docs-pages
cancel-in-progress: true
jobs:
build:
name: Build Docs
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install just
uses: taiki-e/install-action@just
- name: Configure GitHub Pages
uses: actions/configure-pages@v5
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
cache-dependency-glob: |
pyproject.toml
uv.lock
- name: Install dependencies
run: just install-dev
- name: Build docs
run: just check-docs
- name: Upload Pages artifact
uses: actions/upload-pages-artifact@v4
with:
path: docs/build
deploy:
name: Deploy Docs
needs: build
runs-on: ubuntu-latest
permissions:
pages: write
id-token: write
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
steps:
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4

70
.github/workflows/publish-pypi.yml vendored Normal file
View File

@ -0,0 +1,70 @@
name: Publish PyPI
on:
release:
types:
- published
permissions:
contents: read
concurrency:
group: publish-pypi
cancel-in-progress: false
jobs:
build:
name: Build Distributions
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install just
uses: taiki-e/install-action@just
- name: Install uv
uses: astral-sh/setup-uv@v6
with:
enable-cache: true
cache-dependency-glob: |
pyproject.toml
uv.lock
- name: Install dependencies
run: just install-dev
- name: Build distributions
run: just build-dist
- name: Upload distributions
uses: actions/upload-artifact@v4
with:
name: release-dists
path: dist/
publish:
name: Publish to PyPI
needs: build
runs-on: ubuntu-latest
permissions:
id-token: write
environment:
name: pypi
url: https://pypi.org/p/batdetect2
steps:
- name: Download distributions
uses: actions/download-artifact@v5
with:
name: release-dists
path: dist/
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1

View File

@ -1,29 +0,0 @@
name: Python package
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
jobs:
build:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v4
- name: Install uv
uses: astral-sh/setup-uv@v3
with:
enable-cache: true
cache-dependency-glob: "uv.lock"
- name: Set up Python ${{ matrix.python-version }}
run: uv python install ${{ matrix.python-version }}
- name: Install the project
run: uv sync --all-extras --dev
- name: Test with pytest
run: uv run pytest

View File

@ -1,30 +0,0 @@
name: Upload Python Package
on:
release:
types: [published]
permissions:
contents: read
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v3
with:
python-version: "3.x"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build
- name: Build package
run: python -m build
- name: Publish package
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}

1
.gitignore vendored
View File

@ -50,6 +50,7 @@ cover/
# Sphinx documentation
docs/_build/
docs/build/
# PyBuilder
.pybuilder/

247
README.md
View File

@ -1,202 +1,137 @@
# BatDetect2
<img style="display: block-inline;" width="64" height="64" src="assets/bat_icon.png"> Code for detecting and classifying bat echolocation calls in high frequency audio recordings.
## What BatDetect2 is useful for
<img style="display:block-inline;" width="64" height="64" src="assets/bat_icon.png">
BatDetect2 can help you screen recordings for bat calls,
find recordings that need expert review,
and compare model outputs across sites or projects with appropriate caution.
Code for detecting and classifying bat echolocation calls in high-frequency
audio recordings.
It is best used as a tool to support ecological work,
not as a replacement for validation or expert interpretation.
> [!WARNING]
> `batdetect2` 2.0.1 is out.
> There are many changes and new recommended workflows.
> We have left the previous `batdetect2.api` module intact, but if you run
> into issues or want to upgrade, see the
> [migration guide](docs/source/legacy/migration-guide.md) in the docs site.
>
> This update also ships with a refreshed default model.
> It was trained in the same way and on the same data as before, but you should still expect small output differences in some cases.
## Start here
## What is BatDetect2
If you want the simplest current workflow,
use the documentation site and start with:
BatDetect2 is a deep learning model for detecting and classifying bat
echolocation calls.
The model generates multiple predictions for each input recording by providing a
bounding box and predicted class for each individual call within it.
- getting started: `docs/source/getting_started.md`
- first tutorial: `docs/source/tutorials/run-inference-on-folder.md`
This repository also holds `batdetect2`, a Python-based tool to run, train,
finetune and evaluate BatDetect2-type models, including the built-in model for
detecting UK bat species.
You can use the tool from the command line (terminal) or from Python as needed.
The current docs default to:
## Getting Started
- the current command-line workflow: `batdetect2 predict`
- the current Python workflow: `batdetect2.api_v2.BatDetect2API`
We have [extensive documentation](docs/source/index.md) on how to use
`batdetect2`.
See our [getting started](docs/source/getting_started.md) guide and then jump
into any of our tutorials:
If you need the previous workflow based on `batdetect2 detect` or `batdetect2.api`,
use the legacy docs section and migration guide in the docs site.
- Run the model on a folder of recordings:
`docs/source/tutorials/run-inference-on-folder.md`
- Train your own model:
`docs/source/tutorials/train-a-custom-model.md`
- Evaluate your model:
`docs/source/tutorials/evaluate-on-a-test-set.md`
- Fine-tune a model:
`docs/source/tutorials/integrate-with-a-python-pipeline.md`
## Install BatDetect2
### Try the model
If you already use Python,
activate the environment where you want BatDetect2 to live.
If you want to try the model for UK bat species without installing anything, you
can try the following:
If not,
create a fresh one first so BatDetect2 stays separate from other software on your machine.
1. Demo of the model (for UK species) on
[huggingface](https://huggingface.co/spaces/macaodha/batdetect2).
Two common options are:
* Install the Anaconda Python 3.10 distribution for your operating system from [here](https://www.continuum.io/downloads). Create a new environment and activate it:
```bash
conda create -y --name batdetect2 python==3.10
conda activate batdetect2
```
* If you already have Python installed (version >= 3.10,< 3.14), you can create a fresh environment with:
```bash
python -m venv .venv
source .venv/bin/activate
```
2. Alternatively, click
[here](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb)
to run the model using Google Colab.
You can also run this notebook locally.
### Installing BatDetect2
You can use pip to install `batdetect2`:
If you have `uv` installed (if not, we recommend it; follow the instructions
[here](https://docs.astral.sh/uv/getting-started/installation/)), then you can
run `batdetect2` one-off with
```bash
pip install batdetect2
uvx batdetect2
```
Alternatively, download this code from the repository (by clicking on the green button on top right) and unzip it.
Once unzipped, run this from extracted folder.
or if you want to install it permanently:
```bash
pip install .
uv tool install batdetect2
```
Make sure you have the environment activated before installing `batdetect2`.
and test it with
## Run BatDetect2 on a folder of recordings
Once installed,
the simplest current workflow is to run BatDetect2 on a folder of `.wav` files.
If you are working from this repository checkout,
you can use this example checkpoint path:
```text
src/batdetect2/models/checkpoints/Net2DFast_UK_same.pth.tar
```bash
batdetect2
```
### Run BatDetect2 on a folder of recordings
Once installed, you can run BatDetect2 on a folder of `.wav` files.
By default it will use the model trained on UK data.
Example command:
```bash
batdetect2 predict directory \
src/batdetect2/models/checkpoints/Net2DFast_UK_same.pth.tar \
example_data/audio \
outputs
batdetect2 process directory example_data/audio outputs
```
This will scan the audio files in `example_data/audio`
and save model outputs to `outputs`.
This will scan the audio files in `example_data/audio` and save model outputs to
`outputs`.
If you have your own model checkpoint, you can use it:
For the full beginner walkthrough,
use `docs/source/tutorials/run-inference-on-folder.md`.
## Legacy workflow
The sections below are kept only for people maintaining older BatDetect2 scripts and analysis pipelines.
If you are new to BatDetect2,
stop here and use the current docs and command above.
If you really do need the older workflow,
the reference material is below.
## Try the model
1) You can try a demo of the model (for UK species) on [huggingface](https://huggingface.co/spaces/macaodha/batdetect2).
2) Alternatively, click [here](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb) to run the model using Google Colab. You can also run this notebook locally.
## Running the model on your own data
After following the above steps to install the code you can run the model on your own data.
The remainder of this section is legacy reference material.
### Using the command line
The commands below describe the legacy CLI workflow.
For new work, prefer the current docs and `batdetect2 predict`.
You can run the model by opening the command line and typing:
```bash
batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD
```
e.g.
```bash
batdetect2 detect example_data/audio/ example_data/anns/ 0.3
batdetect2 process directory --model path/to/checkpoint.ckpt example_data/audio outputs
```
`AUDIO_DIR` is the path on your computer to the audio wav files of interest.
`ANN_DIR` is the path on your computer where the model predictions will be saved. The model will output both `.csv` and `.json` results for each audio file.
`DETECTION_THRESHOLD` is a number between 0 and 1 specifying the cut-off threshold applied to the calls. A smaller number will result in more calls detected, but with the chance of introducing more mistakes.
There are also optional arguments, e.g. you can request that the model outputs features (i.e. estimated call parameters) such as duration, max_frequency, etc. by setting the flag `--spec_features`. These will be saved as `*_spec_features.csv` files:
`batdetect2 detect example_data/audio/ example_data/anns/ 0.3 --spec_features`
You can also specify which model to use by setting the `--model_path` argument. If not specified, it will default to using a model trained on UK data e.g.
`batdetect2 detect example_data/audio/ example_data/anns/ 0.3 --model_path models/Net2DFast_UK_same.pth.tar`
### Using the Python API
The examples below describe the legacy Python API.
For new work, prefer `batdetect2.api_v2.BatDetect2API` and the current docs site.
If you prefer to process your data within a Python script then you can use the `batdetect2` Python API.
```python
from batdetect2 import api
AUDIO_FILE = "example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav"
# Process a whole file
results = api.process_file(AUDIO_FILE)
# Or, load audio and compute spectrograms
audio = api.load_audio(AUDIO_FILE)
spec = api.generate_spectrogram(audio)
# And process the audio or the spectrogram with the model
detections, features, spec = api.process_audio(audio)
detections, features = api.process_spectrogram(spec)
# Do something else ...
```
You can integrate the detections or the extracted features to your custom analysis pipeline.
## Training the model on your own data
Take a look at the training tutorial in the docs site first.
If you are working from this repository checkout,
start with `docs/source/tutorials/train-a-custom-model.md`.
For the full walkthrough, use
`docs/source/tutorials/run-inference-on-folder.md`.
## Data and annotations
The raw audio data and annotations used to train the models in the paper will be added soon.
The audio interface used to annotate audio data for training and evaluation is available [here](https://github.com/macaodha/batdetect2_GUI).
The raw audio data and annotations used to train the models in the paper will be
added soon.
`batdetect2` supports annotations in various formats and is compatible with the
outputs of [`whombat`](https://github.com/mbsantiago/whombat/) and this
[earlier version](https://github.com/macaodha/batdetect2_GUI).
If you're interested in supporting another format, please reach out or submit a
PR.
## Warning
The models developed and shared as part of this repository should be used with caution.
While they have been evaluated on held out audio data, great care should be taken when using the model outputs for any form of biodiversity assessment.
Your data may differ, and as a result it is very strongly recommended that you validate the model first using data with known species to ensure that the outputs can be trusted.
The models developed and shared as part of this repository should be used with
caution.
While they have been evaluated on held-out audio data, great care should be
taken when using the model outputs for any form of biodiversity assessment.
Your data may differ, and as a result it is very strongly recommended that you
validate the model first using data with known species to ensure that the
outputs can be trusted.
If you train a model, make the best effort to be transparent about its training
and evaluation data, and inform downstream users about its limitations.
## FAQ
For more information please consult our [FAQ](docs/source/faq.md).
## Reference
If you find our work useful in your research please consider citing our paper which you can find [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1):
If you find our work useful in your research, please consider citing our paper,
which you can find
[here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1):
```
@article{batdetect2_2022,
title = {Towards a General Approach for Bat Echolocation Detection and Classification},
@ -207,10 +142,6 @@ If you find our work useful in your research please consider citing our paper wh
```
## Acknowledgements
Thanks to all the contributors who spent time collecting and annotating audio data.
### TODOs
- [x] Release the code and pretrained model
- [ ] Release the datasets and annotations used the experiments in the paper
- [ ] Add the scripts used to generate the tables and figures from the paper
Thanks to all the contributors who spent time collecting and annotating audio
data.

View File

@ -1,19 +1,20 @@
# Getting started
If you want to run BatDetect2 on your recordings,
start with the command-line route below.
If you want to run BatDetect2 on your recordings, start with the command-line
route below.
You do not need to write Python code for a standard first run.
BatDetect2 also has a Python interface,
but that is mainly for users writing their own analysis scripts.
BatDetect2 also has a Python interface, but that is mainly for users writing
their own analysis scripts.
- Use the command-line route if you want to run an existing model or train your own model by typing commands in a terminal window.
- Use the command-line route if you want to run an existing model or train your
own model by typing commands in a terminal window.
- Use the Python route only if you already want to work in scripts or notebooks.
```{note}
If you are looking for the previous BatDetect2 workflow based on `batdetect2 detect` or `batdetect2.api`, go to {doc}`legacy/index`.
New docs default to the current `predict` CLI and `BatDetect2API` workflow.
New docs default to the current `process` CLI and `BatDetect2API` workflow.
```
If you want to try BatDetect2 before installing anything locally:
@ -27,15 +28,14 @@ If you want to try BatDetect2 before installing anything locally:
2. Use a model checkpoint.
3. Run the first tutorial on a folder of recordings.
If that is what you want,
you can ignore the Python sections for now.
If that is what you want, you can ignore the Python sections for now.
## Install BatDetect2
We recommend `uv` for both workflows.
`uv` is a tool that helps install Python software cleanly,
without mixing it into the rest of your machine.
`uv` is a tool that helps install Python software cleanly, without mixing it
into the rest of your machine.
- Use `uv tool` to install the CLI.
- Use `uv add` to add `batdetect2` as a dependency in a Python project.
@ -70,7 +70,8 @@ Go to {doc}`tutorials/run-inference-on-folder` for a complete first run.
## Choose a model checkpoint
The current command-line and Python workflows expect an explicit checkpoint path.
The current command-line and Python workflows expect an explicit checkpoint
path.
A checkpoint is the saved model file that BatDetect2 will use for prediction.
@ -85,7 +86,8 @@ In this repository checkout, an example pretrained checkpoint is available at:
src/batdetect2/models/checkpoints/Net2DFast_UK_same.pth.tar
```
Use that path in the tutorial commands if you want a concrete starting point from this source tree.
Use that path in the tutorial commands if you want a concrete starting point
from this source tree.
## Python route for users writing code

View File

@ -1,8 +1,9 @@
# How to choose an inference input mode
Use this guide to decide whether `predict directory`, `predict file_list`, or `predict dataset` is the right entry point for your run.
Use this guide to decide whether `process directory`, `process file_list`, or
`process dataset` is the right entry point for your run.
## Use `predict directory` when the recordings already live together
## Use `process directory` when the recordings already live together
This is the simplest choice.
@ -13,13 +14,13 @@ Use it when:
- you are doing a first pass over a folder of recordings.
```bash
batdetect2 predict directory \
batdetect2 process directory \
path/to/model.ckpt \
path/to/audio_dir \
path/to/outputs
```
## Use `predict file_list` when you need explicit control over the file set
## Use `process file_list` when you need explicit control over the file set
Use it when:
@ -30,13 +31,13 @@ Use it when:
The list file should contain one path per line.
```bash
batdetect2 predict file_list \
batdetect2 process file_list \
path/to/model.ckpt \
path/to/audio_files.txt \
path/to/outputs
```
## Use `predict dataset` when your workflow is already annotation-set driven
## Use `process dataset` when your workflow is already annotation-set driven
Use it when:
@ -45,13 +46,14 @@ Use it when:
- you want BatDetect2 to resolve recording paths from the annotation set.
```bash
batdetect2 predict dataset \
batdetect2 process dataset \
path/to/model.ckpt \
path/to/annotation_set.json \
path/to/outputs
```
The dataset command reads a `soundevent` annotation set and extracts unique recording paths before inference.
The dataset command reads a `soundevent` annotation set and extracts unique
recording paths before inference.
## Rule of thumb
@ -61,6 +63,9 @@ The dataset command reads a `soundevent` annotation set and extracts unique reco
## Related pages
- Run batch predictions: {doc}`run-batch-predictions`
- Tune inference clipping: {doc}`tune-inference-clipping`
- Predict command reference: {doc}`../reference/cli/predict`
- Run batch predictions:
{doc}`run-batch-predictions`
- Tune inference clipping:
{doc}`tune-inference-clipping`
- Process command reference:
{doc}`../reference/cli/predict`

View File

@ -1,6 +1,7 @@
# How to choose and configure evaluation tasks
Use this guide when the default evaluation tasks do not match the question you want to answer.
Use this guide when the default evaluation tasks do not match the question you
want to answer.
## Know the default first
@ -24,8 +25,10 @@ Common built-in task families include:
Choose based on the question you care about.
- Use sound-event tasks when you care about individual call events.
- Use clip tasks when you care about clip-level presence or clip-level class evidence.
- Use top-class detection when you want matching based on the highest-scoring class per detection.
- Use clip tasks when you care about clip-level presence or clip-level class
evidence.
- Use top-class detection when you want matching based on the highest-scoring
class per detection.
## Configure tasks in `EvaluationConfig`
@ -45,22 +48,27 @@ Pass the config with:
```bash
batdetect2 evaluate \
path/to/model.ckpt \
path/to/test_dataset.yaml \
--model path/to/model.ckpt \
--base-dir path/to/project_root \
--evaluation-config path/to/evaluation.yaml
```
Include `--base-dir` when the dataset config resolves recordings through relative paths.
Include `--base-dir` when the dataset config resolves recordings through
relative paths.
## Change one thing at a time
When comparing models or settings, avoid changing task definitions, thresholds, matching behavior, and datasets all at once.
When comparing models or settings, avoid changing task definitions, thresholds,
matching behavior, and datasets all at once.
Otherwise it becomes hard to explain why the metric changed.
## Related pages
- Evaluation tutorial: {doc}`../tutorials/evaluate-on-a-test-set`
- Evaluation config reference: {doc}`../reference/evaluation-config`
- Evaluation concepts: {doc}`../explanation/evaluation-concepts-and-matching`
- Evaluation tutorial:
{doc}`../tutorials/evaluate-on-a-test-set`
- Evaluation config reference:
{doc}`../reference/evaluation-config`
- Evaluation concepts:
{doc}`../explanation/evaluation-concepts-and-matching`

View File

@ -46,7 +46,7 @@ Available built-ins:
For CLI inference/evaluation, use `--audio-config`.
```bash
batdetect2 predict directory \
batdetect2 process directory \
path/to/model.ckpt \
path/to/audio_dir \
path/to/outputs \
@ -55,10 +55,12 @@ batdetect2 predict directory \
## 4) Verify quickly on a small subset
Run on a small folder first and confirm that outputs and runtime are as
expected before full-batch runs.
Run on a small folder first and confirm that outputs and runtime are as expected
before full-batch runs.
## Related pages
- Spectrogram settings: {doc}`configure-spectrogram-preprocessing`
- Preprocessing config reference: {doc}`../reference/preprocessing-config`
- Spectrogram settings:
{doc}`configure-spectrogram-preprocessing`
- Preprocessing config reference:
{doc}`../reference/preprocessing-config`

View File

@ -1,14 +1,15 @@
# How to run batch predictions
# How to run batch processing
This guide shows practical command patterns for directory-based and file-list
prediction runs.
processing runs.
Use it after you already know which input mode you want and need concrete command templates for a repeatable batch run.
Use it after you already know which input mode you want and need concrete
command templates for a repeatable batch run.
## Predict from a directory
## Process a directory
```bash
batdetect2 predict directory \
batdetect2 process directory \
path/to/model.ckpt \
path/to/audio_dir \
path/to/outputs
@ -16,27 +17,29 @@ batdetect2 predict directory \
Use this when BatDetect2 should discover the audio files for you.
## Predict from a file list
## Process a file list
```bash
batdetect2 predict file_list \
batdetect2 process file_list \
path/to/model.ckpt \
path/to/audio_files.txt \
path/to/outputs
```
Use this when another part of your workflow already produced the exact recording list to process.
Use this when another part of your workflow already produced the exact recording
list to process.
## Predict from a dataset config
## Process a dataset config
```bash
batdetect2 predict dataset \
batdetect2 process dataset \
path/to/model.ckpt \
path/to/annotation_set.json \
path/to/outputs
```
Use this when your project already has a `soundevent` annotation set and you want to extract unique recording paths from it.
Use this when your project already has a `soundevent` annotation set and you
want to extract unique recording paths from it.
## Useful options

View File

@ -1,22 +1,27 @@
# How to save predictions in different output formats
Use this guide when you need BatDetect2 outputs in a specific representation for downstream tools.
Use this guide when you need BatDetect2 outputs in a specific representation for
downstream tools.
## Choose the format that matches the job
Current built-in output formats include:
- `raw`: one NetCDF file per clip, best for rich structured outputs,
- `parquet`: tabular storage for data analysis workflows,
- `soundevent`: prediction-set JSON for soundevent-style tooling,
- `batdetect2`: legacy per-recording JSON output.
- `raw`:
one NetCDF file per clip, best for rich structured outputs,
- `parquet`:
tabular storage for data analysis workflows,
- `soundevent`:
prediction-set JSON for soundevent-style tooling,
- `batdetect2`:
legacy per-recording JSON output.
## Select a format from the CLI
Use `--format` for quick experiments.
```bash
batdetect2 predict directory \
batdetect2 process directory \
path/to/model.ckpt \
path/to/audio_dir \
path/to/outputs \
@ -25,7 +30,8 @@ batdetect2 predict directory \
## Use an outputs config for repeatable runs
Use an outputs config when you want reproducible control over format and transforms.
Use an outputs config when you want reproducible control over format and
transforms.
Example:
@ -43,7 +49,7 @@ transform:
Run with:
```bash
batdetect2 predict directory \
batdetect2 process directory \
path/to/model.ckpt \
path/to/audio_dir \
path/to/outputs \
@ -59,6 +65,9 @@ batdetect2 predict directory \
## Related pages
- Outputs config reference: {doc}`../reference/outputs-config`
- Output formats reference: {doc}`../reference/output-formats`
- Output transforms reference: {doc}`../reference/output-transforms`
- Outputs config reference:
{doc}`../reference/outputs-config`
- Output formats reference:
{doc}`../reference/output-formats`
- Output transforms reference:
{doc}`../reference/output-transforms`

View File

@ -4,7 +4,8 @@ Use this guide to compare detection outputs at different threshold values.
The goal is not to find a universal threshold.
The goal is to choose a threshold that fits your reviewed local data and the project trade-off between missed calls and false positives.
The goal is to choose a threshold that fits your reviewed local data and the
project trade-off between missed calls and false positives.
## 1) Start with a baseline run
@ -12,12 +13,12 @@ Run an initial prediction workflow and keep outputs in a dedicated folder.
## 2) Sweep threshold values
Run `predict` multiple times with different thresholds (for example `0.1`,
Run `process` multiple times with different thresholds (for example `0.1`,
`0.3`, `0.5`) and compare output counts and quality on the same validation
subset.
```bash
batdetect2 predict directory \
batdetect2 process directory \
path/to/model.ckpt \
path/to/audio_dir \
path/to/outputs_thr_03 \
@ -26,7 +27,8 @@ batdetect2 predict directory \
Keep each threshold run in a separate output directory.
That makes it easier to compare counts and inspect example files without mixing results.
That makes it easier to compare counts and inspect example files without mixing
results.
## 3) Validate against known calls
@ -38,7 +40,8 @@ Check both:
- obvious false positives,
- obvious missed calls.
If class interpretation matters downstream, inspect class ranking behavior as well, not just detection counts.
If class interpretation matters downstream, inspect class ranking behavior as
well, not just detection counts.
## 4) Record your chosen setting

View File

@ -1,6 +1,7 @@
# How to tune inference clipping
Use this guide when long recordings need to be split into smaller clips during inference.
Use this guide when long recordings need to be split into smaller clips during
inference.
## What clipping controls
@ -8,14 +9,19 @@ Use this guide when long recordings need to be split into smaller clips during i
Key fields are:
- `duration`: clip duration in seconds,
- `overlap`: overlap between adjacent clips,
- `max_empty`: how much empty padding is allowed,
- `discard_empty`: whether empty clips are dropped.
- `duration`:
clip duration in seconds,
- `overlap`:
overlap between adjacent clips,
- `max_empty`:
how much empty padding is allowed,
- `discard_empty`:
whether empty clips are dropped.
## Start from the defaults
Use the built-in clipping behavior first unless you already know you need something else.
Use the built-in clipping behavior first unless you already know you need
something else.
Only tune clipping when:
@ -25,7 +31,7 @@ Only tune clipping when:
## Override clipping with an inference config
Create an inference config file and pass it to `predict` or `evaluate`.
Create an inference config file and pass it to `process` or `evaluate`.
Example:
@ -43,7 +49,7 @@ loader:
Run with:
```bash
batdetect2 predict directory \
batdetect2 process directory \
path/to/model.ckpt \
path/to/audio_dir \
path/to/outputs \
@ -52,12 +58,16 @@ batdetect2 predict directory \
## Validate clipping changes on a small reviewed subset
Changing clipping changes what the model sees per batch and can change how events near clip boundaries behave.
Changing clipping changes what the model sees per batch and can change how
events near clip boundaries behave.
Check a reviewed subset before applying clipping changes to a full project.
## Related pages
- Inference config reference: {doc}`../reference/inference-config`
- Run batch predictions: {doc}`run-batch-predictions`
- Understanding the pipeline: {doc}`../explanation/pipeline-overview`
- Inference config reference:
{doc}`../reference/inference-config`
- Run batch predictions:
{doc}`run-batch-predictions`
- Understanding the pipeline:
{doc}`../explanation/pipeline-overview`

View File

@ -6,25 +6,20 @@ Welcome to the BatDetect2 documentation.
`batdetect2` detects bat echolocation calls in audio recordings.
It can help you screen large collections of recordings,
find files that need expert review,
and support ecology and conservation work where manual review alone would be slow.
It can help you screen large collections of recordings, find files that need
expert review, and support ecology and conservation work where manual review
alone would be slow.
In practice,
BatDetect2 takes recordings,
looks for likely bat calls,
draws a box around each detected event,
and scores the most likely class for that event.
In practice, BatDetect2 takes recordings, looks for likely bat calls, draws a
box around each detected event, and scores the most likely class for that event.
The current default model is trained for 17 UK species.
The library also supports custom training,
fine-tuning,
evaluation,
and more advanced use from Python.
The library also supports custom training, fine-tuning, evaluation, and more
advanced use from Python.
For details on the underlying approach, see the pre-print:
[Towards a General Approach for Bat Echolocation Detection and Classification](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)
[Towards a General Approach for Bat Echolocation Detection and Classification](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)
## A good first use for BatDetect2
@ -56,7 +51,7 @@ Always validate on reviewed local data before using results for ecological infer
```{note}
Looking for the previous BatDetect2 workflow?
See {doc}`legacy/index`.
The legacy docs are still available, but new workflows should use `batdetect2 predict` and `BatDetect2API`.
The legacy docs are still available, but new workflows should use `batdetect2 process` and `BatDetect2API`.
```
## How to use this site
@ -65,8 +60,7 @@ Start with {doc}`getting_started` if you are new.
Then choose the section that matches what you need.
If you are here mainly to run the model on recordings,
start with Tutorials.
If you are here mainly to run the model on recordings, start with Tutorials.
| Section | Best for | Start here |
| --- | --- | --- |
@ -81,7 +75,7 @@ start with Tutorials.
- GitHub repository:
[macaodha/batdetect2](https://github.com/macaodha/batdetect2)
- Questions, bug reports, and feature requests:
[GitHub Issues](https://github.com/macaodha/batdetect2/issues)
[GitHub Issues](https://github.com/macaodha/batdetect2/issues)
- Common questions:
{doc}`faq`
- Want to contribute?

View File

@ -4,7 +4,7 @@ This page documents the previous CLI workflow based on `batdetect2 detect`.
```{warning}
This is legacy documentation.
For new workflows, use `batdetect2 predict directory` instead.
For new workflows, use `batdetect2 process directory` instead.
If you are migrating, start with {doc}`migration-guide`.
```
@ -27,7 +27,7 @@ Common legacy options included:
The closest current CLI entry point is:
```bash
batdetect2 predict directory \
batdetect2 process directory \
path/to/model.ckpt \
path/to/audio_dir \
path/to/outputs
@ -35,5 +35,7 @@ batdetect2 predict directory \
## Related pages
- Migration guide: {doc}`migration-guide`
- Current predict docs: {doc}`../reference/cli/predict`
- Migration guide:
{doc}`migration-guide`
- Current process docs:
{doc}`../reference/cli/predict`

View File

@ -2,12 +2,15 @@
This section documents the previous BatDetect2 workflow.
Use these pages if you need to keep working with the older `batdetect2 detect` command or the older `batdetect2.api` interface.
Use these pages if you need to keep working with the older `batdetect2 detect`
command or the older `batdetect2.api` interface.
For new projects, we recommend the current workflow:
- CLI: `batdetect2 predict`
- Python: `batdetect2.api_v2.BatDetect2API`
- CLI:
`batdetect2 process`
- Python:
`batdetect2.api_v2.BatDetect2API`
If you are moving from the older workflow, start with {doc}`migration-guide`.

View File

@ -1,6 +1,7 @@
# Migration guide: legacy to current workflows
Use this guide when moving from the previous BatDetect2 workflow to the current CLI and API.
Use this guide when moving from the previous BatDetect2 workflow to the current
CLI and API.
## Who should migrate now
@ -9,31 +10,37 @@ You should migrate if:
- you are starting a new workflow,
- you want the current docs path,
- you want the newer CLI and API surface,
- you are maintaining code that does not depend on the exact legacy JSON or feature outputs.
- you are maintaining code that does not depend on the exact legacy JSON or
feature outputs.
You may need the legacy workflow a bit longer if:
- downstream tooling depends on the exact old output structure,
- you rely on older notebooks built around `batdetect2.api`,
- you depend on legacy feature extraction outputs without a validated replacement yet.
- you depend on legacy feature extraction outputs without a validated
replacement yet.
## CLI mapping
- `batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD`
-> `batdetect2 predict directory MODEL_PATH AUDIO_DIR OUTPUT_PATH --detection-threshold ...`
- `batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD` -> `batdetect2
process directory MODEL_PATH AUDIO_DIR OUTPUT_PATH --detection-threshold ...`
Main changes:
- the model path is now a positional argument on the `predict` subcommand,
- the current workflow expects an explicit checkpoint path rather than silently relying on the old default CLI behavior,
- the model path is now a positional argument on the `process` subcommand,
- the current workflow expects an explicit checkpoint path rather than silently
relying on the old default CLI behavior,
- output formatting is configurable,
- threshold override is an option rather than a required positional argument,
- there are separate subcommands for directory, file-list, and dataset-driven inference.
- there are separate subcommands for directory, file-list, and dataset-driven
inference.
## Python API mapping
- old: `import batdetect2.api as api`
- current: `from batdetect2.api_v2 import BatDetect2API`
- old:
`import batdetect2.api as api`
- current:
`from batdetect2.api_v2 import BatDetect2API`
Typical migration shape:
@ -51,7 +58,7 @@ Useful replacements:
- legacy `process_file` -> current `BatDetect2API.process_file`
- legacy `process_audio` -> current `BatDetect2API.process_audio`
- legacy `process_spectrogram` -> current `BatDetect2API.process_spectrogram`
- legacy one-off batch loops -> current `process_files` or CLI `predict`
- legacy one-off batch loops -> current `process_files` or CLI `process`
## Output and terminology changes
@ -78,7 +85,8 @@ Before replacing a legacy workflow in production or research analysis, validate:
- that outputs are being saved in the right format,
- that downstream code reads the new outputs correctly,
- that feature-related assumptions still hold,
- that evaluation and ecological interpretation are unchanged only where you have actually verified that.
- that evaluation and ecological interpretation are unchanged only where you
have actually verified that.
## Migration checklist
@ -91,6 +99,9 @@ Before replacing a legacy workflow in production or research analysis, validate:
## Related pages
- Current getting started: {doc}`../getting_started`
- Current tutorials: {doc}`../tutorials/index`
- Current API reference: {doc}`../reference/api`
- Current getting started:
{doc}`../getting_started`
- Current tutorials:
{doc}`../tutorials/index`
- Current API reference:
{doc}`../reference/api`

View File

@ -1,65 +1,33 @@
# `BatDetect2API` reference
`BatDetect2API` is the main entry point for the current Python workflow.
`BatDetect2API` is the main Python entry point for BatDetect2.
It wraps model loading, inference, evaluation, output formatting, and
training-related entry points behind one object.
Use it when you want to load a model, run prediction, inspect detections,
evaluate results, or train from Python.
Defined in `batdetect2.api_v2`.
## Create an API instance
## Main ways to create it
- `BatDetect2API.from_checkpoint(path, ...)`
- load a trained checkpoint and optional config overrides.
- load a trained checkpoint, a bundled checkpoint alias, or a Hugging Face
checkpoint.
- `BatDetect2API.from_config(model_config=..., targets_config=..., ...)`
- build a full stack from separate config objects.
- build a full model stack from config objects.
## Inference methods
## Common tasks
- `process_file(audio_file, ...)`
- run inference for one recording.
- `process_files(audio_files, ...)`
- run batch inference across a sequence of file paths.
- `process_directory(audio_dir, ...)`
- run inference across the audio files found in one directory.
- `process_clips(clips, ...)`
- run inference on an explicit sequence of clip objects.
- `process_audio(audio, ...)`
- run inference starting from a waveform array.
- `process_spectrogram(spec, ...)`
- run inference starting from a spectrogram tensor.
- Load a checkpoint and run prediction on one file.
- Run prediction on many files or clips.
- Save predictions in one of the supported output formats.
- Evaluate a model on labelled data.
- Fine-tune an existing checkpoint on new targets.
## Prediction inspection helpers
## Generated reference
- `get_top_class_name(detection)`
- return the highest-scoring class name for one detection.
- `get_class_scores(detection, include_top_class=True, sort_descending=True)`
- return ranked `(class_name, score)` pairs.
- `get_detection_features(detection)`
- return the per-detection feature vector.
## Audio loading helpers
- `load_audio(path)`
- `load_recording(recording)`
- `load_clip(clip)`
- `generate_spectrogram(audio)`
## Output persistence helpers
- `save_predictions(predictions, path, audio_dir=None, format=None,
config=None)`
- `load_predictions(path, format=None, config=None)`
Use these when you want to save programmatic predictions without going through
the CLI.
## Training and evaluation entry points
- `train(...)`
- `finetune(...)`
- `evaluate(...)`
- `evaluate_predictions(...)`
```{eval-rst}
.. autoclass:: batdetect2.api_v2.BatDetect2API
```
## Related pages

View File

@ -4,13 +4,13 @@ Legacy detect command
.. warning::
``batdetect2 detect`` is a legacy compatibility command.
Prefer ``batdetect2 predict directory`` for new workflows.
Prefer ``batdetect2 process directory`` for new workflows.
Migration at a glance
---------------------
- Legacy: ``batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD``
- Current: ``batdetect2 predict directory MODEL_PATH AUDIO_DIR OUTPUT_PATH``
- Current: ``batdetect2 process directory MODEL_PATH AUDIO_DIR OUTPUT_PATH``
with optional ``--detection-threshold``
.. click:: batdetect2.cli.compat:detect

View File

@ -1,7 +1,10 @@
Evaluate command
================
Evaluate a checkpoint against a configured test dataset.
Use ``batdetect2 evaluate`` to compare a checkpoint against labelled test data.
This command writes metrics and any configured artifacts to the output
directory.
.. click:: batdetect2.cli.evaluate:evaluate_command
:prog: batdetect2 evaluate

View File

@ -0,0 +1,11 @@
Finetune command
================
Use ``batdetect2 finetune`` to adapt an existing checkpoint to a new target
definition.
If you do not pass ``--model``, the bundled ``uk_same`` checkpoint is used.
.. click:: batdetect2.cli.finetune:finetune_command
:prog: batdetect2 finetune
:nested: none

View File

@ -1,35 +1,33 @@
# CLI reference
Use this section to find the right command quickly, then open the command page
for full options and argument details.
## How to use this section
1. Start with {doc}`base` for options shared across the CLI.
2. Pick the command group or command you need from the command map below.
3. Open the linked page for complete autogenerated option reference.
for the full option list.
## Command map
| Command | Use it for | Required positional args |
| --- | --- | --- |
| `batdetect2 predict` | Run inference on audio | Depends on subcommand (`directory`, `file_list`, `dataset`) |
| `batdetect2 process` | Run inference on audio | Depends on subcommand (`directory`, `file_list`, `dataset`) |
| `batdetect2 data` | Inspect and convert dataset configs | Depends on subcommand (`summary`, `convert`) |
| `batdetect2 train` | Train or fine-tune models | `TRAIN_DATASET` |
| `batdetect2 evaluate` | Evaluate a checkpoint on a test dataset | `MODEL_PATH`, `TEST_DATASET` |
| `batdetect2 finetune` | Fine-tune a checkpoint on new targets | `TRAIN_DATASET` plus `--targets` |
| `batdetect2 evaluate` | Evaluate a checkpoint on a test dataset | `TEST_DATASET` |
| `batdetect2 detect` | Legacy compatibility workflow | `AUDIO_DIR`, `ANN_DIR`, `DETECTION_THRESHOLD` |
## Global options and conventions
## Notes
- Global CLI options are documented in {doc}`base`.
- Paths with spaces should be wrapped in quotes.
- Input audio is expected to be mono.
- Legacy `detect` uses a required threshold argument, while `predict` uses the
optional `--detection-threshold` override.
- `process` uses the optional `--detection-threshold` override.
- `evaluate` takes `TEST_DATASET` as a positional argument and uses `--model`
for the checkpoint override.
- `finetune` defaults to the bundled `uk_same` checkpoint if `--model` is not
provided.
```{warning}
`batdetect2 detect` is a legacy command.
Prefer `batdetect2 predict directory` for new workflows.
Prefer `batdetect2 process directory` for new workflows.
```
## Related pages
@ -43,9 +41,10 @@ Prefer `batdetect2 predict directory` for new workflows.
:maxdepth: 1
Base command and global options <base>
Predict command group <predict>
Process command group <predict>
Data command group <data>
Train command <train>
Finetune command <finetune>
Evaluate command <evaluate>
Legacy detect command <detect_legacy>
```

View File

@ -1,9 +1,17 @@
Predict command
Process command
===============
Run model inference from a directory, a file list, or a dataset.
Use ``--detection-threshold`` to override the model default per run.
Use ``batdetect2 process`` to run inference on audio.
.. click:: batdetect2.cli.inference:predict
:prog: batdetect2 predict
Choose a subcommand based on how you want to provide the input:
- ``directory`` for all supported audio files in one folder
- ``file_list`` for a text file with one audio path per line
- ``dataset`` for recordings referenced by a dataset file
Use ``--detection-threshold`` when you want to override the configured
threshold for one run.
.. click:: batdetect2.cli.inference:process
:prog: batdetect2 process
:nested: full

View File

@ -1,7 +1,11 @@
Train command
=============
Train a model from dataset configs or fine-tune from a checkpoint.
Use ``batdetect2 train`` to start from a fresh model config or continue from an
existing checkpoint.
If you want to adapt an existing checkpoint to a new target definition, use
``batdetect2 finetune`` instead.
.. click:: batdetect2.cli.train:train_command
:prog: batdetect2 train

View File

@ -3,7 +3,8 @@
This tutorial shows how to evaluate a trained checkpoint on a held-out dataset
and inspect the output metrics.
This tutorial is for advanced users who want to compare one trained model against a separate test dataset.
This tutorial is for advanced users who want to compare one trained model
against a separate test dataset.
## Before you start
@ -32,22 +33,22 @@ Use a dataset that was not used for training or tuning.
A held-out dataset is simply a separate dataset kept aside for evaluation.
If you tune thresholds or configs on the same dataset that you report as final evaluation, the results will be optimistic.
If you tune thresholds or configs on the same dataset that you report as final
evaluation, the results will be optimistic.
## 2. Run evaluation
```bash
batdetect2 evaluate \
path/to/model.ckpt \
path/to/test_dataset.yaml \
--model path/to/model.ckpt \
--base-dir path/to/project_root \
--output-dir path/to/eval_outputs
```
This command loads the checkpoint,
runs prediction on the test dataset,
applies the chosen evaluation tasks,
and writes metrics and result files to the output directory.
This command loads the checkpoint, runs prediction on the test dataset, applies
the chosen evaluation tasks, and writes metrics and result files to the output
directory.
Use `--base-dir` whenever the dataset config contains relative paths.
@ -73,7 +74,8 @@ Check:
- which task the metric belongs to,
- which thresholding or matching assumptions were used,
- whether class-level behavior matches your use case,
- whether the failures are concentrated in specific taxa, sites, or recording conditions.
- whether the failures are concentrated in specific taxa, sites, or recording
conditions.
## 5. Record the evaluation setup
@ -85,7 +87,11 @@ That matters for reproducibility and for later model comparisons.
- Compare thresholds on representative files:
{doc}`../how_to/tune-detection-threshold`
- Configure evaluation tasks: {doc}`../how_to/choose-and-configure-evaluation-tasks`
- Interpret evaluation artifacts: {doc}`../how_to/interpret-evaluation-outputs`
- Learn the evaluation concepts: {doc}`../explanation/evaluation-concepts-and-matching`
- Check full evaluate options: {doc}`../reference/cli/evaluate`
- Configure evaluation tasks:
{doc}`../how_to/choose-and-configure-evaluation-tasks`
- Interpret evaluation artifacts:
{doc}`../how_to/interpret-evaluation-outputs`
- Learn the evaluation concepts:
{doc}`../explanation/evaluation-concepts-and-matching`
- Check full evaluate options:
{doc}`../reference/cli/evaluate`

View File

@ -4,7 +4,8 @@ This tutorial walks through a first end-to-end inference run with the CLI.
It is the default starting point for new users.
Use it when you want to run an existing model on a folder of recordings and quickly check what BatDetect2 found.
Use it when you want to run an existing model on a folder of recordings and
quickly check what BatDetect2 found.
## Before you start
@ -24,7 +25,7 @@ src/batdetect2/models/checkpoints/Net2DFast_UK_same.pth.tar
By the end of this tutorial you will have:
- run `batdetect2 predict directory`,
- run `batdetect2 process directory`,
- saved predictions to disk,
- checked that BatDetect2 wrote output files,
- identified the next pages to use for tuning or customization.
@ -48,12 +49,13 @@ project/
outputs/
```
## 2. Run prediction on the directory
## 2. Run processing on the directory
Use this command when you want BatDetect2 to scan a folder of recordings automatically.
Use this command when you want BatDetect2 to scan a folder of recordings
automatically.
```bash
batdetect2 predict directory \
batdetect2 process directory \
path/to/model.pth.tar \
path/to/audio_dir \
path/to/outputs
@ -70,8 +72,7 @@ What this does:
After the command completes, inspect the output directory.
For a first run,
the important check is simple:
For a first run, the important check is simple:
- did BatDetect2 create result files,
- are they in the output directory you expected,
@ -81,8 +82,8 @@ Different workflows can save results in different file formats.
You do not need to learn those details for the first run.
If you later need to choose a specific output format,
go to {doc}`../how_to/save-predictions-in-different-output-formats`.
If you later need to choose a specific output format, go to
{doc}`../how_to/save-predictions-in-different-output-formats`.
## 4. Inspect predictions
@ -103,13 +104,17 @@ Validation comes next.
## 5. Tune only after you have a baseline
If the first run is too noisy or misses obvious calls, tune thresholds on a reviewed subset rather than changing settings blindly across the full dataset.
If the first run is too noisy or misses obvious calls, tune thresholds on a
reviewed subset rather than changing settings blindly across the full dataset.
Use {doc}`../how_to/tune-detection-threshold` for that process.
## What to do next
- If you need a different input mode, use {doc}`../how_to/choose-an-inference-input-mode`.
- If you want to tune sensitivity, use {doc}`../how_to/tune-detection-threshold`.
- If you already write code and want more control from Python, use {doc}`integrate-with-a-python-pipeline`.
- If you need a different input mode, use
{doc}`../how_to/choose-an-inference-input-mode`.
- If you want to tune sensitivity, use
{doc}`../how_to/tune-detection-threshold`.
- If you already write code and want more control from Python, use
{doc}`integrate-with-a-python-pipeline`.
- If you need full command details, use {doc}`../reference/cli/predict`.

View File

@ -17,6 +17,10 @@ help:
install:
uv sync
# Install full development dependencies for CI and docs builds.
install-dev:
uv sync --all-extras --dev
# Testing & Coverage
# Run tests using pytest.
test:
@ -50,6 +54,9 @@ coverage-serve: coverage-html
docs:
uv run sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
# Check that documentation builds successfully.
check-docs: docs
# Serve documentation with live reload.
docs-serve:
uv run sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
@ -84,6 +91,25 @@ check-types:
# Run all checks (format-check, lint, typecheck).
check: check-format check-lint check-types
# Run the standard CI validation sequence.
ci: check test
# Build source and wheel distributions.
build-dist:
uv run --with build python -m build
# Bump the patch version, commit, and tag.
bump-patch:
uvx bump2version patch
# Bump the minor version, commit, and tag.
bump-minor:
uvx bump2version minor
# Bump the major version, commit, and tag.
bump-major:
uvx bump2version major
# Cleaning tasks
# Remove Python bytecode and cache.
clean-pyc:

View File

@ -7,7 +7,6 @@ authors = [
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" },
]
dependencies = [
"cf-xarray>=0.9.0",
"click>=8.1.7",
"deepmerge>=2.0",
"hydra-core>=1.3.2",
@ -16,21 +15,19 @@ dependencies = [
"loguru>=0.7.3",
"matplotlib>=3.7.1",
"netcdf4>=1.6.5",
"numba>=0.60",
"numpy>=1.23.5",
"omegaconf>=2.3.0",
"onnx>=1.16.0",
"pandas>=1.5.3",
"pydantic>=2.0.0",
"pyyaml>=6.0.2",
"scikit-learn>=1.2.2",
"scipy>=1.10.1",
"seaborn>=0.13.2",
"soundevent[audio,geometry,plot]>=2.10.0",
"soundfile>=0.12.1",
"tensorboard>=2.16.2",
"torch>=1.13.1",
"torchaudio>=1.13.1",
"torchvision>=0.14.0",
"tqdm>=4.66.2",
"xarray>=2024.0.0",
]
requires-python = ">=3.10,<3.14"
readme = "README.md"
@ -66,6 +63,7 @@ build-backend = "hatchling.build"
batdetect2 = "batdetect2.cli:cli"
[dependency-groups]
huggingface = ["huggingface-hub>=0.32.0"]
jupyter = ["ipywidgets>=8.1.5", "jupyter>=1.1.1"]
marimo = ["marimo>=0.12.2", "pyarrow>=20.0.0"]
dev = [

View File

@ -1,11 +1,25 @@
import logging
from typing import TYPE_CHECKING
from loguru import logger
if TYPE_CHECKING:
from batdetect2.api_v2 import BatDetect2API
logger.disable("batdetect2")
numba_logger = logging.getLogger("numba")
numba_logger.setLevel(logging.WARNING)
__all__ = ["BatDetect2API", "__version__"]
__version__ = "1.1.1"
def __getattr__(name: str):
if name == "BatDetect2API":
from batdetect2.api_v2 import BatDetect2API
return BatDetect2API
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -3,13 +3,12 @@ from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, Literal
import numpy as np
from soundevent import data
if TYPE_CHECKING:
from collections.abc import Sequence
import numpy as np
import torch
from soundevent import data
from batdetect2.audio import AudioConfig, AudioLoader
from batdetect2.data import Dataset
@ -20,7 +19,8 @@ if TYPE_CHECKING:
LoggerConfig,
LoggingCallback,
)
from batdetect2.models import Model, ModelConfig
from batdetect2.models import ModelConfig
from batdetect2.models.types import ModelProtocol
from batdetect2.outputs import (
OutputFormatConfig,
OutputFormatterProtocol,
@ -48,6 +48,31 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
class BatDetect2API:
"""High-level interface for the BatDetect2 workflow.
Use this to load a model, run inference, inspect detections,
evaluate predictions, and train or fine-tune models.
In most cases, start with :meth:`from_checkpoint` to load a trained model.
Use :meth:`from_config` when you want to build a new model with custom
configs.
Examples
--------
Load the default checkpoint and run prediction on one file.
>>> from batdetect2.api_v2 import BatDetect2API
>>> api = BatDetect2API.from_checkpoint()
>>> prediction = api.process_file("recording.wav")
Load a checkpoint and save predictions for a folder of audio.
>>> from pathlib import Path
>>> api = BatDetect2API.from_checkpoint("uk_same")
>>> predictions = api.process_directory("audio")
>>> api.save_predictions(predictions, "outputs/")
"""
def __init__(
self,
model_config: ModelConfig,
@ -65,8 +90,49 @@ class BatDetect2API:
evaluator: EvaluatorProtocol,
formatter: OutputFormatterProtocol,
output_transform: OutputTransformProtocol,
model: Model,
model: ModelProtocol,
):
"""Create a fully configured API instance.
This initializer is mainly for internal use.
In most cases, users should create the API with
:meth:`from_checkpoint` or :meth:`from_config`.
Parameters
----------
model_config : ModelConfig
Model configuration.
audio_config : AudioConfig
Audio loading configuration.
train_config : TrainingConfig
Training configuration.
evaluation_config : EvaluationConfig
Evaluation configuration.
inference_config : InferenceConfig
Inference configuration.
outputs_config : OutputsConfig
Output formatting configuration.
logging_config : AppLoggingConfig
Logging configuration.
targets : TargetProtocol
Target definition used by the model.
roi_mapper : ROIMapperProtocol
ROI mapping used for size targets.
audio_loader : AudioLoader
Audio loader.
preprocessor : PreprocessorProtocol
Preprocessor used before the detector.
postprocessor : PostprocessorProtocol
Postprocessor used after the detector.
evaluator : EvaluatorProtocol
Evaluator used for metrics.
formatter : OutputFormatterProtocol
Default formatter used to save predictions.
output_transform : OutputTransformProtocol
Transform that converts model outputs into detections.
model : ModelProtocol
Model instance.
"""
self.model_config = model_config
self.audio_config = audio_config
self.train_config = train_config
@ -91,6 +157,21 @@ class BatDetect2API:
path: data.PathLike,
base_dir: data.PathLike | None = None,
) -> Dataset:
"""Load a set of annotations from a dataset config file.
Parameters
----------
path : data.PathLike
Path to the dataset config file.
base_dir : data.PathLike | None, optional
Base directory used to resolve relative paths in the dataset
config.
Returns
-------
Dataset
Loaded dataset of annotations.
"""
from batdetect2.data import load_dataset_from_config
return load_dataset_from_config(path, base_dir=base_dir)
@ -107,12 +188,50 @@ class BatDetect2API:
num_epochs: int | None = None,
run_name: str | None = None,
seed: int | None = None,
model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None,
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
):
"""Train the current model on a set of annotations.
Parameters
----------
train_annotations : Sequence[data.ClipAnnotation]
Training annotations.
val_annotations : Sequence[data.ClipAnnotation] | None, optional
Validation annotations. If omitted, training runs without a
validation set.
train_workers : int, optional
Number of worker processes for training data loading.
val_workers : int, optional
Number of worker processes for validation data loading.
checkpoint_dir : Path | None, optional
Directory where checkpoints are saved.
log_dir : Path | None, optional
Directory where logs are written.
experiment_name : str | None, optional
Experiment name used by the configured logger.
num_epochs : int | None, optional
Maximum number of training epochs.
run_name : str | None, optional
Run name used by the configured logger.
seed : int | None, optional
Random seed for reproducibility.
audio_config : AudioConfig | None, optional
Audio config override.
train_config : TrainingConfig | None, optional
Training config override.
logger_config : LoggerConfig | None, optional
Training logger config override.
logging_callbacks : Sequence[LoggingCallback[TrainLoggingContext]], optional
Extra logging callbacks to run during training setup.
Returns
-------
BatDetect2API
This API instance with the trained model.
"""
from batdetect2.train import run_train
self.model.train()
@ -122,7 +241,6 @@ class BatDetect2API:
model=self.model,
targets=self.targets,
roi_mapper=self.roi_mapper,
model_config=model_config or self.model_config,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
train_workers=train_workers,
@ -147,7 +265,7 @@ class BatDetect2API:
targets_config: TargetConfig,
val_annotations: Sequence[data.ClipAnnotation] | None = None,
trainable: Literal[
"all", "heads", "classifier_head", "bbox_head"
"all", "heads", "classifier_head", "size_head"
] = "heads",
train_workers: int = 0,
val_workers: int = 0,
@ -162,7 +280,52 @@ class BatDetect2API:
logger_config: LoggerConfig | None = None,
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
) -> "BatDetect2API":
"""Fine-tune from a checkpoint using a new target definition."""
"""Fine-tune the current model for new target sounds.
Use this when you want to keep the existing model weights but change
the target sounds. You can fine-tune the whole model or just the
heads.
Parameters
----------
train_annotations : Sequence[data.ClipAnnotation]
Training annotations.
targets_config : TargetConfig
Target definition to train against.
val_annotations : Sequence[data.ClipAnnotation] | None, optional
Validation annotations.
trainable : {"all", "heads", "classifier_head", "size_head"}, optional
Which model parameters remain trainable.
train_workers : int, optional
Number of worker processes for training data loading.
val_workers : int, optional
Number of worker processes for validation data loading.
checkpoint_dir : Path | None, optional
Directory where checkpoints are saved.
log_dir : Path | None, optional
Directory where logs are written.
experiment_name : str | None, optional
Experiment name used by the configured logger.
num_epochs : int | None, optional
Maximum number of training epochs.
run_name : str | None, optional
Run name used by the configured logger.
seed : int | None, optional
Random seed for reproducibility.
audio_config : AudioConfig | None, optional
Audio config override.
train_config : TrainingConfig | None, optional
Training config override.
logger_config : LoggerConfig | None, optional
Training logger config override.
logging_callbacks : Sequence[LoggingCallback[TrainLoggingContext]], optional
Extra logging callbacks to run during training setup.
Returns
-------
BatDetect2API
A new API instance configured for the new targets.
"""
from batdetect2.evaluate import build_evaluator
from batdetect2.models import build_model_with_new_targets
from batdetect2.outputs import (
@ -225,7 +388,6 @@ class BatDetect2API:
model=api.model,
targets=api.targets,
roi_mapper=api.roi_mapper,
model_config=api.model_config,
preprocessor=api.preprocessor,
audio_loader=api.audio_loader,
train_workers=train_workers,
@ -257,6 +419,36 @@ class BatDetect2API:
outputs_config: OutputsConfig | None = None,
logger_config: LoggerConfig | None = None,
) -> tuple[dict[str, float], list[ClipDetections]]:
"""Evaluate the current model on a labelled dataset.
Parameters
----------
test_annotations : Sequence[data.ClipAnnotation]
Labelled clips used for evaluation.
num_workers : int, optional
Number of worker processes for dataset loading.
output_dir : data.PathLike, optional
Directory where metrics and plots are written.
experiment_name : str | None, optional
Experiment name used by the configured logger.
run_name : str | None, optional
Run name used by the configured logger.
save_predictions : bool, optional
If ``True``, save formatted predictions alongside metrics.
audio_config : AudioConfig | None, optional
Audio config override.
evaluation_config : EvaluationConfig | None, optional
Evaluation config override.
outputs_config : OutputsConfig | None, optional
Output config override.
logger_config : LoggerConfig | None, optional
Evaluation logger config override.
Returns
-------
tuple[dict[str, float], list[ClipDetections]]
Evaluation metrics and per-clip predictions.
"""
from batdetect2.evaluate import run_evaluate
return run_evaluate(
@ -283,6 +475,22 @@ class BatDetect2API:
predictions: Sequence[ClipDetections],
output_dir: data.PathLike | None = None,
):
"""Evaluate an existing set of predictions.
Parameters
----------
annotations : Sequence[data.ClipAnnotation]
Reference annotations.
predictions : Sequence[ClipDetections]
Predictions to compare against the annotations.
output_dir : data.PathLike | None, optional
Directory where metrics and plots are written.
Returns
-------
dict[str, float]
Computed evaluation metrics.
"""
from batdetect2.evaluate import save_evaluation_results
clip_evals = self.evaluator.evaluate(
@ -302,16 +510,65 @@ class BatDetect2API:
return metrics
def load_audio(self, path: data.PathLike) -> np.ndarray:
"""Load one audio file into a waveform array.
Parameters
----------
path : data.PathLike
Path to the audio file.
Returns
-------
np.ndarray
Audio waveform loaded from disk.
"""
return self.audio_loader.load_file(path)
def load_recording(self, recording: data.Recording) -> np.ndarray:
"""Load one recording object into a waveform array.
Parameters
----------
recording : data.Recording
Recording object describing the audio to load.
Returns
-------
np.ndarray
Audio waveform for the requested recording.
"""
return self.audio_loader.load_recording(recording)
def load_clip(self, clip: data.Clip) -> np.ndarray:
"""Load one clip object into a waveform array.
Parameters
----------
clip : data.Clip
Clip object describing the section of audio to load.
Returns
-------
np.ndarray
Audio waveform for the requested clip.
"""
return self.audio_loader.load_clip(clip)
def get_top_class_name(self, detection: Detection) -> str:
"""Get highest-confidence class name for one detection."""
"""Get the name of the highest-confidence class for one detection.
Parameters
----------
detection : Detection
Detection whose class scores will be inspected.
Returns
-------
str
Class name with the highest score.
"""
import numpy as np
top_index = int(np.argmax(detection.class_scores))
return self.targets.class_names[top_index]
@ -323,7 +580,22 @@ class BatDetect2API:
include_top_class: bool = True,
sort_descending: bool = True,
) -> list[tuple[str, float]]:
"""Get class score list as ``(class_name, score)`` pairs."""
"""Get class scores as ``(class_name, score)`` pairs.
Parameters
----------
detection : Detection
Detection whose class scores will be returned.
include_top_class : bool, optional
If ``False``, omit the highest-scoring class from the result.
sort_descending : bool, optional
If ``True``, sort scores from highest to lowest.
Returns
-------
list[tuple[str, float]]
Class-score pairs for the detection.
"""
scores = [
(class_name, float(score))
@ -347,16 +619,22 @@ class BatDetect2API:
if class_name != top_class_name
]
@staticmethod
def get_detection_features(detection: Detection) -> np.ndarray:
"""Get extracted feature vector for one detection."""
return detection.features
def generate_spectrogram(
self,
audio: np.ndarray,
) -> torch.Tensor:
"""Convert a waveform array into a spectrogram tensor.
Parameters
----------
audio : np.ndarray
Audio waveform.
Returns
-------
torch.Tensor
Spectrogram tensor ready for model inference.
"""
import torch
tensor = torch.tensor(audio).unsqueeze(0)
@ -368,6 +646,25 @@ class BatDetect2API:
batch_size: int | None = None,
detection_threshold: float | None = None,
) -> ClipDetections:
"""Run inference on one audio file.
Parameters
----------
audio_file : data.PathLike
Path to the audio file.
batch_size : int | None, optional
Batch size override. If omitted, the inference config value is
used.
detection_threshold : float | None, optional
Detection score threshold override.
Returns
-------
ClipDetections
Predictions for the full recording.
"""
from soundevent import data
from batdetect2.postprocess import ClipDetections
recording = data.Recording.from_file(audio_file, compute_hash=False)
@ -402,6 +699,20 @@ class BatDetect2API:
audio: np.ndarray,
detection_threshold: float | None = None,
) -> list[Detection]:
"""Run inference on a waveform array.
Parameters
----------
audio : np.ndarray
Audio waveform.
detection_threshold : float | None, optional
Detection score threshold override.
Returns
-------
list[Detection]
Detected calls.
"""
spec = self.generate_spectrogram(audio)
return self.process_spectrogram(
spec,
@ -414,6 +725,27 @@ class BatDetect2API:
start_time: float = 0,
detection_threshold: float | None = None,
) -> list[Detection]:
"""Run inference on one spectrogram tensor.
Parameters
----------
spec : torch.Tensor
Spectrogram tensor for one recording or clip.
start_time : float, optional
Start time in seconds used when creating detections.
detection_threshold : float | None, optional
Detection score threshold override.
Returns
-------
list[Detection]
Detected calls.
Raises
------
ValueError
If a batched spectrogram with more than one item is provided.
"""
if spec.ndim == 4 and spec.shape[0] > 1:
raise ValueError("Batched spectrograms not supported.")
@ -436,6 +768,20 @@ class BatDetect2API:
audio_dir: data.PathLike,
detection_threshold: float | None = None,
) -> list[ClipDetections]:
"""Run inference on all supported audio files in a directory.
Parameters
----------
audio_dir : data.PathLike
Directory containing audio files.
detection_threshold : float | None, optional
Detection score threshold override.
Returns
-------
list[ClipDetections]
Predictions for all supported audio files found in the directory.
"""
from soundevent.audio.files import get_audio_files
files = list(get_audio_files(audio_dir))
@ -454,6 +800,30 @@ class BatDetect2API:
output_config: OutputsConfig | None = None,
detection_threshold: float | None = None,
) -> list[ClipDetections]:
"""Run inference on multiple audio files.
Parameters
----------
audio_files : Sequence[data.PathLike]
Audio file paths.
batch_size : int | None, optional
Batch size override.
num_workers : int, optional
Number of worker processes for audio loading.
audio_config : AudioConfig | None, optional
Audio config override.
inference_config : InferenceConfig | None, optional
Inference config override.
output_config : OutputsConfig | None, optional
Output config override.
detection_threshold : float | None, optional
Detection score threshold override.
Returns
-------
list[ClipDetections]
Predictions for each input file.
"""
from batdetect2.inference import process_file_list
return process_file_list(
@ -482,6 +852,30 @@ class BatDetect2API:
output_config: OutputsConfig | None = None,
detection_threshold: float | None = None,
) -> list[ClipDetections]:
"""Run inference on multiple clip objects.
Parameters
----------
clips : Sequence[data.Clip]
Clips to process.
batch_size : int | None, optional
Batch size override.
num_workers : int, optional
Number of worker processes for audio loading.
audio_config : AudioConfig | None, optional
Audio config override.
inference_config : InferenceConfig | None, optional
Inference config override.
output_config : OutputsConfig | None, optional
Output config override.
detection_threshold : float | None, optional
Detection score threshold override.
Returns
-------
list[ClipDetections]
Predictions for each input clip.
"""
from batdetect2.inference import run_batch_inference
return run_batch_inference(
@ -508,6 +902,21 @@ class BatDetect2API:
format: str | None = None,
config: OutputFormatConfig | None = None,
):
"""Save predictions to disk in one of the supported output formats.
Parameters
----------
predictions : Sequence[ClipDetections]
Predictions to save.
path : data.PathLike
Output file or directory path, depending on the selected format.
audio_dir : data.PathLike | None, optional
Audio root directory used when writing relative paths.
format : str | None, optional
Output format name override.
config : OutputFormatConfig | None, optional
Output format config override.
"""
from batdetect2.outputs import get_output_formatter
formatter = self.formatter
@ -529,6 +938,22 @@ class BatDetect2API:
format: str | None = None,
config: OutputFormatConfig | None = None,
) -> list[object]:
"""Load predictions from disk.
Parameters
----------
path : data.PathLike
Path to a saved prediction file or directory.
format : str | None, optional
Output format name override.
config : OutputFormatConfig | None, optional
Output format config override.
Returns
-------
list[object]
Loaded prediction objects returned by the selected formatter.
"""
from batdetect2.outputs import get_output_formatter
formatter = self.formatter
@ -555,6 +980,36 @@ class BatDetect2API:
outputs_config: OutputsConfig | None = None,
logging_config: AppLoggingConfig | None = None,
) -> "BatDetect2API":
"""Build an API instance from config objects.
Use this when you want to create a new model without loading a saved
checkpoint.
Parameters
----------
model_config : ModelConfig | None, optional
Model config. If omitted, the default model config is used.
targets_config : TargetConfig | None, optional
Target config. If omitted, the default target config is used.
audio_config : AudioConfig | None, optional
Audio config. If omitted, the default audio config is used.
train_config : TrainingConfig | None, optional
Training config. If omitted, the default training config is used.
evaluation_config : EvaluationConfig | None, optional
Evaluation config. If omitted, the default evaluation config is
used.
inference_config : InferenceConfig | None, optional
Inference config. If omitted, the default inference config is used.
outputs_config : OutputsConfig | None, optional
Output config. If omitted, the default outputs config is used.
logging_config : AppLoggingConfig | None, optional
Logging config. If omitted, the default logging config is used.
Returns
-------
BatDetect2API
Configured API instance.
"""
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.evaluate import EvaluationConfig, build_evaluator
from batdetect2.inference import InferenceConfig
@ -653,7 +1108,7 @@ class BatDetect2API:
@classmethod
def from_checkpoint(
cls,
path: data.PathLike,
path: data.PathLike | str | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
@ -661,6 +1116,31 @@ class BatDetect2API:
outputs_config: OutputsConfig | None = None,
logging_config: AppLoggingConfig | None = None,
) -> "BatDetect2API":
"""Build an API instance from a saved checkpoint.
Parameters
----------
path : data.PathLike | str | None, optional
Checkpoint path, bundled checkpoint alias, or Hugging Face URI.
If omitted, the default bundled checkpoint is used.
audio_config : AudioConfig | None, optional
Audio config override.
train_config : TrainingConfig | None, optional
Training config override.
evaluation_config : EvaluationConfig | None, optional
Evaluation config override.
inference_config : InferenceConfig | None, optional
Inference config override.
outputs_config : OutputsConfig | None, optional
Output config override.
logging_config : AppLoggingConfig | None, optional
Logging config override.
Returns
-------
BatDetect2API
Configured API instance.
"""
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.evaluate import EvaluationConfig, build_evaluator
from batdetect2.inference import InferenceConfig
@ -759,7 +1239,7 @@ class BatDetect2API:
def _set_trainable_parameters(
self,
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
trainable: Literal["all", "heads", "classifier_head", "size_head"],
) -> None:
detector = self.model.detector
@ -775,6 +1255,6 @@ class BatDetect2API:
for parameter in detector.classifier_head.parameters():
parameter.requires_grad = True
if trainable in {"heads", "bbox_head"}:
for parameter in detector.bbox_head.parameters():
if trainable in {"heads", "size_head"}:
for parameter in detector.size_head.parameters():
parameter.requires_grad = True

View File

@ -3,7 +3,7 @@ from batdetect2.cli.compat import detect
from batdetect2.cli.data import data
from batdetect2.cli.evaluate import evaluate_command
from batdetect2.cli.finetune import finetune_command
from batdetect2.cli.inference import predict
from batdetect2.cli.inference import process
from batdetect2.cli.train import train_command
__all__ = [
@ -13,7 +13,7 @@ __all__ = [
"train_command",
"finetune_command",
"evaluate_command",
"predict",
"process",
]

View File

@ -2,35 +2,39 @@
import click
from batdetect2.cli.ascii import BATDETECT_ASCII_ART
__all__ = [
"cli",
]
INFO_STR = """
BatDetect2 - Detection and Classification
Assumes audio files are mono, not stereo.
Spaces in the input paths will throw an error. Wrap in quotes.
Input files should be short in duration e.g. < 30 seconds.
BatDetect2
Wrap paths that contain spaces in quotes.
"""
@click.group()
@click.group(invoke_without_command=True)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def cli(verbose: int = 0):
@click.pass_context
def cli(ctx: click.Context, verbose: int = 0):
"""Run the BatDetect2 CLI.
This command initializes logging and exposes subcommands for prediction,
training, evaluation, and dataset utilities.
Use subcommands to run processing, training, evaluation, and dataset
utilities.
"""
click.echo(INFO_STR)
if ctx.invoked_subcommand is None:
click.echo(BATDETECT_ASCII_ART)
click.echo(ctx.get_help())
ctx.exit()
from batdetect2.logging import enable_logging
enable_logging(verbose)
# click.echo(BATDETECT_ASCII_ART)

View File

@ -15,7 +15,7 @@ DEFAULT_MODEL_PATH = os.path.join(
@cli.command(
short_help="Legacy detection command.",
epilog=(
"Deprecated workflow. Prefer `batdetect2 predict directory` for "
"Deprecated workflow. Prefer `batdetect2 process directory` for "
"new analyses."
),
)
@ -91,11 +91,17 @@ def detect(
Note
----
This command is kept for backwards compatibility. Prefer
`batdetect2 predict directory` for new workflows.
`batdetect2 process directory` for new workflows.
"""
from batdetect2 import api
from batdetect2.utils.detector_utils import save_results_to_file
message = (
"The `batdetect2 detect` command is deprecated. Prefer "
"`batdetect2 process directory` for new analyses."
)
click.secho(f"WARNING: {message}", fg="yellow", err=True)
click.echo(f"Loading model: {args['model_path']}")
model, params = api.load_model(args["model_path"])

View File

@ -7,9 +7,9 @@ from batdetect2.cli.base import cli
__all__ = ["data"]
@cli.group(short_help="Inspect and convert datasets.")
@cli.group(short_help="Inspect and manage datasets.")
def data():
"""Inspect and convert dataset configuration files."""
"""Inspect and manage dataset configuration files."""
@data.command(short_help="Print dataset summary information.")
@ -64,7 +64,7 @@ def summary(
base_dir=base_dir,
)
print(f"Number of annotated clips: {len(dataset)}")
click.echo(f"Number of annotated clips: {len(dataset)}")
if targets_path is None:
return
@ -73,7 +73,7 @@ def summary(
summary = compute_class_summary(dataset, targets)
print(summary.sort_values("class_name").to_markdown())
click.echo(summary.sort_values("class_name").to_markdown())
@data.command(short_help="Convert dataset config to annotation set.")
@ -200,6 +200,6 @@ def convert(
if not audio_dir.is_absolute():
audio_dir = audio_dir.resolve()
print(f"Using audio directory: {audio_dir}")
click.echo(f"Using audio directory: {audio_dir}")
io.save(annotation_set, output, audio_dir=audio_dir)

View File

@ -12,38 +12,40 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@cli.command(name="evaluate", short_help="Evaluate a model checkpoint.")
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("test_dataset", type=click.Path(exists=True))
@click.option(
"--targets",
"targets_config",
type=click.Path(exists=True),
help="Path to targets config file.",
"--model",
"model_path",
type=str,
help=(
"Path to a checkpoint, checkpoint alias, or a Hugging Face "
"URI to fine-tune from. Defaults to uk_same"
),
)
@click.option(
"--audio-config",
type=click.Path(exists=True),
help="Path to audio config file.",
help="Path to an audio config file.",
)
@click.option(
"--evaluation-config",
type=click.Path(exists=True),
help="Path to evaluation config file.",
help="Path to an evaluation config file.",
)
@click.option(
"--inference-config",
type=click.Path(exists=True),
help="Path to inference config file.",
help="Path to an inference config file.",
)
@click.option(
"--outputs-config",
type=click.Path(exists=True),
help="Path to outputs config file.",
help="Path to an outputs config file.",
)
@click.option(
"--logging-config",
type=click.Path(exists=True),
help="Path to logging config file.",
help="Path to a logging config file.",
)
@click.option(
"--base-dir",
@ -80,24 +82,23 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
default=0,
)
def evaluate_command(
model_path: Path,
test_dataset: Path,
base_dir: Path,
targets_config: Path | None,
audio_config: Path | None,
evaluation_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
logging_config: Path | None,
model_path: str | None = None,
base_dir: Path | None = None,
audio_config: Path | None = None,
evaluation_config: Path | None = None,
inference_config: Path | None = None,
outputs_config: Path | None = None,
logging_config: Path | None = None,
output_dir: Path = DEFAULT_OUTPUT_DIR,
num_workers: int = 0,
experiment_name: str | None = None,
run_name: str | None = None,
):
"""Evaluate a checkpoint against a test dataset.
"""Evaluate a checkpoint on a labelled test dataset.
Loads model and optional override configs, runs evaluation on
`test_dataset`, and writes metrics/artifacts to `output_dir`.
This command loads a checkpoint, runs evaluation on ``test_dataset``, and
writes metrics to ``output_dir``.
"""
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig

View File

@ -1,5 +1,5 @@
from pathlib import Path
from typing import Literal, cast
from typing import Literal
import click
from loguru import logger
@ -13,13 +13,6 @@ __all__ = ["finetune_command"]
name="finetune", short_help="Fine-tune a checkpoint on new targets."
)
@click.argument("train_dataset", type=click.Path(exists=True))
@click.option(
"--model",
"model_path",
required=True,
type=click.Path(exists=True),
help="Path to a checkpoint to fine-tune from.",
)
@click.option(
"--targets",
"targets_config",
@ -27,6 +20,15 @@ __all__ = ["finetune_command"]
type=click.Path(exists=True),
help="Path to the new targets config file.",
)
@click.option(
"--model",
"model_path",
type=str,
help=(
"Path to a checkpoint, checkpoint alias, or a Hugging Face "
"URI to fine-tune from. Defaults to uk_same"
),
)
@click.option(
"--val-dataset",
type=click.Path(exists=True),
@ -57,7 +59,7 @@ __all__ = ["finetune_command"]
)
@click.option(
"--trainable",
type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]),
type=click.Choice(["all", "heads", "classifier_head", "size_head"]),
default="heads",
show_default=True,
help="Which model parameters remain trainable during fine-tuning.",
@ -106,8 +108,8 @@ __all__ = ["finetune_command"]
)
def finetune_command(
train_dataset: Path,
model_path: Path,
targets_config: Path,
model_path: str | None = None,
val_dataset: Path | None = None,
ckpt_dir: Path | None = None,
log_dir: Path | None = None,
@ -115,7 +117,9 @@ def finetune_command(
training_config: Path | None = None,
audio_config: Path | None = None,
logging_config: Path | None = None,
trainable: str = "heads",
trainable: Literal[
"all", "heads", "classifier_head", "size_head"
] = "heads",
seed: int | None = None,
num_epochs: int | None = None,
train_workers: int = 0,
@ -192,10 +196,7 @@ def finetune_command(
train_annotations=train_annotations,
val_annotations=val_annotations,
targets_config=target_conf,
trainable=cast(
Literal["all", "heads", "classifier_head", "bbox_head"],
trainable,
),
trainable=trainable,
train_workers=train_workers,
val_workers=val_workers,
checkpoint_dir=ckpt_dir,

View File

@ -13,27 +13,26 @@ if TYPE_CHECKING:
from batdetect2.inference import InferenceConfig
from batdetect2.outputs import OutputsConfig
__all__ = ["predict"]
__all__ = ["process"]
@cli.group(name="predict", short_help="Run prediction workflows.")
def predict() -> None:
"""Run model inference on audio files.
@cli.group(name="process", short_help="Run processing workflows.")
def process() -> None:
"""Run model inference on audio.
Use one of the subcommands to select inputs from a directory, a text file
list, or an annotation dataset.
Choose a subcommand based on how you want to provide input audio.
"""
def common_predict_options(func):
"""Attach options shared by all `predict` subcommands."""
"""Attach options shared by all ``process`` subcommands."""
@click.option(
"--audio-config",
type=click.Path(exists=True),
help=(
"Path to an audio config file. Use this to override audio "
"loading and preprocessing-related settings."
"loading settings."
),
)
@click.option(
@ -41,7 +40,7 @@ def common_predict_options(func):
type=click.Path(exists=True),
help=(
"Path to an inference config file. Use this to override "
"prediction-time thresholds and behavior."
"prediction settings."
),
)
@click.option(
@ -49,23 +48,19 @@ def common_predict_options(func):
type=click.Path(exists=True),
help=(
"Path to an outputs config file. Use this to control the "
"prediction fields written to disk."
"saved output format and fields."
),
)
@click.option(
"--logging-config",
type=click.Path(exists=True),
help=(
"Path to a logging config file. Use this to customize logging "
"format and levels."
),
help=("Path to a logging config file. Use this to change log output."),
)
@click.option(
"--batch-size",
type=int,
help=(
"Batch size for inference. If omitted, the value from the "
"loaded config is used."
"Batch size for inference. If omitted, the config value is used."
),
)
@click.option(
@ -82,7 +77,7 @@ def common_predict_options(func):
type=str,
help=(
"Output format name used by the prediction writer. If omitted, "
"the default output format is used."
"the config default is used."
),
)
@click.option(
@ -91,7 +86,7 @@ def common_predict_options(func):
default=None,
help=(
"Optional detection score threshold override. If omitted, "
"the model default threshold is used."
"the configured threshold is used."
),
)
@wraps(func)
@ -102,7 +97,7 @@ def common_predict_options(func):
def _build_api(
model_path: Path,
model_path: str,
audio_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
@ -144,7 +139,7 @@ def _build_api(
def _run_prediction(
model_path: Path,
model_path: str,
audio_files: list[Path],
output_path: Path,
audio_config: Path | None,
@ -191,16 +186,16 @@ def _run_prediction(
)
@predict.command(
@process.command(
name="directory",
short_help="Predict on audio files in a directory.",
short_help="Process audio files in a directory.",
)
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("model_path", type=str)
@click.argument("audio_dir", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path())
@common_predict_options
def predict_directory_command(
model_path: Path,
model_path: str,
audio_dir: Path,
output_path: Path,
audio_config: Path | None,
@ -212,10 +207,10 @@ def predict_directory_command(
format_name: str | None,
detection_threshold: float | None,
) -> None:
"""Predict on all audio files in a directory.
"""Run processing on all supported audio files in a directory.
Loads a checkpoint, scans `audio_dir` for supported audio files, runs
inference, and saves predictions to `output_path`.
This command scans ``audio_dir`` for audio files, runs processing, and
saves the results to ``output_path``.
"""
from soundevent.audio.files import get_audio_files
@ -235,16 +230,16 @@ def predict_directory_command(
)
@predict.command(
@process.command(
name="file_list",
short_help="Predict on paths listed in a text file.",
short_help="Process paths listed in a text file.",
)
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("model_path", type=str)
@click.argument("file_list", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path())
@common_predict_options
def predict_file_list_command(
model_path: Path,
model_path: str,
file_list: Path,
output_path: Path,
audio_config: Path | None,
@ -256,9 +251,9 @@ def predict_file_list_command(
format_name: str | None,
detection_threshold: float | None,
) -> None:
"""Predict on audio files listed in a text file.
"""Run processing on audio files listed in a text file.
The list file should contain one audio path per line. Empty lines are
The text file should contain one audio path per line. Empty lines are
ignored.
"""
file_list = Path(file_list)
@ -283,16 +278,16 @@ def predict_file_list_command(
)
@predict.command(
@process.command(
name="dataset",
short_help="Predict on recordings from a dataset config.",
short_help="Process recordings from a dataset config.",
)
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("model_path", type=str)
@click.argument("dataset_path", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path())
@common_predict_options
def predict_dataset_command(
model_path: Path,
model_path: str,
dataset_path: Path,
output_path: Path,
audio_config: Path | None,
@ -304,10 +299,10 @@ def predict_dataset_command(
format_name: str | None,
detection_threshold: float | None,
) -> None:
"""Predict on recordings referenced in an annotation dataset.
"""Run processing on recordings referenced in a dataset file.
The dataset is read as a soundevent annotation set and unique recording
paths are extracted before inference.
Recording paths are read from the dataset and each recording is processed
once.
"""
from soundevent import io

View File

@ -13,15 +13,15 @@ __all__ = ["train_command"]
@click.option(
"--val-dataset",
type=click.Path(exists=True),
help="Path to validation dataset config file.",
help="Path to a validation dataset config file.",
)
@click.option(
"--model",
"model_path",
type=click.Path(exists=True),
type=str,
help=(
"Path to a checkpoint to continue training from. If omitted, "
"training starts from a fresh model config."
"Path to a checkpoint, bundled checkpoint alias, or Hugging Face "
"URI. If omitted, training starts from a fresh model config."
),
)
@click.option(
@ -36,7 +36,7 @@ __all__ = ["train_command"]
"--targets",
"targets_config",
type=click.Path(exists=True),
help="Path to targets config file.",
help="Path to a targets config file.",
)
@click.option(
"--model-config",
@ -46,32 +46,32 @@ __all__ = ["train_command"]
@click.option(
"--training-config",
type=click.Path(exists=True),
help="Path to training config file.",
help="Path to a training config file.",
)
@click.option(
"--audio-config",
type=click.Path(exists=True),
help="Path to audio config file.",
help="Path to an audio config file.",
)
@click.option(
"--evaluation-config",
type=click.Path(exists=True),
help="Path to evaluation config file.",
help="Path to an evaluation config file.",
)
@click.option(
"--inference-config",
type=click.Path(exists=True),
help="Path to inference config file.",
help="Path to an inference config file.",
)
@click.option(
"--outputs-config",
type=click.Path(exists=True),
help="Path to outputs config file.",
help="Path to an outputs config file.",
)
@click.option(
"--logging-config",
type=click.Path(exists=True),
help="Path to logging config file.",
help="Path to a logging config file.",
)
@click.option(
"--ckpt-dir",
@ -118,7 +118,7 @@ __all__ = ["train_command"]
def train_command(
train_dataset: Path,
val_dataset: Path | None = None,
model_path: Path | None = None,
model_path: str | None = None,
ckpt_dir: Path | None = None,
log_dir: Path | None = None,
base_dir: Path | None = None,
@ -139,9 +139,8 @@ def train_command(
):
"""Train a BatDetect2 model.
Train either from a fresh config (`--model-config`) or by fine-tuning an
existing checkpoint (`--model`). Training data are loaded from
`train_dataset`, with optional validation data from `--val-dataset`.
Start from a fresh model config or continue from an existing checkpoint.
Training data are loaded from ``train_dataset``.
"""
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig

View File

@ -102,19 +102,19 @@ def convert_to_annotation_group(
x_inds.append(0)
y_inds.append(0)
annotations.append(
Annotation(
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
class_prob=1.0,
det_prob=1.0,
individual="0",
event=event,
class_id=class_id,
)
)
annotation_entry: Annotation = {
"start_time": start_time,
"end_time": end_time,
"low_freq": low_freq,
"high_freq": high_freq,
"class_prob": 1.0,
"det_prob": 1.0,
"individual": "0",
"event": event,
"class": get_recording_class_name(recording),
"class_id": class_id,
}
annotations.append(annotation_entry)
return {
"id": str(recording.path),

View File

@ -53,7 +53,7 @@ class Registry(Generic[T_Type, P_Type]):
def __init__(self, name: str, discriminator: str = "name"):
self._name = name
self._registry: dict[
str, Callable[Concatenate[..., P_Type], T_Type]
str, Callable[Concatenate[Any, P_Type], T_Type]
] = {}
self._discriminator = discriminator
self._config_types: dict[str, Type[BaseModel]] = {}
@ -80,7 +80,7 @@ class Registry(Generic[T_Type, P_Type]):
)
def decorator(
func: Callable[Concatenate[T_Config, P_Type], T_Type],
func: Callable[..., T_Type],
):
self._registry[name] = func
return func
@ -102,8 +102,8 @@ class Registry(Generic[T_Type, P_Type]):
def build(
self,
config: BaseModel,
*args: P_Type.args,
**kwargs: P_Type.kwargs,
*args: Any,
**kwargs: Any,
) -> T_Type:
"""Builds a logic instance from a config object."""

View File

@ -12,13 +12,15 @@ __all__ = [
]
def _default_tasks() -> list[TaskConfig]:
return [
DetectionTaskConfig(),
ClassificationTaskConfig(),
]
class EvaluationConfig(BaseConfig):
tasks: List[TaskConfig] = Field(
default_factory=lambda: [
DetectionTaskConfig(),
ClassificationTaskConfig(),
]
)
tasks: List[TaskConfig] = Field(default_factory=_default_tasks)
def get_default_eval_config() -> EvaluationConfig:

View File

@ -11,7 +11,7 @@ from batdetect2.evaluate.dataset import build_test_loader
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger
from batdetect2.models import Model
from batdetect2.models.types import ModelProtocol
from batdetect2.outputs import OutputsConfig, build_output_transform
from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess.types import ClipDetections
@ -22,7 +22,7 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
def run_evaluate(
model: Model,
model: ModelProtocol,
test_annotations: Sequence[data.ClipAnnotation],
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,

View File

@ -7,14 +7,14 @@ from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import get_image_logger
from batdetect2.models import Model
from batdetect2.models.types import ModelProtocol
from batdetect2.postprocess.types import ClipDetections
class EvaluationModule(LightningModule):
def __init__(
self,
model: Model,
model: ModelProtocol,
evaluator: EvaluatorProtocol,
):
super().__init__()

View File

@ -25,11 +25,15 @@ from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol
def _default_metrics() -> list[ClassificationMetricConfig]:
return [ClassificationAveragePrecisionConfig()]
class ClassificationTaskConfig(BaseSEDTaskConfig):
name: Literal["sound_event_classification"] = "sound_event_classification"
prefix: str = "classification"
metrics: list[ClassificationMetricConfig] = Field(
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
default_factory=_default_metrics
)
plots: list[ClassificationPlotConfig] = Field(default_factory=list)
include_generics: bool = True

View File

@ -23,13 +23,15 @@ from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
def _default_metrics() -> list[ClipClassificationMetricConfig]:
return [ClipClassificationAveragePrecisionConfig()]
class ClipClassificationTaskConfig(BaseTaskConfig):
name: Literal["clip_classification"] = "clip_classification"
prefix: str = "clip_classification"
metrics: list[ClipClassificationMetricConfig] = Field(
default_factory=lambda: [
ClipClassificationAveragePrecisionConfig(),
]
default_factory=_default_metrics
)
plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)

View File

@ -22,13 +22,15 @@ from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
def _default_metrics() -> list[ClipDetectionMetricConfig]:
return [ClipDetectionAveragePrecisionConfig()]
class ClipDetectionTaskConfig(BaseTaskConfig):
name: Literal["clip_detection"] = "clip_detection"
prefix: str = "clip_detection"
metrics: list[ClipDetectionMetricConfig] = Field(
default_factory=lambda: [
ClipDetectionAveragePrecisionConfig(),
]
default_factory=_default_metrics
)
plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)

View File

@ -24,11 +24,15 @@ from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
def _default_metrics() -> list[DetectionMetricConfig]:
return [DetectionAveragePrecisionConfig()]
class DetectionTaskConfig(BaseSEDTaskConfig):
name: Literal["sound_event_detection"] = "sound_event_detection"
prefix: str = "detection"
metrics: list[DetectionMetricConfig] = Field(
default_factory=lambda: [DetectionAveragePrecisionConfig()]
default_factory=_default_metrics
)
plots: list[DetectionPlotConfig] = Field(default_factory=list)

View File

@ -24,11 +24,15 @@ from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
def _default_metrics() -> list[TopClassMetricConfig]:
return [TopClassAveragePrecisionConfig()]
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
name: Literal["top_class_detection"] = "top_class_detection"
prefix: str = "top_class"
metrics: list[TopClassMetricConfig] = Field(
default_factory=lambda: [TopClassAveragePrecisionConfig()]
default_factory=_default_metrics
)
plots: list[TopClassPlotConfig] = Field(default_factory=list)

View File

@ -10,7 +10,7 @@ from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.config import InferenceConfig
from batdetect2.inference.dataset import build_inference_loader
from batdetect2.inference.lightning import InferenceModule
from batdetect2.models import Model
from batdetect2.models.types import ModelProtocol
from batdetect2.outputs import (
OutputsConfig,
OutputTransformProtocol,
@ -22,7 +22,7 @@ from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
def run_batch_inference(
model: Model,
model: ModelProtocol,
clips: Sequence[data.Clip],
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
@ -86,7 +86,7 @@ def run_batch_inference(
def process_file_list(
model: Model,
model: ModelProtocol,
paths: Sequence[data.PathLike],
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,

View File

@ -4,7 +4,7 @@ from lightning import LightningModule
from torch.utils.data import DataLoader
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model
from batdetect2.models.types import ModelProtocol
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
@ -13,7 +13,7 @@ from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
class InferenceModule(LightningModule):
def __init__(
self,
model: Model,
model: ModelProtocol,
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
output_transform: OutputTransformProtocol | None = None,

View File

@ -62,7 +62,7 @@ from batdetect2.models.encoder import (
build_encoder,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.models.types import DetectionModel
from batdetect2.models.types import DetectorProtocol, ModelProtocol
from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.postprocess.types import (
ClipDetectionsTensor,
@ -149,7 +149,7 @@ class Model(torch.nn.Module):
Attributes
----------
detector : DetectionModel
detector : DetectorProtocol
The neural network that processes spectrograms and produces raw
detection, classification, and bounding-box outputs.
preprocessor : PreprocessorProtocol
@ -164,19 +164,21 @@ class Model(torch.nn.Module):
Size-dimension names corresponding to the model size outputs.
"""
detector: DetectionModel
detector: DetectorProtocol
preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol
class_names: list[str]
dimension_names: list[str]
_config: dict[str, object]
def __init__(
self,
detector: DetectionModel,
detector: DetectorProtocol,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
class_names: list[str],
dimension_names: list[str],
config: dict[str, object],
):
super().__init__()
self.detector = detector
@ -184,6 +186,12 @@ class Model(torch.nn.Module):
self.postprocessor = postprocessor
self.class_names = class_names
self.dimension_names = dimension_names
self._config = config
def get_config(self) -> dict[str, object]:
"""Return the model configuration as plain JSON-serializable data."""
return dict(self._config)
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
"""Run the full detection pipeline on a waveform tensor.
@ -216,7 +224,7 @@ def build_model(
dimension_names: list[str] | None = None,
preprocessor: PreprocessorProtocol | None = None,
postprocessor: PostprocessorProtocol | None = None,
) -> Model:
) -> ModelProtocol:
"""Build a complete, ready-to-use BatDetect2 model.
Assembles a ``Model`` instance from a ``ModelConfig`` and optional
@ -248,7 +256,7 @@ def build_model(
Returns
-------
Model
ModelProtocol
A fully assembled ``Model`` instance ready for inference or
training.
"""
@ -277,8 +285,8 @@ def build_model(
config=config.postprocess,
)
detector = build_detector(
num_classes=len(class_names),
num_sizes=len(dimension_names),
class_names=class_names,
dimension_names=dimension_names,
config=config.architecture,
)
return Model(
@ -287,18 +295,19 @@ def build_model(
preprocessor=preprocessor,
class_names=class_names,
dimension_names=dimension_names,
config=config.model_dump(mode="json"),
)
def build_model_with_new_targets(
model: Model,
model: ModelProtocol,
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
) -> Model:
) -> ModelProtocol:
"""Build a new model with a different target set."""
detector = build_detector(
num_classes=len(targets.class_names),
num_sizes=len(roi_mapper.dimension_names),
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
backbone=model.detector.backbone,
)
@ -308,4 +317,5 @@ def build_model_with_new_targets(
preprocessor=model.preprocessor,
class_names=targets.class_names,
dimension_names=roi_mapper.dimension_names,
config=model.get_config(),
)

View File

@ -27,6 +27,7 @@ from typing import Annotated, Literal
import torch
import torch.nn.functional as F
from loguru import logger
from pydantic import Field, TypeAdapter
from soundevent import data
@ -52,7 +53,7 @@ from batdetect2.models.encoder import (
build_encoder,
)
from batdetect2.models.types import (
BackboneModel,
BackboneProtocol,
BottleneckProtocol,
DecoderProtocol,
EncoderProtocol,
@ -104,7 +105,7 @@ class UNetBackboneConfig(BaseConfig):
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
backbone_registry: Registry[BackboneProtocol, []] = Registry("backbone")
@add_import_config(backbone_registry)
@ -118,7 +119,7 @@ class BackboneImportConfig(ImportConfig):
name: Literal["import"] = "import"
class UNetBackbone(BackboneModel):
class UNetBackbone(torch.nn.Module):
"""U-Net-style encoder-decoder backbone network.
Combines an encoder, a bottleneck, and a decoder into a single module
@ -225,7 +226,7 @@ class UNetBackbone(BackboneModel):
@backbone_registry.register(UNetBackboneConfig)
@staticmethod
def from_config(config: UNetBackboneConfig) -> BackboneModel:
def from_config(config: UNetBackboneConfig) -> BackboneProtocol:
encoder = build_encoder(
in_channels=config.in_channels,
input_height=config.input_height,
@ -266,7 +267,7 @@ BackboneConfig = Annotated[
]
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
def build_backbone(config: BackboneConfig | None = None) -> BackboneProtocol:
"""Build a backbone network from configuration.
Looks up the backbone class corresponding to ``config.name`` in the
@ -282,10 +283,14 @@ def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
Returns
-------
BackboneModel
BackboneProtocol
An initialised backbone module.
"""
config = config or UNetBackboneConfig()
logger.opt(lazy=True).debug(
"Building model backbone with config: \n{}",
lambda: config.to_yaml_string(),
)
return backbone_registry.build(config)

View File

@ -6,8 +6,8 @@ bounding-box size regression.
Components
----------
- ``Detector`` the ``torch.nn.Module`` that wires together a backbone
(``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to
- ``Detector`` - the ``torch.nn.Module`` that wires together a backbone
(``BackboneProtocol``) with a ``ClassifierHead`` and a ``BBoxHead`` to
produce a ``ModelOutput`` tuple from an input spectrogram.
- ``build_detector`` factory function that builds a ready-to-use
``Detector`` from a backbone configuration and a target class count.
@ -18,15 +18,16 @@ preprocessing and output postprocessing are handled by
"""
import torch
from loguru import logger
from batdetect2.models.backbones import (
BackboneConfig,
UNetBackboneConfig,
build_backbone,
)
from batdetect2.models.backbones import BackboneConfig, build_backbone
from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
from batdetect2.models.types import (
BackboneProtocol,
ClassifierHeadProtocol,
DetectorProtocol,
ModelOutput,
SizeHeadProtocol,
)
__all__ = [
"Detector",
@ -34,7 +35,7 @@ __all__ = [
]
class Detector(DetectionModel):
class Detector(torch.nn.Module):
"""Complete BatDetect2 detection and classification model.
Combines a backbone feature extractor with two prediction heads:
@ -51,7 +52,7 @@ class Detector(DetectionModel):
Attributes
----------
backbone : BackboneModel
backbone : BackboneProtocol
The feature extraction backbone.
num_classes : int
Number of target classes (inferred from the classifier head).
@ -61,13 +62,13 @@ class Detector(DetectionModel):
Produces duration and bandwidth predictions from backbone features.
"""
backbone: BackboneModel
backbone: BackboneProtocol
def __init__(
self,
backbone: BackboneModel,
classifier_head: ClassifierHead,
bbox_head: BBoxHead,
backbone: BackboneProtocol,
classifier_head: ClassifierHeadProtocol,
size_head: SizeHeadProtocol,
):
"""Initialise the Detector model.
@ -76,7 +77,7 @@ class Detector(DetectionModel):
Parameters
----------
backbone : BackboneModel
backbone : BackboneProtocol
An initialised backbone module (e.g. built by
``build_backbone``).
classifier_head : ClassifierHead
@ -90,7 +91,7 @@ class Detector(DetectionModel):
self.backbone = backbone
self.num_classes = classifier_head.num_classes
self.classifier_head = classifier_head
self.bbox_head = bbox_head
self.size_head = size_head
def forward(self, spec: torch.Tensor) -> ModelOutput:
"""Run the complete detection model on an input spectrogram.
@ -125,7 +126,7 @@ class Detector(DetectionModel):
features = self.backbone(spec)
classification = self.classifier_head(features)
detection = classification.sum(dim=1, keepdim=True)
size_preds = self.bbox_head(features)
size_preds = self.size_head(features)
return ModelOutput(
detection_probs=detection,
size_preds=size_preds,
@ -135,11 +136,11 @@ class Detector(DetectionModel):
def build_detector(
num_classes: int,
num_sizes: int = 2,
class_names: list[str],
dimension_names: list[str],
config: BackboneConfig | None = None,
backbone: BackboneModel | None = None,
) -> DetectionModel:
backbone: BackboneProtocol | None = None,
) -> DetectorProtocol:
"""Build a complete BatDetect2 detection model.
Constructs a backbone from ``config``, attaches a ``ClassifierHead``
@ -158,7 +159,7 @@ def build_detector(
Returns
-------
DetectionModel
DetectorProtocol
An initialised ``Detector`` instance ready for training or
inference.
@ -168,24 +169,18 @@ def build_detector(
If ``num_classes`` is not positive, or if the backbone
configuration is invalid.
"""
if backbone is None:
config = config or UNetBackboneConfig()
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(), # type: ignore
)
backbone = build_backbone(config=config)
backbone = backbone or build_backbone(config=config)
classifier_head = ClassifierHead(
num_classes=num_classes,
class_names=class_names,
in_channels=backbone.out_channels,
)
bbox_head = BBoxHead(
in_channels=backbone.out_channels,
num_sizes=num_sizes,
dimension_names=dimension_names,
)
return Detector(
backbone=backbone,
classifier_head=classifier_head,
bbox_head=bbox_head,
size_head=bbox_head,
)

View File

@ -54,12 +54,14 @@ class ClassifierHead(nn.Module):
1×1 convolution with ``num_classes + 1`` output channels.
"""
def __init__(self, num_classes: int, in_channels: int):
def __init__(self, class_names: list[str], in_channels: int):
"""Initialise the ClassifierHead."""
super().__init__()
self.num_classes = num_classes
self.class_names = class_names
self.num_classes = len(class_names)
self.in_channels = in_channels
self.classifier = nn.Conv2d(
self.in_channels,
self.num_classes + 1,
@ -165,11 +167,12 @@ class BBoxHead(nn.Module):
1×1 convolution with 2 output channels (duration, bandwidth).
"""
def __init__(self, in_channels: int, num_sizes: int = 2):
def __init__(self, dimension_names: list[str], in_channels: int):
"""Initialise the BBoxHead."""
super().__init__()
self.in_channels = in_channels
self.num_sizes = num_sizes
self.dimension_names = dimension_names
self.num_sizes = len(dimension_names)
self.bbox = nn.Conv2d(
in_channels=self.in_channels,

View File

@ -1,21 +1,42 @@
from abc import ABC, abstractmethod
from typing import NamedTuple, Protocol
from typing import Any, NamedTuple, Protocol
import torch
from batdetect2.postprocess.types import PostprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = [
"BackboneModel",
"BackboneProtocol",
"BlockProtocol",
"BottleneckProtocol",
"ClassifierHeadProtocol",
"DecoderProtocol",
"DetectionModel",
"EncoderDecoderModel",
"DetectorProtocol",
"EncoderProtocol",
"ModelOutput",
"ModelProtocol",
"ModuleProtocol",
"SizeHeadProtocol",
]
class BlockProtocol(Protocol):
class ModuleProtocol(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
def train(self, mode: bool = True) -> torch.nn.Module: ...
def eval(self) -> torch.nn.Module: ...
def state_dict(
self, *args: Any, **kwargs: Any
) -> dict[str, torch.Tensor]: ...
def load_state_dict(self, *args: Any, **kwargs: Any) -> Any: ...
def parameters(self) -> Any: ...
class BlockProtocol(ModuleProtocol, Protocol):
in_channels: int
out_channels: int
@ -24,7 +45,7 @@ class BlockProtocol(Protocol):
def get_output_height(self, input_height: int) -> int: ...
class EncoderProtocol(Protocol):
class EncoderProtocol(ModuleProtocol, Protocol):
in_channels: int
out_channels: int
input_height: int
@ -33,7 +54,7 @@ class EncoderProtocol(Protocol):
def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ...
class BottleneckProtocol(Protocol):
class BottleneckProtocol(ModuleProtocol, Protocol):
in_channels: int
out_channels: int
input_height: int
@ -41,7 +62,7 @@ class BottleneckProtocol(Protocol):
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
class DecoderProtocol(Protocol):
class DecoderProtocol(ModuleProtocol, Protocol):
in_channels: int
out_channels: int
input_height: int
@ -62,29 +83,42 @@ class ModelOutput(NamedTuple):
features: torch.Tensor
class BackboneModel(ABC, torch.nn.Module):
class BackboneProtocol(ModuleProtocol, Protocol):
input_height: int
out_channels: int
@abstractmethod
def forward(self, spec: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def forward(self, spec: torch.Tensor) -> torch.Tensor: ...
class EncoderDecoderModel(BackboneModel):
bottleneck_channels: int
class ClassifierHeadProtocol(ModuleProtocol, Protocol):
num_classes: int
in_channels: int
class_names: list[str]
@abstractmethod
def encode(self, spec: torch.Tensor) -> torch.Tensor: ...
@abstractmethod
def decode(self, encoded: torch.Tensor) -> torch.Tensor: ...
def forward(self, features: torch.Tensor) -> torch.Tensor: ...
class DetectionModel(ABC, torch.nn.Module):
backbone: BackboneModel
classifier_head: torch.nn.Module
bbox_head: torch.nn.Module
class SizeHeadProtocol(ModuleProtocol, Protocol):
in_channels: int
num_sizes: int
dimension_names: list[str]
def forward(self, features: torch.Tensor) -> torch.Tensor: ...
class DetectorProtocol(ModuleProtocol, Protocol):
backbone: BackboneProtocol
classifier_head: ClassifierHeadProtocol
size_head: SizeHeadProtocol
@abstractmethod
def forward(self, spec: torch.Tensor) -> ModelOutput: ...
class ModelProtocol(ModuleProtocol, Protocol):
detector: DetectorProtocol
preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol
class_names: list[str]
dimension_names: list[str]
def get_config(self) -> dict[str, Any]: ...

View File

@ -154,17 +154,18 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
top_class_index = int(np.argmax(prediction.class_scores))
top_class_score = float(prediction.class_scores[top_class_index])
top_class = self.get_class_name(top_class_index)
return Annotation(
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
class_prob=top_class_score,
det_prob=float(prediction.detection_score),
individual="",
event=self.event_name,
**{"class": top_class},
)
annotation: Annotation = {
"start_time": start_time,
"end_time": end_time,
"low_freq": low_freq,
"high_freq": high_freq,
"class_prob": top_class_score,
"det_prob": float(prediction.detection_score),
"individual": "",
"event": self.event_name,
"class": top_class,
}
return annotation
@output_formatters.register(BatDetect2OutputConfig)
@staticmethod

View File

@ -26,6 +26,13 @@ __all__ = [
]
def _default_spectrogram_transforms() -> list[SpectrogramTransform]:
return [
PcenConfig(),
SpectralMeanSubtractionConfig(),
]
class PreprocessingConfig(BaseConfig):
"""Unified configuration for the audio preprocessing pipeline.
@ -58,10 +65,7 @@ class PreprocessingConfig(BaseConfig):
audio_transforms: List[AudioTransform] = Field(default_factory=list)
spectrogram_transforms: List[SpectrogramTransform] = Field(
default_factory=lambda: [
PcenConfig(),
SpectralMeanSubtractionConfig(),
]
default_factory=_default_spectrogram_transforms
)
stft: STFTConfig = Field(default_factory=STFTConfig)

View File

@ -71,7 +71,7 @@ class TargetClassConfig(BaseConfig):
DEFAULT_DETECTION_CLASS = TargetClassConfig(
name="bat",
match_if=AllOfConfig( # ty: ignore[unknown-argument]
match_if=AllOfConfig(
conditions=[
HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
NotConfig(

View File

@ -1,4 +1,7 @@
from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR
from batdetect2.train.checkpoints import (
DEFAULT_CHECKPOINT_DIR,
resolve_checkpoint_path,
)
from batdetect2.train.config import TrainingConfig
from batdetect2.train.lightning import (
TrainingModule,
@ -26,5 +29,6 @@ __all__ = [
"TrainingModule",
"build_trainer",
"load_model_from_checkpoint",
"resolve_checkpoint_path",
"run_train",
]

View File

@ -2,15 +2,31 @@ from pathlib import Path
from typing import Literal
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from soundevent.data import PathLike
from batdetect2.core import BaseConfig
__all__ = [
"CheckpointConfig",
"DEFAULT_CHECKPOINT",
"build_checkpoint_callback",
"get_bundled_checkpoint_names",
"resolve_checkpoint_path",
]
PACKAGE_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
DEFAULT_CHECKPOINT = "uk_same"
CHECKPOINT_ALIASES = {
DEFAULT_CHECKPOINT: PACKAGE_ROOT
/ "models"
/ "checkpoints"
/ "batdetect2_uk_same.ckpt",
"batdetect2_uk_same": PACKAGE_ROOT
/ "models"
/ "checkpoints"
/ "batdetect2_uk_same.ckpt",
}
class CheckpointConfig(BaseConfig):
@ -18,6 +34,8 @@ class CheckpointConfig(BaseConfig):
monitor: str | None = None
mode: str = "max"
save_top_k: int = 1
# Save distributable inference checkpoints by default.
save_weights_only: bool = True
filename: str | None = None
save_last: bool | Literal["link"] = "link"
every_n_epochs: int | None = 1
@ -47,9 +65,86 @@ def build_checkpoint_callback(
return ModelCheckpoint(
dirpath=str(checkpoint_dir),
save_top_k=config.save_top_k,
save_weights_only=config.save_weights_only,
monitor=config.monitor,
mode=config.mode,
filename=config.filename,
save_last=config.save_last,
every_n_epochs=config.every_n_epochs,
)
def get_bundled_checkpoint_names() -> tuple[str, ...]:
"""Return the supported bundled checkpoint aliases."""
return tuple(CHECKPOINT_ALIASES.keys())
def resolve_checkpoint_from_huggingface(path: str) -> Path:
"""Resolve a Hugging Face checkpoint URI."""
try:
from huggingface_hub import hf_hub_download
except ImportError as error:
raise ValueError(
"Hugging Face checkpoint support is not installed. "
"Install it with `pip install batdetect2[huggingface]`."
) from error
repo_id, filename = _parse_huggingface_uri(path)
return Path(hf_hub_download(repo_id=repo_id, filename=filename))
def resolve_checkpoint_path(path: PathLike | str | None = None) -> Path:
"""Resolve a local path, alias, or Hugging Face checkpoint URI.
Parameters
----------
path : PathLike | str | None
Local checkpoint path, checkpoint alias, or a Hugging Face
URI of the form ``hf://owner/repo/path/to/checkpoint.ckpt``. If
omitted, the default alias checkpoint is used.
Returns
-------
Path
Resolved local filesystem path to the checkpoint.
"""
if path is None:
path = DEFAULT_CHECKPOINT
if isinstance(path, str) and path.startswith("hf://"):
return resolve_checkpoint_from_huggingface(path)
if isinstance(path, str) and path in CHECKPOINT_ALIASES:
return Path(CHECKPOINT_ALIASES[path])
path = Path(path)
if path.exists():
return path.resolve()
bundled_names = ", ".join(get_bundled_checkpoint_names())
raise FileNotFoundError(
f"Checkpoint not found: {path}. "
"Expected a local path, a checkpoint alias "
f"({bundled_names}), or a Hugging Face URI."
)
def _parse_huggingface_uri(uri: str) -> tuple[str, str]:
prefix = "hf://"
if not uri.startswith(prefix):
raise ValueError(
"Hugging Face checkpoint URIs must start with 'hf://'."
)
without_prefix = uri.removeprefix(prefix).strip("/")
parts = without_prefix.split("/")
if len(parts) < 3:
raise ValueError(
"Hugging Face checkpoint URIs must be in the form "
"'hf://owner/repo/path/to/checkpoint.ckpt'."
)
repo_id = "/".join(parts[:2])
filename = "/".join(parts[2:])
return repo_id, filename

View File

@ -1,11 +1,13 @@
from dataclasses import dataclass
import lightning as L
import torch
from soundevent.data import PathLike
from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.models.types import ModelOutput
from batdetect2.models import ModelConfig, build_model
from batdetect2.models.types import ModelOutput, ModelProtocol
from batdetect2.targets import TargetConfig
from batdetect2.train.checkpoints import resolve_checkpoint_path
from batdetect2.train.config import TrainingConfig
from batdetect2.train.losses import build_loss
from batdetect2.train.optimizers import build_optimizer
@ -19,7 +21,7 @@ __all__ = [
class TrainingModule(L.LightningModule):
model: Model
model: ModelProtocol
loss: LossProtocol
def __init__(
@ -30,7 +32,7 @@ class TrainingModule(L.LightningModule):
dimension_names: list[str] | None = None,
train_config: dict | None = None,
loss: LossProtocol | None = None,
model: Model | None = None,
model: ModelProtocol | None = None,
):
super().__init__()
@ -130,23 +132,27 @@ class StoredConfig:
def load_model_from_checkpoint(
path: PathLike,
) -> tuple[Model, StoredConfig]:
path: PathLike | str | None = None,
) -> tuple[ModelProtocol, StoredConfig]:
"""Load a model and its configuration from a Lightning checkpoint.
Parameters
----------
path : PathLike
path : PathLike | str | None
Path to a ``.ckpt`` file produced by the BatDetect2 training
pipeline.
pipeline. If omitted, the default bundled checkpoint is used.
Returns
-------
tuple[Model, ModelConfig]
tuple[ModelProtocol, ModelConfig]
The restored ``Model`` instance and the ``ModelConfig`` that
describes its architecture, preprocessing, and postprocessing.
"""
module = TrainingModule.load_from_checkpoint(path) # type: ignore
resolved_path = resolve_checkpoint_path(path)
module = TrainingModule.load_from_checkpoint(
resolved_path,
map_location=torch.device("cpu"),
)
training_config = TrainingConfig.model_validate(module.train_config)
model_config = ModelConfig.model_validate(module.model_config)
targets_config = TargetConfig.model_validate(module.targets_config)
@ -163,7 +169,7 @@ def build_training_module(
class_names: list[str] | None = None,
dimension_names: list[str] | None = None,
train_config: TrainingConfig | None = None,
model: Model | None = None,
model: ModelProtocol | None = None,
) -> TrainingModule:
if model_config is None:
model_config = ModelConfig()

View File

@ -3,6 +3,7 @@ from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import pandas as pd
from lightning.pytorch.loggers import Logger
@ -28,7 +29,7 @@ __all__ = [
@dataclass(frozen=True)
class TrainLoggingContext:
model_config: ModelConfig
model_config: dict[str, Any]
train_config: TrainingConfig
audio_config: AudioConfig
targets: TargetProtocol
@ -49,9 +50,10 @@ class ConfigHyperparameterLogging:
artifact_path: Path,
context: TrainLoggingContext,
) -> None:
model_config = ModelConfig.model_validate(context.model_config)
logger.log_hyperparams(
{
"model": context.model_config.model_dump(
"model": model_config.model_dump(
mode="json",
exclude_none=True,
),

View File

@ -15,7 +15,8 @@ from batdetect2.logging import (
TensorBoardLoggerConfig,
build_logger,
)
from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.models import ModelConfig, build_model
from batdetect2.models.types import ModelProtocol
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import (
ROIMapperProtocol,
@ -50,14 +51,13 @@ DEFAULT_LOG_DIR = Path("outputs") / "logs"
def run_train(
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Sequence[data.ClipAnnotation] | None = None,
model: Model | None = None,
model: ModelProtocol | None = None,
targets: Optional["TargetProtocol"] = None,
roi_mapper: Optional["ROIMapperProtocol"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
labeller: Optional["ClipLabeller"] = None,
audio_config: Optional[AudioConfig] = None,
model_config: Optional[ModelConfig] = None,
targets_config: TargetConfig | None = None,
train_config: Optional[TrainingConfig] = None,
logger_config: LoggerConfig | None = None,
@ -75,7 +75,11 @@ def run_train(
if seed is not None:
seed_everything(seed)
model_config = model_config or ModelConfig()
model_config = (
ModelConfig()
if model is None
else ModelConfig.model_validate(model.get_config())
)
targets_config = targets_config or TargetConfig()
audio_config = audio_config or AudioConfig()
train_config = train_config or TrainingConfig()
@ -172,7 +176,7 @@ def run_train(
root_artifact_path.mkdir(parents=True, exist_ok=True)
logging_context = TrainLoggingContext(
model_config=model_config,
model_config=model_config.model_dump(mode="json"),
train_config=train_config,
audio_config=audio_config,
targets=targets,
@ -214,7 +218,7 @@ def run_train(
def _validate_model_compatibility(
model: Model,
model: ModelProtocol,
model_config: ModelConfig,
class_names: list[str],
dimension_names: list[str],

View File

@ -200,13 +200,14 @@ def test_user_can_read_extracted_features_per_detection(
) -> None:
"""User story: inspect extracted feature vectors per detection."""
# Given
prediction = api_v2.process_file(example_audio_files[0])
assert len(prediction.detections) > 0
# When
feature_vectors = [det.features for det in prediction.detections]
feature_vectors = [
api_v2.get_detection_features(det) for det in prediction.detections
]
# Then
assert len(prediction.detections) > 0
assert len(feature_vectors) == len(prediction.detections)
assert all(vec.ndim == 1 for vec in feature_vectors)
assert all(vec.size > 0 for vec in feature_vectors)
@ -299,14 +300,20 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
value,
)
for key, value in source_detector.bbox_head.state_dict().items():
assert key in detector.bbox_head.state_dict()
for key, value in source_detector.size_head.state_dict().items():
assert key in detector.size_head.state_dict()
torch.testing.assert_close(
detector.bbox_head.state_dict()[key],
detector.size_head.state_dict()[key],
value,
)
def test_api_from_checkpoint_defaults_to_bundled_model() -> None:
api = BatDetect2API.from_checkpoint()
assert api.model.class_names
@pytest.mark.slow
def test_user_can_evaluate_small_dataset_and_get_metrics(
api_v2: BatDetect2API,

View File

@ -18,7 +18,7 @@ def test_user_can_finetune_only_heads(
api = BatDetect2API.from_config()
source_classifier_head = api.model.detector.classifier_head
source_bbox_head = api.model.detector.bbox_head
source_size_head = api.model.detector.size_head
source_backbone = api.model.detector.backbone
finetune_dir = tmp_path / "heads_only"
@ -39,7 +39,7 @@ def test_user_can_finetune_only_heads(
backbone_params = list(detector.backbone.parameters())
classifier_params = list(detector.classifier_head.parameters())
bbox_params = list(detector.bbox_head.parameters())
bbox_params = list(detector.size_head.parameters())
assert backbone_params
assert classifier_params
@ -50,7 +50,7 @@ def test_user_can_finetune_only_heads(
assert finetuned_api is not api
assert detector.backbone is source_backbone
assert detector.classifier_head is not source_classifier_head
assert detector.bbox_head is not source_bbox_head
assert detector.size_head is not source_size_head
assert list(finetune_dir.rglob("*.ckpt"))

View File

@ -11,7 +11,7 @@ def test_cli_base_help_lists_main_commands() -> None:
result = CliRunner().invoke(cli, ["--help"])
assert result.exit_code == 0
assert "predict" in result.output
assert "process" in result.output
assert "train" in result.output
assert "evaluate" in result.output
assert "data" in result.output

View File

@ -15,8 +15,8 @@ def test_cli_evaluate_help() -> None:
result = CliRunner().invoke(cli, ["evaluate", "--help"])
assert result.exit_code == 0
assert "MODEL_PATH" in result.output
assert "TEST_DATASET" in result.output
assert "--model" in result.output
assert "--evaluation-config" in result.output
@ -32,8 +32,9 @@ def test_cli_evaluate_writes_metrics_for_small_dataset(
cli,
[
"evaluate",
str(tiny_checkpoint_path),
str(BASE_DIR / "example_data" / "dataset.yaml"),
"--model",
str(tiny_checkpoint_path),
"--base-dir",
str(BASE_DIR),
"--workers",

View File

@ -1,6 +1,7 @@
"""CLI tests for finetune command."""
from pathlib import Path
from types import SimpleNamespace
import pytest
from click.testing import CliRunner
@ -25,8 +26,41 @@ def test_cli_finetune_help() -> None:
assert "--outputs-config" not in result.output
def test_cli_finetune_requires_model() -> None:
"""User story: finetune requires a checkpoint argument."""
def test_cli_finetune_defaults_to_bundled_model(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""User story: finetune can use the bundled checkpoint by default."""
called = {}
class FakeAPI:
def finetune(self, **kwargs):
called["finetune"] = kwargs
return None
class FakeBatDetect2API:
@classmethod
def from_checkpoint(cls, path=None, **kwargs):
called["path"] = path
called["from_checkpoint_kwargs"] = kwargs
return FakeAPI()
monkeypatch.setattr(
"batdetect2.api_v2.BatDetect2API",
FakeBatDetect2API,
)
monkeypatch.setattr(
"batdetect2.data.load_dataset_config",
lambda path: SimpleNamespace(path=path),
)
monkeypatch.setattr(
"batdetect2.data.load_dataset",
lambda config, base_dir=None: [],
)
monkeypatch.setattr(
"batdetect2.targets.TargetConfig.load",
lambda path: SimpleNamespace(path=path),
)
result = CliRunner().invoke(
cli,
@ -38,8 +72,9 @@ def test_cli_finetune_requires_model() -> None:
],
)
assert result.exit_code != 0
assert "--model" in result.output
assert result.exit_code == 0
assert called["path"] is None
assert "finetune" in called
def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None:

View File

@ -1,4 +1,4 @@
"""Behavior tests for predict CLI workflows."""
"""Behavior tests for process CLI workflows."""
from pathlib import Path
@ -9,10 +9,10 @@ from soundevent import data, io
from batdetect2.cli import cli
def test_cli_predict_help() -> None:
"""User story: discover available predict modes."""
def test_cli_process_help() -> None:
"""User story: discover available process modes."""
result = CliRunner().invoke(cli, ["predict", "--help"])
result = CliRunner().invoke(cli, ["process", "--help"])
assert result.exit_code == 0
assert "directory" in result.output
@ -21,19 +21,19 @@ def test_cli_predict_help() -> None:
@pytest.mark.slow
def test_cli_predict_directory_runs_on_real_audio(
def test_cli_process_directory_runs_on_real_audio(
tmp_path: Path,
tiny_checkpoint_path: Path,
single_audio_dir: Path,
) -> None:
"""User story: run prediction for all files in a directory."""
"""User story: process all files in a directory."""
output_path = tmp_path / "predictions"
result = CliRunner().invoke(
cli,
[
"predict",
"process",
"directory",
str(tiny_checkpoint_path),
str(single_audio_dir),
@ -52,12 +52,12 @@ def test_cli_predict_directory_runs_on_real_audio(
assert len(list(output_path.glob("*.json"))) == 1
def test_cli_predict_file_list_runs_on_real_audio(
def test_cli_process_file_list_runs_on_real_audio(
tmp_path: Path,
tiny_checkpoint_path: Path,
single_audio_dir: Path,
) -> None:
"""User story: run prediction from an explicit list of files."""
"""User story: process an explicit list of files."""
audio_file = next(single_audio_dir.glob("*.wav"))
file_list = tmp_path / "files.txt"
@ -68,7 +68,7 @@ def test_cli_predict_file_list_runs_on_real_audio(
result = CliRunner().invoke(
cli,
[
"predict",
"process",
"file_list",
str(tiny_checkpoint_path),
str(file_list),
@ -87,12 +87,12 @@ def test_cli_predict_file_list_runs_on_real_audio(
assert len(list(output_path.glob("*.json"))) == 1
def test_cli_predict_dataset_runs_on_aoef_metadata(
def test_cli_process_dataset_runs_on_aoef_metadata(
tmp_path: Path,
tiny_checkpoint_path: Path,
single_audio_dir: Path,
) -> None:
"""User story: predict from AOEF dataset metadata file."""
"""User story: process from AOEF dataset metadata file."""
audio_file = next(single_audio_dir.glob("*.wav"))
recording = data.Recording.from_file(audio_file)
@ -103,7 +103,7 @@ def test_cli_predict_dataset_runs_on_aoef_metadata(
)
annotation_set = data.AnnotationSet(
name="test",
description="predict dataset test",
description="process dataset test",
clip_annotations=[data.ClipAnnotation(clip=clip, sound_events=[])],
)
@ -115,7 +115,7 @@ def test_cli_predict_dataset_runs_on_aoef_metadata(
result = CliRunner().invoke(
cli,
[
"predict",
"process",
"dataset",
str(tiny_checkpoint_path),
str(dataset_path),
@ -142,7 +142,7 @@ def test_cli_predict_dataset_runs_on_aoef_metadata(
("soundevent", "*.json", True),
],
)
def test_cli_predict_directory_supports_output_format_override(
def test_cli_process_directory_supports_output_format_override(
tmp_path: Path,
tiny_checkpoint_path: Path,
single_audio_dir: Path,
@ -157,7 +157,7 @@ def test_cli_predict_directory_supports_output_format_override(
result = CliRunner().invoke(
cli,
[
"predict",
"process",
"directory",
str(tiny_checkpoint_path),
str(single_audio_dir),
@ -180,12 +180,12 @@ def test_cli_predict_directory_supports_output_format_override(
assert len(list(output_path.glob(expected_pattern))) >= 1
def test_cli_predict_dataset_deduplicates_recordings(
def test_cli_process_dataset_deduplicates_recordings(
tmp_path: Path,
tiny_checkpoint_path: Path,
single_audio_dir: Path,
) -> None:
"""User story: duplicated recording entries are predicted once."""
"""User story: duplicated recording entries are processed once."""
audio_file = next(single_audio_dir.glob("*.wav"))
recording = data.Recording.from_file(audio_file)
@ -215,7 +215,7 @@ def test_cli_predict_dataset_deduplicates_recordings(
result = CliRunner().invoke(
cli,
[
"predict",
"process",
"dataset",
str(tiny_checkpoint_path),
str(dataset_path),
@ -234,7 +234,7 @@ def test_cli_predict_dataset_deduplicates_recordings(
assert len(list(output_path.glob("*.nc"))) == 1
def test_cli_predict_rejects_unknown_output_format(
def test_cli_process_rejects_unknown_output_format(
tmp_path: Path,
tiny_checkpoint_path: Path,
single_audio_dir: Path,
@ -245,7 +245,7 @@ def test_cli_predict_rejects_unknown_output_format(
result = CliRunner().invoke(
cli,
[
"predict",
"process",
"directory",
str(tiny_checkpoint_path),
str(single_audio_dir),

View File

@ -13,7 +13,6 @@ from batdetect2.models.backbones import (
build_backbone,
load_backbone_config,
)
from batdetect2.models.types import BackboneModel
def test_unet_backbone_config_defaults():
@ -61,10 +60,11 @@ def test_build_backbone_custom_config():
assert backbone.encoder.in_channels == 2
def test_build_backbone_returns_backbone_model():
"""build_backbone always returns a BackboneModel instance."""
def test_build_backbone_returns_unet_backbone():
"""build_backbone returns the default UNet backbone."""
backbone = build_backbone()
assert isinstance(backbone, BackboneModel)
assert isinstance(backbone, UNetBackbone)
def test_registry_has_unet_backbone():

View File

@ -1,3 +1,5 @@
from typing import cast
import numpy as np
import pytest
import torch
@ -19,12 +21,15 @@ def dummy_spectrogram() -> torch.Tensor:
def test_build_detector_default():
"""Test building the default detector without a config."""
num_classes = 5
model = build_detector(num_classes=num_classes)
model = build_detector(
class_names=[f"class_{i}" for i in range(num_classes)],
dimension_names=["width", "height"],
)
assert isinstance(model, Detector)
assert model.num_classes == num_classes
assert isinstance(model.classifier_head, ClassifierHead)
assert isinstance(model.bbox_head, BBoxHead)
assert isinstance(model.size_head, BBoxHead)
def test_build_detector_custom_config():
@ -32,13 +37,19 @@ def test_build_detector_custom_config():
num_classes = 3
config = UNetBackboneConfig(in_channels=2, input_height=128)
model = build_detector(num_classes=num_classes, config=config)
model = build_detector(
class_names=[f"class_{i}" for i in range(num_classes)],
dimension_names=["width", "height"],
config=config,
)
assert isinstance(model, Detector)
assert model.backbone.input_height == 128
assert isinstance(model.backbone.encoder, Encoder)
assert model.backbone.encoder.in_channels == 2
backbone = cast(UNetBackbone, model.backbone)
assert isinstance(backbone.encoder, Encoder)
assert backbone.encoder.in_channels == 2
def test_build_detector_custom_size_channels():
@ -47,8 +58,8 @@ def test_build_detector_custom_size_channels():
config = UNetBackboneConfig(in_channels=1, input_height=128)
model = build_detector(
num_classes=num_classes,
num_sizes=num_sizes,
class_names=[f"class_{i}" for i in range(num_classes)],
dimension_names=[f"size_{i}" for i in range(num_sizes)],
config=config,
)
@ -62,7 +73,11 @@ def test_detector_forward_pass_shapes(dummy_spectrogram):
num_classes = 4
# Build model matching the dummy input shape
config = UNetBackboneConfig(in_channels=1, input_height=256)
model = build_detector(num_classes=num_classes, config=config)
model = build_detector(
class_names=[f"class_{i}" for i in range(num_classes)],
dimension_names=["width", "height"],
config=config,
)
# Process the spectrogram through the model
# PyTorch expects shape (Batch, Channels, Height, Width)
@ -132,7 +147,11 @@ def test_detector_forward_pass_with_preprocessor(sample_preprocessor):
config = UNetBackboneConfig(
in_channels=spec.shape[1], input_height=spec.shape[2]
)
model = build_detector(num_classes=3, config=config)
model = build_detector(
class_names=["class_0", "class_1", "class_2"],
dimension_names=["width", "height"],
config=config,
)
# Process
output = model(spec)

View File

@ -1,9 +1,17 @@
import sys
import types
from pathlib import Path
import pytest
import torch
from soundevent import data
from batdetect2.train import TrainingConfig, run_train
from batdetect2.train.checkpoints import (
DEFAULT_CHECKPOINT,
get_bundled_checkpoint_names,
resolve_checkpoint_path,
)
pytestmark = pytest.mark.slow
@ -92,3 +100,133 @@ def test_train_controls_which_checkpoints_are_kept(
assert last_checkpoints
assert len(best_checkpoints) == 1
assert "epoch" in best_checkpoints[0].name
def test_train_saves_weights_only_checkpoints_by_default(
tmp_path: Path,
example_annotations: list[data.ClipAnnotation],
) -> None:
config = _build_fast_train_config()
run_train(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
train_config=config,
num_epochs=1,
train_workers=0,
val_workers=0,
checkpoint_dir=tmp_path,
seed=0,
)
checkpoint_path = next(tmp_path.rglob("*.ckpt"))
checkpoint = torch.load(
checkpoint_path,
map_location="cpu",
weights_only=False,
)
assert "state_dict" in checkpoint
assert "hyper_parameters" in checkpoint
assert "pytorch-lightning_version" in checkpoint
assert "optimizer_states" not in checkpoint
assert "lr_schedulers" not in checkpoint
def test_resolve_checkpoint_path_returns_local_path_unchanged(
tmp_path: Path,
) -> None:
local_path = tmp_path / "model.ckpt"
local_path.write_bytes(b"checkpoint")
assert resolve_checkpoint_path(local_path) == local_path
assert resolve_checkpoint_path(str(local_path)) == local_path
def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
assert get_bundled_checkpoint_names() == (
DEFAULT_CHECKPOINT,
"batdetect2_uk_same",
)
def test_resolve_checkpoint_path_uses_default_bundled_alias() -> None:
resolved = resolve_checkpoint_path()
assert resolved == resolve_checkpoint_path(DEFAULT_CHECKPOINT)
def test_resolve_checkpoint_path_accepts_bundled_alias() -> None:
resolved = resolve_checkpoint_path(DEFAULT_CHECKPOINT)
assert resolved.name == "batdetect2_uk_same.ckpt"
assert resolved.exists()
def test_resolve_checkpoint_path_prefers_existing_local_path_over_alias(
tmp_path: Path,
) -> None:
local_path = tmp_path / "uk_same"
local_path.write_bytes(b"checkpoint")
assert resolve_checkpoint_path(local_path) == local_path
assert resolve_checkpoint_path(str(local_path)) == local_path
def test_resolve_checkpoint_path_downloads_huggingface_checkpoint(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
expected_path = tmp_path / "downloaded.ckpt"
def fake_hf_hub_download(repo_id: str, filename: str) -> str:
assert repo_id == "owner/repo"
assert filename == "weights/model.ckpt"
return str(expected_path)
class FakeHuggingFaceHub(types.ModuleType):
hf_hub_download = staticmethod(fake_hf_hub_download)
fake_module = FakeHuggingFaceHub("huggingface_hub")
monkeypatch.setitem(
sys.modules,
"huggingface_hub",
fake_module,
)
resolved = resolve_checkpoint_path("hf://owner/repo/weights/model.ckpt")
assert resolved == expected_path
def test_resolve_checkpoint_path_requires_huggingface_dependency(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.delitem(sys.modules, "huggingface_hub", raising=False)
import builtins
original_import = builtins.__import__
def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == "huggingface_hub":
raise ImportError("missing")
return original_import(name, globals, locals, fromlist, level)
monkeypatch.setattr(builtins, "__import__", fake_import)
with pytest.raises(ValueError, match="Hugging Face checkpoint support"):
resolve_checkpoint_path("hf://owner/repo/weights/model.ckpt")
def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None:
with pytest.raises(ValueError, match="hf://owner/repo/path/to"):
resolve_checkpoint_path("hf://owner/repo")
def test_resolve_checkpoint_path_rejects_missing_local_path() -> None:
with pytest.raises(
FileNotFoundError,
match="checkpoint alias",
):
resolve_checkpoint_path("missing.ckpt")

View File

@ -368,7 +368,7 @@ def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> (
assert (
rebuilt_detector.classifier_head is not source_detector.classifier_head
)
assert rebuilt_detector.bbox_head is not source_detector.bbox_head
assert rebuilt_detector.size_head is not source_detector.size_head
assert rebuilt_model.class_names == ["single_class"]
assert rebuilt_model.dimension_names == ["width", "height"]
@ -451,7 +451,6 @@ def test_run_train_rejects_incompatible_model_config(
model=incompatible_model,
targets=targets,
roi_mapper=roi_mapper,
model_config=incompatible_config,
targets_config=targets_config,
train_config=TrainingConfig(),
)