mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Compare commits
17 Commits
5a974711b0
...
b0f85b96e3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b0f85b96e3 | ||
|
|
ce6975770e | ||
|
|
69d8e2d228 | ||
|
|
855a79853b | ||
|
|
6587c6c4e5 | ||
|
|
831925bd95 | ||
|
|
b4efcfcf0f | ||
|
|
5cc5767eff | ||
|
|
2008c8000f | ||
|
|
a27d1bbfd3 | ||
|
|
999dc93d88 | ||
|
|
7c05fb8577 | ||
|
|
31054f64f6 | ||
|
|
84918086c8 | ||
|
|
d83f801515 | ||
|
|
5526ac99fc | ||
|
|
f5afa9881c |
@ -3,6 +3,8 @@ current_version = 1.1.1
|
||||
commit = True
|
||||
tag = True
|
||||
|
||||
[bumpversion:file:batdetect2/__init__.py]
|
||||
[bumpversion:file:src/batdetect2/__init__.py]
|
||||
|
||||
[bumpversion:file:pyproject.toml]
|
||||
|
||||
[bumpversion:file:docs/source/conf.py]
|
||||
|
||||
79
.github/workflows/ci.yml
vendored
Normal file
79
.github/workflows/ci.yml
vendored
Normal file
@ -0,0 +1,79 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
concurrency:
|
||||
group: ci-${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
checks:
|
||||
name: Checks
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install just
|
||||
uses: taiki-e/install-action@just
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
pyproject.toml
|
||||
uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: just install-dev
|
||||
|
||||
- name: Run formatting, lint, and type checks
|
||||
run: just check
|
||||
|
||||
tests:
|
||||
name: Tests (Python ${{ matrix.python-version }})
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version:
|
||||
- "3.10"
|
||||
- "3.11"
|
||||
- "3.12"
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install just
|
||||
uses: taiki-e/install-action@just
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
pyproject.toml
|
||||
uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: just install-dev
|
||||
|
||||
- name: Run test suite
|
||||
run: just test
|
||||
69
.github/workflows/docs-pages.yml
vendored
Normal file
69
.github/workflows/docs-pages.yml
vendored
Normal file
@ -0,0 +1,69 @@
|
||||
name: Docs Pages
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: docs-pages
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build Docs
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install just
|
||||
uses: taiki-e/install-action@just
|
||||
|
||||
- name: Configure GitHub Pages
|
||||
uses: actions/configure-pages@v5
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
pyproject.toml
|
||||
uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: just install-dev
|
||||
|
||||
- name: Build docs
|
||||
run: just check-docs
|
||||
|
||||
- name: Upload Pages artifact
|
||||
uses: actions/upload-pages-artifact@v4
|
||||
with:
|
||||
path: docs/build
|
||||
|
||||
deploy:
|
||||
name: Deploy Docs
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
pages: write
|
||||
id-token: write
|
||||
environment:
|
||||
name: github-pages
|
||||
url: ${{ steps.deployment.outputs.page_url }}
|
||||
|
||||
steps:
|
||||
- name: Deploy to GitHub Pages
|
||||
id: deployment
|
||||
uses: actions/deploy-pages@v4
|
||||
70
.github/workflows/publish-pypi.yml
vendored
Normal file
70
.github/workflows/publish-pypi.yml
vendored
Normal file
@ -0,0 +1,70 @@
|
||||
name: Publish PyPI
|
||||
|
||||
on:
|
||||
release:
|
||||
types:
|
||||
- published
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
group: publish-pypi
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: Build Distributions
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Install just
|
||||
uses: taiki-e/install-action@just
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: |
|
||||
pyproject.toml
|
||||
uv.lock
|
||||
|
||||
- name: Install dependencies
|
||||
run: just install-dev
|
||||
|
||||
- name: Build distributions
|
||||
run: just build-dist
|
||||
|
||||
- name: Upload distributions
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: release-dists
|
||||
path: dist/
|
||||
|
||||
publish:
|
||||
name: Publish to PyPI
|
||||
needs: build
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
id-token: write
|
||||
environment:
|
||||
name: pypi
|
||||
url: https://pypi.org/p/batdetect2
|
||||
|
||||
steps:
|
||||
- name: Download distributions
|
||||
uses: actions/download-artifact@v5
|
||||
with:
|
||||
name: release-dists
|
||||
path: dist/
|
||||
|
||||
- name: Publish to PyPI
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
29
.github/workflows/python-package.yml
vendored
29
.github/workflows/python-package.yml
vendored
@ -1,29 +0,0 @@
|
||||
name: Python package
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
pull_request:
|
||||
branches: ["main"]
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v3
|
||||
with:
|
||||
enable-cache: true
|
||||
cache-dependency-glob: "uv.lock"
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
run: uv python install ${{ matrix.python-version }}
|
||||
- name: Install the project
|
||||
run: uv sync --all-extras --dev
|
||||
- name: Test with pytest
|
||||
run: uv run pytest
|
||||
30
.github/workflows/python-publish.yml
vendored
30
.github/workflows/python-publish.yml
vendored
@ -1,30 +0,0 @@
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [published]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: "3.x"
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install build
|
||||
- name: Build package
|
||||
run: python -m build
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
|
||||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -50,6 +50,7 @@ cover/
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
docs/build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
|
||||
247
README.md
247
README.md
@ -1,202 +1,137 @@
|
||||
# BatDetect2
|
||||
<img style="display: block-inline;" width="64" height="64" src="assets/bat_icon.png"> Code for detecting and classifying bat echolocation calls in high frequency audio recordings.
|
||||
|
||||
## What BatDetect2 is useful for
|
||||
<img style="display:block-inline;" width="64" height="64" src="assets/bat_icon.png">
|
||||
|
||||
BatDetect2 can help you screen recordings for bat calls,
|
||||
find recordings that need expert review,
|
||||
and compare model outputs across sites or projects with appropriate caution.
|
||||
Code for detecting and classifying bat echolocation calls in high-frequency
|
||||
audio recordings.
|
||||
|
||||
It is best used as a tool to support ecological work,
|
||||
not as a replacement for validation or expert interpretation.
|
||||
> [!WARNING]
|
||||
> `batdetect2` 2.0.1 is out.
|
||||
> There are many changes and new recommended workflows.
|
||||
> We have left the previous `batdetect2.api` module intact, but if you run
|
||||
> into issues or want to upgrade, see the
|
||||
> [migration guide](docs/source/legacy/migration-guide.md) in the docs site.
|
||||
>
|
||||
> This update also ships with a refreshed default model.
|
||||
> It was trained in the same way and on the same data as before, but you should still expect small output differences in some cases.
|
||||
|
||||
## Start here
|
||||
## What is BatDetect2
|
||||
|
||||
If you want the simplest current workflow,
|
||||
use the documentation site and start with:
|
||||
BatDetect2 is a deep learning model for detecting and classifying bat
|
||||
echolocation calls.
|
||||
The model generates multiple predictions for each input recording by providing a
|
||||
bounding box and predicted class for each individual call within it.
|
||||
|
||||
- getting started: `docs/source/getting_started.md`
|
||||
- first tutorial: `docs/source/tutorials/run-inference-on-folder.md`
|
||||
This repository also holds `batdetect2`, a Python-based tool to run, train,
|
||||
finetune and evaluate BatDetect2-type models, including the built-in model for
|
||||
detecting UK bat species.
|
||||
You can use the tool from the command line (terminal) or from Python as needed.
|
||||
|
||||
The current docs default to:
|
||||
## Getting Started
|
||||
|
||||
- the current command-line workflow: `batdetect2 predict`
|
||||
- the current Python workflow: `batdetect2.api_v2.BatDetect2API`
|
||||
We have [extensive documentation](docs/source/index.md) on how to use
|
||||
`batdetect2`.
|
||||
See our [getting started](docs/source/getting_started.md) guide and then jump
|
||||
into any of our tutorials:
|
||||
|
||||
If you need the previous workflow based on `batdetect2 detect` or `batdetect2.api`,
|
||||
use the legacy docs section and migration guide in the docs site.
|
||||
- Run the model on a folder of recordings:
|
||||
`docs/source/tutorials/run-inference-on-folder.md`
|
||||
- Train your own model:
|
||||
`docs/source/tutorials/train-a-custom-model.md`
|
||||
- Evaluate your model:
|
||||
`docs/source/tutorials/evaluate-on-a-test-set.md`
|
||||
- Fine-tune a model:
|
||||
`docs/source/tutorials/integrate-with-a-python-pipeline.md`
|
||||
|
||||
## Install BatDetect2
|
||||
### Try the model
|
||||
|
||||
If you already use Python,
|
||||
activate the environment where you want BatDetect2 to live.
|
||||
If you want to try the model for UK bat species without installing anything, you
|
||||
can try the following:
|
||||
|
||||
If not,
|
||||
create a fresh one first so BatDetect2 stays separate from other software on your machine.
|
||||
1. Demo of the model (for UK species) on
|
||||
[huggingface](https://huggingface.co/spaces/macaodha/batdetect2).
|
||||
|
||||
Two common options are:
|
||||
|
||||
* Install the Anaconda Python 3.10 distribution for your operating system from [here](https://www.continuum.io/downloads). Create a new environment and activate it:
|
||||
|
||||
```bash
|
||||
conda create -y --name batdetect2 python==3.10
|
||||
conda activate batdetect2
|
||||
```
|
||||
|
||||
* If you already have Python installed (version >= 3.10,< 3.14), you can create a fresh environment with:
|
||||
|
||||
```bash
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
2. Alternatively, click
|
||||
[here](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb)
|
||||
to run the model using Google Colab.
|
||||
You can also run this notebook locally.
|
||||
|
||||
### Installing BatDetect2
|
||||
You can use pip to install `batdetect2`:
|
||||
|
||||
If you have `uv` installed (if not, we recommend it; follow the instructions
|
||||
[here](https://docs.astral.sh/uv/getting-started/installation/)), then you can
|
||||
run `batdetect2` one-off with
|
||||
|
||||
```bash
|
||||
pip install batdetect2
|
||||
uvx batdetect2
|
||||
```
|
||||
|
||||
Alternatively, download this code from the repository (by clicking on the green button on top right) and unzip it.
|
||||
Once unzipped, run this from extracted folder.
|
||||
or if you want to install it permanently:
|
||||
|
||||
```bash
|
||||
pip install .
|
||||
uv tool install batdetect2
|
||||
```
|
||||
|
||||
Make sure you have the environment activated before installing `batdetect2`.
|
||||
and test it with
|
||||
|
||||
## Run BatDetect2 on a folder of recordings
|
||||
|
||||
Once installed,
|
||||
the simplest current workflow is to run BatDetect2 on a folder of `.wav` files.
|
||||
|
||||
If you are working from this repository checkout,
|
||||
you can use this example checkpoint path:
|
||||
|
||||
```text
|
||||
src/batdetect2/models/checkpoints/Net2DFast_UK_same.pth.tar
|
||||
```bash
|
||||
batdetect2
|
||||
```
|
||||
|
||||
### Run BatDetect2 on a folder of recordings
|
||||
|
||||
Once installed, you can run BatDetect2 on a folder of `.wav` files.
|
||||
By default it will use the model trained on UK data.
|
||||
|
||||
Example command:
|
||||
|
||||
```bash
|
||||
batdetect2 predict directory \
|
||||
src/batdetect2/models/checkpoints/Net2DFast_UK_same.pth.tar \
|
||||
example_data/audio \
|
||||
outputs
|
||||
batdetect2 process directory example_data/audio outputs
|
||||
```
|
||||
|
||||
This will scan the audio files in `example_data/audio`
|
||||
and save model outputs to `outputs`.
|
||||
This will scan the audio files in `example_data/audio` and save model outputs to
|
||||
`outputs`.
|
||||
If you have your own model checkpoint, you can use it:
|
||||
|
||||
For the full beginner walkthrough,
|
||||
use `docs/source/tutorials/run-inference-on-folder.md`.
|
||||
|
||||
## Legacy workflow
|
||||
|
||||
The sections below are kept only for people maintaining older BatDetect2 scripts and analysis pipelines.
|
||||
|
||||
If you are new to BatDetect2,
|
||||
stop here and use the current docs and command above.
|
||||
|
||||
If you really do need the older workflow,
|
||||
the reference material is below.
|
||||
|
||||
|
||||
## Try the model
|
||||
1) You can try a demo of the model (for UK species) on [huggingface](https://huggingface.co/spaces/macaodha/batdetect2).
|
||||
|
||||
2) Alternatively, click [here](https://colab.research.google.com/github/macaodha/batdetect2/blob/master/batdetect2_notebook.ipynb) to run the model using Google Colab. You can also run this notebook locally.
|
||||
|
||||
|
||||
## Running the model on your own data
|
||||
|
||||
After following the above steps to install the code you can run the model on your own data.
|
||||
|
||||
The remainder of this section is legacy reference material.
|
||||
|
||||
|
||||
### Using the command line
|
||||
|
||||
The commands below describe the legacy CLI workflow.
|
||||
|
||||
For new work, prefer the current docs and `batdetect2 predict`.
|
||||
|
||||
You can run the model by opening the command line and typing:
|
||||
```bash
|
||||
batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD
|
||||
```
|
||||
e.g.
|
||||
```bash
|
||||
batdetect2 detect example_data/audio/ example_data/anns/ 0.3
|
||||
batdetect2 process directory --model path/to/checkpoint.ckpt example_data/audio outputs
|
||||
```
|
||||
|
||||
`AUDIO_DIR` is the path on your computer to the audio wav files of interest.
|
||||
`ANN_DIR` is the path on your computer where the model predictions will be saved. The model will output both `.csv` and `.json` results for each audio file.
|
||||
`DETECTION_THRESHOLD` is a number between 0 and 1 specifying the cut-off threshold applied to the calls. A smaller number will result in more calls detected, but with the chance of introducing more mistakes.
|
||||
|
||||
There are also optional arguments, e.g. you can request that the model outputs features (i.e. estimated call parameters) such as duration, max_frequency, etc. by setting the flag `--spec_features`. These will be saved as `*_spec_features.csv` files:
|
||||
`batdetect2 detect example_data/audio/ example_data/anns/ 0.3 --spec_features`
|
||||
|
||||
You can also specify which model to use by setting the `--model_path` argument. If not specified, it will default to using a model trained on UK data e.g.
|
||||
`batdetect2 detect example_data/audio/ example_data/anns/ 0.3 --model_path models/Net2DFast_UK_same.pth.tar`
|
||||
|
||||
|
||||
### Using the Python API
|
||||
|
||||
The examples below describe the legacy Python API.
|
||||
|
||||
For new work, prefer `batdetect2.api_v2.BatDetect2API` and the current docs site.
|
||||
|
||||
If you prefer to process your data within a Python script then you can use the `batdetect2` Python API.
|
||||
|
||||
```python
|
||||
from batdetect2 import api
|
||||
|
||||
AUDIO_FILE = "example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav"
|
||||
|
||||
# Process a whole file
|
||||
results = api.process_file(AUDIO_FILE)
|
||||
|
||||
# Or, load audio and compute spectrograms
|
||||
audio = api.load_audio(AUDIO_FILE)
|
||||
spec = api.generate_spectrogram(audio)
|
||||
|
||||
# And process the audio or the spectrogram with the model
|
||||
detections, features, spec = api.process_audio(audio)
|
||||
detections, features = api.process_spectrogram(spec)
|
||||
|
||||
# Do something else ...
|
||||
```
|
||||
|
||||
You can integrate the detections or the extracted features to your custom analysis pipeline.
|
||||
|
||||
|
||||
## Training the model on your own data
|
||||
Take a look at the training tutorial in the docs site first.
|
||||
|
||||
If you are working from this repository checkout,
|
||||
start with `docs/source/tutorials/train-a-custom-model.md`.
|
||||
|
||||
For the full walkthrough, use
|
||||
`docs/source/tutorials/run-inference-on-folder.md`.
|
||||
|
||||
## Data and annotations
|
||||
The raw audio data and annotations used to train the models in the paper will be added soon.
|
||||
The audio interface used to annotate audio data for training and evaluation is available [here](https://github.com/macaodha/batdetect2_GUI).
|
||||
|
||||
The raw audio data and annotations used to train the models in the paper will be
|
||||
added soon.
|
||||
`batdetect2` supports annotations in various formats and is compatible with the
|
||||
outputs of [`whombat`](https://github.com/mbsantiago/whombat/) and this
|
||||
[earlier version](https://github.com/macaodha/batdetect2_GUI).
|
||||
If you're interested in supporting another format, please reach out or submit a
|
||||
PR.
|
||||
|
||||
## Warning
|
||||
The models developed and shared as part of this repository should be used with caution.
|
||||
While they have been evaluated on held out audio data, great care should be taken when using the model outputs for any form of biodiversity assessment.
|
||||
Your data may differ, and as a result it is very strongly recommended that you validate the model first using data with known species to ensure that the outputs can be trusted.
|
||||
|
||||
The models developed and shared as part of this repository should be used with
|
||||
caution.
|
||||
While they have been evaluated on held-out audio data, great care should be
|
||||
taken when using the model outputs for any form of biodiversity assessment.
|
||||
Your data may differ, and as a result it is very strongly recommended that you
|
||||
validate the model first using data with known species to ensure that the
|
||||
outputs can be trusted.
|
||||
If you train a model, make the best effort to be transparent about its training
|
||||
and evaluation data, and inform downstream users about its limitations.
|
||||
|
||||
## FAQ
|
||||
|
||||
For more information please consult our [FAQ](docs/source/faq.md).
|
||||
|
||||
|
||||
## Reference
|
||||
If you find our work useful in your research please consider citing our paper which you can find [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1):
|
||||
|
||||
If you find our work useful in your research, please consider citing our paper,
|
||||
which you can find
|
||||
[here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1):
|
||||
|
||||
```
|
||||
@article{batdetect2_2022,
|
||||
title = {Towards a General Approach for Bat Echolocation Detection and Classification},
|
||||
@ -207,10 +142,6 @@ If you find our work useful in your research please consider citing our paper wh
|
||||
```
|
||||
|
||||
## Acknowledgements
|
||||
Thanks to all the contributors who spent time collecting and annotating audio data.
|
||||
|
||||
|
||||
### TODOs
|
||||
- [x] Release the code and pretrained model
|
||||
- [ ] Release the datasets and annotations used the experiments in the paper
|
||||
- [ ] Add the scripts used to generate the tables and figures from the paper
|
||||
Thanks to all the contributors who spent time collecting and annotating audio
|
||||
data.
|
||||
|
||||
@ -1,19 +1,20 @@
|
||||
# Getting started
|
||||
|
||||
If you want to run BatDetect2 on your recordings,
|
||||
start with the command-line route below.
|
||||
If you want to run BatDetect2 on your recordings, start with the command-line
|
||||
route below.
|
||||
|
||||
You do not need to write Python code for a standard first run.
|
||||
|
||||
BatDetect2 also has a Python interface,
|
||||
but that is mainly for users writing their own analysis scripts.
|
||||
BatDetect2 also has a Python interface, but that is mainly for users writing
|
||||
their own analysis scripts.
|
||||
|
||||
- Use the command-line route if you want to run an existing model or train your own model by typing commands in a terminal window.
|
||||
- Use the command-line route if you want to run an existing model or train your
|
||||
own model by typing commands in a terminal window.
|
||||
- Use the Python route only if you already want to work in scripts or notebooks.
|
||||
|
||||
```{note}
|
||||
If you are looking for the previous BatDetect2 workflow based on `batdetect2 detect` or `batdetect2.api`, go to {doc}`legacy/index`.
|
||||
New docs default to the current `predict` CLI and `BatDetect2API` workflow.
|
||||
New docs default to the current `process` CLI and `BatDetect2API` workflow.
|
||||
```
|
||||
|
||||
If you want to try BatDetect2 before installing anything locally:
|
||||
@ -27,15 +28,14 @@ If you want to try BatDetect2 before installing anything locally:
|
||||
2. Use a model checkpoint.
|
||||
3. Run the first tutorial on a folder of recordings.
|
||||
|
||||
If that is what you want,
|
||||
you can ignore the Python sections for now.
|
||||
If that is what you want, you can ignore the Python sections for now.
|
||||
|
||||
## Install BatDetect2
|
||||
|
||||
We recommend `uv` for both workflows.
|
||||
|
||||
`uv` is a tool that helps install Python software cleanly,
|
||||
without mixing it into the rest of your machine.
|
||||
`uv` is a tool that helps install Python software cleanly, without mixing it
|
||||
into the rest of your machine.
|
||||
|
||||
- Use `uv tool` to install the CLI.
|
||||
- Use `uv add` to add `batdetect2` as a dependency in a Python project.
|
||||
@ -70,7 +70,8 @@ Go to {doc}`tutorials/run-inference-on-folder` for a complete first run.
|
||||
|
||||
## Choose a model checkpoint
|
||||
|
||||
The current command-line and Python workflows expect an explicit checkpoint path.
|
||||
The current command-line and Python workflows expect an explicit checkpoint
|
||||
path.
|
||||
|
||||
A checkpoint is the saved model file that BatDetect2 will use for prediction.
|
||||
|
||||
@ -85,7 +86,8 @@ In this repository checkout, an example pretrained checkpoint is available at:
|
||||
src/batdetect2/models/checkpoints/Net2DFast_UK_same.pth.tar
|
||||
```
|
||||
|
||||
Use that path in the tutorial commands if you want a concrete starting point from this source tree.
|
||||
Use that path in the tutorial commands if you want a concrete starting point
|
||||
from this source tree.
|
||||
|
||||
## Python route for users writing code
|
||||
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
# How to choose an inference input mode
|
||||
|
||||
Use this guide to decide whether `predict directory`, `predict file_list`, or `predict dataset` is the right entry point for your run.
|
||||
Use this guide to decide whether `process directory`, `process file_list`, or
|
||||
`process dataset` is the right entry point for your run.
|
||||
|
||||
## Use `predict directory` when the recordings already live together
|
||||
## Use `process directory` when the recordings already live together
|
||||
|
||||
This is the simplest choice.
|
||||
|
||||
@ -13,13 +14,13 @@ Use it when:
|
||||
- you are doing a first pass over a folder of recordings.
|
||||
|
||||
```bash
|
||||
batdetect2 predict directory \
|
||||
batdetect2 process directory \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_dir \
|
||||
path/to/outputs
|
||||
```
|
||||
|
||||
## Use `predict file_list` when you need explicit control over the file set
|
||||
## Use `process file_list` when you need explicit control over the file set
|
||||
|
||||
Use it when:
|
||||
|
||||
@ -30,13 +31,13 @@ Use it when:
|
||||
The list file should contain one path per line.
|
||||
|
||||
```bash
|
||||
batdetect2 predict file_list \
|
||||
batdetect2 process file_list \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_files.txt \
|
||||
path/to/outputs
|
||||
```
|
||||
|
||||
## Use `predict dataset` when your workflow is already annotation-set driven
|
||||
## Use `process dataset` when your workflow is already annotation-set driven
|
||||
|
||||
Use it when:
|
||||
|
||||
@ -45,13 +46,14 @@ Use it when:
|
||||
- you want BatDetect2 to resolve recording paths from the annotation set.
|
||||
|
||||
```bash
|
||||
batdetect2 predict dataset \
|
||||
batdetect2 process dataset \
|
||||
path/to/model.ckpt \
|
||||
path/to/annotation_set.json \
|
||||
path/to/outputs
|
||||
```
|
||||
|
||||
The dataset command reads a `soundevent` annotation set and extracts unique recording paths before inference.
|
||||
The dataset command reads a `soundevent` annotation set and extracts unique
|
||||
recording paths before inference.
|
||||
|
||||
## Rule of thumb
|
||||
|
||||
@ -61,6 +63,9 @@ The dataset command reads a `soundevent` annotation set and extracts unique reco
|
||||
|
||||
## Related pages
|
||||
|
||||
- Run batch predictions: {doc}`run-batch-predictions`
|
||||
- Tune inference clipping: {doc}`tune-inference-clipping`
|
||||
- Predict command reference: {doc}`../reference/cli/predict`
|
||||
- Run batch predictions:
|
||||
{doc}`run-batch-predictions`
|
||||
- Tune inference clipping:
|
||||
{doc}`tune-inference-clipping`
|
||||
- Process command reference:
|
||||
{doc}`../reference/cli/predict`
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# How to choose and configure evaluation tasks
|
||||
|
||||
Use this guide when the default evaluation tasks do not match the question you want to answer.
|
||||
Use this guide when the default evaluation tasks do not match the question you
|
||||
want to answer.
|
||||
|
||||
## Know the default first
|
||||
|
||||
@ -24,8 +25,10 @@ Common built-in task families include:
|
||||
Choose based on the question you care about.
|
||||
|
||||
- Use sound-event tasks when you care about individual call events.
|
||||
- Use clip tasks when you care about clip-level presence or clip-level class evidence.
|
||||
- Use top-class detection when you want matching based on the highest-scoring class per detection.
|
||||
- Use clip tasks when you care about clip-level presence or clip-level class
|
||||
evidence.
|
||||
- Use top-class detection when you want matching based on the highest-scoring
|
||||
class per detection.
|
||||
|
||||
## Configure tasks in `EvaluationConfig`
|
||||
|
||||
@ -45,22 +48,27 @@ Pass the config with:
|
||||
|
||||
```bash
|
||||
batdetect2 evaluate \
|
||||
path/to/model.ckpt \
|
||||
path/to/test_dataset.yaml \
|
||||
--model path/to/model.ckpt \
|
||||
--base-dir path/to/project_root \
|
||||
--evaluation-config path/to/evaluation.yaml
|
||||
```
|
||||
|
||||
Include `--base-dir` when the dataset config resolves recordings through relative paths.
|
||||
Include `--base-dir` when the dataset config resolves recordings through
|
||||
relative paths.
|
||||
|
||||
## Change one thing at a time
|
||||
|
||||
When comparing models or settings, avoid changing task definitions, thresholds, matching behavior, and datasets all at once.
|
||||
When comparing models or settings, avoid changing task definitions, thresholds,
|
||||
matching behavior, and datasets all at once.
|
||||
|
||||
Otherwise it becomes hard to explain why the metric changed.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Evaluation tutorial: {doc}`../tutorials/evaluate-on-a-test-set`
|
||||
- Evaluation config reference: {doc}`../reference/evaluation-config`
|
||||
- Evaluation concepts: {doc}`../explanation/evaluation-concepts-and-matching`
|
||||
- Evaluation tutorial:
|
||||
{doc}`../tutorials/evaluate-on-a-test-set`
|
||||
- Evaluation config reference:
|
||||
{doc}`../reference/evaluation-config`
|
||||
- Evaluation concepts:
|
||||
{doc}`../explanation/evaluation-concepts-and-matching`
|
||||
|
||||
@ -46,7 +46,7 @@ Available built-ins:
|
||||
For CLI inference/evaluation, use `--audio-config`.
|
||||
|
||||
```bash
|
||||
batdetect2 predict directory \
|
||||
batdetect2 process directory \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_dir \
|
||||
path/to/outputs \
|
||||
@ -55,10 +55,12 @@ batdetect2 predict directory \
|
||||
|
||||
## 4) Verify quickly on a small subset
|
||||
|
||||
Run on a small folder first and confirm that outputs and runtime are as
|
||||
expected before full-batch runs.
|
||||
Run on a small folder first and confirm that outputs and runtime are as expected
|
||||
before full-batch runs.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Spectrogram settings: {doc}`configure-spectrogram-preprocessing`
|
||||
- Preprocessing config reference: {doc}`../reference/preprocessing-config`
|
||||
- Spectrogram settings:
|
||||
{doc}`configure-spectrogram-preprocessing`
|
||||
- Preprocessing config reference:
|
||||
{doc}`../reference/preprocessing-config`
|
||||
|
||||
@ -1,14 +1,15 @@
|
||||
# How to run batch predictions
|
||||
# How to run batch processing
|
||||
|
||||
This guide shows practical command patterns for directory-based and file-list
|
||||
prediction runs.
|
||||
processing runs.
|
||||
|
||||
Use it after you already know which input mode you want and need concrete command templates for a repeatable batch run.
|
||||
Use it after you already know which input mode you want and need concrete
|
||||
command templates for a repeatable batch run.
|
||||
|
||||
## Predict from a directory
|
||||
## Process a directory
|
||||
|
||||
```bash
|
||||
batdetect2 predict directory \
|
||||
batdetect2 process directory \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_dir \
|
||||
path/to/outputs
|
||||
@ -16,27 +17,29 @@ batdetect2 predict directory \
|
||||
|
||||
Use this when BatDetect2 should discover the audio files for you.
|
||||
|
||||
## Predict from a file list
|
||||
## Process a file list
|
||||
|
||||
```bash
|
||||
batdetect2 predict file_list \
|
||||
batdetect2 process file_list \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_files.txt \
|
||||
path/to/outputs
|
||||
```
|
||||
|
||||
Use this when another part of your workflow already produced the exact recording list to process.
|
||||
Use this when another part of your workflow already produced the exact recording
|
||||
list to process.
|
||||
|
||||
## Predict from a dataset config
|
||||
## Process a dataset config
|
||||
|
||||
```bash
|
||||
batdetect2 predict dataset \
|
||||
batdetect2 process dataset \
|
||||
path/to/model.ckpt \
|
||||
path/to/annotation_set.json \
|
||||
path/to/outputs
|
||||
```
|
||||
|
||||
Use this when your project already has a `soundevent` annotation set and you want to extract unique recording paths from it.
|
||||
Use this when your project already has a `soundevent` annotation set and you
|
||||
want to extract unique recording paths from it.
|
||||
|
||||
## Useful options
|
||||
|
||||
|
||||
@ -1,22 +1,27 @@
|
||||
# How to save predictions in different output formats
|
||||
|
||||
Use this guide when you need BatDetect2 outputs in a specific representation for downstream tools.
|
||||
Use this guide when you need BatDetect2 outputs in a specific representation for
|
||||
downstream tools.
|
||||
|
||||
## Choose the format that matches the job
|
||||
|
||||
Current built-in output formats include:
|
||||
|
||||
- `raw`: one NetCDF file per clip, best for rich structured outputs,
|
||||
- `parquet`: tabular storage for data analysis workflows,
|
||||
- `soundevent`: prediction-set JSON for soundevent-style tooling,
|
||||
- `batdetect2`: legacy per-recording JSON output.
|
||||
- `raw`:
|
||||
one NetCDF file per clip, best for rich structured outputs,
|
||||
- `parquet`:
|
||||
tabular storage for data analysis workflows,
|
||||
- `soundevent`:
|
||||
prediction-set JSON for soundevent-style tooling,
|
||||
- `batdetect2`:
|
||||
legacy per-recording JSON output.
|
||||
|
||||
## Select a format from the CLI
|
||||
|
||||
Use `--format` for quick experiments.
|
||||
|
||||
```bash
|
||||
batdetect2 predict directory \
|
||||
batdetect2 process directory \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_dir \
|
||||
path/to/outputs \
|
||||
@ -25,7 +30,8 @@ batdetect2 predict directory \
|
||||
|
||||
## Use an outputs config for repeatable runs
|
||||
|
||||
Use an outputs config when you want reproducible control over format and transforms.
|
||||
Use an outputs config when you want reproducible control over format and
|
||||
transforms.
|
||||
|
||||
Example:
|
||||
|
||||
@ -43,7 +49,7 @@ transform:
|
||||
Run with:
|
||||
|
||||
```bash
|
||||
batdetect2 predict directory \
|
||||
batdetect2 process directory \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_dir \
|
||||
path/to/outputs \
|
||||
@ -59,6 +65,9 @@ batdetect2 predict directory \
|
||||
|
||||
## Related pages
|
||||
|
||||
- Outputs config reference: {doc}`../reference/outputs-config`
|
||||
- Output formats reference: {doc}`../reference/output-formats`
|
||||
- Output transforms reference: {doc}`../reference/output-transforms`
|
||||
- Outputs config reference:
|
||||
{doc}`../reference/outputs-config`
|
||||
- Output formats reference:
|
||||
{doc}`../reference/output-formats`
|
||||
- Output transforms reference:
|
||||
{doc}`../reference/output-transforms`
|
||||
|
||||
@ -4,7 +4,8 @@ Use this guide to compare detection outputs at different threshold values.
|
||||
|
||||
The goal is not to find a universal threshold.
|
||||
|
||||
The goal is to choose a threshold that fits your reviewed local data and the project trade-off between missed calls and false positives.
|
||||
The goal is to choose a threshold that fits your reviewed local data and the
|
||||
project trade-off between missed calls and false positives.
|
||||
|
||||
## 1) Start with a baseline run
|
||||
|
||||
@ -12,12 +13,12 @@ Run an initial prediction workflow and keep outputs in a dedicated folder.
|
||||
|
||||
## 2) Sweep threshold values
|
||||
|
||||
Run `predict` multiple times with different thresholds (for example `0.1`,
|
||||
Run `process` multiple times with different thresholds (for example `0.1`,
|
||||
`0.3`, `0.5`) and compare output counts and quality on the same validation
|
||||
subset.
|
||||
|
||||
```bash
|
||||
batdetect2 predict directory \
|
||||
batdetect2 process directory \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_dir \
|
||||
path/to/outputs_thr_03 \
|
||||
@ -26,7 +27,8 @@ batdetect2 predict directory \
|
||||
|
||||
Keep each threshold run in a separate output directory.
|
||||
|
||||
That makes it easier to compare counts and inspect example files without mixing results.
|
||||
That makes it easier to compare counts and inspect example files without mixing
|
||||
results.
|
||||
|
||||
## 3) Validate against known calls
|
||||
|
||||
@ -38,7 +40,8 @@ Check both:
|
||||
- obvious false positives,
|
||||
- obvious missed calls.
|
||||
|
||||
If class interpretation matters downstream, inspect class ranking behavior as well, not just detection counts.
|
||||
If class interpretation matters downstream, inspect class ranking behavior as
|
||||
well, not just detection counts.
|
||||
|
||||
## 4) Record your chosen setting
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# How to tune inference clipping
|
||||
|
||||
Use this guide when long recordings need to be split into smaller clips during inference.
|
||||
Use this guide when long recordings need to be split into smaller clips during
|
||||
inference.
|
||||
|
||||
## What clipping controls
|
||||
|
||||
@ -8,14 +9,19 @@ Use this guide when long recordings need to be split into smaller clips during i
|
||||
|
||||
Key fields are:
|
||||
|
||||
- `duration`: clip duration in seconds,
|
||||
- `overlap`: overlap between adjacent clips,
|
||||
- `max_empty`: how much empty padding is allowed,
|
||||
- `discard_empty`: whether empty clips are dropped.
|
||||
- `duration`:
|
||||
clip duration in seconds,
|
||||
- `overlap`:
|
||||
overlap between adjacent clips,
|
||||
- `max_empty`:
|
||||
how much empty padding is allowed,
|
||||
- `discard_empty`:
|
||||
whether empty clips are dropped.
|
||||
|
||||
## Start from the defaults
|
||||
|
||||
Use the built-in clipping behavior first unless you already know you need something else.
|
||||
Use the built-in clipping behavior first unless you already know you need
|
||||
something else.
|
||||
|
||||
Only tune clipping when:
|
||||
|
||||
@ -25,7 +31,7 @@ Only tune clipping when:
|
||||
|
||||
## Override clipping with an inference config
|
||||
|
||||
Create an inference config file and pass it to `predict` or `evaluate`.
|
||||
Create an inference config file and pass it to `process` or `evaluate`.
|
||||
|
||||
Example:
|
||||
|
||||
@ -43,7 +49,7 @@ loader:
|
||||
Run with:
|
||||
|
||||
```bash
|
||||
batdetect2 predict directory \
|
||||
batdetect2 process directory \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_dir \
|
||||
path/to/outputs \
|
||||
@ -52,12 +58,16 @@ batdetect2 predict directory \
|
||||
|
||||
## Validate clipping changes on a small reviewed subset
|
||||
|
||||
Changing clipping changes what the model sees per batch and can change how events near clip boundaries behave.
|
||||
Changing clipping changes what the model sees per batch and can change how
|
||||
events near clip boundaries behave.
|
||||
|
||||
Check a reviewed subset before applying clipping changes to a full project.
|
||||
|
||||
## Related pages
|
||||
|
||||
- Inference config reference: {doc}`../reference/inference-config`
|
||||
- Run batch predictions: {doc}`run-batch-predictions`
|
||||
- Understanding the pipeline: {doc}`../explanation/pipeline-overview`
|
||||
- Inference config reference:
|
||||
{doc}`../reference/inference-config`
|
||||
- Run batch predictions:
|
||||
{doc}`run-batch-predictions`
|
||||
- Understanding the pipeline:
|
||||
{doc}`../explanation/pipeline-overview`
|
||||
|
||||
@ -6,25 +6,20 @@ Welcome to the BatDetect2 documentation.
|
||||
|
||||
`batdetect2` detects bat echolocation calls in audio recordings.
|
||||
|
||||
It can help you screen large collections of recordings,
|
||||
find files that need expert review,
|
||||
and support ecology and conservation work where manual review alone would be slow.
|
||||
It can help you screen large collections of recordings, find files that need
|
||||
expert review, and support ecology and conservation work where manual review
|
||||
alone would be slow.
|
||||
|
||||
In practice,
|
||||
BatDetect2 takes recordings,
|
||||
looks for likely bat calls,
|
||||
draws a box around each detected event,
|
||||
and scores the most likely class for that event.
|
||||
In practice, BatDetect2 takes recordings, looks for likely bat calls, draws a
|
||||
box around each detected event, and scores the most likely class for that event.
|
||||
|
||||
The current default model is trained for 17 UK species.
|
||||
|
||||
The library also supports custom training,
|
||||
fine-tuning,
|
||||
evaluation,
|
||||
and more advanced use from Python.
|
||||
The library also supports custom training, fine-tuning, evaluation, and more
|
||||
advanced use from Python.
|
||||
|
||||
For details on the underlying approach, see the pre-print:
|
||||
[Towards a General Approach for Bat Echolocation Detection and Classification](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)
|
||||
[Towards a General Approach for Bat Echolocation Detection and Classification](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)
|
||||
|
||||
## A good first use for BatDetect2
|
||||
|
||||
@ -56,7 +51,7 @@ Always validate on reviewed local data before using results for ecological infer
|
||||
```{note}
|
||||
Looking for the previous BatDetect2 workflow?
|
||||
See {doc}`legacy/index`.
|
||||
The legacy docs are still available, but new workflows should use `batdetect2 predict` and `BatDetect2API`.
|
||||
The legacy docs are still available, but new workflows should use `batdetect2 process` and `BatDetect2API`.
|
||||
```
|
||||
|
||||
## How to use this site
|
||||
@ -65,8 +60,7 @@ Start with {doc}`getting_started` if you are new.
|
||||
|
||||
Then choose the section that matches what you need.
|
||||
|
||||
If you are here mainly to run the model on recordings,
|
||||
start with Tutorials.
|
||||
If you are here mainly to run the model on recordings, start with Tutorials.
|
||||
|
||||
| Section | Best for | Start here |
|
||||
| --- | --- | --- |
|
||||
@ -81,7 +75,7 @@ start with Tutorials.
|
||||
- GitHub repository:
|
||||
[macaodha/batdetect2](https://github.com/macaodha/batdetect2)
|
||||
- Questions, bug reports, and feature requests:
|
||||
[GitHub Issues](https://github.com/macaodha/batdetect2/issues)
|
||||
[GitHub Issues](https://github.com/macaodha/batdetect2/issues)
|
||||
- Common questions:
|
||||
{doc}`faq`
|
||||
- Want to contribute?
|
||||
|
||||
@ -4,7 +4,7 @@ This page documents the previous CLI workflow based on `batdetect2 detect`.
|
||||
|
||||
```{warning}
|
||||
This is legacy documentation.
|
||||
For new workflows, use `batdetect2 predict directory` instead.
|
||||
For new workflows, use `batdetect2 process directory` instead.
|
||||
If you are migrating, start with {doc}`migration-guide`.
|
||||
```
|
||||
|
||||
@ -27,7 +27,7 @@ Common legacy options included:
|
||||
The closest current CLI entry point is:
|
||||
|
||||
```bash
|
||||
batdetect2 predict directory \
|
||||
batdetect2 process directory \
|
||||
path/to/model.ckpt \
|
||||
path/to/audio_dir \
|
||||
path/to/outputs
|
||||
@ -35,5 +35,7 @@ batdetect2 predict directory \
|
||||
|
||||
## Related pages
|
||||
|
||||
- Migration guide: {doc}`migration-guide`
|
||||
- Current predict docs: {doc}`../reference/cli/predict`
|
||||
- Migration guide:
|
||||
{doc}`migration-guide`
|
||||
- Current process docs:
|
||||
{doc}`../reference/cli/predict`
|
||||
|
||||
@ -2,12 +2,15 @@
|
||||
|
||||
This section documents the previous BatDetect2 workflow.
|
||||
|
||||
Use these pages if you need to keep working with the older `batdetect2 detect` command or the older `batdetect2.api` interface.
|
||||
Use these pages if you need to keep working with the older `batdetect2 detect`
|
||||
command or the older `batdetect2.api` interface.
|
||||
|
||||
For new projects, we recommend the current workflow:
|
||||
|
||||
- CLI: `batdetect2 predict`
|
||||
- Python: `batdetect2.api_v2.BatDetect2API`
|
||||
- CLI:
|
||||
`batdetect2 process`
|
||||
- Python:
|
||||
`batdetect2.api_v2.BatDetect2API`
|
||||
|
||||
If you are moving from the older workflow, start with {doc}`migration-guide`.
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# Migration guide: legacy to current workflows
|
||||
|
||||
Use this guide when moving from the previous BatDetect2 workflow to the current CLI and API.
|
||||
Use this guide when moving from the previous BatDetect2 workflow to the current
|
||||
CLI and API.
|
||||
|
||||
## Who should migrate now
|
||||
|
||||
@ -9,31 +10,37 @@ You should migrate if:
|
||||
- you are starting a new workflow,
|
||||
- you want the current docs path,
|
||||
- you want the newer CLI and API surface,
|
||||
- you are maintaining code that does not depend on the exact legacy JSON or feature outputs.
|
||||
- you are maintaining code that does not depend on the exact legacy JSON or
|
||||
feature outputs.
|
||||
|
||||
You may need the legacy workflow a bit longer if:
|
||||
|
||||
- downstream tooling depends on the exact old output structure,
|
||||
- you rely on older notebooks built around `batdetect2.api`,
|
||||
- you depend on legacy feature extraction outputs without a validated replacement yet.
|
||||
- you depend on legacy feature extraction outputs without a validated
|
||||
replacement yet.
|
||||
|
||||
## CLI mapping
|
||||
|
||||
- `batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD`
|
||||
-> `batdetect2 predict directory MODEL_PATH AUDIO_DIR OUTPUT_PATH --detection-threshold ...`
|
||||
- `batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD` -> `batdetect2
|
||||
process directory MODEL_PATH AUDIO_DIR OUTPUT_PATH --detection-threshold ...`
|
||||
|
||||
Main changes:
|
||||
|
||||
- the model path is now a positional argument on the `predict` subcommand,
|
||||
- the current workflow expects an explicit checkpoint path rather than silently relying on the old default CLI behavior,
|
||||
- the model path is now a positional argument on the `process` subcommand,
|
||||
- the current workflow expects an explicit checkpoint path rather than silently
|
||||
relying on the old default CLI behavior,
|
||||
- output formatting is configurable,
|
||||
- threshold override is an option rather than a required positional argument,
|
||||
- there are separate subcommands for directory, file-list, and dataset-driven inference.
|
||||
- there are separate subcommands for directory, file-list, and dataset-driven
|
||||
inference.
|
||||
|
||||
## Python API mapping
|
||||
|
||||
- old: `import batdetect2.api as api`
|
||||
- current: `from batdetect2.api_v2 import BatDetect2API`
|
||||
- old:
|
||||
`import batdetect2.api as api`
|
||||
- current:
|
||||
`from batdetect2.api_v2 import BatDetect2API`
|
||||
|
||||
Typical migration shape:
|
||||
|
||||
@ -51,7 +58,7 @@ Useful replacements:
|
||||
- legacy `process_file` -> current `BatDetect2API.process_file`
|
||||
- legacy `process_audio` -> current `BatDetect2API.process_audio`
|
||||
- legacy `process_spectrogram` -> current `BatDetect2API.process_spectrogram`
|
||||
- legacy one-off batch loops -> current `process_files` or CLI `predict`
|
||||
- legacy one-off batch loops -> current `process_files` or CLI `process`
|
||||
|
||||
## Output and terminology changes
|
||||
|
||||
@ -78,7 +85,8 @@ Before replacing a legacy workflow in production or research analysis, validate:
|
||||
- that outputs are being saved in the right format,
|
||||
- that downstream code reads the new outputs correctly,
|
||||
- that feature-related assumptions still hold,
|
||||
- that evaluation and ecological interpretation are unchanged only where you have actually verified that.
|
||||
- that evaluation and ecological interpretation are unchanged only where you
|
||||
have actually verified that.
|
||||
|
||||
## Migration checklist
|
||||
|
||||
@ -91,6 +99,9 @@ Before replacing a legacy workflow in production or research analysis, validate:
|
||||
|
||||
## Related pages
|
||||
|
||||
- Current getting started: {doc}`../getting_started`
|
||||
- Current tutorials: {doc}`../tutorials/index`
|
||||
- Current API reference: {doc}`../reference/api`
|
||||
- Current getting started:
|
||||
{doc}`../getting_started`
|
||||
- Current tutorials:
|
||||
{doc}`../tutorials/index`
|
||||
- Current API reference:
|
||||
{doc}`../reference/api`
|
||||
|
||||
@ -1,65 +1,33 @@
|
||||
# `BatDetect2API` reference
|
||||
|
||||
`BatDetect2API` is the main entry point for the current Python workflow.
|
||||
`BatDetect2API` is the main Python entry point for BatDetect2.
|
||||
|
||||
It wraps model loading, inference, evaluation, output formatting, and
|
||||
training-related entry points behind one object.
|
||||
Use it when you want to load a model, run prediction, inspect detections,
|
||||
evaluate results, or train from Python.
|
||||
|
||||
Defined in `batdetect2.api_v2`.
|
||||
|
||||
## Create an API instance
|
||||
## Main ways to create it
|
||||
|
||||
- `BatDetect2API.from_checkpoint(path, ...)`
|
||||
- load a trained checkpoint and optional config overrides.
|
||||
- load a trained checkpoint, a bundled checkpoint alias, or a Hugging Face
|
||||
checkpoint.
|
||||
- `BatDetect2API.from_config(model_config=..., targets_config=..., ...)`
|
||||
- build a full stack from separate config objects.
|
||||
- build a full model stack from config objects.
|
||||
|
||||
## Inference methods
|
||||
## Common tasks
|
||||
|
||||
- `process_file(audio_file, ...)`
|
||||
- run inference for one recording.
|
||||
- `process_files(audio_files, ...)`
|
||||
- run batch inference across a sequence of file paths.
|
||||
- `process_directory(audio_dir, ...)`
|
||||
- run inference across the audio files found in one directory.
|
||||
- `process_clips(clips, ...)`
|
||||
- run inference on an explicit sequence of clip objects.
|
||||
- `process_audio(audio, ...)`
|
||||
- run inference starting from a waveform array.
|
||||
- `process_spectrogram(spec, ...)`
|
||||
- run inference starting from a spectrogram tensor.
|
||||
- Load a checkpoint and run prediction on one file.
|
||||
- Run prediction on many files or clips.
|
||||
- Save predictions in one of the supported output formats.
|
||||
- Evaluate a model on labelled data.
|
||||
- Fine-tune an existing checkpoint on new targets.
|
||||
|
||||
## Prediction inspection helpers
|
||||
## Generated reference
|
||||
|
||||
- `get_top_class_name(detection)`
|
||||
- return the highest-scoring class name for one detection.
|
||||
- `get_class_scores(detection, include_top_class=True, sort_descending=True)`
|
||||
- return ranked `(class_name, score)` pairs.
|
||||
- `get_detection_features(detection)`
|
||||
- return the per-detection feature vector.
|
||||
|
||||
## Audio loading helpers
|
||||
|
||||
- `load_audio(path)`
|
||||
- `load_recording(recording)`
|
||||
- `load_clip(clip)`
|
||||
- `generate_spectrogram(audio)`
|
||||
|
||||
## Output persistence helpers
|
||||
|
||||
- `save_predictions(predictions, path, audio_dir=None, format=None,
|
||||
config=None)`
|
||||
- `load_predictions(path, format=None, config=None)`
|
||||
|
||||
Use these when you want to save programmatic predictions without going through
|
||||
the CLI.
|
||||
|
||||
## Training and evaluation entry points
|
||||
|
||||
- `train(...)`
|
||||
- `finetune(...)`
|
||||
- `evaluate(...)`
|
||||
- `evaluate_predictions(...)`
|
||||
```{eval-rst}
|
||||
.. autoclass:: batdetect2.api_v2.BatDetect2API
|
||||
```
|
||||
|
||||
## Related pages
|
||||
|
||||
|
||||
@ -4,13 +4,13 @@ Legacy detect command
|
||||
.. warning::
|
||||
|
||||
``batdetect2 detect`` is a legacy compatibility command.
|
||||
Prefer ``batdetect2 predict directory`` for new workflows.
|
||||
Prefer ``batdetect2 process directory`` for new workflows.
|
||||
|
||||
Migration at a glance
|
||||
---------------------
|
||||
|
||||
- Legacy: ``batdetect2 detect AUDIO_DIR ANN_DIR DETECTION_THRESHOLD``
|
||||
- Current: ``batdetect2 predict directory MODEL_PATH AUDIO_DIR OUTPUT_PATH``
|
||||
- Current: ``batdetect2 process directory MODEL_PATH AUDIO_DIR OUTPUT_PATH``
|
||||
with optional ``--detection-threshold``
|
||||
|
||||
.. click:: batdetect2.cli.compat:detect
|
||||
|
||||
@ -1,7 +1,10 @@
|
||||
Evaluate command
|
||||
================
|
||||
|
||||
Evaluate a checkpoint against a configured test dataset.
|
||||
Use ``batdetect2 evaluate`` to compare a checkpoint against labelled test data.
|
||||
|
||||
This command writes metrics and any configured artifacts to the output
|
||||
directory.
|
||||
|
||||
.. click:: batdetect2.cli.evaluate:evaluate_command
|
||||
:prog: batdetect2 evaluate
|
||||
|
||||
11
docs/source/reference/cli/finetune.rst
Normal file
11
docs/source/reference/cli/finetune.rst
Normal file
@ -0,0 +1,11 @@
|
||||
Finetune command
|
||||
================
|
||||
|
||||
Use ``batdetect2 finetune`` to adapt an existing checkpoint to a new target
|
||||
definition.
|
||||
|
||||
If you do not pass ``--model``, the bundled ``uk_same`` checkpoint is used.
|
||||
|
||||
.. click:: batdetect2.cli.finetune:finetune_command
|
||||
:prog: batdetect2 finetune
|
||||
:nested: none
|
||||
@ -1,35 +1,33 @@
|
||||
# CLI reference
|
||||
|
||||
Use this section to find the right command quickly, then open the command page
|
||||
for full options and argument details.
|
||||
|
||||
## How to use this section
|
||||
|
||||
1. Start with {doc}`base` for options shared across the CLI.
|
||||
2. Pick the command group or command you need from the command map below.
|
||||
3. Open the linked page for complete autogenerated option reference.
|
||||
for the full option list.
|
||||
|
||||
## Command map
|
||||
|
||||
| Command | Use it for | Required positional args |
|
||||
| --- | --- | --- |
|
||||
| `batdetect2 predict` | Run inference on audio | Depends on subcommand (`directory`, `file_list`, `dataset`) |
|
||||
| `batdetect2 process` | Run inference on audio | Depends on subcommand (`directory`, `file_list`, `dataset`) |
|
||||
| `batdetect2 data` | Inspect and convert dataset configs | Depends on subcommand (`summary`, `convert`) |
|
||||
| `batdetect2 train` | Train or fine-tune models | `TRAIN_DATASET` |
|
||||
| `batdetect2 evaluate` | Evaluate a checkpoint on a test dataset | `MODEL_PATH`, `TEST_DATASET` |
|
||||
| `batdetect2 finetune` | Fine-tune a checkpoint on new targets | `TRAIN_DATASET` plus `--targets` |
|
||||
| `batdetect2 evaluate` | Evaluate a checkpoint on a test dataset | `TEST_DATASET` |
|
||||
| `batdetect2 detect` | Legacy compatibility workflow | `AUDIO_DIR`, `ANN_DIR`, `DETECTION_THRESHOLD` |
|
||||
|
||||
## Global options and conventions
|
||||
## Notes
|
||||
|
||||
- Global CLI options are documented in {doc}`base`.
|
||||
- Paths with spaces should be wrapped in quotes.
|
||||
- Input audio is expected to be mono.
|
||||
- Legacy `detect` uses a required threshold argument, while `predict` uses the
|
||||
optional `--detection-threshold` override.
|
||||
- `process` uses the optional `--detection-threshold` override.
|
||||
- `evaluate` takes `TEST_DATASET` as a positional argument and uses `--model`
|
||||
for the checkpoint override.
|
||||
- `finetune` defaults to the bundled `uk_same` checkpoint if `--model` is not
|
||||
provided.
|
||||
|
||||
```{warning}
|
||||
`batdetect2 detect` is a legacy command.
|
||||
Prefer `batdetect2 predict directory` for new workflows.
|
||||
Prefer `batdetect2 process directory` for new workflows.
|
||||
```
|
||||
|
||||
## Related pages
|
||||
@ -43,9 +41,10 @@ Prefer `batdetect2 predict directory` for new workflows.
|
||||
:maxdepth: 1
|
||||
|
||||
Base command and global options <base>
|
||||
Predict command group <predict>
|
||||
Process command group <predict>
|
||||
Data command group <data>
|
||||
Train command <train>
|
||||
Finetune command <finetune>
|
||||
Evaluate command <evaluate>
|
||||
Legacy detect command <detect_legacy>
|
||||
```
|
||||
|
||||
@ -1,9 +1,17 @@
|
||||
Predict command
|
||||
Process command
|
||||
===============
|
||||
|
||||
Run model inference from a directory, a file list, or a dataset.
|
||||
Use ``--detection-threshold`` to override the model default per run.
|
||||
Use ``batdetect2 process`` to run inference on audio.
|
||||
|
||||
.. click:: batdetect2.cli.inference:predict
|
||||
:prog: batdetect2 predict
|
||||
Choose a subcommand based on how you want to provide the input:
|
||||
|
||||
- ``directory`` for all supported audio files in one folder
|
||||
- ``file_list`` for a text file with one audio path per line
|
||||
- ``dataset`` for recordings referenced by a dataset file
|
||||
|
||||
Use ``--detection-threshold`` when you want to override the configured
|
||||
threshold for one run.
|
||||
|
||||
.. click:: batdetect2.cli.inference:process
|
||||
:prog: batdetect2 process
|
||||
:nested: full
|
||||
|
||||
@ -1,7 +1,11 @@
|
||||
Train command
|
||||
=============
|
||||
|
||||
Train a model from dataset configs or fine-tune from a checkpoint.
|
||||
Use ``batdetect2 train`` to start from a fresh model config or continue from an
|
||||
existing checkpoint.
|
||||
|
||||
If you want to adapt an existing checkpoint to a new target definition, use
|
||||
``batdetect2 finetune`` instead.
|
||||
|
||||
.. click:: batdetect2.cli.train:train_command
|
||||
:prog: batdetect2 train
|
||||
|
||||
@ -3,7 +3,8 @@
|
||||
This tutorial shows how to evaluate a trained checkpoint on a held-out dataset
|
||||
and inspect the output metrics.
|
||||
|
||||
This tutorial is for advanced users who want to compare one trained model against a separate test dataset.
|
||||
This tutorial is for advanced users who want to compare one trained model
|
||||
against a separate test dataset.
|
||||
|
||||
## Before you start
|
||||
|
||||
@ -32,22 +33,22 @@ Use a dataset that was not used for training or tuning.
|
||||
|
||||
A held-out dataset is simply a separate dataset kept aside for evaluation.
|
||||
|
||||
If you tune thresholds or configs on the same dataset that you report as final evaluation, the results will be optimistic.
|
||||
If you tune thresholds or configs on the same dataset that you report as final
|
||||
evaluation, the results will be optimistic.
|
||||
|
||||
## 2. Run evaluation
|
||||
|
||||
```bash
|
||||
batdetect2 evaluate \
|
||||
path/to/model.ckpt \
|
||||
path/to/test_dataset.yaml \
|
||||
--model path/to/model.ckpt \
|
||||
--base-dir path/to/project_root \
|
||||
--output-dir path/to/eval_outputs
|
||||
```
|
||||
|
||||
This command loads the checkpoint,
|
||||
runs prediction on the test dataset,
|
||||
applies the chosen evaluation tasks,
|
||||
and writes metrics and result files to the output directory.
|
||||
This command loads the checkpoint, runs prediction on the test dataset, applies
|
||||
the chosen evaluation tasks, and writes metrics and result files to the output
|
||||
directory.
|
||||
|
||||
Use `--base-dir` whenever the dataset config contains relative paths.
|
||||
|
||||
@ -73,7 +74,8 @@ Check:
|
||||
- which task the metric belongs to,
|
||||
- which thresholding or matching assumptions were used,
|
||||
- whether class-level behavior matches your use case,
|
||||
- whether the failures are concentrated in specific taxa, sites, or recording conditions.
|
||||
- whether the failures are concentrated in specific taxa, sites, or recording
|
||||
conditions.
|
||||
|
||||
## 5. Record the evaluation setup
|
||||
|
||||
@ -85,7 +87,11 @@ That matters for reproducibility and for later model comparisons.
|
||||
|
||||
- Compare thresholds on representative files:
|
||||
{doc}`../how_to/tune-detection-threshold`
|
||||
- Configure evaluation tasks: {doc}`../how_to/choose-and-configure-evaluation-tasks`
|
||||
- Interpret evaluation artifacts: {doc}`../how_to/interpret-evaluation-outputs`
|
||||
- Learn the evaluation concepts: {doc}`../explanation/evaluation-concepts-and-matching`
|
||||
- Check full evaluate options: {doc}`../reference/cli/evaluate`
|
||||
- Configure evaluation tasks:
|
||||
{doc}`../how_to/choose-and-configure-evaluation-tasks`
|
||||
- Interpret evaluation artifacts:
|
||||
{doc}`../how_to/interpret-evaluation-outputs`
|
||||
- Learn the evaluation concepts:
|
||||
{doc}`../explanation/evaluation-concepts-and-matching`
|
||||
- Check full evaluate options:
|
||||
{doc}`../reference/cli/evaluate`
|
||||
|
||||
@ -4,7 +4,8 @@ This tutorial walks through a first end-to-end inference run with the CLI.
|
||||
|
||||
It is the default starting point for new users.
|
||||
|
||||
Use it when you want to run an existing model on a folder of recordings and quickly check what BatDetect2 found.
|
||||
Use it when you want to run an existing model on a folder of recordings and
|
||||
quickly check what BatDetect2 found.
|
||||
|
||||
## Before you start
|
||||
|
||||
@ -24,7 +25,7 @@ src/batdetect2/models/checkpoints/Net2DFast_UK_same.pth.tar
|
||||
|
||||
By the end of this tutorial you will have:
|
||||
|
||||
- run `batdetect2 predict directory`,
|
||||
- run `batdetect2 process directory`,
|
||||
- saved predictions to disk,
|
||||
- checked that BatDetect2 wrote output files,
|
||||
- identified the next pages to use for tuning or customization.
|
||||
@ -48,12 +49,13 @@ project/
|
||||
outputs/
|
||||
```
|
||||
|
||||
## 2. Run prediction on the directory
|
||||
## 2. Run processing on the directory
|
||||
|
||||
Use this command when you want BatDetect2 to scan a folder of recordings automatically.
|
||||
Use this command when you want BatDetect2 to scan a folder of recordings
|
||||
automatically.
|
||||
|
||||
```bash
|
||||
batdetect2 predict directory \
|
||||
batdetect2 process directory \
|
||||
path/to/model.pth.tar \
|
||||
path/to/audio_dir \
|
||||
path/to/outputs
|
||||
@ -70,8 +72,7 @@ What this does:
|
||||
|
||||
After the command completes, inspect the output directory.
|
||||
|
||||
For a first run,
|
||||
the important check is simple:
|
||||
For a first run, the important check is simple:
|
||||
|
||||
- did BatDetect2 create result files,
|
||||
- are they in the output directory you expected,
|
||||
@ -81,8 +82,8 @@ Different workflows can save results in different file formats.
|
||||
|
||||
You do not need to learn those details for the first run.
|
||||
|
||||
If you later need to choose a specific output format,
|
||||
go to {doc}`../how_to/save-predictions-in-different-output-formats`.
|
||||
If you later need to choose a specific output format, go to
|
||||
{doc}`../how_to/save-predictions-in-different-output-formats`.
|
||||
|
||||
## 4. Inspect predictions
|
||||
|
||||
@ -103,13 +104,17 @@ Validation comes next.
|
||||
|
||||
## 5. Tune only after you have a baseline
|
||||
|
||||
If the first run is too noisy or misses obvious calls, tune thresholds on a reviewed subset rather than changing settings blindly across the full dataset.
|
||||
If the first run is too noisy or misses obvious calls, tune thresholds on a
|
||||
reviewed subset rather than changing settings blindly across the full dataset.
|
||||
|
||||
Use {doc}`../how_to/tune-detection-threshold` for that process.
|
||||
|
||||
## What to do next
|
||||
|
||||
- If you need a different input mode, use {doc}`../how_to/choose-an-inference-input-mode`.
|
||||
- If you want to tune sensitivity, use {doc}`../how_to/tune-detection-threshold`.
|
||||
- If you already write code and want more control from Python, use {doc}`integrate-with-a-python-pipeline`.
|
||||
- If you need a different input mode, use
|
||||
{doc}`../how_to/choose-an-inference-input-mode`.
|
||||
- If you want to tune sensitivity, use
|
||||
{doc}`../how_to/tune-detection-threshold`.
|
||||
- If you already write code and want more control from Python, use
|
||||
{doc}`integrate-with-a-python-pipeline`.
|
||||
- If you need full command details, use {doc}`../reference/cli/predict`.
|
||||
|
||||
26
justfile
26
justfile
@ -17,6 +17,10 @@ help:
|
||||
install:
|
||||
uv sync
|
||||
|
||||
# Install full development dependencies for CI and docs builds.
|
||||
install-dev:
|
||||
uv sync --all-extras --dev
|
||||
|
||||
# Testing & Coverage
|
||||
# Run tests using pytest.
|
||||
test:
|
||||
@ -50,6 +54,9 @@ coverage-serve: coverage-html
|
||||
docs:
|
||||
uv run sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
|
||||
|
||||
# Check that documentation builds successfully.
|
||||
check-docs: docs
|
||||
|
||||
# Serve documentation with live reload.
|
||||
docs-serve:
|
||||
uv run sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
|
||||
@ -84,6 +91,25 @@ check-types:
|
||||
# Run all checks (format-check, lint, typecheck).
|
||||
check: check-format check-lint check-types
|
||||
|
||||
# Run the standard CI validation sequence.
|
||||
ci: check test
|
||||
|
||||
# Build source and wheel distributions.
|
||||
build-dist:
|
||||
uv run --with build python -m build
|
||||
|
||||
# Bump the patch version, commit, and tag.
|
||||
bump-patch:
|
||||
uvx bump2version patch
|
||||
|
||||
# Bump the minor version, commit, and tag.
|
||||
bump-minor:
|
||||
uvx bump2version minor
|
||||
|
||||
# Bump the major version, commit, and tag.
|
||||
bump-major:
|
||||
uvx bump2version major
|
||||
|
||||
# Cleaning tasks
|
||||
# Remove Python bytecode and cache.
|
||||
clean-pyc:
|
||||
|
||||
@ -7,7 +7,6 @@ authors = [
|
||||
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" },
|
||||
]
|
||||
dependencies = [
|
||||
"cf-xarray>=0.9.0",
|
||||
"click>=8.1.7",
|
||||
"deepmerge>=2.0",
|
||||
"hydra-core>=1.3.2",
|
||||
@ -16,21 +15,19 @@ dependencies = [
|
||||
"loguru>=0.7.3",
|
||||
"matplotlib>=3.7.1",
|
||||
"netcdf4>=1.6.5",
|
||||
"numba>=0.60",
|
||||
"numpy>=1.23.5",
|
||||
"omegaconf>=2.3.0",
|
||||
"onnx>=1.16.0",
|
||||
"pandas>=1.5.3",
|
||||
"pydantic>=2.0.0",
|
||||
"pyyaml>=6.0.2",
|
||||
"scikit-learn>=1.2.2",
|
||||
"scipy>=1.10.1",
|
||||
"seaborn>=0.13.2",
|
||||
"soundevent[audio,geometry,plot]>=2.10.0",
|
||||
"soundfile>=0.12.1",
|
||||
"tensorboard>=2.16.2",
|
||||
"torch>=1.13.1",
|
||||
"torchaudio>=1.13.1",
|
||||
"torchvision>=0.14.0",
|
||||
"tqdm>=4.66.2",
|
||||
"xarray>=2024.0.0",
|
||||
]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
readme = "README.md"
|
||||
@ -66,6 +63,7 @@ build-backend = "hatchling.build"
|
||||
batdetect2 = "batdetect2.cli:cli"
|
||||
|
||||
[dependency-groups]
|
||||
huggingface = ["huggingface-hub>=0.32.0"]
|
||||
jupyter = ["ipywidgets>=8.1.5", "jupyter>=1.1.1"]
|
||||
marimo = ["marimo>=0.12.2", "pyarrow>=20.0.0"]
|
||||
dev = [
|
||||
|
||||
@ -1,11 +1,25 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from loguru import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
|
||||
logger.disable("batdetect2")
|
||||
|
||||
|
||||
numba_logger = logging.getLogger("numba")
|
||||
numba_logger.setLevel(logging.WARNING)
|
||||
|
||||
__all__ = ["BatDetect2API", "__version__"]
|
||||
__version__ = "1.1.1"
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
if name == "BatDetect2API":
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
|
||||
return BatDetect2API
|
||||
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
@ -3,13 +3,12 @@ from __future__ import annotations
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import AudioConfig, AudioLoader
|
||||
from batdetect2.data import Dataset
|
||||
@ -20,7 +19,8 @@ if TYPE_CHECKING:
|
||||
LoggerConfig,
|
||||
LoggingCallback,
|
||||
)
|
||||
from batdetect2.models import Model, ModelConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.models.types import ModelProtocol
|
||||
from batdetect2.outputs import (
|
||||
OutputFormatConfig,
|
||||
OutputFormatterProtocol,
|
||||
@ -48,6 +48,31 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||
|
||||
|
||||
class BatDetect2API:
|
||||
"""High-level interface for the BatDetect2 workflow.
|
||||
|
||||
Use this to load a model, run inference, inspect detections,
|
||||
evaluate predictions, and train or fine-tune models.
|
||||
|
||||
In most cases, start with :meth:`from_checkpoint` to load a trained model.
|
||||
Use :meth:`from_config` when you want to build a new model with custom
|
||||
configs.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Load the default checkpoint and run prediction on one file.
|
||||
|
||||
>>> from batdetect2.api_v2 import BatDetect2API
|
||||
>>> api = BatDetect2API.from_checkpoint()
|
||||
>>> prediction = api.process_file("recording.wav")
|
||||
|
||||
Load a checkpoint and save predictions for a folder of audio.
|
||||
|
||||
>>> from pathlib import Path
|
||||
>>> api = BatDetect2API.from_checkpoint("uk_same")
|
||||
>>> predictions = api.process_directory("audio")
|
||||
>>> api.save_predictions(predictions, "outputs/")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
@ -65,8 +90,49 @@ class BatDetect2API:
|
||||
evaluator: EvaluatorProtocol,
|
||||
formatter: OutputFormatterProtocol,
|
||||
output_transform: OutputTransformProtocol,
|
||||
model: Model,
|
||||
model: ModelProtocol,
|
||||
):
|
||||
"""Create a fully configured API instance.
|
||||
|
||||
This initializer is mainly for internal use.
|
||||
In most cases, users should create the API with
|
||||
:meth:`from_checkpoint` or :meth:`from_config`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_config : ModelConfig
|
||||
Model configuration.
|
||||
audio_config : AudioConfig
|
||||
Audio loading configuration.
|
||||
train_config : TrainingConfig
|
||||
Training configuration.
|
||||
evaluation_config : EvaluationConfig
|
||||
Evaluation configuration.
|
||||
inference_config : InferenceConfig
|
||||
Inference configuration.
|
||||
outputs_config : OutputsConfig
|
||||
Output formatting configuration.
|
||||
logging_config : AppLoggingConfig
|
||||
Logging configuration.
|
||||
targets : TargetProtocol
|
||||
Target definition used by the model.
|
||||
roi_mapper : ROIMapperProtocol
|
||||
ROI mapping used for size targets.
|
||||
audio_loader : AudioLoader
|
||||
Audio loader.
|
||||
preprocessor : PreprocessorProtocol
|
||||
Preprocessor used before the detector.
|
||||
postprocessor : PostprocessorProtocol
|
||||
Postprocessor used after the detector.
|
||||
evaluator : EvaluatorProtocol
|
||||
Evaluator used for metrics.
|
||||
formatter : OutputFormatterProtocol
|
||||
Default formatter used to save predictions.
|
||||
output_transform : OutputTransformProtocol
|
||||
Transform that converts model outputs into detections.
|
||||
model : ModelProtocol
|
||||
Model instance.
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.audio_config = audio_config
|
||||
self.train_config = train_config
|
||||
@ -91,6 +157,21 @@ class BatDetect2API:
|
||||
path: data.PathLike,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> Dataset:
|
||||
"""Load a set of annotations from a dataset config file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the dataset config file.
|
||||
base_dir : data.PathLike | None, optional
|
||||
Base directory used to resolve relative paths in the dataset
|
||||
config.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dataset
|
||||
Loaded dataset of annotations.
|
||||
"""
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
return load_dataset_from_config(path, base_dir=base_dir)
|
||||
@ -107,12 +188,50 @@ class BatDetect2API:
|
||||
num_epochs: int | None = None,
|
||||
run_name: str | None = None,
|
||||
seed: int | None = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
||||
):
|
||||
"""Train the current model on a set of annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
train_annotations : Sequence[data.ClipAnnotation]
|
||||
Training annotations.
|
||||
val_annotations : Sequence[data.ClipAnnotation] | None, optional
|
||||
Validation annotations. If omitted, training runs without a
|
||||
validation set.
|
||||
train_workers : int, optional
|
||||
Number of worker processes for training data loading.
|
||||
val_workers : int, optional
|
||||
Number of worker processes for validation data loading.
|
||||
checkpoint_dir : Path | None, optional
|
||||
Directory where checkpoints are saved.
|
||||
log_dir : Path | None, optional
|
||||
Directory where logs are written.
|
||||
experiment_name : str | None, optional
|
||||
Experiment name used by the configured logger.
|
||||
num_epochs : int | None, optional
|
||||
Maximum number of training epochs.
|
||||
run_name : str | None, optional
|
||||
Run name used by the configured logger.
|
||||
seed : int | None, optional
|
||||
Random seed for reproducibility.
|
||||
audio_config : AudioConfig | None, optional
|
||||
Audio config override.
|
||||
train_config : TrainingConfig | None, optional
|
||||
Training config override.
|
||||
logger_config : LoggerConfig | None, optional
|
||||
Training logger config override.
|
||||
logging_callbacks : Sequence[LoggingCallback[TrainLoggingContext]], optional
|
||||
Extra logging callbacks to run during training setup.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BatDetect2API
|
||||
This API instance with the trained model.
|
||||
"""
|
||||
from batdetect2.train import run_train
|
||||
|
||||
self.model.train()
|
||||
@ -122,7 +241,6 @@ class BatDetect2API:
|
||||
model=self.model,
|
||||
targets=self.targets,
|
||||
roi_mapper=self.roi_mapper,
|
||||
model_config=model_config or self.model_config,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
train_workers=train_workers,
|
||||
@ -147,7 +265,7 @@ class BatDetect2API:
|
||||
targets_config: TargetConfig,
|
||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||
trainable: Literal[
|
||||
"all", "heads", "classifier_head", "bbox_head"
|
||||
"all", "heads", "classifier_head", "size_head"
|
||||
] = "heads",
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
@ -162,7 +280,52 @@ class BatDetect2API:
|
||||
logger_config: LoggerConfig | None = None,
|
||||
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
||||
) -> "BatDetect2API":
|
||||
"""Fine-tune from a checkpoint using a new target definition."""
|
||||
"""Fine-tune the current model for new target sounds.
|
||||
|
||||
Use this when you want to keep the existing model weights but change
|
||||
the target sounds. You can fine-tune the whole model or just the
|
||||
heads.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
train_annotations : Sequence[data.ClipAnnotation]
|
||||
Training annotations.
|
||||
targets_config : TargetConfig
|
||||
Target definition to train against.
|
||||
val_annotations : Sequence[data.ClipAnnotation] | None, optional
|
||||
Validation annotations.
|
||||
trainable : {"all", "heads", "classifier_head", "size_head"}, optional
|
||||
Which model parameters remain trainable.
|
||||
train_workers : int, optional
|
||||
Number of worker processes for training data loading.
|
||||
val_workers : int, optional
|
||||
Number of worker processes for validation data loading.
|
||||
checkpoint_dir : Path | None, optional
|
||||
Directory where checkpoints are saved.
|
||||
log_dir : Path | None, optional
|
||||
Directory where logs are written.
|
||||
experiment_name : str | None, optional
|
||||
Experiment name used by the configured logger.
|
||||
num_epochs : int | None, optional
|
||||
Maximum number of training epochs.
|
||||
run_name : str | None, optional
|
||||
Run name used by the configured logger.
|
||||
seed : int | None, optional
|
||||
Random seed for reproducibility.
|
||||
audio_config : AudioConfig | None, optional
|
||||
Audio config override.
|
||||
train_config : TrainingConfig | None, optional
|
||||
Training config override.
|
||||
logger_config : LoggerConfig | None, optional
|
||||
Training logger config override.
|
||||
logging_callbacks : Sequence[LoggingCallback[TrainLoggingContext]], optional
|
||||
Extra logging callbacks to run during training setup.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BatDetect2API
|
||||
A new API instance configured for the new targets.
|
||||
"""
|
||||
from batdetect2.evaluate import build_evaluator
|
||||
from batdetect2.models import build_model_with_new_targets
|
||||
from batdetect2.outputs import (
|
||||
@ -225,7 +388,6 @@ class BatDetect2API:
|
||||
model=api.model,
|
||||
targets=api.targets,
|
||||
roi_mapper=api.roi_mapper,
|
||||
model_config=api.model_config,
|
||||
preprocessor=api.preprocessor,
|
||||
audio_loader=api.audio_loader,
|
||||
train_workers=train_workers,
|
||||
@ -257,6 +419,36 @@ class BatDetect2API:
|
||||
outputs_config: OutputsConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
) -> tuple[dict[str, float], list[ClipDetections]]:
|
||||
"""Evaluate the current model on a labelled dataset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
test_annotations : Sequence[data.ClipAnnotation]
|
||||
Labelled clips used for evaluation.
|
||||
num_workers : int, optional
|
||||
Number of worker processes for dataset loading.
|
||||
output_dir : data.PathLike, optional
|
||||
Directory where metrics and plots are written.
|
||||
experiment_name : str | None, optional
|
||||
Experiment name used by the configured logger.
|
||||
run_name : str | None, optional
|
||||
Run name used by the configured logger.
|
||||
save_predictions : bool, optional
|
||||
If ``True``, save formatted predictions alongside metrics.
|
||||
audio_config : AudioConfig | None, optional
|
||||
Audio config override.
|
||||
evaluation_config : EvaluationConfig | None, optional
|
||||
Evaluation config override.
|
||||
outputs_config : OutputsConfig | None, optional
|
||||
Output config override.
|
||||
logger_config : LoggerConfig | None, optional
|
||||
Evaluation logger config override.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[dict[str, float], list[ClipDetections]]
|
||||
Evaluation metrics and per-clip predictions.
|
||||
"""
|
||||
from batdetect2.evaluate import run_evaluate
|
||||
|
||||
return run_evaluate(
|
||||
@ -283,6 +475,22 @@ class BatDetect2API:
|
||||
predictions: Sequence[ClipDetections],
|
||||
output_dir: data.PathLike | None = None,
|
||||
):
|
||||
"""Evaluate an existing set of predictions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
annotations : Sequence[data.ClipAnnotation]
|
||||
Reference annotations.
|
||||
predictions : Sequence[ClipDetections]
|
||||
Predictions to compare against the annotations.
|
||||
output_dir : data.PathLike | None, optional
|
||||
Directory where metrics and plots are written.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, float]
|
||||
Computed evaluation metrics.
|
||||
"""
|
||||
from batdetect2.evaluate import save_evaluation_results
|
||||
|
||||
clip_evals = self.evaluator.evaluate(
|
||||
@ -302,16 +510,65 @@ class BatDetect2API:
|
||||
return metrics
|
||||
|
||||
def load_audio(self, path: data.PathLike) -> np.ndarray:
|
||||
"""Load one audio file into a waveform array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Audio waveform loaded from disk.
|
||||
"""
|
||||
return self.audio_loader.load_file(path)
|
||||
|
||||
def load_recording(self, recording: data.Recording) -> np.ndarray:
|
||||
"""Load one recording object into a waveform array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
Recording object describing the audio to load.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Audio waveform for the requested recording.
|
||||
"""
|
||||
return self.audio_loader.load_recording(recording)
|
||||
|
||||
def load_clip(self, clip: data.Clip) -> np.ndarray:
|
||||
"""Load one clip object into a waveform array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
Clip object describing the section of audio to load.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Audio waveform for the requested clip.
|
||||
"""
|
||||
return self.audio_loader.load_clip(clip)
|
||||
|
||||
def get_top_class_name(self, detection: Detection) -> str:
|
||||
"""Get highest-confidence class name for one detection."""
|
||||
"""Get the name of the highest-confidence class for one detection.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection : Detection
|
||||
Detection whose class scores will be inspected.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Class name with the highest score.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
top_index = int(np.argmax(detection.class_scores))
|
||||
return self.targets.class_names[top_index]
|
||||
@ -323,7 +580,22 @@ class BatDetect2API:
|
||||
include_top_class: bool = True,
|
||||
sort_descending: bool = True,
|
||||
) -> list[tuple[str, float]]:
|
||||
"""Get class score list as ``(class_name, score)`` pairs."""
|
||||
"""Get class scores as ``(class_name, score)`` pairs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection : Detection
|
||||
Detection whose class scores will be returned.
|
||||
include_top_class : bool, optional
|
||||
If ``False``, omit the highest-scoring class from the result.
|
||||
sort_descending : bool, optional
|
||||
If ``True``, sort scores from highest to lowest.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[tuple[str, float]]
|
||||
Class-score pairs for the detection.
|
||||
"""
|
||||
|
||||
scores = [
|
||||
(class_name, float(score))
|
||||
@ -347,16 +619,22 @@ class BatDetect2API:
|
||||
if class_name != top_class_name
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_detection_features(detection: Detection) -> np.ndarray:
|
||||
"""Get extracted feature vector for one detection."""
|
||||
|
||||
return detection.features
|
||||
|
||||
def generate_spectrogram(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
) -> torch.Tensor:
|
||||
"""Convert a waveform array into a spectrogram tensor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : np.ndarray
|
||||
Audio waveform.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Spectrogram tensor ready for model inference.
|
||||
"""
|
||||
import torch
|
||||
|
||||
tensor = torch.tensor(audio).unsqueeze(0)
|
||||
@ -368,6 +646,25 @@ class BatDetect2API:
|
||||
batch_size: int | None = None,
|
||||
detection_threshold: float | None = None,
|
||||
) -> ClipDetections:
|
||||
"""Run inference on one audio file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_file : data.PathLike
|
||||
Path to the audio file.
|
||||
batch_size : int | None, optional
|
||||
Batch size override. If omitted, the inference config value is
|
||||
used.
|
||||
detection_threshold : float | None, optional
|
||||
Detection score threshold override.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ClipDetections
|
||||
Predictions for the full recording.
|
||||
"""
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.postprocess import ClipDetections
|
||||
|
||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||
@ -402,6 +699,20 @@ class BatDetect2API:
|
||||
audio: np.ndarray,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[Detection]:
|
||||
"""Run inference on a waveform array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio : np.ndarray
|
||||
Audio waveform.
|
||||
detection_threshold : float | None, optional
|
||||
Detection score threshold override.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[Detection]
|
||||
Detected calls.
|
||||
"""
|
||||
spec = self.generate_spectrogram(audio)
|
||||
return self.process_spectrogram(
|
||||
spec,
|
||||
@ -414,6 +725,27 @@ class BatDetect2API:
|
||||
start_time: float = 0,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[Detection]:
|
||||
"""Run inference on one spectrogram tensor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Spectrogram tensor for one recording or clip.
|
||||
start_time : float, optional
|
||||
Start time in seconds used when creating detections.
|
||||
detection_threshold : float | None, optional
|
||||
Detection score threshold override.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[Detection]
|
||||
Detected calls.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If a batched spectrogram with more than one item is provided.
|
||||
"""
|
||||
if spec.ndim == 4 and spec.shape[0] > 1:
|
||||
raise ValueError("Batched spectrograms not supported.")
|
||||
|
||||
@ -436,6 +768,20 @@ class BatDetect2API:
|
||||
audio_dir: data.PathLike,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
"""Run inference on all supported audio files in a directory.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_dir : data.PathLike
|
||||
Directory containing audio files.
|
||||
detection_threshold : float | None, optional
|
||||
Detection score threshold override.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ClipDetections]
|
||||
Predictions for all supported audio files found in the directory.
|
||||
"""
|
||||
from soundevent.audio.files import get_audio_files
|
||||
|
||||
files = list(get_audio_files(audio_dir))
|
||||
@ -454,6 +800,30 @@ class BatDetect2API:
|
||||
output_config: OutputsConfig | None = None,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
"""Run inference on multiple audio files.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_files : Sequence[data.PathLike]
|
||||
Audio file paths.
|
||||
batch_size : int | None, optional
|
||||
Batch size override.
|
||||
num_workers : int, optional
|
||||
Number of worker processes for audio loading.
|
||||
audio_config : AudioConfig | None, optional
|
||||
Audio config override.
|
||||
inference_config : InferenceConfig | None, optional
|
||||
Inference config override.
|
||||
output_config : OutputsConfig | None, optional
|
||||
Output config override.
|
||||
detection_threshold : float | None, optional
|
||||
Detection score threshold override.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ClipDetections]
|
||||
Predictions for each input file.
|
||||
"""
|
||||
from batdetect2.inference import process_file_list
|
||||
|
||||
return process_file_list(
|
||||
@ -482,6 +852,30 @@ class BatDetect2API:
|
||||
output_config: OutputsConfig | None = None,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
"""Run inference on multiple clip objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clips : Sequence[data.Clip]
|
||||
Clips to process.
|
||||
batch_size : int | None, optional
|
||||
Batch size override.
|
||||
num_workers : int, optional
|
||||
Number of worker processes for audio loading.
|
||||
audio_config : AudioConfig | None, optional
|
||||
Audio config override.
|
||||
inference_config : InferenceConfig | None, optional
|
||||
Inference config override.
|
||||
output_config : OutputsConfig | None, optional
|
||||
Output config override.
|
||||
detection_threshold : float | None, optional
|
||||
Detection score threshold override.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ClipDetections]
|
||||
Predictions for each input clip.
|
||||
"""
|
||||
from batdetect2.inference import run_batch_inference
|
||||
|
||||
return run_batch_inference(
|
||||
@ -508,6 +902,21 @@ class BatDetect2API:
|
||||
format: str | None = None,
|
||||
config: OutputFormatConfig | None = None,
|
||||
):
|
||||
"""Save predictions to disk in one of the supported output formats.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
predictions : Sequence[ClipDetections]
|
||||
Predictions to save.
|
||||
path : data.PathLike
|
||||
Output file or directory path, depending on the selected format.
|
||||
audio_dir : data.PathLike | None, optional
|
||||
Audio root directory used when writing relative paths.
|
||||
format : str | None, optional
|
||||
Output format name override.
|
||||
config : OutputFormatConfig | None, optional
|
||||
Output format config override.
|
||||
"""
|
||||
from batdetect2.outputs import get_output_formatter
|
||||
|
||||
formatter = self.formatter
|
||||
@ -529,6 +938,22 @@ class BatDetect2API:
|
||||
format: str | None = None,
|
||||
config: OutputFormatConfig | None = None,
|
||||
) -> list[object]:
|
||||
"""Load predictions from disk.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to a saved prediction file or directory.
|
||||
format : str | None, optional
|
||||
Output format name override.
|
||||
config : OutputFormatConfig | None, optional
|
||||
Output format config override.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[object]
|
||||
Loaded prediction objects returned by the selected formatter.
|
||||
"""
|
||||
from batdetect2.outputs import get_output_formatter
|
||||
|
||||
formatter = self.formatter
|
||||
@ -555,6 +980,36 @@ class BatDetect2API:
|
||||
outputs_config: OutputsConfig | None = None,
|
||||
logging_config: AppLoggingConfig | None = None,
|
||||
) -> "BatDetect2API":
|
||||
"""Build an API instance from config objects.
|
||||
|
||||
Use this when you want to create a new model without loading a saved
|
||||
checkpoint.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_config : ModelConfig | None, optional
|
||||
Model config. If omitted, the default model config is used.
|
||||
targets_config : TargetConfig | None, optional
|
||||
Target config. If omitted, the default target config is used.
|
||||
audio_config : AudioConfig | None, optional
|
||||
Audio config. If omitted, the default audio config is used.
|
||||
train_config : TrainingConfig | None, optional
|
||||
Training config. If omitted, the default training config is used.
|
||||
evaluation_config : EvaluationConfig | None, optional
|
||||
Evaluation config. If omitted, the default evaluation config is
|
||||
used.
|
||||
inference_config : InferenceConfig | None, optional
|
||||
Inference config. If omitted, the default inference config is used.
|
||||
outputs_config : OutputsConfig | None, optional
|
||||
Output config. If omitted, the default outputs config is used.
|
||||
logging_config : AppLoggingConfig | None, optional
|
||||
Logging config. If omitted, the default logging config is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BatDetect2API
|
||||
Configured API instance.
|
||||
"""
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
||||
from batdetect2.inference import InferenceConfig
|
||||
@ -653,7 +1108,7 @@ class BatDetect2API:
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
path: data.PathLike,
|
||||
path: data.PathLike | str | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
evaluation_config: EvaluationConfig | None = None,
|
||||
@ -661,6 +1116,31 @@ class BatDetect2API:
|
||||
outputs_config: OutputsConfig | None = None,
|
||||
logging_config: AppLoggingConfig | None = None,
|
||||
) -> "BatDetect2API":
|
||||
"""Build an API instance from a saved checkpoint.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike | str | None, optional
|
||||
Checkpoint path, bundled checkpoint alias, or Hugging Face URI.
|
||||
If omitted, the default bundled checkpoint is used.
|
||||
audio_config : AudioConfig | None, optional
|
||||
Audio config override.
|
||||
train_config : TrainingConfig | None, optional
|
||||
Training config override.
|
||||
evaluation_config : EvaluationConfig | None, optional
|
||||
Evaluation config override.
|
||||
inference_config : InferenceConfig | None, optional
|
||||
Inference config override.
|
||||
outputs_config : OutputsConfig | None, optional
|
||||
Output config override.
|
||||
logging_config : AppLoggingConfig | None, optional
|
||||
Logging config override.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BatDetect2API
|
||||
Configured API instance.
|
||||
"""
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
||||
from batdetect2.inference import InferenceConfig
|
||||
@ -759,7 +1239,7 @@ class BatDetect2API:
|
||||
|
||||
def _set_trainable_parameters(
|
||||
self,
|
||||
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
|
||||
trainable: Literal["all", "heads", "classifier_head", "size_head"],
|
||||
) -> None:
|
||||
detector = self.model.detector
|
||||
|
||||
@ -775,6 +1255,6 @@ class BatDetect2API:
|
||||
for parameter in detector.classifier_head.parameters():
|
||||
parameter.requires_grad = True
|
||||
|
||||
if trainable in {"heads", "bbox_head"}:
|
||||
for parameter in detector.bbox_head.parameters():
|
||||
if trainable in {"heads", "size_head"}:
|
||||
for parameter in detector.size_head.parameters():
|
||||
parameter.requires_grad = True
|
||||
|
||||
@ -3,7 +3,7 @@ from batdetect2.cli.compat import detect
|
||||
from batdetect2.cli.data import data
|
||||
from batdetect2.cli.evaluate import evaluate_command
|
||||
from batdetect2.cli.finetune import finetune_command
|
||||
from batdetect2.cli.inference import predict
|
||||
from batdetect2.cli.inference import process
|
||||
from batdetect2.cli.train import train_command
|
||||
|
||||
__all__ = [
|
||||
@ -13,7 +13,7 @@ __all__ = [
|
||||
"train_command",
|
||||
"finetune_command",
|
||||
"evaluate_command",
|
||||
"predict",
|
||||
"process",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -2,35 +2,39 @@
|
||||
|
||||
import click
|
||||
|
||||
from batdetect2.cli.ascii import BATDETECT_ASCII_ART
|
||||
|
||||
__all__ = [
|
||||
"cli",
|
||||
]
|
||||
|
||||
|
||||
INFO_STR = """
|
||||
BatDetect2 - Detection and Classification
|
||||
Assumes audio files are mono, not stereo.
|
||||
Spaces in the input paths will throw an error. Wrap in quotes.
|
||||
Input files should be short in duration e.g. < 30 seconds.
|
||||
BatDetect2
|
||||
Wrap paths that contain spaces in quotes.
|
||||
"""
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.group(invoke_without_command=True)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
count=True,
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def cli(verbose: int = 0):
|
||||
@click.pass_context
|
||||
def cli(ctx: click.Context, verbose: int = 0):
|
||||
"""Run the BatDetect2 CLI.
|
||||
|
||||
This command initializes logging and exposes subcommands for prediction,
|
||||
training, evaluation, and dataset utilities.
|
||||
Use subcommands to run processing, training, evaluation, and dataset
|
||||
utilities.
|
||||
"""
|
||||
click.echo(INFO_STR)
|
||||
|
||||
if ctx.invoked_subcommand is None:
|
||||
click.echo(BATDETECT_ASCII_ART)
|
||||
click.echo(ctx.get_help())
|
||||
ctx.exit()
|
||||
|
||||
from batdetect2.logging import enable_logging
|
||||
|
||||
enable_logging(verbose)
|
||||
# click.echo(BATDETECT_ASCII_ART)
|
||||
|
||||
@ -15,7 +15,7 @@ DEFAULT_MODEL_PATH = os.path.join(
|
||||
@cli.command(
|
||||
short_help="Legacy detection command.",
|
||||
epilog=(
|
||||
"Deprecated workflow. Prefer `batdetect2 predict directory` for "
|
||||
"Deprecated workflow. Prefer `batdetect2 process directory` for "
|
||||
"new analyses."
|
||||
),
|
||||
)
|
||||
@ -91,11 +91,17 @@ def detect(
|
||||
Note
|
||||
----
|
||||
This command is kept for backwards compatibility. Prefer
|
||||
`batdetect2 predict directory` for new workflows.
|
||||
`batdetect2 process directory` for new workflows.
|
||||
"""
|
||||
from batdetect2 import api
|
||||
from batdetect2.utils.detector_utils import save_results_to_file
|
||||
|
||||
message = (
|
||||
"The `batdetect2 detect` command is deprecated. Prefer "
|
||||
"`batdetect2 process directory` for new analyses."
|
||||
)
|
||||
click.secho(f"WARNING: {message}", fg="yellow", err=True)
|
||||
|
||||
click.echo(f"Loading model: {args['model_path']}")
|
||||
model, params = api.load_model(args["model_path"])
|
||||
|
||||
|
||||
@ -7,9 +7,9 @@ from batdetect2.cli.base import cli
|
||||
__all__ = ["data"]
|
||||
|
||||
|
||||
@cli.group(short_help="Inspect and convert datasets.")
|
||||
@cli.group(short_help="Inspect and manage datasets.")
|
||||
def data():
|
||||
"""Inspect and convert dataset configuration files."""
|
||||
"""Inspect and manage dataset configuration files."""
|
||||
|
||||
|
||||
@data.command(short_help="Print dataset summary information.")
|
||||
@ -64,7 +64,7 @@ def summary(
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
print(f"Number of annotated clips: {len(dataset)}")
|
||||
click.echo(f"Number of annotated clips: {len(dataset)}")
|
||||
|
||||
if targets_path is None:
|
||||
return
|
||||
@ -73,7 +73,7 @@ def summary(
|
||||
|
||||
summary = compute_class_summary(dataset, targets)
|
||||
|
||||
print(summary.sort_values("class_name").to_markdown())
|
||||
click.echo(summary.sort_values("class_name").to_markdown())
|
||||
|
||||
|
||||
@data.command(short_help="Convert dataset config to annotation set.")
|
||||
@ -200,6 +200,6 @@ def convert(
|
||||
if not audio_dir.is_absolute():
|
||||
audio_dir = audio_dir.resolve()
|
||||
|
||||
print(f"Using audio directory: {audio_dir}")
|
||||
click.echo(f"Using audio directory: {audio_dir}")
|
||||
|
||||
io.save(annotation_set, output, audio_dir=audio_dir)
|
||||
|
||||
@ -12,38 +12,40 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||
|
||||
|
||||
@cli.command(name="evaluate", short_help="Evaluate a model checkpoint.")
|
||||
@click.argument("model_path", type=click.Path(exists=True))
|
||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||
@click.option(
|
||||
"--targets",
|
||||
"targets_config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to targets config file.",
|
||||
"--model",
|
||||
"model_path",
|
||||
type=str,
|
||||
help=(
|
||||
"Path to a checkpoint, checkpoint alias, or a Hugging Face "
|
||||
"URI to fine-tune from. Defaults to uk_same"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--audio-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to audio config file.",
|
||||
help="Path to an audio config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--evaluation-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to evaluation config file.",
|
||||
help="Path to an evaluation config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--inference-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to inference config file.",
|
||||
help="Path to an inference config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--outputs-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to outputs config file.",
|
||||
help="Path to an outputs config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--logging-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to logging config file.",
|
||||
help="Path to a logging config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--base-dir",
|
||||
@ -80,24 +82,23 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||
default=0,
|
||||
)
|
||||
def evaluate_command(
|
||||
model_path: Path,
|
||||
test_dataset: Path,
|
||||
base_dir: Path,
|
||||
targets_config: Path | None,
|
||||
audio_config: Path | None,
|
||||
evaluation_config: Path | None,
|
||||
inference_config: Path | None,
|
||||
outputs_config: Path | None,
|
||||
logging_config: Path | None,
|
||||
model_path: str | None = None,
|
||||
base_dir: Path | None = None,
|
||||
audio_config: Path | None = None,
|
||||
evaluation_config: Path | None = None,
|
||||
inference_config: Path | None = None,
|
||||
outputs_config: Path | None = None,
|
||||
logging_config: Path | None = None,
|
||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||
num_workers: int = 0,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
):
|
||||
"""Evaluate a checkpoint against a test dataset.
|
||||
"""Evaluate a checkpoint on a labelled test dataset.
|
||||
|
||||
Loads model and optional override configs, runs evaluation on
|
||||
`test_dataset`, and writes metrics/artifacts to `output_dir`.
|
||||
This command loads a checkpoint, runs evaluation on ``test_dataset``, and
|
||||
writes metrics to ``output_dir``.
|
||||
"""
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.audio import AudioConfig
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal, cast
|
||||
from typing import Literal
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
@ -13,13 +13,6 @@ __all__ = ["finetune_command"]
|
||||
name="finetune", short_help="Fine-tune a checkpoint on new targets."
|
||||
)
|
||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||
@click.option(
|
||||
"--model",
|
||||
"model_path",
|
||||
required=True,
|
||||
type=click.Path(exists=True),
|
||||
help="Path to a checkpoint to fine-tune from.",
|
||||
)
|
||||
@click.option(
|
||||
"--targets",
|
||||
"targets_config",
|
||||
@ -27,6 +20,15 @@ __all__ = ["finetune_command"]
|
||||
type=click.Path(exists=True),
|
||||
help="Path to the new targets config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--model",
|
||||
"model_path",
|
||||
type=str,
|
||||
help=(
|
||||
"Path to a checkpoint, checkpoint alias, or a Hugging Face "
|
||||
"URI to fine-tune from. Defaults to uk_same"
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--val-dataset",
|
||||
type=click.Path(exists=True),
|
||||
@ -57,7 +59,7 @@ __all__ = ["finetune_command"]
|
||||
)
|
||||
@click.option(
|
||||
"--trainable",
|
||||
type=click.Choice(["all", "heads", "classifier_head", "bbox_head"]),
|
||||
type=click.Choice(["all", "heads", "classifier_head", "size_head"]),
|
||||
default="heads",
|
||||
show_default=True,
|
||||
help="Which model parameters remain trainable during fine-tuning.",
|
||||
@ -106,8 +108,8 @@ __all__ = ["finetune_command"]
|
||||
)
|
||||
def finetune_command(
|
||||
train_dataset: Path,
|
||||
model_path: Path,
|
||||
targets_config: Path,
|
||||
model_path: str | None = None,
|
||||
val_dataset: Path | None = None,
|
||||
ckpt_dir: Path | None = None,
|
||||
log_dir: Path | None = None,
|
||||
@ -115,7 +117,9 @@ def finetune_command(
|
||||
training_config: Path | None = None,
|
||||
audio_config: Path | None = None,
|
||||
logging_config: Path | None = None,
|
||||
trainable: str = "heads",
|
||||
trainable: Literal[
|
||||
"all", "heads", "classifier_head", "size_head"
|
||||
] = "heads",
|
||||
seed: int | None = None,
|
||||
num_epochs: int | None = None,
|
||||
train_workers: int = 0,
|
||||
@ -192,10 +196,7 @@ def finetune_command(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
targets_config=target_conf,
|
||||
trainable=cast(
|
||||
Literal["all", "heads", "classifier_head", "bbox_head"],
|
||||
trainable,
|
||||
),
|
||||
trainable=trainable,
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
checkpoint_dir=ckpt_dir,
|
||||
|
||||
@ -13,27 +13,26 @@ if TYPE_CHECKING:
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
|
||||
__all__ = ["predict"]
|
||||
__all__ = ["process"]
|
||||
|
||||
|
||||
@cli.group(name="predict", short_help="Run prediction workflows.")
|
||||
def predict() -> None:
|
||||
"""Run model inference on audio files.
|
||||
@cli.group(name="process", short_help="Run processing workflows.")
|
||||
def process() -> None:
|
||||
"""Run model inference on audio.
|
||||
|
||||
Use one of the subcommands to select inputs from a directory, a text file
|
||||
list, or an annotation dataset.
|
||||
Choose a subcommand based on how you want to provide input audio.
|
||||
"""
|
||||
|
||||
|
||||
def common_predict_options(func):
|
||||
"""Attach options shared by all `predict` subcommands."""
|
||||
"""Attach options shared by all ``process`` subcommands."""
|
||||
|
||||
@click.option(
|
||||
"--audio-config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to an audio config file. Use this to override audio "
|
||||
"loading and preprocessing-related settings."
|
||||
"loading settings."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
@ -41,7 +40,7 @@ def common_predict_options(func):
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to an inference config file. Use this to override "
|
||||
"prediction-time thresholds and behavior."
|
||||
"prediction settings."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
@ -49,23 +48,19 @@ def common_predict_options(func):
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to an outputs config file. Use this to control the "
|
||||
"prediction fields written to disk."
|
||||
"saved output format and fields."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--logging-config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to a logging config file. Use this to customize logging "
|
||||
"format and levels."
|
||||
),
|
||||
help=("Path to a logging config file. Use this to change log output."),
|
||||
)
|
||||
@click.option(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
help=(
|
||||
"Batch size for inference. If omitted, the value from the "
|
||||
"loaded config is used."
|
||||
"Batch size for inference. If omitted, the config value is used."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
@ -82,7 +77,7 @@ def common_predict_options(func):
|
||||
type=str,
|
||||
help=(
|
||||
"Output format name used by the prediction writer. If omitted, "
|
||||
"the default output format is used."
|
||||
"the config default is used."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
@ -91,7 +86,7 @@ def common_predict_options(func):
|
||||
default=None,
|
||||
help=(
|
||||
"Optional detection score threshold override. If omitted, "
|
||||
"the model default threshold is used."
|
||||
"the configured threshold is used."
|
||||
),
|
||||
)
|
||||
@wraps(func)
|
||||
@ -102,7 +97,7 @@ def common_predict_options(func):
|
||||
|
||||
|
||||
def _build_api(
|
||||
model_path: Path,
|
||||
model_path: str,
|
||||
audio_config: Path | None,
|
||||
inference_config: Path | None,
|
||||
outputs_config: Path | None,
|
||||
@ -144,7 +139,7 @@ def _build_api(
|
||||
|
||||
|
||||
def _run_prediction(
|
||||
model_path: Path,
|
||||
model_path: str,
|
||||
audio_files: list[Path],
|
||||
output_path: Path,
|
||||
audio_config: Path | None,
|
||||
@ -191,16 +186,16 @@ def _run_prediction(
|
||||
)
|
||||
|
||||
|
||||
@predict.command(
|
||||
@process.command(
|
||||
name="directory",
|
||||
short_help="Predict on audio files in a directory.",
|
||||
short_help="Process audio files in a directory.",
|
||||
)
|
||||
@click.argument("model_path", type=click.Path(exists=True))
|
||||
@click.argument("model_path", type=str)
|
||||
@click.argument("audio_dir", type=click.Path(exists=True))
|
||||
@click.argument("output_path", type=click.Path())
|
||||
@common_predict_options
|
||||
def predict_directory_command(
|
||||
model_path: Path,
|
||||
model_path: str,
|
||||
audio_dir: Path,
|
||||
output_path: Path,
|
||||
audio_config: Path | None,
|
||||
@ -212,10 +207,10 @@ def predict_directory_command(
|
||||
format_name: str | None,
|
||||
detection_threshold: float | None,
|
||||
) -> None:
|
||||
"""Predict on all audio files in a directory.
|
||||
"""Run processing on all supported audio files in a directory.
|
||||
|
||||
Loads a checkpoint, scans `audio_dir` for supported audio files, runs
|
||||
inference, and saves predictions to `output_path`.
|
||||
This command scans ``audio_dir`` for audio files, runs processing, and
|
||||
saves the results to ``output_path``.
|
||||
"""
|
||||
from soundevent.audio.files import get_audio_files
|
||||
|
||||
@ -235,16 +230,16 @@ def predict_directory_command(
|
||||
)
|
||||
|
||||
|
||||
@predict.command(
|
||||
@process.command(
|
||||
name="file_list",
|
||||
short_help="Predict on paths listed in a text file.",
|
||||
short_help="Process paths listed in a text file.",
|
||||
)
|
||||
@click.argument("model_path", type=click.Path(exists=True))
|
||||
@click.argument("model_path", type=str)
|
||||
@click.argument("file_list", type=click.Path(exists=True))
|
||||
@click.argument("output_path", type=click.Path())
|
||||
@common_predict_options
|
||||
def predict_file_list_command(
|
||||
model_path: Path,
|
||||
model_path: str,
|
||||
file_list: Path,
|
||||
output_path: Path,
|
||||
audio_config: Path | None,
|
||||
@ -256,9 +251,9 @@ def predict_file_list_command(
|
||||
format_name: str | None,
|
||||
detection_threshold: float | None,
|
||||
) -> None:
|
||||
"""Predict on audio files listed in a text file.
|
||||
"""Run processing on audio files listed in a text file.
|
||||
|
||||
The list file should contain one audio path per line. Empty lines are
|
||||
The text file should contain one audio path per line. Empty lines are
|
||||
ignored.
|
||||
"""
|
||||
file_list = Path(file_list)
|
||||
@ -283,16 +278,16 @@ def predict_file_list_command(
|
||||
)
|
||||
|
||||
|
||||
@predict.command(
|
||||
@process.command(
|
||||
name="dataset",
|
||||
short_help="Predict on recordings from a dataset config.",
|
||||
short_help="Process recordings from a dataset config.",
|
||||
)
|
||||
@click.argument("model_path", type=click.Path(exists=True))
|
||||
@click.argument("model_path", type=str)
|
||||
@click.argument("dataset_path", type=click.Path(exists=True))
|
||||
@click.argument("output_path", type=click.Path())
|
||||
@common_predict_options
|
||||
def predict_dataset_command(
|
||||
model_path: Path,
|
||||
model_path: str,
|
||||
dataset_path: Path,
|
||||
output_path: Path,
|
||||
audio_config: Path | None,
|
||||
@ -304,10 +299,10 @@ def predict_dataset_command(
|
||||
format_name: str | None,
|
||||
detection_threshold: float | None,
|
||||
) -> None:
|
||||
"""Predict on recordings referenced in an annotation dataset.
|
||||
"""Run processing on recordings referenced in a dataset file.
|
||||
|
||||
The dataset is read as a soundevent annotation set and unique recording
|
||||
paths are extracted before inference.
|
||||
Recording paths are read from the dataset and each recording is processed
|
||||
once.
|
||||
"""
|
||||
from soundevent import io
|
||||
|
||||
|
||||
@ -13,15 +13,15 @@ __all__ = ["train_command"]
|
||||
@click.option(
|
||||
"--val-dataset",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to validation dataset config file.",
|
||||
help="Path to a validation dataset config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--model",
|
||||
"model_path",
|
||||
type=click.Path(exists=True),
|
||||
type=str,
|
||||
help=(
|
||||
"Path to a checkpoint to continue training from. If omitted, "
|
||||
"training starts from a fresh model config."
|
||||
"Path to a checkpoint, bundled checkpoint alias, or Hugging Face "
|
||||
"URI. If omitted, training starts from a fresh model config."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
@ -36,7 +36,7 @@ __all__ = ["train_command"]
|
||||
"--targets",
|
||||
"targets_config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to targets config file.",
|
||||
help="Path to a targets config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--model-config",
|
||||
@ -46,32 +46,32 @@ __all__ = ["train_command"]
|
||||
@click.option(
|
||||
"--training-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to training config file.",
|
||||
help="Path to a training config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--audio-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to audio config file.",
|
||||
help="Path to an audio config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--evaluation-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to evaluation config file.",
|
||||
help="Path to an evaluation config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--inference-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to inference config file.",
|
||||
help="Path to an inference config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--outputs-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to outputs config file.",
|
||||
help="Path to an outputs config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--logging-config",
|
||||
type=click.Path(exists=True),
|
||||
help="Path to logging config file.",
|
||||
help="Path to a logging config file.",
|
||||
)
|
||||
@click.option(
|
||||
"--ckpt-dir",
|
||||
@ -118,7 +118,7 @@ __all__ = ["train_command"]
|
||||
def train_command(
|
||||
train_dataset: Path,
|
||||
val_dataset: Path | None = None,
|
||||
model_path: Path | None = None,
|
||||
model_path: str | None = None,
|
||||
ckpt_dir: Path | None = None,
|
||||
log_dir: Path | None = None,
|
||||
base_dir: Path | None = None,
|
||||
@ -139,9 +139,8 @@ def train_command(
|
||||
):
|
||||
"""Train a BatDetect2 model.
|
||||
|
||||
Train either from a fresh config (`--model-config`) or by fine-tuning an
|
||||
existing checkpoint (`--model`). Training data are loaded from
|
||||
`train_dataset`, with optional validation data from `--val-dataset`.
|
||||
Start from a fresh model config or continue from an existing checkpoint.
|
||||
Training data are loaded from ``train_dataset``.
|
||||
"""
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.audio import AudioConfig
|
||||
|
||||
@ -102,19 +102,19 @@ def convert_to_annotation_group(
|
||||
x_inds.append(0)
|
||||
y_inds.append(0)
|
||||
|
||||
annotations.append(
|
||||
Annotation(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
class_prob=1.0,
|
||||
det_prob=1.0,
|
||||
individual="0",
|
||||
event=event,
|
||||
class_id=class_id,
|
||||
)
|
||||
)
|
||||
annotation_entry: Annotation = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"low_freq": low_freq,
|
||||
"high_freq": high_freq,
|
||||
"class_prob": 1.0,
|
||||
"det_prob": 1.0,
|
||||
"individual": "0",
|
||||
"event": event,
|
||||
"class": get_recording_class_name(recording),
|
||||
"class_id": class_id,
|
||||
}
|
||||
annotations.append(annotation_entry)
|
||||
|
||||
return {
|
||||
"id": str(recording.path),
|
||||
|
||||
@ -53,7 +53,7 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
def __init__(self, name: str, discriminator: str = "name"):
|
||||
self._name = name
|
||||
self._registry: dict[
|
||||
str, Callable[Concatenate[..., P_Type], T_Type]
|
||||
str, Callable[Concatenate[Any, P_Type], T_Type]
|
||||
] = {}
|
||||
self._discriminator = discriminator
|
||||
self._config_types: dict[str, Type[BaseModel]] = {}
|
||||
@ -80,7 +80,7 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
)
|
||||
|
||||
def decorator(
|
||||
func: Callable[Concatenate[T_Config, P_Type], T_Type],
|
||||
func: Callable[..., T_Type],
|
||||
):
|
||||
self._registry[name] = func
|
||||
return func
|
||||
@ -102,8 +102,8 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
def build(
|
||||
self,
|
||||
config: BaseModel,
|
||||
*args: P_Type.args,
|
||||
**kwargs: P_Type.kwargs,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> T_Type:
|
||||
"""Builds a logic instance from a config object."""
|
||||
|
||||
|
||||
@ -12,13 +12,15 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def _default_tasks() -> list[TaskConfig]:
|
||||
return [
|
||||
DetectionTaskConfig(),
|
||||
ClassificationTaskConfig(),
|
||||
]
|
||||
|
||||
|
||||
class EvaluationConfig(BaseConfig):
|
||||
tasks: List[TaskConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
DetectionTaskConfig(),
|
||||
ClassificationTaskConfig(),
|
||||
]
|
||||
)
|
||||
tasks: List[TaskConfig] = Field(default_factory=_default_tasks)
|
||||
|
||||
|
||||
def get_default_eval_config() -> EvaluationConfig:
|
||||
|
||||
@ -11,7 +11,7 @@ from batdetect2.evaluate.dataset import build_test_loader
|
||||
from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.evaluate.lightning import EvaluationModule
|
||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.models.types import ModelProtocol
|
||||
from batdetect2.outputs import OutputsConfig, build_output_transform
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
@ -22,7 +22,7 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||
|
||||
|
||||
def run_evaluate(
|
||||
model: Model,
|
||||
model: ModelProtocol,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
targets: TargetProtocol,
|
||||
roi_mapper: ROIMapperProtocol,
|
||||
|
||||
@ -7,14 +7,14 @@ from torch.utils.data import DataLoader
|
||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.logging import get_image_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.models.types import ModelProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
|
||||
|
||||
class EvaluationModule(LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
model: ModelProtocol,
|
||||
evaluator: EvaluatorProtocol,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -25,11 +25,15 @@ from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def _default_metrics() -> list[ClassificationMetricConfig]:
|
||||
return [ClassificationAveragePrecisionConfig()]
|
||||
|
||||
|
||||
class ClassificationTaskConfig(BaseSEDTaskConfig):
|
||||
name: Literal["sound_event_classification"] = "sound_event_classification"
|
||||
prefix: str = "classification"
|
||||
metrics: list[ClassificationMetricConfig] = Field(
|
||||
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
||||
default_factory=_default_metrics
|
||||
)
|
||||
plots: list[ClassificationPlotConfig] = Field(default_factory=list)
|
||||
include_generics: bool = True
|
||||
|
||||
@ -23,13 +23,15 @@ from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def _default_metrics() -> list[ClipClassificationMetricConfig]:
|
||||
return [ClipClassificationAveragePrecisionConfig()]
|
||||
|
||||
|
||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||
name: Literal["clip_classification"] = "clip_classification"
|
||||
prefix: str = "clip_classification"
|
||||
metrics: list[ClipClassificationMetricConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ClipClassificationAveragePrecisionConfig(),
|
||||
]
|
||||
default_factory=_default_metrics
|
||||
)
|
||||
plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
@ -22,13 +22,15 @@ from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def _default_metrics() -> list[ClipDetectionMetricConfig]:
|
||||
return [ClipDetectionAveragePrecisionConfig()]
|
||||
|
||||
|
||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||
name: Literal["clip_detection"] = "clip_detection"
|
||||
prefix: str = "clip_detection"
|
||||
metrics: list[ClipDetectionMetricConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ClipDetectionAveragePrecisionConfig(),
|
||||
]
|
||||
default_factory=_default_metrics
|
||||
)
|
||||
plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
@ -24,11 +24,15 @@ from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def _default_metrics() -> list[DetectionMetricConfig]:
|
||||
return [DetectionAveragePrecisionConfig()]
|
||||
|
||||
|
||||
class DetectionTaskConfig(BaseSEDTaskConfig):
|
||||
name: Literal["sound_event_detection"] = "sound_event_detection"
|
||||
prefix: str = "detection"
|
||||
metrics: list[DetectionMetricConfig] = Field(
|
||||
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
||||
default_factory=_default_metrics
|
||||
)
|
||||
plots: list[DetectionPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
@ -24,11 +24,15 @@ from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def _default_metrics() -> list[TopClassMetricConfig]:
|
||||
return [TopClassAveragePrecisionConfig()]
|
||||
|
||||
|
||||
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
|
||||
name: Literal["top_class_detection"] = "top_class_detection"
|
||||
prefix: str = "top_class"
|
||||
metrics: list[TopClassMetricConfig] = Field(
|
||||
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
||||
default_factory=_default_metrics
|
||||
)
|
||||
plots: list[TopClassPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ from batdetect2.inference.clips import get_clips_from_files
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
from batdetect2.inference.dataset import build_inference_loader
|
||||
from batdetect2.inference.lightning import InferenceModule
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.models.types import ModelProtocol
|
||||
from batdetect2.outputs import (
|
||||
OutputsConfig,
|
||||
OutputTransformProtocol,
|
||||
@ -22,7 +22,7 @@ from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
|
||||
|
||||
def run_batch_inference(
|
||||
model: Model,
|
||||
model: ModelProtocol,
|
||||
clips: Sequence[data.Clip],
|
||||
targets: TargetProtocol | None = None,
|
||||
roi_mapper: ROIMapperProtocol | None = None,
|
||||
@ -86,7 +86,7 @@ def run_batch_inference(
|
||||
|
||||
|
||||
def process_file_list(
|
||||
model: Model,
|
||||
model: ModelProtocol,
|
||||
paths: Sequence[data.PathLike],
|
||||
targets: TargetProtocol | None = None,
|
||||
roi_mapper: ROIMapperProtocol | None = None,
|
||||
|
||||
@ -4,7 +4,7 @@ from lightning import LightningModule
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.models.types import ModelProtocol
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
@ -13,7 +13,7 @@ from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
class InferenceModule(LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
model: ModelProtocol,
|
||||
targets: TargetProtocol | None = None,
|
||||
roi_mapper: ROIMapperProtocol | None = None,
|
||||
output_transform: OutputTransformProtocol | None = None,
|
||||
|
||||
@ -62,7 +62,7 @@ from batdetect2.models.encoder import (
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||
from batdetect2.models.types import DetectionModel
|
||||
from batdetect2.models.types import DetectorProtocol, ModelProtocol
|
||||
from batdetect2.postprocess.config import PostprocessConfig
|
||||
from batdetect2.postprocess.types import (
|
||||
ClipDetectionsTensor,
|
||||
@ -149,7 +149,7 @@ class Model(torch.nn.Module):
|
||||
|
||||
Attributes
|
||||
----------
|
||||
detector : DetectionModel
|
||||
detector : DetectorProtocol
|
||||
The neural network that processes spectrograms and produces raw
|
||||
detection, classification, and bounding-box outputs.
|
||||
preprocessor : PreprocessorProtocol
|
||||
@ -164,19 +164,21 @@ class Model(torch.nn.Module):
|
||||
Size-dimension names corresponding to the model size outputs.
|
||||
"""
|
||||
|
||||
detector: DetectionModel
|
||||
detector: DetectorProtocol
|
||||
preprocessor: PreprocessorProtocol
|
||||
postprocessor: PostprocessorProtocol
|
||||
class_names: list[str]
|
||||
dimension_names: list[str]
|
||||
_config: dict[str, object]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detector: DetectionModel,
|
||||
detector: DetectorProtocol,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
class_names: list[str],
|
||||
dimension_names: list[str],
|
||||
config: dict[str, object],
|
||||
):
|
||||
super().__init__()
|
||||
self.detector = detector
|
||||
@ -184,6 +186,12 @@ class Model(torch.nn.Module):
|
||||
self.postprocessor = postprocessor
|
||||
self.class_names = class_names
|
||||
self.dimension_names = dimension_names
|
||||
self._config = config
|
||||
|
||||
def get_config(self) -> dict[str, object]:
|
||||
"""Return the model configuration as plain JSON-serializable data."""
|
||||
|
||||
return dict(self._config)
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
||||
"""Run the full detection pipeline on a waveform tensor.
|
||||
@ -216,7 +224,7 @@ def build_model(
|
||||
dimension_names: list[str] | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
postprocessor: PostprocessorProtocol | None = None,
|
||||
) -> Model:
|
||||
) -> ModelProtocol:
|
||||
"""Build a complete, ready-to-use BatDetect2 model.
|
||||
|
||||
Assembles a ``Model`` instance from a ``ModelConfig`` and optional
|
||||
@ -248,7 +256,7 @@ def build_model(
|
||||
|
||||
Returns
|
||||
-------
|
||||
Model
|
||||
ModelProtocol
|
||||
A fully assembled ``Model`` instance ready for inference or
|
||||
training.
|
||||
"""
|
||||
@ -277,8 +285,8 @@ def build_model(
|
||||
config=config.postprocess,
|
||||
)
|
||||
detector = build_detector(
|
||||
num_classes=len(class_names),
|
||||
num_sizes=len(dimension_names),
|
||||
class_names=class_names,
|
||||
dimension_names=dimension_names,
|
||||
config=config.architecture,
|
||||
)
|
||||
return Model(
|
||||
@ -287,18 +295,19 @@ def build_model(
|
||||
preprocessor=preprocessor,
|
||||
class_names=class_names,
|
||||
dimension_names=dimension_names,
|
||||
config=config.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
|
||||
def build_model_with_new_targets(
|
||||
model: Model,
|
||||
model: ModelProtocol,
|
||||
targets: TargetProtocol,
|
||||
roi_mapper: ROIMapperProtocol,
|
||||
) -> Model:
|
||||
) -> ModelProtocol:
|
||||
"""Build a new model with a different target set."""
|
||||
detector = build_detector(
|
||||
num_classes=len(targets.class_names),
|
||||
num_sizes=len(roi_mapper.dimension_names),
|
||||
class_names=targets.class_names,
|
||||
dimension_names=roi_mapper.dimension_names,
|
||||
backbone=model.detector.backbone,
|
||||
)
|
||||
|
||||
@ -308,4 +317,5 @@ def build_model_with_new_targets(
|
||||
preprocessor=model.preprocessor,
|
||||
class_names=targets.class_names,
|
||||
dimension_names=roi_mapper.dimension_names,
|
||||
config=model.get_config(),
|
||||
)
|
||||
|
||||
@ -27,6 +27,7 @@ from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from loguru import logger
|
||||
from pydantic import Field, TypeAdapter
|
||||
from soundevent import data
|
||||
|
||||
@ -52,7 +53,7 @@ from batdetect2.models.encoder import (
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.types import (
|
||||
BackboneModel,
|
||||
BackboneProtocol,
|
||||
BottleneckProtocol,
|
||||
DecoderProtocol,
|
||||
EncoderProtocol,
|
||||
@ -104,7 +105,7 @@ class UNetBackboneConfig(BaseConfig):
|
||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||
|
||||
|
||||
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
|
||||
backbone_registry: Registry[BackboneProtocol, []] = Registry("backbone")
|
||||
|
||||
|
||||
@add_import_config(backbone_registry)
|
||||
@ -118,7 +119,7 @@ class BackboneImportConfig(ImportConfig):
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class UNetBackbone(BackboneModel):
|
||||
class UNetBackbone(torch.nn.Module):
|
||||
"""U-Net-style encoder-decoder backbone network.
|
||||
|
||||
Combines an encoder, a bottleneck, and a decoder into a single module
|
||||
@ -225,7 +226,7 @@ class UNetBackbone(BackboneModel):
|
||||
|
||||
@backbone_registry.register(UNetBackboneConfig)
|
||||
@staticmethod
|
||||
def from_config(config: UNetBackboneConfig) -> BackboneModel:
|
||||
def from_config(config: UNetBackboneConfig) -> BackboneProtocol:
|
||||
encoder = build_encoder(
|
||||
in_channels=config.in_channels,
|
||||
input_height=config.input_height,
|
||||
@ -266,7 +267,7 @@ BackboneConfig = Annotated[
|
||||
]
|
||||
|
||||
|
||||
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
||||
def build_backbone(config: BackboneConfig | None = None) -> BackboneProtocol:
|
||||
"""Build a backbone network from configuration.
|
||||
|
||||
Looks up the backbone class corresponding to ``config.name`` in the
|
||||
@ -282,10 +283,14 @@ def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneModel
|
||||
BackboneProtocol
|
||||
An initialised backbone module.
|
||||
"""
|
||||
config = config or UNetBackboneConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building model backbone with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
return backbone_registry.build(config)
|
||||
|
||||
|
||||
|
||||
BIN
src/batdetect2/models/checkpoints/batdetect2_uk_same.ckpt
Normal file
BIN
src/batdetect2/models/checkpoints/batdetect2_uk_same.ckpt
Normal file
Binary file not shown.
@ -6,8 +6,8 @@ bounding-box size regression.
|
||||
|
||||
Components
|
||||
----------
|
||||
- ``Detector`` – the ``torch.nn.Module`` that wires together a backbone
|
||||
(``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to
|
||||
- ``Detector`` - the ``torch.nn.Module`` that wires together a backbone
|
||||
(``BackboneProtocol``) with a ``ClassifierHead`` and a ``BBoxHead`` to
|
||||
produce a ``ModelOutput`` tuple from an input spectrogram.
|
||||
- ``build_detector`` – factory function that builds a ready-to-use
|
||||
``Detector`` from a backbone configuration and a target class count.
|
||||
@ -18,15 +18,16 @@ preprocessing and output postprocessing are handled by
|
||||
"""
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.models.backbones import (
|
||||
BackboneConfig,
|
||||
UNetBackboneConfig,
|
||||
build_backbone,
|
||||
)
|
||||
from batdetect2.models.backbones import BackboneConfig, build_backbone
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||
from batdetect2.models.types import (
|
||||
BackboneProtocol,
|
||||
ClassifierHeadProtocol,
|
||||
DetectorProtocol,
|
||||
ModelOutput,
|
||||
SizeHeadProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Detector",
|
||||
@ -34,7 +35,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class Detector(DetectionModel):
|
||||
class Detector(torch.nn.Module):
|
||||
"""Complete BatDetect2 detection and classification model.
|
||||
|
||||
Combines a backbone feature extractor with two prediction heads:
|
||||
@ -51,7 +52,7 @@ class Detector(DetectionModel):
|
||||
|
||||
Attributes
|
||||
----------
|
||||
backbone : BackboneModel
|
||||
backbone : BackboneProtocol
|
||||
The feature extraction backbone.
|
||||
num_classes : int
|
||||
Number of target classes (inferred from the classifier head).
|
||||
@ -61,13 +62,13 @@ class Detector(DetectionModel):
|
||||
Produces duration and bandwidth predictions from backbone features.
|
||||
"""
|
||||
|
||||
backbone: BackboneModel
|
||||
backbone: BackboneProtocol
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone: BackboneModel,
|
||||
classifier_head: ClassifierHead,
|
||||
bbox_head: BBoxHead,
|
||||
backbone: BackboneProtocol,
|
||||
classifier_head: ClassifierHeadProtocol,
|
||||
size_head: SizeHeadProtocol,
|
||||
):
|
||||
"""Initialise the Detector model.
|
||||
|
||||
@ -76,7 +77,7 @@ class Detector(DetectionModel):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backbone : BackboneModel
|
||||
backbone : BackboneProtocol
|
||||
An initialised backbone module (e.g. built by
|
||||
``build_backbone``).
|
||||
classifier_head : ClassifierHead
|
||||
@ -90,7 +91,7 @@ class Detector(DetectionModel):
|
||||
self.backbone = backbone
|
||||
self.num_classes = classifier_head.num_classes
|
||||
self.classifier_head = classifier_head
|
||||
self.bbox_head = bbox_head
|
||||
self.size_head = size_head
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
"""Run the complete detection model on an input spectrogram.
|
||||
@ -125,7 +126,7 @@ class Detector(DetectionModel):
|
||||
features = self.backbone(spec)
|
||||
classification = self.classifier_head(features)
|
||||
detection = classification.sum(dim=1, keepdim=True)
|
||||
size_preds = self.bbox_head(features)
|
||||
size_preds = self.size_head(features)
|
||||
return ModelOutput(
|
||||
detection_probs=detection,
|
||||
size_preds=size_preds,
|
||||
@ -135,11 +136,11 @@ class Detector(DetectionModel):
|
||||
|
||||
|
||||
def build_detector(
|
||||
num_classes: int,
|
||||
num_sizes: int = 2,
|
||||
class_names: list[str],
|
||||
dimension_names: list[str],
|
||||
config: BackboneConfig | None = None,
|
||||
backbone: BackboneModel | None = None,
|
||||
) -> DetectionModel:
|
||||
backbone: BackboneProtocol | None = None,
|
||||
) -> DetectorProtocol:
|
||||
"""Build a complete BatDetect2 detection model.
|
||||
|
||||
Constructs a backbone from ``config``, attaches a ``ClassifierHead``
|
||||
@ -158,7 +159,7 @@ def build_detector(
|
||||
|
||||
Returns
|
||||
-------
|
||||
DetectionModel
|
||||
DetectorProtocol
|
||||
An initialised ``Detector`` instance ready for training or
|
||||
inference.
|
||||
|
||||
@ -168,24 +169,18 @@ def build_detector(
|
||||
If ``num_classes`` is not positive, or if the backbone
|
||||
configuration is invalid.
|
||||
"""
|
||||
if backbone is None:
|
||||
config = config or UNetBackboneConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building model with config: \n{}",
|
||||
lambda: config.to_yaml_string(), # type: ignore
|
||||
)
|
||||
backbone = build_backbone(config=config)
|
||||
backbone = backbone or build_backbone(config=config)
|
||||
|
||||
classifier_head = ClassifierHead(
|
||||
num_classes=num_classes,
|
||||
class_names=class_names,
|
||||
in_channels=backbone.out_channels,
|
||||
)
|
||||
bbox_head = BBoxHead(
|
||||
in_channels=backbone.out_channels,
|
||||
num_sizes=num_sizes,
|
||||
dimension_names=dimension_names,
|
||||
)
|
||||
return Detector(
|
||||
backbone=backbone,
|
||||
classifier_head=classifier_head,
|
||||
bbox_head=bbox_head,
|
||||
size_head=bbox_head,
|
||||
)
|
||||
|
||||
@ -54,12 +54,14 @@ class ClassifierHead(nn.Module):
|
||||
1×1 convolution with ``num_classes + 1`` output channels.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes: int, in_channels: int):
|
||||
def __init__(self, class_names: list[str], in_channels: int):
|
||||
"""Initialise the ClassifierHead."""
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.class_names = class_names
|
||||
self.num_classes = len(class_names)
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.classifier = nn.Conv2d(
|
||||
self.in_channels,
|
||||
self.num_classes + 1,
|
||||
@ -165,11 +167,12 @@ class BBoxHead(nn.Module):
|
||||
1×1 convolution with 2 output channels (duration, bandwidth).
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int, num_sizes: int = 2):
|
||||
def __init__(self, dimension_names: list[str], in_channels: int):
|
||||
"""Initialise the BBoxHead."""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.num_sizes = num_sizes
|
||||
self.dimension_names = dimension_names
|
||||
self.num_sizes = len(dimension_names)
|
||||
|
||||
self.bbox = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
|
||||
@ -1,21 +1,42 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import NamedTuple, Protocol
|
||||
from typing import Any, NamedTuple, Protocol
|
||||
|
||||
import torch
|
||||
|
||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"BackboneModel",
|
||||
"BackboneProtocol",
|
||||
"BlockProtocol",
|
||||
"BottleneckProtocol",
|
||||
"ClassifierHeadProtocol",
|
||||
"DecoderProtocol",
|
||||
"DetectionModel",
|
||||
"EncoderDecoderModel",
|
||||
"DetectorProtocol",
|
||||
"EncoderProtocol",
|
||||
"ModelOutput",
|
||||
"ModelProtocol",
|
||||
"ModuleProtocol",
|
||||
"SizeHeadProtocol",
|
||||
]
|
||||
|
||||
|
||||
class BlockProtocol(Protocol):
|
||||
class ModuleProtocol(Protocol):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
def train(self, mode: bool = True) -> torch.nn.Module: ...
|
||||
|
||||
def eval(self) -> torch.nn.Module: ...
|
||||
|
||||
def state_dict(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> dict[str, torch.Tensor]: ...
|
||||
|
||||
def load_state_dict(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
def parameters(self) -> Any: ...
|
||||
|
||||
|
||||
class BlockProtocol(ModuleProtocol, Protocol):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
|
||||
@ -24,7 +45,7 @@ class BlockProtocol(Protocol):
|
||||
def get_output_height(self, input_height: int) -> int: ...
|
||||
|
||||
|
||||
class EncoderProtocol(Protocol):
|
||||
class EncoderProtocol(ModuleProtocol, Protocol):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
@ -33,7 +54,7 @@ class EncoderProtocol(Protocol):
|
||||
def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ...
|
||||
|
||||
|
||||
class BottleneckProtocol(Protocol):
|
||||
class BottleneckProtocol(ModuleProtocol, Protocol):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
@ -41,7 +62,7 @@ class BottleneckProtocol(Protocol):
|
||||
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class DecoderProtocol(Protocol):
|
||||
class DecoderProtocol(ModuleProtocol, Protocol):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
@ -62,29 +83,42 @@ class ModelOutput(NamedTuple):
|
||||
features: torch.Tensor
|
||||
|
||||
|
||||
class BackboneModel(ABC, torch.nn.Module):
|
||||
class BackboneProtocol(ModuleProtocol, Protocol):
|
||||
input_height: int
|
||||
out_channels: int
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class EncoderDecoderModel(BackboneModel):
|
||||
bottleneck_channels: int
|
||||
class ClassifierHeadProtocol(ModuleProtocol, Protocol):
|
||||
num_classes: int
|
||||
in_channels: int
|
||||
class_names: list[str]
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, encoded: torch.Tensor) -> torch.Tensor: ...
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class DetectionModel(ABC, torch.nn.Module):
|
||||
backbone: BackboneModel
|
||||
classifier_head: torch.nn.Module
|
||||
bbox_head: torch.nn.Module
|
||||
class SizeHeadProtocol(ModuleProtocol, Protocol):
|
||||
in_channels: int
|
||||
num_sizes: int
|
||||
dimension_names: list[str]
|
||||
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class DetectorProtocol(ModuleProtocol, Protocol):
|
||||
backbone: BackboneProtocol
|
||||
classifier_head: ClassifierHeadProtocol
|
||||
size_head: SizeHeadProtocol
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput: ...
|
||||
|
||||
|
||||
class ModelProtocol(ModuleProtocol, Protocol):
|
||||
detector: DetectorProtocol
|
||||
preprocessor: PreprocessorProtocol
|
||||
postprocessor: PostprocessorProtocol
|
||||
class_names: list[str]
|
||||
dimension_names: list[str]
|
||||
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
|
||||
@ -154,17 +154,18 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||
top_class_index = int(np.argmax(prediction.class_scores))
|
||||
top_class_score = float(prediction.class_scores[top_class_index])
|
||||
top_class = self.get_class_name(top_class_index)
|
||||
return Annotation(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
class_prob=top_class_score,
|
||||
det_prob=float(prediction.detection_score),
|
||||
individual="",
|
||||
event=self.event_name,
|
||||
**{"class": top_class},
|
||||
)
|
||||
annotation: Annotation = {
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"low_freq": low_freq,
|
||||
"high_freq": high_freq,
|
||||
"class_prob": top_class_score,
|
||||
"det_prob": float(prediction.detection_score),
|
||||
"individual": "",
|
||||
"event": self.event_name,
|
||||
"class": top_class,
|
||||
}
|
||||
return annotation
|
||||
|
||||
@output_formatters.register(BatDetect2OutputConfig)
|
||||
@staticmethod
|
||||
|
||||
@ -26,6 +26,13 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def _default_spectrogram_transforms() -> list[SpectrogramTransform]:
|
||||
return [
|
||||
PcenConfig(),
|
||||
SpectralMeanSubtractionConfig(),
|
||||
]
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseConfig):
|
||||
"""Unified configuration for the audio preprocessing pipeline.
|
||||
|
||||
@ -58,10 +65,7 @@ class PreprocessingConfig(BaseConfig):
|
||||
audio_transforms: List[AudioTransform] = Field(default_factory=list)
|
||||
|
||||
spectrogram_transforms: List[SpectrogramTransform] = Field(
|
||||
default_factory=lambda: [
|
||||
PcenConfig(),
|
||||
SpectralMeanSubtractionConfig(),
|
||||
]
|
||||
default_factory=_default_spectrogram_transforms
|
||||
)
|
||||
|
||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||
|
||||
@ -71,7 +71,7 @@ class TargetClassConfig(BaseConfig):
|
||||
|
||||
DEFAULT_DETECTION_CLASS = TargetClassConfig(
|
||||
name="bat",
|
||||
match_if=AllOfConfig( # ty: ignore[unknown-argument]
|
||||
match_if=AllOfConfig(
|
||||
conditions=[
|
||||
HasTagConfig(tag=data.Tag(term=call_type, value="Echolocation")),
|
||||
NotConfig(
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR
|
||||
from batdetect2.train.checkpoints import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
resolve_checkpoint_path,
|
||||
)
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.lightning import (
|
||||
TrainingModule,
|
||||
@ -26,5 +29,6 @@ __all__ = [
|
||||
"TrainingModule",
|
||||
"build_trainer",
|
||||
"load_model_from_checkpoint",
|
||||
"resolve_checkpoint_path",
|
||||
"run_train",
|
||||
]
|
||||
|
||||
@ -2,15 +2,31 @@ from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"CheckpointConfig",
|
||||
"DEFAULT_CHECKPOINT",
|
||||
"build_checkpoint_callback",
|
||||
"get_bundled_checkpoint_names",
|
||||
"resolve_checkpoint_path",
|
||||
]
|
||||
|
||||
PACKAGE_ROOT = Path(__file__).resolve().parents[1]
|
||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||
DEFAULT_CHECKPOINT = "uk_same"
|
||||
CHECKPOINT_ALIASES = {
|
||||
DEFAULT_CHECKPOINT: PACKAGE_ROOT
|
||||
/ "models"
|
||||
/ "checkpoints"
|
||||
/ "batdetect2_uk_same.ckpt",
|
||||
"batdetect2_uk_same": PACKAGE_ROOT
|
||||
/ "models"
|
||||
/ "checkpoints"
|
||||
/ "batdetect2_uk_same.ckpt",
|
||||
}
|
||||
|
||||
|
||||
class CheckpointConfig(BaseConfig):
|
||||
@ -18,6 +34,8 @@ class CheckpointConfig(BaseConfig):
|
||||
monitor: str | None = None
|
||||
mode: str = "max"
|
||||
save_top_k: int = 1
|
||||
# Save distributable inference checkpoints by default.
|
||||
save_weights_only: bool = True
|
||||
filename: str | None = None
|
||||
save_last: bool | Literal["link"] = "link"
|
||||
every_n_epochs: int | None = 1
|
||||
@ -47,9 +65,86 @@ def build_checkpoint_callback(
|
||||
return ModelCheckpoint(
|
||||
dirpath=str(checkpoint_dir),
|
||||
save_top_k=config.save_top_k,
|
||||
save_weights_only=config.save_weights_only,
|
||||
monitor=config.monitor,
|
||||
mode=config.mode,
|
||||
filename=config.filename,
|
||||
save_last=config.save_last,
|
||||
every_n_epochs=config.every_n_epochs,
|
||||
)
|
||||
|
||||
|
||||
def get_bundled_checkpoint_names() -> tuple[str, ...]:
|
||||
"""Return the supported bundled checkpoint aliases."""
|
||||
return tuple(CHECKPOINT_ALIASES.keys())
|
||||
|
||||
|
||||
def resolve_checkpoint_from_huggingface(path: str) -> Path:
|
||||
"""Resolve a Hugging Face checkpoint URI."""
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
except ImportError as error:
|
||||
raise ValueError(
|
||||
"Hugging Face checkpoint support is not installed. "
|
||||
"Install it with `pip install batdetect2[huggingface]`."
|
||||
) from error
|
||||
|
||||
repo_id, filename = _parse_huggingface_uri(path)
|
||||
return Path(hf_hub_download(repo_id=repo_id, filename=filename))
|
||||
|
||||
|
||||
def resolve_checkpoint_path(path: PathLike | str | None = None) -> Path:
|
||||
"""Resolve a local path, alias, or Hugging Face checkpoint URI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike | str | None
|
||||
Local checkpoint path, checkpoint alias, or a Hugging Face
|
||||
URI of the form ``hf://owner/repo/path/to/checkpoint.ckpt``. If
|
||||
omitted, the default alias checkpoint is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Path
|
||||
Resolved local filesystem path to the checkpoint.
|
||||
"""
|
||||
if path is None:
|
||||
path = DEFAULT_CHECKPOINT
|
||||
|
||||
if isinstance(path, str) and path.startswith("hf://"):
|
||||
return resolve_checkpoint_from_huggingface(path)
|
||||
|
||||
if isinstance(path, str) and path in CHECKPOINT_ALIASES:
|
||||
return Path(CHECKPOINT_ALIASES[path])
|
||||
|
||||
path = Path(path)
|
||||
if path.exists():
|
||||
return path.resolve()
|
||||
|
||||
bundled_names = ", ".join(get_bundled_checkpoint_names())
|
||||
raise FileNotFoundError(
|
||||
f"Checkpoint not found: {path}. "
|
||||
"Expected a local path, a checkpoint alias "
|
||||
f"({bundled_names}), or a Hugging Face URI."
|
||||
)
|
||||
|
||||
|
||||
def _parse_huggingface_uri(uri: str) -> tuple[str, str]:
|
||||
prefix = "hf://"
|
||||
if not uri.startswith(prefix):
|
||||
raise ValueError(
|
||||
"Hugging Face checkpoint URIs must start with 'hf://'."
|
||||
)
|
||||
|
||||
without_prefix = uri.removeprefix(prefix).strip("/")
|
||||
parts = without_prefix.split("/")
|
||||
|
||||
if len(parts) < 3:
|
||||
raise ValueError(
|
||||
"Hugging Face checkpoint URIs must be in the form "
|
||||
"'hf://owner/repo/path/to/checkpoint.ckpt'."
|
||||
)
|
||||
|
||||
repo_id = "/".join(parts[:2])
|
||||
filename = "/".join(parts[2:])
|
||||
return repo_id, filename
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.models import Model, ModelConfig, build_model
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.models import ModelConfig, build_model
|
||||
from batdetect2.models.types import ModelOutput, ModelProtocol
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train.checkpoints import resolve_checkpoint_path
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.train.optimizers import build_optimizer
|
||||
@ -19,7 +21,7 @@ __all__ = [
|
||||
|
||||
|
||||
class TrainingModule(L.LightningModule):
|
||||
model: Model
|
||||
model: ModelProtocol
|
||||
loss: LossProtocol
|
||||
|
||||
def __init__(
|
||||
@ -30,7 +32,7 @@ class TrainingModule(L.LightningModule):
|
||||
dimension_names: list[str] | None = None,
|
||||
train_config: dict | None = None,
|
||||
loss: LossProtocol | None = None,
|
||||
model: Model | None = None,
|
||||
model: ModelProtocol | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -130,23 +132,27 @@ class StoredConfig:
|
||||
|
||||
|
||||
def load_model_from_checkpoint(
|
||||
path: PathLike,
|
||||
) -> tuple[Model, StoredConfig]:
|
||||
path: PathLike | str | None = None,
|
||||
) -> tuple[ModelProtocol, StoredConfig]:
|
||||
"""Load a model and its configuration from a Lightning checkpoint.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
path : PathLike | str | None
|
||||
Path to a ``.ckpt`` file produced by the BatDetect2 training
|
||||
pipeline.
|
||||
pipeline. If omitted, the default bundled checkpoint is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[Model, ModelConfig]
|
||||
tuple[ModelProtocol, ModelConfig]
|
||||
The restored ``Model`` instance and the ``ModelConfig`` that
|
||||
describes its architecture, preprocessing, and postprocessing.
|
||||
"""
|
||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||
resolved_path = resolve_checkpoint_path(path)
|
||||
module = TrainingModule.load_from_checkpoint(
|
||||
resolved_path,
|
||||
map_location=torch.device("cpu"),
|
||||
)
|
||||
training_config = TrainingConfig.model_validate(module.train_config)
|
||||
model_config = ModelConfig.model_validate(module.model_config)
|
||||
targets_config = TargetConfig.model_validate(module.targets_config)
|
||||
@ -163,7 +169,7 @@ def build_training_module(
|
||||
class_names: list[str] | None = None,
|
||||
dimension_names: list[str] | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
model: Model | None = None,
|
||||
model: ModelProtocol | None = None,
|
||||
) -> TrainingModule:
|
||||
if model_config is None:
|
||||
model_config = ModelConfig()
|
||||
|
||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
from lightning.pytorch.loggers import Logger
|
||||
@ -28,7 +29,7 @@ __all__ = [
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrainLoggingContext:
|
||||
model_config: ModelConfig
|
||||
model_config: dict[str, Any]
|
||||
train_config: TrainingConfig
|
||||
audio_config: AudioConfig
|
||||
targets: TargetProtocol
|
||||
@ -49,9 +50,10 @@ class ConfigHyperparameterLogging:
|
||||
artifact_path: Path,
|
||||
context: TrainLoggingContext,
|
||||
) -> None:
|
||||
model_config = ModelConfig.model_validate(context.model_config)
|
||||
logger.log_hyperparams(
|
||||
{
|
||||
"model": context.model_config.model_dump(
|
||||
"model": model_config.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
|
||||
@ -15,7 +15,8 @@ from batdetect2.logging import (
|
||||
TensorBoardLoggerConfig,
|
||||
build_logger,
|
||||
)
|
||||
from batdetect2.models import Model, ModelConfig, build_model
|
||||
from batdetect2.models import ModelConfig, build_model
|
||||
from batdetect2.models.types import ModelProtocol
|
||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||
from batdetect2.targets import (
|
||||
ROIMapperProtocol,
|
||||
@ -50,14 +51,13 @@ DEFAULT_LOG_DIR = Path("outputs") / "logs"
|
||||
def run_train(
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||
model: Model | None = None,
|
||||
model: ModelProtocol | None = None,
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
roi_mapper: Optional["ROIMapperProtocol"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
labeller: Optional["ClipLabeller"] = None,
|
||||
audio_config: Optional[AudioConfig] = None,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
targets_config: TargetConfig | None = None,
|
||||
train_config: Optional[TrainingConfig] = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
@ -75,7 +75,11 @@ def run_train(
|
||||
if seed is not None:
|
||||
seed_everything(seed)
|
||||
|
||||
model_config = model_config or ModelConfig()
|
||||
model_config = (
|
||||
ModelConfig()
|
||||
if model is None
|
||||
else ModelConfig.model_validate(model.get_config())
|
||||
)
|
||||
targets_config = targets_config or TargetConfig()
|
||||
audio_config = audio_config or AudioConfig()
|
||||
train_config = train_config or TrainingConfig()
|
||||
@ -172,7 +176,7 @@ def run_train(
|
||||
root_artifact_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging_context = TrainLoggingContext(
|
||||
model_config=model_config,
|
||||
model_config=model_config.model_dump(mode="json"),
|
||||
train_config=train_config,
|
||||
audio_config=audio_config,
|
||||
targets=targets,
|
||||
@ -214,7 +218,7 @@ def run_train(
|
||||
|
||||
|
||||
def _validate_model_compatibility(
|
||||
model: Model,
|
||||
model: ModelProtocol,
|
||||
model_config: ModelConfig,
|
||||
class_names: list[str],
|
||||
dimension_names: list[str],
|
||||
|
||||
@ -200,13 +200,14 @@ def test_user_can_read_extracted_features_per_detection(
|
||||
) -> None:
|
||||
"""User story: inspect extracted feature vectors per detection."""
|
||||
|
||||
# Given
|
||||
prediction = api_v2.process_file(example_audio_files[0])
|
||||
|
||||
assert len(prediction.detections) > 0
|
||||
# When
|
||||
feature_vectors = [det.features for det in prediction.detections]
|
||||
|
||||
feature_vectors = [
|
||||
api_v2.get_detection_features(det) for det in prediction.detections
|
||||
]
|
||||
# Then
|
||||
assert len(prediction.detections) > 0
|
||||
assert len(feature_vectors) == len(prediction.detections)
|
||||
assert all(vec.ndim == 1 for vec in feature_vectors)
|
||||
assert all(vec.size > 0 for vec in feature_vectors)
|
||||
@ -299,14 +300,20 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
||||
value,
|
||||
)
|
||||
|
||||
for key, value in source_detector.bbox_head.state_dict().items():
|
||||
assert key in detector.bbox_head.state_dict()
|
||||
for key, value in source_detector.size_head.state_dict().items():
|
||||
assert key in detector.size_head.state_dict()
|
||||
torch.testing.assert_close(
|
||||
detector.bbox_head.state_dict()[key],
|
||||
detector.size_head.state_dict()[key],
|
||||
value,
|
||||
)
|
||||
|
||||
|
||||
def test_api_from_checkpoint_defaults_to_bundled_model() -> None:
|
||||
api = BatDetect2API.from_checkpoint()
|
||||
|
||||
assert api.model.class_names
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
||||
api_v2: BatDetect2API,
|
||||
|
||||
@ -18,7 +18,7 @@ def test_user_can_finetune_only_heads(
|
||||
|
||||
api = BatDetect2API.from_config()
|
||||
source_classifier_head = api.model.detector.classifier_head
|
||||
source_bbox_head = api.model.detector.bbox_head
|
||||
source_size_head = api.model.detector.size_head
|
||||
source_backbone = api.model.detector.backbone
|
||||
finetune_dir = tmp_path / "heads_only"
|
||||
|
||||
@ -39,7 +39,7 @@ def test_user_can_finetune_only_heads(
|
||||
|
||||
backbone_params = list(detector.backbone.parameters())
|
||||
classifier_params = list(detector.classifier_head.parameters())
|
||||
bbox_params = list(detector.bbox_head.parameters())
|
||||
bbox_params = list(detector.size_head.parameters())
|
||||
|
||||
assert backbone_params
|
||||
assert classifier_params
|
||||
@ -50,7 +50,7 @@ def test_user_can_finetune_only_heads(
|
||||
assert finetuned_api is not api
|
||||
assert detector.backbone is source_backbone
|
||||
assert detector.classifier_head is not source_classifier_head
|
||||
assert detector.bbox_head is not source_bbox_head
|
||||
assert detector.size_head is not source_size_head
|
||||
assert list(finetune_dir.rglob("*.ckpt"))
|
||||
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ def test_cli_base_help_lists_main_commands() -> None:
|
||||
result = CliRunner().invoke(cli, ["--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "predict" in result.output
|
||||
assert "process" in result.output
|
||||
assert "train" in result.output
|
||||
assert "evaluate" in result.output
|
||||
assert "data" in result.output
|
||||
|
||||
@ -15,8 +15,8 @@ def test_cli_evaluate_help() -> None:
|
||||
result = CliRunner().invoke(cli, ["evaluate", "--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "MODEL_PATH" in result.output
|
||||
assert "TEST_DATASET" in result.output
|
||||
assert "--model" in result.output
|
||||
assert "--evaluation-config" in result.output
|
||||
|
||||
|
||||
@ -32,8 +32,9 @@ def test_cli_evaluate_writes_metrics_for_small_dataset(
|
||||
cli,
|
||||
[
|
||||
"evaluate",
|
||||
str(tiny_checkpoint_path),
|
||||
str(BASE_DIR / "example_data" / "dataset.yaml"),
|
||||
"--model",
|
||||
str(tiny_checkpoint_path),
|
||||
"--base-dir",
|
||||
str(BASE_DIR),
|
||||
"--workers",
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""CLI tests for finetune command."""
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
@ -25,8 +26,41 @@ def test_cli_finetune_help() -> None:
|
||||
assert "--outputs-config" not in result.output
|
||||
|
||||
|
||||
def test_cli_finetune_requires_model() -> None:
|
||||
"""User story: finetune requires a checkpoint argument."""
|
||||
def test_cli_finetune_defaults_to_bundled_model(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""User story: finetune can use the bundled checkpoint by default."""
|
||||
|
||||
called = {}
|
||||
|
||||
class FakeAPI:
|
||||
def finetune(self, **kwargs):
|
||||
called["finetune"] = kwargs
|
||||
return None
|
||||
|
||||
class FakeBatDetect2API:
|
||||
@classmethod
|
||||
def from_checkpoint(cls, path=None, **kwargs):
|
||||
called["path"] = path
|
||||
called["from_checkpoint_kwargs"] = kwargs
|
||||
return FakeAPI()
|
||||
|
||||
monkeypatch.setattr(
|
||||
"batdetect2.api_v2.BatDetect2API",
|
||||
FakeBatDetect2API,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"batdetect2.data.load_dataset_config",
|
||||
lambda path: SimpleNamespace(path=path),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"batdetect2.data.load_dataset",
|
||||
lambda config, base_dir=None: [],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"batdetect2.targets.TargetConfig.load",
|
||||
lambda path: SimpleNamespace(path=path),
|
||||
)
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
@ -38,8 +72,9 @@ def test_cli_finetune_requires_model() -> None:
|
||||
],
|
||||
)
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "--model" in result.output
|
||||
assert result.exit_code == 0
|
||||
assert called["path"] is None
|
||||
assert "finetune" in called
|
||||
|
||||
|
||||
def test_cli_finetune_requires_targets(tiny_checkpoint_path: Path) -> None:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
"""Behavior tests for predict CLI workflows."""
|
||||
"""Behavior tests for process CLI workflows."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@ -9,10 +9,10 @@ from soundevent import data, io
|
||||
from batdetect2.cli import cli
|
||||
|
||||
|
||||
def test_cli_predict_help() -> None:
|
||||
"""User story: discover available predict modes."""
|
||||
def test_cli_process_help() -> None:
|
||||
"""User story: discover available process modes."""
|
||||
|
||||
result = CliRunner().invoke(cli, ["predict", "--help"])
|
||||
result = CliRunner().invoke(cli, ["process", "--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "directory" in result.output
|
||||
@ -21,19 +21,19 @@ def test_cli_predict_help() -> None:
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_cli_predict_directory_runs_on_real_audio(
|
||||
def test_cli_process_directory_runs_on_real_audio(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
single_audio_dir: Path,
|
||||
) -> None:
|
||||
"""User story: run prediction for all files in a directory."""
|
||||
"""User story: process all files in a directory."""
|
||||
|
||||
output_path = tmp_path / "predictions"
|
||||
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"predict",
|
||||
"process",
|
||||
"directory",
|
||||
str(tiny_checkpoint_path),
|
||||
str(single_audio_dir),
|
||||
@ -52,12 +52,12 @@ def test_cli_predict_directory_runs_on_real_audio(
|
||||
assert len(list(output_path.glob("*.json"))) == 1
|
||||
|
||||
|
||||
def test_cli_predict_file_list_runs_on_real_audio(
|
||||
def test_cli_process_file_list_runs_on_real_audio(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
single_audio_dir: Path,
|
||||
) -> None:
|
||||
"""User story: run prediction from an explicit list of files."""
|
||||
"""User story: process an explicit list of files."""
|
||||
|
||||
audio_file = next(single_audio_dir.glob("*.wav"))
|
||||
file_list = tmp_path / "files.txt"
|
||||
@ -68,7 +68,7 @@ def test_cli_predict_file_list_runs_on_real_audio(
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"predict",
|
||||
"process",
|
||||
"file_list",
|
||||
str(tiny_checkpoint_path),
|
||||
str(file_list),
|
||||
@ -87,12 +87,12 @@ def test_cli_predict_file_list_runs_on_real_audio(
|
||||
assert len(list(output_path.glob("*.json"))) == 1
|
||||
|
||||
|
||||
def test_cli_predict_dataset_runs_on_aoef_metadata(
|
||||
def test_cli_process_dataset_runs_on_aoef_metadata(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
single_audio_dir: Path,
|
||||
) -> None:
|
||||
"""User story: predict from AOEF dataset metadata file."""
|
||||
"""User story: process from AOEF dataset metadata file."""
|
||||
|
||||
audio_file = next(single_audio_dir.glob("*.wav"))
|
||||
recording = data.Recording.from_file(audio_file)
|
||||
@ -103,7 +103,7 @@ def test_cli_predict_dataset_runs_on_aoef_metadata(
|
||||
)
|
||||
annotation_set = data.AnnotationSet(
|
||||
name="test",
|
||||
description="predict dataset test",
|
||||
description="process dataset test",
|
||||
clip_annotations=[data.ClipAnnotation(clip=clip, sound_events=[])],
|
||||
)
|
||||
|
||||
@ -115,7 +115,7 @@ def test_cli_predict_dataset_runs_on_aoef_metadata(
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"predict",
|
||||
"process",
|
||||
"dataset",
|
||||
str(tiny_checkpoint_path),
|
||||
str(dataset_path),
|
||||
@ -142,7 +142,7 @@ def test_cli_predict_dataset_runs_on_aoef_metadata(
|
||||
("soundevent", "*.json", True),
|
||||
],
|
||||
)
|
||||
def test_cli_predict_directory_supports_output_format_override(
|
||||
def test_cli_process_directory_supports_output_format_override(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
single_audio_dir: Path,
|
||||
@ -157,7 +157,7 @@ def test_cli_predict_directory_supports_output_format_override(
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"predict",
|
||||
"process",
|
||||
"directory",
|
||||
str(tiny_checkpoint_path),
|
||||
str(single_audio_dir),
|
||||
@ -180,12 +180,12 @@ def test_cli_predict_directory_supports_output_format_override(
|
||||
assert len(list(output_path.glob(expected_pattern))) >= 1
|
||||
|
||||
|
||||
def test_cli_predict_dataset_deduplicates_recordings(
|
||||
def test_cli_process_dataset_deduplicates_recordings(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
single_audio_dir: Path,
|
||||
) -> None:
|
||||
"""User story: duplicated recording entries are predicted once."""
|
||||
"""User story: duplicated recording entries are processed once."""
|
||||
|
||||
audio_file = next(single_audio_dir.glob("*.wav"))
|
||||
recording = data.Recording.from_file(audio_file)
|
||||
@ -215,7 +215,7 @@ def test_cli_predict_dataset_deduplicates_recordings(
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"predict",
|
||||
"process",
|
||||
"dataset",
|
||||
str(tiny_checkpoint_path),
|
||||
str(dataset_path),
|
||||
@ -234,7 +234,7 @@ def test_cli_predict_dataset_deduplicates_recordings(
|
||||
assert len(list(output_path.glob("*.nc"))) == 1
|
||||
|
||||
|
||||
def test_cli_predict_rejects_unknown_output_format(
|
||||
def test_cli_process_rejects_unknown_output_format(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
single_audio_dir: Path,
|
||||
@ -245,7 +245,7 @@ def test_cli_predict_rejects_unknown_output_format(
|
||||
result = CliRunner().invoke(
|
||||
cli,
|
||||
[
|
||||
"predict",
|
||||
"process",
|
||||
"directory",
|
||||
str(tiny_checkpoint_path),
|
||||
str(single_audio_dir),
|
||||
|
||||
@ -13,7 +13,6 @@ from batdetect2.models.backbones import (
|
||||
build_backbone,
|
||||
load_backbone_config,
|
||||
)
|
||||
from batdetect2.models.types import BackboneModel
|
||||
|
||||
|
||||
def test_unet_backbone_config_defaults():
|
||||
@ -61,10 +60,11 @@ def test_build_backbone_custom_config():
|
||||
assert backbone.encoder.in_channels == 2
|
||||
|
||||
|
||||
def test_build_backbone_returns_backbone_model():
|
||||
"""build_backbone always returns a BackboneModel instance."""
|
||||
def test_build_backbone_returns_unet_backbone():
|
||||
"""build_backbone returns the default UNet backbone."""
|
||||
backbone = build_backbone()
|
||||
assert isinstance(backbone, BackboneModel)
|
||||
|
||||
assert isinstance(backbone, UNetBackbone)
|
||||
|
||||
|
||||
def test_registry_has_unet_backbone():
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import cast
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
@ -19,12 +21,15 @@ def dummy_spectrogram() -> torch.Tensor:
|
||||
def test_build_detector_default():
|
||||
"""Test building the default detector without a config."""
|
||||
num_classes = 5
|
||||
model = build_detector(num_classes=num_classes)
|
||||
model = build_detector(
|
||||
class_names=[f"class_{i}" for i in range(num_classes)],
|
||||
dimension_names=["width", "height"],
|
||||
)
|
||||
|
||||
assert isinstance(model, Detector)
|
||||
assert model.num_classes == num_classes
|
||||
assert isinstance(model.classifier_head, ClassifierHead)
|
||||
assert isinstance(model.bbox_head, BBoxHead)
|
||||
assert isinstance(model.size_head, BBoxHead)
|
||||
|
||||
|
||||
def test_build_detector_custom_config():
|
||||
@ -32,13 +37,19 @@ def test_build_detector_custom_config():
|
||||
num_classes = 3
|
||||
config = UNetBackboneConfig(in_channels=2, input_height=128)
|
||||
|
||||
model = build_detector(num_classes=num_classes, config=config)
|
||||
model = build_detector(
|
||||
class_names=[f"class_{i}" for i in range(num_classes)],
|
||||
dimension_names=["width", "height"],
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert isinstance(model, Detector)
|
||||
assert model.backbone.input_height == 128
|
||||
|
||||
assert isinstance(model.backbone.encoder, Encoder)
|
||||
assert model.backbone.encoder.in_channels == 2
|
||||
backbone = cast(UNetBackbone, model.backbone)
|
||||
|
||||
assert isinstance(backbone.encoder, Encoder)
|
||||
assert backbone.encoder.in_channels == 2
|
||||
|
||||
|
||||
def test_build_detector_custom_size_channels():
|
||||
@ -47,8 +58,8 @@ def test_build_detector_custom_size_channels():
|
||||
config = UNetBackboneConfig(in_channels=1, input_height=128)
|
||||
|
||||
model = build_detector(
|
||||
num_classes=num_classes,
|
||||
num_sizes=num_sizes,
|
||||
class_names=[f"class_{i}" for i in range(num_classes)],
|
||||
dimension_names=[f"size_{i}" for i in range(num_sizes)],
|
||||
config=config,
|
||||
)
|
||||
|
||||
@ -62,7 +73,11 @@ def test_detector_forward_pass_shapes(dummy_spectrogram):
|
||||
num_classes = 4
|
||||
# Build model matching the dummy input shape
|
||||
config = UNetBackboneConfig(in_channels=1, input_height=256)
|
||||
model = build_detector(num_classes=num_classes, config=config)
|
||||
model = build_detector(
|
||||
class_names=[f"class_{i}" for i in range(num_classes)],
|
||||
dimension_names=["width", "height"],
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Process the spectrogram through the model
|
||||
# PyTorch expects shape (Batch, Channels, Height, Width)
|
||||
@ -132,7 +147,11 @@ def test_detector_forward_pass_with_preprocessor(sample_preprocessor):
|
||||
config = UNetBackboneConfig(
|
||||
in_channels=spec.shape[1], input_height=spec.shape[2]
|
||||
)
|
||||
model = build_detector(num_classes=3, config=config)
|
||||
model = build_detector(
|
||||
class_names=["class_0", "class_1", "class_2"],
|
||||
dimension_names=["width", "height"],
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Process
|
||||
output = model(spec)
|
||||
|
||||
@ -1,9 +1,17 @@
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.train import TrainingConfig, run_train
|
||||
from batdetect2.train.checkpoints import (
|
||||
DEFAULT_CHECKPOINT,
|
||||
get_bundled_checkpoint_names,
|
||||
resolve_checkpoint_path,
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.slow
|
||||
|
||||
@ -92,3 +100,133 @@ def test_train_controls_which_checkpoints_are_kept(
|
||||
assert last_checkpoints
|
||||
assert len(best_checkpoints) == 1
|
||||
assert "epoch" in best_checkpoints[0].name
|
||||
|
||||
|
||||
def test_train_saves_weights_only_checkpoints_by_default(
|
||||
tmp_path: Path,
|
||||
example_annotations: list[data.ClipAnnotation],
|
||||
) -> None:
|
||||
config = _build_fast_train_config()
|
||||
|
||||
run_train(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=example_annotations[:1],
|
||||
train_config=config,
|
||||
num_epochs=1,
|
||||
train_workers=0,
|
||||
val_workers=0,
|
||||
checkpoint_dir=tmp_path,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
checkpoint_path = next(tmp_path.rglob("*.ckpt"))
|
||||
checkpoint = torch.load(
|
||||
checkpoint_path,
|
||||
map_location="cpu",
|
||||
weights_only=False,
|
||||
)
|
||||
|
||||
assert "state_dict" in checkpoint
|
||||
assert "hyper_parameters" in checkpoint
|
||||
assert "pytorch-lightning_version" in checkpoint
|
||||
assert "optimizer_states" not in checkpoint
|
||||
assert "lr_schedulers" not in checkpoint
|
||||
|
||||
|
||||
def test_resolve_checkpoint_path_returns_local_path_unchanged(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
local_path = tmp_path / "model.ckpt"
|
||||
local_path.write_bytes(b"checkpoint")
|
||||
|
||||
assert resolve_checkpoint_path(local_path) == local_path
|
||||
assert resolve_checkpoint_path(str(local_path)) == local_path
|
||||
|
||||
|
||||
def test_get_bundled_checkpoint_names_lists_supported_aliases() -> None:
|
||||
assert get_bundled_checkpoint_names() == (
|
||||
DEFAULT_CHECKPOINT,
|
||||
"batdetect2_uk_same",
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_checkpoint_path_uses_default_bundled_alias() -> None:
|
||||
resolved = resolve_checkpoint_path()
|
||||
|
||||
assert resolved == resolve_checkpoint_path(DEFAULT_CHECKPOINT)
|
||||
|
||||
|
||||
def test_resolve_checkpoint_path_accepts_bundled_alias() -> None:
|
||||
resolved = resolve_checkpoint_path(DEFAULT_CHECKPOINT)
|
||||
|
||||
assert resolved.name == "batdetect2_uk_same.ckpt"
|
||||
assert resolved.exists()
|
||||
|
||||
|
||||
def test_resolve_checkpoint_path_prefers_existing_local_path_over_alias(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
local_path = tmp_path / "uk_same"
|
||||
local_path.write_bytes(b"checkpoint")
|
||||
|
||||
assert resolve_checkpoint_path(local_path) == local_path
|
||||
assert resolve_checkpoint_path(str(local_path)) == local_path
|
||||
|
||||
|
||||
def test_resolve_checkpoint_path_downloads_huggingface_checkpoint(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
expected_path = tmp_path / "downloaded.ckpt"
|
||||
|
||||
def fake_hf_hub_download(repo_id: str, filename: str) -> str:
|
||||
assert repo_id == "owner/repo"
|
||||
assert filename == "weights/model.ckpt"
|
||||
return str(expected_path)
|
||||
|
||||
class FakeHuggingFaceHub(types.ModuleType):
|
||||
hf_hub_download = staticmethod(fake_hf_hub_download)
|
||||
|
||||
fake_module = FakeHuggingFaceHub("huggingface_hub")
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"huggingface_hub",
|
||||
fake_module,
|
||||
)
|
||||
|
||||
resolved = resolve_checkpoint_path("hf://owner/repo/weights/model.ckpt")
|
||||
|
||||
assert resolved == expected_path
|
||||
|
||||
|
||||
def test_resolve_checkpoint_path_requires_huggingface_dependency(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.delitem(sys.modules, "huggingface_hub", raising=False)
|
||||
|
||||
import builtins
|
||||
|
||||
original_import = builtins.__import__
|
||||
|
||||
def fake_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||
if name == "huggingface_hub":
|
||||
raise ImportError("missing")
|
||||
return original_import(name, globals, locals, fromlist, level)
|
||||
|
||||
monkeypatch.setattr(builtins, "__import__", fake_import)
|
||||
|
||||
with pytest.raises(ValueError, match="Hugging Face checkpoint support"):
|
||||
resolve_checkpoint_path("hf://owner/repo/weights/model.ckpt")
|
||||
|
||||
|
||||
def test_resolve_checkpoint_path_rejects_incomplete_huggingface_uri() -> None:
|
||||
with pytest.raises(ValueError, match="hf://owner/repo/path/to"):
|
||||
resolve_checkpoint_path("hf://owner/repo")
|
||||
|
||||
|
||||
def test_resolve_checkpoint_path_rejects_missing_local_path() -> None:
|
||||
with pytest.raises(
|
||||
FileNotFoundError,
|
||||
match="checkpoint alias",
|
||||
):
|
||||
resolve_checkpoint_path("missing.ckpt")
|
||||
|
||||
@ -368,7 +368,7 @@ def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> (
|
||||
assert (
|
||||
rebuilt_detector.classifier_head is not source_detector.classifier_head
|
||||
)
|
||||
assert rebuilt_detector.bbox_head is not source_detector.bbox_head
|
||||
assert rebuilt_detector.size_head is not source_detector.size_head
|
||||
assert rebuilt_model.class_names == ["single_class"]
|
||||
assert rebuilt_model.dimension_names == ["width", "height"]
|
||||
|
||||
@ -451,7 +451,6 @@ def test_run_train_rejects_incompatible_model_config(
|
||||
model=incompatible_model,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
model_config=incompatible_config,
|
||||
targets_config=targets_config,
|
||||
train_config=TrainingConfig(),
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user