mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Compare commits
17 Commits
5a974711b0
...
b0f85b96e3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0f85b96e3 | ||
|
|
ce6975770e | ||
|
|
69d8e2d228 | ||
|
|
855a79853b | ||
|
|
6587c6c4e5 | ||
|
|
831925bd95 | ||
|
|
b4efcfcf0f | ||
|
|
5cc5767eff | ||
|
|
2008c8000f | ||
|
|
a27d1bbfd3 | ||
|
|
999dc93d88 | ||
|
|
7c05fb8577 | ||
|
|
31054f64f6 | ||
|
|
84918086c8 | ||
|
|
d83f801515 | ||
|
|
5526ac99fc | ||
|
|
f5afa9881c |
@ -3,6 +3,8 @@ current_version = 1.1.1
|
|||||||
commit = True
|
commit = True
|
||||||
tag = True
|
tag = True
|
||||||
|
|
||||||
[bumpversion:file:batdetect2/__init__.py]
|
[bumpversion:file:src/batdetect2/__init__.py]
|
||||||
|
|
||||||
[bumpversion:file:pyproject.toml]
|
[bumpversion:file:pyproject.toml]
|
||||||
|
|
||||||
|
[bumpversion:file:docs/source/conf.py]
|
||||||
|
|||||||
79
.github/workflows/ci.yml
vendored
Normal file
79
.github/workflows/ci.yml
vendored
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: ci-${{ github.workflow }}-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
checks:
|
||||||
|
name: Checks
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
|
||||||
|
- name: Install just
|
||||||
|
uses: taiki-e/install-action@just
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
with:
|
||||||
|
enable-cache: true
|
||||||
|
cache-dependency-glob: |
|
||||||
|
pyproject.toml
|
||||||
|
uv.lock
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: just install-dev
|
||||||
|
|
||||||
|
- name: Run formatting, lint, and type checks
|
||||||
|
run: just check
|
||||||
|
|
||||||
|
tests:
|
||||||
|
name: Tests (Python ${{ matrix.python-version }})
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
python-version:
|
||||||
|
- "3.10"
|
||||||
|
- "3.11"
|
||||||
|
- "3.12"
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install just
|
||||||
|
uses: taiki-e/install-action@just
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
with:
|
||||||
|
enable-cache: true
|
||||||
|
cache-dependency-glob: |
|
||||||
|
pyproject.toml
|
||||||
|
uv.lock
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: just install-dev
|
||||||
|
|
||||||
|
- name: Run test suite
|
||||||
|
run: just test
|
||||||
69
.github/workflows/docs-pages.yml
vendored
Normal file
69
.github/workflows/docs-pages.yml
vendored
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
name: Docs Pages
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: docs-pages
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
name: Build Docs
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
|
||||||
|
- name: Install just
|
||||||
|
uses: taiki-e/install-action@just
|
||||||
|
|
||||||
|
- name: Configure GitHub Pages
|
||||||
|
uses: actions/configure-pages@v5
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
with:
|
||||||
|
enable-cache: true
|
||||||
|
cache-dependency-glob: |
|
||||||
|
pyproject.toml
|
||||||
|
uv.lock
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: just install-dev
|
||||||
|
|
||||||
|
- name: Build docs
|
||||||
|
run: just check-docs
|
||||||
|
|
||||||
|
- name: Upload Pages artifact
|
||||||
|
uses: actions/upload-pages-artifact@v4
|
||||||
|
with:
|
||||||
|
path: docs/build
|
||||||
|
|
||||||
|
deploy:
|
||||||
|
name: Deploy Docs
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
pages: write
|
||||||
|
id-token: write
|
||||||
|
environment:
|
||||||
|
name: github-pages
|
||||||
|
url: ${{ steps.deployment.outputs.page_url }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Deploy to GitHub Pages
|
||||||
|
id: deployment
|
||||||
|
uses: actions/deploy-pages@v4
|
||||||
70
.github/workflows/publish-pypi.yml
vendored
Normal file
70
.github/workflows/publish-pypi.yml
vendored
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
name: Publish PyPI
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types:
|
||||||
|
- published
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: publish-pypi
|
||||||
|
cancel-in-progress: false
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
name: Build Distributions
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
|
||||||
|
- name: Install just
|
||||||
|
uses: taiki-e/install-action@just
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v6
|
||||||
|
with:
|
||||||
|
enable-cache: true
|
||||||
|
cache-dependency-glob: |
|
||||||
|
pyproject.toml
|
||||||
|
uv.lock
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: just install-dev
|
||||||
|
|
||||||
|
- name: Build distributions
|
||||||
|
run: just build-dist
|
||||||
|
|
||||||
|
- name: Upload distributions
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: release-dists
|
||||||
|
path: dist/
|
||||||
|
|
||||||
|
publish:
|
||||||
|
name: Publish to PyPI
|
||||||
|
needs: build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
id-token: write
|
||||||
|
environment:
|
||||||
|
name: pypi
|
||||||
|
url: https://pypi.org/p/batdetect2
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Download distributions
|
||||||
|
uses: actions/download-artifact@v5
|
||||||
|
with:
|
||||||
|
name: release-dists
|
||||||
|
path: dist/
|
||||||
|
|
||||||
|
- name: Publish to PyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
29
.github/workflows/python-package.yml
vendored
29
.github/workflows/python-package.yml
vendored
@ -1,29 +0,0 @@
|
|||||||
name: Python package
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
branches: ["main"]
|
|
||||||
pull_request:
|
|
||||||
branches: ["main"]
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- name: Install uv
|
|
||||||
uses: astral-sh/setup-uv@v3
|
|
||||||
with:
|
|
||||||
enable-cache: true
|
|
||||||
cache-dependency-glob: "uv.lock"
|
|
||||||
- name: Set up Python ${{ matrix.python-version }}
|
|
||||||
run: uv python install ${{ matrix.python-version }}
|
|
||||||
- name: Install the project
|
|
||||||
run: uv sync --all-extras --dev
|
|
||||||
- name: Test with pytest
|
|
||||||
run: uv run pytest
|
|
||||||
30
.github/workflows/python-publish.yml
vendored
30
.github/workflows/python-publish.yml
vendored
@ -1,30 +0,0 @@
|
|||||||
name: Upload Python Package
|
|
||||||
|
|
||||||
on:
|
|
||||||
release:
|
|
||||||
types: [published]
|
|
||||||
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
deploy:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- name: Set up Python
|
|
||||||
uses: actions/setup-python@v3
|
|
||||||
with:
|
|
||||||
python-version: "3.x"
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install build
|
|
||||||
- name: Build package
|
|
||||||
run: python -m build
|
|
||||||
- name: Publish package
|
|
||||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
|
||||||
with:
|
|
||||||
user: __token__
|
|
||||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -50,6 +50,7 @@ cover/
|
|||||||
|
|
||||||
# Sphinx documentation
|
# Sphinx documentation
|
||||||
docs/_build/
|
docs/_build/
|
||||||
|
docs/build/
|
||||||
|
|
||||||
# PyBuilder
|
# PyBuilder
|
||||||
.pybuilder/
|
.pybuilder/
|
||||||
|
|||||||
247
README.md
247
README.md
@ -1,202 +1,137 @@
|
|||||||
# BatDetect2
|
# BatDetect2
|
||||||
<img style="display: block-inline;" width="64" height="64" src="assets/bat_icon.png"> Code for detecting and classifying bat echolocation calls in high frequency audio recordings.
|
|
||||||
|
|
||||||
## What BatDetect2 is useful for
|
<img style="display:block-inline;" width="64" height="64" src="assets/bat_icon.png">
|
||||||
|
|
||||||
BatDetect2 can help you screen recordings for bat calls,
|
Code for detecting and classifying bat echolocation calls in high-frequency
|
||||||
find recordings that need expert review,
|
audio recordings.
|
||||||
and compare model outputs across sites or projects with appropriate caution.
|
|
||||||
|
|
||||||
It is best used as a tool to support ecological work,
|
> [!WARNING]
|
||||||
not as a replacement for validation or expert interpretation.
|
> `batdetect2` 2.0.1 is out.
|
||||||
|
> There are many changes and new recommended workflows.
|
||||||
|
> We have left the previous `batdetect2.api` module intact, but if you run
|
||||||
|
> into issues or want to upgrade, see the
|
||||||
|
> [migration guide](docs/source/legacy/migration-guide.md) in the docs site.
|
||||||
|
>
|
||||||
|
> This update also ships with a refreshed default model.
|
||||||
|
> It was trained in the same way and on the same data as before, but you should still expect small output differences in some cases.
|
||||||
|
|
||||||
## Start here
|
## What is BatDetect2
|
||||||
|
|
||||||
If you want the simplest current workflow,
|
BatDetect2 is a deep learning model for detecting and classifying bat
|
||||||
use the documentation site and start with:
|
echolocation calls.
|
||||||
|
The model generates multiple predictions for each input recording by providing a
|
||||||
|
bounding box and predicted class for each individual call within it.
|
||||||
|
|
||||||
- getting started: `docs/source/getting_started.md`
|
This repository also holds `batdetect2`, a Python-based tool to run, train,
|
||||||
- first tutorial: `docs/source/tutorials/run-inference-on-folder.md`
|
finetune and evaluate BatDetect2-type models, including the built-in model for
|
||||||
|
detecting UK bat species.
|
||||||
|
You can use the tool from the command line (terminal) or from Python as needed.
|
||||||
|
|
||||||
The current docs default to:
|
## Getting Started
|
||||||
|
|
||||||
- the current command-line workflow: `batdetect2 predict`
|
We have [extensive documentation](docs/source/index.md) on how to use
|
||||||
- the current Python workflow: `batdetect2.api_v2.BatDetect2API`
|
`batdetect2`.
|
||||||
|
See our [getting started](docs/source/getting_started.md) guide and then jump
|
||||||
|
into any of our tutorials:
|
||||||
|
|
||||||
If you need the previous workflow based on `batdetect2 detect` or `batdetect2.api`,
|
- Run the model on a folder of recordings:
|
||||||
use the legacy docs section and migration guide in the docs site.
|
`docs/source/tutorials/run-inference-on-folder.md`
|
||||||
|
- Train your own model:
|
||||||
|
`docs/source/tutorials/train-a-custom-model.md`
|
||||||
|
- Evaluate your model:
|
||||||
|
`docs/source/tutorials/evaluate-on-a-test-set.md`
|
||||||
|
- Fine-tune a model:
|
||||||
|
`docs/source/tutorials/integrate-with-a-python-pipeline.md`
|
||||||
|
|
||||||
## Install BatDetect2
|
### Try the model
|
||||||
|
|
||||||
If you already use Python,
|
If you want to try the model for UK bat species without installing anything, you
|
||||||
activate the environment where you want BatDetect2 to live.
|
can try the following:
|
||||||
|
|
||||||
If not,
|
1. Demo of the model (for UK species) on
|
||||||
create a fresh one first so BatDetect2 stays separate from other software on your machine.
|
[huggingface](https://huggingface.co/spaces/macaodha/batdetect2).
|
||||||
|
|
||||||
Two common options are:
|
2. Alternatively, click
|
||||||
|
[here](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb)
|
||||||
* Install the Anaconda Python 3.10 distribution for your operating system from [here](https://www.continuum.io/downloads). Create a new environment and activate it:
|
to run the model using Google Colab.
|
||||||
|
You can also run this notebook locally.
|
||||||
```bash
|
|
||||||
conda create -y --name batdetect2 python==3.10
|
|
||||||
conda activate batdetect2
|
|
||||||
```
|
|
||||||
|
|
||||||
* If you already have Python installed (version >= 3.10,< 3.14), you can create a fresh environment with:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python -m venv .venv
|
|
||||||
source .venv/bin/activate
|
|
||||||
```
|
|
||||||
|
|
||||||
### Installing BatDetect2
|
### Installing BatDetect2
|
||||||
You can use pip to install `batdetect2`:
|
|
||||||
|
If you have `uv` installed (if not, we recommend it; follow the instructions
|
||||||
|
[here](https://docs.astral.sh/uv/getting-started/installation/)), then you can
|
||||||
|
run `batdetect2` one-off with
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install batdetect2
|
uvx batdetect2
|
||||||
```
|
```
|
||||||
|
|
||||||
Alternatively, download this code from the repository (by clicking on the green button on top right) and unzip it.
|
or if you want to install it permanently:
|
||||||
Once unzipped, run this from extracted folder.
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install .
|
uv tool install batdetect2
|
||||||
```
|
```
|
||||||
|
|
||||||
Make sure you have the environment activated before installing `batdetect2`.
|
and test it with
|
||||||
|
|
||||||
## Run BatDetect2 on a folder of recordings
|
```bash
|
||||||
|
batdetect2
|
||||||
Once installed,
|
|
||||||
the simplest current workflow is to run BatDetect2 on a folder of `.wav` files.
|
|
||||||
|
|
||||||
If you are working from this repository checkout,
|
|
||||||
you can use this example checkpoint path:
|
|
||||||
|
|
||||||
```text
|
|
||||||
src/batdetect2/models/checkpoints/Net2DFast_UK_same.pth.tar
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Run BatDetect2 on a folder of recordings
|
||||||
|
|
||||||
|
Once installed, you can run BatDetect2 on a folder of `.wav` files.
|
||||||
|
By default it will use the model trained on UK data.
|
||||||
|
|
||||||
Example command:
|
Example command:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 predict directory \
|
batdetect2 process directory example_data/audio outputs
|
||||||
src/batdetect2/models/checkpoints/Net2DFast_UK_same.pth.tar \
|
|
||||||
example_data/audio \
|
|
||||||
outputs
|
|
||||||
```
|
```
|
||||||
|
|
||||||
This will scan the audio files in `example_data/audio`
|
This will scan the audio files in `example_data/audio` and save model outputs to
|
||||||
and save model outputs to `outputs`.
|
`outputs`.
|
||||||
|
If you have your own model checkpoint, you can use it:
|
||||||
|
|
||||||
For the full beginner walkthrough,
|
|
||||||
use `docs/source/tutorials/run-inference-on-folder.md`.
|
|
||||||
|
|
||||||
## Legacy workflow
|
|
||||||
|
|
||||||
The sections below are kept only for people maintaining older BatDetect2 scripts and analysis pipelines.
|
|
||||||
|
|
||||||
If you are new to BatDetect2,
|
|
||||||
stop here and use the current docs and command above.
|
|
||||||
|
|
||||||
If you really do need the older workflow,
|
|
||||||
the reference material is below.
|
|
||||||
|
|
||||||
|
|
||||||
## Try the model
|
|
||||||
1) You can try a demo of the model (for UK species) on [huggingface](https://huggingface.co/spaces/macaodha/batdetect2).
|
|
||||||
|
|
||||||
2) Alternatively, click [here](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb) to run the model using Google Colab. You can also run this notebook locally.
|
|
||||||
|
|
||||||
|
|
||||||
## Running the model on your own data
|
|
||||||
|
|
||||||
After following the above steps to install the code you can run the model on your own data.
|
|
||||||
|
|
||||||
The remainder of this section is legacy reference material.
|
|
||||||
|
|
||||||
|
|
||||||
### Using the command line
|
|
||||||
|
|
||||||
The commands below describe the legacy CLI workflow.
|
|
||||||
|
|
||||||
For new work, prefer the current docs and `batdetect2 predict`.
|
|
||||||
|
|
||||||
You can run the model by opening the command line and typing:
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD
|
batdetect2 process directory --model path/to/checkpoint.ckpt example_data/audio outputs
|
||||||
```
|
|
||||||
e.g.
|
|
||||||
```bash
|
|
||||||
batdetect2 detect example_data/audio/ example_data/anns/ 0.3
|
|
||||||
```
|
```
|
||||||
|
|
||||||
`AUDIO_DIR` is the path on your computer to the audio wav files of interest.
|
For the full walkthrough, use
|
||||||
`ANN_DIR` is the path on your computer where the model predictions will be saved. The model will output both `.csv` and `.json` results for each audio file.
|
`docs/source/tutorials/run-inference-on-folder.md`.
|
||||||
`DETECTION_THRESHOLD` is a number between 0 and 1 specifying the cut-off threshold applied to the calls. A smaller number will result in more calls detected, but with the chance of introducing more mistakes.
|
|
||||||
|
|
||||||
There are also optional arguments, e.g. you can request that the model outputs features (i.e. estimated call parameters) such as duration, max_frequency, etc. by setting the flag `--spec_features`. These will be saved as `*_spec_features.csv` files:
|
|
||||||
`batdetect2 detect example_data/audio/ example_data/anns/ 0.3 --spec_features`
|
|
||||||
|
|
||||||
You can also specify which model to use by setting the `--model_path` argument. If not specified, it will default to using a model trained on UK data e.g.
|
|
||||||
`batdetect2 detect example_data/audio/ example_data/anns/ 0.3 --model_path models/Net2DFast_UK_same.pth.tar`
|
|
||||||
|
|
||||||
|
|
||||||
### Using the Python API
|
|
||||||
|
|
||||||
The examples below describe the legacy Python API.
|
|
||||||
|
|
||||||
For new work, prefer `batdetect2.api_v2.BatDetect2API` and the current docs site.
|
|
||||||
|
|
||||||
If you prefer to process your data within a Python script then you can use the `batdetect2` Python API.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from batdetect2 import api
|
|
||||||
|
|
||||||
AUDIO_FILE = "example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav"
|
|
||||||
|
|
||||||
# Process a whole file
|
|
||||||
results = api.process_file(AUDIO_FILE)
|
|
||||||
|
|
||||||
# Or, load audio and compute spectrograms
|
|
||||||
audio = api.load_audio(AUDIO_FILE)
|
|
||||||
spec = api.generate_spectrogram(audio)
|
|
||||||
|
|
||||||
# And process the audio or the spectrogram with the model
|
|
||||||
detections, features, spec = api.process_audio(audio)
|
|
||||||
detections, features = api.process_spectrogram(spec)
|
|
||||||
|
|
||||||
# Do something else ...
|
|
||||||
```
|
|
||||||
|
|
||||||
You can integrate the detections or the extracted features to your custom analysis pipeline.
|
|
||||||
|
|
||||||
|
|
||||||
## Training the model on your own data
|
|
||||||
Take a look at the training tutorial in the docs site first.
|
|
||||||
|
|
||||||
If you are working from this repository checkout,
|
|
||||||
start with `docs/source/tutorials/train-a-custom-model.md`.
|
|
||||||
|
|
||||||
|
|
||||||
## Data and annotations
|
## Data and annotations
|
||||||
The raw audio data and annotations used to train the models in the paper will be added soon.
|
|
||||||
The audio interface used to annotate audio data for training and evaluation is available [here](https://github.com/macaodha/batdetect2_GUI).
|
|
||||||
|
|
||||||
|
The raw audio data and annotations used to train the models in the paper will be
|
||||||
|
added soon.
|
||||||
|
`batdetect2` supports annotations in various formats and is compatible with the
|
||||||
|
outputs of [`whombat`](https://github.com/mbsantiago/whombat/) and this
|
||||||
|
[earlier version](https://github.com/macaodha/batdetect2_GUI).
|
||||||
|
If you're interested in supporting another format, please reach out or submit a
|
||||||
|
PR.
|
||||||
|
|
||||||
## Warning
|
## Warning
|
||||||
The models developed and shared as part of this repository should be used with caution.
|
|
||||||
While they have been evaluated on held out audio data, great care should be taken when using the model outputs for any form of biodiversity assessment.
|
|
||||||
Your data may differ, and as a result it is very strongly recommended that you validate the model first using data with known species to ensure that the outputs can be trusted.
|
|
||||||
|
|
||||||
|
The models developed and shared as part of this repository should be used with
|
||||||
|
caution.
|
||||||
|
While they have been evaluated on held-out audio data, great care should be
|
||||||
|
taken when using the model outputs for any form of biodiversity assessment.
|
||||||
|
Your data may differ, and as a result it is very strongly recommended that you
|
||||||
|
validate the model first using data with known species to ensure that the
|
||||||
|
outputs can be trusted.
|
||||||
|
If you train a model, make the best effort to be transparent about its training
|
||||||
|
and evaluation data, and inform downstream users about its limitations.
|
||||||
|
|
||||||
## FAQ
|
## FAQ
|
||||||
|
|
||||||
For more information please consult our [FAQ](docs/source/faq.md).
|
For more information please consult our [FAQ](docs/source/faq.md).
|
||||||
|
|
||||||
|
|
||||||
## Reference
|
## Reference
|
||||||
If you find our work useful in your research please consider citing our paper which you can find [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1):
|
|
||||||
|
If you find our work useful in your research, please consider citing our paper,
|
||||||
|
which you can find
|
||||||
|
[here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1):
|
||||||
|
|
||||||
```
|
```
|
||||||
@article{batdetect2_2022,
|
@article{batdetect2_2022,
|
||||||
title = {Towards a General Approach for Bat Echolocation Detection and Classification},
|
title = {Towards a General Approach for Bat Echolocation Detection and Classification},
|
||||||
@ -207,10 +142,6 @@ If you find our work useful in your research please consider citing our paper wh
|
|||||||
```
|
```
|
||||||
|
|
||||||
## Acknowledgements
|
## Acknowledgements
|
||||||
Thanks to all the contributors who spent time collecting and annotating audio data.
|
|
||||||
|
|
||||||
|
Thanks to all the contributors who spent time collecting and annotating audio
|
||||||
### TODOs
|
data.
|
||||||
- [x] Release the code and pretrained model
|
|
||||||
- [ ] Release the datasets and annotations used the experiments in the paper
|
|
||||||
- [ ] Add the scripts used to generate the tables and figures from the paper
|
|
||||||
|
|||||||
@ -1,19 +1,20 @@
|
|||||||
# Getting started
|
# Getting started
|
||||||
|
|
||||||
If you want to run BatDetect2 on your recordings,
|
If you want to run BatDetect2 on your recordings, start with the command-line
|
||||||
start with the command-line route below.
|
route below.
|
||||||
|
|
||||||
You do not need to write Python code for a standard first run.
|
You do not need to write Python code for a standard first run.
|
||||||
|
|
||||||
BatDetect2 also has a Python interface,
|
BatDetect2 also has a Python interface, but that is mainly for users writing
|
||||||
but that is mainly for users writing their own analysis scripts.
|
their own analysis scripts.
|
||||||
|
|
||||||
- Use the command-line route if you want to run an existing model or train your own model by typing commands in a terminal window.
|
- Use the command-line route if you want to run an existing model or train your
|
||||||
|
own model by typing commands in a terminal window.
|
||||||
- Use the Python route only if you already want to work in scripts or notebooks.
|
- Use the Python route only if you already want to work in scripts or notebooks.
|
||||||
|
|
||||||
```{note}
|
```{note}
|
||||||
If you are looking for the previous BatDetect2 workflow based on `batdetect2 detect` or `batdetect2.api`, go to {doc}`legacy/index`.
|
If you are looking for the previous BatDetect2 workflow based on `batdetect2 detect` or `batdetect2.api`, go to {doc}`legacy/index`.
|
||||||
New docs default to the current `predict` CLI and `BatDetect2API` workflow.
|
New docs default to the current `process` CLI and `BatDetect2API` workflow.
|
||||||
```
|
```
|
||||||
|
|
||||||
If you want to try BatDetect2 before installing anything locally:
|
If you want to try BatDetect2 before installing anything locally:
|
||||||
@ -27,15 +28,14 @@ If you want to try BatDetect2 before installing anything locally:
|
|||||||
2. Use a model checkpoint.
|
2. Use a model checkpoint.
|
||||||
3. Run the first tutorial on a folder of recordings.
|
3. Run the first tutorial on a folder of recordings.
|
||||||
|
|
||||||
If that is what you want,
|
If that is what you want, you can ignore the Python sections for now.
|
||||||
you can ignore the Python sections for now.
|
|
||||||
|
|
||||||
## Install BatDetect2
|
## Install BatDetect2
|
||||||
|
|
||||||
We recommend `uv` for both workflows.
|
We recommend `uv` for both workflows.
|
||||||
|
|
||||||
`uv` is a tool that helps install Python software cleanly,
|
`uv` is a tool that helps install Python software cleanly, without mixing it
|
||||||
without mixing it into the rest of your machine.
|
into the rest of your machine.
|
||||||
|
|
||||||
- Use `uv tool` to install the CLI.
|
- Use `uv tool` to install the CLI.
|
||||||
- Use `uv add` to add `batdetect2` as a dependency in a Python project.
|
- Use `uv add` to add `batdetect2` as a dependency in a Python project.
|
||||||
@ -70,7 +70,8 @@ Go to {doc}`tutorials/run-inference-on-folder` for a complete first run.
|
|||||||
|
|
||||||
## Choose a model checkpoint
|
## Choose a model checkpoint
|
||||||
|
|
||||||
The current command-line and Python workflows expect an explicit checkpoint path.
|
The current command-line and Python workflows expect an explicit checkpoint
|
||||||
|
path.
|
||||||
|
|
||||||
A checkpoint is the saved model file that BatDetect2 will use for prediction.
|
A checkpoint is the saved model file that BatDetect2 will use for prediction.
|
||||||
|
|
||||||
@ -85,7 +86,8 @@ In this repository checkout, an example pretrained checkpoint is available at:
|
|||||||
src/batdetect2/models/checkpoints/Net2DFast_UK_same.pth.tar
|
src/batdetect2/models/checkpoints/Net2DFast_UK_same.pth.tar
|
||||||
```
|
```
|
||||||
|
|
||||||
Use that path in the tutorial commands if you want a concrete starting point from this source tree.
|
Use that path in the tutorial commands if you want a concrete starting point
|
||||||
|
from this source tree.
|
||||||
|
|
||||||
## Python route for users writing code
|
## Python route for users writing code
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
# How to choose an inference input mode
|
# How to choose an inference input mode
|
||||||
|
|
||||||
Use this guide to decide whether `predict directory`, `predict file_list`, or `predict dataset` is the right entry point for your run.
|
Use this guide to decide whether `process directory`, `process file_list`, or
|
||||||
|
`process dataset` is the right entry point for your run.
|
||||||
|
|
||||||
## Use `predict directory` when the recordings already live together
|
## Use `process directory` when the recordings already live together
|
||||||
|
|
||||||
This is the simplest choice.
|
This is the simplest choice.
|
||||||
|
|
||||||
@ -13,13 +14,13 @@ Use it when:
|
|||||||
- you are doing a first pass over a folder of recordings.
|
- you are doing a first pass over a folder of recordings.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 predict directory \
|
batdetect2 process directory \
|
||||||
path/to/model.ckpt \
|
path/to/model.ckpt \
|
||||||
path/to/audio_dir \
|
path/to/audio_dir \
|
||||||
path/to/outputs
|
path/to/outputs
|
||||||
```
|
```
|
||||||
|
|
||||||
## Use `predict file_list` when you need explicit control over the file set
|
## Use `process file_list` when you need explicit control over the file set
|
||||||
|
|
||||||
Use it when:
|
Use it when:
|
||||||
|
|
||||||
@ -30,13 +31,13 @@ Use it when:
|
|||||||
The list file should contain one path per line.
|
The list file should contain one path per line.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 predict file_list \
|
batdetect2 process file_list \
|
||||||
path/to/model.ckpt \
|
path/to/model.ckpt \
|
||||||
path/to/audio_files.txt \
|
path/to/audio_files.txt \
|
||||||
path/to/outputs
|
path/to/outputs
|
||||||
```
|
```
|
||||||
|
|
||||||
## Use `predict dataset` when your workflow is already annotation-set driven
|
## Use `process dataset` when your workflow is already annotation-set driven
|
||||||
|
|
||||||
Use it when:
|
Use it when:
|
||||||
|
|
||||||
@ -45,13 +46,14 @@ Use it when:
|
|||||||
- you want BatDetect2 to resolve recording paths from the annotation set.
|
- you want BatDetect2 to resolve recording paths from the annotation set.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 predict dataset \
|
batdetect2 process dataset \
|
||||||
path/to/model.ckpt \
|
path/to/model.ckpt \
|
||||||
path/to/annotation_set.json \
|
path/to/annotation_set.json \
|
||||||
path/to/outputs
|
path/to/outputs
|
||||||
```
|
```
|
||||||
|
|
||||||
The dataset command reads a `soundevent` annotation set and extracts unique recording paths before inference.
|
The dataset command reads a `soundevent` annotation set and extracts unique
|
||||||
|
recording paths before inference.
|
||||||
|
|
||||||
## Rule of thumb
|
## Rule of thumb
|
||||||
|
|
||||||
@ -61,6 +63,9 @@ The dataset command reads a `soundevent` annotation set and extracts unique reco
|
|||||||
|
|
||||||
## Related pages
|
## Related pages
|
||||||
|
|
||||||
- Run batch predictions: {doc}`run-batch-predictions`
|
- Run batch predictions:
|
||||||
- Tune inference clipping: {doc}`tune-inference-clipping`
|
{doc}`run-batch-predictions`
|
||||||
- Predict command reference: {doc}`../reference/cli/predict`
|
- Tune inference clipping:
|
||||||
|
{doc}`tune-inference-clipping`
|
||||||
|
- Process command reference:
|
||||||
|
{doc}`../reference/cli/predict`
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# How to choose and configure evaluation tasks
|
# How to choose and configure evaluation tasks
|
||||||
|
|
||||||
Use this guide when the default evaluation tasks do not match the question you want to answer.
|
Use this guide when the default evaluation tasks do not match the question you
|
||||||
|
want to answer.
|
||||||
|
|
||||||
## Know the default first
|
## Know the default first
|
||||||
|
|
||||||
@ -24,8 +25,10 @@ Common built-in task families include:
|
|||||||
Choose based on the question you care about.
|
Choose based on the question you care about.
|
||||||
|
|
||||||
- Use sound-event tasks when you care about individual call events.
|
- Use sound-event tasks when you care about individual call events.
|
||||||
- Use clip tasks when you care about clip-level presence or clip-level class evidence.
|
- Use clip tasks when you care about clip-level presence or clip-level class
|
||||||
- Use top-class detection when you want matching based on the highest-scoring class per detection.
|
evidence.
|
||||||
|
- Use top-class detection when you want matching based on the highest-scoring
|
||||||
|
class per detection.
|
||||||
|
|
||||||
## Configure tasks in `EvaluationConfig`
|
## Configure tasks in `EvaluationConfig`
|
||||||
|
|
||||||
@ -45,22 +48,27 @@ Pass the config with:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 evaluate \
|
batdetect2 evaluate \
|
||||||
path/to/model.ckpt \
|
|
||||||
path/to/test_dataset.yaml \
|
path/to/test_dataset.yaml \
|
||||||
|
--model path/to/model.ckpt \
|
||||||
--base-dir path/to/project_root \
|
--base-dir path/to/project_root \
|
||||||
--evaluation-config path/to/evaluation.yaml
|
--evaluation-config path/to/evaluation.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
Include `--base-dir` when the dataset config resolves recordings through relative paths.
|
Include `--base-dir` when the dataset config resolves recordings through
|
||||||
|
relative paths.
|
||||||
|
|
||||||
## Change one thing at a time
|
## Change one thing at a time
|
||||||
|
|
||||||
When comparing models or settings, avoid changing task definitions, thresholds, matching behavior, and datasets all at once.
|
When comparing models or settings, avoid changing task definitions, thresholds,
|
||||||
|
matching behavior, and datasets all at once.
|
||||||
|
|
||||||
Otherwise it becomes hard to explain why the metric changed.
|
Otherwise it becomes hard to explain why the metric changed.
|
||||||
|
|
||||||
## Related pages
|
## Related pages
|
||||||
|
|
||||||
- Evaluation tutorial: {doc}`../tutorials/evaluate-on-a-test-set`
|
- Evaluation tutorial:
|
||||||
- Evaluation config reference: {doc}`../reference/evaluation-config`
|
{doc}`../tutorials/evaluate-on-a-test-set`
|
||||||
- Evaluation concepts: {doc}`../explanation/evaluation-concepts-and-matching`
|
- Evaluation config reference:
|
||||||
|
{doc}`../reference/evaluation-config`
|
||||||
|
- Evaluation concepts:
|
||||||
|
{doc}`../explanation/evaluation-concepts-and-matching`
|
||||||
|
|||||||
@ -46,7 +46,7 @@ Available built-ins:
|
|||||||
For CLI inference/evaluation, use `--audio-config`.
|
For CLI inference/evaluation, use `--audio-config`.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 predict directory \
|
batdetect2 process directory \
|
||||||
path/to/model.ckpt \
|
path/to/model.ckpt \
|
||||||
path/to/audio_dir \
|
path/to/audio_dir \
|
||||||
path/to/outputs \
|
path/to/outputs \
|
||||||
@ -55,10 +55,12 @@ batdetect2 predict directory \
|
|||||||
|
|
||||||
## 4) Verify quickly on a small subset
|
## 4) Verify quickly on a small subset
|
||||||
|
|
||||||
Run on a small folder first and confirm that outputs and runtime are as
|
Run on a small folder first and confirm that outputs and runtime are as expected
|
||||||
expected before full-batch runs.
|
before full-batch runs.
|
||||||
|
|
||||||
## Related pages
|
## Related pages
|
||||||
|
|
||||||
- Spectrogram settings: {doc}`configure-spectrogram-preprocessing`
|
- Spectrogram settings:
|
||||||
- Preprocessing config reference: {doc}`../reference/preprocessing-config`
|
{doc}`configure-spectrogram-preprocessing`
|
||||||
|
- Preprocessing config reference:
|
||||||
|
{doc}`../reference/preprocessing-config`
|
||||||
|
|||||||
@ -1,14 +1,15 @@
|
|||||||
# How to run batch predictions
|
# How to run batch processing
|
||||||
|
|
||||||
This guide shows practical command patterns for directory-based and file-list
|
This guide shows practical command patterns for directory-based and file-list
|
||||||
prediction runs.
|
processing runs.
|
||||||
|
|
||||||
Use it after you already know which input mode you want and need concrete command templates for a repeatable batch run.
|
Use it after you already know which input mode you want and need concrete
|
||||||
|
command templates for a repeatable batch run.
|
||||||
|
|
||||||
## Predict from a directory
|
## Process a directory
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 predict directory \
|
batdetect2 process directory \
|
||||||
path/to/model.ckpt \
|
path/to/model.ckpt \
|
||||||
path/to/audio_dir \
|
path/to/audio_dir \
|
||||||
path/to/outputs
|
path/to/outputs
|
||||||
@ -16,27 +17,29 @@ batdetect2 predict directory \
|
|||||||
|
|
||||||
Use this when BatDetect2 should discover the audio files for you.
|
Use this when BatDetect2 should discover the audio files for you.
|
||||||
|
|
||||||
## Predict from a file list
|
## Process a file list
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 predict file_list \
|
batdetect2 process file_list \
|
||||||
path/to/model.ckpt \
|
path/to/model.ckpt \
|
||||||
path/to/audio_files.txt \
|
path/to/audio_files.txt \
|
||||||
path/to/outputs
|
path/to/outputs
|
||||||
```
|
```
|
||||||
|
|
||||||
Use this when another part of your workflow already produced the exact recording list to process.
|
Use this when another part of your workflow already produced the exact recording
|
||||||
|
list to process.
|
||||||
|
|
||||||
## Predict from a dataset config
|
## Process a dataset config
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 predict dataset \
|
batdetect2 process dataset \
|
||||||
path/to/model.ckpt \
|
path/to/model.ckpt \
|
||||||
path/to/annotation_set.json \
|
path/to/annotation_set.json \
|
||||||
path/to/outputs
|
path/to/outputs
|
||||||
```
|
```
|
||||||
|
|
||||||
Use this when your project already has a `soundevent` annotation set and you want to extract unique recording paths from it.
|
Use this when your project already has a `soundevent` annotation set and you
|
||||||
|
want to extract unique recording paths from it.
|
||||||
|
|
||||||
## Useful options
|
## Useful options
|
||||||
|
|
||||||
|
|||||||
@ -1,22 +1,27 @@
|
|||||||
# How to save predictions in different output formats
|
# How to save predictions in different output formats
|
||||||
|
|
||||||
Use this guide when you need BatDetect2 outputs in a specific representation for downstream tools.
|
Use this guide when you need BatDetect2 outputs in a specific representation for
|
||||||
|
downstream tools.
|
||||||
|
|
||||||
## Choose the format that matches the job
|
## Choose the format that matches the job
|
||||||
|
|
||||||
Current built-in output formats include:
|
Current built-in output formats include:
|
||||||
|
|
||||||
- `raw`: one NetCDF file per clip, best for rich structured outputs,
|
- `raw`:
|
||||||
- `parquet`: tabular storage for data analysis workflows,
|
one NetCDF file per clip, best for rich structured outputs,
|
||||||
- `soundevent`: prediction-set JSON for soundevent-style tooling,
|
- `parquet`:
|
||||||
- `batdetect2`: legacy per-recording JSON output.
|
tabular storage for data analysis workflows,
|
||||||
|
- `soundevent`:
|
||||||
|
prediction-set JSON for soundevent-style tooling,
|
||||||
|
- `batdetect2`:
|
||||||
|
legacy per-recording JSON output.
|
||||||
|
|
||||||
## Select a format from the CLI
|
## Select a format from the CLI
|
||||||
|
|
||||||
Use `--format` for quick experiments.
|
Use `--format` for quick experiments.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 predict directory \
|
batdetect2 process directory \
|
||||||
path/to/model.ckpt \
|
path/to/model.ckpt \
|
||||||
path/to/audio_dir \
|
path/to/audio_dir \
|
||||||
path/to/outputs \
|
path/to/outputs \
|
||||||
@ -25,7 +30,8 @@ batdetect2 predict directory \
|
|||||||
|
|
||||||
## Use an outputs config for repeatable runs
|
## Use an outputs config for repeatable runs
|
||||||
|
|
||||||
Use an outputs config when you want reproducible control over format and transforms.
|
Use an outputs config when you want reproducible control over format and
|
||||||
|
transforms.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@ -43,7 +49,7 @@ transform:
|
|||||||
Run with:
|
Run with:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 predict directory \
|
batdetect2 process directory \
|
||||||
path/to/model.ckpt \
|
path/to/model.ckpt \
|
||||||
path/to/audio_dir \
|
path/to/audio_dir \
|
||||||
path/to/outputs \
|
path/to/outputs \
|
||||||
@ -59,6 +65,9 @@ batdetect2 predict directory \
|
|||||||
|
|
||||||
## Related pages
|
## Related pages
|
||||||
|
|
||||||
- Outputs config reference: {doc}`../reference/outputs-config`
|
- Outputs config reference:
|
||||||
- Output formats reference: {doc}`../reference/output-formats`
|
{doc}`../reference/outputs-config`
|
||||||
- Output transforms reference: {doc}`../reference/output-transforms`
|
- Output formats reference:
|
||||||
|
{doc}`../reference/output-formats`
|
||||||
|
- Output transforms reference:
|
||||||
|
{doc}`../reference/output-transforms`
|
||||||
|
|||||||
@ -4,7 +4,8 @@ Use this guide to compare detection outputs at different threshold values.
|
|||||||
|
|
||||||
The goal is not to find a universal threshold.
|
The goal is not to find a universal threshold.
|
||||||
|
|
||||||
The goal is to choose a threshold that fits your reviewed local data and the project trade-off between missed calls and false positives.
|
The goal is to choose a threshold that fits your reviewed local data and the
|
||||||
|
project trade-off between missed calls and false positives.
|
||||||
|
|
||||||
## 1) Start with a baseline run
|
## 1) Start with a baseline run
|
||||||
|
|
||||||
@ -12,12 +13,12 @@ Run an initial prediction workflow and keep outputs in a dedicated folder.
|
|||||||
|
|
||||||
## 2) Sweep threshold values
|
## 2) Sweep threshold values
|
||||||
|
|
||||||
Run `predict` multiple times with different thresholds (for example `0.1`,
|
Run `process` multiple times with different thresholds (for example `0.1`,
|
||||||
`0.3`, `0.5`) and compare output counts and quality on the same validation
|
`0.3`, `0.5`) and compare output counts and quality on the same validation
|
||||||
subset.
|
subset.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 predict directory \
|
batdetect2 process directory \
|
||||||
path/to/model.ckpt \
|
path/to/model.ckpt \
|
||||||
path/to/audio_dir \
|
path/to/audio_dir \
|
||||||
path/to/outputs_thr_03 \
|
path/to/outputs_thr_03 \
|
||||||
@ -26,7 +27,8 @@ batdetect2 predict directory \
|
|||||||
|
|
||||||
Keep each threshold run in a separate output directory.
|
Keep each threshold run in a separate output directory.
|
||||||
|
|
||||||
That makes it easier to compare counts and inspect example files without mixing results.
|
That makes it easier to compare counts and inspect example files without mixing
|
||||||
|
results.
|
||||||
|
|
||||||
## 3) Validate against known calls
|
## 3) Validate against known calls
|
||||||
|
|
||||||
@ -38,7 +40,8 @@ Check both:
|
|||||||
- obvious false positives,
|
- obvious false positives,
|
||||||
- obvious missed calls.
|
- obvious missed calls.
|
||||||
|
|
||||||
If class interpretation matters downstream, inspect class ranking behavior as well, not just detection counts.
|
If class interpretation matters downstream, inspect class ranking behavior as
|
||||||
|
well, not just detection counts.
|
||||||
|
|
||||||
## 4) Record your chosen setting
|
## 4) Record your chosen setting
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# How to tune inference clipping
|
# How to tune inference clipping
|
||||||
|
|
||||||
Use this guide when long recordings need to be split into smaller clips during inference.
|
Use this guide when long recordings need to be split into smaller clips during
|
||||||
|
inference.
|
||||||
|
|
||||||
## What clipping controls
|
## What clipping controls
|
||||||
|
|
||||||
@ -8,14 +9,19 @@ Use this guide when long recordings need to be split into smaller clips during i
|
|||||||
|
|
||||||
Key fields are:
|
Key fields are:
|
||||||
|
|
||||||
- `duration`: clip duration in seconds,
|
- `duration`:
|
||||||
- `overlap`: overlap between adjacent clips,
|
clip duration in seconds,
|
||||||
- `max_empty`: how much empty padding is allowed,
|
- `overlap`:
|
||||||
- `discard_empty`: whether empty clips are dropped.
|
overlap between adjacent clips,
|
||||||
|
- `max_empty`:
|
||||||
|
how much empty padding is allowed,
|
||||||
|
- `discard_empty`:
|
||||||
|
whether empty clips are dropped.
|
||||||
|
|
||||||
## Start from the defaults
|
## Start from the defaults
|
||||||
|
|
||||||
Use the built-in clipping behavior first unless you already know you need something else.
|
Use the built-in clipping behavior first unless you already know you need
|
||||||
|
something else.
|
||||||
|
|
||||||
Only tune clipping when:
|
Only tune clipping when:
|
||||||
|
|
||||||
@ -25,7 +31,7 @@ Only tune clipping when:
|
|||||||
|
|
||||||
## Override clipping with an inference config
|
## Override clipping with an inference config
|
||||||
|
|
||||||
Create an inference config file and pass it to `predict` or `evaluate`.
|
Create an inference config file and pass it to `process` or `evaluate`.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@ -43,7 +49,7 @@ loader:
|
|||||||
Run with:
|
Run with:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 predict directory \
|
batdetect2 process directory \
|
||||||
path/to/model.ckpt \
|
path/to/model.ckpt \
|
||||||
path/to/audio_dir \
|
path/to/audio_dir \
|
||||||
path/to/outputs \
|
path/to/outputs \
|
||||||
@ -52,12 +58,16 @@ batdetect2 predict directory \
|
|||||||
|
|
||||||
## Validate clipping changes on a small reviewed subset
|
## Validate clipping changes on a small reviewed subset
|
||||||
|
|
||||||
Changing clipping changes what the model sees per batch and can change how events near clip boundaries behave.
|
Changing clipping changes what the model sees per batch and can change how
|
||||||
|
events near clip boundaries behave.
|
||||||
|
|
||||||
Check a reviewed subset before applying clipping changes to a full project.
|
Check a reviewed subset before applying clipping changes to a full project.
|
||||||
|
|
||||||
## Related pages
|
## Related pages
|
||||||
|
|
||||||
- Inference config reference: {doc}`../reference/inference-config`
|
- Inference config reference:
|
||||||
- Run batch predictions: {doc}`run-batch-predictions`
|
{doc}`../reference/inference-config`
|
||||||
- Understanding the pipeline: {doc}`../explanation/pipeline-overview`
|
- Run batch predictions:
|
||||||
|
{doc}`run-batch-predictions`
|
||||||
|
- Understanding the pipeline:
|
||||||
|
{doc}`../explanation/pipeline-overview`
|
||||||
|
|||||||
@ -6,25 +6,20 @@ Welcome to the BatDetect2 documentation.
|
|||||||
|
|
||||||
`batdetect2` detects bat echolocation calls in audio recordings.
|
`batdetect2` detects bat echolocation calls in audio recordings.
|
||||||
|
|
||||||
It can help you screen large collections of recordings,
|
It can help you screen large collections of recordings, find files that need
|
||||||
find files that need expert review,
|
expert review, and support ecology and conservation work where manual review
|
||||||
and support ecology and conservation work where manual review alone would be slow.
|
alone would be slow.
|
||||||
|
|
||||||
In practice,
|
In practice, BatDetect2 takes recordings, looks for likely bat calls, draws a
|
||||||
BatDetect2 takes recordings,
|
box around each detected event, and scores the most likely class for that event.
|
||||||
looks for likely bat calls,
|
|
||||||
draws a box around each detected event,
|
|
||||||
and scores the most likely class for that event.
|
|
||||||
|
|
||||||
The current default model is trained for 17 UK species.
|
The current default model is trained for 17 UK species.
|
||||||
|
|
||||||
The library also supports custom training,
|
The library also supports custom training, fine-tuning, evaluation, and more
|
||||||
fine-tuning,
|
advanced use from Python.
|
||||||
evaluation,
|
|
||||||
and more advanced use from Python.
|
|
||||||
|
|
||||||
For details on the underlying approach, see the pre-print:
|
For details on the underlying approach, see the pre-print:
|
||||||
[Towards a General Approach for Bat Echolocation Detection and Classification](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)
|
[Towards a General Approach for Bat Echolocation Detection and Classification](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)
|
||||||
|
|
||||||
## A good first use for BatDetect2
|
## A good first use for BatDetect2
|
||||||
|
|
||||||
@ -56,7 +51,7 @@ Always validate on reviewed local data before using results for ecological infer
|
|||||||
```{note}
|
```{note}
|
||||||
Looking for the previous BatDetect2 workflow?
|
Looking for the previous BatDetect2 workflow?
|
||||||
See {doc}`legacy/index`.
|
See {doc}`legacy/index`.
|
||||||
The legacy docs are still available, but new workflows should use `batdetect2 predict` and `BatDetect2API`.
|
The legacy docs are still available, but new workflows should use `batdetect2 process` and `BatDetect2API`.
|
||||||
```
|
```
|
||||||
|
|
||||||
## How to use this site
|
## How to use this site
|
||||||
@ -65,8 +60,7 @@ Start with {doc}`getting_started` if you are new.
|
|||||||
|
|
||||||
Then choose the section that matches what you need.
|
Then choose the section that matches what you need.
|
||||||
|
|
||||||
If you are here mainly to run the model on recordings,
|
If you are here mainly to run the model on recordings, start with Tutorials.
|
||||||
start with Tutorials.
|
|
||||||
|
|
||||||
| Section | Best for | Start here |
|
| Section | Best for | Start here |
|
||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
@ -81,7 +75,7 @@ start with Tutorials.
|
|||||||
- GitHub repository:
|
- GitHub repository:
|
||||||
[macaodha/batdetect2](https://github.com/macaodha/batdetect2)
|
[macaodha/batdetect2](https://github.com/macaodha/batdetect2)
|
||||||
- Questions, bug reports, and feature requests:
|
- Questions, bug reports, and feature requests:
|
||||||
[GitHub Issues](https://github.com/macaodha/batdetect2/issues)
|
[GitHub Issues](https://github.com/macaodha/batdetect2/issues)
|
||||||
- Common questions:
|
- Common questions:
|
||||||
{doc}`faq`
|
{doc}`faq`
|
||||||
- Want to contribute?
|
- Want to contribute?
|
||||||
|
|||||||
@ -4,7 +4,7 @@ This page documents the previous CLI workflow based on `batdetect2 detect`.
|
|||||||
|
|
||||||
```{warning}
|
```{warning}
|
||||||
This is legacy documentation.
|
This is legacy documentation.
|
||||||
For new workflows, use `batdetect2 predict directory` instead.
|
For new workflows, use `batdetect2 process directory` instead.
|
||||||
If you are migrating, start with {doc}`migration-guide`.
|
If you are migrating, start with {doc}`migration-guide`.
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ Common legacy options included:
|
|||||||
The closest current CLI entry point is:
|
The closest current CLI entry point is:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 predict directory \
|
batdetect2 process directory \
|
||||||
path/to/model.ckpt \
|
path/to/model.ckpt \
|
||||||
path/to/audio_dir \
|
path/to/audio_dir \
|
||||||
path/to/outputs
|
path/to/outputs
|
||||||
@ -35,5 +35,7 @@ batdetect2 predict directory \
|
|||||||
|
|
||||||
## Related pages
|
## Related pages
|
||||||
|
|
||||||
- Migration guide: {doc}`migration-guide`
|
- Migration guide:
|
||||||
- Current predict docs: {doc}`../reference/cli/predict`
|
{doc}`migration-guide`
|
||||||
|
- Current process docs:
|
||||||
|
{doc}`../reference/cli/predict`
|
||||||
|
|||||||
@ -2,12 +2,15 @@
|
|||||||
|
|
||||||
This section documents the previous BatDetect2 workflow.
|
This section documents the previous BatDetect2 workflow.
|
||||||
|
|
||||||
Use these pages if you need to keep working with the older `batdetect2 detect` command or the older `batdetect2.api` interface.
|
Use these pages if you need to keep working with the older `batdetect2 detect`
|
||||||
|
command or the older `batdetect2.api` interface.
|
||||||
|
|
||||||
For new projects, we recommend the current workflow:
|
For new projects, we recommend the current workflow:
|
||||||
|
|
||||||
- CLI: `batdetect2 predict`
|
- CLI:
|
||||||
- Python: `batdetect2.api_v2.BatDetect2API`
|
`batdetect2 process`
|
||||||
|
- Python:
|
||||||
|
`batdetect2.api_v2.BatDetect2API`
|
||||||
|
|
||||||
If you are moving from the older workflow, start with {doc}`migration-guide`.
|
If you are moving from the older workflow, start with {doc}`migration-guide`.
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# Migration guide: legacy to current workflows
|
# Migration guide: legacy to current workflows
|
||||||
|
|
||||||
Use this guide when moving from the previous BatDetect2 workflow to the current CLI and API.
|
Use this guide when moving from the previous BatDetect2 workflow to the current
|
||||||
|
CLI and API.
|
||||||
|
|
||||||
## Who should migrate now
|
## Who should migrate now
|
||||||
|
|
||||||
@ -9,31 +10,37 @@ You should migrate if:
|
|||||||
- you are starting a new workflow,
|
- you are starting a new workflow,
|
||||||
- you want the current docs path,
|
- you want the current docs path,
|
||||||
- you want the newer CLI and API surface,
|
- you want the newer CLI and API surface,
|
||||||
- you are maintaining code that does not depend on the exact legacy JSON or feature outputs.
|
- you are maintaining code that does not depend on the exact legacy JSON or
|
||||||
|
feature outputs.
|
||||||
|
|
||||||
You may need the legacy workflow a bit longer if:
|
You may need the legacy workflow a bit longer if:
|
||||||
|
|
||||||
- downstream tooling depends on the exact old output structure,
|
- downstream tooling depends on the exact old output structure,
|
||||||
- you rely on older notebooks built around `batdetect2.api`,
|
- you rely on older notebooks built around `batdetect2.api`,
|
||||||
- you depend on legacy feature extraction outputs without a validated replacement yet.
|
- you depend on legacy feature extraction outputs without a validated
|
||||||
|
replacement yet.
|
||||||
|
|
||||||
## CLI mapping
|
## CLI mapping
|
||||||
|
|
||||||
- `batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD`
|
- `batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD` -> `batdetect2
|
||||||
-> `batdetect2 predict directory MODEL_PATH AUDIO_DIR OUTPUT_PATH --detection-threshold ...`
|
process directory MODEL_PATH AUDIO_DIR OUTPUT_PATH --detection-threshold ...`
|
||||||
|
|
||||||
Main changes:
|
Main changes:
|
||||||
|
|
||||||
- the model path is now a positional argument on the `predict` subcommand,
|
- the model path is now a positional argument on the `process` subcommand,
|
||||||
- the current workflow expects an explicit checkpoint path rather than silently relying on the old default CLI behavior,
|
- the current workflow expects an explicit checkpoint path rather than silently
|
||||||
|
relying on the old default CLI behavior,
|
||||||
- output formatting is configurable,
|
- output formatting is configurable,
|
||||||
- threshold override is an option rather than a required positional argument,
|
- threshold override is an option rather than a required positional argument,
|
||||||
- there are separate subcommands for directory, file-list, and dataset-driven inference.
|
- there are separate subcommands for directory, file-list, and dataset-driven
|
||||||
|
inference.
|
||||||
|
|
||||||
## Python API mapping
|
## Python API mapping
|
||||||
|
|
||||||
- old: `import batdetect2.api as api`
|
- old:
|
||||||
- current: `from batdetect2.api_v2 import BatDetect2API`
|
`import batdetect2.api as api`
|
||||||
|
- current:
|
||||||
|
`from batdetect2.api_v2 import BatDetect2API`
|
||||||
|
|
||||||
Typical migration shape:
|
Typical migration shape:
|
||||||
|
|
||||||
@ -51,7 +58,7 @@ Useful replacements:
|
|||||||
- legacy `process_file` -> current `BatDetect2API.process_file`
|
- legacy `process_file` -> current `BatDetect2API.process_file`
|
||||||
- legacy `process_audio` -> current `BatDetect2API.process_audio`
|
- legacy `process_audio` -> current `BatDetect2API.process_audio`
|
||||||
- legacy `process_spectrogram` -> current `BatDetect2API.process_spectrogram`
|
- legacy `process_spectrogram` -> current `BatDetect2API.process_spectrogram`
|
||||||
- legacy one-off batch loops -> current `process_files` or CLI `predict`
|
- legacy one-off batch loops -> current `process_files` or CLI `process`
|
||||||
|
|
||||||
## Output and terminology changes
|
## Output and terminology changes
|
||||||
|
|
||||||
@ -78,7 +85,8 @@ Before replacing a legacy workflow in production or research analysis, validate:
|
|||||||
- that outputs are being saved in the right format,
|
- that outputs are being saved in the right format,
|
||||||
- that downstream code reads the new outputs correctly,
|
- that downstream code reads the new outputs correctly,
|
||||||
- that feature-related assumptions still hold,
|
- that feature-related assumptions still hold,
|
||||||
- that evaluation and ecological interpretation are unchanged only where you have actually verified that.
|
- that evaluation and ecological interpretation are unchanged only where you
|
||||||
|
have actually verified that.
|
||||||
|
|
||||||
## Migration checklist
|
## Migration checklist
|
||||||
|
|
||||||
@ -91,6 +99,9 @@ Before replacing a legacy workflow in production or research analysis, validate:
|
|||||||
|
|
||||||
## Related pages
|
## Related pages
|
||||||
|
|
||||||
- Current getting started: {doc}`../getting_started`
|
- Current getting started:
|
||||||
- Current tutorials: {doc}`../tutorials/index`
|
{doc}`../getting_started`
|
||||||
- Current API reference: {doc}`../reference/api`
|
- Current tutorials:
|
||||||
|
{doc}`../tutorials/index`
|
||||||
|
- Current API reference:
|
||||||
|
{doc}`../reference/api`
|
||||||
|
|||||||
@ -1,65 +1,33 @@
|
|||||||
# `BatDetect2API` reference
|
# `BatDetect2API` reference
|
||||||
|
|
||||||
`BatDetect2API` is the main entry point for the current Python workflow.
|
`BatDetect2API` is the main Python entry point for BatDetect2.
|
||||||
|
|
||||||
It wraps model loading, inference, evaluation, output formatting, and
|
Use it when you want to load a model, run prediction, inspect detections,
|
||||||
training-related entry points behind one object.
|
evaluate results, or train from Python.
|
||||||
|
|
||||||
Defined in `batdetect2.api_v2`.
|
Defined in `batdetect2.api_v2`.
|
||||||
|
|
||||||
## Create an API instance
|
## Main ways to create it
|
||||||
|
|
||||||
- `BatDetect2API.from_checkpoint(path, ...)`
|
- `BatDetect2API.from_checkpoint(path, ...)`
|
||||||
- load a trained checkpoint and optional config overrides.
|
- load a trained checkpoint, a bundled checkpoint alias, or a Hugging Face
|
||||||
|
checkpoint.
|
||||||
- `BatDetect2API.from_config(model_config=..., targets_config=..., ...)`
|
- `BatDetect2API.from_config(model_config=..., targets_config=..., ...)`
|
||||||
- build a full stack from separate config objects.
|
- build a full model stack from config objects.
|
||||||
|
|
||||||
## Inference methods
|
## Common tasks
|
||||||
|
|
||||||
- `process_file(audio_file, ...)`
|
- Load a checkpoint and run prediction on one file.
|
||||||
- run inference for one recording.
|
- Run prediction on many files or clips.
|
||||||
- `process_files(audio_files, ...)`
|
- Save predictions in one of the supported output formats.
|
||||||
- run batch inference across a sequence of file paths.
|
- Evaluate a model on labelled data.
|
||||||
- `process_directory(audio_dir, ...)`
|
- Fine-tune an existing checkpoint on new targets.
|
||||||
- run inference across the audio files found in one directory.
|
|
||||||
- `process_clips(clips, ...)`
|
|
||||||
- run inference on an explicit sequence of clip objects.
|
|
||||||
- `process_audio(audio, ...)`
|
|
||||||
- run inference starting from a waveform array.
|
|
||||||
- `process_spectrogram(spec, ...)`
|
|
||||||
- run inference starting from a spectrogram tensor.
|
|
||||||
|
|
||||||
## Prediction inspection helpers
|
## Generated reference
|
||||||
|
|
||||||
- `get_top_class_name(detection)`
|
```{eval-rst}
|
||||||
- return the highest-scoring class name for one detection.
|
.. autoclass:: batdetect2.api_v2.BatDetect2API
|
||||||
- `get_class_scores(detection, include_top_class=True, sort_descending=True)`
|
```
|
||||||
- return ranked `(class_name, score)` pairs.
|
|
||||||
- `get_detection_features(detection)`
|
|
||||||
- return the per-detection feature vector.
|
|
||||||
|
|
||||||
## Audio loading helpers
|
|
||||||
|
|
||||||
- `load_audio(path)`
|
|
||||||
- `load_recording(recording)`
|
|
||||||
- `load_clip(clip)`
|
|
||||||
- `generate_spectrogram(audio)`
|
|
||||||
|
|
||||||
## Output persistence helpers
|
|
||||||
|
|
||||||
- `save_predictions(predictions, path, audio_dir=None, format=None,
|
|
||||||
config=None)`
|
|
||||||
- `load_predictions(path, format=None, config=None)`
|
|
||||||
|
|
||||||
Use these when you want to save programmatic predictions without going through
|
|
||||||
the CLI.
|
|
||||||
|
|
||||||
## Training and evaluation entry points
|
|
||||||
|
|
||||||
- `train(...)`
|
|
||||||
- `finetune(...)`
|
|
||||||
- `evaluate(...)`
|
|
||||||
- `evaluate_predictions(...)`
|
|
||||||
|
|
||||||
## Related pages
|
## Related pages
|
||||||
|
|
||||||
|
|||||||
@ -4,13 +4,13 @@ Legacy detect command
|
|||||||
.. warning::
|
.. warning::
|
||||||
|
|
||||||
``batdetect2 detect`` is a legacy compatibility command.
|
``batdetect2 detect`` is a legacy compatibility command.
|
||||||
Prefer ``batdetect2 predict directory`` for new workflows.
|
Prefer ``batdetect2 process directory`` for new workflows.
|
||||||
|
|
||||||
Migration at a glance
|
Migration at a glance
|
||||||
---------------------
|
---------------------
|
||||||
|
|
||||||
- Legacy: ``batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD``
|
- Legacy: ``batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD``
|
||||||
- Current: ``batdetect2 predict directory MODEL_PATH AUDIO_DIR OUTPUT_PATH``
|
- Current: ``batdetect2 process directory MODEL_PATH AUDIO_DIR OUTPUT_PATH``
|
||||||
with optional ``--detection-threshold``
|
with optional ``--detection-threshold``
|
||||||
|
|
||||||
.. click:: batdetect2.cli.compat:detect
|
.. click:: batdetect2.cli.compat:detect
|
||||||
|
|||||||
@ -1,7 +1,10 @@
|
|||||||
Evaluate command
|
Evaluate command
|
||||||
================
|
================
|
||||||
|
|
||||||
Evaluate a checkpoint against a configured test dataset.
|
Use ``batdetect2 evaluate`` to compare a checkpoint against labelled test data.
|
||||||
|
|
||||||
|
This command writes metrics and any configured artifacts to the output
|
||||||
|
directory.
|
||||||
|
|
||||||
.. click:: batdetect2.cli.evaluate:evaluate_command
|
.. click:: batdetect2.cli.evaluate:evaluate_command
|
||||||
:prog: batdetect2 evaluate
|
:prog: batdetect2 evaluate
|
||||||
|
|||||||
11
docs/source/reference/cli/finetune.rst
Normal file
11
docs/source/reference/cli/finetune.rst
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
Finetune command
|
||||||
|
================
|
||||||
|
|
||||||
|
Use ``batdetect2 finetune`` to adapt an existing checkpoint to a new target
|
||||||
|
definition.
|
||||||
|
|
||||||
|
If you do not pass ``--model``, the bundled ``uk_same`` checkpoint is used.
|
||||||
|
|
||||||
|
.. click:: batdetect2.cli.finetune:finetune_command
|
||||||
|
:prog: batdetect2 finetune
|
||||||
|
:nested: none
|
||||||
@ -1,35 +1,33 @@
|
|||||||
# CLI reference
|
# CLI reference
|
||||||
|
|
||||||
Use this section to find the right command quickly, then open the command page
|
Use this section to find the right command quickly, then open the command page
|
||||||
for full options and argument details.
|
for the full option list.
|
||||||
|
|
||||||
## How to use this section
|
|
||||||
|
|
||||||
1. Start with {doc}`base` for options shared across the CLI.
|
|
||||||
2. Pick the command group or command you need from the command map below.
|
|
||||||
3. Open the linked page for complete autogenerated option reference.
|
|
||||||
|
|
||||||
## Command map
|
## Command map
|
||||||
|
|
||||||
| Command | Use it for | Required positional args |
|
| Command | Use it for | Required positional args |
|
||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
| `batdetect2 predict` | Run inference on audio | Depends on subcommand (`directory`, `file_list`, `dataset`) |
|
| `batdetect2 process` | Run inference on audio | Depends on subcommand (`directory`, `file_list`, `dataset`) |
|
||||||
| `batdetect2 data` | Inspect and convert dataset configs | Depends on subcommand (`summary`, `convert`) |
|
| `batdetect2 data` | Inspect and convert dataset configs | Depends on subcommand (`summary`, `convert`) |
|
||||||
| `batdetect2 train` | Train or fine-tune models | `TRAIN_DATASET` |
|
| `batdetect2 train` | Train or fine-tune models | `TRAIN_DATASET` |
|
||||||
| `batdetect2 evaluate` | Evaluate a checkpoint on a test dataset | `MODEL_PATH`, `TEST_DATASET` |
|
| `batdetect2 finetune` | Fine-tune a checkpoint on new targets | `TRAIN_DATASET` plus `--targets` |
|
||||||
|
| `batdetect2 evaluate` | Evaluate a checkpoint on a test dataset | `TEST_DATASET` |
|
||||||
| `batdetect2 detect` | Legacy compatibility workflow | `AUDIO_DIR`, `ANN_DIR`, `DETECTION_THRESHOLD` |
|
| `batdetect2 detect` | Legacy compatibility workflow | `AUDIO_DIR`, `ANN_DIR`, `DETECTION_THRESHOLD` |
|
||||||
|
|
||||||
## Global options and conventions
|
## Notes
|
||||||
|
|
||||||
- Global CLI options are documented in {doc}`base`.
|
- Global CLI options are documented in {doc}`base`.
|
||||||
- Paths with spaces should be wrapped in quotes.
|
- Paths with spaces should be wrapped in quotes.
|
||||||
- Input audio is expected to be mono.
|
- Input audio is expected to be mono.
|
||||||
- Legacy `detect` uses a required threshold argument, while `predict` uses the
|
- `process` uses the optional `--detection-threshold` override.
|
||||||
optional `--detection-threshold` override.
|
- `evaluate` takes `TEST_DATASET` as a positional argument and uses `--model`
|
||||||
|
for the checkpoint override.
|
||||||
|
- `finetune` defaults to the bundled `uk_same` checkpoint if `--model` is not
|
||||||
|
provided.
|
||||||
|
|
||||||
```{warning}
|
```{warning}
|
||||||
`batdetect2 detect` is a legacy command.
|
`batdetect2 detect` is a legacy command.
|
||||||
Prefer `batdetect2 predict directory` for new workflows.
|
Prefer `batdetect2 process directory` for new workflows.
|
||||||
```
|
```
|
||||||
|
|
||||||
## Related pages
|
## Related pages
|
||||||
@ -43,9 +41,10 @@ Prefer `batdetect2 predict directory` for new workflows.
|
|||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
Base command and global options <base>
|
Base command and global options <base>
|
||||||
Predict command group <predict>
|
Process command group <predict>
|
||||||
Data command group <data>
|
Data command group <data>
|
||||||
Train command <train>
|
Train command <train>
|
||||||
|
Finetune command <finetune>
|
||||||
Evaluate command <evaluate>
|
Evaluate command <evaluate>
|
||||||
Legacy detect command <detect_legacy>
|
Legacy detect command <detect_legacy>
|
||||||
```
|
```
|
||||||
|
|||||||
@ -1,9 +1,17 @@
|
|||||||
Predict command
|
Process command
|
||||||
===============
|
===============
|
||||||
|
|
||||||
Run model inference from a directory, a file list, or a dataset.
|
Use ``batdetect2 process`` to run inference on audio.
|
||||||
Use ``--detection-threshold`` to override the model default per run.
|
|
||||||
|
|
||||||
.. click:: batdetect2.cli.inference:predict
|
Choose a subcommand based on how you want to provide the input:
|
||||||
:prog: batdetect2 predict
|
|
||||||
|
- ``directory`` for all supported audio files in one folder
|
||||||
|
- ``file_list`` for a text file with one audio path per line
|
||||||
|
- ``dataset`` for recordings referenced by a dataset file
|
||||||
|
|
||||||
|
Use ``--detection-threshold`` when you want to override the configured
|
||||||
|
threshold for one run.
|
||||||
|
|
||||||
|
.. click:: batdetect2.cli.inference:process
|
||||||
|
:prog: batdetect2 process
|
||||||
:nested: full
|
:nested: full
|
||||||
|
|||||||
@ -1,7 +1,11 @@
|
|||||||
Train command
|
Train command
|
||||||
=============
|
=============
|
||||||
|
|
||||||
Train a model from dataset configs or fine-tune from a checkpoint.
|
Use ``batdetect2 train`` to start from a fresh model config or continue from an
|
||||||
|
existing checkpoint.
|
||||||
|
|
||||||
|
If you want to adapt an existing checkpoint to a new target definition, use
|
||||||
|
``batdetect2 finetune`` instead.
|
||||||
|
|
||||||
.. click:: batdetect2.cli.train:train_command
|
.. click:: batdetect2.cli.train:train_command
|
||||||
:prog: batdetect2 train
|
:prog: batdetect2 train
|
||||||
|
|||||||
@ -3,7 +3,8 @@
|
|||||||
This tutorial shows how to evaluate a trained checkpoint on a held-out dataset
|
This tutorial shows how to evaluate a trained checkpoint on a held-out dataset
|
||||||
and inspect the output metrics.
|
and inspect the output metrics.
|
||||||
|
|
||||||
This tutorial is for advanced users who want to compare one trained model against a separate test dataset.
|
This tutorial is for advanced users who want to compare one trained model
|
||||||
|
against a separate test dataset.
|
||||||
|
|
||||||
## Before you start
|
## Before you start
|
||||||
|
|
||||||
@ -32,22 +33,22 @@ Use a dataset that was not used for training or tuning.
|
|||||||
|
|
||||||
A held-out dataset is simply a separate dataset kept aside for evaluation.
|
A held-out dataset is simply a separate dataset kept aside for evaluation.
|
||||||
|
|
||||||
If you tune thresholds or configs on the same dataset that you report as final evaluation, the results will be optimistic.
|
If you tune thresholds or configs on the same dataset that you report as final
|
||||||
|
evaluation, the results will be optimistic.
|
||||||
|
|
||||||
## 2. Run evaluation
|
## 2. Run evaluation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 evaluate \
|
batdetect2 evaluate \
|
||||||
path/to/model.ckpt \
|
|
||||||
path/to/test_dataset.yaml \
|
path/to/test_dataset.yaml \
|
||||||
|
--model path/to/model.ckpt \
|
||||||
--base-dir path/to/project_root \
|
--base-dir path/to/project_root \
|
||||||
--output-dir path/to/eval_outputs
|
--output-dir path/to/eval_outputs
|
||||||
```
|
```
|
||||||
|
|
||||||
This command loads the checkpoint,
|
This command loads the checkpoint, runs prediction on the test dataset, applies
|
||||||
runs prediction on the test dataset,
|
the chosen evaluation tasks, and writes metrics and result files to the output
|
||||||
applies the chosen evaluation tasks,
|
directory.
|
||||||
and writes metrics and result files to the output directory.
|
|
||||||
|
|
||||||
Use `--base-dir` whenever the dataset config contains relative paths.
|
Use `--base-dir` whenever the dataset config contains relative paths.
|
||||||
|
|
||||||
@ -73,7 +74,8 @@ Check:
|
|||||||
- which task the metric belongs to,
|
- which task the metric belongs to,
|
||||||
- which thresholding or matching assumptions were used,
|
- which thresholding or matching assumptions were used,
|
||||||
- whether class-level behavior matches your use case,
|
- whether class-level behavior matches your use case,
|
||||||
- whether the failures are concentrated in specific taxa, sites, or recording conditions.
|
- whether the failures are concentrated in specific taxa, sites, or recording
|
||||||
|
conditions.
|
||||||
|
|
||||||
## 5. Record the evaluation setup
|
## 5. Record the evaluation setup
|
||||||
|
|
||||||
@ -85,7 +87,11 @@ That matters for reproducibility and for later model comparisons.
|
|||||||
|
|
||||||
- Compare thresholds on representative files:
|
- Compare thresholds on representative files:
|
||||||
{doc}`../how_to/tune-detection-threshold`
|
{doc}`../how_to/tune-detection-threshold`
|
||||||
- Configure evaluation tasks: {doc}`../how_to/choose-and-configure-evaluation-tasks`
|
- Configure evaluation tasks:
|
||||||
- Interpret evaluation artifacts: {doc}`../how_to/interpret-evaluation-outputs`
|
{doc}`../how_to/choose-and-configure-evaluation-tasks`
|
||||||
- Learn the evaluation concepts: {doc}`../explanation/evaluation-concepts-and-matching`
|
- Interpret evaluation artifacts:
|
||||||
- Check full evaluate options: {doc}`../reference/cli/evaluate`
|
{doc}`../how_to/interpret-evaluation-outputs`
|
||||||
|
- Learn the evaluation concepts:
|
||||||
|
{doc}`../explanation/evaluation-concepts-and-matching`
|
||||||
|
- Check full evaluate options:
|
||||||
|
{doc}`../reference/cli/evaluate`
|
||||||
|
|||||||
@ -4,7 +4,8 @@ This tutorial walks through a first end-to-end inference run with the CLI.
|
|||||||
|
|
||||||
It is the default starting point for new users.
|
It is the default starting point for new users.
|
||||||
|
|
||||||
Use it when you want to run an existing model on a folder of recordings and quickly check what BatDetect2 found.
|
Use it when you want to run an existing model on a folder of recordings and
|
||||||
|
quickly check what BatDetect2 found.
|
||||||
|
|
||||||
## Before you start
|
## Before you start
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ src/batdetect2/models/checkpoints/Net2DFast_UK_same.pth.tar
|
|||||||
|
|
||||||
By the end of this tutorial you will have:
|
By the end of this tutorial you will have:
|
||||||
|
|
||||||
- run `batdetect2 predict directory`,
|
- run `batdetect2 process directory`,
|
||||||
- saved predictions to disk,
|
- saved predictions to disk,
|
||||||
- checked that BatDetect2 wrote output files,
|
- checked that BatDetect2 wrote output files,
|
||||||
- identified the next pages to use for tuning or customization.
|
- identified the next pages to use for tuning or customization.
|
||||||
@ -48,12 +49,13 @@ project/
|
|||||||
outputs/
|
outputs/
|
||||||
```
|
```
|
||||||
|
|
||||||
## 2. Run prediction on the directory
|
## 2. Run processing on the directory
|
||||||
|
|
||||||
Use this command when you want BatDetect2 to scan a folder of recordings automatically.
|
Use this command when you want BatDetect2 to scan a folder of recordings
|
||||||
|
automatically.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
batdetect2 predict directory \
|
batdetect2 process directory \
|
||||||
path/to/model.pth.tar \
|
path/to/model.pth.tar \
|
||||||
path/to/audio_dir \
|
path/to/audio_dir \
|
||||||
path/to/outputs
|
path/to/outputs
|
||||||
@ -70,8 +72,7 @@ What this does:
|
|||||||
|
|
||||||
After the command completes, inspect the output directory.
|
After the command completes, inspect the output directory.
|
||||||
|
|
||||||
For a first run,
|
For a first run, the important check is simple:
|
||||||
the important check is simple:
|
|
||||||
|
|
||||||
- did BatDetect2 create result files,
|
- did BatDetect2 create result files,
|
||||||
- are they in the output directory you expected,
|
- are they in the output directory you expected,
|
||||||
@ -81,8 +82,8 @@ Different workflows can save results in different file formats.
|
|||||||
|
|
||||||
You do not need to learn those details for the first run.
|
You do not need to learn those details for the first run.
|
||||||
|
|
||||||
If you later need to choose a specific output format,
|
If you later need to choose a specific output format, go to
|
||||||
go to {doc}`../how_to/save-predictions-in-different-output-formats`.
|
{doc}`../how_to/save-predictions-in-different-output-formats`.
|
||||||
|
|
||||||
## 4. Inspect predictions
|
## 4. Inspect predictions
|
||||||
|
|
||||||
@ -103,13 +104,17 @@ Validation comes next.
|
|||||||
|
|
||||||
## 5. Tune only after you have a baseline
|
## 5. Tune only after you have a baseline
|
||||||
|
|
||||||
If the first run is too noisy or misses obvious calls, tune thresholds on a reviewed subset rather than changing settings blindly across the full dataset.
|
If the first run is too noisy or misses obvious calls, tune thresholds on a
|
||||||
|
reviewed subset rather than changing settings blindly across the full dataset.
|
||||||
|
|
||||||
Use {doc}`../how_to/tune-detection-threshold` for that process.
|
Use {doc}`../how_to/tune-detection-threshold` for that process.
|
||||||
|
|
||||||
## What to do next
|
## What to do next
|
||||||
|
|
||||||
- If you need a different input mode, use {doc}`../how_to/choose-an-inference-input-mode`.
|
- If you need a different input mode, use
|
||||||
- If you want to tune sensitivity, use {doc}`../how_to/tune-detection-threshold`.
|
{doc}`../how_to/choose-an-inference-input-mode`.
|
||||||
- If you already write code and want more control from Python, use {doc}`integrate-with-a-python-pipeline`.
|
- If you want to tune sensitivity, use
|
||||||
|
{doc}`../how_to/tune-detection-threshold`.
|
||||||
|
- If you already write code and want more control from Python, use
|
||||||
|
{doc}`integrate-with-a-python-pipeline`.
|
||||||
- If you need full command details, use {doc}`../reference/cli/predict`.
|
- If you need full command details, use {doc}`../reference/cli/predict`.
|
||||||
|
|||||||
26
justfile
26
justfile
@ -17,6 +17,10 @@ help:
|
|||||||
install:
|
install:
|
||||||
uv sync
|
uv sync
|
||||||
|
|
||||||
|
# Install full development dependencies for CI and docs builds.
|
||||||
|
install-dev:
|
||||||
|
uv sync --all-extras --dev
|
||||||
|
|
||||||
# Testing & Coverage
|
# Testing & Coverage
|
||||||
# Run tests using pytest.
|
# Run tests using pytest.
|
||||||
test:
|
test:
|
||||||
@ -50,6 +54,9 @@ coverage-serve: coverage-html
|
|||||||
docs:
|
docs:
|
||||||
uv run sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
|
uv run sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
|
||||||
|
|
||||||
|
# Check that documentation builds successfully.
|
||||||
|
check-docs: docs
|
||||||
|
|
||||||
# Serve documentation with live reload.
|
# Serve documentation with live reload.
|
||||||
docs-serve:
|
docs-serve:
|
||||||
uv run sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
|
uv run sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
|
||||||
@ -84,6 +91,25 @@ check-types:
|
|||||||
# Run all checks (format-check, lint, typecheck).
|
# Run all checks (format-check, lint, typecheck).
|
||||||
check: check-format check-lint check-types
|
check: check-format check-lint check-types
|
||||||
|
|
||||||
|
# Run the standard CI validation sequence.
|
||||||
|
ci: check test
|
||||||
|
|
||||||
|
# Build source and wheel distributions.
|
||||||
|
build-dist:
|
||||||
|
uv run --with build python -m build
|
||||||
|
|
||||||
|
# Bump the patch version, commit, and tag.
|
||||||
|
bump-patch:
|
||||||
|
uvx bump2version patch
|
||||||
|
|
||||||
|
# Bump the minor version, commit, and tag.
|
||||||
|
bump-minor:
|
||||||
|
uvx bump2version minor
|
||||||
|
|
||||||
|
# Bump the major version, commit, and tag.
|
||||||
|
bump-major:
|
||||||
|
uvx bump2version major
|
||||||
|
|
||||||
# Cleaning tasks
|
# Cleaning tasks
|
||||||
# Remove Python bytecode and cache.
|
# Remove Python bytecode and cache.
|
||||||
clean-pyc:
|
clean-pyc:
|
||||||
|
|||||||
@ -7,7 +7,6 @@ authors = [
|
|||||||
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" },
|
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" },
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cf-xarray>=0.9.0",
|
|
||||||
"click>=8.1.7",
|
"click>=8.1.7",
|
||||||
"deepmerge>=2.0",
|
"deepmerge>=2.0",
|
||||||
"hydra-core>=1.3.2",
|
"hydra-core>=1.3.2",
|
||||||
@ -16,21 +15,19 @@ dependencies = [
|
|||||||
"loguru>=0.7.3",
|
"loguru>=0.7.3",
|
||||||
"matplotlib>=3.7.1",
|
"matplotlib>=3.7.1",
|
||||||
"netcdf4>=1.6.5",
|
"netcdf4>=1.6.5",
|
||||||
"numba>=0.60",
|
|
||||||
"numpy>=1.23.5",
|
"numpy>=1.23.5",
|
||||||
"omegaconf>=2.3.0",
|
|
||||||
"onnx>=1.16.0",
|
|
||||||
"pandas>=1.5.3",
|
"pandas>=1.5.3",
|
||||||
|
"pydantic>=2.0.0",
|
||||||
"pyyaml>=6.0.2",
|
"pyyaml>=6.0.2",
|
||||||
"scikit-learn>=1.2.2",
|
"scikit-learn>=1.2.2",
|
||||||
"scipy>=1.10.1",
|
"scipy>=1.10.1",
|
||||||
"seaborn>=0.13.2",
|
"seaborn>=0.13.2",
|
||||||
"soundevent[audio,geometry,plot]>=2.10.0",
|
"soundevent[audio,geometry,plot]>=2.10.0",
|
||||||
|
"soundfile>=0.12.1",
|
||||||
"tensorboard>=2.16.2",
|
"tensorboard>=2.16.2",
|
||||||
"torch>=1.13.1",
|
"torch>=1.13.1",
|
||||||
"torchaudio>=1.13.1",
|
"torchaudio>=1.13.1",
|
||||||
"torchvision>=0.14.0",
|
"xarray>=2024.0.0",
|
||||||
"tqdm>=4.66.2",
|
|
||||||
]
|
]
|
||||||
requires-python = ">=3.10,<3.14"
|
requires-python = ">=3.10,<3.14"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
@ -66,6 +63,7 @@ build-backend = "hatchling.build"
|
|||||||
batdetect2 = "batdetect2.cli:cli"
|
batdetect2 = "batdetect2.cli:cli"
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
|
huggingface = ["huggingface-hub>=0.32.0"]
|
||||||
jupyter = ["ipywidgets>=8.1.5", "jupyter>=1.1.1"]
|
jupyter = ["ipywidgets>=8.1.5", "jupyter>=1.1.1"]
|
||||||
marimo = ["marimo>=0.12.2", "pyarrow>=20.0.0"]
|
marimo = ["marimo>=0.12.2", "pyarrow>=20.0.0"]
|
||||||
dev = [
|
dev = [
|
||||||
|
|||||||
@ -1,11 +1,25 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
|
|
||||||
logger.disable("batdetect2")
|
logger.disable("batdetect2")
|
||||||
|
|
||||||
|
|
||||||
numba_logger = logging.getLogger("numba")
|
numba_logger = logging.getLogger("numba")
|
||||||
numba_logger.setLevel(logging.WARNING)
|
numba_logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
__all__ = ["BatDetect2API", "__version__"]
|
||||||
__version__ = "1.1.1"
|
__version__ = "1.1.1"
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(name: str):
|
||||||
|
if name == "BatDetect2API":
|
||||||
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
|
|
||||||
|
return BatDetect2API
|
||||||
|
|
||||||
|
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||||
|
|||||||
@ -3,13 +3,12 @@ from __future__ import annotations
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Literal
|
from typing import TYPE_CHECKING, Literal
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, AudioLoader
|
from batdetect2.audio import AudioConfig, AudioLoader
|
||||||
from batdetect2.data import Dataset
|
from batdetect2.data import Dataset
|
||||||
@ -20,7 +19,8 @@ if TYPE_CHECKING:
|
|||||||
LoggerConfig,
|
LoggerConfig,
|
||||||
LoggingCallback,
|
LoggingCallback,
|
||||||
)
|
)
|
||||||
from batdetect2.models import Model, ModelConfig
|
from batdetect2.models import ModelConfig
|
||||||
|
from batdetect2.models.types import ModelProtocol
|
||||||
from batdetect2.outputs import (
|
from batdetect2.outputs import (
|
||||||
OutputFormatConfig,
|
OutputFormatConfig,
|
||||||
OutputFormatterProtocol,
|
OutputFormatterProtocol,
|
||||||
@ -48,6 +48,31 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
|||||||
|
|
||||||
|
|
||||||
class BatDetect2API:
|
class BatDetect2API:
|
||||||
|
"""High-level interface for the BatDetect2 workflow.
|
||||||
|
|
||||||
|
Use this to load a model, run inference, inspect detections,
|
||||||
|
evaluate predictions, and train or fine-tune models.
|
||||||
|
|
||||||
|
In most cases, start with :meth:`from_checkpoint` to load a trained model.
|
||||||
|
Use :meth:`from_config` when you want to build a new model with custom
|
||||||
|
configs.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
Load the default checkpoint and run prediction on one file.
|
||||||
|
|
||||||
|
>>> from batdetect2.api_v2 import BatDetect2API
|
||||||
|
>>> api = BatDetect2API.from_checkpoint()
|
||||||
|
>>> prediction = api.process_file("recording.wav")
|
||||||
|
|
||||||
|
Load a checkpoint and save predictions for a folder of audio.
|
||||||
|
|
||||||
|
>>> from pathlib import Path
|
||||||
|
>>> api = BatDetect2API.from_checkpoint("uk_same")
|
||||||
|
>>> predictions = api.process_directory("audio")
|
||||||
|
>>> api.save_predictions(predictions, "outputs/")
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
@ -65,8 +90,49 @@ class BatDetect2API:
|
|||||||
evaluator: EvaluatorProtocol,
|
evaluator: EvaluatorProtocol,
|
||||||
formatter: OutputFormatterProtocol,
|
formatter: OutputFormatterProtocol,
|
||||||
output_transform: OutputTransformProtocol,
|
output_transform: OutputTransformProtocol,
|
||||||
model: Model,
|
model: ModelProtocol,
|
||||||
):
|
):
|
||||||
|
"""Create a fully configured API instance.
|
||||||
|
|
||||||
|
This initializer is mainly for internal use.
|
||||||
|
In most cases, users should create the API with
|
||||||
|
:meth:`from_checkpoint` or :meth:`from_config`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_config : ModelConfig
|
||||||
|
Model configuration.
|
||||||
|
audio_config : AudioConfig
|
||||||
|
Audio loading configuration.
|
||||||
|
train_config : TrainingConfig
|
||||||
|
Training configuration.
|
||||||
|
evaluation_config : EvaluationConfig
|
||||||
|
Evaluation configuration.
|
||||||
|
inference_config : InferenceConfig
|
||||||
|
Inference configuration.
|
||||||
|
outputs_config : OutputsConfig
|
||||||
|
Output formatting configuration.
|
||||||
|
logging_config : AppLoggingConfig
|
||||||
|
Logging configuration.
|
||||||
|
targets : TargetProtocol
|
||||||
|
Target definition used by the model.
|
||||||
|
roi_mapper : ROIMapperProtocol
|
||||||
|
ROI mapping used for size targets.
|
||||||
|
audio_loader : AudioLoader
|
||||||
|
Audio loader.
|
||||||
|
preprocessor : PreprocessorProtocol
|
||||||
|
Preprocessor used before the detector.
|
||||||
|
postprocessor : PostprocessorProtocol
|
||||||
|
Postprocessor used after the detector.
|
||||||
|
evaluator : EvaluatorProtocol
|
||||||
|
Evaluator used for metrics.
|
||||||
|
formatter : OutputFormatterProtocol
|
||||||
|
Default formatter used to save predictions.
|
||||||
|
output_transform : OutputTransformProtocol
|
||||||
|
Transform that converts model outputs into detections.
|
||||||
|
model : ModelProtocol
|
||||||
|
Model instance.
|
||||||
|
"""
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
self.audio_config = audio_config
|
self.audio_config = audio_config
|
||||||
self.train_config = train_config
|
self.train_config = train_config
|
||||||
@ -91,6 +157,21 @@ class BatDetect2API:
|
|||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
base_dir: data.PathLike | None = None,
|
base_dir: data.PathLike | None = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
|
"""Load a set of annotations from a dataset config file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike
|
||||||
|
Path to the dataset config file.
|
||||||
|
base_dir : data.PathLike | None, optional
|
||||||
|
Base directory used to resolve relative paths in the dataset
|
||||||
|
config.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Dataset
|
||||||
|
Loaded dataset of annotations.
|
||||||
|
"""
|
||||||
from batdetect2.data import load_dataset_from_config
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
|
||||||
return load_dataset_from_config(path, base_dir=base_dir)
|
return load_dataset_from_config(path, base_dir=base_dir)
|
||||||
@ -107,12 +188,50 @@ class BatDetect2API:
|
|||||||
num_epochs: int | None = None,
|
num_epochs: int | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
seed: int | None = None,
|
seed: int | None = None,
|
||||||
model_config: ModelConfig | None = None,
|
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
||||||
):
|
):
|
||||||
|
"""Train the current model on a set of annotations.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
train_annotations : Sequence[data.ClipAnnotation]
|
||||||
|
Training annotations.
|
||||||
|
val_annotations : Sequence[data.ClipAnnotation] | None, optional
|
||||||
|
Validation annotations. If omitted, training runs without a
|
||||||
|
validation set.
|
||||||
|
train_workers : int, optional
|
||||||
|
Number of worker processes for training data loading.
|
||||||
|
val_workers : int, optional
|
||||||
|
Number of worker processes for validation data loading.
|
||||||
|
checkpoint_dir : Path | None, optional
|
||||||
|
Directory where checkpoints are saved.
|
||||||
|
log_dir : Path | None, optional
|
||||||
|
Directory where logs are written.
|
||||||
|
experiment_name : str | None, optional
|
||||||
|
Experiment name used by the configured logger.
|
||||||
|
num_epochs : int | None, optional
|
||||||
|
Maximum number of training epochs.
|
||||||
|
run_name : str | None, optional
|
||||||
|
Run name used by the configured logger.
|
||||||
|
seed : int | None, optional
|
||||||
|
Random seed for reproducibility.
|
||||||
|
audio_config : AudioConfig | None, optional
|
||||||
|
Audio config override.
|
||||||
|
train_config : TrainingConfig | None, optional
|
||||||
|
Training config override.
|
||||||
|
logger_config : LoggerConfig | None, optional
|
||||||
|
Training logger config override.
|
||||||
|
logging_callbacks : Sequence[LoggingCallback[TrainLoggingContext]], optional
|
||||||
|
Extra logging callbacks to run during training setup.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BatDetect2API
|
||||||
|
This API instance with the trained model.
|
||||||
|
"""
|
||||||
from batdetect2.train import run_train
|
from batdetect2.train import run_train
|
||||||
|
|
||||||
self.model.train()
|
self.model.train()
|
||||||
@ -122,7 +241,6 @@ class BatDetect2API:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
roi_mapper=self.roi_mapper,
|
roi_mapper=self.roi_mapper,
|
||||||
model_config=model_config or self.model_config,
|
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
@ -147,7 +265,7 @@ class BatDetect2API:
|
|||||||
targets_config: TargetConfig,
|
targets_config: TargetConfig,
|
||||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||||
trainable: Literal[
|
trainable: Literal[
|
||||||
"all", "heads", "classifier_head", "bbox_head"
|
"all", "heads", "classifier_head", "size_head"
|
||||||
] = "heads",
|
] = "heads",
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
val_workers: int = 0,
|
val_workers: int = 0,
|
||||||
@ -162,7 +280,52 @@ class BatDetect2API:
|
|||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
"""Fine-tune from a checkpoint using a new target definition."""
|
"""Fine-tune the current model for new target sounds.
|
||||||
|
|
||||||
|
Use this when you want to keep the existing model weights but change
|
||||||
|
the target sounds. You can fine-tune the whole model or just the
|
||||||
|
heads.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
train_annotations : Sequence[data.ClipAnnotation]
|
||||||
|
Training annotations.
|
||||||
|
targets_config : TargetConfig
|
||||||
|
Target definition to train against.
|
||||||
|
val_annotations : Sequence[data.ClipAnnotation] | None, optional
|
||||||
|
Validation annotations.
|
||||||
|
trainable : {"all", "heads", "classifier_head", "size_head"}, optional
|
||||||
|
Which model parameters remain trainable.
|
||||||
|
train_workers : int, optional
|
||||||
|
Number of worker processes for training data loading.
|
||||||
|
val_workers : int, optional
|
||||||
|
Number of worker processes for validation data loading.
|
||||||
|
checkpoint_dir : Path | None, optional
|
||||||
|
Directory where checkpoints are saved.
|
||||||
|
log_dir : Path | None, optional
|
||||||
|
Directory where logs are written.
|
||||||
|
experiment_name : str | None, optional
|
||||||
|
Experiment name used by the configured logger.
|
||||||
|
num_epochs : int | None, optional
|
||||||
|
Maximum number of training epochs.
|
||||||
|
run_name : str | None, optional
|
||||||
|
Run name used by the configured logger.
|
||||||
|
seed : int | None, optional
|
||||||
|
Random seed for reproducibility.
|
||||||
|
audio_config : AudioConfig | None, optional
|
||||||
|
Audio config override.
|
||||||
|
train_config : TrainingConfig | None, optional
|
||||||
|
Training config override.
|
||||||
|
logger_config : LoggerConfig | None, optional
|
||||||
|
Training logger config override.
|
||||||
|
logging_callbacks : Sequence[LoggingCallback[TrainLoggingContext]], optional
|
||||||
|
Extra logging callbacks to run during training setup.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BatDetect2API
|
||||||
|
A new API instance configured for the new targets.
|
||||||
|
"""
|
||||||
from batdetect2.evaluate import build_evaluator
|
from batdetect2.evaluate import build_evaluator
|
||||||
from batdetect2.models import build_model_with_new_targets
|
from batdetect2.models import build_model_with_new_targets
|
||||||
from batdetect2.outputs import (
|
from batdetect2.outputs import (
|
||||||
@ -225,7 +388,6 @@ class BatDetect2API:
|
|||||||
model=api.model,
|
model=api.model,
|
||||||
targets=api.targets,
|
targets=api.targets,
|
||||||
roi_mapper=api.roi_mapper,
|
roi_mapper=api.roi_mapper,
|
||||||
model_config=api.model_config,
|
|
||||||
preprocessor=api.preprocessor,
|
preprocessor=api.preprocessor,
|
||||||
audio_loader=api.audio_loader,
|
audio_loader=api.audio_loader,
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
@ -257,6 +419,36 @@ class BatDetect2API:
|
|||||||
outputs_config: OutputsConfig | None = None,
|
outputs_config: OutputsConfig | None = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
) -> tuple[dict[str, float], list[ClipDetections]]:
|
) -> tuple[dict[str, float], list[ClipDetections]]:
|
||||||
|
"""Evaluate the current model on a labelled dataset.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
test_annotations : Sequence[data.ClipAnnotation]
|
||||||
|
Labelled clips used for evaluation.
|
||||||
|
num_workers : int, optional
|
||||||
|
Number of worker processes for dataset loading.
|
||||||
|
output_dir : data.PathLike, optional
|
||||||
|
Directory where metrics and plots are written.
|
||||||
|
experiment_name : str | None, optional
|
||||||
|
Experiment name used by the configured logger.
|
||||||
|
run_name : str | None, optional
|
||||||
|
Run name used by the configured logger.
|
||||||
|
save_predictions : bool, optional
|
||||||
|
If ``True``, save formatted predictions alongside metrics.
|
||||||
|
audio_config : AudioConfig | None, optional
|
||||||
|
Audio config override.
|
||||||
|
evaluation_config : EvaluationConfig | None, optional
|
||||||
|
Evaluation config override.
|
||||||
|
outputs_config : OutputsConfig | None, optional
|
||||||
|
Output config override.
|
||||||
|
logger_config : LoggerConfig | None, optional
|
||||||
|
Evaluation logger config override.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple[dict[str, float], list[ClipDetections]]
|
||||||
|
Evaluation metrics and per-clip predictions.
|
||||||
|
"""
|
||||||
from batdetect2.evaluate import run_evaluate
|
from batdetect2.evaluate import run_evaluate
|
||||||
|
|
||||||
return run_evaluate(
|
return run_evaluate(
|
||||||
@ -283,6 +475,22 @@ class BatDetect2API:
|
|||||||
predictions: Sequence[ClipDetections],
|
predictions: Sequence[ClipDetections],
|
||||||
output_dir: data.PathLike | None = None,
|
output_dir: data.PathLike | None = None,
|
||||||
):
|
):
|
||||||
|
"""Evaluate an existing set of predictions.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
annotations : Sequence[data.ClipAnnotation]
|
||||||
|
Reference annotations.
|
||||||
|
predictions : Sequence[ClipDetections]
|
||||||
|
Predictions to compare against the annotations.
|
||||||
|
output_dir : data.PathLike | None, optional
|
||||||
|
Directory where metrics and plots are written.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict[str, float]
|
||||||
|
Computed evaluation metrics.
|
||||||
|
"""
|
||||||
from batdetect2.evaluate import save_evaluation_results
|
from batdetect2.evaluate import save_evaluation_results
|
||||||
|
|
||||||
clip_evals = self.evaluator.evaluate(
|
clip_evals = self.evaluator.evaluate(
|
||||||
@ -302,16 +510,65 @@ class BatDetect2API:
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def load_audio(self, path: data.PathLike) -> np.ndarray:
|
def load_audio(self, path: data.PathLike) -> np.ndarray:
|
||||||
|
"""Load one audio file into a waveform array.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike
|
||||||
|
Path to the audio file.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
Audio waveform loaded from disk.
|
||||||
|
"""
|
||||||
return self.audio_loader.load_file(path)
|
return self.audio_loader.load_file(path)
|
||||||
|
|
||||||
def load_recording(self, recording: data.Recording) -> np.ndarray:
|
def load_recording(self, recording: data.Recording) -> np.ndarray:
|
||||||
|
"""Load one recording object into a waveform array.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
recording : data.Recording
|
||||||
|
Recording object describing the audio to load.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
Audio waveform for the requested recording.
|
||||||
|
"""
|
||||||
return self.audio_loader.load_recording(recording)
|
return self.audio_loader.load_recording(recording)
|
||||||
|
|
||||||
def load_clip(self, clip: data.Clip) -> np.ndarray:
|
def load_clip(self, clip: data.Clip) -> np.ndarray:
|
||||||
|
"""Load one clip object into a waveform array.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
clip : data.Clip
|
||||||
|
Clip object describing the section of audio to load.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
Audio waveform for the requested clip.
|
||||||
|
"""
|
||||||
return self.audio_loader.load_clip(clip)
|
return self.audio_loader.load_clip(clip)
|
||||||
|
|
||||||
def get_top_class_name(self, detection: Detection) -> str:
|
def get_top_class_name(self, detection: Detection) -> str:
|
||||||
"""Get highest-confidence class name for one detection."""
|
"""Get the name of the highest-confidence class for one detection.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
detection : Detection
|
||||||
|
Detection whose class scores will be inspected.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
Class name with the highest score.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
top_index = int(np.argmax(detection.class_scores))
|
top_index = int(np.argmax(detection.class_scores))
|
||||||
return self.targets.class_names[top_index]
|
return self.targets.class_names[top_index]
|
||||||
@ -323,7 +580,22 @@ class BatDetect2API:
|
|||||||
include_top_class: bool = True,
|
include_top_class: bool = True,
|
||||||
sort_descending: bool = True,
|
sort_descending: bool = True,
|
||||||
) -> list[tuple[str, float]]:
|
) -> list[tuple[str, float]]:
|
||||||
"""Get class score list as ``(class_name, score)`` pairs."""
|
"""Get class scores as ``(class_name, score)`` pairs.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
detection : Detection
|
||||||
|
Detection whose class scores will be returned.
|
||||||
|
include_top_class : bool, optional
|
||||||
|
If ``False``, omit the highest-scoring class from the result.
|
||||||
|
sort_descending : bool, optional
|
||||||
|
If ``True``, sort scores from highest to lowest.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[tuple[str, float]]
|
||||||
|
Class-score pairs for the detection.
|
||||||
|
"""
|
||||||
|
|
||||||
scores = [
|
scores = [
|
||||||
(class_name, float(score))
|
(class_name, float(score))
|
||||||
@ -347,16 +619,22 @@ class BatDetect2API:
|
|||||||
if class_name != top_class_name
|
if class_name != top_class_name
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_detection_features(detection: Detection) -> np.ndarray:
|
|
||||||
"""Get extracted feature vector for one detection."""
|
|
||||||
|
|
||||||
return detection.features
|
|
||||||
|
|
||||||
def generate_spectrogram(
|
def generate_spectrogram(
|
||||||
self,
|
self,
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""Convert a waveform array into a spectrogram tensor.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio : np.ndarray
|
||||||
|
Audio waveform.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Spectrogram tensor ready for model inference.
|
||||||
|
"""
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
tensor = torch.tensor(audio).unsqueeze(0)
|
tensor = torch.tensor(audio).unsqueeze(0)
|
||||||
@ -368,6 +646,25 @@ class BatDetect2API:
|
|||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
detection_threshold: float | None = None,
|
detection_threshold: float | None = None,
|
||||||
) -> ClipDetections:
|
) -> ClipDetections:
|
||||||
|
"""Run inference on one audio file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio_file : data.PathLike
|
||||||
|
Path to the audio file.
|
||||||
|
batch_size : int | None, optional
|
||||||
|
Batch size override. If omitted, the inference config value is
|
||||||
|
used.
|
||||||
|
detection_threshold : float | None, optional
|
||||||
|
Detection score threshold override.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ClipDetections
|
||||||
|
Predictions for the full recording.
|
||||||
|
"""
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.postprocess import ClipDetections
|
from batdetect2.postprocess import ClipDetections
|
||||||
|
|
||||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||||
@ -402,6 +699,20 @@ class BatDetect2API:
|
|||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
detection_threshold: float | None = None,
|
detection_threshold: float | None = None,
|
||||||
) -> list[Detection]:
|
) -> list[Detection]:
|
||||||
|
"""Run inference on a waveform array.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio : np.ndarray
|
||||||
|
Audio waveform.
|
||||||
|
detection_threshold : float | None, optional
|
||||||
|
Detection score threshold override.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[Detection]
|
||||||
|
Detected calls.
|
||||||
|
"""
|
||||||
spec = self.generate_spectrogram(audio)
|
spec = self.generate_spectrogram(audio)
|
||||||
return self.process_spectrogram(
|
return self.process_spectrogram(
|
||||||
spec,
|
spec,
|
||||||
@ -414,6 +725,27 @@ class BatDetect2API:
|
|||||||
start_time: float = 0,
|
start_time: float = 0,
|
||||||
detection_threshold: float | None = None,
|
detection_threshold: float | None = None,
|
||||||
) -> list[Detection]:
|
) -> list[Detection]:
|
||||||
|
"""Run inference on one spectrogram tensor.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
spec : torch.Tensor
|
||||||
|
Spectrogram tensor for one recording or clip.
|
||||||
|
start_time : float, optional
|
||||||
|
Start time in seconds used when creating detections.
|
||||||
|
detection_threshold : float | None, optional
|
||||||
|
Detection score threshold override.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[Detection]
|
||||||
|
Detected calls.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If a batched spectrogram with more than one item is provided.
|
||||||
|
"""
|
||||||
if spec.ndim == 4 and spec.shape[0] > 1:
|
if spec.ndim == 4 and spec.shape[0] > 1:
|
||||||
raise ValueError("Batched spectrograms not supported.")
|
raise ValueError("Batched spectrograms not supported.")
|
||||||
|
|
||||||
@ -436,6 +768,20 @@ class BatDetect2API:
|
|||||||
audio_dir: data.PathLike,
|
audio_dir: data.PathLike,
|
||||||
detection_threshold: float | None = None,
|
detection_threshold: float | None = None,
|
||||||
) -> list[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
|
"""Run inference on all supported audio files in a directory.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio_dir : data.PathLike
|
||||||
|
Directory containing audio files.
|
||||||
|
detection_threshold : float | None, optional
|
||||||
|
Detection score threshold override.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[ClipDetections]
|
||||||
|
Predictions for all supported audio files found in the directory.
|
||||||
|
"""
|
||||||
from soundevent.audio.files import get_audio_files
|
from soundevent.audio.files import get_audio_files
|
||||||
|
|
||||||
files = list(get_audio_files(audio_dir))
|
files = list(get_audio_files(audio_dir))
|
||||||
@ -454,6 +800,30 @@ class BatDetect2API:
|
|||||||
output_config: OutputsConfig | None = None,
|
output_config: OutputsConfig | None = None,
|
||||||
detection_threshold: float | None = None,
|
detection_threshold: float | None = None,
|
||||||
) -> list[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
|
"""Run inference on multiple audio files.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
audio_files : Sequence[data.PathLike]
|
||||||
|
Audio file paths.
|
||||||
|
batch_size : int | None, optional
|
||||||
|
Batch size override.
|
||||||
|
num_workers : int, optional
|
||||||
|
Number of worker processes for audio loading.
|
||||||
|
audio_config : AudioConfig | None, optional
|
||||||
|
Audio config override.
|
||||||
|
inference_config : InferenceConfig | None, optional
|
||||||
|
Inference config override.
|
||||||
|
output_config : OutputsConfig | None, optional
|
||||||
|
Output config override.
|
||||||
|
detection_threshold : float | None, optional
|
||||||
|
Detection score threshold override.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[ClipDetections]
|
||||||
|
Predictions for each input file.
|
||||||
|
"""
|
||||||
from batdetect2.inference import process_file_list
|
from batdetect2.inference import process_file_list
|
||||||
|
|
||||||
return process_file_list(
|
return process_file_list(
|
||||||
@ -482,6 +852,30 @@ class BatDetect2API:
|
|||||||
output_config: OutputsConfig | None = None,
|
output_config: OutputsConfig | None = None,
|
||||||
detection_threshold: float | None = None,
|
detection_threshold: float | None = None,
|
||||||
) -> list[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
|
"""Run inference on multiple clip objects.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
clips : Sequence[data.Clip]
|
||||||
|
Clips to process.
|
||||||
|
batch_size : int | None, optional
|
||||||
|
Batch size override.
|
||||||
|
num_workers : int, optional
|
||||||
|
Number of worker processes for audio loading.
|
||||||
|
audio_config : AudioConfig | None, optional
|
||||||
|
Audio config override.
|
||||||
|
inference_config : InferenceConfig | None, optional
|
||||||
|
Inference config override.
|
||||||
|
output_config : OutputsConfig | None, optional
|
||||||
|
Output config override.
|
||||||
|
detection_threshold : float | None, optional
|
||||||
|
Detection score threshold override.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[ClipDetections]
|
||||||
|
Predictions for each input clip.
|
||||||
|
"""
|
||||||
from batdetect2.inference import run_batch_inference
|
from batdetect2.inference import run_batch_inference
|
||||||
|
|
||||||
return run_batch_inference(
|
return run_batch_inference(
|
||||||
@ -508,6 +902,21 @@ class BatDetect2API:
|
|||||||
format: str | None = None,
|
format: str | None = None,
|
||||||
config: OutputFormatConfig | None = None,
|
config: OutputFormatConfig | None = None,
|
||||||
):
|
):
|
||||||
|
"""Save predictions to disk in one of the supported output formats.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
predictions : Sequence[ClipDetections]
|
||||||
|
Predictions to save.
|
||||||
|
path : data.PathLike
|
||||||
|
Output file or directory path, depending on the selected format.
|
||||||
|
audio_dir : data.PathLike | None, optional
|
||||||
|
Audio root directory used when writing relative paths.
|
||||||
|
format : str | None, optional
|
||||||
|
Output format name override.
|
||||||
|
config : OutputFormatConfig | None, optional
|
||||||
|
Output format config override.
|
||||||
|
"""
|
||||||
from batdetect2.outputs import get_output_formatter
|
from batdetect2.outputs import get_output_formatter
|
||||||
|
|
||||||
formatter = self.formatter
|
formatter = self.formatter
|
||||||
@ -529,6 +938,22 @@ class BatDetect2API:
|
|||||||
format: str | None = None,
|
format: str | None = None,
|
||||||
config: OutputFormatConfig | None = None,
|
config: OutputFormatConfig | None = None,
|
||||||
) -> list[object]:
|
) -> list[object]:
|
||||||
|
"""Load predictions from disk.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike
|
||||||
|
Path to a saved prediction file or directory.
|
||||||
|
format : str | None, optional
|
||||||
|
Output format name override.
|
||||||
|
config : OutputFormatConfig | None, optional
|
||||||
|
Output format config override.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[object]
|
||||||
|
Loaded prediction objects returned by the selected formatter.
|
||||||
|
"""
|
||||||
from batdetect2.outputs import get_output_formatter
|
from batdetect2.outputs import get_output_formatter
|
||||||
|
|
||||||
formatter = self.formatter
|
formatter = self.formatter
|
||||||
@ -555,6 +980,36 @@ class BatDetect2API:
|
|||||||
outputs_config: OutputsConfig | None = None,
|
outputs_config: OutputsConfig | None = None,
|
||||||
logging_config: AppLoggingConfig | None = None,
|
logging_config: AppLoggingConfig | None = None,
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
|
"""Build an API instance from config objects.
|
||||||
|
|
||||||
|
Use this when you want to create a new model without loading a saved
|
||||||
|
checkpoint.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_config : ModelConfig | None, optional
|
||||||
|
Model config. If omitted, the default model config is used.
|
||||||
|
targets_config : TargetConfig | None, optional
|
||||||
|
Target config. If omitted, the default target config is used.
|
||||||
|
audio_config : AudioConfig | None, optional
|
||||||
|
Audio config. If omitted, the default audio config is used.
|
||||||
|
train_config : TrainingConfig | None, optional
|
||||||
|
Training config. If omitted, the default training config is used.
|
||||||
|
evaluation_config : EvaluationConfig | None, optional
|
||||||
|
Evaluation config. If omitted, the default evaluation config is
|
||||||
|
used.
|
||||||
|
inference_config : InferenceConfig | None, optional
|
||||||
|
Inference config. If omitted, the default inference config is used.
|
||||||
|
outputs_config : OutputsConfig | None, optional
|
||||||
|
Output config. If omitted, the default outputs config is used.
|
||||||
|
logging_config : AppLoggingConfig | None, optional
|
||||||
|
Logging config. If omitted, the default logging config is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BatDetect2API
|
||||||
|
Configured API instance.
|
||||||
|
"""
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
||||||
from batdetect2.inference import InferenceConfig
|
from batdetect2.inference import InferenceConfig
|
||||||
@ -653,7 +1108,7 @@ class BatDetect2API:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_checkpoint(
|
def from_checkpoint(
|
||||||
cls,
|
cls,
|
||||||
path: data.PathLike,
|
path: data.PathLike | str | None = None,
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
evaluation_config: EvaluationConfig | None = None,
|
evaluation_config: EvaluationConfig | None = None,
|
||||||
@ -661,6 +1116,31 @@ class BatDetect2API:
|
|||||||
outputs_config: OutputsConfig | None = None,
|
outputs_config: OutputsConfig | None = None,
|
||||||
logging_config: AppLoggingConfig | None = None,
|
logging_config: AppLoggingConfig | None = None,
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
|
"""Build an API instance from a saved checkpoint.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike | str | None, optional
|
||||||
|
Checkpoint path, bundled checkpoint alias, or Hugging Face URI.
|
||||||
|
If omitted, the default bundled checkpoint is used.
|
||||||
|
audio_config : AudioConfig | None, optional
|
||||||
|
Audio config override.
|
||||||
|
train_config : TrainingConfig | None, optional
|
||||||
|
Training config override.
|
||||||
|
evaluation_config : EvaluationConfig | None, optional
|
||||||
|
Evaluation config override.
|
||||||
|
inference_config : InferenceConfig | None, optional
|
||||||
|
Inference config override.
|
||||||
|
outputs_config : OutputsConfig | None, optional
|
||||||
|
Output config override.
|
||||||
|
logging_config : AppLoggingConfig | None, optional
|
||||||
|
Logging config override.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BatDetect2API
|
||||||
|
Configured API instance.
|
||||||
|
"""
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
||||||
from batdetect2.inference import InferenceConfig
|
from batdetect2.inference import InferenceConfig
|
||||||
@ -759,7 +1239,7 @@ class BatDetect2API:
|
|||||||
|
|
||||||
def _set_trainable_parameters(
|
def _set_trainable_parameters(
|
||||||
self,
|
self,
|
||||||
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
|
trainable: Literal["all", "heads", "classifier_head", "size_head"],
|
||||||
) -> None:
|
) -> None:
|
||||||
detector = self.model.detector
|
detector = self.model.detector
|
||||||
|
|
||||||
@ -775,6 +1255,6 @@ class BatDetect2API:
|
|||||||
for parameter in detector.classifier_head.parameters():
|
for parameter in detector.classifier_head.parameters():
|
||||||
parameter.requires_grad = True
|
parameter.requires_grad = True
|
||||||
|
|
||||||
if trainable in {"heads", "bbox_head"}:
|
if trainable in {"heads", "size_head"}:
|
||||||
for parameter in detector.bbox_head.parameters():
|
for parameter in detector.size_head.parameters():
|
||||||
parameter.requires_grad = True
|
parameter.requires_grad = True
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from batdetect2.cli.compat import detect
|
|||||||
from batdetect2.cli.data import data
|
from batdetect2.cli.data import data
|
||||||
from batdetect2.cli.evaluate import evaluate_command
|
from batdetect2.cli.evaluate import evaluate_command
|
||||||
from batdetect2.cli.finetune import finetune_command
|
from batdetect2.cli.finetune import finetune_command
|
||||||
from batdetect2.cli.inference import predict
|
from batdetect2.cli.inference import process
|
||||||
from batdetect2.cli.train import train_command
|
from batdetect2.cli.train import train_command
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -13,7 +13,7 @@ __all__ = [
|
|||||||
"train_command",
|
"train_command",
|
||||||
"finetune_command",
|
"finetune_command",
|
||||||
"evaluate_command",
|
"evaluate_command",
|
||||||
"predict",
|
"process",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -2,35 +2,39 @@
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
from batdetect2.cli.ascii import BATDETECT_ASCII_ART
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"cli",
|
"cli",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
INFO_STR = """
|
INFO_STR = """
|
||||||
BatDetect2 - Detection and Classification
|
BatDetect2
|
||||||
Assumes audio files are mono, not stereo.
|
Wrap paths that contain spaces in quotes.
|
||||||
Spaces in the input paths will throw an error. Wrap in quotes.
|
|
||||||
Input files should be short in duration e.g. < 30 seconds.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group(invoke_without_command=True)
|
||||||
@click.option(
|
@click.option(
|
||||||
"-v",
|
"-v",
|
||||||
"--verbose",
|
"--verbose",
|
||||||
count=True,
|
count=True,
|
||||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||||
)
|
)
|
||||||
def cli(verbose: int = 0):
|
@click.pass_context
|
||||||
|
def cli(ctx: click.Context, verbose: int = 0):
|
||||||
"""Run the BatDetect2 CLI.
|
"""Run the BatDetect2 CLI.
|
||||||
|
|
||||||
This command initializes logging and exposes subcommands for prediction,
|
Use subcommands to run processing, training, evaluation, and dataset
|
||||||
training, evaluation, and dataset utilities.
|
utilities.
|
||||||
"""
|
"""
|
||||||
click.echo(INFO_STR)
|
|
||||||
|
if ctx.invoked_subcommand is None:
|
||||||
|
click.echo(BATDETECT_ASCII_ART)
|
||||||
|
click.echo(ctx.get_help())
|
||||||
|
ctx.exit()
|
||||||
|
|
||||||
from batdetect2.logging import enable_logging
|
from batdetect2.logging import enable_logging
|
||||||
|
|
||||||
enable_logging(verbose)
|
enable_logging(verbose)
|
||||||
# click.echo(BATDETECT_ASCII_ART)
|
|
||||||
|
|||||||
@ -15,7 +15,7 @@ DEFAULT_MODEL_PATH = os.path.join(
|
|||||||
@cli.command(
|
@cli.command(
|
||||||
short_help="Legacy detection command.",
|
short_help="Legacy detection command.",
|
||||||
epilog=(
|
epilog=(
|
||||||
"Deprecated workflow. Prefer `batdetect2 predict directory` for "
|
"Deprecated workflow. Prefer `batdetect2 process directory` for "
|
||||||
"new analyses."
|
"new analyses."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -91,11 +91,17 @@ def detect(
|
|||||||
Note
|
Note
|
||||||
----
|
----
|
||||||
This command is kept for backwards compatibility. Prefer
|
This command is kept for backwards compatibility. Prefer
|
||||||
`batdetect2 predict directory` for new workflows.
|
`batdetect2 process directory` for new workflows.
|
||||||
"""
|
"""
|
||||||
from batdetect2 import api
|
from batdetect2 import api
|
||||||
from batdetect2.utils.detector_utils import save_results_to_file
|
from batdetect2.utils.detector_utils import save_results_to_file
|
||||||
|
|
||||||
|
message = (
|
||||||
|
"The `batdetect2 detect` command is deprecated. Prefer "
|
||||||
|
"`batdetect2 process directory` for new analyses."
|
||||||
|
)
|
||||||
|
click.secho(f"WARNING: {message}", fg="yellow", err=True)
|
||||||
|
|
||||||
click.echo(f"Loading model: {args['model_path']}")
|
click.echo(f"Loading model: {args['model_path']}")
|
||||||
model, params = api.load_model(args["model_path"])
|
model, params = api.load_model(args["model_path"])
|
||||||
|
|
||||||
|
|||||||
@ -7,9 +7,9 @@ from batdetect2.cli.base import cli
|
|||||||
__all__ = ["data"]
|
__all__ = ["data"]
|
||||||
|
|
||||||
|
|
||||||
@cli.group(short_help="Inspect and convert datasets.")
|
@cli.group(short_help="Inspect and manage datasets.")
|
||||||
def data():
|
def data():
|
||||||
"""Inspect and convert dataset configuration files."""
|
"""Inspect and manage dataset configuration files."""
|
||||||
|
|
||||||
|
|
||||||
@data.command(short_help="Print dataset summary information.")
|
@data.command(short_help="Print dataset summary information.")
|
||||||
@ -64,7 +64,7 @@ def summary(
|
|||||||
base_dir=base_dir,
|
base_dir=base_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Number of annotated clips: {len(dataset)}")
|
click.echo(f"Number of annotated clips: {len(dataset)}")
|
||||||
|
|
||||||
if targets_path is None:
|
if targets_path is None:
|
||||||
return
|
return
|
||||||
@ -73,7 +73,7 @@ def summary(
|
|||||||
|
|
||||||
summary = compute_class_summary(dataset, targets)
|
summary = compute_class_summary(dataset, targets)
|
||||||
|
|
||||||
print(summary.sort_values("class_name").to_markdown())
|
click.echo(summary.sort_values("class_name").to_markdown())
|
||||||
|
|
||||||
|
|
||||||
@data.command(short_help="Convert dataset config to annotation set.")
|
@data.command(short_help="Convert dataset config to annotation set.")
|
||||||
@ -200,6 +200,6 @@ def convert(
|
|||||||
if not audio_dir.is_absolute():
|
if not audio_dir.is_absolute():
|
||||||
audio_dir = audio_dir.resolve()
|
audio_dir = audio_dir.resolve()
|
||||||
|
|
||||||
print(f"Using audio directory: {audio_dir}")
|
click.echo(f"Using audio directory: {audio_dir}")
|
||||||
|
|
||||||
io.save(annotation_set, output, audio_dir=audio_dir)
|
io.save(annotation_set, output, audio_dir=audio_dir)
|
||||||
|
|||||||
@ -12,38 +12,40 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
|||||||
|
|
||||||
|
|
||||||
@cli.command(name="evaluate", short_help="Evaluate a model checkpoint.")
|
@cli.command(name="evaluate", short_help="Evaluate a model checkpoint.")
|
||||||
@click.argument("model_path", type=click.Path(exists=True))
|
|
||||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||||
@click.option(
|
@click.option(
|
||||||
"--targets",
|
"--model",
|
||||||
"targets_config",
|
"model_path",
|
||||||
type=click.Path(exists=True),
|
type=str,
|
||||||
help="Path to targets config file.",
|
help=(
|
||||||
|
"Path to a checkpoint, checkpoint alias, or a Hugging Face "
|
||||||
|
"URI to fine-tune from. Defaults to uk_same"
|
||||||
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--audio-config",
|
"--audio-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to audio config file.",
|
help="Path to an audio config file.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--evaluation-config",
|
"--evaluation-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to evaluation config file.",
|
help="Path to an evaluation config file.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--inference-config",
|
"--inference-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to inference config file.",
|
help="Path to an inference config file.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--outputs-config",
|
"--outputs-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to outputs config file.",
|
help="Path to an outputs config file.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--logging-config",
|
"--logging-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to logging config file.",
|
help="Path to a logging config file.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--base-dir",
|
"--base-dir",
|
||||||
@ -80,24 +82,23 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
|||||||
default=0,
|
default=0,
|
||||||
)
|
)
|
||||||
def evaluate_command(
|
def evaluate_command(
|
||||||
model_path: Path,
|
|
||||||
test_dataset: Path,
|
test_dataset: Path,
|
||||||
base_dir: Path,
|
model_path: str | None = None,
|
||||||
targets_config: Path | None,
|
base_dir: Path | None = None,
|
||||||
audio_config: Path | None,
|
audio_config: Path | None = None,
|
||||||
evaluation_config: Path | None,
|
evaluation_config: Path | None = None,
|
||||||
inference_config: Path | None,
|
inference_config: Path | None = None,
|
||||||
outputs_config: Path | None,
|
outputs_config: Path | None = None,
|
||||||
logging_config: Path | None,
|
logging_config: Path | None = None,
|
||||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||||
num_workers: int = 0,
|
num_workers: int = 0,
|
||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
):
|
):
|
||||||
"""Evaluate a checkpoint against a test dataset.
|
"""Evaluate a checkpoint on a labelled test dataset.
|
||||||
|
|
||||||
Loads model and optional override configs, runs evaluation on
|
This command loads a checkpoint, runs evaluation on ``test_dataset``, and
|
||||||
`test_dataset`, and writes metrics/artifacts to `output_dir`.
|
writes metrics to ``output_dir``.
|
||||||
"""
|
"""
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.audio import AudioConfig
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, cast
|
from typing import Literal
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -13,13 +13,6 @@ __all__ = ["finetune_command"]
|
|||||||
name="finetune", short_help="Fine-tune a checkpoint on new targets."
|
name="finetune", short_help="Fine-tune a checkpoint on new targets."
|
||||||
)
|
)
|
||||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||||
@click.option(
|
|
||||||
"--model",
|
|
||||||
"model_path",
|
|
||||||
required=True,
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help="Path to a checkpoint to fine-tune from.",
|
|
||||||
)
|
|
||||||
@click.option(
|
@click.option(
|
||||||
"--targets",
|
"--targets",
|
||||||
"targets_config",
|
"targets_config",
|
||||||
@ -27,6 +20,15 @@ __all__ = ["finetune_command"]
|
|||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to the new targets config file.",
|
help="Path to the new targets config file.",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--model",
|
||||||
|
"model_path",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"Path to a checkpoint, checkpoint alias, or a Hugging Face "
|
||||||
|
"URI to fine-tune from. Defaults to uk_same"
|
||||||
|
),
|
||||||
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--val-dataset",
|
"--val-dataset",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
@ -57,7 +59,7 @@ __all__ = ["finetune_command"]
|
|||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--trainable",
|
"--trainable",
|
||||||
type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]),
|
type=click.Choice(["all", "heads", "classifier_head", "size_head"]),
|
||||||
default="heads",
|
default="heads",
|
||||||
show_default=True,
|
show_default=True,
|
||||||
help="Which model parameters remain trainable during fine-tuning.",
|
help="Which model parameters remain trainable during fine-tuning.",
|
||||||
@ -106,8 +108,8 @@ __all__ = ["finetune_command"]
|
|||||||
)
|
)
|
||||||
def finetune_command(
|
def finetune_command(
|
||||||
train_dataset: Path,
|
train_dataset: Path,
|
||||||
model_path: Path,
|
|
||||||
targets_config: Path,
|
targets_config: Path,
|
||||||
|
model_path: str | None = None,
|
||||||
val_dataset: Path | None = None,
|
val_dataset: Path | None = None,
|
||||||
ckpt_dir: Path | None = None,
|
ckpt_dir: Path | None = None,
|
||||||
log_dir: Path | None = None,
|
log_dir: Path | None = None,
|
||||||
@ -115,7 +117,9 @@ def finetune_command(
|
|||||||
training_config: Path | None = None,
|
training_config: Path | None = None,
|
||||||
audio_config: Path | None = None,
|
audio_config: Path | None = None,
|
||||||
logging_config: Path | None = None,
|
logging_config: Path | None = None,
|
||||||
trainable: str = "heads",
|
trainable: Literal[
|
||||||
|
"all", "heads", "classifier_head", "size_head"
|
||||||
|
] = "heads",
|
||||||
seed: int | None = None,
|
seed: int | None = None,
|
||||||
num_epochs: int | None = None,
|
num_epochs: int | None = None,
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
@ -192,10 +196,7 @@ def finetune_command(
|
|||||||
train_annotations=train_annotations,
|
train_annotations=train_annotations,
|
||||||
val_annotations=val_annotations,
|
val_annotations=val_annotations,
|
||||||
targets_config=target_conf,
|
targets_config=target_conf,
|
||||||
trainable=cast(
|
trainable=trainable,
|
||||||
Literal["all", "heads", "classifier_head", "bbox_head"],
|
|
||||||
trainable,
|
|
||||||
),
|
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
val_workers=val_workers,
|
val_workers=val_workers,
|
||||||
checkpoint_dir=ckpt_dir,
|
checkpoint_dir=ckpt_dir,
|
||||||
|
|||||||
@ -13,27 +13,26 @@ if TYPE_CHECKING:
|
|||||||
from batdetect2.inference import InferenceConfig
|
from batdetect2.inference import InferenceConfig
|
||||||
from batdetect2.outputs import OutputsConfig
|
from batdetect2.outputs import OutputsConfig
|
||||||
|
|
||||||
__all__ = ["predict"]
|
__all__ = ["process"]
|
||||||
|
|
||||||
|
|
||||||
@cli.group(name="predict", short_help="Run prediction workflows.")
|
@cli.group(name="process", short_help="Run processing workflows.")
|
||||||
def predict() -> None:
|
def process() -> None:
|
||||||
"""Run model inference on audio files.
|
"""Run model inference on audio.
|
||||||
|
|
||||||
Use one of the subcommands to select inputs from a directory, a text file
|
Choose a subcommand based on how you want to provide input audio.
|
||||||
list, or an annotation dataset.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def common_predict_options(func):
|
def common_predict_options(func):
|
||||||
"""Attach options shared by all `predict` subcommands."""
|
"""Attach options shared by all ``process`` subcommands."""
|
||||||
|
|
||||||
@click.option(
|
@click.option(
|
||||||
"--audio-config",
|
"--audio-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help=(
|
help=(
|
||||||
"Path to an audio config file. Use this to override audio "
|
"Path to an audio config file. Use this to override audio "
|
||||||
"loading and preprocessing-related settings."
|
"loading settings."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
@ -41,7 +40,7 @@ def common_predict_options(func):
|
|||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help=(
|
help=(
|
||||||
"Path to an inference config file. Use this to override "
|
"Path to an inference config file. Use this to override "
|
||||||
"prediction-time thresholds and behavior."
|
"prediction settings."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
@ -49,23 +48,19 @@ def common_predict_options(func):
|
|||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help=(
|
help=(
|
||||||
"Path to an outputs config file. Use this to control the "
|
"Path to an outputs config file. Use this to control the "
|
||||||
"prediction fields written to disk."
|
"saved output format and fields."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--logging-config",
|
"--logging-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help=(
|
help=("Path to a logging config file. Use this to change log output."),
|
||||||
"Path to a logging config file. Use this to customize logging "
|
|
||||||
"format and levels."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--batch-size",
|
"--batch-size",
|
||||||
type=int,
|
type=int,
|
||||||
help=(
|
help=(
|
||||||
"Batch size for inference. If omitted, the value from the "
|
"Batch size for inference. If omitted, the config value is used."
|
||||||
"loaded config is used."
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
@ -82,7 +77,7 @@ def common_predict_options(func):
|
|||||||
type=str,
|
type=str,
|
||||||
help=(
|
help=(
|
||||||
"Output format name used by the prediction writer. If omitted, "
|
"Output format name used by the prediction writer. If omitted, "
|
||||||
"the default output format is used."
|
"the config default is used."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
@ -91,7 +86,7 @@ def common_predict_options(func):
|
|||||||
default=None,
|
default=None,
|
||||||
help=(
|
help=(
|
||||||
"Optional detection score threshold override. If omitted, "
|
"Optional detection score threshold override. If omitted, "
|
||||||
"the model default threshold is used."
|
"the configured threshold is used."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
@ -102,7 +97,7 @@ def common_predict_options(func):
|
|||||||
|
|
||||||
|
|
||||||
def _build_api(
|
def _build_api(
|
||||||
model_path: Path,
|
model_path: str,
|
||||||
audio_config: Path | None,
|
audio_config: Path | None,
|
||||||
inference_config: Path | None,
|
inference_config: Path | None,
|
||||||
outputs_config: Path | None,
|
outputs_config: Path | None,
|
||||||
@ -144,7 +139,7 @@ def _build_api(
|
|||||||
|
|
||||||
|
|
||||||
def _run_prediction(
|
def _run_prediction(
|
||||||
model_path: Path,
|
model_path: str,
|
||||||
audio_files: list[Path],
|
audio_files: list[Path],
|
||||||
output_path: Path,
|
output_path: Path,
|
||||||
audio_config: Path | None,
|
audio_config: Path | None,
|
||||||
@ -191,16 +186,16 @@ def _run_prediction(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@predict.command(
|
@process.command(
|
||||||
name="directory",
|
name="directory",
|
||||||
short_help="Predict on audio files in a directory.",
|
short_help="Process audio files in a directory.",
|
||||||
)
|
)
|
||||||
@click.argument("model_path", type=click.Path(exists=True))
|
@click.argument("model_path", type=str)
|
||||||
@click.argument("audio_dir", type=click.Path(exists=True))
|
@click.argument("audio_dir", type=click.Path(exists=True))
|
||||||
@click.argument("output_path", type=click.Path())
|
@click.argument("output_path", type=click.Path())
|
||||||
@common_predict_options
|
@common_predict_options
|
||||||
def predict_directory_command(
|
def predict_directory_command(
|
||||||
model_path: Path,
|
model_path: str,
|
||||||
audio_dir: Path,
|
audio_dir: Path,
|
||||||
output_path: Path,
|
output_path: Path,
|
||||||
audio_config: Path | None,
|
audio_config: Path | None,
|
||||||
@ -212,10 +207,10 @@ def predict_directory_command(
|
|||||||
format_name: str | None,
|
format_name: str | None,
|
||||||
detection_threshold: float | None,
|
detection_threshold: float | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Predict on all audio files in a directory.
|
"""Run processing on all supported audio files in a directory.
|
||||||
|
|
||||||
Loads a checkpoint, scans `audio_dir` for supported audio files, runs
|
This command scans ``audio_dir`` for audio files, runs processing, and
|
||||||
inference, and saves predictions to `output_path`.
|
saves the results to ``output_path``.
|
||||||
"""
|
"""
|
||||||
from soundevent.audio.files import get_audio_files
|
from soundevent.audio.files import get_audio_files
|
||||||
|
|
||||||
@ -235,16 +230,16 @@ def predict_directory_command(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@predict.command(
|
@process.command(
|
||||||
name="file_list",
|
name="file_list",
|
||||||
short_help="Predict on paths listed in a text file.",
|
short_help="Process paths listed in a text file.",
|
||||||
)
|
)
|
||||||
@click.argument("model_path", type=click.Path(exists=True))
|
@click.argument("model_path", type=str)
|
||||||
@click.argument("file_list", type=click.Path(exists=True))
|
@click.argument("file_list", type=click.Path(exists=True))
|
||||||
@click.argument("output_path", type=click.Path())
|
@click.argument("output_path", type=click.Path())
|
||||||
@common_predict_options
|
@common_predict_options
|
||||||
def predict_file_list_command(
|
def predict_file_list_command(
|
||||||
model_path: Path,
|
model_path: str,
|
||||||
file_list: Path,
|
file_list: Path,
|
||||||
output_path: Path,
|
output_path: Path,
|
||||||
audio_config: Path | None,
|
audio_config: Path | None,
|
||||||
@ -256,9 +251,9 @@ def predict_file_list_command(
|
|||||||
format_name: str | None,
|
format_name: str | None,
|
||||||
detection_threshold: float | None,
|
detection_threshold: float | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Predict on audio files listed in a text file.
|
"""Run processing on audio files listed in a text file.
|
||||||
|
|
||||||
The list file should contain one audio path per line. Empty lines are
|
The text file should contain one audio path per line. Empty lines are
|
||||||
ignored.
|
ignored.
|
||||||
"""
|
"""
|
||||||
file_list = Path(file_list)
|
file_list = Path(file_list)
|
||||||
@ -283,16 +278,16 @@ def predict_file_list_command(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@predict.command(
|
@process.command(
|
||||||
name="dataset",
|
name="dataset",
|
||||||
short_help="Predict on recordings from a dataset config.",
|
short_help="Process recordings from a dataset config.",
|
||||||
)
|
)
|
||||||
@click.argument("model_path", type=click.Path(exists=True))
|
@click.argument("model_path", type=str)
|
||||||
@click.argument("dataset_path", type=click.Path(exists=True))
|
@click.argument("dataset_path", type=click.Path(exists=True))
|
||||||
@click.argument("output_path", type=click.Path())
|
@click.argument("output_path", type=click.Path())
|
||||||
@common_predict_options
|
@common_predict_options
|
||||||
def predict_dataset_command(
|
def predict_dataset_command(
|
||||||
model_path: Path,
|
model_path: str,
|
||||||
dataset_path: Path,
|
dataset_path: Path,
|
||||||
output_path: Path,
|
output_path: Path,
|
||||||
audio_config: Path | None,
|
audio_config: Path | None,
|
||||||
@ -304,10 +299,10 @@ def predict_dataset_command(
|
|||||||
format_name: str | None,
|
format_name: str | None,
|
||||||
detection_threshold: float | None,
|
detection_threshold: float | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Predict on recordings referenced in an annotation dataset.
|
"""Run processing on recordings referenced in a dataset file.
|
||||||
|
|
||||||
The dataset is read as a soundevent annotation set and unique recording
|
Recording paths are read from the dataset and each recording is processed
|
||||||
paths are extracted before inference.
|
once.
|
||||||
"""
|
"""
|
||||||
from soundevent import io
|
from soundevent import io
|
||||||
|
|
||||||
|
|||||||
@ -13,15 +13,15 @@ __all__ = ["train_command"]
|
|||||||
@click.option(
|
@click.option(
|
||||||
"--val-dataset",
|
"--val-dataset",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to validation dataset config file.",
|
help="Path to a validation dataset config file.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--model",
|
"--model",
|
||||||
"model_path",
|
"model_path",
|
||||||
type=click.Path(exists=True),
|
type=str,
|
||||||
help=(
|
help=(
|
||||||
"Path to a checkpoint to continue training from. If omitted, "
|
"Path to a checkpoint, bundled checkpoint alias, or Hugging Face "
|
||||||
"training starts from a fresh model config."
|
"URI. If omitted, training starts from a fresh model config."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
@ -36,7 +36,7 @@ __all__ = ["train_command"]
|
|||||||
"--targets",
|
"--targets",
|
||||||
"targets_config",
|
"targets_config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to targets config file.",
|
help="Path to a targets config file.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--model-config",
|
"--model-config",
|
||||||
@ -46,32 +46,32 @@ __all__ = ["train_command"]
|
|||||||
@click.option(
|
@click.option(
|
||||||
"--training-config",
|
"--training-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to training config file.",
|
help="Path to a training config file.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--audio-config",
|
"--audio-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to audio config file.",
|
help="Path to an audio config file.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--evaluation-config",
|
"--evaluation-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to evaluation config file.",
|
help="Path to an evaluation config file.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--inference-config",
|
"--inference-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to inference config file.",
|
help="Path to an inference config file.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--outputs-config",
|
"--outputs-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to outputs config file.",
|
help="Path to an outputs config file.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--logging-config",
|
"--logging-config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help="Path to logging config file.",
|
help="Path to a logging config file.",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--ckpt-dir",
|
"--ckpt-dir",
|
||||||
@ -118,7 +118,7 @@ __all__ = ["train_command"]
|
|||||||
def train_command(
|
def train_command(
|
||||||
train_dataset: Path,
|
train_dataset: Path,
|
||||||
val_dataset: Path | None = None,
|
val_dataset: Path | None = None,
|
||||||
model_path: Path | None = None,
|
model_path: str | None = None,
|
||||||
ckpt_dir: Path | None = None,
|
ckpt_dir: Path | None = None,
|
||||||
log_dir: Path | None = None,
|
log_dir: Path | None = None,
|
||||||
base_dir: Path | None = None,
|
base_dir: Path | None = None,
|
||||||
@ -139,9 +139,8 @@ def train_command(
|
|||||||
):
|
):
|
||||||
"""Train a BatDetect2 model.
|
"""Train a BatDetect2 model.
|
||||||
|
|
||||||
Train either from a fresh config (`--model-config`) or by fine-tuning an
|
Start from a fresh model config or continue from an existing checkpoint.
|
||||||
existing checkpoint (`--model`). Training data are loaded from
|
Training data are loaded from ``train_dataset``.
|
||||||
`train_dataset`, with optional validation data from `--val-dataset`.
|
|
||||||
"""
|
"""
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.audio import AudioConfig
|
||||||
|
|||||||
@ -102,19 +102,19 @@ def convert_to_annotation_group(
|
|||||||
x_inds.append(0)
|
x_inds.append(0)
|
||||||
y_inds.append(0)
|
y_inds.append(0)
|
||||||
|
|
||||||
annotations.append(
|
annotation_entry: Annotation = {
|
||||||
Annotation(
|
"start_time": start_time,
|
||||||
start_time=start_time,
|
"end_time": end_time,
|
||||||
end_time=end_time,
|
"low_freq": low_freq,
|
||||||
low_freq=low_freq,
|
"high_freq": high_freq,
|
||||||
high_freq=high_freq,
|
"class_prob": 1.0,
|
||||||
class_prob=1.0,
|
"det_prob": 1.0,
|
||||||
det_prob=1.0,
|
"individual": "0",
|
||||||
individual="0",
|
"event": event,
|
||||||
event=event,
|
"class": get_recording_class_name(recording),
|
||||||
class_id=class_id,
|
"class_id": class_id,
|
||||||
)
|
}
|
||||||
)
|
annotations.append(annotation_entry)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": str(recording.path),
|
"id": str(recording.path),
|
||||||
|
|||||||
@ -53,7 +53,7 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
def __init__(self, name: str, discriminator: str = "name"):
|
def __init__(self, name: str, discriminator: str = "name"):
|
||||||
self._name = name
|
self._name = name
|
||||||
self._registry: dict[
|
self._registry: dict[
|
||||||
str, Callable[Concatenate[..., P_Type], T_Type]
|
str, Callable[Concatenate[Any, P_Type], T_Type]
|
||||||
] = {}
|
] = {}
|
||||||
self._discriminator = discriminator
|
self._discriminator = discriminator
|
||||||
self._config_types: dict[str, Type[BaseModel]] = {}
|
self._config_types: dict[str, Type[BaseModel]] = {}
|
||||||
@ -80,7 +80,7 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def decorator(
|
def decorator(
|
||||||
func: Callable[Concatenate[T_Config, P_Type], T_Type],
|
func: Callable[..., T_Type],
|
||||||
):
|
):
|
||||||
self._registry[name] = func
|
self._registry[name] = func
|
||||||
return func
|
return func
|
||||||
@ -102,8 +102,8 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
config: BaseModel,
|
config: BaseModel,
|
||||||
*args: P_Type.args,
|
*args: Any,
|
||||||
**kwargs: P_Type.kwargs,
|
**kwargs: Any,
|
||||||
) -> T_Type:
|
) -> T_Type:
|
||||||
"""Builds a logic instance from a config object."""
|
"""Builds a logic instance from a config object."""
|
||||||
|
|
||||||
|
|||||||
@ -12,13 +12,15 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _default_tasks() -> list[TaskConfig]:
|
||||||
|
return [
|
||||||
|
DetectionTaskConfig(),
|
||||||
|
ClassificationTaskConfig(),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class EvaluationConfig(BaseConfig):
|
class EvaluationConfig(BaseConfig):
|
||||||
tasks: List[TaskConfig] = Field(
|
tasks: List[TaskConfig] = Field(default_factory=_default_tasks)
|
||||||
default_factory=lambda: [
|
|
||||||
DetectionTaskConfig(),
|
|
||||||
ClassificationTaskConfig(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_eval_config() -> EvaluationConfig:
|
def get_default_eval_config() -> EvaluationConfig:
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from batdetect2.evaluate.dataset import build_test_loader
|
|||||||
from batdetect2.evaluate.evaluator import build_evaluator
|
from batdetect2.evaluate.evaluator import build_evaluator
|
||||||
from batdetect2.evaluate.lightning import EvaluationModule
|
from batdetect2.evaluate.lightning import EvaluationModule
|
||||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger
|
from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models.types import ModelProtocol
|
||||||
from batdetect2.outputs import OutputsConfig, build_output_transform
|
from batdetect2.outputs import OutputsConfig, build_output_transform
|
||||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
@ -22,7 +22,7 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
|||||||
|
|
||||||
|
|
||||||
def run_evaluate(
|
def run_evaluate(
|
||||||
model: Model,
|
model: ModelProtocol,
|
||||||
test_annotations: Sequence[data.ClipAnnotation],
|
test_annotations: Sequence[data.ClipAnnotation],
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
roi_mapper: ROIMapperProtocol,
|
roi_mapper: ROIMapperProtocol,
|
||||||
|
|||||||
@ -7,14 +7,14 @@ from torch.utils.data import DataLoader
|
|||||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||||
from batdetect2.logging import get_image_logger
|
from batdetect2.logging import get_image_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models.types import ModelProtocol
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
|
|
||||||
|
|
||||||
class EvaluationModule(LightningModule):
|
class EvaluationModule(LightningModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Model,
|
model: ModelProtocol,
|
||||||
evaluator: EvaluatorProtocol,
|
evaluator: EvaluatorProtocol,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -25,11 +25,15 @@ from batdetect2.postprocess.types import ClipDetections, Detection
|
|||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
def _default_metrics() -> list[ClassificationMetricConfig]:
|
||||||
|
return [ClassificationAveragePrecisionConfig()]
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTaskConfig(BaseSEDTaskConfig):
|
class ClassificationTaskConfig(BaseSEDTaskConfig):
|
||||||
name: Literal["sound_event_classification"] = "sound_event_classification"
|
name: Literal["sound_event_classification"] = "sound_event_classification"
|
||||||
prefix: str = "classification"
|
prefix: str = "classification"
|
||||||
metrics: list[ClassificationMetricConfig] = Field(
|
metrics: list[ClassificationMetricConfig] = Field(
|
||||||
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
default_factory=_default_metrics
|
||||||
)
|
)
|
||||||
plots: list[ClassificationPlotConfig] = Field(default_factory=list)
|
plots: list[ClassificationPlotConfig] = Field(default_factory=list)
|
||||||
include_generics: bool = True
|
include_generics: bool = True
|
||||||
|
|||||||
@ -23,13 +23,15 @@ from batdetect2.postprocess.types import ClipDetections
|
|||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
def _default_metrics() -> list[ClipClassificationMetricConfig]:
|
||||||
|
return [ClipClassificationAveragePrecisionConfig()]
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||||
name: Literal["clip_classification"] = "clip_classification"
|
name: Literal["clip_classification"] = "clip_classification"
|
||||||
prefix: str = "clip_classification"
|
prefix: str = "clip_classification"
|
||||||
metrics: list[ClipClassificationMetricConfig] = Field(
|
metrics: list[ClipClassificationMetricConfig] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=_default_metrics
|
||||||
ClipClassificationAveragePrecisionConfig(),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)
|
plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|||||||
@ -22,13 +22,15 @@ from batdetect2.postprocess.types import ClipDetections
|
|||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
def _default_metrics() -> list[ClipDetectionMetricConfig]:
|
||||||
|
return [ClipDetectionAveragePrecisionConfig()]
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||||
name: Literal["clip_detection"] = "clip_detection"
|
name: Literal["clip_detection"] = "clip_detection"
|
||||||
prefix: str = "clip_detection"
|
prefix: str = "clip_detection"
|
||||||
metrics: list[ClipDetectionMetricConfig] = Field(
|
metrics: list[ClipDetectionMetricConfig] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=_default_metrics
|
||||||
ClipDetectionAveragePrecisionConfig(),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)
|
plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|||||||
@ -24,11 +24,15 @@ from batdetect2.postprocess.types import ClipDetections
|
|||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
def _default_metrics() -> list[DetectionMetricConfig]:
|
||||||
|
return [DetectionAveragePrecisionConfig()]
|
||||||
|
|
||||||
|
|
||||||
class DetectionTaskConfig(BaseSEDTaskConfig):
|
class DetectionTaskConfig(BaseSEDTaskConfig):
|
||||||
name: Literal["sound_event_detection"] = "sound_event_detection"
|
name: Literal["sound_event_detection"] = "sound_event_detection"
|
||||||
prefix: str = "detection"
|
prefix: str = "detection"
|
||||||
metrics: list[DetectionMetricConfig] = Field(
|
metrics: list[DetectionMetricConfig] = Field(
|
||||||
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
default_factory=_default_metrics
|
||||||
)
|
)
|
||||||
plots: list[DetectionPlotConfig] = Field(default_factory=list)
|
plots: list[DetectionPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|||||||
@ -24,11 +24,15 @@ from batdetect2.postprocess.types import ClipDetections
|
|||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
def _default_metrics() -> list[TopClassMetricConfig]:
|
||||||
|
return [TopClassAveragePrecisionConfig()]
|
||||||
|
|
||||||
|
|
||||||
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
|
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
|
||||||
name: Literal["top_class_detection"] = "top_class_detection"
|
name: Literal["top_class_detection"] = "top_class_detection"
|
||||||
prefix: str = "top_class"
|
prefix: str = "top_class"
|
||||||
metrics: list[TopClassMetricConfig] = Field(
|
metrics: list[TopClassMetricConfig] = Field(
|
||||||
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
default_factory=_default_metrics
|
||||||
)
|
)
|
||||||
plots: list[TopClassPlotConfig] = Field(default_factory=list)
|
plots: list[TopClassPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from batdetect2.inference.clips import get_clips_from_files
|
|||||||
from batdetect2.inference.config import InferenceConfig
|
from batdetect2.inference.config import InferenceConfig
|
||||||
from batdetect2.inference.dataset import build_inference_loader
|
from batdetect2.inference.dataset import build_inference_loader
|
||||||
from batdetect2.inference.lightning import InferenceModule
|
from batdetect2.inference.lightning import InferenceModule
|
||||||
from batdetect2.models import Model
|
from batdetect2.models.types import ModelProtocol
|
||||||
from batdetect2.outputs import (
|
from batdetect2.outputs import (
|
||||||
OutputsConfig,
|
OutputsConfig,
|
||||||
OutputTransformProtocol,
|
OutputTransformProtocol,
|
||||||
@ -22,7 +22,7 @@ from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
|||||||
|
|
||||||
|
|
||||||
def run_batch_inference(
|
def run_batch_inference(
|
||||||
model: Model,
|
model: ModelProtocol,
|
||||||
clips: Sequence[data.Clip],
|
clips: Sequence[data.Clip],
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
roi_mapper: ROIMapperProtocol | None = None,
|
roi_mapper: ROIMapperProtocol | None = None,
|
||||||
@ -86,7 +86,7 @@ def run_batch_inference(
|
|||||||
|
|
||||||
|
|
||||||
def process_file_list(
|
def process_file_list(
|
||||||
model: Model,
|
model: ModelProtocol,
|
||||||
paths: Sequence[data.PathLike],
|
paths: Sequence[data.PathLike],
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
roi_mapper: ROIMapperProtocol | None = None,
|
roi_mapper: ROIMapperProtocol | None = None,
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from lightning import LightningModule
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||||
from batdetect2.models import Model
|
from batdetect2.models.types import ModelProtocol
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||||
@ -13,7 +13,7 @@ from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
|||||||
class InferenceModule(LightningModule):
|
class InferenceModule(LightningModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Model,
|
model: ModelProtocol,
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
roi_mapper: ROIMapperProtocol | None = None,
|
roi_mapper: ROIMapperProtocol | None = None,
|
||||||
output_transform: OutputTransformProtocol | None = None,
|
output_transform: OutputTransformProtocol | None = None,
|
||||||
|
|||||||
@ -62,7 +62,7 @@ from batdetect2.models.encoder import (
|
|||||||
build_encoder,
|
build_encoder,
|
||||||
)
|
)
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||||
from batdetect2.models.types import DetectionModel
|
from batdetect2.models.types import DetectorProtocol, ModelProtocol
|
||||||
from batdetect2.postprocess.config import PostprocessConfig
|
from batdetect2.postprocess.config import PostprocessConfig
|
||||||
from batdetect2.postprocess.types import (
|
from batdetect2.postprocess.types import (
|
||||||
ClipDetectionsTensor,
|
ClipDetectionsTensor,
|
||||||
@ -149,7 +149,7 @@ class Model(torch.nn.Module):
|
|||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
detector : DetectionModel
|
detector : DetectorProtocol
|
||||||
The neural network that processes spectrograms and produces raw
|
The neural network that processes spectrograms and produces raw
|
||||||
detection, classification, and bounding-box outputs.
|
detection, classification, and bounding-box outputs.
|
||||||
preprocessor : PreprocessorProtocol
|
preprocessor : PreprocessorProtocol
|
||||||
@ -164,19 +164,21 @@ class Model(torch.nn.Module):
|
|||||||
Size-dimension names corresponding to the model size outputs.
|
Size-dimension names corresponding to the model size outputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
detector: DetectionModel
|
detector: DetectorProtocol
|
||||||
preprocessor: PreprocessorProtocol
|
preprocessor: PreprocessorProtocol
|
||||||
postprocessor: PostprocessorProtocol
|
postprocessor: PostprocessorProtocol
|
||||||
class_names: list[str]
|
class_names: list[str]
|
||||||
dimension_names: list[str]
|
dimension_names: list[str]
|
||||||
|
_config: dict[str, object]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
detector: DetectionModel,
|
detector: DetectorProtocol,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
class_names: list[str],
|
class_names: list[str],
|
||||||
dimension_names: list[str],
|
dimension_names: list[str],
|
||||||
|
config: dict[str, object],
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.detector = detector
|
self.detector = detector
|
||||||
@ -184,6 +186,12 @@ class Model(torch.nn.Module):
|
|||||||
self.postprocessor = postprocessor
|
self.postprocessor = postprocessor
|
||||||
self.class_names = class_names
|
self.class_names = class_names
|
||||||
self.dimension_names = dimension_names
|
self.dimension_names = dimension_names
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, object]:
|
||||||
|
"""Return the model configuration as plain JSON-serializable data."""
|
||||||
|
|
||||||
|
return dict(self._config)
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
||||||
"""Run the full detection pipeline on a waveform tensor.
|
"""Run the full detection pipeline on a waveform tensor.
|
||||||
@ -216,7 +224,7 @@ def build_model(
|
|||||||
dimension_names: list[str] | None = None,
|
dimension_names: list[str] | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
postprocessor: PostprocessorProtocol | None = None,
|
postprocessor: PostprocessorProtocol | None = None,
|
||||||
) -> Model:
|
) -> ModelProtocol:
|
||||||
"""Build a complete, ready-to-use BatDetect2 model.
|
"""Build a complete, ready-to-use BatDetect2 model.
|
||||||
|
|
||||||
Assembles a ``Model`` instance from a ``ModelConfig`` and optional
|
Assembles a ``Model`` instance from a ``ModelConfig`` and optional
|
||||||
@ -248,7 +256,7 @@ def build_model(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Model
|
ModelProtocol
|
||||||
A fully assembled ``Model`` instance ready for inference or
|
A fully assembled ``Model`` instance ready for inference or
|
||||||
training.
|
training.
|
||||||
"""
|
"""
|
||||||
@ -277,8 +285,8 @@ def build_model(
|
|||||||
config=config.postprocess,
|
config=config.postprocess,
|
||||||
)
|
)
|
||||||
detector = build_detector(
|
detector = build_detector(
|
||||||
num_classes=len(class_names),
|
class_names=class_names,
|
||||||
num_sizes=len(dimension_names),
|
dimension_names=dimension_names,
|
||||||
config=config.architecture,
|
config=config.architecture,
|
||||||
)
|
)
|
||||||
return Model(
|
return Model(
|
||||||
@ -287,18 +295,19 @@ def build_model(
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
class_names=class_names,
|
class_names=class_names,
|
||||||
dimension_names=dimension_names,
|
dimension_names=dimension_names,
|
||||||
|
config=config.model_dump(mode="json"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_model_with_new_targets(
|
def build_model_with_new_targets(
|
||||||
model: Model,
|
model: ModelProtocol,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
roi_mapper: ROIMapperProtocol,
|
roi_mapper: ROIMapperProtocol,
|
||||||
) -> Model:
|
) -> ModelProtocol:
|
||||||
"""Build a new model with a different target set."""
|
"""Build a new model with a different target set."""
|
||||||
detector = build_detector(
|
detector = build_detector(
|
||||||
num_classes=len(targets.class_names),
|
class_names=targets.class_names,
|
||||||
num_sizes=len(roi_mapper.dimension_names),
|
dimension_names=roi_mapper.dimension_names,
|
||||||
backbone=model.detector.backbone,
|
backbone=model.detector.backbone,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -308,4 +317,5 @@ def build_model_with_new_targets(
|
|||||||
preprocessor=model.preprocessor,
|
preprocessor=model.preprocessor,
|
||||||
class_names=targets.class_names,
|
class_names=targets.class_names,
|
||||||
dimension_names=roi_mapper.dimension_names,
|
dimension_names=roi_mapper.dimension_names,
|
||||||
|
config=model.get_config(),
|
||||||
)
|
)
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from typing import Annotated, Literal
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from loguru import logger
|
||||||
from pydantic import Field, TypeAdapter
|
from pydantic import Field, TypeAdapter
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -52,7 +53,7 @@ from batdetect2.models.encoder import (
|
|||||||
build_encoder,
|
build_encoder,
|
||||||
)
|
)
|
||||||
from batdetect2.models.types import (
|
from batdetect2.models.types import (
|
||||||
BackboneModel,
|
BackboneProtocol,
|
||||||
BottleneckProtocol,
|
BottleneckProtocol,
|
||||||
DecoderProtocol,
|
DecoderProtocol,
|
||||||
EncoderProtocol,
|
EncoderProtocol,
|
||||||
@ -104,7 +105,7 @@ class UNetBackboneConfig(BaseConfig):
|
|||||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||||
|
|
||||||
|
|
||||||
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
|
backbone_registry: Registry[BackboneProtocol, []] = Registry("backbone")
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(backbone_registry)
|
@add_import_config(backbone_registry)
|
||||||
@ -118,7 +119,7 @@ class BackboneImportConfig(ImportConfig):
|
|||||||
name: Literal["import"] = "import"
|
name: Literal["import"] = "import"
|
||||||
|
|
||||||
|
|
||||||
class UNetBackbone(BackboneModel):
|
class UNetBackbone(torch.nn.Module):
|
||||||
"""U-Net-style encoder-decoder backbone network.
|
"""U-Net-style encoder-decoder backbone network.
|
||||||
|
|
||||||
Combines an encoder, a bottleneck, and a decoder into a single module
|
Combines an encoder, a bottleneck, and a decoder into a single module
|
||||||
@ -225,7 +226,7 @@ class UNetBackbone(BackboneModel):
|
|||||||
|
|
||||||
@backbone_registry.register(UNetBackboneConfig)
|
@backbone_registry.register(UNetBackboneConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(config: UNetBackboneConfig) -> BackboneModel:
|
def from_config(config: UNetBackboneConfig) -> BackboneProtocol:
|
||||||
encoder = build_encoder(
|
encoder = build_encoder(
|
||||||
in_channels=config.in_channels,
|
in_channels=config.in_channels,
|
||||||
input_height=config.input_height,
|
input_height=config.input_height,
|
||||||
@ -266,7 +267,7 @@ BackboneConfig = Annotated[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
def build_backbone(config: BackboneConfig | None = None) -> BackboneProtocol:
|
||||||
"""Build a backbone network from configuration.
|
"""Build a backbone network from configuration.
|
||||||
|
|
||||||
Looks up the backbone class corresponding to ``config.name`` in the
|
Looks up the backbone class corresponding to ``config.name`` in the
|
||||||
@ -282,10 +283,14 @@ def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
BackboneModel
|
BackboneProtocol
|
||||||
An initialised backbone module.
|
An initialised backbone module.
|
||||||
"""
|
"""
|
||||||
config = config or UNetBackboneConfig()
|
config = config or UNetBackboneConfig()
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building model backbone with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
return backbone_registry.build(config)
|
return backbone_registry.build(config)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
BIN
src/batdetect2/models/checkpoints/batdetect2_uk_same.ckpt
Normal file
BIN
src/batdetect2/models/checkpoints/batdetect2_uk_same.ckpt
Normal file
Binary file not shown.
@ -6,8 +6,8 @@ bounding-box size regression.
|
|||||||
|
|
||||||
Components
|
Components
|
||||||
----------
|
----------
|
||||||
- ``Detector`` – the ``torch.nn.Module`` that wires together a backbone
|
- ``Detector`` - the ``torch.nn.Module`` that wires together a backbone
|
||||||
(``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to
|
(``BackboneProtocol``) with a ``ClassifierHead`` and a ``BBoxHead`` to
|
||||||
produce a ``ModelOutput`` tuple from an input spectrogram.
|
produce a ``ModelOutput`` tuple from an input spectrogram.
|
||||||
- ``build_detector`` – factory function that builds a ready-to-use
|
- ``build_detector`` – factory function that builds a ready-to-use
|
||||||
``Detector`` from a backbone configuration and a target class count.
|
``Detector`` from a backbone configuration and a target class count.
|
||||||
@ -18,15 +18,16 @@ preprocessing and output postprocessing are handled by
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import BackboneConfig, build_backbone
|
||||||
BackboneConfig,
|
|
||||||
UNetBackboneConfig,
|
|
||||||
build_backbone,
|
|
||||||
)
|
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
from batdetect2.models.types import (
|
||||||
|
BackboneProtocol,
|
||||||
|
ClassifierHeadProtocol,
|
||||||
|
DetectorProtocol,
|
||||||
|
ModelOutput,
|
||||||
|
SizeHeadProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Detector",
|
"Detector",
|
||||||
@ -34,7 +35,7 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Detector(DetectionModel):
|
class Detector(torch.nn.Module):
|
||||||
"""Complete BatDetect2 detection and classification model.
|
"""Complete BatDetect2 detection and classification model.
|
||||||
|
|
||||||
Combines a backbone feature extractor with two prediction heads:
|
Combines a backbone feature extractor with two prediction heads:
|
||||||
@ -51,7 +52,7 @@ class Detector(DetectionModel):
|
|||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
backbone : BackboneModel
|
backbone : BackboneProtocol
|
||||||
The feature extraction backbone.
|
The feature extraction backbone.
|
||||||
num_classes : int
|
num_classes : int
|
||||||
Number of target classes (inferred from the classifier head).
|
Number of target classes (inferred from the classifier head).
|
||||||
@ -61,13 +62,13 @@ class Detector(DetectionModel):
|
|||||||
Produces duration and bandwidth predictions from backbone features.
|
Produces duration and bandwidth predictions from backbone features.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
backbone: BackboneModel
|
backbone: BackboneProtocol
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
backbone: BackboneModel,
|
backbone: BackboneProtocol,
|
||||||
classifier_head: ClassifierHead,
|
classifier_head: ClassifierHeadProtocol,
|
||||||
bbox_head: BBoxHead,
|
size_head: SizeHeadProtocol,
|
||||||
):
|
):
|
||||||
"""Initialise the Detector model.
|
"""Initialise the Detector model.
|
||||||
|
|
||||||
@ -76,7 +77,7 @@ class Detector(DetectionModel):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
backbone : BackboneModel
|
backbone : BackboneProtocol
|
||||||
An initialised backbone module (e.g. built by
|
An initialised backbone module (e.g. built by
|
||||||
``build_backbone``).
|
``build_backbone``).
|
||||||
classifier_head : ClassifierHead
|
classifier_head : ClassifierHead
|
||||||
@ -90,7 +91,7 @@ class Detector(DetectionModel):
|
|||||||
self.backbone = backbone
|
self.backbone = backbone
|
||||||
self.num_classes = classifier_head.num_classes
|
self.num_classes = classifier_head.num_classes
|
||||||
self.classifier_head = classifier_head
|
self.classifier_head = classifier_head
|
||||||
self.bbox_head = bbox_head
|
self.size_head = size_head
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
"""Run the complete detection model on an input spectrogram.
|
"""Run the complete detection model on an input spectrogram.
|
||||||
@ -125,7 +126,7 @@ class Detector(DetectionModel):
|
|||||||
features = self.backbone(spec)
|
features = self.backbone(spec)
|
||||||
classification = self.classifier_head(features)
|
classification = self.classifier_head(features)
|
||||||
detection = classification.sum(dim=1, keepdim=True)
|
detection = classification.sum(dim=1, keepdim=True)
|
||||||
size_preds = self.bbox_head(features)
|
size_preds = self.size_head(features)
|
||||||
return ModelOutput(
|
return ModelOutput(
|
||||||
detection_probs=detection,
|
detection_probs=detection,
|
||||||
size_preds=size_preds,
|
size_preds=size_preds,
|
||||||
@ -135,11 +136,11 @@ class Detector(DetectionModel):
|
|||||||
|
|
||||||
|
|
||||||
def build_detector(
|
def build_detector(
|
||||||
num_classes: int,
|
class_names: list[str],
|
||||||
num_sizes: int = 2,
|
dimension_names: list[str],
|
||||||
config: BackboneConfig | None = None,
|
config: BackboneConfig | None = None,
|
||||||
backbone: BackboneModel | None = None,
|
backbone: BackboneProtocol | None = None,
|
||||||
) -> DetectionModel:
|
) -> DetectorProtocol:
|
||||||
"""Build a complete BatDetect2 detection model.
|
"""Build a complete BatDetect2 detection model.
|
||||||
|
|
||||||
Constructs a backbone from ``config``, attaches a ``ClassifierHead``
|
Constructs a backbone from ``config``, attaches a ``ClassifierHead``
|
||||||
@ -158,7 +159,7 @@ def build_detector(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
DetectionModel
|
DetectorProtocol
|
||||||
An initialised ``Detector`` instance ready for training or
|
An initialised ``Detector`` instance ready for training or
|
||||||
inference.
|
inference.
|
||||||
|
|
||||||
@ -168,24 +169,18 @@ def build_detector(
|
|||||||
If ``num_classes`` is not positive, or if the backbone
|
If ``num_classes`` is not positive, or if the backbone
|
||||||
configuration is invalid.
|
configuration is invalid.
|
||||||
"""
|
"""
|
||||||
if backbone is None:
|
backbone = backbone or build_backbone(config=config)
|
||||||
config = config or UNetBackboneConfig()
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Building model with config: \n{}",
|
|
||||||
lambda: config.to_yaml_string(), # type: ignore
|
|
||||||
)
|
|
||||||
backbone = build_backbone(config=config)
|
|
||||||
|
|
||||||
classifier_head = ClassifierHead(
|
classifier_head = ClassifierHead(
|
||||||
num_classes=num_classes,
|
class_names=class_names,
|
||||||
in_channels=backbone.out_channels,
|
in_channels=backbone.out_channels,
|
||||||
)
|
)
|
||||||
bbox_head = BBoxHead(
|
bbox_head = BBoxHead(
|
||||||
in_channels=backbone.out_channels,
|
in_channels=backbone.out_channels,
|
||||||
num_sizes=num_sizes,
|
dimension_names=dimension_names,
|
||||||
)
|
)
|
||||||
return Detector(
|
return Detector(
|
||||||
backbone=backbone,
|
backbone=backbone,
|
||||||
classifier_head=classifier_head,
|
classifier_head=classifier_head,
|
||||||
bbox_head=bbox_head,
|
size_head=bbox_head,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -54,12 +54,14 @@ class ClassifierHead(nn.Module):
|
|||||||
1×1 convolution with ``num_classes + 1`` output channels.
|
1×1 convolution with ``num_classes + 1`` output channels.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_classes: int, in_channels: int):
|
def __init__(self, class_names: list[str], in_channels: int):
|
||||||
"""Initialise the ClassifierHead."""
|
"""Initialise the ClassifierHead."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_classes = num_classes
|
self.class_names = class_names
|
||||||
|
self.num_classes = len(class_names)
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
self.classifier = nn.Conv2d(
|
self.classifier = nn.Conv2d(
|
||||||
self.in_channels,
|
self.in_channels,
|
||||||
self.num_classes + 1,
|
self.num_classes + 1,
|
||||||
@ -165,11 +167,12 @@ class BBoxHead(nn.Module):
|
|||||||
1×1 convolution with 2 output channels (duration, bandwidth).
|
1×1 convolution with 2 output channels (duration, bandwidth).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels: int, num_sizes: int = 2):
|
def __init__(self, dimension_names: list[str], in_channels: int):
|
||||||
"""Initialise the BBoxHead."""
|
"""Initialise the BBoxHead."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.num_sizes = num_sizes
|
self.dimension_names = dimension_names
|
||||||
|
self.num_sizes = len(dimension_names)
|
||||||
|
|
||||||
self.bbox = nn.Conv2d(
|
self.bbox = nn.Conv2d(
|
||||||
in_channels=self.in_channels,
|
in_channels=self.in_channels,
|
||||||
|
|||||||
@ -1,21 +1,42 @@
|
|||||||
from abc import ABC, abstractmethod
|
from typing import Any, NamedTuple, Protocol
|
||||||
from typing import NamedTuple, Protocol
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||||
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BackboneModel",
|
"BackboneProtocol",
|
||||||
"BlockProtocol",
|
"BlockProtocol",
|
||||||
"BottleneckProtocol",
|
"BottleneckProtocol",
|
||||||
|
"ClassifierHeadProtocol",
|
||||||
"DecoderProtocol",
|
"DecoderProtocol",
|
||||||
"DetectionModel",
|
"DetectorProtocol",
|
||||||
"EncoderDecoderModel",
|
|
||||||
"EncoderProtocol",
|
"EncoderProtocol",
|
||||||
"ModelOutput",
|
"ModelOutput",
|
||||||
|
"ModelProtocol",
|
||||||
|
"ModuleProtocol",
|
||||||
|
"SizeHeadProtocol",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class BlockProtocol(Protocol):
|
class ModuleProtocol(Protocol):
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||||
|
|
||||||
|
def train(self, mode: bool = True) -> torch.nn.Module: ...
|
||||||
|
|
||||||
|
def eval(self) -> torch.nn.Module: ...
|
||||||
|
|
||||||
|
def state_dict(
|
||||||
|
self, *args: Any, **kwargs: Any
|
||||||
|
) -> dict[str, torch.Tensor]: ...
|
||||||
|
|
||||||
|
def load_state_dict(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||||
|
|
||||||
|
def parameters(self) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
|
class BlockProtocol(ModuleProtocol, Protocol):
|
||||||
in_channels: int
|
in_channels: int
|
||||||
out_channels: int
|
out_channels: int
|
||||||
|
|
||||||
@ -24,7 +45,7 @@ class BlockProtocol(Protocol):
|
|||||||
def get_output_height(self, input_height: int) -> int: ...
|
def get_output_height(self, input_height: int) -> int: ...
|
||||||
|
|
||||||
|
|
||||||
class EncoderProtocol(Protocol):
|
class EncoderProtocol(ModuleProtocol, Protocol):
|
||||||
in_channels: int
|
in_channels: int
|
||||||
out_channels: int
|
out_channels: int
|
||||||
input_height: int
|
input_height: int
|
||||||
@ -33,7 +54,7 @@ class EncoderProtocol(Protocol):
|
|||||||
def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ...
|
def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ...
|
||||||
|
|
||||||
|
|
||||||
class BottleneckProtocol(Protocol):
|
class BottleneckProtocol(ModuleProtocol, Protocol):
|
||||||
in_channels: int
|
in_channels: int
|
||||||
out_channels: int
|
out_channels: int
|
||||||
input_height: int
|
input_height: int
|
||||||
@ -41,7 +62,7 @@ class BottleneckProtocol(Protocol):
|
|||||||
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
|
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
class DecoderProtocol(Protocol):
|
class DecoderProtocol(ModuleProtocol, Protocol):
|
||||||
in_channels: int
|
in_channels: int
|
||||||
out_channels: int
|
out_channels: int
|
||||||
input_height: int
|
input_height: int
|
||||||
@ -62,29 +83,42 @@ class ModelOutput(NamedTuple):
|
|||||||
features: torch.Tensor
|
features: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class BackboneModel(ABC, torch.nn.Module):
|
class BackboneProtocol(ModuleProtocol, Protocol):
|
||||||
input_height: int
|
input_height: int
|
||||||
out_channels: int
|
out_channels: int
|
||||||
|
|
||||||
@abstractmethod
|
def forward(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderDecoderModel(BackboneModel):
|
class ClassifierHeadProtocol(ModuleProtocol, Protocol):
|
||||||
bottleneck_channels: int
|
num_classes: int
|
||||||
|
in_channels: int
|
||||||
|
class_names: list[str]
|
||||||
|
|
||||||
@abstractmethod
|
def forward(self, features: torch.Tensor) -> torch.Tensor: ...
|
||||||
def encode(self, spec: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def decode(self, encoded: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionModel(ABC, torch.nn.Module):
|
class SizeHeadProtocol(ModuleProtocol, Protocol):
|
||||||
backbone: BackboneModel
|
in_channels: int
|
||||||
classifier_head: torch.nn.Module
|
num_sizes: int
|
||||||
bbox_head: torch.nn.Module
|
dimension_names: list[str]
|
||||||
|
|
||||||
|
def forward(self, features: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DetectorProtocol(ModuleProtocol, Protocol):
|
||||||
|
backbone: BackboneProtocol
|
||||||
|
classifier_head: ClassifierHeadProtocol
|
||||||
|
size_head: SizeHeadProtocol
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput: ...
|
def forward(self, spec: torch.Tensor) -> ModelOutput: ...
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProtocol(ModuleProtocol, Protocol):
|
||||||
|
detector: DetectorProtocol
|
||||||
|
preprocessor: PreprocessorProtocol
|
||||||
|
postprocessor: PostprocessorProtocol
|
||||||
|
class_names: list[str]
|
||||||
|
dimension_names: list[str]
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]: ...
|
||||||
|
|||||||
@ -154,17 +154,18 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
top_class_index = int(np.argmax(prediction.class_scores))
|
top_class_index = int(np.argmax(prediction.class_scores))
|
||||||
top_class_score = float(prediction.class_scores[top_class_index])
|
top_class_score = float(prediction.class_scores[top_class_index])
|
||||||
top_class = self.get_class_name(top_class_index)
|
top_class = self.get_class_name(top_class_index)
|
||||||
return Annotation(
|
annotation: Annotation = {
|
||||||
start_time=start_time,
|
"start_time": start_time,
|
||||||
end_time=end_time,
|
"end_time": end_time,
|
||||||
low_freq=low_freq,
|
"low_freq": low_freq,
|
||||||
high_freq=high_freq,
|
"high_freq": high_freq,
|
||||||
class_prob=top_class_score,
|
"class_prob": top_class_score,
|
||||||
det_prob=float(prediction.detection_score),
|
"det_prob": float(prediction.detection_score),
|
||||||
individual="",
|
"individual": "",
|
||||||
event=self.event_name,
|
"event": self.event_name,
|
||||||
**{"class": top_class},
|
"class": top_class,
|
||||||
)
|
}
|
||||||
|
return annotation
|
||||||
|
|
||||||
@output_formatters.register(BatDetect2OutputConfig)
|
@output_formatters.register(BatDetect2OutputConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -26,6 +26,13 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _default_spectrogram_transforms() -> list[SpectrogramTransform]:
|
||||||
|
return [
|
||||||
|
PcenConfig(),
|
||||||
|
SpectralMeanSubtractionConfig(),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class PreprocessingConfig(BaseConfig):
|
class PreprocessingConfig(BaseConfig):
|
||||||
"""Unified configuration for the audio preprocessing pipeline.
|
"""Unified configuration for the audio preprocessing pipeline.
|
||||||
|
|
||||||
@ -58,10 +65,7 @@ class PreprocessingConfig(BaseConfig):
|
|||||||
audio_transforms: List[AudioTransform] = Field(default_factory=list)
|
audio_transforms: List[AudioTransform] = Field(default_factory=list)
|
||||||
|
|
||||||
spectrogram_transforms: List[SpectrogramTransform] = Field(
|
spectrogram_transforms: List[SpectrogramTransform] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=_default_spectrogram_transforms
|
||||||
PcenConfig(),
|
|
||||||
SpectralMeanSubtractionConfig(),
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||||
|
|||||||
@ -71,7 +71,7 @@ class TargetClassConfig(BaseConfig):
|
|||||||
|
|
||||||
DEFAULT_DETECTION_CLASS = TargetClassConfig(
|
DEFAULT_DETECTION_CLASS = TargetClassConfig(
|
||||||
name="bat",
|
name="bat",
|
||||||
match_if=AllOfConfig( # ty: ignore[unknown-argument]
|
match_if=AllOfConfig(
|
||||||
conditions=[
|
conditions=[
|
||||||
HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
|
HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
|
||||||
NotConfig(
|
NotConfig(
|
||||||
|
|||||||
@ -1,4 +1,7 @@
|
|||||||
from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR
|
from batdetect2.train.checkpoints import (
|
||||||
|
DEFAULT_CHECKPOINT_DIR,
|
||||||
|
resolve_checkpoint_path,
|
||||||
|
)
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import TrainingConfig
|
||||||
from batdetect2.train.lightning import (
|
from batdetect2.train.lightning import (
|
||||||
TrainingModule,
|
TrainingModule,
|
||||||
@ -26,5 +29,6 @@ __all__ = [
|
|||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
"build_trainer",
|
"build_trainer",
|
||||||
"load_model_from_checkpoint",
|
"load_model_from_checkpoint",
|
||||||
|
"resolve_checkpoint_path",
|
||||||
"run_train",
|
"run_train",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -2,15 +2,31 @@ from pathlib import Path
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CheckpointConfig",
|
"CheckpointConfig",
|
||||||
|
"DEFAULT_CHECKPOINT",
|
||||||
"build_checkpoint_callback",
|
"build_checkpoint_callback",
|
||||||
|
"get_bundled_checkpoint_names",
|
||||||
|
"resolve_checkpoint_path",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
PACKAGE_ROOT = Path(__file__).resolve().parents[1]
|
||||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||||
|
DEFAULT_CHECKPOINT = "uk_same"
|
||||||
|
CHECKPOINT_ALIASES = {
|
||||||
|
DEFAULT_CHECKPOINT: PACKAGE_ROOT
|
||||||
|
/ "models"
|
||||||
|
/ "checkpoints"
|
||||||
|
/ "batdetect2_uk_same.ckpt",
|
||||||
|
"batdetect2_uk_same": PACKAGE_ROOT
|
||||||
|
/ "models"
|
||||||
|
/ "checkpoints"
|
||||||
|
/ "batdetect2_uk_same.ckpt",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class CheckpointConfig(BaseConfig):
|
class CheckpointConfig(BaseConfig):
|
||||||
@ -18,6 +34,8 @@ class CheckpointConfig(BaseConfig):
|
|||||||
monitor: str | None = None
|
monitor: str | None = None
|
||||||
mode: str = "max"
|
mode: str = "max"
|
||||||
save_top_k: int = 1
|
save_top_k: int = 1
|
||||||
|
# Save distributable inference checkpoints by default.
|
||||||
|
save_weights_only: bool = True
|
||||||
filename: str | None = None
|
filename: str | None = None
|
||||||
save_last: bool | Literal["link"] = "link"
|
save_last: bool | Literal["link"] = "link"
|
||||||
every_n_epochs: int | None = 1
|
every_n_epochs: int | None = 1
|
||||||
@ -47,9 +65,86 @@ def build_checkpoint_callback(
|
|||||||
return ModelCheckpoint(
|
return ModelCheckpoint(
|
||||||
dirpath=str(checkpoint_dir),
|
dirpath=str(checkpoint_dir),
|
||||||
save_top_k=config.save_top_k,
|
save_top_k=config.save_top_k,
|
||||||
|
save_weights_only=config.save_weights_only,
|
||||||
monitor=config.monitor,
|
monitor=config.monitor,
|
||||||
mode=config.mode,
|
mode=config.mode,
|
||||||
filename=config.filename,
|
filename=config.filename,
|
||||||
save_last=config.save_last,
|
save_last=config.save_last,
|
||||||
every_n_epochs=config.every_n_epochs,
|
every_n_epochs=config.every_n_epochs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_bundled_checkpoint_names() -> tuple[str, ...]:
|
||||||
|
"""Return the supported bundled checkpoint aliases."""
|
||||||
|
return tuple(CHECKPOINT_ALIASES.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_checkpoint_from_huggingface(path: str) -> Path:
|
||||||
|
"""Resolve a Hugging Face checkpoint URI."""
|
||||||
|
try:
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
except ImportError as error:
|
||||||
|
raise ValueError(
|
||||||
|
"Hugging Face checkpoint support is not installed. "
|
||||||
|
"Install it with `pip install batdetect2[huggingface]`."
|
||||||
|
) from error
|
||||||
|
|
||||||
|
repo_id, filename = _parse_huggingface_uri(path)
|
||||||
|
return Path(hf_hub_download(repo_id=repo_id, filename=filename))
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_checkpoint_path(path: PathLike | str | None = None) -> Path:
|
||||||
|
"""Resolve a local path, alias, or Hugging Face checkpoint URI.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike | str | None
|
||||||
|
Local checkpoint path, checkpoint alias, or a Hugging Face
|
||||||
|
URI of the form ``hf://owner/repo/path/to/checkpoint.ckpt``. If
|
||||||
|
omitted, the default alias checkpoint is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Path
|
||||||
|
Resolved local filesystem path to the checkpoint.
|
||||||
|
"""
|
||||||
|
if path is None:
|
||||||
|
path = DEFAULT_CHECKPOINT
|
||||||
|
|
||||||
|
if isinstance(path, str) and path.startswith("hf://"):
|
||||||
|
return resolve_checkpoint_from_huggingface(path)
|
||||||
|
|
||||||
|
if isinstance(path, str) and path in CHECKPOINT_ALIASES:
|
||||||
|
return Path(CHECKPOINT_ALIASES[path])
|
||||||
|
|
||||||
|
path = Path(path)
|
||||||
|
if path.exists():
|
||||||
|
return path.resolve()
|
||||||
|
|
||||||
|
bundled_names = ", ".join(get_bundled_checkpoint_names())
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Checkpoint not found: {path}. "
|
||||||
|
"Expected a local path, a checkpoint alias "
|
||||||
|
f"({bundled_names}), or a Hugging Face URI."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_huggingface_uri(uri: str) -> tuple[str, str]:
|
||||||
|
prefix = "hf://"
|
||||||
|
if not uri.startswith(prefix):
|
||||||
|
raise ValueError(
|
||||||
|
"Hugging Face checkpoint URIs must start with 'hf://'."
|
||||||
|
)
|
||||||
|
|
||||||
|
without_prefix = uri.removeprefix(prefix).strip("/")
|
||||||
|
parts = without_prefix.split("/")
|
||||||
|
|
||||||
|
if len(parts) < 3:
|
||||||
|
raise ValueError(
|
||||||
|
"Hugging Face checkpoint URIs must be in the form "
|
||||||
|
"'hf://owner/repo/path/to/checkpoint.ckpt'."
|
||||||
|
)
|
||||||
|
|
||||||
|
repo_id = "/".join(parts[:2])
|
||||||
|
filename = "/".join(parts[2:])
|
||||||
|
return repo_id, filename
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
|
import torch
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.models import Model, ModelConfig, build_model
|
from batdetect2.models import ModelConfig, build_model
|
||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.models.types import ModelOutput, ModelProtocol
|
||||||
from batdetect2.targets import TargetConfig
|
from batdetect2.targets import TargetConfig
|
||||||
|
from batdetect2.train.checkpoints import resolve_checkpoint_path
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import TrainingConfig
|
||||||
from batdetect2.train.losses import build_loss
|
from batdetect2.train.losses import build_loss
|
||||||
from batdetect2.train.optimizers import build_optimizer
|
from batdetect2.train.optimizers import build_optimizer
|
||||||
@ -19,7 +21,7 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class TrainingModule(L.LightningModule):
|
class TrainingModule(L.LightningModule):
|
||||||
model: Model
|
model: ModelProtocol
|
||||||
loss: LossProtocol
|
loss: LossProtocol
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -30,7 +32,7 @@ class TrainingModule(L.LightningModule):
|
|||||||
dimension_names: list[str] | None = None,
|
dimension_names: list[str] | None = None,
|
||||||
train_config: dict | None = None,
|
train_config: dict | None = None,
|
||||||
loss: LossProtocol | None = None,
|
loss: LossProtocol | None = None,
|
||||||
model: Model | None = None,
|
model: ModelProtocol | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -130,23 +132,27 @@ class StoredConfig:
|
|||||||
|
|
||||||
|
|
||||||
def load_model_from_checkpoint(
|
def load_model_from_checkpoint(
|
||||||
path: PathLike,
|
path: PathLike | str | None = None,
|
||||||
) -> tuple[Model, StoredConfig]:
|
) -> tuple[ModelProtocol, StoredConfig]:
|
||||||
"""Load a model and its configuration from a Lightning checkpoint.
|
"""Load a model and its configuration from a Lightning checkpoint.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
path : PathLike
|
path : PathLike | str | None
|
||||||
Path to a ``.ckpt`` file produced by the BatDetect2 training
|
Path to a ``.ckpt`` file produced by the BatDetect2 training
|
||||||
pipeline.
|
pipeline. If omitted, the default bundled checkpoint is used.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
tuple[Model, ModelConfig]
|
tuple[ModelProtocol, ModelConfig]
|
||||||
The restored ``Model`` instance and the ``ModelConfig`` that
|
The restored ``Model`` instance and the ``ModelConfig`` that
|
||||||
describes its architecture, preprocessing, and postprocessing.
|
describes its architecture, preprocessing, and postprocessing.
|
||||||
"""
|
"""
|
||||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
resolved_path = resolve_checkpoint_path(path)
|
||||||
|
module = TrainingModule.load_from_checkpoint(
|
||||||
|
resolved_path,
|
||||||
|
map_location=torch.device("cpu"),
|
||||||
|
)
|
||||||
training_config = TrainingConfig.model_validate(module.train_config)
|
training_config = TrainingConfig.model_validate(module.train_config)
|
||||||
model_config = ModelConfig.model_validate(module.model_config)
|
model_config = ModelConfig.model_validate(module.model_config)
|
||||||
targets_config = TargetConfig.model_validate(module.targets_config)
|
targets_config = TargetConfig.model_validate(module.targets_config)
|
||||||
@ -163,7 +169,7 @@ def build_training_module(
|
|||||||
class_names: list[str] | None = None,
|
class_names: list[str] | None = None,
|
||||||
dimension_names: list[str] | None = None,
|
dimension_names: list[str] | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
model: Model | None = None,
|
model: ModelProtocol | None = None,
|
||||||
) -> TrainingModule:
|
) -> TrainingModule:
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
model_config = ModelConfig()
|
model_config = ModelConfig()
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from lightning.pytorch.loggers import Logger
|
from lightning.pytorch.loggers import Logger
|
||||||
@ -28,7 +29,7 @@ __all__ = [
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class TrainLoggingContext:
|
class TrainLoggingContext:
|
||||||
model_config: ModelConfig
|
model_config: dict[str, Any]
|
||||||
train_config: TrainingConfig
|
train_config: TrainingConfig
|
||||||
audio_config: AudioConfig
|
audio_config: AudioConfig
|
||||||
targets: TargetProtocol
|
targets: TargetProtocol
|
||||||
@ -49,9 +50,10 @@ class ConfigHyperparameterLogging:
|
|||||||
artifact_path: Path,
|
artifact_path: Path,
|
||||||
context: TrainLoggingContext,
|
context: TrainLoggingContext,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
model_config = ModelConfig.model_validate(context.model_config)
|
||||||
logger.log_hyperparams(
|
logger.log_hyperparams(
|
||||||
{
|
{
|
||||||
"model": context.model_config.model_dump(
|
"model": model_config.model_dump(
|
||||||
mode="json",
|
mode="json",
|
||||||
exclude_none=True,
|
exclude_none=True,
|
||||||
),
|
),
|
||||||
|
|||||||
@ -15,7 +15,8 @@ from batdetect2.logging import (
|
|||||||
TensorBoardLoggerConfig,
|
TensorBoardLoggerConfig,
|
||||||
build_logger,
|
build_logger,
|
||||||
)
|
)
|
||||||
from batdetect2.models import Model, ModelConfig, build_model
|
from batdetect2.models import ModelConfig, build_model
|
||||||
|
from batdetect2.models.types import ModelProtocol
|
||||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
ROIMapperProtocol,
|
ROIMapperProtocol,
|
||||||
@ -50,14 +51,13 @@ DEFAULT_LOG_DIR = Path("outputs") / "logs"
|
|||||||
def run_train(
|
def run_train(
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||||
model: Model | None = None,
|
model: ModelProtocol | None = None,
|
||||||
targets: Optional["TargetProtocol"] = None,
|
targets: Optional["TargetProtocol"] = None,
|
||||||
roi_mapper: Optional["ROIMapperProtocol"] = None,
|
roi_mapper: Optional["ROIMapperProtocol"] = None,
|
||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
audio_loader: Optional["AudioLoader"] = None,
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
labeller: Optional["ClipLabeller"] = None,
|
labeller: Optional["ClipLabeller"] = None,
|
||||||
audio_config: Optional[AudioConfig] = None,
|
audio_config: Optional[AudioConfig] = None,
|
||||||
model_config: Optional[ModelConfig] = None,
|
|
||||||
targets_config: TargetConfig | None = None,
|
targets_config: TargetConfig | None = None,
|
||||||
train_config: Optional[TrainingConfig] = None,
|
train_config: Optional[TrainingConfig] = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
@ -75,7 +75,11 @@ def run_train(
|
|||||||
if seed is not None:
|
if seed is not None:
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
|
|
||||||
model_config = model_config or ModelConfig()
|
model_config = (
|
||||||
|
ModelConfig()
|
||||||
|
if model is None
|
||||||
|
else ModelConfig.model_validate(model.get_config())
|
||||||
|
)
|
||||||
targets_config = targets_config or TargetConfig()
|
targets_config = targets_config or TargetConfig()
|
||||||
audio_config = audio_config or AudioConfig()
|
audio_config = audio_config or AudioConfig()
|
||||||
train_config = train_config or TrainingConfig()
|
train_config = train_config or TrainingConfig()
|
||||||
@ -172,7 +176,7 @@ def run_train(
|
|||||||
root_artifact_path.mkdir(parents=True, exist_ok=True)
|
root_artifact_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
logging_context = TrainLoggingContext(
|
logging_context = TrainLoggingContext(
|
||||||
model_config=model_config,
|
model_config=model_config.model_dump(mode="json"),
|
||||||
train_config=train_config,
|
train_config=train_config,
|
||||||
audio_config=audio_config,
|
audio_config=audio_config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
@ -214,7 +218,7 @@ def run_train(
|
|||||||
|
|
||||||
|
|
||||||
def _validate_model_compatibility(
|
def _validate_model_compatibility(
|
||||||
model: Model,
|
model: ModelProtocol,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
class_names: list[str],
|
class_names: list[str],
|
||||||
dimension_names: list[str],
|
dimension_names: list[str],
|
||||||
|
|||||||
@ -200,13 +200,14 @@ def test_user_can_read_extracted_features_per_detection(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""User story: inspect extracted feature vectors per detection."""
|
"""User story: inspect extracted feature vectors per detection."""
|
||||||
|
|
||||||
|
# Given
|
||||||
prediction = api_v2.process_file(example_audio_files[0])
|
prediction = api_v2.process_file(example_audio_files[0])
|
||||||
|
|
||||||
assert len(prediction.detections) > 0
|
# When
|
||||||
|
feature_vectors = [det.features for det in prediction.detections]
|
||||||
|
|
||||||
feature_vectors = [
|
# Then
|
||||||
api_v2.get_detection_features(det) for det in prediction.detections
|
assert len(prediction.detections) > 0
|
||||||
]
|
|
||||||
assert len(feature_vectors) == len(prediction.detections)
|
assert len(feature_vectors) == len(prediction.detections)
|
||||||
assert all(vec.ndim == 1 for vec in feature_vectors)
|
assert all(vec.ndim == 1 for vec in feature_vectors)
|
||||||
assert all(vec.size > 0 for vec in feature_vectors)
|
assert all(vec.size > 0 for vec in feature_vectors)
|
||||||
@ -299,14 +300,20 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
|||||||
value,
|
value,
|
||||||
)
|
)
|
||||||
|
|
||||||
for key, value in source_detector.bbox_head.state_dict().items():
|
for key, value in source_detector.size_head.state_dict().items():
|
||||||
assert key in detector.bbox_head.state_dict()
|
assert key in detector.size_head.state_dict()
|
||||||
torch.testing.assert_close(
|
torch.testing.assert_close(
|
||||||
detector.bbox_head.state_dict()[key],
|
detector.size_head.state_dict()[key],
|
||||||
value,
|
value,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_from_checkpoint_defaults_to_bundled_model() -> None:
|
||||||
|
api = BatDetect2API.from_checkpoint()
|
||||||
|
|
||||||
|
assert api.model.class_names
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
||||||
api_v2: BatDetect2API,
|
api_v2: BatDetect2API,
|
||||||
|
|||||||
@ -18,7 +18,7 @@ def test_user_can_finetune_only_heads(
|
|||||||
|
|
||||||
api = BatDetect2API.from_config()
|
api = BatDetect2API.from_config()
|
||||||
source_classifier_head = api.model.detector.classifier_head
|
source_classifier_head = api.model.detector.classifier_head
|
||||||
source_bbox_head = api.model.detector.bbox_head
|
source_size_head = api.model.detector.size_head
|
||||||
source_backbone = api.model.detector.backbone
|
source_backbone = api.model.detector.backbone
|
||||||
finetune_dir = tmp_path / "heads_only"
|
finetune_dir = tmp_path / "heads_only"
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ def test_user_can_finetune_only_heads(
|
|||||||
|
|
||||||
backbone_params = list(detector.backbone.parameters())
|
backbone_params = list(detector.backbone.parameters())
|
||||||
classifier_params = list(detector.classifier_head.parameters())
|
classifier_params = list(detector.classifier_head.parameters())
|
||||||
bbox_params = list(detector.bbox_head.parameters())
|
bbox_params = list(detector.size_head.parameters())
|
||||||
|
|
||||||
assert backbone_params
|
assert backbone_params
|
||||||
assert classifier_params
|
assert classifier_params
|
||||||
@ -50,7 +50,7 @@ def test_user_can_finetune_only_heads(
|
|||||||
assert finetuned_api is not api
|
assert finetuned_api is not api
|
||||||
assert detector.backbone is source_backbone
|
assert detector.backbone is source_backbone
|
||||||
assert detector.classifier_head is not source_classifier_head
|
assert detector.classifier_head is not source_classifier_head
|
||||||
assert detector.bbox_head is not source_bbox_head
|
assert detector.size_head is not source_size_head
|
||||||
assert list(finetune_dir.rglob("*.ckpt"))
|
assert list(finetune_dir.rglob("*.ckpt"))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -11,7 +11,7 @@ def test_cli_base_help_lists_main_commands() -> None:
|
|||||||
result = CliRunner().invoke(cli, ["--help"])
|
result = CliRunner().invoke(cli, ["--help"])
|
||||||
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert "predict" in result.output
|
assert "process" in result.output
|
||||||
assert "train" in result.output
|
assert "train" in result.output
|
||||||
assert "evaluate" in result.output
|
assert "evaluate" in result.output
|
||||||
assert "data" in result.output
|
assert "data" in result.output
|
||||||
|
|||||||
@ -15,8 +15,8 @@ def test_cli_evaluate_help() -> None:
|
|||||||
result = CliRunner().invoke(cli, ["evaluate", "--help"])
|
result = CliRunner().invoke(cli, ["evaluate", "--help"])
|
||||||
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert "MODEL_PATH" in result.output
|
|
||||||
assert "TEST_DATASET" in result.output
|
assert "TEST_DATASET" in result.output
|
||||||
|
assert "--model" in result.output
|
||||||
assert "--evaluation-config" in result.output
|
assert "--evaluation-config" in result.output
|
||||||
|
|
||||||
|
|
||||||
@ -32,8 +32,9 @@ def test_cli_evaluate_writes_metrics_for_small_dataset(
|
|||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
"evaluate",
|
"evaluate",
|
||||||
str(tiny_checkpoint_path),
|
|
||||||
str(BASE_DIR / "example_data" / "dataset.yaml"),
|
str(BASE_DIR / "example_data" / "dataset.yaml"),
|
||||||
|
"--model",
|
||||||
|
str(tiny_checkpoint_path),
|
||||||
"--base-dir",
|
"--base-dir",
|
||||||
str(BASE_DIR),
|
str(BASE_DIR),
|
||||||
"--workers",
|
"--workers",
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
"""CLI tests for finetune command."""
|
"""CLI tests for finetune command."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
@ -25,8 +26,41 @@ def test_cli_finetune_help() -> None:
|
|||||||
assert "--outputs-config" not in result.output
|
assert "--outputs-config" not in result.output
|
||||||
|
|
||||||
|
|
||||||
def test_cli_finetune_requires_model() -> None:
|
def test_cli_finetune_defaults_to_bundled_model(
|
||||||
"""User story: finetune requires a checkpoint argument."""
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""User story: finetune can use the bundled checkpoint by default."""
|
||||||
|
|
||||||
|
called = {}
|
||||||
|
|
||||||
|
class FakeAPI:
|
||||||
|
def finetune(self, **kwargs):
|
||||||
|
called["finetune"] = kwargs
|
||||||
|
return None
|
||||||
|
|
||||||
|
class FakeBatDetect2API:
|
||||||
|
@classmethod
|
||||||
|
def from_checkpoint(cls, path=None, **kwargs):
|
||||||
|
called["path"] = path
|
||||||
|
called["from_checkpoint_kwargs"] = kwargs
|
||||||
|
return FakeAPI()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"batdetect2.api_v2.BatDetect2API",
|
||||||
|
FakeBatDetect2API,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"batdetect2.data.load_dataset_config",
|
||||||
|
lambda path: SimpleNamespace(path=path),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"batdetect2.data.load_dataset",
|
||||||
|
lambda config, base_dir=None: [],
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"batdetect2.targets.TargetConfig.load",
|
||||||
|
lambda path: SimpleNamespace(path=path),
|
||||||
|
)
|
||||||
|
|
||||||
result = CliRunner().invoke(
|
result = CliRunner().invoke(
|
||||||
cli,
|
cli,
|
||||||
@ -38,8 +72,9 @@ def test_cli_finetune_requires_model() -> None:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result.exit_code != 0
|
assert result.exit_code == 0
|
||||||
assert "--model" in result.output
|
assert called["path"] is None
|
||||||
|
assert "finetune" in called
|
||||||
|
|
||||||
|
|
||||||
def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None:
|
def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None:
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
"""Behavior tests for predict CLI workflows."""
|
"""Behavior tests for process CLI workflows."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -9,10 +9,10 @@ from soundevent import data, io
|
|||||||
from batdetect2.cli import cli
|
from batdetect2.cli import cli
|
||||||
|
|
||||||
|
|
||||||
def test_cli_predict_help() -> None:
|
def test_cli_process_help() -> None:
|
||||||
"""User story: discover available predict modes."""
|
"""User story: discover available process modes."""
|
||||||
|
|
||||||
result = CliRunner().invoke(cli, ["predict", "--help"])
|
result = CliRunner().invoke(cli, ["process", "--help"])
|
||||||
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
assert "directory" in result.output
|
assert "directory" in result.output
|
||||||
@ -21,19 +21,19 @@ def test_cli_predict_help() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_cli_predict_directory_runs_on_real_audio(
|
def test_cli_process_directory_runs_on_real_audio(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
tiny_checkpoint_path: Path,
|
tiny_checkpoint_path: Path,
|
||||||
single_audio_dir: Path,
|
single_audio_dir: Path,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""User story: run prediction for all files in a directory."""
|
"""User story: process all files in a directory."""
|
||||||
|
|
||||||
output_path = tmp_path / "predictions"
|
output_path = tmp_path / "predictions"
|
||||||
|
|
||||||
result = CliRunner().invoke(
|
result = CliRunner().invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
"predict",
|
"process",
|
||||||
"directory",
|
"directory",
|
||||||
str(tiny_checkpoint_path),
|
str(tiny_checkpoint_path),
|
||||||
str(single_audio_dir),
|
str(single_audio_dir),
|
||||||
@ -52,12 +52,12 @@ def test_cli_predict_directory_runs_on_real_audio(
|
|||||||
assert len(list(output_path.glob("*.json"))) == 1
|
assert len(list(output_path.glob("*.json"))) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_cli_predict_file_list_runs_on_real_audio(
|
def test_cli_process_file_list_runs_on_real_audio(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
tiny_checkpoint_path: Path,
|
tiny_checkpoint_path: Path,
|
||||||
single_audio_dir: Path,
|
single_audio_dir: Path,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""User story: run prediction from an explicit list of files."""
|
"""User story: process an explicit list of files."""
|
||||||
|
|
||||||
audio_file = next(single_audio_dir.glob("*.wav"))
|
audio_file = next(single_audio_dir.glob("*.wav"))
|
||||||
file_list = tmp_path / "files.txt"
|
file_list = tmp_path / "files.txt"
|
||||||
@ -68,7 +68,7 @@ def test_cli_predict_file_list_runs_on_real_audio(
|
|||||||
result = CliRunner().invoke(
|
result = CliRunner().invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
"predict",
|
"process",
|
||||||
"file_list",
|
"file_list",
|
||||||
str(tiny_checkpoint_path),
|
str(tiny_checkpoint_path),
|
||||||
str(file_list),
|
str(file_list),
|
||||||
@ -87,12 +87,12 @@ def test_cli_predict_file_list_runs_on_real_audio(
|
|||||||
assert len(list(output_path.glob("*.json"))) == 1
|
assert len(list(output_path.glob("*.json"))) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_cli_predict_dataset_runs_on_aoef_metadata(
|
def test_cli_process_dataset_runs_on_aoef_metadata(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
tiny_checkpoint_path: Path,
|
tiny_checkpoint_path: Path,
|
||||||
single_audio_dir: Path,
|
single_audio_dir: Path,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""User story: predict from AOEF dataset metadata file."""
|
"""User story: process from AOEF dataset metadata file."""
|
||||||
|
|
||||||
audio_file = next(single_audio_dir.glob("*.wav"))
|
audio_file = next(single_audio_dir.glob("*.wav"))
|
||||||
recording = data.Recording.from_file(audio_file)
|
recording = data.Recording.from_file(audio_file)
|
||||||
@ -103,7 +103,7 @@ def test_cli_predict_dataset_runs_on_aoef_metadata(
|
|||||||
)
|
)
|
||||||
annotation_set = data.AnnotationSet(
|
annotation_set = data.AnnotationSet(
|
||||||
name="test",
|
name="test",
|
||||||
description="predict dataset test",
|
description="process dataset test",
|
||||||
clip_annotations=[data.ClipAnnotation(clip=clip, sound_events=[])],
|
clip_annotations=[data.ClipAnnotation(clip=clip, sound_events=[])],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -115,7 +115,7 @@ def test_cli_predict_dataset_runs_on_aoef_metadata(
|
|||||||
result = CliRunner().invoke(
|
result = CliRunner().invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
"predict",
|
"process",
|
||||||
"dataset",
|
"dataset",
|
||||||
str(tiny_checkpoint_path),
|
str(tiny_checkpoint_path),
|
||||||
str(dataset_path),
|
str(dataset_path),
|
||||||
@ -142,7 +142,7 @@ def test_cli_predict_dataset_runs_on_aoef_metadata(
|
|||||||
("soundevent", "*.json", True),
|
("soundevent", "*.json", True),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_cli_predict_directory_supports_output_format_override(
|
def test_cli_process_directory_supports_output_format_override(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
tiny_checkpoint_path: Path,
|
tiny_checkpoint_path: Path,
|
||||||
single_audio_dir: Path,
|
single_audio_dir: Path,
|
||||||
@ -157,7 +157,7 @@ def test_cli_predict_directory_supports_output_format_override(
|
|||||||
result = CliRunner().invoke(
|
result = CliRunner().invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
"predict",
|
"process",
|
||||||
"directory",
|
"directory",
|
||||||
str(tiny_checkpoint_path),
|
str(tiny_checkpoint_path),
|
||||||
str(single_audio_dir),
|
str(single_audio_dir),
|
||||||
@ -180,12 +180,12 @@ def test_cli_predict_directory_supports_output_format_override(
|
|||||||
assert len(list(output_path.glob(expected_pattern))) >= 1
|
assert len(list(output_path.glob(expected_pattern))) >= 1
|
||||||
|
|
||||||
|
|
||||||
def test_cli_predict_dataset_deduplicates_recordings(
|
def test_cli_process_dataset_deduplicates_recordings(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
tiny_checkpoint_path: Path,
|
tiny_checkpoint_path: Path,
|
||||||
single_audio_dir: Path,
|
single_audio_dir: Path,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""User story: duplicated recording entries are predicted once."""
|
"""User story: duplicated recording entries are processed once."""
|
||||||
|
|
||||||
audio_file = next(single_audio_dir.glob("*.wav"))
|
audio_file = next(single_audio_dir.glob("*.wav"))
|
||||||
recording = data.Recording.from_file(audio_file)
|
recording = data.Recording.from_file(audio_file)
|
||||||
@ -215,7 +215,7 @@ def test_cli_predict_dataset_deduplicates_recordings(
|
|||||||
result = CliRunner().invoke(
|
result = CliRunner().invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
"predict",
|
"process",
|
||||||
"dataset",
|
"dataset",
|
||||||
str(tiny_checkpoint_path),
|
str(tiny_checkpoint_path),
|
||||||
str(dataset_path),
|
str(dataset_path),
|
||||||
@ -234,7 +234,7 @@ def test_cli_predict_dataset_deduplicates_recordings(
|
|||||||
assert len(list(output_path.glob("*.nc"))) == 1
|
assert len(list(output_path.glob("*.nc"))) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_cli_predict_rejects_unknown_output_format(
|
def test_cli_process_rejects_unknown_output_format(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
tiny_checkpoint_path: Path,
|
tiny_checkpoint_path: Path,
|
||||||
single_audio_dir: Path,
|
single_audio_dir: Path,
|
||||||
@ -245,7 +245,7 @@ def test_cli_predict_rejects_unknown_output_format(
|
|||||||
result = CliRunner().invoke(
|
result = CliRunner().invoke(
|
||||||
cli,
|
cli,
|
||||||
[
|
[
|
||||||
"predict",
|
"process",
|
||||||
"directory",
|
"directory",
|
||||||
str(tiny_checkpoint_path),
|
str(tiny_checkpoint_path),
|
||||||
str(single_audio_dir),
|
str(single_audio_dir),
|
||||||
|
|||||||
@ -13,7 +13,6 @@ from batdetect2.models.backbones import (
|
|||||||
build_backbone,
|
build_backbone,
|
||||||
load_backbone_config,
|
load_backbone_config,
|
||||||
)
|
)
|
||||||
from batdetect2.models.types import BackboneModel
|
|
||||||
|
|
||||||
|
|
||||||
def test_unet_backbone_config_defaults():
|
def test_unet_backbone_config_defaults():
|
||||||
@ -61,10 +60,11 @@ def test_build_backbone_custom_config():
|
|||||||
assert backbone.encoder.in_channels == 2
|
assert backbone.encoder.in_channels == 2
|
||||||
|
|
||||||
|
|
||||||
def test_build_backbone_returns_backbone_model():
|
def test_build_backbone_returns_unet_backbone():
|
||||||
"""build_backbone always returns a BackboneModel instance."""
|
"""build_backbone returns the default UNet backbone."""
|
||||||
backbone = build_backbone()
|
backbone = build_backbone()
|
||||||
assert isinstance(backbone, BackboneModel)
|
|
||||||
|
assert isinstance(backbone, UNetBackbone)
|
||||||
|
|
||||||
|
|
||||||
def test_registry_has_unet_backbone():
|
def test_registry_has_unet_backbone():
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -19,12 +21,15 @@ def dummy_spectrogram() -> torch.Tensor:
|
|||||||
def test_build_detector_default():
|
def test_build_detector_default():
|
||||||
"""Test building the default detector without a config."""
|
"""Test building the default detector without a config."""
|
||||||
num_classes = 5
|
num_classes = 5
|
||||||
model = build_detector(num_classes=num_classes)
|
model = build_detector(
|
||||||
|
class_names=[f"class_{i}" for i in range(num_classes)],
|
||||||
|
dimension_names=["width", "height"],
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(model, Detector)
|
assert isinstance(model, Detector)
|
||||||
assert model.num_classes == num_classes
|
assert model.num_classes == num_classes
|
||||||
assert isinstance(model.classifier_head, ClassifierHead)
|
assert isinstance(model.classifier_head, ClassifierHead)
|
||||||
assert isinstance(model.bbox_head, BBoxHead)
|
assert isinstance(model.size_head, BBoxHead)
|
||||||
|
|
||||||
|
|
||||||
def test_build_detector_custom_config():
|
def test_build_detector_custom_config():
|
||||||
@ -32,13 +37,19 @@ def test_build_detector_custom_config():
|
|||||||
num_classes = 3
|
num_classes = 3
|
||||||
config = UNetBackboneConfig(in_channels=2, input_height=128)
|
config = UNetBackboneConfig(in_channels=2, input_height=128)
|
||||||
|
|
||||||
model = build_detector(num_classes=num_classes, config=config)
|
model = build_detector(
|
||||||
|
class_names=[f"class_{i}" for i in range(num_classes)],
|
||||||
|
dimension_names=["width", "height"],
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
assert isinstance(model, Detector)
|
assert isinstance(model, Detector)
|
||||||
assert model.backbone.input_height == 128
|
assert model.backbone.input_height == 128
|
||||||
|
|
||||||
assert isinstance(model.backbone.encoder, Encoder)
|
backbone = cast(UNetBackbone, model.backbone)
|
||||||
assert model.backbone.encoder.in_channels == 2
|
|
||||||
|
assert isinstance(backbone.encoder, Encoder)
|
||||||
|
assert backbone.encoder.in_channels == 2
|
||||||
|
|
||||||
|
|
||||||
def test_build_detector_custom_size_channels():
|
def test_build_detector_custom_size_channels():
|
||||||
@ -47,8 +58,8 @@ def test_build_detector_custom_size_channels():
|
|||||||
config = UNetBackboneConfig(in_channels=1, input_height=128)
|
config = UNetBackboneConfig(in_channels=1, input_height=128)
|
||||||
|
|
||||||
model = build_detector(
|
model = build_detector(
|
||||||
num_classes=num_classes,
|
class_names=[f"class_{i}" for i in range(num_classes)],
|
||||||
num_sizes=num_sizes,
|
dimension_names=[f"size_{i}" for i in range(num_sizes)],
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -62,7 +73,11 @@ def test_detector_forward_pass_shapes(dummy_spectrogram):
|
|||||||
num_classes = 4
|
num_classes = 4
|
||||||
# Build model matching the dummy input shape
|
# Build model matching the dummy input shape
|
||||||
config = UNetBackboneConfig(in_channels=1, input_height=256)
|
config = UNetBackboneConfig(in_channels=1, input_height=256)
|
||||||
model = build_detector(num_classes=num_classes, config=config)
|
model = build_detector(
|
||||||
|
class_names=[f"class_{i}" for i in range(num_classes)],
|
||||||
|
dimension_names=["width", "height"],
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
# Process the spectrogram through the model
|
# Process the spectrogram through the model
|
||||||
# PyTorch expects shape (Batch, Channels, Height, Width)
|
# PyTorch expects shape (Batch, Channels, Height, Width)
|
||||||
@ -132,7 +147,11 @@ def test_detector_forward_pass_with_preprocessor(sample_preprocessor):
|
|||||||
config = UNetBackboneConfig(
|
config = UNetBackboneConfig(
|
||||||
in_channels=spec.shape[1], input_height=spec.shape[2]
|
in_channels=spec.shape[1], input_height=spec.shape[2]
|
||||||
)
|
)
|
||||||
model = build_detector(num_classes=3, config=config)
|
model = build_detector(
|
||||||
|
class_names=["class_0", "class_1", "class_2"],
|
||||||
|
dimension_names=["width", "height"],
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
# Process
|
# Process
|
||||||
output = model(spec)
|
output = model(spec)
|
||||||
|
|||||||
@ -1,9 +1,17 @@
|
|||||||
|
import sys
|
||||||
|
import types
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.train import TrainingConfig, run_train
|
from batdetect2.train import TrainingConfig, run_train
|
||||||
|
from batdetect2.train.checkpoints import (
|
||||||
|
DEFAULT_CHECKPOINT,
|
||||||
|
get_bundled_checkpoint_names,
|
||||||
|
resolve_checkpoint_path,
|
||||||
|
)
|
||||||
|
|
||||||
pytestmark = pytest.mark.slow
|
pytestmark = pytest.mark.slow
|
||||||
|
|
||||||
@ -92,3 +100,133 @@ def test_train_controls_which_checkpoints_are_kept(
|
|||||||
assert last_checkpoints
|
assert last_checkpoints
|
||||||
assert len(best_checkpoints) == 1
|
assert len(best_checkpoints) == 1
|
||||||
assert "epoch" in best_checkpoints[0].name
|
assert "epoch" in best_checkpoints[0].name
|
||||||
|
|
||||||
|
|
||||||
|
def test_train_saves_weights_only_checkpoints_by_default(
|
||||||
|
tmp_path: Path,
|
||||||
|
example_annotations: list[data.ClipAnnotation],
|
||||||
|
) -> None:
|
||||||
|
config = _build_fast_train_config()
|
||||||
|
|
||||||
|
run_train(
|
||||||
|
train_annotations=example_annotations[:1],
|
||||||
|
val_annotations=example_annotations[:1],
|
||||||
|
train_config=config,
|
||||||
|
num_epochs=1,
|
||||||
|
train_workers=0,
|
||||||
|
val_workers=0,
|
||||||
|
checkpoint_dir=tmp_path,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
checkpoint_path = next(tmp_path.rglob("*.ckpt"))
|
||||||
|
checkpoint = torch.load(
|
||||||
|
checkpoint_path,
|
||||||
|
map_location="cpu",
|
||||||
|
weights_only=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "state_dict" in checkpoint
|
||||||
|
assert "hyper_parameters" in checkpoint
|
||||||
|
assert "pytorch-lightning_version" in checkpoint
|
||||||
|
assert "optimizer_states" not in checkpoint
|
||||||
|
assert "lr_schedulers" not in checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_checkpoint_path_returns_local_path_unchanged(
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
local_path = tmp_path / "model.ckpt"
|
||||||
|
local_path.write_bytes(b"checkpoint")
|
||||||
|
|
||||||
|
assert resolve_checkpoint_path(local_path) == local_path
|
||||||
|
assert resolve_checkpoint_path(str(local_path)) == local_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
|
||||||
|
assert get_bundled_checkpoint_names() == (
|
||||||
|
DEFAULT_CHECKPOINT,
|
||||||
|
"batdetect2_uk_same",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_checkpoint_path_uses_default_bundled_alias() -> None:
|
||||||
|
resolved = resolve_checkpoint_path()
|
||||||
|
|
||||||
|
assert resolved == resolve_checkpoint_path(DEFAULT_CHECKPOINT)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_checkpoint_path_accepts_bundled_alias() -> None:
|
||||||
|
resolved = resolve_checkpoint_path(DEFAULT_CHECKPOINT)
|
||||||
|
|
||||||
|
assert resolved.name == "batdetect2_uk_same.ckpt"
|
||||||
|
assert resolved.exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_checkpoint_path_prefers_existing_local_path_over_alias(
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
local_path = tmp_path / "uk_same"
|
||||||
|
local_path.write_bytes(b"checkpoint")
|
||||||
|
|
||||||
|
assert resolve_checkpoint_path(local_path) == local_path
|
||||||
|
assert resolve_checkpoint_path(str(local_path)) == local_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_checkpoint_path_downloads_huggingface_checkpoint(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
expected_path = tmp_path / "downloaded.ckpt"
|
||||||
|
|
||||||
|
def fake_hf_hub_download(repo_id: str, filename: str) -> str:
|
||||||
|
assert repo_id == "owner/repo"
|
||||||
|
assert filename == "weights/model.ckpt"
|
||||||
|
return str(expected_path)
|
||||||
|
|
||||||
|
class FakeHuggingFaceHub(types.ModuleType):
|
||||||
|
hf_hub_download = staticmethod(fake_hf_hub_download)
|
||||||
|
|
||||||
|
fake_module = FakeHuggingFaceHub("huggingface_hub")
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"huggingface_hub",
|
||||||
|
fake_module,
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved = resolve_checkpoint_path("hf://owner/repo/weights/model.ckpt")
|
||||||
|
|
||||||
|
assert resolved == expected_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_checkpoint_path_requires_huggingface_dependency(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.delitem(sys.modules, "huggingface_hub", raising=False)
|
||||||
|
|
||||||
|
import builtins
|
||||||
|
|
||||||
|
original_import = builtins.__import__
|
||||||
|
|
||||||
|
def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||||
|
if name == "huggingface_hub":
|
||||||
|
raise ImportError("missing")
|
||||||
|
return original_import(name, globals, locals, fromlist, level)
|
||||||
|
|
||||||
|
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Hugging Face checkpoint support"):
|
||||||
|
resolve_checkpoint_path("hf://owner/repo/weights/model.ckpt")
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None:
|
||||||
|
with pytest.raises(ValueError, match="hf://owner/repo/path/to"):
|
||||||
|
resolve_checkpoint_path("hf://owner/repo")
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_checkpoint_path_rejects_missing_local_path() -> None:
|
||||||
|
with pytest.raises(
|
||||||
|
FileNotFoundError,
|
||||||
|
match="checkpoint alias",
|
||||||
|
):
|
||||||
|
resolve_checkpoint_path("missing.ckpt")
|
||||||
|
|||||||
@ -368,7 +368,7 @@ def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> (
|
|||||||
assert (
|
assert (
|
||||||
rebuilt_detector.classifier_head is not source_detector.classifier_head
|
rebuilt_detector.classifier_head is not source_detector.classifier_head
|
||||||
)
|
)
|
||||||
assert rebuilt_detector.bbox_head is not source_detector.bbox_head
|
assert rebuilt_detector.size_head is not source_detector.size_head
|
||||||
assert rebuilt_model.class_names == ["single_class"]
|
assert rebuilt_model.class_names == ["single_class"]
|
||||||
assert rebuilt_model.dimension_names == ["width", "height"]
|
assert rebuilt_model.dimension_names == ["width", "height"]
|
||||||
|
|
||||||
@ -451,7 +451,6 @@ def test_run_train_rejects_incompatible_model_config(
|
|||||||
model=incompatible_model,
|
model=incompatible_model,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
model_config=incompatible_config,
|
|
||||||
targets_config=targets_config,
|
targets_config=targets_config,
|
||||||
train_config=TrainingConfig(),
|
train_config=TrainingConfig(),
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user