Compare commits

...

73 Commits

Author SHA1 Message Date
mbsantiago
5a14b29281 Do not crash on failure to plot 2026-03-19 01:31:21 +00:00
mbsantiago
13a31d9de9 Do not sync with model in loading from checkpoint otherwise device clash 2026-03-19 01:28:05 +00:00
mbsantiago
a1fad6d7d7 Add test for checkpointing 2026-03-19 01:26:37 +00:00
Santiago Martinez Balvanera
b8acd86c71 By default only save the last checkpoint 2026-03-19 01:26:11 +00:00
Santiago Martinez Balvanera
875751d340 Add save last config to checkpoint 2026-03-19 00:36:45 +00:00
Santiago Martinez Balvanera
32d8c4a9e5 Create separate preproc/postproc/target instances for model in api 2026-03-19 00:23:55 +00:00
mbsantiago
fb3dc3eaf0 Ensure preprocessor is in CPU 2026-03-19 00:09:30 +00:00
mbsantiago
0d90cb5cc3 Create duplicate preprocessor for data input pipeline 2026-03-19 00:03:29 +00:00
mbsantiago
91806aa01e Update default scheduler params 2026-03-18 23:54:16 +00:00
mbsantiago
23ac619c50 Add option to override targets when loading the model config 2026-03-18 23:36:24 +00:00
mbsantiago
99b9e55c0e Add extra and strict arguments to load config functions, migrate example config 2026-03-18 22:06:45 +00:00
mbsantiago
652670b01d Name changes 2026-03-18 20:48:35 +00:00
mbsantiago
0163a572cb Expanded cli tests 2026-03-18 20:35:08 +00:00
mbsantiago
f0af5dd79e Change inference command to predict 2026-03-18 20:07:53 +00:00
mbsantiago
2f03abe8f6 Remove stale load functions 2026-03-18 19:43:54 +00:00
mbsantiago
22a3d18d45 Move the logging config out of the train/eval configs 2026-03-18 19:32:19 +00:00
mbsantiago
bf5b88016a Polish evaluate and train CLI 2026-03-18 19:15:57 +00:00
mbsantiago
f9056eb19a Ensure config is source of truth 2026-03-18 18:33:51 +00:00
mbsantiago
ebe7e134e9 Allow building new head 2026-03-18 17:44:35 +00:00
mbsantiago
8e35956007 Create save evaluation results 2026-03-18 16:49:22 +00:00
mbsantiago
a332c5c3bd Allow loading predictions in different formats 2026-03-18 16:17:50 +00:00
mbsantiago
9fa703b34b Allow training on existing model 2026-03-18 13:58:52 +00:00
mbsantiago
0bf809e376 Exported types at module level 2026-03-18 13:08:28 +00:00
mbsantiago
6276a8884e Add roundtrip test for encoding decoding geometries 2026-03-18 12:09:03 +00:00
mbsantiago
7b1cb402b4 Add a few clip and detection transforms for outputs 2026-03-18 11:24:22 +00:00
mbsantiago
45ae15eed5 Cleanup postprocessing 2026-03-18 09:07:13 +00:00
mbsantiago
b3af70761e Create EvaluateTaskProtocol 2026-03-18 01:53:28 +00:00
mbsantiago
daff74fdde Moved decoding to outputs 2026-03-18 01:35:34 +00:00
mbsantiago
31d4f92359 reformat 2026-03-18 00:41:45 +00:00
mbsantiago
d56b9f02ae Remove num_worker from config 2026-03-18 00:22:59 +00:00
mbsantiago
573d8e38d6 Ran formatter 2026-03-18 00:03:26 +00:00
mbsantiago
751be53edf Moving types around to each submodule 2026-03-18 00:01:35 +00:00
mbsantiago
c226dc3f2b Add outputs module 2026-03-17 23:00:26 +00:00
mbsantiago
3b47c688dd Update api_v2 2026-03-17 22:25:41 +00:00
mbsantiago
feee2bdfa3 Add scheduler and optimizer module 2026-03-17 21:16:41 +00:00
mbsantiago
615c7d78fb Added a full test of training and saving 2026-03-17 15:38:07 +00:00
mbsantiago
56f6affc72 use train config in training module 2026-03-17 14:16:40 +00:00
mbsantiago
65bd0dc6ae Restructure model config 2026-03-17 13:33:13 +00:00
mbsantiago
1a7c0b4b3a Removing legacy types 2026-03-17 12:53:03 +00:00
mbsantiago
8ac4f4c44d Add dynamic imports to existing registries 2026-03-16 10:04:34 +00:00
mbsantiago
038d58ed99 Add config for dynamic imports, and tests 2026-03-16 09:30:23 +00:00
mbsantiago
47f418a63c Add test for annotated dataset load function 2026-03-15 21:33:14 +00:00
mbsantiago
197cc38e3e Add tests for registry 2026-03-15 21:17:25 +00:00
mbsantiago
e0503487ec Remove deprecated types 2026-03-15 21:17:20 +00:00
mbsantiago
4b7d23abde Create data annotation loader registry 2026-03-15 20:53:59 +00:00
mbsantiago
3c337a06cb Add architecture document 2026-03-11 19:08:18 +00:00
mbsantiago
0d590a26cc Update tests 2026-03-08 18:41:05 +00:00
mbsantiago
46c02962f3 Add test for preprocessing 2026-03-08 17:11:27 +00:00
mbsantiago
bfc88a4a0f Update audio loader docstrings 2026-03-08 17:02:35 +00:00
mbsantiago
ce952e364b Update audio module docstrings 2026-03-08 17:02:25 +00:00
mbsantiago
ef3348d651 Update model docstrings 2026-03-08 16:34:17 +00:00
mbsantiago
4207661da4 Add test for backbones 2026-03-08 16:04:54 +00:00
mbsantiago
45e3cf1434 Modify example config to add name 2026-03-08 15:23:14 +00:00
mbsantiago
f2d5088bec Run formatter 2026-03-08 15:18:21 +00:00
mbsantiago
652d076b46 Add backbone registry 2026-03-08 15:02:56 +00:00
mbsantiago
e393709258 Add interfaces for encoder/decoder/bottleneck 2026-03-08 14:43:16 +00:00
mbsantiago
54605ef269 Add blocks and detector tests 2026-03-08 14:17:47 +00:00
mbsantiago
b8b8a68f49 Add gradio as optional group 2026-03-08 13:17:32 +00:00
mbsantiago
6812e1c515 Update default detection class config 2026-03-08 13:17:23 +00:00
mbsantiago
0b344003a1 Ignore scripts for now 2026-03-08 13:17:12 +00:00
mbsantiago
c9bcaebcde Ignore notebooks for now 2026-03-08 13:17:03 +00:00
mbsantiago
d52e988b8f Fix type errors 2026-03-08 12:55:36 +00:00
mbsantiago
cce1b49a8d Run formatting 2026-03-08 08:59:28 +00:00
mbsantiago
8313fe1484 Minor formatting 2026-03-07 16:09:43 +00:00
mbsantiago
4509602e70 Run automated fixes 2025-12-12 21:29:25 +00:00
mbsantiago
0adb58e039 Run formatter 2025-12-12 21:28:28 +00:00
mbsantiago
531ff69974 Cleanup evaluate 2025-12-12 21:14:31 +00:00
mbsantiago
750f9e43c4 Make sure threshold is used 2025-12-12 19:53:15 +00:00
mbsantiago
f71fe0c2e2 Using matching and affinity functions from soundevent 2025-12-12 19:25:01 +00:00
mbsantiago
113f438e74 Run lint fixes 2025-12-08 17:19:33 +00:00
mbsantiago
2563f26ed3 Update type hints to python 3.10 2025-12-08 17:14:50 +00:00
mbsantiago
9c72537ddd Add parquet format for outputs 2025-12-08 17:11:35 +00:00
mbsantiago
72278d75ec Upgrade soundevent to 2.10 2025-12-08 17:11:22 +00:00
216 changed files with 11302 additions and 9484 deletions

7
.gitignore vendored
View File

@ -102,7 +102,7 @@ experiments/*
DvcLiveLogger/checkpoints DvcLiveLogger/checkpoints
logs/ logs/
mlruns/ mlruns/
outputs/ /outputs/
notebooks/lightning_logs notebooks/lightning_logs
# Jupiter notebooks # Jupiter notebooks
@ -123,3 +123,8 @@ example_data/preprocessed
# Dev notebooks # Dev notebooks
notebooks/tmp notebooks/tmp
/tmp
/.agents/skills
/notebooks
/AGENTS.md
/scripts

View File

@ -0,0 +1,93 @@
# BatDetect2 Architecture Overview
This document provides a comprehensive map of the `batdetect2` codebase architecture. It is intended to serve as a deep-dive reference for developers, agents, and contributors navigating the project.
`batdetect2` is designed as a modular deep learning pipeline for detecting and classifying bat echolocation calls in high-frequency audio recordings. It heavily utilizes **PyTorch**, **PyTorch Lightning** for training, and the **Soundevent** library for standardized audio and geometry data classes.
The repository follows a configuration-driven design pattern, heavily utilizing `pydantic`/`omegaconf` (via `BaseConfig`) and the Factory/Registry patterns for dependency injection and modularity. The entire pipeline can be orchestrated via the high-level API `BatDetect2API` (`src/batdetect2/api_v2.py`).
---
## 1. Data Flow Pipeline
The standard lifecycle of a prediction request follows these sequential stages, each handled by an isolated, replaceable module:
1. **Audio Loading (`batdetect2.audio`)**: Read raw `.wav` files into standard NumPy arrays or `soundevent.data.Clip` objects. Handles resampling.
2. **Preprocessing (`batdetect2.preprocess`)**: Converts raw 1D waveforms into 2D Spectrogram tensors.
3. **Forward Pass (`batdetect2.models`)**: A PyTorch neural network processes the spectrogram and outputs dense prediction tensors (e.g., detection heatmaps, bounding box sizes, class probabilities).
4. **Postprocessing (`batdetect2.postprocess`)**: Decodes the raw output tensors back into explicit geometry bounding boxes and runs Non-Maximum Suppression (NMS) to filter redundant predictions.
5. **Formatting (`batdetect2.data`)**: Transforms the predictions into standard formats (`.csv`, `.json`, `.parquet`) using `OutputFormatterProtocol`.
---
## 2. Core Modules Breakdown
### 2.1 Audio and Preprocessing
- **`audio/`**:
- Centralizes audio I/O using `AudioLoader`. It abstracts over the `soundevent` library, efficiently handling full `Recording` files or smaller `Clip` segments, standardizing the sample rate.
- **`preprocess/`**:
- Dictated by the `PreprocessorProtocol`.
- Its primary responsibility is spectrogram generation via Short-Time Fourier Transform (STFT).
- During training, it incorporates data augmentation layers (e.g., amplitude scaling, time masking, frequency masking, spectral mean subtraction) configured via `PreprocessingConfig`.
### 2.2 Deep Learning Models (`models/`)
The `models` directory contains all PyTorch neural network architectures. The default architecture is an Encoder-Decoder (U-Net style) network.
- **`blocks.py`**: Reusable neural network blocks, including standard Convolutions (`ConvBlock`) and specialized layers like `FreqCoordConvDownBlock`/`FreqCoordConvUpBlock` which append normalized spatial frequency coordinates to explicitly grant convolutional filters frequency-awareness.
- **`encoder.py`**: The downsampling path (feature extraction). Builds a sequential list of blocks and captures skip connections.
- **`bottleneck.py`**: The deepest, lowest-resolution segment connecting the Encoder and Decoder. Features an optional `SelfAttention` mechanism to weigh global temporal contexts.
- **`decoder.py`**: The upsampling path (reconstruction), actively integrating skip connections (residuals) from the Encoder.
- **`heads.py`**: Attach to the backbone's feature map to output specific predictions:
- `BBoxHead`: Predicts bounding box sizes.
- `ClassifierHead`: Predicts species classes.
- `DetectorHead`: Predicts detection probability heatmaps.
- **`backbones.py` & `detectors.py`**: Assemble the encoder, bottleneck, decoder, and heads into a cohesive `Detector` model.
- **`__init__.py:Model`**: The overarching wrapper `torch.nn.Module` containing the `detector`, `preprocessor`, `postprocessor`, and `targets`.
### 2.3 Targets and Regions of Interest (`targets/`)
Crucial for training, this module translates physical annotations (Regions of Interest / ROIs) into training targets (tensors).
- **`rois.py`**: Implements `ROITargetMapper`. Maps a geometric bounding box into a 2D reference `Position` (time, freq) and a `Size` array. Includes strategies like:
- `AnchorBBoxMapper`: Maps based on a fixed bounding box corner/center.
- `PeakEnergyBBoxMapper`: Identifies the physical coordinate of peak acoustic energy inside the bounding box and calculates offsets to the box edges.
- **`targets.py`**: Reconstructs complete multi-channel target heatmaps and coordinate tensors from the ROIs to compute losses during training.
### 2.4 Postprocessing (`postprocess/`)
- Implements `PostprocessorProtocol`.
- Reverses the logic from `targets`. It scans the model's output detection heatmaps for peaks, extracts the predicted sizes and class probabilities at those peaks, and decodes them back into physical `soundevent.data.Geometry` (Bounding Boxes).
- Automatically applies Non-Maximum Suppression (NMS) configured via `PostprocessConfig` to remove highly overlapping predictions.
### 2.5 Data Management (`data/`)
- **`annotations/`**: Utilities to load dataset annotations supporting multiple standardized schemas (`AOEF`, `BatDetect2` formats).
- **`datasets.py`**: Aggregates recordings and annotations into memory.
- **`predictions/`**: Handles the exporting of model results via `OutputFormatterProtocol`. Includes formatters for `RawOutput`, `.parquet`, `.json`, etc.
### 2.6 Evaluation (`evaluate/`)
- Computes scientific metrics using `EvaluatorProtocol`.
- Provides specific testing environments for tasks like `Clip Classification`, `Clip Detection`, and `Top Class` predictions.
- Generates precision-recall curves and scatter plots.
### 2.7 Training (`train/`)
- Implements the distributed PyTorch training loop via PyTorch Lightning.
- **`lightning.py`**: Contains `TrainingModule`, the `LightningModule` that orchestrates the optimizer, learning rate scheduler, forward passes, and backpropagation using the generated `targets`.
---
## 3. Interfaces and Tooling
### 3.1 APIs
- **`api_v2.py` (`BatDetect2API`)**: The modern API object. It is deeply integrated with dependency injection using `BatDetect2Config`. It instantiates the loader, targets, preprocessor, postprocessor, and model, exposing easy-to-use methods like `process_file`, `evaluate`, and `train`.
- **`api.py`**: The legacy API. Kept for backwards compatibility. Uses hardcoded default instances rather than configuration objects.
### 3.2 Command Line Interface (`cli/`)
- Implements terminal commands utilizing `click`. Commands include `batdetect2 detect`, `evaluate`, and `train`.
### 3.3 Core and Configuration (`core/`, `config.py`)
- **`core/registries.py`**: A string-based Registry pattern (e.g., `block_registry`, `roi_mapper_registry`) that allows developers to dynamically swap components (like a custom neural network block) via configuration files without modifying python code.
- **`config.py`**: Aggregates all modular `BaseConfig` objects (`AudioConfig`, `PreprocessingConfig`, `BackboneConfig`) into the monolithic `BatDetect2Config`.
---
## Summary
To navigate this codebase effectively:
1. Follow **`api_v2.py`** to see how high-level operations invoke individual components.
2. Rely heavily on the typed **Protocols** located in each subsystem's `types.py` module (for example `src/batdetect2/preprocess/types.py` and `src/batdetect2/postprocess/types.py`) to understand inputs and outputs without needing to read each implementation.
3. Understand that data flows structurally as `soundevent` primitives externally, and as pure `torch.Tensor` internally through the network.

View File

@ -6,6 +6,7 @@ Hi!
:maxdepth: 1 :maxdepth: 1
:caption: Contents: :caption: Contents:
architecture
data/index data/index
preprocessing/index preprocessing/index
postprocessing postprocessing

View File

@ -1,70 +1,79 @@
config_version: v1
audio: audio:
samplerate: 256000 samplerate: 256000
resample: resample:
enabled: True enabled: true
method: "poly" method: poly
preprocess:
stft:
window_duration: 0.002
window_overlap: 0.75
window_fn: hann
frequencies:
max_freq: 120000
min_freq: 10000
size:
height: 128
resize_factor: 0.5
spectrogram_transforms:
- name: pcen
time_constant: 0.1
gain: 0.98
bias: 2
power: 0.5
- name: spectral_mean_substraction
postprocess:
nms_kernel_size: 9
detection_threshold: 0.01
top_k_per_sec: 200
model: model:
input_height: 128 samplerate: 256000
in_channels: 1
out_channels: 32 preprocess:
encoder: stft:
layers: window_duration: 0.002
- name: FreqCoordConvDown window_overlap: 0.75
out_channels: 32 window_fn: hann
- name: FreqCoordConvDown frequencies:
out_channels: 64 max_freq: 120000
- name: LayerGroup min_freq: 10000
layers: size:
- name: FreqCoordConvDown height: 128
out_channels: 128 resize_factor: 0.5
- name: ConvBlock spectrogram_transforms:
out_channels: 256 - name: pcen
bottleneck: time_constant: 0.1
channels: 256 gain: 0.98
layers: bias: 2
- name: SelfAttention power: 0.5
attention_channels: 256 - name: spectral_mean_subtraction
decoder:
layers: architecture:
- name: FreqCoordConvUp name: UNetBackbone
out_channels: 64 input_height: 128
- name: FreqCoordConvUp in_channels: 1
out_channels: 32 encoder:
- name: LayerGroup layers:
layers: - name: FreqCoordConvDown
- name: FreqCoordConvUp out_channels: 32
out_channels: 32 - name: FreqCoordConvDown
- name: ConvBlock out_channels: 64
out_channels: 32 - name: LayerGroup
layers:
- name: FreqCoordConvDown
out_channels: 128
- name: ConvBlock
out_channels: 256
bottleneck:
channels: 256
layers:
- name: SelfAttention
attention_channels: 256
decoder:
layers:
- name: FreqCoordConvUp
out_channels: 64
- name: FreqCoordConvUp
out_channels: 32
- name: LayerGroup
layers:
- name: FreqCoordConvUp
out_channels: 32
- name: ConvBlock
out_channels: 32
postprocess:
nms_kernel_size: 9
detection_threshold: 0.01
top_k_per_sec: 200
train: train:
optimizer: optimizer:
name: adam
learning_rate: 0.001 learning_rate: 0.001
scheduler:
name: cosine_annealing
t_max: 100 t_max: 100
labels: labels:
@ -76,10 +85,7 @@ train:
train_loader: train_loader:
batch_size: 8 batch_size: 8
shuffle: true
num_workers: 2
shuffle: True
clipping_strategy: clipping_strategy:
name: random_subclip name: random_subclip
@ -115,7 +121,6 @@ train:
max_masks: 3 max_masks: 3
val_loader: val_loader:
num_workers: 2
clipping_strategy: clipping_strategy:
name: whole_audio_padded name: whole_audio_padded
chunk_size: 0.256 chunk_size: 0.256
@ -134,9 +139,6 @@ train:
size: size:
weight: 0.1 weight: 0.1
logger:
name: csv
validation: validation:
tasks: tasks:
- name: sound_event_detection - name: sound_event_detection
@ -146,6 +148,10 @@ train:
metrics: metrics:
- name: average_precision - name: average_precision
logging:
train:
name: csv
evaluation: evaluation:
tasks: tasks:
- name: sound_event_detection - name: sound_event_detection

View File

@ -14,60 +14,67 @@ HTML_COVERAGE_DIR := "htmlcov"
help: help:
@just --list @just --list
install:
uv sync
# Testing & Coverage # Testing & Coverage
# Run tests using pytest. # Run tests using pytest.
test: test:
pytest {{TESTS_DIR}} uv run pytest {{TESTS_DIR}}
# Run tests and generate coverage data. # Run tests and generate coverage data.
coverage: coverage:
pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}} uv run pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
# Generate an HTML coverage report. # Generate an HTML coverage report.
coverage-html: coverage coverage-html: coverage
@echo "Generating HTML coverage report..." @echo "Generating HTML coverage report..."
coverage html -d {{HTML_COVERAGE_DIR}} uv run coverage html -d {{HTML_COVERAGE_DIR}}
@echo "HTML coverage report generated in {{HTML_COVERAGE_DIR}}/" @echo "HTML coverage report generated in {{HTML_COVERAGE_DIR}}/"
# Serve the HTML coverage report locally. # Serve the HTML coverage report locally.
coverage-serve: coverage-html coverage-serve: coverage-html
@echo "Serving report at http://localhost:8000/ ..." @echo "Serving report at http://localhost:8000/ ..."
python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000 uv run python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
# Documentation # Documentation
# Build documentation using Sphinx. # Build documentation using Sphinx.
docs: docs:
sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}} uv run sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
# Serve documentation with live reload. # Serve documentation with live reload.
docs-serve: docs-serve:
sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser uv run sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
# Formatting & Linting # Formatting & Linting
# Format code using ruff. # Format code using ruff.
format: fix-format:
ruff format {{PYTHON_DIRS}} uv run ruff format {{PYTHON_DIRS}}
# Check code formatting using ruff.
format-check:
ruff format --check {{PYTHON_DIRS}}
# Lint code using ruff.
lint:
ruff check {{PYTHON_DIRS}}
# Lint code using ruff and apply automatic fixes. # Lint code using ruff and apply automatic fixes.
lint-fix: fix-lint:
ruff check --fix {{PYTHON_DIRS}} uv run ruff check --fix {{PYTHON_DIRS}}
# Combined Formatting & Linting
fix: fix-format fix-lint
# Checking tasks
# Check code formatting using ruff.
check-format:
uv run ruff format --check {{PYTHON_DIRS}}
# Lint code using ruff.
check-lint:
uv run ruff check {{PYTHON_DIRS}}
# Type Checking # Type Checking
# Type check code using pyright. # Type check code using ty.
typecheck: check-types:
pyright {{PYTHON_DIRS}} uv run ty check {{PYTHON_DIRS}}
# Combined Checks # Combined Checks
# Run all checks (format-check, lint, typecheck). # Run all checks (format-check, lint, typecheck).
check: format-check lint typecheck test check: check-format check-lint check-types
# Cleaning tasks # Cleaning tasks
# Remove Python bytecode and cache. # Remove Python bytecode and cache.
@ -95,7 +102,7 @@ clean: clean-build clean-pyc clean-test clean-docs
# Train on example data. # Train on example data.
example-train OPTIONS="": example-train OPTIONS="":
batdetect2 train \ uv run batdetect2 train \
--val-dataset example_data/dataset.yaml \ --val-dataset example_data/dataset.yaml \
--config example_data/config.yaml \ --config example_data/config.yaml \
{{OPTIONS}} \ {{OPTIONS}} \

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,440 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "cfb0b360-a204-4c27-a18f-3902e8758879",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:02.699871Z",
"iopub.status.busy": "2024-11-19T17:33:02.699590Z",
"iopub.status.idle": "2024-11-19T17:33:02.710312Z",
"shell.execute_reply": "2024-11-19T17:33:02.709798Z",
"shell.execute_reply.started": "2024-11-19T17:33:02.699839Z"
}
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "326c5432-94e6-4abf-a332-fe902559461b",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:02.711324Z",
"iopub.status.busy": "2024-11-19T17:33:02.711067Z",
"iopub.status.idle": "2024-11-19T17:33:09.092380Z",
"shell.execute_reply": "2024-11-19T17:33:09.091830Z",
"shell.execute_reply.started": "2024-11-19T17:33:02.711304Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/santiago/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from pathlib import Path\n",
"from typing import List, Optional\n",
"import torch\n",
"\n",
"import pytorch_lightning as pl\n",
"from batdetect2.train.modules import DetectorModel\n",
"from batdetect2.train.augmentations import (\n",
" add_echo,\n",
" select_random_subclip,\n",
" warp_spectrogram,\n",
")\n",
"from batdetect2.train.dataset import LabeledDataset, get_files\n",
"from batdetect2.train.preprocess import PreprocessingConfig\n",
"from soundevent import data\n",
"import matplotlib.pyplot as plt\n",
"from soundevent.types import ClassMapper\n",
"from torch.utils.data import DataLoader"
]
},
{
"cell_type": "markdown",
"id": "9402a473-0b25-4123-9fa8-ad1f71a4237a",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-18T22:39:12.395329Z",
"iopub.status.busy": "2024-11-18T22:39:12.393444Z",
"iopub.status.idle": "2024-11-18T22:39:12.405938Z",
"shell.execute_reply": "2024-11-18T22:39:12.402980Z",
"shell.execute_reply.started": "2024-11-18T22:39:12.395236Z"
}
},
"source": [
"## Training Datasets"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cfd97d83-8c2b-46c8-9eae-cea59f53bc61",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:09.093487Z",
"iopub.status.busy": "2024-11-19T17:33:09.092990Z",
"iopub.status.idle": "2024-11-19T17:33:09.121636Z",
"shell.execute_reply": "2024-11-19T17:33:09.121143Z",
"shell.execute_reply.started": "2024-11-19T17:33:09.093459Z"
}
},
"outputs": [],
"source": [
"data_dir = Path.cwd().parent / \"example_data\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d5131ae9-2efd-4758-b6e5-189a6d90789b",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:09.122685Z",
"iopub.status.busy": "2024-11-19T17:33:09.122270Z",
"iopub.status.idle": "2024-11-19T17:33:09.151386Z",
"shell.execute_reply": "2024-11-19T17:33:09.150788Z",
"shell.execute_reply.started": "2024-11-19T17:33:09.122661Z"
}
},
"outputs": [],
"source": [
"files = get_files(data_dir / \"preprocessed\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "bc733d3d-7829-4e90-896d-a0dc76b33288",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:09.152327Z",
"iopub.status.busy": "2024-11-19T17:33:09.152060Z",
"iopub.status.idle": "2024-11-19T17:33:09.184041Z",
"shell.execute_reply": "2024-11-19T17:33:09.183372Z",
"shell.execute_reply.started": "2024-11-19T17:33:09.152305Z"
}
},
"outputs": [],
"source": [
"train_dataset = LabeledDataset(files)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "dfbb94ab-7b12-4689-9c15-4dc34cd17cb2",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:09.186393Z",
"iopub.status.busy": "2024-11-19T17:33:09.186117Z",
"iopub.status.idle": "2024-11-19T17:33:09.220175Z",
"shell.execute_reply": "2024-11-19T17:33:09.219322Z",
"shell.execute_reply.started": "2024-11-19T17:33:09.186375Z"
}
},
"outputs": [],
"source": [
"train_dataloader = DataLoader(\n",
" train_dataset,\n",
" shuffle=True,\n",
" batch_size=32,\n",
" num_workers=4,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e2eedaa9-6be3-481a-8786-7618515d98f8",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:09.221653Z",
"iopub.status.busy": "2024-11-19T17:33:09.221242Z",
"iopub.status.idle": "2024-11-19T17:33:09.260977Z",
"shell.execute_reply": "2024-11-19T17:33:09.260375Z",
"shell.execute_reply.started": "2024-11-19T17:33:09.221616Z"
}
},
"outputs": [],
"source": [
"# List of all possible classes\n",
"class Mapper(ClassMapper):\n",
" class_labels = [\n",
" \"Eptesicus serotinus\",\n",
" \"Myotis mystacinus\",\n",
" \"Pipistrellus pipistrellus\",\n",
" \"Rhinolophus ferrumequinum\",\n",
" ]\n",
"\n",
" def encode(self, x: data.SoundEventAnnotation) -> Optional[str]:\n",
" event_tag = data.find_tag(x.tags, \"event\")\n",
"\n",
" if event_tag.value == \"Social\":\n",
" return \"social\"\n",
"\n",
" if event_tag.value != \"Echolocation\":\n",
" # Ignore all other types of calls\n",
" return None\n",
"\n",
" species_tag = data.find_tag(x.tags, \"class\")\n",
" return species_tag.value\n",
"\n",
" def decode(self, class_name: str) -> List[data.Tag]:\n",
" if class_name == \"social\":\n",
" return [data.Tag(key=\"event\", value=\"social\")]\n",
"\n",
" return [data.Tag(key=\"class\", value=class_name)]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1ff6072c-511e-42fe-a74f-282f269b80f0",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:09.262337Z",
"iopub.status.busy": "2024-11-19T17:33:09.261775Z",
"iopub.status.idle": "2024-11-19T17:33:09.309793Z",
"shell.execute_reply": "2024-11-19T17:33:09.309216Z",
"shell.execute_reply.started": "2024-11-19T17:33:09.262307Z"
}
},
"outputs": [],
"source": [
"detector = DetectorModel(class_mapper=Mapper())"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3a763ee6-15bc-4105-a409-f06e0ad21a06",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:09.310695Z",
"iopub.status.busy": "2024-11-19T17:33:09.310438Z",
"iopub.status.idle": "2024-11-19T17:33:09.366636Z",
"shell.execute_reply": "2024-11-19T17:33:09.366059Z",
"shell.execute_reply.started": "2024-11-19T17:33:09.310669Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: False, used: False\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n"
]
}
],
"source": [
"trainer = pl.Trainer(\n",
" limit_train_batches=100,\n",
" max_epochs=2,\n",
" log_every_n_steps=1,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0b86d49d-3314-4257-94f5-f964855be385",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:09.367499Z",
"iopub.status.busy": "2024-11-19T17:33:09.367242Z",
"iopub.status.idle": "2024-11-19T17:33:10.811300Z",
"shell.execute_reply": "2024-11-19T17:33:10.809823Z",
"shell.execute_reply.started": "2024-11-19T17:33:09.367473Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
" | Name | Type | Params | Mode \n",
"--------------------------------------------------------\n",
"0 | feature_extractor | Net2DFast | 119 K | train\n",
"1 | classifier | Conv2d | 54 | train\n",
"2 | bbox | Conv2d | 18 | train\n",
"--------------------------------------------------------\n",
"119 K Trainable params\n",
"448 Non-trainable params\n",
"119 K Total params\n",
"0.480 Total estimated model params size (MB)\n",
"32 Modules in train mode\n",
"0 Modules in eval mode\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0: 0%| | 0/1 [00:00<?, ?it/s]class heatmap shape torch.Size([3, 4, 128, 512])\n",
"class props shape torch.Size([3, 5, 128, 512])\n"
]
},
{
"ename": "RuntimeError",
"evalue": "The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdetector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:538\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 536\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m=\u001b[39m TrainerStatus\u001b[38;5;241m.\u001b[39mRUNNING\n\u001b[1;32m 537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 538\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 539\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 540\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:47\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 47\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 50\u001b[0m _call_teardown_hook(trainer)\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:574\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 568\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 569\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 570\u001b[0m ckpt_path,\n\u001b[1;32m 571\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 572\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 573\u001b[0m )\n\u001b[0;32m--> 574\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 576\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 577\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:981\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 976\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 978\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 981\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 983\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 984\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 985\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 986\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:1025\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1023\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 1024\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1025\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1026\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1027\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:205\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:363\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 363\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone:\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end(data_fetcher)\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:250\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.advance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_batch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mlightning_module\u001b[38;5;241m.\u001b[39mautomatic_optimization:\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# in automatic optimization, there can only be one optimizer\u001b[39;00m\n\u001b[0;32m--> 250\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautomatic_optimization\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 252\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_optimization\u001b[38;5;241m.\u001b[39mrun(kwargs)\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:190\u001b[0m, in \u001b[0;36m_AutomaticOptimization.run\u001b[0;34m(self, optimizer, batch_idx, kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m closure()\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# BACKWARD PASS\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# gradient update with accumulated gradients\u001b[39;00m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 190\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 192\u001b[0m result \u001b[38;5;241m=\u001b[39m closure\u001b[38;5;241m.\u001b[39mconsume_result()\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:268\u001b[0m, in \u001b[0;36m_AutomaticOptimization._optimizer_step\u001b[0;34m(self, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_ready()\n\u001b[1;32m 267\u001b[0m \u001b[38;5;66;03m# model hook\u001b[39;00m\n\u001b[0;32m--> 268\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_lightning_module_hook\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moptimizer_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_step_and_backward_closure\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m should_accumulate:\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_completed()\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:167\u001b[0m, in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(trainer, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m hook_name\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[LightningModule]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpl_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 167\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 170\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/module.py:1306\u001b[0m, in \u001b[0;36mLightningModule.optimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_closure)\u001b[0m\n\u001b[1;32m 1275\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21moptimizer_step\u001b[39m(\n\u001b[1;32m 1276\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1277\u001b[0m epoch: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1280\u001b[0m optimizer_closure: Optional[Callable[[], Any]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1281\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls\u001b[39;00m\n\u001b[1;32m 1283\u001b[0m \u001b[38;5;124;03m the optimizer.\u001b[39;00m\n\u001b[1;32m 1284\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1304\u001b[0m \n\u001b[1;32m 1305\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1306\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer_closure\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/optimizer.py:153\u001b[0m, in \u001b[0;36mLightningOptimizer.step\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MisconfigurationException(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWhen `optimizer.step(closure)` is called, the closure should be callable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_strategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 153\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_strategy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_after_step()\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m step_output\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:238\u001b[0m, in \u001b[0;36mStrategy.optimizer_step\u001b[0;34m(self, optimizer, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;66;03m# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed\u001b[39;00m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, pl\u001b[38;5;241m.\u001b[39mLightningModule)\n\u001b[0;32m--> 238\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_plugin\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:122\u001b[0m, in \u001b[0;36mPrecision.optimizer_step\u001b[0;34m(self, optimizer, model, closure, **kwargs)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Hook to run the optimizer step.\"\"\"\u001b[39;00m\n\u001b[1;32m 121\u001b[0m closure \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wrap_closure, model, optimizer, closure)\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:130\u001b[0m, in \u001b[0;36mLRScheduler.__init__.<locals>.patch_track_step_called.<locals>.wrap_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 128\u001b[0m opt \u001b[38;5;241m=\u001b[39m opt_ref()\n\u001b[1;32m 129\u001b[0m opt\u001b[38;5;241m.\u001b[39m_opt_called \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m \u001b[38;5;66;03m# type: ignore[union-attr]\u001b[39;00m\n\u001b[0;32m--> 130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__get__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mopt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__class__\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:484\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 481\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 482\u001b[0m )\n\u001b[0;32m--> 484\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 485\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 487\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:89\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.<locals>._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 87\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 88\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 89\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 91\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/adam.py:205\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m closure \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39menable_grad():\n\u001b[0;32m--> 205\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparam_groups:\n\u001b[1;32m 208\u001b[0m params_with_grad: List[Tensor] \u001b[38;5;241m=\u001b[39m []\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:108\u001b[0m, in \u001b[0;36mPrecision._wrap_closure\u001b[0;34m(self, model, optimizer, closure)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrap_closure\u001b[39m(\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 97\u001b[0m model: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpl.LightningModule\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 98\u001b[0m optimizer: Steppable,\n\u001b[1;32m 99\u001b[0m closure: Callable[[], Any],\n\u001b[1;32m 100\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 101\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;124;03m hook is called.\u001b[39;00m\n\u001b[1;32m 103\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 108\u001b[0m closure_result \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_after_closure(model, optimizer)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m closure_result\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:144\u001b[0m, in \u001b[0;36mClosure.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Optional[Tensor]:\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result\u001b[38;5;241m.\u001b[39mloss\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:129\u001b[0m, in \u001b[0;36mClosure.closure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39menable_grad()\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mclosure\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ClosureResult:\n\u001b[0;32m--> 129\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step_output\u001b[38;5;241m.\u001b[39mclosure_loss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwarning_cache\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`training_step` returned `None`. If this was on purpose, ignore this warning...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:317\u001b[0m, in \u001b[0;36m_AutomaticOptimization._training_step\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Performs the actual train step with the tied hooks.\u001b[39;00m\n\u001b[1;32m 307\u001b[0m \n\u001b[1;32m 308\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 313\u001b[0m \n\u001b[1;32m 314\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 315\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\n\u001b[0;32m--> 317\u001b[0m training_step_output \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtraining_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mpost_training_step() \u001b[38;5;66;03m# unused hook - call anyway for backward compatibility\u001b[39;00m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m training_step_output \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mworld_size \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:319\u001b[0m, in \u001b[0;36m_call_strategy_hook\u001b[0;34m(trainer, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 319\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 322\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.training_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module:\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_redirection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_step\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:167\u001b[0m, in \u001b[0;36mDetectorModel.training_step\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtraining_step\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch: TrainExample):\n\u001b[1;32m 166\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(batch\u001b[38;5;241m.\u001b[39mspec)\n\u001b[0;32m--> 167\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:150\u001b[0m, in \u001b[0;36mDetectorModel.compute_loss\u001b[0;34m(self, outputs, batch)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mclass props shape\u001b[39m\u001b[38;5;124m\"\u001b[39m, outputs\u001b[38;5;241m.\u001b[39mclass_probs\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 149\u001b[0m valid_mask \u001b[38;5;241m=\u001b[39m batch\u001b[38;5;241m.\u001b[39mclass_heatmap\u001b[38;5;241m.\u001b[39many(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m--> 150\u001b[0m classification_loss \u001b[38;5;241m=\u001b[39m \u001b[43mlosses\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal_loss\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[43m \u001b[49m\u001b[43moutputs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_probs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 152\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_heatmap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 153\u001b[0m \u001b[43m \u001b[49m\u001b[43mweights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalid_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalid_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbeta\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 156\u001b[0m \u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43malpha\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 157\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 160\u001b[0m detection_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mdetection\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 161\u001b[0m \u001b[38;5;241m+\u001b[39m size_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39msize\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 162\u001b[0m \u001b[38;5;241m+\u001b[39m classification_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mclassification\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 163\u001b[0m )\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/losses.py:38\u001b[0m, in \u001b[0;36mfocal_loss\u001b[0;34m(pred, gt, weights, valid_mask, eps, beta, alpha)\u001b[0m\n\u001b[1;32m 35\u001b[0m pos_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39meq(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[1;32m 36\u001b[0m neg_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39mlt(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m---> 38\u001b[0m pos_loss \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpred\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpow\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpred\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpos_inds\u001b[49m\n\u001b[1;32m 39\u001b[0m neg_loss \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 40\u001b[0m torch\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m pred \u001b[38;5;241m+\u001b[39m eps)\n\u001b[1;32m 41\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(pred, alpha)\n\u001b[1;32m 42\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m gt, beta)\n\u001b[1;32m 43\u001b[0m \u001b[38;5;241m*\u001b[39m neg_inds\n\u001b[1;32m 44\u001b[0m )\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m weights \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1"
]
}
],
"source": [
"trainer.fit(detector, train_dataloaders=train_dataloader)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f6924db-e520-49a1-bbe8-6c4956e46314",
"metadata": {
"execution": {
"iopub.status.busy": "2024-11-19T17:33:10.811729Z",
"iopub.status.idle": "2024-11-19T17:33:10.811955Z",
"shell.execute_reply": "2024-11-19T17:33:10.811858Z",
"shell.execute_reply.started": "2024-11-19T17:33:10.811849Z"
}
},
"outputs": [],
"source": [
"clip_annotation = train_dataset.get_clip_annotation(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "23943e13-6875-49b8-9f18-2ba6528aa673",
"metadata": {
"execution": {
"iopub.status.busy": "2024-11-19T17:33:10.812924Z",
"iopub.status.idle": "2024-11-19T17:33:10.813260Z",
"shell.execute_reply": "2024-11-19T17:33:10.813104Z",
"shell.execute_reply.started": "2024-11-19T17:33:10.813087Z"
}
},
"outputs": [],
"source": [
"spec = detector.compute_spectrogram(clip_annotation.clip)\n",
"outputs = detector(torch.tensor(spec.values).unsqueeze(0).unsqueeze(0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd1fe346-0873-4b14-ae1b-92ef1f4f27a5",
"metadata": {
"execution": {
"iopub.status.busy": "2024-11-19T17:33:10.814343Z",
"iopub.status.idle": "2024-11-19T17:33:10.814806Z",
"shell.execute_reply": "2024-11-19T17:33:10.814628Z",
"shell.execute_reply.started": "2024-11-19T17:33:10.814611Z"
}
},
"outputs": [],
"source": [
"_, ax= plt.subplots(figsize=(15, 5))\n",
"spec.plot(ax=ax, add_colorbar=False)\n",
"ax.pcolormesh(spec.time, spec.frequency, outputs.detection_probs.detach().squeeze())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eadd36ef-a04a-4665-b703-cec84cf1673b",
"metadata": {
"execution": {
"iopub.status.busy": "2024-11-19T17:33:10.815603Z",
"iopub.status.idle": "2024-11-19T17:33:10.816065Z",
"shell.execute_reply": "2024-11-19T17:33:10.815894Z",
"shell.execute_reply.started": "2024-11-19T17:33:10.815877Z"
}
},
"outputs": [],
"source": [
"print(f\"Num predicted soundevents: {len(predictions.sound_events)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e4e54f3e-6ddc-4fe5-8ce0-b527ff6f18ae",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "batdetect2-dev",
"language": "python",
"name": "batdetect2-dev"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -1,194 +0,0 @@
import marimo
__generated_with = "0.14.16"
app = marimo.App(width="medium")
@app.cell
def _():
import marimo as mo
return (mo,)
@app.cell
def _():
from batdetect2.data import (
load_dataset_config,
load_dataset,
extract_recordings_df,
extract_sound_events_df,
compute_class_summary,
)
return (
compute_class_summary,
extract_recordings_df,
extract_sound_events_df,
load_dataset,
load_dataset_config,
)
@app.cell
def _(mo):
dataset_config_browser = mo.ui.file_browser(
selection_mode="file",
multiple=False,
)
dataset_config_browser
return (dataset_config_browser,)
@app.cell
def _(dataset_config_browser, load_dataset_config, mo):
mo.stop(dataset_config_browser.path() is None)
dataset_config = load_dataset_config(dataset_config_browser.path())
return (dataset_config,)
@app.cell
def _(dataset_config, load_dataset):
dataset = load_dataset(dataset_config, base_dir="../paper/")
return (dataset,)
@app.cell
def _():
from batdetect2.targets import load_target_config, build_targets
return build_targets, load_target_config
@app.cell
def _(mo):
targets_config_browser = mo.ui.file_browser(
selection_mode="file",
multiple=False,
)
targets_config_browser
return (targets_config_browser,)
@app.cell
def _(load_target_config, mo, targets_config_browser):
mo.stop(targets_config_browser.path() is None)
targets_config = load_target_config(targets_config_browser.path())
return (targets_config,)
@app.cell
def _(build_targets, targets_config):
targets = build_targets(targets_config)
return (targets,)
@app.cell
def _():
import pandas as pd
from soundevent.geometry import compute_bounds
return
@app.cell
def _(dataset, extract_recordings_df):
recordings = extract_recordings_df(dataset)
recordings
return
@app.cell
def _(dataset, extract_sound_events_df, targets):
sound_events = extract_sound_events_df(dataset, targets)
sound_events
return
@app.cell
def _(compute_class_summary, dataset, targets):
compute_class_summary(dataset, targets)
return
@app.cell
def _():
from batdetect2.data.split import split_dataset_by_recordings
return (split_dataset_by_recordings,)
@app.cell
def _(dataset, split_dataset_by_recordings, targets):
train_dataset, val_dataset = split_dataset_by_recordings(dataset, targets, random_state=42)
return train_dataset, val_dataset
@app.cell
def _(compute_class_summary, targets, train_dataset):
compute_class_summary(train_dataset, targets)
return
@app.cell
def _(compute_class_summary, targets, val_dataset):
compute_class_summary(val_dataset, targets)
return
@app.cell
def _():
from soundevent import io, data
from pathlib import Path
return Path, data, io
@app.cell
def _(Path, data, io, train_dataset):
io.save(
data.AnnotationSet(
name="batdetect2_tuning_train",
description="Set of annotations used as the train dataset for the hyper-parameter tuning stage.",
clip_annotations=train_dataset,
),
Path("../paper/data/datasets/annotation_sets/tuning_train.json"),
audio_dir=Path("../paper/data/datasets/"),
)
return
@app.cell
def _(Path, data, io, val_dataset):
io.save(
data.AnnotationSet(
name="batdetect2_tuning_val",
description="Set of annotations used as the validation dataset for the hyper-parameter tuning stage.",
clip_annotations=val_dataset,
),
Path("../paper/data/datasets/annotation_sets/tuning_val.json"),
audio_dir=Path("../paper/data/datasets/"),
)
return
@app.cell
def _(load_dataset, load_dataset_config):
config = load_dataset_config("../paper/conf/datasets/train/uk_tune.yaml")
rec = load_dataset(config, base_dir="../paper/")
return (rec,)
@app.cell
def _(rec):
dict(rec[0].sound_events[0].tags[0].term)
return
@app.cell
def _(compute_class_summary, rec, targets):
compute_class_summary(rec,targets)
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

View File

@ -1,273 +0,0 @@
import marimo
__generated_with = "0.14.16"
app = marimo.App(width="medium")
@app.cell
def _():
import marimo as mo
return
@app.cell
def _():
from batdetect2.data import load_dataset_config, load_dataset
from batdetect2.preprocess import load_preprocessing_config, build_preprocessor
from batdetect2 import api
from soundevent import data
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.types import Annotation
from batdetect2.compat import annotation_to_sound_event_prediction
from batdetect2.plotting import (
plot_clip,
plot_clip_annotation,
plot_clip_prediction,
plot_matches,
plot_false_positive_match,
plot_false_negative_match,
plot_true_positive_match,
plot_cross_trigger_match,
)
return (
MatchEvaluation,
annotation_to_sound_event_prediction,
api,
build_preprocessor,
data,
load_dataset,
load_dataset_config,
load_preprocessing_config,
plot_clip_annotation,
plot_clip_prediction,
plot_cross_trigger_match,
plot_false_negative_match,
plot_false_positive_match,
plot_matches,
plot_true_positive_match,
)
@app.cell
def _(build_preprocessor, load_dataset_config, load_preprocessing_config):
dataset_config = load_dataset_config(
path="example_data/config.yaml", field="datasets.train"
)
preprocessor_config = load_preprocessing_config(
path="example_data/config.yaml", field="preprocess"
)
preprocessor = build_preprocessor(preprocessor_config)
return dataset_config, preprocessor
@app.cell
def _(dataset_config, load_dataset):
dataset = load_dataset(dataset_config)
return (dataset,)
@app.cell
def _(dataset):
clip_annotation = dataset[1]
return (clip_annotation,)
@app.cell
def _(clip_annotation, plot_clip_annotation, preprocessor):
plot_clip_annotation(
clip_annotation, preprocessor=preprocessor, figsize=(15, 5)
)
return
@app.cell
def _(annotation_to_sound_event_prediction, api, clip_annotation, data):
audio = api.load_audio(clip_annotation.clip.recording.path)
detections, features, spec = api.process_audio(audio)
clip_prediction = data.ClipPrediction(
clip=clip_annotation.clip,
sound_events=[
annotation_to_sound_event_prediction(
prediction, clip_annotation.clip.recording
)
for prediction in detections
],
)
return (clip_prediction,)
@app.cell
def _(clip_prediction, plot_clip_prediction):
plot_clip_prediction(clip_prediction, figsize=(15, 5))
return
@app.cell
def _():
from batdetect2.evaluate import match_predictions_and_annotations
import random
return match_predictions_and_annotations, random
@app.cell
def _(data, random):
def add_noise(clip_annotation, time_buffer=0.003, freq_buffer=1000):
def _add_bbox_noise(bbox):
start_time, low_freq, end_time, high_freq = bbox.coordinates
return data.BoundingBox(
coordinates=[
start_time + random.uniform(-time_buffer, time_buffer),
low_freq + random.uniform(-freq_buffer, freq_buffer),
end_time + random.uniform(-time_buffer, time_buffer),
high_freq + random.uniform(-freq_buffer, freq_buffer),
]
)
def _add_noise(se):
return se.model_copy(
update=dict(
sound_event=se.sound_event.model_copy(
update=dict(
geometry=_add_bbox_noise(se.sound_event.geometry)
)
)
)
)
return clip_annotation.model_copy(
update=dict(
sound_events=[
_add_noise(se) for se in clip_annotation.sound_events
]
)
)
def drop_random(obj, p=0.5):
return obj.model_copy(
update=dict(
sound_events=[se for se in obj.sound_events if random.random() > p]
)
)
return add_noise, drop_random
@app.cell
def _(
add_noise,
clip_annotation,
clip_prediction,
drop_random,
match_predictions_and_annotations,
):
matches = match_predictions_and_annotations(
drop_random(add_noise(clip_annotation), p=0.2),
drop_random(clip_prediction),
)
return (matches,)
@app.cell
def _(clip_annotation, matches, plot_matches):
plot_matches(matches, clip_annotation.clip, figsize=(15, 5))
return
@app.cell
def _(matches):
true_positives = []
false_positives = []
false_negatives = []
for match in matches:
if match.source is None and match.target is not None:
false_negatives.append(match)
elif match.target is None and match.source is not None:
false_positives.append(match)
elif match.target is not None and match.source is not None:
true_positives.append(match)
else:
continue
return false_negatives, false_positives, true_positives
@app.cell
def _(MatchEvaluation, false_positives, plot_false_positive_match):
false_positive = false_positives[0]
false_positive_eval = MatchEvaluation(
match=false_positive,
gt_det=False,
gt_class=None,
pred_score=false_positive.source.score,
pred_class_scores={
"myomyo": 0.2
}
)
plot_false_positive_match(false_positive_eval)
return
@app.cell
def _(MatchEvaluation, false_negatives, plot_false_negative_match):
false_negative = false_negatives[0]
false_negative_eval = MatchEvaluation(
match=false_negative,
gt_det=True,
gt_class="myomyo",
pred_score=None,
pred_class_scores={}
)
plot_false_negative_match(false_negative_eval)
return
@app.cell
def _(MatchEvaluation, plot_true_positive_match, true_positives):
true_positive = true_positives[0]
true_positive_eval = MatchEvaluation(
match=true_positive,
gt_det=True,
gt_class="myomyo",
pred_score=0.87,
pred_class_scores={
"pyomyo": 0.84,
"pippip": 0.84,
}
)
plot_true_positive_match(true_positive_eval)
return (true_positive,)
@app.cell
def _(MatchEvaluation, plot_cross_trigger_match, true_positive):
cross_trigger_eval = MatchEvaluation(
match=true_positive,
gt_det=True,
gt_class="myomyo",
pred_score=0.87,
pred_class_scores={
"pippip": 0.84,
"myomyo": 0.84,
}
)
plot_cross_trigger_match(cross_trigger_eval)
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

View File

@ -1,19 +0,0 @@
import marimo
__generated_with = "0.13.15"
app = marimo.App(width="medium")
@app.cell
def _():
from batdetect2.preprocess import build_preprocessor
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

View File

@ -1,97 +0,0 @@
import marimo
__generated_with = "0.14.5"
app = marimo.App(width="medium")
@app.cell
def _():
import marimo as mo
return (mo,)
@app.cell
def _():
from batdetect2.data import load_dataset, load_dataset_config
return load_dataset, load_dataset_config
@app.cell
def _(mo):
dataset_input = mo.ui.file_browser(label="dataset config file")
dataset_input
return (dataset_input,)
@app.cell
def _(mo):
audio_dir_input = mo.ui.file_browser(
label="audio directory", selection_mode="directory"
)
audio_dir_input
return (audio_dir_input,)
@app.cell
def _(dataset_input, load_dataset_config):
dataset_config = load_dataset_config(
path=dataset_input.path(), field="datasets.train",
)
return (dataset_config,)
@app.cell
def _(audio_dir_input, dataset_config, load_dataset):
dataset = load_dataset(dataset_config, base_dir=audio_dir_input.path())
return (dataset,)
@app.cell
def _(dataset):
len(dataset)
return
@app.cell
def _(dataset):
tag_groups = [
se.tags
for clip in dataset
for se in clip.sound_events
if se.tags
]
all_tags = [
tag for group in tag_groups for tag in group
]
return (all_tags,)
@app.cell
def _(mo):
key_search = mo.ui.text(label="key", debounce=0.1)
value_search = mo.ui.text(label="value", debounce=0.1)
mo.hstack([key_search, value_search]).left()
return key_search, value_search
@app.cell
def _(all_tags, key_search, mo, value_search):
filtered_tags = list(set(all_tags))
if key_search.value:
filtered_tags = [tag for tag in filtered_tags if key_search.value.lower() in tag.key.lower()]
if value_search.value:
filtered_tags = [tag for tag in filtered_tags if value_search.value.lower() in tag.value.lower()]
mo.vstack([mo.md(f"key={tag.key} value={tag.value}") for tag in filtered_tags[:5]])
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

View File

@ -25,14 +25,14 @@ dependencies = [
"scikit-learn>=1.2.2", "scikit-learn>=1.2.2",
"scipy>=1.10.1", "scipy>=1.10.1",
"seaborn>=0.13.2", "seaborn>=0.13.2",
"soundevent[audio,geometry,plot]>=2.9.1", "soundevent[audio,geometry,plot]>=2.10.0",
"tensorboard>=2.16.2", "tensorboard>=2.16.2",
"torch>=1.13.1", "torch>=1.13.1",
"torchaudio>=1.13.1", "torchaudio>=1.13.1",
"torchvision>=0.14.0", "torchvision>=0.14.0",
"tqdm>=4.66.2", "tqdm>=4.66.2",
] ]
requires-python = ">=3.9,<3.13" requires-python = ">=3.10,<3.14"
readme = "README.md" readme = "README.md"
license = { text = "CC-by-nc-4" } license = { text = "CC-by-nc-4" }
classifiers = [ classifiers = [
@ -75,7 +75,6 @@ dev = [
"ruff>=0.7.3", "ruff>=0.7.3",
"ipykernel>=6.29.4", "ipykernel>=6.29.4",
"setuptools>=69.5.1", "setuptools>=69.5.1",
"basedpyright>=1.31.0",
"myst-parser>=3.0.1", "myst-parser>=3.0.1",
"sphinx-autobuild>=2024.10.3", "sphinx-autobuild>=2024.10.3",
"numpydoc>=1.8.0", "numpydoc>=1.8.0",
@ -87,13 +86,24 @@ dev = [
"rust-just>=1.40.0", "rust-just>=1.40.0",
"pandas-stubs>=2.2.2.240807", "pandas-stubs>=2.2.2.240807",
"python-lsp-server>=1.13.0", "python-lsp-server>=1.13.0",
"deepdiff>=8.6.1",
] ]
dvclive = ["dvclive>=3.48.2"] dvclive = ["dvclive>=3.48.2"]
mlflow = ["mlflow>=3.1.1"] mlflow = ["mlflow>=3.1.1"]
gradio = [
"gradio>=6.9.0",
]
[tool.ruff] [tool.ruff]
line-length = 79 line-length = 79
target-version = "py39" target-version = "py310"
exclude = [
"src/batdetect2/train/legacy",
"src/batdetect2/plotting/legacy",
"src/batdetect2/evaluate/legacy",
"src/batdetect2/finetune",
"src/batdetect2/utils",
]
[tool.ruff.format] [tool.ruff.format]
docstring-code-format = true docstring-code-format = true
@ -105,15 +115,12 @@ select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
[tool.ruff.lint.pydocstyle] [tool.ruff.lint.pydocstyle]
convention = "numpy" convention = "numpy"
[tool.pyright] [tool.ty.src]
include = ["src", "tests"] include = ["src", "tests"]
pythonVersion = "3.9"
pythonPlatform = "All"
exclude = [ exclude = [
"src/batdetect2/detector/", "src/batdetect2/train/legacy",
"src/batdetect2/plotting/legacy",
"src/batdetect2/evaluate/legacy",
"src/batdetect2/finetune", "src/batdetect2/finetune",
"src/batdetect2/utils", "src/batdetect2/utils",
"src/batdetect2/plot",
"src/batdetect2/evaluate/legacy",
"src/batdetect2/train/legacy",
] ]

View File

@ -98,7 +98,6 @@ consult the API documentation in the code.
""" """
import warnings import warnings
from typing import List, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
@ -165,7 +164,7 @@ def load_audio(
time_exp_fact: float = 1, time_exp_fact: float = 1,
target_samp_rate: int = TARGET_SAMPLERATE_HZ, target_samp_rate: int = TARGET_SAMPLERATE_HZ,
scale: bool = False, scale: bool = False,
max_duration: Optional[float] = None, max_duration: float | None = None,
) -> np.ndarray: ) -> np.ndarray:
"""Load audio from file. """Load audio from file.
@ -203,7 +202,7 @@ def load_audio(
def generate_spectrogram( def generate_spectrogram(
audio: np.ndarray, audio: np.ndarray,
samp_rate: int = TARGET_SAMPLERATE_HZ, samp_rate: int = TARGET_SAMPLERATE_HZ,
config: Optional[SpectrogramParameters] = None, config: SpectrogramParameters | None = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
) -> torch.Tensor: ) -> torch.Tensor:
"""Generate spectrogram from audio array. """Generate spectrogram from audio array.
@ -240,7 +239,7 @@ def generate_spectrogram(
def process_file( def process_file(
audio_file: str, audio_file: str,
model: DetectionModel = MODEL, model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None, config: ProcessingConfiguration | None = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
) -> du.RunResults: ) -> du.RunResults:
"""Process audio file with model. """Process audio file with model.
@ -271,8 +270,8 @@ def process_spectrogram(
spec: torch.Tensor, spec: torch.Tensor,
samp_rate: int = TARGET_SAMPLERATE_HZ, samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL, model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None, config: ProcessingConfiguration | None = None,
) -> Tuple[List[Annotation], np.ndarray]: ) -> tuple[list[Annotation], np.ndarray]:
"""Process spectrogram with model. """Process spectrogram with model.
Parameters Parameters
@ -312,9 +311,9 @@ def process_audio(
audio: np.ndarray, audio: np.ndarray,
samp_rate: int = TARGET_SAMPLERATE_HZ, samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL, model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None, config: ProcessingConfiguration | None = None,
device: torch.device = DEVICE, device: torch.device = DEVICE,
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]: ) -> tuple[list[Annotation], np.ndarray, torch.Tensor]:
"""Process audio array with model. """Process audio array with model.
Parameters Parameters
@ -356,8 +355,8 @@ def process_audio(
def postprocess( def postprocess(
outputs: ModelOutput, outputs: ModelOutput,
samp_rate: int = TARGET_SAMPLERATE_HZ, samp_rate: int = TARGET_SAMPLERATE_HZ,
config: Optional[ProcessingConfiguration] = None, config: ProcessingConfiguration | None = None,
) -> Tuple[List[Annotation], np.ndarray]: ) -> tuple[list[Annotation], np.ndarray]:
"""Postprocess model outputs. """Postprocess model outputs.
Convert model tensor outputs to predicted bounding boxes and Convert model tensor outputs to predicted bounding boxes and

View File

@ -1,59 +1,90 @@
import json
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple from typing import Literal, Sequence, cast
import numpy as np import numpy as np
import torch import torch
from soundevent import data from soundevent import data
from soundevent.audio.files import get_audio_files from soundevent.audio.files import get_audio_files
from batdetect2.audio import build_audio_loader from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
from batdetect2.core import merge_configs from batdetect2.data import Dataset, load_dataset_from_config
from batdetect2.data import ( from batdetect2.evaluate import (
OutputFormatConfig, DEFAULT_EVAL_DIR,
build_output_formatter, EvaluationConfig,
get_output_formatter, EvaluatorProtocol,
load_dataset_from_config, build_evaluator,
run_evaluate,
save_evaluation_results,
) )
from batdetect2.data.datasets import Dataset from batdetect2.inference import (
from batdetect2.data.predictions.base import OutputFormatterProtocol InferenceConfig,
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate process_file_list,
from batdetect2.inference import process_file_list, run_batch_inference run_batch_inference,
from batdetect2.logging import DEFAULT_LOGS_DIR )
from batdetect2.models import Model, build_model from batdetect2.logging import (
from batdetect2.postprocess import build_postprocessor, to_raw_predictions DEFAULT_LOGS_DIR,
from batdetect2.preprocess import build_preprocessor AppLoggingConfig,
from batdetect2.targets import build_targets LoggerConfig,
)
from batdetect2.models import (
Model,
ModelConfig,
build_model,
build_model_with_new_targets,
)
from batdetect2.models.detectors import Detector
from batdetect2.outputs import (
OutputFormatConfig,
OutputFormatterProtocol,
OutputsConfig,
OutputTransformProtocol,
build_output_formatter,
build_output_transform,
get_output_formatter,
)
from batdetect2.postprocess import (
ClipDetections,
Detection,
PostprocessorProtocol,
build_postprocessor,
)
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
from batdetect2.train import ( from batdetect2.train import (
DEFAULT_CHECKPOINT_DIR, DEFAULT_CHECKPOINT_DIR,
TrainingConfig,
load_model_from_checkpoint, load_model_from_checkpoint,
train, run_train,
)
from batdetect2.typing import (
AudioLoader,
BatDetect2Prediction,
EvaluatorProtocol,
PostprocessorProtocol,
PreprocessorProtocol,
RawPrediction,
TargetProtocol,
) )
class BatDetect2API: class BatDetect2API:
def __init__( def __init__(
self, self,
config: BatDetect2Config, model_config: ModelConfig,
audio_config: AudioConfig,
train_config: TrainingConfig,
evaluation_config: EvaluationConfig,
inference_config: InferenceConfig,
outputs_config: OutputsConfig,
logging_config: AppLoggingConfig,
targets: TargetProtocol, targets: TargetProtocol,
audio_loader: AudioLoader, audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol, postprocessor: PostprocessorProtocol,
evaluator: EvaluatorProtocol, evaluator: EvaluatorProtocol,
formatter: OutputFormatterProtocol, formatter: OutputFormatterProtocol,
output_transform: OutputTransformProtocol,
model: Model, model: Model,
): ):
self.config = config self.model_config = model_config
self.audio_config = audio_config
self.train_config = train_config
self.evaluation_config = evaluation_config
self.inference_config = inference_config
self.outputs_config = outputs_config
self.logging_config = logging_config
self.targets = targets self.targets = targets
self.audio_loader = audio_loader self.audio_loader = audio_loader
self.preprocessor = preprocessor self.preprocessor = preprocessor
@ -61,34 +92,40 @@ class BatDetect2API:
self.evaluator = evaluator self.evaluator = evaluator
self.model = model self.model = model
self.formatter = formatter self.formatter = formatter
self.output_transform = output_transform
self.model.eval() self.model.eval()
def load_annotations( def load_annotations(
self, self,
path: data.PathLike, path: data.PathLike,
base_dir: Optional[data.PathLike] = None, base_dir: data.PathLike | None = None,
) -> Dataset: ) -> Dataset:
return load_dataset_from_config(path, base_dir=base_dir) return load_dataset_from_config(path, base_dir=base_dir)
def train( def train(
self, self,
train_annotations: Sequence[data.ClipAnnotation], train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None, val_annotations: Sequence[data.ClipAnnotation] | None = None,
train_workers: Optional[int] = None, train_workers: int = 0,
val_workers: Optional[int] = None, val_workers: int = 0,
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR, checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
log_dir: Optional[Path] = DEFAULT_LOGS_DIR, log_dir: Path | None = DEFAULT_LOGS_DIR,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
num_epochs: Optional[int] = None, num_epochs: int | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
seed: Optional[int] = None, seed: int | None = None,
model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None,
): ):
train( run_train(
train_annotations=train_annotations, train_annotations=train_annotations,
val_annotations=val_annotations, val_annotations=val_annotations,
model=self.model,
targets=self.targets, targets=self.targets,
config=self.config, model_config=model_config or self.model_config,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
train_workers=train_workers, train_workers=train_workers,
@ -99,25 +136,81 @@ class BatDetect2API:
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
seed=seed, seed=seed,
train_config=train_config or self.train_config,
audio_config=audio_config or self.audio_config,
logger_config=logger_config or self.logging_config.train,
)
return self
def finetune(
self,
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Sequence[data.ClipAnnotation] | None = None,
trainable: Literal[
"all", "heads", "classifier_head", "bbox_head"
] = "heads",
train_workers: int = 0,
val_workers: int = 0,
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
log_dir: Path | None = DEFAULT_LOGS_DIR,
experiment_name: str | None = None,
num_epochs: int | None = None,
run_name: str | None = None,
seed: int | None = None,
model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None,
) -> "BatDetect2API":
"""Fine-tune the model with trainable-parameter selection."""
self._set_trainable_parameters(trainable)
run_train(
train_annotations=train_annotations,
val_annotations=val_annotations,
model=self.model,
targets=self.targets,
model_config=model_config or self.model_config,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
train_workers=train_workers,
val_workers=val_workers,
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
experiment_name=experiment_name,
num_epochs=num_epochs,
run_name=run_name,
seed=seed,
audio_config=audio_config or self.audio_config,
train_config=train_config or self.train_config,
logger_config=logger_config or self.logging_config.train,
) )
return self return self
def evaluate( def evaluate(
self, self,
test_annotations: Sequence[data.ClipAnnotation], test_annotations: Sequence[data.ClipAnnotation],
num_workers: Optional[int] = None, num_workers: int = 0,
output_dir: data.PathLike = DEFAULT_EVAL_DIR, output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
save_predictions: bool = True, save_predictions: bool = True,
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]: audio_config: AudioConfig | None = None,
return evaluate( evaluation_config: EvaluationConfig | None = None,
outputs_config: OutputsConfig | None = None,
logger_config: LoggerConfig | None = None,
) -> tuple[dict[str, float], list[ClipDetections]]:
return run_evaluate(
self.model, self.model,
test_annotations, test_annotations,
targets=self.targets, targets=self.targets,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
config=self.config, audio_config=audio_config or self.audio_config,
evaluation_config=evaluation_config or self.evaluation_config,
output_config=outputs_config or self.outputs_config,
logger_config=logger_config or self.logging_config.evaluation,
num_workers=num_workers, num_workers=num_workers,
output_dir=output_dir, output_dir=output_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
@ -128,8 +221,8 @@ class BatDetect2API:
def evaluate_predictions( def evaluate_predictions(
self, self,
annotations: Sequence[data.ClipAnnotation], annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[ClipDetections],
output_dir: Optional[data.PathLike] = None, output_dir: data.PathLike | None = None,
): ):
clip_evals = self.evaluator.evaluate( clip_evals = self.evaluator.evaluate(
annotations, annotations,
@ -139,30 +232,66 @@ class BatDetect2API:
metrics = self.evaluator.compute_metrics(clip_evals) metrics = self.evaluator.compute_metrics(clip_evals)
if output_dir is not None: if output_dir is not None:
output_dir = Path(output_dir) save_evaluation_results(
metrics=metrics,
if not output_dir.is_dir(): plots=self.evaluator.generate_plots(clip_evals),
output_dir.mkdir(parents=True) output_dir=output_dir,
)
metrics_path = output_dir / "metrics.json"
metrics_path.write_text(json.dumps(metrics))
for figure_name, fig in self.evaluator.generate_plots(clip_evals):
fig_path = output_dir / figure_name
if not fig_path.parent.is_dir():
fig_path.parent.mkdir(parents=True)
fig.savefig(fig_path)
return metrics return metrics
def load_audio(self, path: data.PathLike) -> np.ndarray: def load_audio(self, path: data.PathLike) -> np.ndarray:
return self.audio_loader.load_file(path) return self.audio_loader.load_file(path)
def load_recording(self, recording: data.Recording) -> np.ndarray:
return self.audio_loader.load_recording(recording)
def load_clip(self, clip: data.Clip) -> np.ndarray: def load_clip(self, clip: data.Clip) -> np.ndarray:
return self.audio_loader.load_clip(clip) return self.audio_loader.load_clip(clip)
def get_top_class_name(self, detection: Detection) -> str:
"""Get highest-confidence class name for one detection."""
top_index = int(np.argmax(detection.class_scores))
return self.targets.class_names[top_index]
def get_class_scores(
self,
detection: Detection,
*,
include_top_class: bool = True,
sort_descending: bool = True,
) -> list[tuple[str, float]]:
"""Get class score list as ``(class_name, score)`` pairs."""
scores = [
(class_name, float(score))
for class_name, score in zip(
self.targets.class_names,
detection.class_scores,
strict=True,
)
]
if sort_descending:
scores.sort(key=lambda item: item[1], reverse=True)
if include_top_class:
return scores
top_class_name = self.get_top_class_name(detection)
return [
(class_name, score)
for class_name, score in scores
if class_name != top_class_name
]
@staticmethod
def get_detection_features(detection: Detection) -> np.ndarray:
"""Get extracted feature vector for one detection."""
return detection.features
def generate_spectrogram( def generate_spectrogram(
self, self,
audio: np.ndarray, audio: np.ndarray,
@ -170,24 +299,41 @@ class BatDetect2API:
tensor = torch.tensor(audio).unsqueeze(0) tensor = torch.tensor(audio).unsqueeze(0)
return self.preprocessor(tensor) return self.preprocessor(tensor)
def process_file(self, audio_file: str) -> BatDetect2Prediction: def process_file(
self,
audio_file: data.PathLike,
batch_size: int | None = None,
) -> ClipDetections:
recording = data.Recording.from_file(audio_file, compute_hash=False) recording = data.Recording.from_file(audio_file, compute_hash=False)
wav = self.audio_loader.load_recording(recording)
detections = self.process_audio(wav) predictions = self.process_files(
return BatDetect2Prediction( [audio_file],
batch_size=(
batch_size
if batch_size is not None
else self.inference_config.loader.batch_size
),
)
detections = [
detection
for prediction in predictions
for detection in prediction.detections
]
return ClipDetections(
clip=data.Clip( clip=data.Clip(
uuid=recording.uuid, uuid=recording.uuid,
recording=recording, recording=recording,
start_time=0, start_time=0,
end_time=recording.duration, end_time=recording.duration,
), ),
predictions=detections, detections=detections,
) )
def process_audio( def process_audio(
self, self,
audio: np.ndarray, audio: np.ndarray,
) -> List[RawPrediction]: ) -> list[Detection]:
spec = self.generate_spectrogram(audio) spec = self.generate_spectrogram(audio)
return self.process_spectrogram(spec) return self.process_spectrogram(spec)
@ -195,7 +341,7 @@ class BatDetect2API:
self, self,
spec: torch.Tensor, spec: torch.Tensor,
start_time: float = 0, start_time: float = 0,
) -> List[RawPrediction]: ) -> list[Detection]:
if spec.ndim == 4 and spec.shape[0] > 1: if spec.ndim == 4 and spec.shape[0] > 1:
raise ValueError("Batched spectrograms not supported.") raise ValueError("Batched spectrograms not supported.")
@ -204,59 +350,74 @@ class BatDetect2API:
outputs = self.model.detector(spec) outputs = self.model.detector(spec)
detections = self.model.postprocessor( detections = self.postprocessor(
outputs, outputs,
start_times=[start_time],
)[0] )[0]
return self.output_transform.to_detections(
return to_raw_predictions(detections.numpy(), targets=self.targets) detections=detections,
start_time=start_time,
)
def process_directory( def process_directory(
self, self,
audio_dir: data.PathLike, audio_dir: data.PathLike,
) -> List[BatDetect2Prediction]: ) -> list[ClipDetections]:
files = list(get_audio_files(audio_dir)) files = list(get_audio_files(audio_dir))
return self.process_files(files) return self.process_files(files)
def process_files( def process_files(
self, self,
audio_files: Sequence[data.PathLike], audio_files: Sequence[data.PathLike],
num_workers: Optional[int] = None, batch_size: int | None = None,
) -> List[BatDetect2Prediction]: num_workers: int = 0,
audio_config: AudioConfig | None = None,
inference_config: InferenceConfig | None = None,
output_config: OutputsConfig | None = None,
) -> list[ClipDetections]:
return process_file_list( return process_file_list(
self.model, self.model,
audio_files, audio_files,
config=self.config,
targets=self.targets, targets=self.targets,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
output_transform=self.output_transform,
batch_size=batch_size,
num_workers=num_workers, num_workers=num_workers,
audio_config=audio_config or self.audio_config,
inference_config=inference_config or self.inference_config,
output_config=output_config or self.outputs_config,
) )
def process_clips( def process_clips(
self, self,
clips: Sequence[data.Clip], clips: Sequence[data.Clip],
batch_size: Optional[int] = None, batch_size: int | None = None,
num_workers: Optional[int] = None, num_workers: int = 0,
) -> List[BatDetect2Prediction]: audio_config: AudioConfig | None = None,
inference_config: InferenceConfig | None = None,
output_config: OutputsConfig | None = None,
) -> list[ClipDetections]:
return run_batch_inference( return run_batch_inference(
self.model, self.model,
clips, clips,
targets=self.targets, targets=self.targets,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
config=self.config, output_transform=self.output_transform,
batch_size=batch_size, batch_size=batch_size,
num_workers=num_workers, num_workers=num_workers,
audio_config=audio_config or self.audio_config,
inference_config=inference_config or self.inference_config,
output_config=output_config or self.outputs_config,
) )
def save_predictions( def save_predictions(
self, self,
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[ClipDetections],
path: data.PathLike, path: data.PathLike,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
format: Optional[str] = None, format: str | None = None,
config: Optional[OutputFormatConfig] = None, config: OutputFormatConfig | None = None,
): ):
formatter = self.formatter formatter = self.formatter
@ -274,50 +435,78 @@ class BatDetect2API:
def load_predictions( def load_predictions(
self, self,
path: data.PathLike, path: data.PathLike,
) -> List[BatDetect2Prediction]: format: str | None = None,
return self.formatter.load(path) config: OutputFormatConfig | None = None,
) -> list[object]:
formatter = self.formatter
if format is not None or config is not None:
format = format or config.name # type: ignore
formatter = get_output_formatter(
name=format,
targets=self.targets,
config=config,
)
return formatter.load(path)
@classmethod @classmethod
def from_config( def from_config(
cls, cls,
config: BatDetect2Config, config: BatDetect2Config,
): ) -> "BatDetect2API":
targets = build_targets(config=config.targets) targets = build_targets(config=config.model.targets)
audio_loader = build_audio_loader(config=config.audio) audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor( preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate, input_samplerate=audio_loader.samplerate,
config=config.preprocess, config=config.model.preprocess,
) )
postprocessor = build_postprocessor( postprocessor = build_postprocessor(
preprocessor, preprocessor,
config=config.postprocess, config=config.model.postprocess,
) )
evaluator = build_evaluator(config=config.evaluation, targets=targets) formatter = build_output_formatter(
targets,
config=config.outputs.format,
)
output_transform = build_output_transform(
config=config.outputs.transform,
targets=targets,
)
# NOTE: Better to have a separate instance of evaluator = build_evaluator(
# preprocessor and postprocessor as these may be moved config=config.evaluation,
# to another device. targets=targets,
transform=output_transform,
)
# NOTE: Build separate instances of preprocessor and postprocessor
# to avoid device mismatch errors
model = build_model( model = build_model(
config=config.model, config=config.model,
targets=targets, targets=build_targets(config=config.model.targets),
preprocessor=build_preprocessor( preprocessor=build_preprocessor(
input_samplerate=audio_loader.samplerate, input_samplerate=audio_loader.samplerate,
config=config.preprocess, config=config.model.preprocess,
), ),
postprocessor=build_postprocessor( postprocessor=build_postprocessor(
preprocessor, preprocessor,
config=config.postprocess, config=config.model.postprocess,
), ),
) )
formatter = build_output_formatter(targets, config=config.output)
return cls( return cls(
config=config, model_config=config.model,
audio_config=config.audio,
train_config=config.train,
evaluation_config=config.evaluation,
inference_config=config.inference,
outputs_config=config.outputs,
logging_config=config.logging,
targets=targets, targets=targets,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
@ -325,40 +514,83 @@ class BatDetect2API:
evaluator=evaluator, evaluator=evaluator,
model=model, model=model,
formatter=formatter, formatter=formatter,
output_transform=output_transform,
) )
@classmethod @classmethod
def from_checkpoint( def from_checkpoint(
cls, cls,
path: data.PathLike, path: data.PathLike,
config: Optional[BatDetect2Config] = None, targets_config: TargetConfig | None = None,
): audio_config: AudioConfig | None = None,
model, stored_config = load_model_from_checkpoint(path) train_config: TrainingConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
inference_config: InferenceConfig | None = None,
outputs_config: OutputsConfig | None = None,
logging_config: AppLoggingConfig | None = None,
) -> "BatDetect2API":
model, model_config = load_model_from_checkpoint(path)
config = ( audio_config = audio_config or AudioConfig(
merge_configs(stored_config, config) if config else stored_config samplerate=model_config.samplerate,
) )
train_config = train_config or TrainingConfig()
evaluation_config = evaluation_config or EvaluationConfig()
inference_config = inference_config or InferenceConfig()
outputs_config = outputs_config or OutputsConfig()
logging_config = logging_config or AppLoggingConfig()
targets = build_targets(config=config.targets) if (
targets_config is not None
and targets_config != model_config.targets
):
targets = build_targets(config=targets_config)
model = build_model_with_new_targets(
model=model,
targets=targets,
)
model_config = model_config.model_copy(
update={"targets": targets_config}
)
audio_loader = build_audio_loader(config=config.audio) targets = build_targets(config=model_config.targets)
audio_loader = build_audio_loader(config=audio_config)
preprocessor = build_preprocessor( preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate, input_samplerate=audio_loader.samplerate,
config=config.preprocess, config=model_config.preprocess,
) )
postprocessor = build_postprocessor( postprocessor = build_postprocessor(
preprocessor, preprocessor,
config=config.postprocess, config=model_config.postprocess,
) )
evaluator = build_evaluator(config=config.evaluation, targets=targets) formatter = build_output_formatter(
targets,
config=outputs_config.format,
)
formatter = build_output_formatter(targets, config=config.output) output_transform = build_output_transform(
config=outputs_config.transform,
targets=targets,
)
evaluator = build_evaluator(
config=evaluation_config,
targets=targets,
transform=output_transform,
)
return cls( return cls(
config=config, model_config=model_config,
audio_config=audio_config,
train_config=train_config,
evaluation_config=evaluation_config,
inference_config=inference_config,
outputs_config=outputs_config,
logging_config=logging_config,
targets=targets, targets=targets,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
@ -366,4 +598,27 @@ class BatDetect2API:
evaluator=evaluator, evaluator=evaluator,
model=model, model=model,
formatter=formatter, formatter=formatter,
output_transform=output_transform,
) )
def _set_trainable_parameters(
self,
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
) -> None:
detector = cast(Detector, self.model.detector)
for parameter in detector.parameters():
parameter.requires_grad = False
if trainable == "all":
for parameter in detector.parameters():
parameter.requires_grad = True
return
if trainable in {"heads", "classifier_head"}:
for parameter in detector.classifier_head.parameters():
parameter.requires_grad = True
if trainable in {"heads", "bbox_head"}:
for parameter in detector.bbox_head.parameters():
parameter.requires_grad = True

View File

@ -5,8 +5,11 @@ from batdetect2.audio.loader import (
SoundEventAudioLoader, SoundEventAudioLoader,
build_audio_loader, build_audio_loader,
) )
from batdetect2.audio.types import AudioLoader, ClipperProtocol
__all__ = [ __all__ = [
"AudioLoader",
"ClipperProtocol",
"TARGET_SAMPLERATE_HZ", "TARGET_SAMPLERATE_HZ",
"AudioConfig", "AudioConfig",
"SoundEventAudioLoader", "SoundEventAudioLoader",

View File

@ -1,4 +1,4 @@
from typing import Annotated, List, Literal, Optional, Union from typing import Annotated, List, Literal
import numpy as np import numpy as np
from loguru import logger from loguru import logger
@ -6,8 +6,13 @@ from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds, intervals_overlap from soundevent.geometry import compute_bounds, intervals_overlap
from batdetect2.core import BaseConfig, Registry from batdetect2.audio.types import ClipperProtocol
from batdetect2.typing import ClipperProtocol from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
DEFAULT_TRAIN_CLIP_DURATION = 0.256 DEFAULT_TRAIN_CLIP_DURATION = 0.256
DEFAULT_MAX_EMPTY_CLIP = 0.1 DEFAULT_MAX_EMPTY_CLIP = 0.1
@ -16,12 +21,24 @@ DEFAULT_MAX_EMPTY_CLIP = 0.1
__all__ = [ __all__ = [
"build_clipper", "build_clipper",
"ClipConfig", "ClipConfig",
"ClipperImportConfig",
] ]
clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper") clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")
@add_import_config(clipper_registry)
class ClipperImportConfig(ImportConfig):
"""Use any callable as a clipper.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class RandomClipConfig(BaseConfig): class RandomClipConfig(BaseConfig):
name: Literal["random_subclip"] = "random_subclip" name: Literal["random_subclip"] = "random_subclip"
duration: float = DEFAULT_TRAIN_CLIP_DURATION duration: float = DEFAULT_TRAIN_CLIP_DURATION
@ -245,16 +262,12 @@ class FixedDurationClip:
ClipConfig = Annotated[ ClipConfig = Annotated[
Union[ RandomClipConfig | PaddedClipConfig | FixedDurationClipConfig,
RandomClipConfig,
PaddedClipConfig,
FixedDurationClipConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol: def build_clipper(config: ClipConfig | None = None) -> ClipperProtocol:
config = config or RandomClipConfig() config = config or RandomClipConfig()
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(

View File

@ -1,5 +1,3 @@
from typing import Optional
import numpy as np import numpy as np
from numpy.typing import DTypeLike from numpy.typing import DTypeLike
from pydantic import Field from pydantic import Field
@ -7,8 +5,8 @@ from scipy.signal import resample, resample_poly
from soundevent import audio, data from soundevent import audio, data
from soundfile import LibsndfileError from soundfile import LibsndfileError
from batdetect2.audio.types import AudioLoader
from batdetect2.core import BaseConfig from batdetect2.core import BaseConfig
from batdetect2.typing import AudioLoader
__all__ = [ __all__ = [
"SoundEventAudioLoader", "SoundEventAudioLoader",
@ -28,15 +26,17 @@ class ResampleConfig(BaseConfig):
Attributes Attributes
---------- ----------
samplerate : int, default=256000 enabled : bool, default=True
The target sample rate in Hz to resample the audio to. Must be > 0. Whether to resample the audio to the target sample rate. If
``False``, the audio is returned at its original sample rate.
method : str, default="poly" method : str, default="poly"
The resampling algorithm to use. Options: The resampling algorithm to use. Options:
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
Generally fast. - ``"poly"``: Polyphase resampling via
- "fourier": Resampling via Fourier method using ``scipy.signal.resample_poly``. Generally fast and accurate.
`scipy.signal.resample`. May handle non-integer - ``"fourier"``: FFT-based resampling via
resampling factors differently. ``scipy.signal.resample``. May be preferred for non-integer
resampling ratios.
""" """
enabled: bool = True enabled: bool = True
@ -50,7 +50,7 @@ class AudioConfig(BaseConfig):
resample: ResampleConfig = Field(default_factory=ResampleConfig) resample: ResampleConfig = Field(default_factory=ResampleConfig)
def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader: def build_audio_loader(config: AudioConfig | None = None) -> AudioLoader:
"""Factory function to create an AudioLoader based on configuration.""" """Factory function to create an AudioLoader based on configuration."""
config = config or AudioConfig() config = config or AudioConfig()
return SoundEventAudioLoader( return SoundEventAudioLoader(
@ -65,7 +65,7 @@ class SoundEventAudioLoader(AudioLoader):
def __init__( def __init__(
self, self,
samplerate: int = TARGET_SAMPLERATE_HZ, samplerate: int = TARGET_SAMPLERATE_HZ,
config: Optional[ResampleConfig] = None, config: ResampleConfig | None = None,
): ):
self.samplerate = samplerate self.samplerate = samplerate
self.config = config or ResampleConfig() self.config = config or ResampleConfig()
@ -73,7 +73,7 @@ class SoundEventAudioLoader(AudioLoader):
def load_file( def load_file(
self, self,
path: data.PathLike, path: data.PathLike,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess audio directly from a file path.""" """Load and preprocess audio directly from a file path."""
return load_file_audio( return load_file_audio(
@ -86,7 +86,7 @@ class SoundEventAudioLoader(AudioLoader):
def load_recording( def load_recording(
self, self,
recording: data.Recording, recording: data.Recording,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess the entire audio for a Recording object.""" """Load and preprocess the entire audio for a Recording object."""
return load_recording_audio( return load_recording_audio(
@ -99,7 +99,7 @@ class SoundEventAudioLoader(AudioLoader):
def load_clip( def load_clip(
self, self,
clip: data.Clip, clip: data.Clip,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess the audio segment defined by a Clip object.""" """Load and preprocess the audio segment defined by a Clip object."""
return load_clip_audio( return load_clip_audio(
@ -112,10 +112,10 @@ class SoundEventAudioLoader(AudioLoader):
def load_file_audio( def load_file_audio(
path: data.PathLike, path: data.PathLike,
samplerate: Optional[int] = None, samplerate: int | None = None,
config: Optional[ResampleConfig] = None, config: ResampleConfig | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32, # type: ignore dtype: DTypeLike = np.float32,
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess audio from a file path using specified config.""" """Load and preprocess audio from a file path using specified config."""
try: try:
@ -136,10 +136,10 @@ def load_file_audio(
def load_recording_audio( def load_recording_audio(
recording: data.Recording, recording: data.Recording,
samplerate: Optional[int] = None, samplerate: int | None = None,
config: Optional[ResampleConfig] = None, config: ResampleConfig | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32, # type: ignore dtype: DTypeLike = np.float32,
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess the entire audio content of a recording using config.""" """Load and preprocess the entire audio content of a recording using config."""
clip = data.Clip( clip = data.Clip(
@ -158,10 +158,10 @@ def load_recording_audio(
def load_clip_audio( def load_clip_audio(
clip: data.Clip, clip: data.Clip,
samplerate: Optional[int] = None, samplerate: int | None = None,
config: Optional[ResampleConfig] = None, config: ResampleConfig | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32, # type: ignore dtype: DTypeLike = np.float32,
) -> np.ndarray: ) -> np.ndarray:
"""Load and preprocess a specific audio clip segment based on config.""" """Load and preprocess a specific audio clip segment based on config."""
try: try:
@ -194,7 +194,31 @@ def resample_audio(
samplerate: int = TARGET_SAMPLERATE_HZ, samplerate: int = TARGET_SAMPLERATE_HZ,
method: str = "poly", method: str = "poly",
) -> np.ndarray: ) -> np.ndarray:
"""Resample an audio waveform DataArray to a target sample rate.""" """Resample an audio waveform to a target sample rate.
Parameters
----------
wav : np.ndarray
Input waveform array. The last axis is assumed to be time.
sr : int
Original sample rate of ``wav`` in Hz.
samplerate : int, default=256000
Target sample rate in Hz.
method : str, default="poly"
Resampling algorithm: ``"poly"`` (polyphase) or
``"fourier"`` (FFT-based).
Returns
-------
np.ndarray
Resampled waveform. If ``sr == samplerate`` the input array is
returned unchanged.
Raises
------
NotImplementedError
If ``method`` is not ``"poly"`` or ``"fourier"``.
"""
if sr == samplerate: if sr == samplerate:
return wav return wav
@ -264,7 +288,7 @@ def resample_audio_fourier(
sr_new: int, sr_new: int,
axis: int = -1, axis: int = -1,
) -> np.ndarray: ) -> np.ndarray:
"""Resample a numpy array using `scipy.signal.resample`. """Resample a numpy array using ``scipy.signal.resample``.
This method uses FFTs to resample the signal. This method uses FFTs to resample the signal.
@ -272,23 +296,20 @@ def resample_audio_fourier(
---------- ----------
array : np.ndarray array : np.ndarray
The input array to resample. The input array to resample.
num : int sr_orig : int
The desired number of samples in the output array along `axis`. The original sample rate in Hz.
sr_new : int
The target sample rate in Hz.
axis : int, default=-1 axis : int, default=-1
The axis of `array` along which to resample. The axis of ``array`` along which to resample.
Returns Returns
------- -------
np.ndarray np.ndarray
The array resampled to have `num` samples along `axis`. The array resampled to the target sample rate.
Raises
------
ValueError
If `num` is negative.
""" """
ratio = sr_new / sr_orig ratio = sr_new / sr_orig
return resample( # type: ignore return resample(
array, array,
int(array.shape[axis] * ratio), int(array.shape[axis] * ratio),
axis=axis, axis=axis,

View File

@ -0,0 +1,40 @@
from typing import Protocol
import numpy as np
from soundevent import data
__all__ = [
"AudioLoader",
"ClipperProtocol",
]
class AudioLoader(Protocol):
samplerate: int
def load_file(
self,
path: data.PathLike,
audio_dir: data.PathLike | None = None,
) -> np.ndarray: ...
def load_recording(
self,
recording: data.Recording,
audio_dir: data.PathLike | None = None,
) -> np.ndarray: ...
def load_clip(
self,
clip: data.Clip,
audio_dir: data.PathLike | None = None,
) -> np.ndarray: ...
class ClipperProtocol(Protocol):
def __call__(
self,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation: ...
def get_subclip(self, clip: data.Clip) -> data.Clip: ...

View File

@ -2,6 +2,7 @@ from batdetect2.cli.base import cli
from batdetect2.cli.compat import detect from batdetect2.cli.compat import detect
from batdetect2.cli.data import data from batdetect2.cli.data import data
from batdetect2.cli.evaluate import evaluate_command from batdetect2.cli.evaluate import evaluate_command
from batdetect2.cli.inference import predict
from batdetect2.cli.train import train_command from batdetect2.cli.train import train_command
__all__ = [ __all__ = [
@ -10,6 +11,7 @@ __all__ = [
"data", "data",
"train_command", "train_command",
"evaluate_command", "evaluate_command",
"predict",
] ]

View File

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from typing import Optional
import click import click
@ -34,9 +33,9 @@ def data(): ...
) )
def summary( def summary(
dataset_config: Path, dataset_config: Path,
field: Optional[str] = None, field: str | None = None,
targets_path: Optional[Path] = None, targets_path: Path | None = None,
base_dir: Optional[Path] = None, base_dir: Path | None = None,
): ):
from batdetect2.data import compute_class_summary, load_dataset_from_config from batdetect2.data import compute_class_summary, load_dataset_from_config
from batdetect2.targets import load_targets from batdetect2.targets import load_targets
@ -83,9 +82,9 @@ def summary(
) )
def convert( def convert(
dataset_config: Path, dataset_config: Path,
field: Optional[str] = None, field: str | None = None,
output: Path = Path("annotations.json"), output: Path = Path("annotations.json"),
base_dir: Optional[Path] = None, base_dir: Path | None = None,
): ):
"""Convert a dataset config file to soundevent format.""" """Convert a dataset config file to soundevent format."""
from soundevent import data, io from soundevent import data, io

View File

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from typing import Optional
import click import click
from loguru import logger from loguru import logger
@ -13,9 +12,14 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@cli.command(name="evaluate") @cli.command(name="evaluate")
@click.argument("model-path", type=click.Path(exists=True)) @click.argument("model_path", type=click.Path(exists=True))
@click.argument("test_dataset", type=click.Path(exists=True)) @click.argument("test_dataset", type=click.Path(exists=True))
@click.option("--config", "config_path", type=click.Path()) @click.option("--targets", "targets_config", type=click.Path(exists=True))
@click.option("--audio-config", type=click.Path(exists=True))
@click.option("--evaluation-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--logging-config", type=click.Path(exists=True))
@click.option("--base-dir", type=click.Path(), default=Path.cwd()) @click.option("--base-dir", type=click.Path(), default=Path.cwd())
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR) @click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
@click.option("--experiment-name", type=str) @click.option("--experiment-name", type=str)
@ -25,15 +29,25 @@ def evaluate_command(
model_path: Path, model_path: Path,
test_dataset: Path, test_dataset: Path,
base_dir: Path, base_dir: Path,
config_path: Optional[Path], targets_config: Path | None,
audio_config: Path | None,
evaluation_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
logging_config: Path | None,
output_dir: Path = DEFAULT_OUTPUT_DIR, output_dir: Path = DEFAULT_OUTPUT_DIR,
num_workers: Optional[int] = None, num_workers: int = 0,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
): ):
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import load_full_config from batdetect2.audio import AudioConfig
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate import EvaluationConfig
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.targets import TargetConfig
logger.info("Initiating evaluation process...") logger.info("Initiating evaluation process...")
@ -47,11 +61,44 @@ def evaluate_command(
num_annotations=len(test_annotations), num_annotations=len(test_annotations),
) )
config = None target_conf = (
if config_path is not None: TargetConfig.load(targets_config)
config = load_full_config(config_path) if targets_config is not None
else None
)
audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None
)
eval_conf = (
EvaluationConfig.load(evaluation_config)
if evaluation_config is not None
else None
)
inference_conf = (
InferenceConfig.load(inference_config)
if inference_config is not None
else None
)
outputs_conf = (
OutputsConfig.load(outputs_config)
if outputs_config is not None
else None
)
logging_conf = (
AppLoggingConfig.load(logging_config)
if logging_config is not None
else None
)
api = BatDetect2API.from_checkpoint(model_path, config=config) api = BatDetect2API.from_checkpoint(
model_path,
targets_config=target_conf,
audio_config=audio_conf,
evaluation_config=eval_conf,
inference_config=inference_conf,
outputs_config=outputs_conf,
logging_config=logging_conf,
)
api.evaluate( api.evaluate(
test_annotations, test_annotations,

View File

@ -0,0 +1,231 @@
from pathlib import Path
import click
from loguru import logger
from soundevent import io
from soundevent.audio.files import get_audio_files
from batdetect2.cli.base import cli
__all__ = ["predict"]
@cli.group(name="predict")
def predict() -> None:
"""Run prediction with BatDetect2 API v2."""
def _build_api(
model_path: Path,
audio_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
logging_config: Path | None,
):
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.outputs import OutputsConfig
audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None
)
inference_conf = (
InferenceConfig.load(inference_config)
if inference_config is not None
else None
)
outputs_conf = (
OutputsConfig.load(outputs_config)
if outputs_config is not None
else None
)
logging_conf = (
AppLoggingConfig.load(logging_config)
if logging_config is not None
else None
)
api = BatDetect2API.from_checkpoint(
model_path,
audio_config=audio_conf,
inference_config=inference_conf,
outputs_config=outputs_conf,
logging_config=logging_conf,
)
return api, audio_conf, inference_conf, outputs_conf
def _run_prediction(
model_path: Path,
audio_files: list[Path],
output_path: Path,
audio_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
logging_config: Path | None,
batch_size: int | None,
num_workers: int,
format_name: str | None,
) -> None:
logger.info("Initiating prediction process...")
api, audio_conf, inference_conf, outputs_conf = _build_api(
model_path,
audio_config,
inference_config,
outputs_config,
logging_config,
)
logger.info("Found {num_files} audio files", num_files=len(audio_files))
predictions = api.process_files(
audio_files,
batch_size=batch_size,
num_workers=num_workers,
audio_config=audio_conf,
inference_config=inference_conf,
output_config=outputs_conf,
)
common_path = audio_files[0].parent if audio_files else None
api.save_predictions(
predictions,
path=output_path,
audio_dir=common_path,
format=format_name,
)
logger.info(
"Prediction complete. Results saved to {path}", path=output_path
)
@predict.command(name="directory")
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("audio_dir", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path())
@click.option("--audio-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--logging-config", type=click.Path(exists=True))
@click.option("--batch-size", type=int)
@click.option("--workers", "num_workers", type=int, default=0)
@click.option("--format", "format_name", type=str)
def predict_directory_command(
model_path: Path,
audio_dir: Path,
output_path: Path,
audio_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
logging_config: Path | None,
batch_size: int | None,
num_workers: int,
format_name: str | None,
) -> None:
audio_files = list(get_audio_files(audio_dir))
_run_prediction(
model_path=model_path,
audio_files=audio_files,
output_path=output_path,
audio_config=audio_config,
inference_config=inference_config,
outputs_config=outputs_config,
logging_config=logging_config,
batch_size=batch_size,
num_workers=num_workers,
format_name=format_name,
)
@predict.command(name="file_list")
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("file_list", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path())
@click.option("--audio-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--logging-config", type=click.Path(exists=True))
@click.option("--batch-size", type=int)
@click.option("--workers", "num_workers", type=int, default=0)
@click.option("--format", "format_name", type=str)
def predict_file_list_command(
model_path: Path,
file_list: Path,
output_path: Path,
audio_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
logging_config: Path | None,
batch_size: int | None,
num_workers: int,
format_name: str | None,
) -> None:
file_list = Path(file_list)
audio_files = [
Path(line.strip())
for line in file_list.read_text().splitlines()
if line.strip()
]
_run_prediction(
model_path=model_path,
audio_files=audio_files,
output_path=output_path,
audio_config=audio_config,
inference_config=inference_config,
outputs_config=outputs_config,
logging_config=logging_config,
batch_size=batch_size,
num_workers=num_workers,
format_name=format_name,
)
@predict.command(name="dataset")
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("dataset_path", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path())
@click.option("--audio-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--logging-config", type=click.Path(exists=True))
@click.option("--batch-size", type=int)
@click.option("--workers", "num_workers", type=int, default=0)
@click.option("--format", "format_name", type=str)
def predict_dataset_command(
model_path: Path,
dataset_path: Path,
output_path: Path,
audio_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
logging_config: Path | None,
batch_size: int | None,
num_workers: int,
format_name: str | None,
) -> None:
dataset_path = Path(dataset_path)
dataset = io.load(dataset_path, type="annotation_set")
audio_files = sorted(
{
Path(clip_annotation.clip.recording.path)
for clip_annotation in dataset.clip_annotations
}
)
_run_prediction(
model_path=model_path,
audio_files=audio_files,
output_path=output_path,
audio_config=audio_config,
inference_config=inference_config,
outputs_config=outputs_config,
logging_config=logging_config,
batch_size=batch_size,
num_workers=num_workers,
format_name=format_name,
)

View File

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from typing import Optional
import click import click
from loguru import logger from loguru import logger
@ -14,10 +13,15 @@ __all__ = ["train_command"]
@click.option("--val-dataset", type=click.Path(exists=True)) @click.option("--val-dataset", type=click.Path(exists=True))
@click.option("--model", "model_path", type=click.Path(exists=True)) @click.option("--model", "model_path", type=click.Path(exists=True))
@click.option("--targets", "targets_config", type=click.Path(exists=True)) @click.option("--targets", "targets_config", type=click.Path(exists=True))
@click.option("--model-config", type=click.Path(exists=True))
@click.option("--training-config", type=click.Path(exists=True))
@click.option("--audio-config", type=click.Path(exists=True))
@click.option("--evaluation-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--logging-config", type=click.Path(exists=True))
@click.option("--ckpt-dir", type=click.Path(exists=True)) @click.option("--ckpt-dir", type=click.Path(exists=True))
@click.option("--log-dir", type=click.Path(exists=True)) @click.option("--log-dir", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True))
@click.option("--config-field", type=str)
@click.option("--train-workers", type=int) @click.option("--train-workers", type=int)
@click.option("--val-workers", type=int) @click.option("--val-workers", type=int)
@click.option("--num-epochs", type=int) @click.option("--num-epochs", type=int)
@ -26,42 +30,82 @@ __all__ = ["train_command"]
@click.option("--seed", type=int) @click.option("--seed", type=int)
def train_command( def train_command(
train_dataset: Path, train_dataset: Path,
val_dataset: Optional[Path] = None, val_dataset: Path | None = None,
model_path: Optional[Path] = None, model_path: Path | None = None,
ckpt_dir: Optional[Path] = None, ckpt_dir: Path | None = None,
log_dir: Optional[Path] = None, log_dir: Path | None = None,
config: Optional[Path] = None, targets_config: Path | None = None,
targets_config: Optional[Path] = None, model_config: Path | None = None,
config_field: Optional[str] = None, training_config: Path | None = None,
seed: Optional[int] = None, audio_config: Path | None = None,
num_epochs: Optional[int] = None, evaluation_config: Path | None = None,
inference_config: Path | None = None,
outputs_config: Path | None = None,
logging_config: Path | None = None,
seed: int | None = None,
num_epochs: int | None = None,
train_workers: int = 0, train_workers: int = 0,
val_workers: int = 0, val_workers: int = 0,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
): ):
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import ( from batdetect2.audio import AudioConfig
BatDetect2Config, from batdetect2.config import BatDetect2Config
load_full_config,
)
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config from batdetect2.evaluate import EvaluationConfig
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.targets import TargetConfig
from batdetect2.train import TrainingConfig
logger.info("Initiating training process...") logger.info("Initiating training process...")
logger.info("Loading configuration...") logger.info("Loading configuration...")
conf = ( target_conf = (
load_full_config(config, field=config_field) TargetConfig.load(targets_config)
if config is not None if targets_config is not None
else BatDetect2Config() else None
)
model_conf = (
ModelConfig.load(model_config) if model_config is not None else None
)
train_conf = (
TrainingConfig.load(training_config)
if training_config is not None
else None
)
audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None
)
eval_conf = (
EvaluationConfig.load(evaluation_config)
if evaluation_config is not None
else None
)
inference_conf = (
InferenceConfig.load(inference_config)
if inference_config is not None
else None
)
outputs_conf = (
OutputsConfig.load(outputs_config)
if outputs_config is not None
else None
)
logging_conf = (
AppLoggingConfig.load(logging_config)
if logging_config is not None
else None
) )
if targets_config is not None: if target_conf is not None:
logger.info("Loading targets configuration...") logger.info("Loaded targets configuration.")
conf = conf.model_copy(
update=dict(targets=load_target_config(targets_config)) if model_conf is not None and target_conf is not None:
) model_conf = model_conf.model_copy(update={"targets": target_conf})
logger.info("Loading training dataset...") logger.info("Loading training dataset...")
train_annotations = load_dataset_from_config(train_dataset) train_annotations = load_dataset_from_config(train_dataset)
@ -82,12 +126,43 @@ def train_command(
logger.info("Configuration and data loaded. Starting training...") logger.info("Configuration and data loaded. Starting training...")
if model_path is not None and model_conf is not None:
raise click.UsageError(
"--model-config cannot be used with --model. "
"Checkpoint model configuration is loaded from the checkpoint."
)
if model_path is None: if model_path is None:
conf = BatDetect2Config()
if model_conf is not None:
conf.model = model_conf
elif target_conf is not None:
conf.model = conf.model.model_copy(update={"targets": target_conf})
if train_conf is not None:
conf.train = train_conf
if audio_conf is not None:
conf.audio = audio_conf
if eval_conf is not None:
conf.evaluation = eval_conf
if inference_conf is not None:
conf.inference = inference_conf
if outputs_conf is not None:
conf.outputs = outputs_conf
if logging_conf is not None:
conf.logging = logging_conf
api = BatDetect2API.from_config(conf) api = BatDetect2API.from_config(conf)
else: else:
api = BatDetect2API.from_checkpoint( api = BatDetect2API.from_checkpoint(
model_path, model_path,
config=conf if config is not None else None, targets_config=target_conf,
train_config=train_conf,
audio_config=audio_conf,
evaluation_config=eval_conf,
inference_config=inference_conf,
outputs_config=outputs_conf,
logging_config=logging_conf,
) )
return api.train( return api.train(

View File

@ -4,7 +4,7 @@ import json
import os import os
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Union from typing import Callable, List
import numpy as np import numpy as np
from soundevent import data from soundevent import data
@ -17,7 +17,7 @@ from batdetect2.types import (
FileAnnotation, FileAnnotation,
) )
PathLike = Union[Path, str, os.PathLike] PathLike = Path | str | os.PathLike
__all__ = [ __all__ = [
"convert_to_annotation_group", "convert_to_annotation_group",
@ -33,7 +33,7 @@ UNKNOWN_CLASS = "__UNKNOWN__"
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242") NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]] EventFn = Callable[[data.SoundEventAnnotation], str | None]
ClassFn = Callable[[data.Recording], int] ClassFn = Callable[[data.Recording], int]
@ -103,17 +103,17 @@ def convert_to_annotation_group(
y_inds.append(0) y_inds.append(0)
annotations.append( annotations.append(
{ Annotation(
"start_time": start_time, start_time=start_time,
"end_time": end_time, end_time=end_time,
"low_freq": low_freq, low_freq=low_freq,
"high_freq": high_freq, high_freq=high_freq,
"class_prob": 1.0, class_prob=1.0,
"det_prob": 1.0, det_prob=1.0,
"individual": "0", individual="0",
"event": event, event=event,
"class_id": class_id, # type: ignore class_id=class_id,
} )
) )
return { return {
@ -221,7 +221,7 @@ def annotation_to_sound_event_prediction(
def file_annotation_to_clip( def file_annotation_to_clip(
file_annotation: FileAnnotation, file_annotation: FileAnnotation,
audio_dir: Optional[PathLike] = None, audio_dir: PathLike | None = None,
label_key: str = "class", label_key: str = "class",
) -> data.Clip: ) -> data.Clip:
"""Convert file annotation to recording.""" """Convert file annotation to recording."""

View File

@ -1,28 +1,20 @@
from typing import Literal, Optional from typing import Literal
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike
from batdetect2.audio import AudioConfig from batdetect2.audio import AudioConfig
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig
from batdetect2.data.predictions import OutputFormatConfig
from batdetect2.data.predictions.raw import RawOutputConfig
from batdetect2.evaluate.config import ( from batdetect2.evaluate.config import (
EvaluationConfig, EvaluationConfig,
get_default_eval_config, get_default_eval_config,
) )
from batdetect2.inference.config import InferenceConfig from batdetect2.inference.config import InferenceConfig
from batdetect2.models.config import BackboneConfig from batdetect2.logging import AppLoggingConfig
from batdetect2.postprocess.config import PostprocessConfig from batdetect2.models import ModelConfig
from batdetect2.preprocess.config import PreprocessingConfig from batdetect2.outputs import OutputsConfig
from batdetect2.targets.config import TargetConfig
from batdetect2.train.config import TrainingConfig from batdetect2.train.config import TrainingConfig
__all__ = [ __all__ = ["BatDetect2Config"]
"BatDetect2Config",
"load_full_config",
"validate_config",
]
class BatDetect2Config(BaseConfig): class BatDetect2Config(BaseConfig):
@ -32,26 +24,8 @@ class BatDetect2Config(BaseConfig):
evaluation: EvaluationConfig = Field( evaluation: EvaluationConfig = Field(
default_factory=get_default_eval_config default_factory=get_default_eval_config
) )
model: BackboneConfig = Field(default_factory=BackboneConfig) model: ModelConfig = Field(default_factory=ModelConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
audio: AudioConfig = Field(default_factory=AudioConfig) audio: AudioConfig = Field(default_factory=AudioConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
inference: InferenceConfig = Field(default_factory=InferenceConfig) inference: InferenceConfig = Field(default_factory=InferenceConfig)
output: OutputFormatConfig = Field(default_factory=RawOutputConfig) outputs: OutputsConfig = Field(default_factory=OutputsConfig)
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
def validate_config(config: Optional[dict]) -> BatDetect2Config:
if config is None:
return BatDetect2Config()
return BatDetect2Config.model_validate(config)
def load_full_config(
path: PathLike,
field: Optional[str] = None,
) -> BatDetect2Config:
return load_config(path, schema=BatDetect2Config, field=field)

View File

@ -1,8 +1,14 @@
from batdetect2.core.configs import BaseConfig, load_config, merge_configs from batdetect2.core.configs import BaseConfig, load_config, merge_configs
from batdetect2.core.registries import Registry from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
__all__ = [ __all__ = [
"add_import_config",
"BaseConfig", "BaseConfig",
"ImportConfig",
"load_config", "load_config",
"Registry", "Registry",
"merge_configs", "merge_configs",

View File

@ -1,5 +1,3 @@
from typing import Optional
import numpy as np import numpy as np
import torch import torch
import xarray as xr import xarray as xr
@ -86,8 +84,8 @@ def adjust_width(
def slice_tensor( def slice_tensor(
tensor: torch.Tensor, tensor: torch.Tensor,
start: Optional[int] = None, start: int | None = None,
end: Optional[int] = None, end: int | None = None,
dim: int = -1, dim: int = -1,
) -> torch.Tensor: ) -> torch.Tensor:
slices = [slice(None)] * tensor.ndim slices = [slice(None)] * tensor.ndim

View File

@ -8,11 +8,11 @@ configuration data from files, with optional support for accessing nested
configuration sections. configuration sections.
""" """
from typing import Any, Optional, Type, TypeVar from typing import Any, Literal, Type, TypeVar, overload
import yaml import yaml
from deepmerge.merger import Merger from deepmerge.merger import Merger
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict, TypeAdapter
from soundevent.data import PathLike from soundevent.data import PathLike
__all__ = [ __all__ = [
@ -21,6 +21,8 @@ __all__ = [
"merge_configs", "merge_configs",
] ]
C = TypeVar("C", bound="BaseConfig")
class BaseConfig(BaseModel): class BaseConfig(BaseModel):
"""Base class for all configuration models in BatDetect2. """Base class for all configuration models in BatDetect2.
@ -62,8 +64,30 @@ class BaseConfig(BaseModel):
) )
) )
@classmethod
def from_yaml(cls, yaml_str: str):
return cls.model_validate(yaml.safe_load(yaml_str))
T = TypeVar("T", bound=BaseModel) @classmethod
def load(
cls: Type[C],
path: PathLike,
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
) -> C:
return load_config(
path,
schema=cls,
field=field,
extra=extra,
strict=strict,
)
T = TypeVar("T")
T_Model = TypeVar("T_Model", bound=BaseModel)
Schema = Type[T_Model] | TypeAdapter[T]
def get_object_field(obj: dict, current_key: str) -> Any: def get_object_field(obj: dict, current_key: str) -> Any:
@ -125,35 +149,69 @@ def get_object_field(obj: dict, current_key: str) -> Any:
return get_object_field(subobj, rest) return get_object_field(subobj, rest)
@overload
def load_config( def load_config(
path: PathLike, path: PathLike,
schema: Type[T], schema: Type[T_Model],
field: Optional[str] = None, field: str | None = None,
) -> T: extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
) -> T_Model: ...
@overload
def load_config(
path: PathLike,
schema: TypeAdapter[T],
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
) -> T: ...
def load_config(
path: PathLike,
schema: Type[T_Model] | TypeAdapter[T],
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
) -> T_Model | T:
"""Load and validate configuration data from a file against a schema. """Load and validate configuration data from a file against a schema.
Reads a YAML file, optionally extracts a specific section using dot Reads a YAML file, optionally extracts a specific section using dot
notation, and then validates the resulting data against the provided notation, and then validates the resulting data against the provided
Pydantic `schema`. Pydantic schema.
Parameters Parameters
---------- ----------
path : PathLike path : PathLike
The path to the configuration file (typically `.yaml`). The path to the configuration file (typically `.yaml`).
schema : Type[T] schema : Type[T_Model] | TypeAdapter[T]
The Pydantic `BaseModel` subclass that defines the expected structure Either a Pydantic `BaseModel` subclass or a `TypeAdapter` instance
and types for the configuration data. that defines the expected structure and types for the configuration
data.
field : str, optional field : str, optional
A dot-separated string indicating a nested section within the YAML A dot-separated string indicating a nested section within the YAML
file to extract before validation. If None (default), the entire file to extract before validation. If None (default), the entire
file content is validated against the schema. file content is validated against the schema.
Example: `"training.optimizer"` would extract the `optimizer` section Example: `"training.optimizer"` would extract the `optimizer` section
within the `training` section. within the `training` section.
extra : Literal["ignore", "allow", "forbid"], optional
How to handle extra keys in the configuration data. If None (default),
the default behaviour of the schema is used. If "ignore", extra keys
are ignored. If "allow", extra keys are allowed and will be accessible
as attributes on the resulting model instance. If "forbid", extra
keys are forbidden and an exception is raised. See pydantic
documentation for more details.
strict : bool, optional
Whether to enforce types strictly. If None (default), the default
behaviour of the schema is used. See pydantic documentation for more
details.
Returns Returns
------- -------
T T_Model | T
An instance of the provided `schema`, populated and validated with An instance of the schema type, populated and validated with
data from the configuration file. data from the configuration file.
Raises Raises
@ -179,7 +237,10 @@ def load_config(
if field: if field:
config = get_object_field(config, field) config = get_object_field(config, field)
return schema.model_validate(config or {}) if isinstance(schema, TypeAdapter):
return schema.validate_python(config or {}, extra=extra, strict=strict)
return schema.model_validate(config or {}, extra=extra, strict=strict)
default_merger = Merger( default_merger = Merger(
@ -189,7 +250,7 @@ default_merger = Merger(
) )
def merge_configs(config1: T, config2: T) -> T: def merge_configs(config1: T_Model, config2: T_Model) -> T_Model:
"""Merge two configuration objects.""" """Merge two configuration objects."""
model = type(config1) model = type(config1)
dict1 = config1.model_dump() dict1 = config1.model_dump()

View File

@ -1,23 +1,28 @@
import sys from typing import (
from typing import Callable, Dict, Generic, Tuple, Type, TypeVar Any,
Callable,
Concatenate,
Generic,
ParamSpec,
Sequence,
Type,
TypeVar,
)
from pydantic import BaseModel from hydra.utils import instantiate
from pydantic import BaseModel, Field
if sys.version_info >= (3, 10):
from typing import Concatenate, ParamSpec
else:
from typing_extensions import Concatenate, ParamSpec
__all__ = [ __all__ = [
"add_import_config",
"ImportConfig",
"Registry", "Registry",
"SimpleRegistry", "SimpleRegistry",
] ]
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True) T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
T_Type = TypeVar("T_Type", covariant=True) T_Type = TypeVar("T_Type", covariant=True)
P_Type = ParamSpec("P_Type") P_Type = ParamSpec("P_Type")
T = TypeVar("T") T = TypeVar("T")
@ -43,12 +48,13 @@ class SimpleRegistry(Generic[T]):
class Registry(Generic[T_Type, P_Type]): class Registry(Generic[T_Type, P_Type]):
"""A generic class to create and manage a registry of items.""" """A generic class to create and manage a registry of items."""
def __init__(self, name: str): def __init__(self, name: str, discriminator: str = "name"):
self._name = name self._name = name
self._registry: Dict[ self._registry: dict[
str, Callable[Concatenate[..., P_Type], T_Type] str, Callable[Concatenate[..., P_Type], T_Type]
] = {} ] = {}
self._config_types: Dict[str, Type[BaseModel]] = {} self._discriminator = discriminator
self._config_types: dict[str, Type[BaseModel]] = {}
def register( def register(
self, self,
@ -56,15 +62,20 @@ class Registry(Generic[T_Type, P_Type]):
): ):
fields = config_cls.model_fields fields = config_cls.model_fields
if "name" not in fields: if self._discriminator not in fields:
raise ValueError("Configuration object must have a 'name' field.") raise ValueError(
"Configuration object must have "
f"a '{self._discriminator}' field."
)
name = fields["name"].default name = fields[self._discriminator].default
self._config_types[name] = config_cls self._config_types[name] = config_cls
if not isinstance(name, str): if not isinstance(name, str):
raise ValueError("'name' field must be a string literal.") raise ValueError(
f"'{self._discriminator}' field must be a string literal."
)
def decorator( def decorator(
func: Callable[Concatenate[T_Config, P_Type], T_Type], func: Callable[Concatenate[T_Config, P_Type], T_Type],
@ -74,7 +85,7 @@ class Registry(Generic[T_Type, P_Type]):
return decorator return decorator
def get_config_types(self) -> Tuple[Type[BaseModel], ...]: def get_config_types(self) -> tuple[Type[BaseModel], ...]:
return tuple(self._config_types.values()) return tuple(self._config_types.values())
def get_config_type(self, name: str) -> Type[BaseModel]: def get_config_type(self, name: str) -> Type[BaseModel]:
@ -94,10 +105,12 @@ class Registry(Generic[T_Type, P_Type]):
) -> T_Type: ) -> T_Type:
"""Builds a logic instance from a config object.""" """Builds a logic instance from a config object."""
name = getattr(config, "name") # noqa: B009 name = getattr(config, self._discriminator) # noqa: B009
if name is None: if name is None:
raise ValueError("Config does not have a name field") raise ValueError(
f"Config does not have a '{self._discriminator}' field"
)
if name not in self._registry: if name not in self._registry:
raise NotImplementedError( raise NotImplementedError(
@ -105,3 +118,92 @@ class Registry(Generic[T_Type, P_Type]):
) )
return self._registry[name](config, *args, **kwargs) return self._registry[name](config, *args, **kwargs)
class ImportConfig(BaseModel):
"""Base config for dynamic instantiation via Hydra.
Subclass this to create a registry-specific import escape hatch.
The subclass must add a discriminator field whose name matches the
registry's own discriminator key, with its value fixed to
``Literal["import"]``.
Attributes
----------
target : str
Fully-qualified dotted path to the callable to instantiate,
e.g. ``"mypackage.module.MyClass"``.
arguments : dict[str, Any]
Base keyword arguments forwarded to the callable. When the
same key also appears in ``kwargs`` passed to ``build()``,
the ``kwargs`` value takes priority.
"""
target: str
arguments: dict[str, Any] = Field(default_factory=dict)
T_Import = TypeVar("T_Import", bound=ImportConfig)
def add_import_config(
registry: Registry[T_Type, P_Type],
arg_names: Sequence[str] | None = None,
) -> Callable[[Type[T_Import]], Type[T_Import]]:
"""Decorator that registers an ImportConfig subclass as an escape hatch.
Wraps the decorated class in a builder that calls
``hydra.utils.instantiate`` using ``config.target`` and
``config.arguments``. The builder is registered on *registry*
under the discriminator value ``"import"``.
Parameters
----------
registry : Registry
The registry instance on which the config should be registered.
Returns
-------
Callable[[type[ImportConfig]], type[ImportConfig]]
A class decorator that registers the class and returns it
unchanged.
Examples
--------
Define a per-registry import escape hatch::
@add_import_config(my_registry)
class MyRegistryImportConfig(ImportConfig):
name: Literal["import"] = "import"
"""
def decorator(config_cls: Type[T_Import]) -> Type[T_Import]:
def builder(
config: T_Import,
*args: P_Type.args,
**kwargs: P_Type.kwargs,
) -> T_Type:
_arg_names = arg_names or []
if len(args) != len(_arg_names):
raise ValueError(
"Positional arguments are not supported "
"for import escape hatch unless you specify "
"the argument names. Use `arg_names` to specify "
"the names of the positional arguments."
)
args_dict = {_arg_names[i]: args[i] for i in range(len(args))}
hydra_cfg = {
"_target_": config.target,
**config.arguments,
**args_dict,
**kwargs,
}
return instantiate(hydra_cfg)
registry.register(config_cls)(builder)
return config_cls
return decorator

View File

@ -7,20 +7,12 @@ from batdetect2.data.annotations import (
load_annotated_dataset, load_annotated_dataset,
) )
from batdetect2.data.datasets import ( from batdetect2.data.datasets import (
Dataset,
DatasetConfig, DatasetConfig,
load_dataset, load_dataset,
load_dataset_config, load_dataset_config,
load_dataset_from_config, load_dataset_from_config,
) )
from batdetect2.data.predictions import (
BatDetect2OutputConfig,
OutputFormatConfig,
RawOutputConfig,
SoundEventOutputConfig,
build_output_formatter,
get_output_formatter,
load_predictions,
)
from batdetect2.data.summary import ( from batdetect2.data.summary import (
compute_class_summary, compute_class_summary,
extract_recordings_df, extract_recordings_df,
@ -28,6 +20,7 @@ from batdetect2.data.summary import (
) )
__all__ = [ __all__ = [
"Dataset",
"AOEFAnnotations", "AOEFAnnotations",
"AnnotatedDataset", "AnnotatedDataset",
"AnnotationFormats", "AnnotationFormats",
@ -36,6 +29,7 @@ __all__ = [
"BatDetect2OutputConfig", "BatDetect2OutputConfig",
"DatasetConfig", "DatasetConfig",
"OutputFormatConfig", "OutputFormatConfig",
"ParquetOutputConfig",
"RawOutputConfig", "RawOutputConfig",
"SoundEventOutputConfig", "SoundEventOutputConfig",
"build_output_formatter", "build_output_formatter",

View File

@ -13,22 +13,18 @@ format-specific loading function to retrieve the annotations as a standard
`soundevent.data.AnnotationSet`. `soundevent.data.AnnotationSet`.
""" """
from typing import Annotated, Optional, Union from typing import Annotated
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.data.annotations.aoef import ( from batdetect2.data.annotations.aoef import AOEFAnnotations
AOEFAnnotations,
load_aoef_annotated_dataset,
)
from batdetect2.data.annotations.batdetect2 import ( from batdetect2.data.annotations.batdetect2 import (
AnnotationFilter, AnnotationFilter,
BatDetect2FilesAnnotations, BatDetect2FilesAnnotations,
BatDetect2MergedAnnotations, BatDetect2MergedAnnotations,
load_batdetect2_files_annotated_dataset,
load_batdetect2_merged_annotated_dataset,
) )
from batdetect2.data.annotations.registry import annotation_format_registry
from batdetect2.data.annotations.types import AnnotatedDataset from batdetect2.data.annotations.types import AnnotatedDataset
__all__ = [ __all__ = [
@ -43,11 +39,7 @@ __all__ = [
AnnotationFormats = Annotated[ AnnotationFormats = Annotated[
Union[ BatDetect2MergedAnnotations | BatDetect2FilesAnnotations | AOEFAnnotations,
BatDetect2MergedAnnotations,
BatDetect2FilesAnnotations,
AOEFAnnotations,
],
Field(discriminator="format"), Field(discriminator="format"),
] ]
"""Type Alias representing all supported data source configurations. """Type Alias representing all supported data source configurations.
@ -63,24 +55,24 @@ source configuration represents.
def load_annotated_dataset( def load_annotated_dataset(
dataset: AnnotatedDataset, dataset: AnnotatedDataset,
base_dir: Optional[data.PathLike] = None, base_dir: data.PathLike | None = None,
) -> data.AnnotationSet: ) -> data.AnnotationSet:
"""Load annotations for a single data source based on its configuration. """Load annotations for a single data source based on its configuration.
This function acts as a dispatcher. It inspects the type of the input This function acts as a dispatcher. It inspects the format of the input
`source_config` object (which corresponds to a specific annotation format) `dataset` object and delegates to the appropriate format-specific loader
and calls the appropriate loading function (e.g., registered in the `annotation_format_registry` (e.g.,
`load_aoef_annotated_dataset` for `AOEFAnnotations`). `AOEFLoader` for `AOEFAnnotations`).
Parameters Parameters
---------- ----------
source_config : AnnotationFormats dataset : AnnotatedDataset
The configuration object for the data source, specifying its format The configuration object for the data source, specifying its format
and necessary details (like paths). Must be an instance of one of the and necessary details (like paths). Must be an instance of one of the
types included in the `AnnotationFormats` union. types included in the `AnnotationFormats` union.
base_dir : Path, optional base_dir : Path, optional
An optional base directory path. If provided, relative paths within An optional base directory path. If provided, relative paths within
the `source_config` might be resolved relative to this directory by the `dataset` will be resolved relative to this directory by
the underlying loading functions. Defaults to None. the underlying loading functions. Defaults to None.
Returns Returns
@ -92,23 +84,8 @@ def load_annotated_dataset(
Raises Raises
------ ------
NotImplementedError NotImplementedError
If the type of the `source_config` object does not match any of the If the `format` field of `dataset` does not match any registered
known format-specific loading functions implemented in the dispatch annotation format loader.
logic.
""" """
loader = annotation_format_registry.build(dataset)
if isinstance(dataset, AOEFAnnotations): return loader.load(base_dir=base_dir)
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)
if isinstance(dataset, BatDetect2MergedAnnotations):
return load_batdetect2_merged_annotated_dataset(
dataset, base_dir=base_dir
)
if isinstance(dataset, BatDetect2FilesAnnotations):
return load_batdetect2_files_annotated_dataset(
dataset,
base_dir=base_dir,
)
raise NotImplementedError(f"Unknown annotation format: {dataset.name}")

View File

@ -12,17 +12,22 @@ that meet specific status criteria (e.g., completed, verified, without issues).
""" """
from pathlib import Path from pathlib import Path
from typing import Literal, Optional from typing import Literal
from uuid import uuid5 from uuid import uuid5
from pydantic import Field from pydantic import Field
from soundevent import data, io from soundevent import data, io
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.data.annotations.types import AnnotatedDataset from batdetect2.data.annotations.registry import annotation_format_registry
from batdetect2.data.annotations.types import (
AnnotatedDataset,
AnnotationLoader,
)
__all__ = [ __all__ = [
"AOEFAnnotations", "AOEFAnnotations",
"AOEFLoader",
"load_aoef_annotated_dataset", "load_aoef_annotated_dataset",
"AnnotationTaskFilter", "AnnotationTaskFilter",
] ]
@ -77,14 +82,30 @@ class AOEFAnnotations(AnnotatedDataset):
annotations_path: Path annotations_path: Path
filter: Optional[AnnotationTaskFilter] = Field( filter: AnnotationTaskFilter | None = Field(
default_factory=AnnotationTaskFilter default_factory=AnnotationTaskFilter
) )
class AOEFLoader(AnnotationLoader):
def __init__(self, config: AOEFAnnotations):
self.config = config
def load(
self,
base_dir: data.PathLike | None = None,
) -> data.AnnotationSet:
return load_aoef_annotated_dataset(self.config, base_dir=base_dir)
@annotation_format_registry.register(AOEFAnnotations)
@staticmethod
def from_config(config: AOEFAnnotations):
return AOEFLoader(config)
def load_aoef_annotated_dataset( def load_aoef_annotated_dataset(
dataset: AOEFAnnotations, dataset: AOEFAnnotations,
base_dir: Optional[data.PathLike] = None, base_dir: data.PathLike | None = None,
) -> data.AnnotationSet: ) -> data.AnnotationSet:
"""Load annotations from an AnnotationSet or AnnotationProject file. """Load annotations from an AnnotationSet or AnnotationProject file.

View File

@ -27,7 +27,7 @@ aggregated into a `soundevent.data.AnnotationSet`.
import json import json
import os import os
from pathlib import Path from pathlib import Path
from typing import Literal, Optional, Union from typing import Literal
from loguru import logger from loguru import logger
from pydantic import Field, ValidationError from pydantic import Field, ValidationError
@ -41,9 +41,13 @@ from batdetect2.data.annotations.legacy import (
list_file_annotations, list_file_annotations,
load_file_annotation, load_file_annotation,
) )
from batdetect2.data.annotations.types import AnnotatedDataset from batdetect2.data.annotations.registry import annotation_format_registry
from batdetect2.data.annotations.types import (
AnnotatedDataset,
AnnotationLoader,
)
PathLike = Union[Path, str, os.PathLike] PathLike = Path | str | os.PathLike
__all__ = [ __all__ = [
@ -102,7 +106,7 @@ class BatDetect2FilesAnnotations(AnnotatedDataset):
format: Literal["batdetect2"] = "batdetect2" format: Literal["batdetect2"] = "batdetect2"
annotations_dir: Path annotations_dir: Path
filter: Optional[AnnotationFilter] = Field( filter: AnnotationFilter | None = Field(
default_factory=AnnotationFilter, default_factory=AnnotationFilter,
) )
@ -133,14 +137,14 @@ class BatDetect2MergedAnnotations(AnnotatedDataset):
format: Literal["batdetect2_file"] = "batdetect2_file" format: Literal["batdetect2_file"] = "batdetect2_file"
annotations_path: Path annotations_path: Path
filter: Optional[AnnotationFilter] = Field( filter: AnnotationFilter | None = Field(
default_factory=AnnotationFilter, default_factory=AnnotationFilter,
) )
def load_batdetect2_files_annotated_dataset( def load_batdetect2_files_annotated_dataset(
dataset: BatDetect2FilesAnnotations, dataset: BatDetect2FilesAnnotations,
base_dir: Optional[PathLike] = None, base_dir: PathLike | None = None,
) -> data.AnnotationSet: ) -> data.AnnotationSet:
"""Load and convert 'batdetect2_file' annotations into an AnnotationSet. """Load and convert 'batdetect2_file' annotations into an AnnotationSet.
@ -244,7 +248,7 @@ def load_batdetect2_files_annotated_dataset(
def load_batdetect2_merged_annotated_dataset( def load_batdetect2_merged_annotated_dataset(
dataset: BatDetect2MergedAnnotations, dataset: BatDetect2MergedAnnotations,
base_dir: Optional[PathLike] = None, base_dir: PathLike | None = None,
) -> data.AnnotationSet: ) -> data.AnnotationSet:
"""Load and convert 'batdetect2_merged' annotations into an AnnotationSet. """Load and convert 'batdetect2_merged' annotations into an AnnotationSet.
@ -302,7 +306,7 @@ def load_batdetect2_merged_annotated_dataset(
try: try:
ann = FileAnnotation.model_validate(ann) ann = FileAnnotation.model_validate(ann)
except ValueError as err: except ValueError as err:
logger.warning(f"Invalid annotation file: {err}") logger.warning("Invalid annotation file: {err}", err=err)
continue continue
if ( if (
@ -310,17 +314,23 @@ def load_batdetect2_merged_annotated_dataset(
and dataset.filter.only_annotated and dataset.filter.only_annotated
and not ann.annotated and not ann.annotated
): ):
logger.debug(f"Skipping incomplete annotation {ann.id}") logger.debug(
"Skipping incomplete annotation {ann_id}",
ann_id=ann.id,
)
continue continue
if dataset.filter and dataset.filter.exclude_issues and ann.issues: if dataset.filter and dataset.filter.exclude_issues and ann.issues:
logger.debug(f"Skipping annotation with issues {ann.id}") logger.debug(
"Skipping annotation with issues {ann_id}",
ann_id=ann.id,
)
continue continue
try: try:
clip = file_annotation_to_clip(ann, audio_dir=audio_dir) clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
except FileNotFoundError as err: except FileNotFoundError as err:
logger.warning(f"Error loading annotations: {err}") logger.warning("Error loading annotations: {err}", err=err)
continue continue
annotations.append(file_annotation_to_clip_annotation(ann, clip)) annotations.append(file_annotation_to_clip_annotation(ann, clip))
@ -330,3 +340,41 @@ def load_batdetect2_merged_annotated_dataset(
description=dataset.description, description=dataset.description,
clip_annotations=annotations, clip_annotations=annotations,
) )
class BatDetect2MergedLoader(AnnotationLoader):
def __init__(self, config: BatDetect2MergedAnnotations):
self.config = config
def load(
self,
base_dir: PathLike | None = None,
) -> data.AnnotationSet:
return load_batdetect2_merged_annotated_dataset(
self.config,
base_dir=base_dir,
)
@annotation_format_registry.register(BatDetect2MergedAnnotations)
@staticmethod
def from_config(config: BatDetect2MergedAnnotations):
return BatDetect2MergedLoader(config)
class BatDetect2FilesLoader(AnnotationLoader):
def __init__(self, config: BatDetect2FilesAnnotations):
self.config = config
def load(
self,
base_dir: PathLike | None = None,
) -> data.AnnotationSet:
return load_batdetect2_files_annotated_dataset(
self.config,
base_dir=base_dir,
)
@annotation_format_registry.register(BatDetect2FilesAnnotations)
@staticmethod
def from_config(config: BatDetect2FilesAnnotations):
return BatDetect2FilesLoader(config)

View File

@ -3,12 +3,12 @@
import os import os
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional, Union from typing import Callable, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from soundevent import data from soundevent import data
PathLike = Union[Path, str, os.PathLike] PathLike = Path | str | os.PathLike
__all__ = [] __all__ = []
@ -27,7 +27,7 @@ SOUND_EVENT_ANNOTATION_NAMESPACE = uuid.uuid5(
) )
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]] EventFn = Callable[[data.SoundEventAnnotation], str | None]
ClassFn = Callable[[data.Recording], int] ClassFn = Callable[[data.Recording], int]
@ -130,7 +130,7 @@ def get_sound_event_tags(
def file_annotation_to_clip( def file_annotation_to_clip(
file_annotation: FileAnnotation, file_annotation: FileAnnotation,
audio_dir: Optional[PathLike] = None, audio_dir: PathLike | None = None,
label_key: str = "class", label_key: str = "class",
) -> data.Clip: ) -> data.Clip:
"""Convert file annotation to recording.""" """Convert file annotation to recording."""

View File

@ -0,0 +1,35 @@
from typing import Literal
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.data.annotations.types import AnnotationLoader
__all__ = [
"AnnotationFormatImportConfig",
"annotation_format_registry",
]
annotation_format_registry: Registry[AnnotationLoader, []] = Registry(
"annotation_format",
discriminator="format",
)
@add_import_config(annotation_format_registry)
class AnnotationFormatImportConfig(ImportConfig):
"""Import escape hatch for the annotation format registry.
Use this config to dynamically instantiate any callable as an
annotation loader without registering it in
``annotation_format_registry`` ahead of time.
Parameters
----------
format : Literal["import"]
Discriminator value; must always be ``"import"``.
target : str
Fully-qualified dotted path to the callable to instantiate.
arguments : dict[str, Any]
Keyword arguments forwarded to the callable.
"""
format: Literal["import"] = "import"

View File

@ -1,9 +1,13 @@
from pathlib import Path from pathlib import Path
from typing import Protocol
from soundevent import data
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
__all__ = [ __all__ = [
"AnnotatedDataset", "AnnotatedDataset",
"AnnotationLoader",
] ]
@ -34,3 +38,10 @@ class AnnotatedDataset(BaseConfig):
name: str name: str
audio_dir: Path audio_dir: Path
description: str = "" description: str = ""
class AnnotationLoader(Protocol):
def load(
self,
base_dir: data.PathLike | None = None,
) -> data.AnnotationSet: ...

View File

@ -1,18 +1,33 @@
from collections.abc import Callable from collections.abc import Callable
from typing import Annotated, List, Literal, Sequence, Union from typing import Annotated, List, Literal, Sequence
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool] SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
conditions: Registry[SoundEventCondition, []] = Registry("condition") conditions: Registry[SoundEventCondition, []] = Registry("condition")
@add_import_config(conditions)
class SoundEventConditionImportConfig(ImportConfig):
"""Use any callable as a sound event condition.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class HasTagConfig(BaseConfig): class HasTagConfig(BaseConfig):
name: Literal["has_tag"] = "has_tag" name: Literal["has_tag"] = "has_tag"
tag: data.Tag tag: data.Tag
@ -264,16 +279,14 @@ class Not:
SoundEventConditionConfig = Annotated[ SoundEventConditionConfig = Annotated[
Union[ HasTagConfig
HasTagConfig, | HasAllTagsConfig
HasAllTagsConfig, | HasAnyTagConfig
HasAnyTagConfig, | DurationConfig
DurationConfig, | FrequencyConfig
FrequencyConfig, | AllOfConfig
AllOfConfig, | AnyOfConfig
AnyOfConfig, | NotConfig,
NotConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -19,7 +19,7 @@ The core components are:
""" """
from pathlib import Path from pathlib import Path
from typing import List, Optional, Sequence from typing import List, Sequence
from loguru import logger from loguru import logger
from pydantic import Field from pydantic import Field
@ -69,7 +69,7 @@ class DatasetConfig(BaseConfig):
description: str description: str
sources: List[AnnotationFormats] sources: List[AnnotationFormats]
sound_event_filter: Optional[SoundEventConditionConfig] = None sound_event_filter: SoundEventConditionConfig | None = None
sound_event_transforms: List[SoundEventTransformConfig] = Field( sound_event_transforms: List[SoundEventTransformConfig] = Field(
default_factory=list default_factory=list
) )
@ -77,7 +77,7 @@ class DatasetConfig(BaseConfig):
def load_dataset( def load_dataset(
config: DatasetConfig, config: DatasetConfig,
base_dir: Optional[data.PathLike] = None, base_dir: data.PathLike | None = None,
) -> Dataset: ) -> Dataset:
"""Load all clip annotations from the sources defined in a DatasetConfig.""" """Load all clip annotations from the sources defined in a DatasetConfig."""
clip_annotations = [] clip_annotations = []
@ -161,14 +161,14 @@ def insert_source_tag(
) )
def load_dataset_config(path: data.PathLike, field: Optional[str] = None): def load_dataset_config(path: data.PathLike, field: str | None = None):
return load_config(path=path, schema=DatasetConfig, field=field) return load_config(path=path, schema=DatasetConfig, field=field)
def load_dataset_from_config( def load_dataset_from_config(
path: data.PathLike, path: data.PathLike,
field: Optional[str] = None, field: str | None = None,
base_dir: Optional[data.PathLike] = None, base_dir: data.PathLike | None = None,
) -> Dataset: ) -> Dataset:
"""Load dataset annotation metadata from a configuration file. """Load dataset annotation metadata from a configuration file.
@ -215,9 +215,9 @@ def load_dataset_from_config(
def save_dataset( def save_dataset(
dataset: Dataset, dataset: Dataset,
path: data.PathLike, path: data.PathLike,
name: Optional[str] = None, name: str | None = None,
description: Optional[str] = None, description: str | None = None,
audio_dir: Optional[Path] = None, audio_dir: Path | None = None,
) -> None: ) -> None:
"""Save a loaded dataset (list of ClipAnnotations) to a file. """Save a loaded dataset (list of ClipAnnotations) to a file.

View File

@ -1,16 +1,15 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Optional, Tuple
from soundevent import data from soundevent import data
from batdetect2.data.datasets import Dataset from batdetect2.data.datasets import Dataset
from batdetect2.typing.targets import TargetProtocol from batdetect2.targets.types import TargetProtocol
def iterate_over_sound_events( def iterate_over_sound_events(
dataset: Dataset, dataset: Dataset,
targets: TargetProtocol, targets: TargetProtocol,
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]: ) -> Generator[tuple[str | None, data.SoundEventAnnotation], None, None]:
"""Iterate over sound events in a dataset. """Iterate over sound events in a dataset.
Parameters Parameters
@ -24,7 +23,7 @@ def iterate_over_sound_events(
Yields Yields
------ ------
Tuple[Optional[str], data.SoundEventAnnotation] tuple[Optional[str], data.SoundEventAnnotation]
A tuple containing: A tuple containing:
- The encoded class name (str) for the sound event, or None if it - The encoded class name (str) for the sound event, or None if it
cannot be encoded to a specific class. cannot be encoded to a specific class.

View File

@ -1,29 +0,0 @@
from pathlib import Path
from soundevent.data import PathLike
from batdetect2.core import Registry
from batdetect2.typing import (
OutputFormatterProtocol,
TargetProtocol,
)
def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
path = Path(path)
audio_dir = Path(audio_dir)
if path.is_absolute():
if not path.is_relative_to(audio_dir):
raise ValueError(
f"Audio file {path} is not in audio_dir {audio_dir}"
)
return path.relative_to(audio_dir)
return path
prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
Registry(name="output_formatter")
)

View File

@ -1,5 +1,3 @@
from typing import Optional, Tuple
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from batdetect2.data.datasets import Dataset from batdetect2.data.datasets import Dataset
@ -7,15 +5,15 @@ from batdetect2.data.summary import (
extract_recordings_df, extract_recordings_df,
extract_sound_events_df, extract_sound_events_df,
) )
from batdetect2.typing.targets import TargetProtocol from batdetect2.targets.types import TargetProtocol
def split_dataset_by_recordings( def split_dataset_by_recordings(
dataset: Dataset, dataset: Dataset,
targets: TargetProtocol, targets: TargetProtocol,
train_size: float = 0.75, train_size: float = 0.75,
random_state: Optional[int] = None, random_state: int | None = None,
) -> Tuple[Dataset, Dataset]: ) -> tuple[Dataset, Dataset]:
recordings = extract_recordings_df(dataset) recordings = extract_recordings_df(dataset)
sound_events = extract_sound_events_df( sound_events = extract_sound_events_df(
@ -26,13 +24,15 @@ def split_dataset_by_recordings(
) )
majority_class = ( majority_class = (
sound_events.groupby("recording_id") sound_events.groupby("recording_id") # type: ignore
.apply( .apply(
lambda group: group["class_name"] # type: ignore lambda group: (
.value_counts() group["class_name"]
.sort_values(ascending=False) .value_counts()
.index[0], .sort_values(ascending=False)
include_groups=False, # type: ignore .index[0]
),
include_groups=False,
) )
.rename("class_name") .rename("class_name")
.to_frame() .to_frame()
@ -46,8 +46,8 @@ def split_dataset_by_recordings(
random_state=random_state, random_state=random_state,
) )
train_ids_set = set(train.values) # type: ignore train_ids_set = set(train.values)
test_ids_set = set(test.values) # type: ignore test_ids_set = set(test.values)
extra = set(recordings["recording_id"]) - train_ids_set - test_ids_set extra = set(recordings["recording_id"]) - train_ids_set - test_ids_set

View File

@ -2,7 +2,7 @@ import pandas as pd
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.data.datasets import Dataset from batdetect2.data.datasets import Dataset
from batdetect2.typing.targets import TargetProtocol from batdetect2.targets.types import TargetProtocol
__all__ = [ __all__ = [
"extract_recordings_df", "extract_recordings_df",
@ -175,14 +175,14 @@ def compute_class_summary(
.rename("num recordings") .rename("num recordings")
) )
durations = ( durations = (
sound_events.groupby("class_name") sound_events.groupby("class_name") # ty: ignore[no-matching-overload]
.apply( .apply(
lambda group: recordings[ lambda group: recordings[
recordings["clip_annotation_id"].isin( recordings["clip_annotation_id"].isin(
group["clip_annotation_id"] # type: ignore group["clip_annotation_id"]
) )
]["duration"].sum(), ]["duration"].sum(),
include_groups=False, # type: ignore include_groups=False,
) )
.sort_values(ascending=False) .sort_values(ascending=False)
.rename("duration") .rename("duration")

View File

@ -1,11 +1,15 @@
from collections.abc import Callable from collections.abc import Callable
from typing import Annotated, Dict, List, Literal, Optional, Union from typing import Annotated, Dict, List, Literal
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.data.conditions import ( from batdetect2.data.conditions import (
SoundEventCondition, SoundEventCondition,
SoundEventConditionConfig, SoundEventConditionConfig,
@ -20,6 +24,17 @@ SoundEventTransform = Callable[
transforms: Registry[SoundEventTransform, []] = Registry("transform") transforms: Registry[SoundEventTransform, []] = Registry("transform")
@add_import_config(transforms)
class SoundEventTransformImportConfig(ImportConfig):
"""Use any callable as a sound event transform.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class SetFrequencyBoundConfig(BaseConfig): class SetFrequencyBoundConfig(BaseConfig):
name: Literal["set_frequency"] = "set_frequency" name: Literal["set_frequency"] = "set_frequency"
boundary: Literal["low", "high"] = "low" boundary: Literal["low", "high"] = "low"
@ -142,7 +157,7 @@ class MapTagValueConfig(BaseConfig):
name: Literal["map_tag_value"] = "map_tag_value" name: Literal["map_tag_value"] = "map_tag_value"
tag_key: str tag_key: str
value_mapping: Dict[str, str] value_mapping: Dict[str, str]
target_key: Optional[str] = None target_key: str | None = None
class MapTagValue: class MapTagValue:
@ -150,7 +165,7 @@ class MapTagValue:
self, self,
tag_key: str, tag_key: str,
value_mapping: Dict[str, str], value_mapping: Dict[str, str],
target_key: Optional[str] = None, target_key: str | None = None,
): ):
self.tag_key = tag_key self.tag_key = tag_key
self.value_mapping = value_mapping self.value_mapping = value_mapping
@ -176,12 +191,7 @@ class MapTagValue:
if self.target_key is None: if self.target_key is None:
tags.append(tag.model_copy(update=dict(value=value))) tags.append(tag.model_copy(update=dict(value=value)))
else: else:
tags.append( tags.append(data.Tag(key=self.target_key, value=value))
data.Tag(
key=self.target_key, # type: ignore
value=value,
)
)
return sound_event_annotation.model_copy(update=dict(tags=tags)) return sound_event_annotation.model_copy(update=dict(tags=tags))
@ -221,13 +231,11 @@ class ApplyAll:
SoundEventTransformConfig = Annotated[ SoundEventTransformConfig = Annotated[
Union[ SetFrequencyBoundConfig
SetFrequencyBoundConfig, | ReplaceTagConfig
ReplaceTagConfig, | MapTagValueConfig
MapTagValueConfig, | ApplyIfConfig
ApplyIfConfig, | ApplyAllConfig,
ApplyAllConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -1,6 +1,6 @@
"""Functions to compute features from predictions.""" """Functions to compute features from predictions."""
from typing import Dict, List, Optional from typing import Dict, List
import numpy as np import numpy as np
@ -86,7 +86,7 @@ def compute_bandwidth(
def compute_max_power_bb( def compute_max_power_bb(
prediction: types.Prediction, prediction: types.Prediction,
spec: Optional[np.ndarray] = None, spec: np.ndarray | None = None,
min_freq: int = MIN_FREQ_HZ, min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ, max_freq: int = MAX_FREQ_HZ,
**_, **_,
@ -131,7 +131,7 @@ def compute_max_power_bb(
def compute_max_power( def compute_max_power(
prediction: types.Prediction, prediction: types.Prediction,
spec: Optional[np.ndarray] = None, spec: np.ndarray | None = None,
min_freq: int = MIN_FREQ_HZ, min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ, max_freq: int = MAX_FREQ_HZ,
**_, **_,
@ -157,7 +157,7 @@ def compute_max_power(
def compute_max_power_first( def compute_max_power_first(
prediction: types.Prediction, prediction: types.Prediction,
spec: Optional[np.ndarray] = None, spec: np.ndarray | None = None,
min_freq: int = MIN_FREQ_HZ, min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ, max_freq: int = MAX_FREQ_HZ,
**_, **_,
@ -184,7 +184,7 @@ def compute_max_power_first(
def compute_max_power_second( def compute_max_power_second(
prediction: types.Prediction, prediction: types.Prediction,
spec: Optional[np.ndarray] = None, spec: np.ndarray | None = None,
min_freq: int = MIN_FREQ_HZ, min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ, max_freq: int = MAX_FREQ_HZ,
**_, **_,
@ -211,7 +211,7 @@ def compute_max_power_second(
def compute_call_interval( def compute_call_interval(
prediction: types.Prediction, prediction: types.Prediction,
previous: Optional[types.Prediction] = None, previous: types.Prediction | None = None,
**_, **_,
) -> float: ) -> float:
"""Compute time between this call and the previous call in seconds.""" """Compute time between this call and the previous call in seconds."""

View File

@ -1,7 +1,7 @@
import datetime import datetime
import os import os
from pathlib import Path from pathlib import Path
from typing import List, Optional, Union from typing import List
from pydantic import BaseModel, Field, computed_field from pydantic import BaseModel, Field, computed_field
@ -198,8 +198,8 @@ class TrainingParameters(BaseModel):
def get_params( def get_params(
make_dirs: bool = False, make_dirs: bool = False,
exps_dir: str = "../../experiments/", exps_dir: str = "../../experiments/",
model_name: Optional[str] = None, model_name: str | None = None,
experiment: Union[Path, str, None] = None, experiment: Path | str | None = None,
**kwargs, **kwargs,
) -> TrainingParameters: ) -> TrainingParameters:
experiments_dir = Path(exps_dir) experiments_dir = Path(exps_dir)

View File

@ -1,7 +1,5 @@
"""Post-processing of the output of the model.""" """Post-processing of the output of the model."""
from typing import List, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
@ -45,7 +43,7 @@ def run_nms(
outputs: ModelOutput, outputs: ModelOutput,
params: NonMaximumSuppressionConfig, params: NonMaximumSuppressionConfig,
sampling_rate: np.ndarray, sampling_rate: np.ndarray,
) -> Tuple[List[PredictionResults], List[np.ndarray]]: ) -> tuple[list[PredictionResults], list[np.ndarray]]:
"""Run non-maximum suppression on the output of the model. """Run non-maximum suppression on the output of the model.
Model outputs processed are expected to have a batch dimension. Model outputs processed are expected to have a batch dimension.
@ -73,8 +71,8 @@ def run_nms(
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k) scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
# loop over batch to save outputs # loop over batch to save outputs
preds: List[PredictionResults] = [] preds: list[PredictionResults] = []
feats: List[np.ndarray] = [] feats: list[np.ndarray] = []
for num_detection in range(pred_det_nms.shape[0]): for num_detection in range(pred_det_nms.shape[0]):
# get valid indices # get valid indices
inds_ord = torch.argsort(x_pos[num_detection, :]) inds_ord = torch.argsort(x_pos[num_detection, :])
@ -151,7 +149,7 @@ def run_nms(
def non_max_suppression( def non_max_suppression(
heat: torch.Tensor, heat: torch.Tensor,
kernel_size: Union[int, Tuple[int, int]], kernel_size: int | tuple[int, int],
): ):
# kernel can be an int or list/tuple # kernel can be an int or list/tuple
if isinstance(kernel_size, int): if isinstance(kernel_size, int):

View File

@ -1,15 +1,32 @@
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
from batdetect2.evaluate.results import save_evaluation_results
from batdetect2.evaluate.tasks import TaskConfig, build_task from batdetect2.evaluate.tasks import TaskConfig, build_task
from batdetect2.evaluate.types import (
AffinityFunction,
ClipMatches,
EvaluationTaskProtocol,
EvaluatorProtocol,
MetricsProtocol,
PlotterProtocol,
)
__all__ = [ __all__ = [
"AffinityFunction",
"ClipMatches",
"DEFAULT_EVAL_DIR",
"EvaluationConfig", "EvaluationConfig",
"EvaluationTaskProtocol",
"Evaluator", "Evaluator",
"EvaluatorProtocol",
"MatchEvaluation",
"MatcherProtocol",
"MetricsProtocol",
"PlotterProtocol",
"TaskConfig", "TaskConfig",
"build_evaluator", "build_evaluator",
"build_task", "build_task",
"evaluate", "run_evaluate",
"load_evaluation_config", "save_evaluation_results",
"DEFAULT_EVAL_DIR",
] ]

View File

@ -1,76 +1,116 @@
from typing import Annotated, Literal, Optional, Union from typing import Annotated, Literal
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import compute_affinity from soundevent.geometry import (
from soundevent.geometry import compute_interval_overlap buffer_geometry,
compute_bbox_iou,
compute_geometric_iou,
compute_temporal_closeness,
compute_temporal_iou,
)
from batdetect2.core.configs import BaseConfig from batdetect2.core import (
from batdetect2.core.registries import Registry BaseConfig,
from batdetect2.typing.evaluate import AffinityFunction ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.types import AffinityFunction
from batdetect2.postprocess.types import Detection
affinity_functions: Registry[AffinityFunction, []] = Registry( affinity_functions: Registry[AffinityFunction, []] = Registry(
"matching_strategy" "affinity_function"
) )
@add_import_config(affinity_functions)
class AffinityFunctionImportConfig(ImportConfig):
"""Use any callable as an affinity function.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class TimeAffinityConfig(BaseConfig): class TimeAffinityConfig(BaseConfig):
name: Literal["time_affinity"] = "time_affinity" name: Literal["time_affinity"] = "time_affinity"
time_buffer: float = 0.01 position: Literal["start", "end", "center"] | float = "start"
max_distance: float = 0.01
class TimeAffinity(AffinityFunction): class TimeAffinity(AffinityFunction):
def __init__(self, time_buffer: float): def __init__(
self.time_buffer = time_buffer self,
max_distance: float = 0.01,
position: Literal["start", "end", "center"] | float = "start",
):
if position == "start":
position = 0
elif position == "end":
position = 1
elif position == "center":
position = 0.5
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry): self.position = position
return compute_timestamp_affinity( self.max_distance = max_distance
geometry1, geometry2, time_buffer=self.time_buffer
def __call__(
self,
detection: Detection,
ground_truth: data.SoundEventAnnotation,
) -> float:
target_geometry = ground_truth.sound_event.geometry
source_geometry = detection.geometry
return compute_temporal_closeness(
target_geometry,
source_geometry,
ratio=self.position,
max_distance=self.max_distance,
) )
@affinity_functions.register(TimeAffinityConfig) @affinity_functions.register(TimeAffinityConfig)
@staticmethod @staticmethod
def from_config(config: TimeAffinityConfig): def from_config(config: TimeAffinityConfig):
return TimeAffinity(time_buffer=config.time_buffer) return TimeAffinity(
max_distance=config.max_distance,
position=config.position,
def compute_timestamp_affinity( )
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
) -> float:
assert isinstance(geometry1, data.TimeStamp)
assert isinstance(geometry2, data.TimeStamp)
start_time1 = geometry1.coordinates
start_time2 = geometry2.coordinates
a = min(start_time1, start_time2)
b = max(start_time1, start_time2)
if b - a >= 2 * time_buffer:
return 0
intersection = a - b + 2 * time_buffer
union = b - a + 2 * time_buffer
return intersection / union
class IntervalIOUConfig(BaseConfig): class IntervalIOUConfig(BaseConfig):
name: Literal["interval_iou"] = "interval_iou" name: Literal["interval_iou"] = "interval_iou"
time_buffer: float = 0.01 time_buffer: float = 0.0
class IntervalIOU(AffinityFunction): class IntervalIOU(AffinityFunction):
def __init__(self, time_buffer: float): def __init__(self, time_buffer: float):
if time_buffer < 0:
raise ValueError("time_buffer must be non-negative")
self.time_buffer = time_buffer self.time_buffer = time_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry): def __call__(
return compute_interval_iou( self,
geometry1, detection: Detection,
geometry2, ground_truth: data.SoundEventAnnotation,
time_buffer=self.time_buffer, ) -> float:
) target_geometry = ground_truth.sound_event.geometry
source_geometry = detection.geometry
if self.time_buffer > 0:
target_geometry = buffer_geometry(
target_geometry,
time=self.time_buffer,
)
source_geometry = buffer_geometry(
source_geometry,
time=self.time_buffer,
)
return compute_temporal_iou(target_geometry, source_geometry)
@affinity_functions.register(IntervalIOUConfig) @affinity_functions.register(IntervalIOUConfig)
@staticmethod @staticmethod
@ -78,64 +118,44 @@ class IntervalIOU(AffinityFunction):
return IntervalIOU(time_buffer=config.time_buffer) return IntervalIOU(time_buffer=config.time_buffer)
def compute_interval_iou(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
) -> float:
assert isinstance(geometry1, data.TimeInterval)
assert isinstance(geometry2, data.TimeInterval)
start_time1, end_time1 = geometry1.coordinates
start_time2, end_time2 = geometry1.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
intersection = compute_interval_overlap(
(start_time1, end_time1),
(start_time2, end_time2),
)
union = (
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
)
if union == 0:
return 0
return intersection / union
class BBoxIOUConfig(BaseConfig): class BBoxIOUConfig(BaseConfig):
name: Literal["bbox_iou"] = "bbox_iou" name: Literal["bbox_iou"] = "bbox_iou"
time_buffer: float = 0.01 time_buffer: float = 0.0
freq_buffer: float = 1000 freq_buffer: float = 0.0
class BBoxIOU(AffinityFunction): class BBoxIOU(AffinityFunction):
def __init__(self, time_buffer: float, freq_buffer: float): def __init__(self, time_buffer: float, freq_buffer: float):
if time_buffer < 0:
raise ValueError("time_buffer must be non-negative")
if freq_buffer < 0:
raise ValueError("freq_buffer must be non-negative")
self.time_buffer = time_buffer self.time_buffer = time_buffer
self.freq_buffer = freq_buffer self.freq_buffer = freq_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry): def __call__(
if not isinstance(geometry1, data.BoundingBox): self,
raise TypeError( detection: Detection,
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}" ground_truth: data.SoundEventAnnotation,
):
target_geometry = ground_truth.sound_event.geometry
source_geometry = detection.geometry
if self.time_buffer > 0 or self.freq_buffer > 0:
target_geometry = buffer_geometry(
target_geometry,
time=self.time_buffer,
freq=self.freq_buffer,
)
source_geometry = buffer_geometry(
source_geometry,
time=self.time_buffer,
freq=self.freq_buffer,
) )
if not isinstance(geometry2, data.BoundingBox): return compute_bbox_iou(target_geometry, source_geometry)
raise TypeError(
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
)
return bbox_iou(
geometry1,
geometry2,
time_buffer=self.time_buffer,
freq_buffer=self.freq_buffer,
)
@affinity_functions.register(BBoxIOUConfig) @affinity_functions.register(BBoxIOUConfig)
@staticmethod @staticmethod
@ -146,65 +166,44 @@ class BBoxIOU(AffinityFunction):
) )
def bbox_iou(
geometry1: data.BoundingBox,
geometry2: data.BoundingBox,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
low_freq1 -= freq_buffer
low_freq2 -= freq_buffer
high_freq1 += freq_buffer
high_freq2 += freq_buffer
time_intersection = compute_interval_overlap(
(start_time1, end_time1),
(start_time2, end_time2),
)
freq_intersection = max(
0,
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
)
intersection = time_intersection * freq_intersection
if intersection == 0:
return 0
union = (
(end_time1 - start_time1) * (high_freq1 - low_freq1)
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
- intersection
)
return intersection / union
class GeometricIOUConfig(BaseConfig): class GeometricIOUConfig(BaseConfig):
name: Literal["geometric_iou"] = "geometric_iou" name: Literal["geometric_iou"] = "geometric_iou"
time_buffer: float = 0.01 time_buffer: float = 0.0
freq_buffer: float = 1000 freq_buffer: float = 0.0
class GeometricIOU(AffinityFunction): class GeometricIOU(AffinityFunction):
def __init__(self, time_buffer: float): def __init__(self, time_buffer: float = 0, freq_buffer: float = 0):
self.time_buffer = time_buffer if time_buffer < 0:
raise ValueError("time_buffer must be non-negative")
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry): if freq_buffer < 0:
return compute_affinity( raise ValueError("freq_buffer must be non-negative")
geometry1,
geometry2, self.time_buffer = time_buffer
time_buffer=self.time_buffer, self.freq_buffer = freq_buffer
)
def __call__(
self,
detection: Detection,
ground_truth: data.SoundEventAnnotation,
):
target_geometry = ground_truth.sound_event.geometry
source_geometry = detection.geometry
if self.time_buffer > 0 or self.freq_buffer > 0:
target_geometry = buffer_geometry(
target_geometry,
time=self.time_buffer,
freq=self.freq_buffer,
)
source_geometry = buffer_geometry(
source_geometry,
time=self.time_buffer,
freq=self.freq_buffer,
)
return compute_geometric_iou(target_geometry, source_geometry)
@affinity_functions.register(GeometricIOUConfig) @affinity_functions.register(GeometricIOUConfig)
@staticmethod @staticmethod
@ -213,18 +212,16 @@ class GeometricIOU(AffinityFunction):
AffinityConfig = Annotated[ AffinityConfig = Annotated[
Union[ TimeAffinityConfig
TimeAffinityConfig, | IntervalIOUConfig
IntervalIOUConfig, | BBoxIOUConfig
BBoxIOUConfig, | GeometricIOUConfig,
GeometricIOUConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
def build_affinity_function( def build_affinity_function(
config: Optional[AffinityConfig] = None, config: AffinityConfig | None = None,
) -> AffinityFunction: ) -> AffinityFunction:
config = config or GeometricIOUConfig() config = config or GeometricIOUConfig()
return affinity_functions.build(config) return affinity_functions.build(config)

View File

@ -1,19 +1,14 @@
from typing import List, Optional from typing import List
from pydantic import Field from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig
from batdetect2.evaluate.tasks import ( from batdetect2.evaluate.tasks import TaskConfig
TaskConfig,
)
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
__all__ = [ __all__ = [
"EvaluationConfig", "EvaluationConfig",
"load_evaluation_config",
] ]
@ -24,7 +19,6 @@ class EvaluationConfig(BaseConfig):
ClassificationTaskConfig(), ClassificationTaskConfig(),
] ]
) )
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
def get_default_eval_config() -> EvaluationConfig: def get_default_eval_config() -> EvaluationConfig:
@ -47,10 +41,3 @@ def get_default_eval_config() -> EvaluationConfig:
] ]
} }
) )
def load_evaluation_config(
path: data.PathLike,
field: Optional[str] = None,
) -> EvaluationConfig:
return load_config(path, schema=EvaluationConfig, field=field)

View File

@ -1,4 +1,4 @@
from typing import List, NamedTuple, Optional, Sequence from typing import List, NamedTuple, Sequence
import torch import torch
from loguru import logger from loguru import logger
@ -8,14 +8,11 @@ from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
from batdetect2.audio.clips import PaddedClipConfig from batdetect2.audio.clips import PaddedClipConfig
from batdetect2.audio.types import AudioLoader, ClipperProtocol
from batdetect2.core import BaseConfig from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import ( from batdetect2.preprocess.types import PreprocessorProtocol
AudioLoader,
ClipperProtocol,
PreprocessorProtocol,
)
__all__ = [ __all__ = [
"TestDataset", "TestDataset",
@ -39,8 +36,8 @@ class TestDataset(Dataset[TestExample]):
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader, audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
clipper: Optional[ClipperProtocol] = None, clipper: ClipperProtocol | None = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
): ):
self.clip_annotations = list(clip_annotations) self.clip_annotations = list(clip_annotations)
self.clipper = clipper self.clipper = clipper
@ -51,8 +48,8 @@ class TestDataset(Dataset[TestExample]):
def __len__(self): def __len__(self):
return len(self.clip_annotations) return len(self.clip_annotations)
def __getitem__(self, idx: int) -> TestExample: def __getitem__(self, index: int) -> TestExample:
clip_annotation = self.clip_annotations[idx] clip_annotation = self.clip_annotations[index]
if self.clipper is not None: if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation) clip_annotation = self.clipper(clip_annotation)
@ -63,14 +60,13 @@ class TestDataset(Dataset[TestExample]):
spectrogram = self.preprocessor(wav_tensor) spectrogram = self.preprocessor(wav_tensor)
return TestExample( return TestExample(
spec=spectrogram, spec=spectrogram,
idx=torch.tensor(idx), idx=torch.tensor(index),
start_time=torch.tensor(clip.start_time), start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time), end_time=torch.tensor(clip.end_time),
) )
class TestLoaderConfig(BaseConfig): class TestLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field( clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig() default_factory=lambda: PaddedClipConfig()
) )
@ -78,10 +74,10 @@ class TestLoaderConfig(BaseConfig):
def build_test_loader( def build_test_loader(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
config: Optional[TestLoaderConfig] = None, config: TestLoaderConfig | None = None,
num_workers: Optional[int] = None, num_workers: int = 0,
) -> DataLoader[TestExample]: ) -> DataLoader[TestExample]:
logger.info("Building test data loader...") logger.info("Building test data loader...")
config = config or TestLoaderConfig() config = config or TestLoaderConfig()
@ -97,7 +93,6 @@ def build_test_loader(
config=config, config=config,
) )
num_workers = num_workers or config.num_workers
return DataLoader( return DataLoader(
test_dataset, test_dataset,
batch_size=1, batch_size=1,
@ -109,9 +104,9 @@ def build_test_loader(
def build_test_dataset( def build_test_dataset(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
config: Optional[TestLoaderConfig] = None, config: TestLoaderConfig | None = None,
) -> TestDataset: ) -> TestDataset:
logger.info("Building training dataset...") logger.info("Building training dataset...")
config = config or TestLoaderConfig() config = config or TestLoaderConfig()

View File

@ -1,56 +1,51 @@
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple from typing import Sequence
from lightning import Trainer from lightning import Trainer
from soundevent import data from soundevent import data
from batdetect2.audio import build_audio_loader from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.evaluate import EvaluationConfig
from batdetect2.evaluate.dataset import build_test_loader from batdetect2.evaluate.dataset import build_test_loader
from batdetect2.evaluate.evaluator import build_evaluator from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.lightning import EvaluationModule from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import build_logger from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.preprocess import build_preprocessor from batdetect2.outputs import OutputsConfig, build_output_transform
from batdetect2.targets import build_targets from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.typing.postprocess import RawPrediction from batdetect2.postprocess.types import ClipDetections
from batdetect2.preprocess.types import PreprocessorProtocol
if TYPE_CHECKING: from batdetect2.targets.types import TargetProtocol
from batdetect2.config import BatDetect2Config
from batdetect2.typing import (
AudioLoader,
OutputFormatterProtocol,
PreprocessorProtocol,
TargetProtocol,
)
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations" DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
def evaluate( def run_evaluate(
model: Model, model: Model,
test_annotations: Sequence[data.ClipAnnotation], test_annotations: Sequence[data.ClipAnnotation],
targets: Optional["TargetProtocol"] = None, targets: TargetProtocol | None = None,
audio_loader: Optional["AudioLoader"] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: PreprocessorProtocol | None = None,
config: Optional["BatDetect2Config"] = None, audio_config: AudioConfig | None = None,
formatter: Optional["OutputFormatterProtocol"] = None, evaluation_config: EvaluationConfig | None = None,
num_workers: Optional[int] = None, output_config: OutputsConfig | None = None,
logger_config: LoggerConfig | None = None,
formatter: OutputFormatterProtocol | None = None,
num_workers: int = 0,
output_dir: data.PathLike = DEFAULT_EVAL_DIR, output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]: ) -> tuple[dict[str, float], list[ClipDetections]]:
from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config() audio_config = audio_config or AudioConfig()
evaluation_config = evaluation_config or EvaluationConfig()
output_config = output_config or OutputsConfig()
audio_loader = audio_loader or build_audio_loader(config=config.audio) audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or build_preprocessor( preprocessor = preprocessor or model.preprocessor
config=config.preprocess, targets = targets or model.targets
input_samplerate=audio_loader.samplerate,
)
targets = targets or build_targets(config=config.targets)
loader = build_test_loader( loader = build_test_loader(
test_annotations, test_annotations,
@ -59,15 +54,26 @@ def evaluate(
num_workers=num_workers, num_workers=num_workers,
) )
evaluator = build_evaluator(config=config.evaluation, targets=targets) output_transform = build_output_transform(
config=output_config.transform,
targets=targets,
)
evaluator = build_evaluator(
config=evaluation_config,
targets=targets,
transform=output_transform,
)
logger = build_logger( logger = build_logger(
config.evaluation.logger, logger_config or CSVLoggerConfig(),
log_dir=Path(output_dir), log_dir=Path(output_dir),
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
) )
module = EvaluationModule(model, evaluator) module = EvaluationModule(
model,
evaluator,
)
trainer = Trainer(logger=logger, enable_checkpointing=False) trainer = Trainer(logger=logger, enable_checkpointing=False)
metrics = trainer.test(module, loader) metrics = trainer.test(module, loader)

View File

@ -1,13 +1,15 @@
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union from typing import Any, Dict, Iterable, List, Sequence, Tuple
from matplotlib.figure import Figure from matplotlib.figure import Figure
from soundevent import data from soundevent import data
from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.tasks import build_task from batdetect2.evaluate.tasks import build_task
from batdetect2.evaluate.types import EvaluationTaskProtocol, EvaluatorProtocol
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections, ClipDetectionsTensor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.typing import EvaluatorProtocol, TargetProtocol from batdetect2.targets.types import TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
__all__ = [ __all__ = [
"Evaluator", "Evaluator",
@ -19,15 +21,27 @@ class Evaluator:
def __init__( def __init__(
self, self,
targets: TargetProtocol, targets: TargetProtocol,
tasks: Sequence[EvaluatorProtocol], transform: OutputTransformProtocol,
tasks: Sequence[EvaluationTaskProtocol],
): ):
self.targets = targets self.targets = targets
self.transform = transform
self.tasks = tasks self.tasks = tasks
def to_clip_detections_batch(
self,
clip_detections: Sequence[ClipDetectionsTensor],
clips: Sequence[data.Clip],
) -> list[ClipDetections]:
return [
self.transform.to_clip_detections(detections=dets, clip=clip)
for dets, clip in zip(clip_detections, clips, strict=False)
]
def evaluate( def evaluate(
self, self,
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[ClipDetections],
) -> List[Any]: ) -> List[Any]:
return [ return [
task.evaluate(clip_annotations, predictions) for task in self.tasks task.evaluate(clip_annotations, predictions) for task in self.tasks
@ -36,7 +50,7 @@ class Evaluator:
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]: def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
results = {} results = {}
for task, outputs in zip(self.tasks, eval_outputs): for task, outputs in zip(self.tasks, eval_outputs, strict=False):
results.update(task.compute_metrics(outputs)) results.update(task.compute_metrics(outputs))
return results return results
@ -45,14 +59,15 @@ class Evaluator:
self, self,
eval_outputs: List[Any], eval_outputs: List[Any],
) -> Iterable[Tuple[str, Figure]]: ) -> Iterable[Tuple[str, Figure]]:
for task, outputs in zip(self.tasks, eval_outputs): for task, outputs in zip(self.tasks, eval_outputs, strict=False):
for name, fig in task.generate_plots(outputs): for name, fig in task.generate_plots(outputs):
yield name, fig yield name, fig
def build_evaluator( def build_evaluator(
config: Optional[Union[EvaluationConfig, dict]] = None, config: EvaluationConfig | dict | None = None,
targets: Optional[TargetProtocol] = None, targets: TargetProtocol | None = None,
transform: OutputTransformProtocol | None = None,
) -> EvaluatorProtocol: ) -> EvaluatorProtocol:
targets = targets or build_targets() targets = targets or build_targets()
@ -62,7 +77,10 @@ def build_evaluator(
if not isinstance(config, EvaluationConfig): if not isinstance(config, EvaluationConfig):
config = EvaluationConfig.model_validate(config) config = EvaluationConfig.model_validate(config)
transform = transform or build_output_transform(targets=targets)
return Evaluator( return Evaluator(
targets=targets, targets=targets,
transform=transform,
tasks=[build_task(task, targets=targets) for task in config.tasks], tasks=[build_task(task, targets=targets) for task in config.tasks],
) )

View File

@ -357,7 +357,7 @@ def train_rf_model(x_train, y_train, num_classes, seed=2001):
clf = RandomForestClassifier(random_state=seed, n_jobs=-1) clf = RandomForestClassifier(random_state=seed, n_jobs=-1)
clf.fit(x_train, y_train) clf.fit(x_train, y_train)
y_pred = clf.predict(x_train) y_pred = clf.predict(x_train)
tr_acc = (y_pred == y_train).mean() (y_pred == y_train).mean()
# print('Train acc', round(tr_acc*100, 2)) # print('Train acc', round(tr_acc*100, 2))
return clf, un_train_class return clf, un_train_class
@ -450,7 +450,7 @@ def add_root_path_back(data_sets, ann_path, wav_path):
def check_classes_in_train(gt_list, class_names): def check_classes_in_train(gt_list, class_names):
num_gt_total = np.sum([gg["start_times"].shape[0] for gg in gt_list]) np.sum([gg["start_times"].shape[0] for gg in gt_list])
num_with_no_class = 0 num_with_no_class = 0
for gt in gt_list: for gt in gt_list:
for cc in gt["class_names"]: for cc in gt["class_names"]:
@ -569,7 +569,7 @@ if __name__ == "__main__":
num_with_no_class = check_classes_in_train(gt_test, class_names) num_with_no_class = check_classes_in_train(gt_test, class_names)
if total_num_calls == num_with_no_class: if total_num_calls == num_with_no_class:
print("Classes from the test set are not in the train set.") print("Classes from the test set are not in the train set.")
assert False raise AssertionError()
# only need the train data if evaluating Sonobat or Tadarida # only need the train data if evaluating Sonobat or Tadarida
if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "": if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "":
@ -743,7 +743,7 @@ if __name__ == "__main__":
# check if the class names are the same # check if the class names are the same
if params_bd["class_names"] != class_names: if params_bd["class_names"] != class_names:
print("Warning: Class names are not the same as the trained model") print("Warning: Class names are not the same as the trained model")
assert False raise AssertionError()
run_config = { run_config = {
**bd_args, **bd_args,
@ -753,7 +753,7 @@ if __name__ == "__main__":
preds_bd = [] preds_bd = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for ii, gg in enumerate(gt_test): for gg in gt_test:
pred = du.process_file( pred = du.process_file(
gg["file_path"], gg["file_path"],
model, model,

View File

@ -5,11 +5,10 @@ from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import get_image_logger from batdetect2.logging import get_image_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions from batdetect2.postprocess.types import ClipDetections
from batdetect2.typing import EvaluatorProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
class EvaluationModule(LightningModule): class EvaluationModule(LightningModule):
@ -24,7 +23,7 @@ class EvaluationModule(LightningModule):
self.evaluator = evaluator self.evaluator = evaluator
self.clip_annotations: List[data.ClipAnnotation] = [] self.clip_annotations: List[data.ClipAnnotation] = []
self.predictions: List[BatDetect2Prediction] = [] self.predictions: List[ClipDetections] = []
def test_step(self, batch: TestExample, batch_idx: int): def test_step(self, batch: TestExample, batch_idx: int):
dataset = self.get_dataset() dataset = self.get_dataset()
@ -34,22 +33,11 @@ class EvaluationModule(LightningModule):
] ]
outputs = self.model.detector(batch.spec) outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor( clip_detections = self.model.postprocessor(outputs)
outputs, predictions = self.evaluator.to_clip_detections_batch(
start_times=[ca.clip.start_time for ca in clip_annotations], clip_detections,
[clip_annotation.clip for clip_annotation in clip_annotations],
) )
predictions = [
BatDetect2Prediction(
clip=clip_annotation.clip,
predictions=to_raw_predictions(
clip_dets.numpy(),
targets=self.evaluator.targets,
),
)
for clip_annotation, clip_dets in zip(
clip_annotations, clip_detections
)
]
self.clip_annotations.extend(clip_annotations) self.clip_annotations.extend(clip_annotations)
self.predictions.extend(predictions) self.predictions.extend(predictions)

View File

@ -1,617 +0,0 @@
from collections.abc import Callable, Iterable, Mapping
from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np
from pydantic import Field
from scipy.optimize import linear_sum_assignment
from soundevent import data
from soundevent.evaluation import compute_affinity
from soundevent.geometry import buffer_geometry, compute_bounds, scale_geometry
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.affinity import (
AffinityConfig,
BBoxIOUConfig,
GeometricIOUConfig,
build_affinity_function,
)
from batdetect2.targets import build_targets
from batdetect2.typing import (
AffinityFunction,
MatcherProtocol,
MatchEvaluation,
RawPrediction,
TargetProtocol,
)
from batdetect2.typing.evaluate import ClipMatches
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
"""The geometry representation to use for matching."""
matching_strategies = Registry("matching_strategy")
def match(
sound_event_annotations: Sequence[data.SoundEventAnnotation],
raw_predictions: Sequence[RawPrediction],
clip: data.Clip,
scores: Optional[Sequence[float]] = None,
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
) -> ClipMatches:
if matcher is None:
matcher = build_matcher()
if targets is None:
targets = build_targets()
target_geometries: List[data.Geometry] = [ # type: ignore
sound_event_annotation.sound_event.geometry
for sound_event_annotation in sound_event_annotations
]
predicted_geometries = [
raw_prediction.geometry for raw_prediction in raw_predictions
]
if scores is None:
scores = [
raw_prediction.detection_score
for raw_prediction in raw_predictions
]
matches = []
for source_idx, target_idx, affinity in matcher(
ground_truth=target_geometries,
predictions=predicted_geometries,
scores=scores,
):
target = (
sound_event_annotations[target_idx]
if target_idx is not None
else None
)
prediction = (
raw_predictions[source_idx] if source_idx is not None else None
)
gt_det = target_idx is not None
gt_class = targets.encode_class(target) if target is not None else None
gt_geometry = (
target_geometries[target_idx] if target_idx is not None else None
)
pred_score = float(prediction.detection_score) if prediction else 0
pred_geometry = (
predicted_geometries[source_idx]
if source_idx is not None
else None
)
class_scores = (
{
class_name: score
for class_name, score in zip(
targets.class_names,
prediction.class_scores,
)
}
if prediction is not None
else {}
)
matches.append(
MatchEvaluation(
clip=clip,
sound_event_annotation=target,
gt_det=gt_det,
gt_class=gt_class,
gt_geometry=gt_geometry,
pred_score=pred_score,
pred_class_scores=class_scores,
pred_geometry=pred_geometry,
affinity=affinity,
)
)
return ClipMatches(clip=clip, matches=matches)
class StartTimeMatchConfig(BaseConfig):
name: Literal["start_time_match"] = "start_time_match"
distance_threshold: float = 0.01
class StartTimeMatcher(MatcherProtocol):
def __init__(self, distance_threshold: float):
self.distance_threshold = distance_threshold
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
):
return match_start_times(
ground_truth,
predictions,
scores,
distance_threshold=self.distance_threshold,
)
@matching_strategies.register(StartTimeMatchConfig)
@staticmethod
def from_config(config: StartTimeMatchConfig):
return StartTimeMatcher(distance_threshold=config.distance_threshold)
def match_start_times(
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
distance_threshold: float = 0.01,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
if not ground_truth:
for index in range(len(predictions)):
yield index, None, 0
return
if not predictions:
for index in range(len(ground_truth)):
yield None, index, 0
return
gt_times = np.array([compute_bounds(geom)[0] for geom in ground_truth])
pred_times = np.array([compute_bounds(geom)[0] for geom in predictions])
scores = np.array(scores)
sort_args = np.argsort(scores)[::-1]
distances = np.abs(gt_times[None, :] - pred_times[:, None])
closests = np.argmin(distances, axis=-1)
unmatched_gt = set(range(len(gt_times)))
for pred_index in sort_args:
# Get the closest ground truth
gt_closest_index = closests[pred_index]
if gt_closest_index not in unmatched_gt:
# Does not match if closest has been assigned
yield pred_index, None, 0
continue
# Get the actual distance
distance = distances[pred_index, gt_closest_index]
if distance > distance_threshold:
# Does not match if too far from closest
yield pred_index, None, 0
continue
# Return affinity value: linear interpolation between 0 to 1, where a
# distance at the threshold maps to 0 affinity and a zero distance maps
# to 1.
affinity = np.interp(
distance,
[0, distance_threshold],
[1, 0],
left=1,
right=0,
)
unmatched_gt.remove(gt_closest_index)
yield pred_index, gt_closest_index, affinity
for missing_index in unmatched_gt:
yield None, missing_index, 0
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
return data.BoundingBox(
coordinates=[start_time, low_freq, end_time, high_freq]
)
def _to_interval(geometry: data.Geometry) -> data.TimeInterval:
start_time, _, end_time, _ = compute_bounds(geometry)
return data.TimeInterval(coordinates=[start_time, end_time])
def _to_timestamp(geometry: data.Geometry) -> data.TimeStamp:
start_time = compute_bounds(geometry)[0]
return data.TimeStamp(coordinates=start_time)
_geometry_cast_functions: Mapping[
MatchingGeometry, Callable[[data.Geometry], data.Geometry]
] = {
"bbox": _to_bbox,
"interval": _to_interval,
"timestamp": _to_timestamp,
}
class GreedyMatchConfig(BaseConfig):
name: Literal["greedy_match"] = "greedy_match"
geometry: MatchingGeometry = "timestamp"
affinity_threshold: float = 0.5
affinity_function: AffinityConfig = Field(
default_factory=GeometricIOUConfig
)
class GreedyMatcher(MatcherProtocol):
def __init__(
self,
geometry: MatchingGeometry,
affinity_threshold: float,
affinity_function: AffinityFunction,
):
self.geometry = geometry
self.affinity_function = affinity_function
self.affinity_threshold = affinity_threshold
self.cast_geometry = _geometry_cast_functions[self.geometry]
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
):
return greedy_match(
ground_truth=[self.cast_geometry(geom) for geom in ground_truth],
predictions=[self.cast_geometry(geom) for geom in predictions],
scores=scores,
affinity_function=self.affinity_function,
affinity_threshold=self.affinity_threshold,
)
@matching_strategies.register(GreedyMatchConfig)
@staticmethod
def from_config(config: GreedyMatchConfig):
affinity_function = build_affinity_function(config.affinity_function)
return GreedyMatcher(
geometry=config.geometry,
affinity_threshold=config.affinity_threshold,
affinity_function=affinity_function,
)
def greedy_match(
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
affinity_threshold: float = 0.5,
affinity_function: AffinityFunction = compute_affinity,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
"""Performs a greedy, one-to-one matching of source to target geometries.
Iterates through source geometries, prioritizing by score if provided. Each
source is matched to the best available target, provided the affinity
exceeds the threshold and the target has not already been assigned.
Parameters
----------
source
A list of source geometries (e.g., predictions).
target
A list of target geometries (e.g., ground truths).
scores
Confidence scores for each source geometry for prioritization.
affinity_threshold
The minimum affinity score required for a valid match.
Yields
------
Tuple[Optional[int], Optional[int], float]
A 3-element tuple describing a match or a miss. There are three
possible formats:
- Successful Match: `(source_idx, target_idx, affinity)`
- Unmatched Source (False Positive): `(source_idx, None, 0)`
- Unmatched Target (False Negative): `(None, target_idx, 0)`
"""
unassigned_gt = set(range(len(ground_truth)))
if not predictions:
for gt_idx in range(len(ground_truth)):
yield None, gt_idx, 0
return
if not ground_truth:
for pred_idx in range(len(predictions)):
yield pred_idx, None, 0
return
indices = np.argsort(scores)[::-1]
for pred_idx in indices:
source_geometry = predictions[pred_idx]
affinities = np.array(
[
affinity_function(source_geometry, target_geometry)
for target_geometry in ground_truth
]
)
closest_target = int(np.argmax(affinities))
affinity = affinities[closest_target]
if affinities[closest_target] <= affinity_threshold:
yield pred_idx, None, 0
continue
if closest_target not in unassigned_gt:
yield pred_idx, None, 0
continue
unassigned_gt.remove(closest_target)
yield pred_idx, closest_target, affinity
for gt_idx in unassigned_gt:
yield None, gt_idx, 0
class GreedyAffinityMatchConfig(BaseConfig):
name: Literal["greedy_affinity_match"] = "greedy_affinity_match"
affinity_function: AffinityConfig = Field(default_factory=BBoxIOUConfig)
affinity_threshold: float = 0.5
time_buffer: float = 0
frequency_buffer: float = 0
time_scale: float = 1.0
frequency_scale: float = 1.0
class GreedyAffinityMatcher(MatcherProtocol):
def __init__(
self,
affinity_threshold: float,
affinity_function: AffinityFunction,
time_buffer: float = 0,
frequency_buffer: float = 0,
time_scale: float = 1.0,
frequency_scale: float = 1.0,
):
self.affinity_threshold = affinity_threshold
self.affinity_function = affinity_function
self.time_buffer = time_buffer
self.frequency_buffer = frequency_buffer
self.time_scale = time_scale
self.frequency_scale = frequency_scale
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
):
if self.time_buffer != 0 or self.frequency_buffer != 0:
ground_truth = [
buffer_geometry(
geometry,
time_buffer=self.time_buffer,
freq_buffer=self.frequency_buffer,
)
for geometry in ground_truth
]
predictions = [
buffer_geometry(
geometry,
time_buffer=self.time_buffer,
freq_buffer=self.frequency_buffer,
)
for geometry in predictions
]
affinity_matrix = compute_affinity_matrix(
ground_truth,
predictions,
self.affinity_function,
time_scale=self.time_scale,
frequency_scale=self.frequency_scale,
)
return select_greedy_matches(
affinity_matrix,
affinity_threshold=self.affinity_threshold,
)
@matching_strategies.register(GreedyAffinityMatchConfig)
@staticmethod
def from_config(config: GreedyAffinityMatchConfig):
affinity_function = build_affinity_function(config.affinity_function)
return GreedyAffinityMatcher(
affinity_threshold=config.affinity_threshold,
affinity_function=affinity_function,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
)
class OptimalMatchConfig(BaseConfig):
name: Literal["optimal_affinity_match"] = "optimal_affinity_match"
affinity_function: AffinityConfig = Field(default_factory=BBoxIOUConfig)
affinity_threshold: float = 0.5
time_buffer: float = 0
frequency_buffer: float = 0
time_scale: float = 1.0
frequency_scale: float = 1.0
class OptimalMatcher(MatcherProtocol):
def __init__(
self,
affinity_threshold: float,
affinity_function: AffinityFunction,
time_buffer: float = 0,
frequency_buffer: float = 0,
time_scale: float = 1.0,
frequency_scale: float = 1.0,
):
self.affinity_threshold = affinity_threshold
self.affinity_function = affinity_function
self.time_buffer = time_buffer
self.frequency_buffer = frequency_buffer
self.time_scale = time_scale
self.frequency_scale = frequency_scale
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
):
if self.time_buffer != 0 or self.frequency_buffer != 0:
ground_truth = [
buffer_geometry(
geometry,
time_buffer=self.time_buffer,
freq_buffer=self.frequency_buffer,
)
for geometry in ground_truth
]
predictions = [
buffer_geometry(
geometry,
time_buffer=self.time_buffer,
freq_buffer=self.frequency_buffer,
)
for geometry in predictions
]
affinity_matrix = compute_affinity_matrix(
ground_truth,
predictions,
self.affinity_function,
time_scale=self.time_scale,
frequency_scale=self.frequency_scale,
)
return select_optimal_matches(
affinity_matrix,
affinity_threshold=self.affinity_threshold,
)
@matching_strategies.register(OptimalMatchConfig)
@staticmethod
def from_config(config: OptimalMatchConfig):
affinity_function = build_affinity_function(config.affinity_function)
return OptimalMatcher(
affinity_threshold=config.affinity_threshold,
affinity_function=affinity_function,
time_buffer=config.time_buffer,
frequency_buffer=config.frequency_buffer,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
)
MatchConfig = Annotated[
Union[
GreedyMatchConfig,
StartTimeMatchConfig,
OptimalMatchConfig,
GreedyAffinityMatchConfig,
],
Field(discriminator="name"),
]
def compute_affinity_matrix(
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
affinity_function: AffinityFunction,
time_scale: float = 1,
frequency_scale: float = 1,
) -> np.ndarray:
# Scale geometries if necessary
if time_scale != 1 or frequency_scale != 1:
ground_truth = [
scale_geometry(geometry, time_scale, frequency_scale)
for geometry in ground_truth
]
predictions = [
scale_geometry(geometry, time_scale, frequency_scale)
for geometry in predictions
]
affinity_matrix = np.zeros((len(ground_truth), len(predictions)))
for gt_idx, gt_geometry in enumerate(ground_truth):
for pred_idx, pred_geometry in enumerate(predictions):
affinity = affinity_function(
gt_geometry,
pred_geometry,
)
affinity_matrix[gt_idx, pred_idx] = affinity
return affinity_matrix
def select_optimal_matches(
affinity_matrix: np.ndarray,
affinity_threshold: float = 0.5,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
num_gt, num_pred = affinity_matrix.shape
gts = set(range(num_gt))
preds = set(range(num_pred))
assiged_rows, assigned_columns = linear_sum_assignment(
affinity_matrix,
maximize=True,
)
for gt_idx, pred_idx in zip(assiged_rows, assigned_columns):
affinity = float(affinity_matrix[gt_idx, pred_idx])
if affinity <= affinity_threshold:
continue
yield gt_idx, pred_idx, affinity
gts.remove(gt_idx)
preds.remove(pred_idx)
for gt_idx in gts:
yield gt_idx, None, 0
for pred_idx in preds:
yield None, pred_idx, 0
def select_greedy_matches(
affinity_matrix: np.ndarray,
affinity_threshold: float = 0.5,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
num_gt, num_pred = affinity_matrix.shape
unmatched_pred = set(range(num_pred))
for gt_idx in range(num_gt):
row = affinity_matrix[gt_idx]
top_pred = int(np.argmax(row))
top_affinity = float(row[top_pred])
if (
top_affinity <= affinity_threshold
or top_pred not in unmatched_pred
):
yield None, gt_idx, 0
continue
unmatched_pred.remove(top_pred)
yield top_pred, gt_idx, top_affinity
for pred_idx in unmatched_pred:
yield pred_idx, None, 0
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
config = config or StartTimeMatchConfig()
return matching_strategies.build(config)

View File

@ -7,10 +7,8 @@ from typing import (
List, List,
Literal, Literal,
Mapping, Mapping,
Optional,
Sequence, Sequence,
Tuple, Tuple,
Union,
) )
import numpy as np import numpy as np
@ -18,16 +16,23 @@ from pydantic import Field
from sklearn import metrics from sklearn import metrics
from soundevent import data from soundevent import data
from batdetect2.core import BaseConfig, Registry from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.metrics.common import ( from batdetect2.evaluate.metrics.common import (
average_precision, average_precision,
compute_precision_recall, compute_precision_recall,
) )
from batdetect2.typing import RawPrediction, TargetProtocol from batdetect2.postprocess.types import Detection
from batdetect2.targets.types import TargetProtocol
__all__ = [ __all__ = [
"ClassificationMetric", "ClassificationMetric",
"ClassificationMetricConfig", "ClassificationMetricConfig",
"ClassificationMetricImportConfig",
"build_classification_metric", "build_classification_metric",
"compute_precision_recall_curves", "compute_precision_recall_curves",
] ]
@ -36,13 +41,13 @@ __all__ = [
@dataclass @dataclass
class MatchEval: class MatchEval:
clip: data.Clip clip: data.Clip
gt: Optional[data.SoundEventAnnotation] gt: data.SoundEventAnnotation | None
pred: Optional[RawPrediction] pred: Detection | None
is_prediction: bool is_prediction: bool
is_ground_truth: bool is_ground_truth: bool
is_generic: bool is_generic: bool
true_class: Optional[str] true_class: str | None
score: float score: float
@ -60,17 +65,28 @@ classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
) )
@add_import_config(classification_metrics)
class ClassificationMetricImportConfig(ImportConfig):
"""Use any callable as a classification metric.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class BaseClassificationConfig(BaseConfig): class BaseClassificationConfig(BaseConfig):
include: Optional[List[str]] = None include: List[str] | None = None
exclude: Optional[List[str]] = None exclude: List[str] | None = None
class BaseClassificationMetric: class BaseClassificationMetric:
def __init__( def __init__(
self, self,
targets: TargetProtocol, targets: TargetProtocol,
include: Optional[List[str]] = None, include: List[str] | None = None,
exclude: Optional[List[str]] = None, exclude: List[str] | None = None,
): ):
self.targets = targets self.targets = targets
self.include = include self.include = include
@ -100,8 +116,8 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
ignore_non_predictions: bool = True, ignore_non_predictions: bool = True,
ignore_generic: bool = True, ignore_generic: bool = True,
label: str = "average_precision", label: str = "average_precision",
include: Optional[List[str]] = None, include: List[str] | None = None,
exclude: Optional[List[str]] = None, exclude: List[str] | None = None,
): ):
super().__init__(include=include, exclude=exclude, targets=targets) super().__init__(include=include, exclude=exclude, targets=targets)
self.ignore_non_predictions = ignore_non_predictions self.ignore_non_predictions = ignore_non_predictions
@ -169,8 +185,8 @@ class ClassificationROCAUC(BaseClassificationMetric):
ignore_non_predictions: bool = True, ignore_non_predictions: bool = True,
ignore_generic: bool = True, ignore_generic: bool = True,
label: str = "roc_auc", label: str = "roc_auc",
include: Optional[List[str]] = None, include: List[str] | None = None,
exclude: Optional[List[str]] = None, exclude: List[str] | None = None,
): ):
self.targets = targets self.targets = targets
self.ignore_non_predictions = ignore_non_predictions self.ignore_non_predictions = ignore_non_predictions
@ -225,10 +241,7 @@ class ClassificationROCAUC(BaseClassificationMetric):
ClassificationMetricConfig = Annotated[ ClassificationMetricConfig = Annotated[
Union[ ClassificationAveragePrecisionConfig | ClassificationROCAUCConfig,
ClassificationAveragePrecisionConfig,
ClassificationROCAUCConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -1,13 +1,17 @@
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union from typing import Annotated, Callable, Dict, Literal, Sequence, Set
import numpy as np import numpy as np
from pydantic import Field from pydantic import Field
from sklearn import metrics from sklearn import metrics
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.metrics.common import average_precision from batdetect2.evaluate.metrics.common import average_precision
@ -24,6 +28,17 @@ clip_classification_metrics: Registry[ClipClassificationMetric, []] = Registry(
) )
@add_import_config(clip_classification_metrics)
class ClipClassificationMetricImportConfig(ImportConfig):
"""Use any callable as a clip classification metric.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class ClipClassificationAveragePrecisionConfig(BaseConfig): class ClipClassificationAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision" name: Literal["average_precision"] = "average_precision"
label: str = "average_precision" label: str = "average_precision"
@ -123,10 +138,7 @@ class ClipClassificationROCAUC:
ClipClassificationMetricConfig = Annotated[ ClipClassificationMetricConfig = Annotated[
Union[ ClipClassificationAveragePrecisionConfig | ClipClassificationROCAUCConfig,
ClipClassificationAveragePrecisionConfig,
ClipClassificationROCAUCConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -1,12 +1,16 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Annotated, Callable, Dict, Literal, Sequence, Union from typing import Annotated, Callable, Dict, Literal, Sequence
import numpy as np import numpy as np
from pydantic import Field from pydantic import Field
from sklearn import metrics from sklearn import metrics
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.metrics.common import average_precision from batdetect2.evaluate.metrics.common import average_precision
@ -23,6 +27,17 @@ clip_detection_metrics: Registry[ClipDetectionMetric, []] = Registry(
) )
@add_import_config(clip_detection_metrics)
class ClipDetectionMetricImportConfig(ImportConfig):
"""Use any callable as a clip detection metric.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class ClipDetectionAveragePrecisionConfig(BaseConfig): class ClipDetectionAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision" name: Literal["average_precision"] = "average_precision"
label: str = "average_precision" label: str = "average_precision"
@ -159,12 +174,10 @@ class ClipDetectionPrecision:
ClipDetectionMetricConfig = Annotated[ ClipDetectionMetricConfig = Annotated[
Union[ ClipDetectionAveragePrecisionConfig
ClipDetectionAveragePrecisionConfig, | ClipDetectionROCAUCConfig
ClipDetectionROCAUCConfig, | ClipDetectionRecallConfig
ClipDetectionRecallConfig, | ClipDetectionPrecisionConfig,
ClipDetectionPrecisionConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple from typing import Tuple
import numpy as np import numpy as np
@ -11,7 +11,7 @@ __all__ = [
def compute_precision_recall( def compute_precision_recall(
y_true, y_true,
y_score, y_score,
num_positives: Optional[int] = None, num_positives: int | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
y_true = np.array(y_true) y_true = np.array(y_true)
y_score = np.array(y_score) y_score = np.array(y_score)
@ -41,7 +41,7 @@ def compute_precision_recall(
def average_precision( def average_precision(
y_true, y_true,
y_score, y_score,
num_positives: Optional[int] = None, num_positives: int | None = None,
) -> float: ) -> float:
if num_positives == 0: if num_positives == 0:
return np.nan return np.nan

View File

@ -5,9 +5,7 @@ from typing import (
Dict, Dict,
List, List,
Literal, Literal,
Optional,
Sequence, Sequence,
Union,
) )
import numpy as np import numpy as np
@ -15,21 +13,27 @@ from pydantic import Field
from sklearn import metrics from sklearn import metrics
from soundevent import data from soundevent import data
from batdetect2.core import BaseConfig, Registry from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.metrics.common import average_precision from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction from batdetect2.postprocess.types import Detection
__all__ = [ __all__ = [
"DetectionMetricConfig", "DetectionMetricConfig",
"DetectionMetric", "DetectionMetric",
"DetectionMetricImportConfig",
"build_detection_metric", "build_detection_metric",
] ]
@dataclass @dataclass
class MatchEval: class MatchEval:
gt: Optional[data.SoundEventAnnotation] gt: data.SoundEventAnnotation | None
pred: Optional[RawPrediction] pred: Detection | None
is_prediction: bool is_prediction: bool
is_ground_truth: bool is_ground_truth: bool
@ -48,6 +52,17 @@ DetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric") detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric")
@add_import_config(detection_metrics)
class DetectionMetricImportConfig(ImportConfig):
"""Use any callable as a detection metric.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class DetectionAveragePrecisionConfig(BaseConfig): class DetectionAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision" name: Literal["average_precision"] = "average_precision"
label: str = "average_precision" label: str = "average_precision"
@ -79,7 +94,7 @@ class DetectionAveragePrecision:
y_score.append(m.score) y_score.append(m.score)
ap = average_precision(y_true, y_score, num_positives=num_positives) ap = average_precision(y_true, y_score, num_positives=num_positives)
return {self.label: ap} return {self.label: float(ap)}
@detection_metrics.register(DetectionAveragePrecisionConfig) @detection_metrics.register(DetectionAveragePrecisionConfig)
@staticmethod @staticmethod
@ -212,12 +227,10 @@ class DetectionPrecision:
DetectionMetricConfig = Annotated[ DetectionMetricConfig = Annotated[
Union[ DetectionAveragePrecisionConfig
DetectionAveragePrecisionConfig, | DetectionROCAUCConfig
DetectionROCAUCConfig, | DetectionRecallConfig
DetectionRecallConfig, | DetectionPrecisionConfig,
DetectionPrecisionConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -1,4 +1,3 @@
from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import (
Annotated, Annotated,
@ -6,9 +5,7 @@ from typing import (
Dict, Dict,
List, List,
Literal, Literal,
Optional,
Sequence, Sequence,
Union,
) )
import numpy as np import numpy as np
@ -16,14 +13,20 @@ from pydantic import Field
from sklearn import metrics, preprocessing from sklearn import metrics, preprocessing
from soundevent import data from soundevent import data
from batdetect2.core import BaseConfig, Registry from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.metrics.common import average_precision from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction from batdetect2.postprocess.types import Detection
from batdetect2.typing.targets import TargetProtocol from batdetect2.targets.types import TargetProtocol
__all__ = [ __all__ = [
"TopClassMetricConfig", "TopClassMetricConfig",
"TopClassMetric", "TopClassMetric",
"TopClassMetricImportConfig",
"build_top_class_metric", "build_top_class_metric",
] ]
@ -31,14 +34,14 @@ __all__ = [
@dataclass @dataclass
class MatchEval: class MatchEval:
clip: data.Clip clip: data.Clip
gt: Optional[data.SoundEventAnnotation] gt: data.SoundEventAnnotation | None
pred: Optional[RawPrediction] pred: Detection | None
is_ground_truth: bool is_ground_truth: bool
is_generic: bool is_generic: bool
is_prediction: bool is_prediction: bool
pred_class: Optional[str] pred_class: str | None
true_class: Optional[str] true_class: str | None
score: float score: float
@ -54,6 +57,17 @@ TopClassMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric") top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric")
@add_import_config(top_class_metrics)
class TopClassMetricImportConfig(ImportConfig):
"""Use any callable as a top-class metric.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class TopClassAveragePrecisionConfig(BaseConfig): class TopClassAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision" name: Literal["average_precision"] = "average_precision"
label: str = "average_precision" label: str = "average_precision"
@ -301,13 +315,11 @@ class BalancedAccuracy:
TopClassMetricConfig = Annotated[ TopClassMetricConfig = Annotated[
Union[ TopClassAveragePrecisionConfig
TopClassAveragePrecisionConfig, | TopClassROCAUCConfig
TopClassROCAUCConfig, | TopClassRecallConfig
TopClassRecallConfig, | TopClassPrecisionConfig
TopClassPrecisionConfig, | BalancedAccuracyConfig,
BalancedAccuracyConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -1,16 +1,14 @@
from typing import Optional
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.figure import Figure from matplotlib.figure import Figure
from batdetect2.core import BaseConfig from batdetect2.core import BaseConfig
from batdetect2.typing import TargetProtocol from batdetect2.targets.types import TargetProtocol
class BasePlotConfig(BaseConfig): class BasePlotConfig(BaseConfig):
label: str = "plot" label: str = "plot"
theme: str = "default" theme: str = "default"
title: Optional[str] = None title: str | None = None
figsize: tuple[int, int] = (10, 10) figsize: tuple[int, int] = (10, 10)
dpi: int = 100 dpi: int = 100
@ -21,7 +19,7 @@ class BasePlot:
targets: TargetProtocol, targets: TargetProtocol,
label: str = "plot", label: str = "plot",
figsize: tuple[int, int] = (10, 10), figsize: tuple[int, int] = (10, 10),
title: Optional[str] = None, title: str | None = None,
dpi: int = 100, dpi: int = 100,
theme: str = "default", theme: str = "default",
): ):

View File

@ -3,10 +3,8 @@ from typing import (
Callable, Callable,
Iterable, Iterable,
Literal, Literal,
Optional,
Sequence, Sequence,
Tuple, Tuple,
Union,
) )
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -14,7 +12,7 @@ from matplotlib.figure import Figure
from pydantic import Field from pydantic import Field
from sklearn import metrics from sklearn import metrics
from batdetect2.core import Registry from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.classification import ( from batdetect2.evaluate.metrics.classification import (
ClipEval, ClipEval,
_extract_per_class_metric_data, _extract_per_class_metric_data,
@ -31,7 +29,7 @@ from batdetect2.plotting.metrics import (
plot_threshold_recall_curve, plot_threshold_recall_curve,
plot_threshold_recall_curves, plot_threshold_recall_curves,
) )
from batdetect2.typing import TargetProtocol from batdetect2.targets.types import TargetProtocol
ClassificationPlotter = Callable[ ClassificationPlotter = Callable[
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]] [Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
@ -42,10 +40,21 @@ classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
) )
@add_import_config(classification_plots)
class ClassificationPlotImportConfig(ImportConfig):
"""Use any callable as a classification plot.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class PRCurveConfig(BasePlotConfig): class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve" name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve" label: str = "pr_curve"
title: Optional[str] = "Classification Precision-Recall Curve" title: str | None = "Classification Precision-Recall Curve"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
separate_figures: bool = False separate_figures: bool = False
@ -88,9 +97,7 @@ class PRCurve(BasePlot):
ax = plot_pr_curve(precision, recall, thresholds, ax=ax) ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
ax.set_title(class_name) ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig yield f"{self.label}/{class_name}", fig
plt.close(fig) plt.close(fig)
@classification_plots.register(PRCurveConfig) @classification_plots.register(PRCurveConfig)
@ -108,7 +115,7 @@ class PRCurve(BasePlot):
class ThresholdPrecisionCurveConfig(BasePlotConfig): class ThresholdPrecisionCurveConfig(BasePlotConfig):
name: Literal["threshold_precision_curve"] = "threshold_precision_curve" name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
label: str = "threshold_precision_curve" label: str = "threshold_precision_curve"
title: Optional[str] = "Classification Threshold-Precision Curve" title: str | None = "Classification Threshold-Precision Curve"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
separate_figures: bool = False separate_figures: bool = False
@ -181,7 +188,7 @@ class ThresholdPrecisionCurve(BasePlot):
class ThresholdRecallCurveConfig(BasePlotConfig): class ThresholdRecallCurveConfig(BasePlotConfig):
name: Literal["threshold_recall_curve"] = "threshold_recall_curve" name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
label: str = "threshold_recall_curve" label: str = "threshold_recall_curve"
title: Optional[str] = "Classification Threshold-Recall Curve" title: str | None = "Classification Threshold-Recall Curve"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
separate_figures: bool = False separate_figures: bool = False
@ -254,7 +261,7 @@ class ThresholdRecallCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig): class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve" name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve" label: str = "roc_curve"
title: Optional[str] = "Classification ROC Curve" title: str | None = "Classification ROC Curve"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
separate_figures: bool = False separate_figures: bool = False
@ -326,12 +333,10 @@ class ROCCurve(BasePlot):
ClassificationPlotConfig = Annotated[ ClassificationPlotConfig = Annotated[
Union[ PRCurveConfig
PRCurveConfig, | ROCCurveConfig
ROCCurveConfig, | ThresholdPrecisionCurveConfig
ThresholdPrecisionCurveConfig, | ThresholdRecallCurveConfig,
ThresholdRecallCurveConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -3,10 +3,8 @@ from typing import (
Callable, Callable,
Iterable, Iterable,
Literal, Literal,
Optional,
Sequence, Sequence,
Tuple, Tuple,
Union,
) )
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -14,7 +12,7 @@ from matplotlib.figure import Figure
from pydantic import Field from pydantic import Field
from sklearn import metrics from sklearn import metrics
from batdetect2.core import Registry from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.clip_classification import ClipEval from batdetect2.evaluate.metrics.clip_classification import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
@ -24,10 +22,11 @@ from batdetect2.plotting.metrics import (
plot_roc_curve, plot_roc_curve,
plot_roc_curves, plot_roc_curves,
) )
from batdetect2.typing import TargetProtocol from batdetect2.targets.types import TargetProtocol
__all__ = [ __all__ = [
"ClipClassificationPlotConfig", "ClipClassificationPlotConfig",
"ClipClassificationPlotImportConfig",
"ClipClassificationPlotter", "ClipClassificationPlotter",
"build_clip_classification_plotter", "build_clip_classification_plotter",
] ]
@ -41,10 +40,21 @@ clip_classification_plots: Registry[
] = Registry("clip_classification_plot") ] = Registry("clip_classification_plot")
@add_import_config(clip_classification_plots)
class ClipClassificationPlotImportConfig(ImportConfig):
"""Use any callable as a clip classification plot.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class PRCurveConfig(BasePlotConfig): class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve" name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve" label: str = "pr_curve"
title: Optional[str] = "Clip Classification Precision-Recall Curve" title: str | None = "Clip Classification Precision-Recall Curve"
separate_figures: bool = False separate_figures: bool = False
@ -111,7 +121,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig): class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve" name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve" label: str = "roc_curve"
title: Optional[str] = "Clip Classification ROC Curve" title: str | None = "Clip Classification ROC Curve"
separate_figures: bool = False separate_figures: bool = False
@ -174,10 +184,7 @@ class ROCCurve(BasePlot):
ClipClassificationPlotConfig = Annotated[ ClipClassificationPlotConfig = Annotated[
Union[ PRCurveConfig | ROCCurveConfig,
PRCurveConfig,
ROCCurveConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -3,10 +3,8 @@ from typing import (
Callable, Callable,
Iterable, Iterable,
Literal, Literal,
Optional,
Sequence, Sequence,
Tuple, Tuple,
Union,
) )
import pandas as pd import pandas as pd
@ -15,15 +13,16 @@ from matplotlib.figure import Figure
from pydantic import Field from pydantic import Field
from sklearn import metrics from sklearn import metrics
from batdetect2.core import Registry from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.clip_detection import ClipEval from batdetect2.evaluate.metrics.clip_detection import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.typing import TargetProtocol from batdetect2.targets.types import TargetProtocol
__all__ = [ __all__ = [
"ClipDetectionPlotConfig", "ClipDetectionPlotConfig",
"ClipDetectionPlotImportConfig",
"ClipDetectionPlotter", "ClipDetectionPlotter",
"build_clip_detection_plotter", "build_clip_detection_plotter",
] ]
@ -38,10 +37,21 @@ clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
) )
@add_import_config(clip_detection_plots)
class ClipDetectionPlotImportConfig(ImportConfig):
"""Use any callable as a clip detection plot.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class PRCurveConfig(BasePlotConfig): class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve" name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve" label: str = "pr_curve"
title: Optional[str] = "Clip Detection Precision-Recall Curve" title: str | None = "Clip Detection Precision-Recall Curve"
class PRCurve(BasePlot): class PRCurve(BasePlot):
@ -74,7 +84,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig): class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve" name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve" label: str = "roc_curve"
title: Optional[str] = "Clip Detection ROC Curve" title: str | None = "Clip Detection ROC Curve"
class ROCCurve(BasePlot): class ROCCurve(BasePlot):
@ -107,7 +117,7 @@ class ROCCurve(BasePlot):
class ScoreDistributionPlotConfig(BasePlotConfig): class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution" name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution" label: str = "score_distribution"
title: Optional[str] = "Clip Detection Score Distribution" title: str | None = "Clip Detection Score Distribution"
class ScoreDistributionPlot(BasePlot): class ScoreDistributionPlot(BasePlot):
@ -147,11 +157,7 @@ class ScoreDistributionPlot(BasePlot):
ClipDetectionPlotConfig = Annotated[ ClipDetectionPlotConfig = Annotated[
Union[ PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig,
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -4,10 +4,8 @@ from typing import (
Callable, Callable,
Iterable, Iterable,
Literal, Literal,
Optional,
Sequence, Sequence,
Tuple, Tuple,
Union,
) )
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -18,14 +16,16 @@ from pydantic import Field
from sklearn import metrics from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry from batdetect2.audio.types import AudioLoader
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.detection import ClipEval from batdetect2.evaluate.metrics.detection import ClipEval
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.detections import plot_clip_detections from batdetect2.plotting.detections import plot_clip_detections
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]] DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
@ -34,10 +34,21 @@ detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
) )
@add_import_config(detection_plots)
class DetectionPlotImportConfig(ImportConfig):
"""Use any callable as a detection plot.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class PRCurveConfig(BasePlotConfig): class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve" name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve" label: str = "pr_curve"
title: Optional[str] = "Detection Precision-Recall Curve" title: str | None = "Detection Precision-Recall Curve"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
@ -100,7 +111,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig): class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve" name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve" label: str = "roc_curve"
title: Optional[str] = "Detection ROC Curve" title: str | None = "Detection ROC Curve"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
@ -159,7 +170,7 @@ class ROCCurve(BasePlot):
class ScoreDistributionPlotConfig(BasePlotConfig): class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution" name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution" label: str = "score_distribution"
title: Optional[str] = "Detection Score Distribution" title: str | None = "Detection Score Distribution"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
@ -226,7 +237,7 @@ class ScoreDistributionPlot(BasePlot):
class ExampleDetectionPlotConfig(BasePlotConfig): class ExampleDetectionPlotConfig(BasePlotConfig):
name: Literal["example_detection"] = "example_detection" name: Literal["example_detection"] = "example_detection"
label: str = "example_detection" label: str = "example_detection"
title: Optional[str] = "Example Detection" title: str | None = "Example Detection"
figsize: tuple[int, int] = (10, 4) figsize: tuple[int, int] = (10, 4)
num_examples: int = 5 num_examples: int = 5
threshold: float = 0.2 threshold: float = 0.2
@ -292,12 +303,10 @@ class ExampleDetectionPlot(BasePlot):
DetectionPlotConfig = Annotated[ DetectionPlotConfig = Annotated[
Union[ PRCurveConfig
PRCurveConfig, | ROCCurveConfig
ROCCurveConfig, | ScoreDistributionPlotConfig
ScoreDistributionPlotConfig, | ExampleDetectionPlotConfig,
ExampleDetectionPlotConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -4,14 +4,9 @@ from dataclasses import dataclass, field
from typing import ( from typing import (
Annotated, Annotated,
Callable, Callable,
Dict,
Iterable, Iterable,
List,
Literal, Literal,
Optional,
Sequence, Sequence,
Tuple,
Union,
) )
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -21,7 +16,8 @@ from pydantic import Field
from sklearn import metrics from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry from batdetect2.audio.types import AudioLoader
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.top_class import ( from batdetect2.evaluate.metrics.top_class import (
ClipEval, ClipEval,
@ -32,19 +28,31 @@ from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.gallery import plot_match_gallery from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]] TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[tuple[str, Figure]]]
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry( top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
name="top_class_plot" name="top_class_plot"
) )
@add_import_config(top_class_plots)
class TopClassPlotImportConfig(ImportConfig):
"""Use any callable as a top-class plot.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class PRCurveConfig(BasePlotConfig): class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve" name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve" label: str = "pr_curve"
title: Optional[str] = "Top Class Precision-Recall Curve" title: str | None = "Top Class Precision-Recall Curve"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
@ -64,7 +72,7 @@ class PRCurve(BasePlot):
def __call__( def __call__(
self, self,
clip_evaluations: Sequence[ClipEval], clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]: ) -> Iterable[tuple[str, Figure]]:
y_true = [] y_true = []
y_score = [] y_score = []
num_positives = 0 num_positives = 0
@ -111,7 +119,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig): class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve" name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve" label: str = "roc_curve"
title: Optional[str] = "Top Class ROC Curve" title: str | None = "Top Class ROC Curve"
ignore_non_predictions: bool = True ignore_non_predictions: bool = True
ignore_generic: bool = True ignore_generic: bool = True
@ -131,7 +139,7 @@ class ROCCurve(BasePlot):
def __call__( def __call__(
self, self,
clip_evaluations: Sequence[ClipEval], clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]: ) -> Iterable[tuple[str, Figure]]:
y_true = [] y_true = []
y_score = [] y_score = []
@ -173,7 +181,7 @@ class ROCCurve(BasePlot):
class ConfusionMatrixConfig(BasePlotConfig): class ConfusionMatrixConfig(BasePlotConfig):
name: Literal["confusion_matrix"] = "confusion_matrix" name: Literal["confusion_matrix"] = "confusion_matrix"
title: Optional[str] = "Top Class Confusion Matrix" title: str | None = "Top Class Confusion Matrix"
figsize: tuple[int, int] = (10, 10) figsize: tuple[int, int] = (10, 10)
label: str = "confusion_matrix" label: str = "confusion_matrix"
exclude_generic: bool = True exclude_generic: bool = True
@ -214,7 +222,7 @@ class ConfusionMatrix(BasePlot):
def __call__( def __call__(
self, self,
clip_evaluations: Sequence[ClipEval], clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]: ) -> Iterable[tuple[str, Figure]]:
cm, labels = compute_confusion_matrix( cm, labels = compute_confusion_matrix(
clip_evaluations, clip_evaluations,
self.targets, self.targets,
@ -257,7 +265,7 @@ class ConfusionMatrix(BasePlot):
class ExampleClassificationPlotConfig(BasePlotConfig): class ExampleClassificationPlotConfig(BasePlotConfig):
name: Literal["example_classification"] = "example_classification" name: Literal["example_classification"] = "example_classification"
label: str = "example_classification" label: str = "example_classification"
title: Optional[str] = "Example Classification" title: str | None = "Example Classification"
num_examples: int = 4 num_examples: int = 4
threshold: float = 0.2 threshold: float = 0.2
audio: AudioConfig = Field(default_factory=AudioConfig) audio: AudioConfig = Field(default_factory=AudioConfig)
@ -286,26 +294,26 @@ class ExampleClassificationPlot(BasePlot):
def __call__( def __call__(
self, self,
clip_evaluations: Sequence[ClipEval], clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]: ) -> Iterable[tuple[str, Figure]]:
grouped = group_matches(clip_evaluations, threshold=self.threshold) grouped = group_matches(clip_evaluations, threshold=self.threshold)
for class_name, matches in grouped.items(): for class_name, matches in grouped.items():
true_positives: List[MatchEval] = get_binned_sample( true_positives: list[MatchEval] = get_binned_sample(
matches.true_positives, matches.true_positives,
n_examples=self.num_examples, n_examples=self.num_examples,
) )
false_positives: List[MatchEval] = get_binned_sample( false_positives: list[MatchEval] = get_binned_sample(
matches.false_positives, matches.false_positives,
n_examples=self.num_examples, n_examples=self.num_examples,
) )
false_negatives: List[MatchEval] = random.sample( false_negatives: list[MatchEval] = random.sample(
matches.false_negatives, matches.false_negatives,
k=min(self.num_examples, len(matches.false_negatives)), k=min(self.num_examples, len(matches.false_negatives)),
) )
cross_triggers: List[MatchEval] = get_binned_sample( cross_triggers: list[MatchEval] = get_binned_sample(
matches.cross_triggers, n_examples=self.num_examples matches.cross_triggers, n_examples=self.num_examples
) )
@ -348,12 +356,10 @@ class ExampleClassificationPlot(BasePlot):
TopClassPlotConfig = Annotated[ TopClassPlotConfig = Annotated[
Union[ PRCurveConfig
PRCurveConfig, | ROCCurveConfig
ROCCurveConfig, | ConfusionMatrixConfig
ConfusionMatrixConfig, | ExampleClassificationPlotConfig,
ExampleClassificationPlotConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
@ -367,16 +373,16 @@ def build_top_class_plotter(
@dataclass @dataclass
class ClassMatches: class ClassMatches:
false_positives: List[MatchEval] = field(default_factory=list) false_positives: list[MatchEval] = field(default_factory=list)
false_negatives: List[MatchEval] = field(default_factory=list) false_negatives: list[MatchEval] = field(default_factory=list)
true_positives: List[MatchEval] = field(default_factory=list) true_positives: list[MatchEval] = field(default_factory=list)
cross_triggers: List[MatchEval] = field(default_factory=list) cross_triggers: list[MatchEval] = field(default_factory=list)
def group_matches( def group_matches(
clip_evals: Sequence[ClipEval], clip_evals: Sequence[ClipEval],
threshold: float = 0.2, threshold: float = 0.2,
) -> Dict[str, ClassMatches]: ) -> dict[str, ClassMatches]:
class_examples = defaultdict(ClassMatches) class_examples = defaultdict(ClassMatches)
for clip_eval in clip_evals: for clip_eval in clip_evals:
@ -405,12 +411,13 @@ def group_matches(
return class_examples return class_examples
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5): def get_binned_sample(matches: list[MatchEval], n_examples: int = 5):
if len(matches) < n_examples: if len(matches) < n_examples:
return matches return matches
indices, pred_scores = zip( indices, pred_scores = zip(
*[(index, match.score) for index, match in enumerate(matches)] *[(index, match.score) for index, match in enumerate(matches)],
strict=False,
) )
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop") bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")

View File

@ -0,0 +1,27 @@
import json
from pathlib import Path
from typing import Iterable
from matplotlib.figure import Figure
from soundevent import data
__all__ = ["save_evaluation_results"]
def save_evaluation_results(
metrics: dict[str, float],
plots: Iterable[tuple[str, Figure]],
output_dir: data.PathLike,
) -> None:
"""Save evaluation metrics and plots to disk."""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
metrics_path = output_path / "metrics.json"
metrics_path.write_text(json.dumps(metrics))
for figure_name, figure in plots:
figure_path = output_path / figure_name
figure_path.parent.mkdir(parents=True, exist_ok=True)
figure.savefig(figure_path)

View File

@ -1,106 +0,0 @@
from typing import Annotated, Callable, Literal, Sequence, Union
import pandas as pd
from pydantic import Field
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipMatches
EvaluationTableGenerator = Callable[[Sequence[ClipMatches]], pd.DataFrame]
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
"evaluation_table"
)
class FullEvaluationTableConfig(BaseConfig):
name: Literal["full_evaluation"] = "full_evaluation"
class FullEvaluationTable:
def __call__(
self, clip_evaluations: Sequence[ClipMatches]
) -> pd.DataFrame:
return extract_matches_dataframe(clip_evaluations)
@tables_registry.register(FullEvaluationTableConfig)
@staticmethod
def from_config(config: FullEvaluationTableConfig):
return FullEvaluationTable()
def extract_matches_dataframe(
clip_evaluations: Sequence[ClipMatches],
) -> pd.DataFrame:
data = []
for clip_evaluation in clip_evaluations:
for match in clip_evaluation.matches:
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
pred_start_time = pred_low_freq = pred_end_time = (
pred_high_freq
) = None
sound_event_annotation = match.sound_event_annotation
if sound_event_annotation is not None:
geometry = sound_event_annotation.sound_event.geometry
assert geometry is not None
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
compute_bounds(geometry)
)
if match.pred_geometry is not None:
(
pred_start_time,
pred_low_freq,
pred_end_time,
pred_high_freq,
) = compute_bounds(match.pred_geometry)
data.append(
{
("recording", "uuid"): match.clip.recording.uuid,
("clip", "uuid"): match.clip.uuid,
("clip", "start_time"): match.clip.start_time,
("clip", "end_time"): match.clip.end_time,
("gt", "uuid"): match.sound_event_annotation.uuid
if match.sound_event_annotation is not None
else None,
("gt", "class"): match.gt_class,
("gt", "det"): match.gt_det,
("gt", "start_time"): gt_start_time,
("gt", "end_time"): gt_end_time,
("gt", "low_freq"): gt_low_freq,
("gt", "high_freq"): gt_high_freq,
("pred", "score"): match.pred_score,
("pred", "class"): match.top_class,
("pred", "class_score"): match.top_class_score,
("pred", "start_time"): pred_start_time,
("pred", "end_time"): pred_end_time,
("pred", "low_freq"): pred_low_freq,
("pred", "high_freq"): pred_high_freq,
("match", "affinity"): match.affinity,
**{
("pred_class_score", key): value
for key, value in match.pred_class_scores.items()
},
}
)
df = pd.DataFrame(data)
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
return df
EvaluationTableConfig = Annotated[
Union[FullEvaluationTableConfig,], Field(discriminator="name")
]
def build_table_generator(
config: EvaluationTableConfig,
) -> EvaluationTableGenerator:
return tables_registry.build(config)

View File

@ -1,4 +1,4 @@
from typing import Annotated, Optional, Sequence, Union from typing import Annotated, Sequence
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -11,12 +11,10 @@ from batdetect2.evaluate.tasks.clip_classification import (
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
from batdetect2.evaluate.types import EvaluationTaskProtocol
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.typing import ( from batdetect2.targets.types import TargetProtocol
BatDetect2Prediction,
EvaluatorProtocol,
TargetProtocol,
)
__all__ = [ __all__ = [
"TaskConfig", "TaskConfig",
@ -26,31 +24,29 @@ __all__ = [
TaskConfig = Annotated[ TaskConfig = Annotated[
Union[ ClassificationTaskConfig
ClassificationTaskConfig, | DetectionTaskConfig
DetectionTaskConfig, | ClipDetectionTaskConfig
ClipDetectionTaskConfig, | ClipClassificationTaskConfig
ClipClassificationTaskConfig, | TopClassDetectionTaskConfig,
TopClassDetectionTaskConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
def build_task( def build_task(
config: TaskConfig, config: TaskConfig,
targets: Optional[TargetProtocol] = None, targets: TargetProtocol | None = None,
) -> EvaluatorProtocol: ) -> EvaluationTaskProtocol:
targets = targets or build_targets() targets = targets or build_targets()
return tasks_registry.build(config, targets) return tasks_registry.build(config, targets)
def evaluate_task( def evaluate_task(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[ClipDetections],
task: Optional["str"] = None, task: str | None = None,
targets: Optional[TargetProtocol] = None, targets: TargetProtocol | None = None,
config: Optional[Union[TaskConfig, dict]] = None, config: TaskConfig | dict | None = None,
): ):
if isinstance(config, BaseTaskConfig): if isinstance(config, BaseTaskConfig):
task_obj = build_task(config, targets) task_obj = build_task(config, targets)

View File

@ -4,78 +4,93 @@ from typing import (
Generic, Generic,
Iterable, Iterable,
List, List,
Optional, Literal,
Sequence, Sequence,
Tuple, Tuple,
TypeVar, TypeVar,
) )
from loguru import logger
from matplotlib.figure import Figure from matplotlib.figure import Figure
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig from batdetect2.core import (
from batdetect2.core.registries import Registry BaseConfig,
from batdetect2.evaluate.match import ( ImportConfig,
MatchConfig, Registry,
StartTimeMatchConfig, add_import_config,
build_matcher,
) )
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol from batdetect2.evaluate.affinity import (
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction AffinityConfig,
from batdetect2.typing.targets import TargetProtocol TimeAffinityConfig,
build_affinity_function,
)
from batdetect2.evaluate.types import (
AffinityFunction,
EvaluationTaskProtocol,
)
from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol
__all__ = [ __all__ = [
"BaseTaskConfig", "BaseTaskConfig",
"BaseTask", "BaseTask",
"TaskImportConfig",
] ]
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry( tasks_registry: Registry[EvaluationTaskProtocol, [TargetProtocol]] = Registry(
"tasks" "tasks"
) )
@add_import_config(tasks_registry)
class TaskImportConfig(ImportConfig):
"""Use any callable as an evaluation task.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
T_Output = TypeVar("T_Output") T_Output = TypeVar("T_Output")
class BaseTaskConfig(BaseConfig): class BaseTaskConfig(BaseConfig):
prefix: str prefix: str
ignore_start_end: float = 0.01 ignore_start_end: float = 0.01
matching_strategy: MatchConfig = Field(
default_factory=StartTimeMatchConfig
)
class BaseTask(EvaluatorProtocol, Generic[T_Output]): class BaseTask(EvaluationTaskProtocol, Generic[T_Output]):
targets: TargetProtocol targets: TargetProtocol
matcher: MatcherProtocol
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]] metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]] plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
ignore_start_end: float
prefix: str prefix: str
ignore_start_end: float
def __init__( def __init__(
self, self,
matcher: MatcherProtocol,
targets: TargetProtocol, targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]], metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
prefix: str, prefix: str,
plots: List[
Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]
]
| None = None,
ignore_start_end: float = 0.01, ignore_start_end: float = 0.01,
plots: Optional[
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
] = None,
): ):
self.matcher = matcher self.prefix = prefix
self.targets = targets
self.metrics = metrics self.metrics = metrics
self.plots = plots or [] self.plots = plots or []
self.targets = targets
self.prefix = prefix
self.ignore_start_end = ignore_start_end self.ignore_start_end = ignore_start_end
def compute_metrics( def compute_metrics(
@ -93,24 +108,30 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
self, eval_outputs: List[T_Output] self, eval_outputs: List[T_Output]
) -> Iterable[Tuple[str, Figure]]: ) -> Iterable[Tuple[str, Figure]]:
for plot in self.plots: for plot in self.plots:
for name, fig in plot(eval_outputs): try:
yield f"{self.prefix}/{name}", fig for name, fig in plot(eval_outputs):
yield f"{self.prefix}/{name}", fig
except Exception as e:
logger.error(f"Error plotting {self.prefix}: {e}")
continue
def evaluate( def evaluate(
self, self,
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction], predictions: Sequence[ClipDetections],
) -> List[T_Output]: ) -> List[T_Output]:
return [ return [
self.evaluate_clip(clip_annotation, preds) self.evaluate_clip(clip_annotation, preds)
for clip_annotation, preds in zip(clip_annotations, predictions) for clip_annotation, preds in zip(
clip_annotations, predictions, strict=False
)
] ]
def evaluate_clip( def evaluate_clip(
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction, prediction: ClipDetections,
) -> T_Output: ... ) -> T_Output: ... # ty: ignore[empty-body]
def include_sound_event_annotation( def include_sound_event_annotation(
self, self,
@ -121,9 +142,6 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
return False return False
geometry = sound_event_annotation.sound_event.geometry geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
return is_in_bounds( return is_in_bounds(
geometry, geometry,
clip, clip,
@ -132,7 +150,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
def include_prediction( def include_prediction(
self, self,
prediction: RawPrediction, prediction: Detection,
clip: data.Clip, clip: data.Clip,
) -> bool: ) -> bool:
return is_in_bounds( return is_in_bounds(
@ -141,25 +159,56 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
self.ignore_start_end, self.ignore_start_end,
) )
class BaseSEDTaskConfig(BaseTaskConfig):
affinity: AffinityConfig = Field(default_factory=TimeAffinityConfig)
affinity_threshold: float = 0
strict_match: bool = True
class BaseSEDTask(BaseTask[T_Output]):
affinity: AffinityFunction
def __init__(
self,
prefix: str,
targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
affinity: AffinityFunction,
plots: List[
Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]
]
| None = None,
affinity_threshold: float = 0,
ignore_start_end: float = 0.01,
strict_match: bool = True,
):
super().__init__(
prefix=prefix,
metrics=metrics,
plots=plots,
targets=targets,
ignore_start_end=ignore_start_end,
)
self.affinity = affinity
self.affinity_threshold = affinity_threshold
self.strict_match = strict_match
@classmethod @classmethod
def build( def build(
cls, cls,
config: BaseTaskConfig, config: BaseSEDTaskConfig,
targets: TargetProtocol, targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
plots: Optional[
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
] = None,
**kwargs, **kwargs,
): ):
matcher = build_matcher(config.matching_strategy) affinity = build_affinity_function(config.affinity)
return cls( return cls(
matcher=matcher, affinity=affinity,
targets=targets, affinity_threshold=config.affinity_threshold,
metrics=metrics,
plots=plots,
prefix=config.prefix, prefix=config.prefix,
ignore_start_end=config.ignore_start_end, ignore_start_end=config.ignore_start_end,
strict_match=config.strict_match,
targets=targets,
**kwargs, **kwargs,
) )

View File

@ -1,10 +1,9 @@
from typing import ( from functools import partial
List, from typing import Literal
Literal,
)
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.metrics.classification import ( from batdetect2.evaluate.metrics.classification import (
ClassificationAveragePrecisionConfig, ClassificationAveragePrecisionConfig,
@ -18,24 +17,25 @@ from batdetect2.evaluate.plots.classification import (
build_classification_plotter, build_classification_plotter,
) )
from batdetect2.evaluate.tasks.base import ( from batdetect2.evaluate.tasks.base import (
BaseTask, BaseSEDTask,
BaseTaskConfig, BaseSEDTaskConfig,
tasks_registry, tasks_registry,
) )
from batdetect2.typing import BatDetect2Prediction, TargetProtocol from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol
class ClassificationTaskConfig(BaseTaskConfig): class ClassificationTaskConfig(BaseSEDTaskConfig):
name: Literal["sound_event_classification"] = "sound_event_classification" name: Literal["sound_event_classification"] = "sound_event_classification"
prefix: str = "classification" prefix: str = "classification"
metrics: List[ClassificationMetricConfig] = Field( metrics: list[ClassificationMetricConfig] = Field(
default_factory=lambda: [ClassificationAveragePrecisionConfig()] default_factory=lambda: [ClassificationAveragePrecisionConfig()]
) )
plots: List[ClassificationPlotConfig] = Field(default_factory=list) plots: list[ClassificationPlotConfig] = Field(default_factory=list)
include_generics: bool = True include_generics: bool = True
class ClassificationTask(BaseTask[ClipEval]): class ClassificationTask(BaseSEDTask[ClipEval]):
def __init__( def __init__(
self, self,
*args, *args,
@ -48,13 +48,13 @@ class ClassificationTask(BaseTask[ClipEval]):
def evaluate_clip( def evaluate_clip(
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction, prediction: ClipDetections,
) -> ClipEval: ) -> ClipEval:
clip = clip_annotation.clip clip = clip_annotation.clip
preds = [ preds = [
pred pred
for pred in prediction.predictions for pred in prediction.detections
if self.include_prediction(pred, clip) if self.include_prediction(pred, clip)
] ]
@ -73,40 +73,40 @@ class ClassificationTask(BaseTask[ClipEval]):
gts = [ gts = [
sound_event sound_event
for sound_event in all_gts for sound_event in all_gts
if self.is_class(sound_event, class_name) if is_target_class(
sound_event,
class_name,
self.targets,
include_generics=self.include_generics,
)
] ]
scores = [float(pred.class_scores[class_idx]) for pred in preds]
matches = [] matches = []
for pred_idx, gt_idx, _ in self.matcher( for match in match_detections_and_gts(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore detections=preds,
predictions=[pred.geometry for pred in preds], ground_truths=gts,
scores=scores, affinity=self.affinity,
score=partial(get_class_score, class_idx=class_idx),
strict_match=self.strict_match,
affinity_threshold=self.affinity_threshold,
): ):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
true_class = ( true_class = (
self.targets.encode_class(gt) if gt is not None else None self.targets.encode_class(match.annotation)
if match.annotation is not None
else None
) )
score = (
float(pred.class_scores[class_idx])
if pred is not None
else 0
)
matches.append( matches.append(
MatchEval( MatchEval(
clip=clip, clip=clip,
gt=gt, gt=match.annotation,
pred=pred, pred=match.prediction,
is_prediction=pred is not None, is_prediction=match.prediction is not None,
is_ground_truth=gt is not None, is_ground_truth=match.annotation is not None,
is_generic=gt is not None and true_class is None, is_generic=match.annotation is not None
and true_class is None,
true_class=true_class, true_class=true_class,
score=score, score=match.prediction_score,
) )
) )
@ -114,20 +114,6 @@ class ClassificationTask(BaseTask[ClipEval]):
return ClipEval(clip=clip, matches=per_class_matches) return ClipEval(clip=clip, matches=per_class_matches)
def is_class(
self,
sound_event: data.SoundEventAnnotation,
class_name: str,
) -> bool:
sound_event_class = self.targets.encode_class(sound_event)
if sound_event_class is None and self.include_generics:
# Sound events that are generic could be of the given
# class
return True
return sound_event_class == class_name
@tasks_registry.register(ClassificationTaskConfig) @tasks_registry.register(ClassificationTaskConfig)
@staticmethod @staticmethod
def from_config( def from_config(
@ -147,4 +133,25 @@ class ClassificationTask(BaseTask[ClipEval]):
plots=plots, plots=plots,
targets=targets, targets=targets,
metrics=metrics, metrics=metrics,
include_generics=config.include_generics,
) )
def get_class_score(pred: Detection, class_idx: int) -> float:
return pred.class_scores[class_idx]
def is_target_class(
sound_event: data.SoundEventAnnotation,
class_name: str,
targets: TargetProtocol,
include_generics: bool = True,
) -> bool:
sound_event_class = targets.encode_class(sound_event)
if sound_event_class is None and include_generics:
# Sound events that are generic could be of the given
# class
return True
return sound_event_class == class_name

View File

@ -1,5 +1,5 @@
from collections import defaultdict from collections import defaultdict
from typing import List, Literal from typing import Literal
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -19,26 +19,26 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig, BaseTaskConfig,
tasks_registry, tasks_registry,
) )
from batdetect2.typing import TargetProtocol from batdetect2.postprocess.types import ClipDetections
from batdetect2.typing.postprocess import BatDetect2Prediction from batdetect2.targets.types import TargetProtocol
class ClipClassificationTaskConfig(BaseTaskConfig): class ClipClassificationTaskConfig(BaseTaskConfig):
name: Literal["clip_classification"] = "clip_classification" name: Literal["clip_classification"] = "clip_classification"
prefix: str = "clip_classification" prefix: str = "clip_classification"
metrics: List[ClipClassificationMetricConfig] = Field( metrics: list[ClipClassificationMetricConfig] = Field(
default_factory=lambda: [ default_factory=lambda: [
ClipClassificationAveragePrecisionConfig(), ClipClassificationAveragePrecisionConfig(),
] ]
) )
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list) plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)
class ClipClassificationTask(BaseTask[ClipEval]): class ClipClassificationTask(BaseTask[ClipEval]):
def evaluate_clip( def evaluate_clip(
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction, prediction: ClipDetections,
) -> ClipEval: ) -> ClipEval:
clip = clip_annotation.clip clip = clip_annotation.clip
@ -55,7 +55,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
gt_classes.add(class_name) gt_classes.add(class_name)
pred_scores = defaultdict(float) pred_scores = defaultdict(float)
for pred in prediction.predictions: for pred in prediction.detections:
if not self.include_prediction(pred, clip): if not self.include_prediction(pred, clip):
continue continue
@ -78,8 +78,8 @@ class ClipClassificationTask(BaseTask[ClipEval]):
build_clip_classification_plotter(plot, targets) build_clip_classification_plotter(plot, targets)
for plot in config.plots for plot in config.plots
] ]
return ClipClassificationTask.build( return ClipClassificationTask(
config=config, prefix=config.prefix,
plots=plots, plots=plots,
metrics=metrics, metrics=metrics,
targets=targets, targets=targets,

View File

@ -1,4 +1,4 @@
from typing import List, Literal from typing import Literal
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -18,26 +18,26 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig, BaseTaskConfig,
tasks_registry, tasks_registry,
) )
from batdetect2.typing import TargetProtocol from batdetect2.postprocess.types import ClipDetections
from batdetect2.typing.postprocess import BatDetect2Prediction from batdetect2.targets.types import TargetProtocol
class ClipDetectionTaskConfig(BaseTaskConfig): class ClipDetectionTaskConfig(BaseTaskConfig):
name: Literal["clip_detection"] = "clip_detection" name: Literal["clip_detection"] = "clip_detection"
prefix: str = "clip_detection" prefix: str = "clip_detection"
metrics: List[ClipDetectionMetricConfig] = Field( metrics: list[ClipDetectionMetricConfig] = Field(
default_factory=lambda: [ default_factory=lambda: [
ClipDetectionAveragePrecisionConfig(), ClipDetectionAveragePrecisionConfig(),
] ]
) )
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list) plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)
class ClipDetectionTask(BaseTask[ClipEval]): class ClipDetectionTask(BaseTask[ClipEval]):
def evaluate_clip( def evaluate_clip(
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction, prediction: ClipDetections,
) -> ClipEval: ) -> ClipEval:
clip = clip_annotation.clip clip = clip_annotation.clip
@ -47,7 +47,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
) )
pred_score = 0 pred_score = 0
for pred in prediction.predictions: for pred in prediction.detections:
if not self.include_prediction(pred, clip): if not self.include_prediction(pred, clip):
continue continue
@ -69,8 +69,8 @@ class ClipDetectionTask(BaseTask[ClipEval]):
build_clip_detection_plotter(plot, targets) build_clip_detection_plotter(plot, targets)
for plot in config.plots for plot in config.plots
] ]
return ClipDetectionTask.build( return ClipDetectionTask(
config=config, prefix=config.prefix,
metrics=metrics, metrics=metrics,
targets=targets, targets=targets,
plots=plots, plots=plots,

View File

@ -1,7 +1,8 @@
from typing import List, Literal from typing import Literal
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.metrics.detection import ( from batdetect2.evaluate.metrics.detection import (
ClipEval, ClipEval,
@ -15,28 +16,28 @@ from batdetect2.evaluate.plots.detection import (
build_detection_plotter, build_detection_plotter,
) )
from batdetect2.evaluate.tasks.base import ( from batdetect2.evaluate.tasks.base import (
BaseTask, BaseSEDTask,
BaseTaskConfig, BaseSEDTaskConfig,
tasks_registry, tasks_registry,
) )
from batdetect2.typing import TargetProtocol from batdetect2.postprocess.types import ClipDetections
from batdetect2.typing.postprocess import BatDetect2Prediction from batdetect2.targets.types import TargetProtocol
class DetectionTaskConfig(BaseTaskConfig): class DetectionTaskConfig(BaseSEDTaskConfig):
name: Literal["sound_event_detection"] = "sound_event_detection" name: Literal["sound_event_detection"] = "sound_event_detection"
prefix: str = "detection" prefix: str = "detection"
metrics: List[DetectionMetricConfig] = Field( metrics: list[DetectionMetricConfig] = Field(
default_factory=lambda: [DetectionAveragePrecisionConfig()] default_factory=lambda: [DetectionAveragePrecisionConfig()]
) )
plots: List[DetectionPlotConfig] = Field(default_factory=list) plots: list[DetectionPlotConfig] = Field(default_factory=list)
class DetectionTask(BaseTask[ClipEval]): class DetectionTask(BaseSEDTask[ClipEval]):
def evaluate_clip( def evaluate_clip(
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction, prediction: ClipDetections,
) -> ClipEval: ) -> ClipEval:
clip = clip_annotation.clip clip = clip_annotation.clip
@ -47,27 +48,26 @@ class DetectionTask(BaseTask[ClipEval]):
] ]
preds = [ preds = [
pred pred
for pred in prediction.predictions for pred in prediction.detections
if self.include_prediction(pred, clip) if self.include_prediction(pred, clip)
] ]
scores = [pred.detection_score for pred in preds]
matches = [] matches = []
for pred_idx, gt_idx, _ in self.matcher( for match in match_detections_and_gts(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore detections=preds,
predictions=[pred.geometry for pred in preds], ground_truths=gts,
scores=scores, affinity=self.affinity,
score=lambda pred: pred.detection_score,
strict_match=self.strict_match,
affinity_threshold=self.affinity_threshold,
): ):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
matches.append( matches.append(
MatchEval( MatchEval(
gt=gt, gt=match.annotation,
pred=pred, pred=match.prediction,
is_prediction=pred is not None, is_prediction=match.prediction is not None,
is_ground_truth=gt is not None, is_ground_truth=match.annotation is not None,
score=pred.detection_score if pred is not None else 0, score=match.prediction_score,
) )
) )

View File

@ -1,7 +1,8 @@
from typing import List, Literal from typing import Literal
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.metrics.top_class import ( from batdetect2.evaluate.metrics.top_class import (
ClipEval, ClipEval,
@ -15,28 +16,28 @@ from batdetect2.evaluate.plots.top_class import (
build_top_class_plotter, build_top_class_plotter,
) )
from batdetect2.evaluate.tasks.base import ( from batdetect2.evaluate.tasks.base import (
BaseTask, BaseSEDTask,
BaseTaskConfig, BaseSEDTaskConfig,
tasks_registry, tasks_registry,
) )
from batdetect2.typing import TargetProtocol from batdetect2.postprocess.types import ClipDetections
from batdetect2.typing.postprocess import BatDetect2Prediction from batdetect2.targets.types import TargetProtocol
class TopClassDetectionTaskConfig(BaseTaskConfig): class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
name: Literal["top_class_detection"] = "top_class_detection" name: Literal["top_class_detection"] = "top_class_detection"
prefix: str = "top_class" prefix: str = "top_class"
metrics: List[TopClassMetricConfig] = Field( metrics: list[TopClassMetricConfig] = Field(
default_factory=lambda: [TopClassAveragePrecisionConfig()] default_factory=lambda: [TopClassAveragePrecisionConfig()]
) )
plots: List[TopClassPlotConfig] = Field(default_factory=list) plots: list[TopClassPlotConfig] = Field(default_factory=list)
class TopClassDetectionTask(BaseTask[ClipEval]): class TopClassDetectionTask(BaseSEDTask[ClipEval]):
def evaluate_clip( def evaluate_clip(
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction, prediction: ClipDetections,
) -> ClipEval: ) -> ClipEval:
clip = clip_annotation.clip clip = clip_annotation.clip
@ -47,21 +48,21 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
] ]
preds = [ preds = [
pred pred
for pred in prediction.predictions for pred in prediction.detections
if self.include_prediction(pred, clip) if self.include_prediction(pred, clip)
] ]
# Take the highest score for each prediction
scores = [pred.class_scores.max() for pred in preds]
matches = [] matches = []
for pred_idx, gt_idx, _ in self.matcher( for match in match_detections_and_gts(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore ground_truths=gts,
predictions=[pred.geometry for pred in preds], detections=preds,
scores=scores, affinity=self.affinity,
score=lambda pred: pred.class_scores.max(),
strict_match=self.strict_match,
affinity_threshold=self.affinity_threshold,
): ):
gt = gts[gt_idx] if gt_idx is not None else None gt = match.annotation
pred = preds[pred_idx] if pred_idx is not None else None pred = match.prediction
true_class = ( true_class = (
self.targets.encode_class(gt) if gt is not None else None self.targets.encode_class(gt) if gt is not None else None
) )
@ -69,11 +70,6 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
class_idx = ( class_idx = (
pred.class_scores.argmax() if pred is not None else None pred.class_scores.argmax() if pred is not None else None
) )
score = (
float(pred.class_scores[class_idx]) if pred is not None else 0
)
pred_class = ( pred_class = (
self.targets.class_names[class_idx] self.targets.class_names[class_idx]
if class_idx is not None if class_idx is not None
@ -90,7 +86,7 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
true_class=true_class, true_class=true_class,
is_generic=gt is not None and true_class is None, is_generic=gt is not None and true_class is None,
pred_class=pred_class, pred_class=pred_class,
score=score, score=match.prediction_score,
) )
) )

View File

@ -0,0 +1,147 @@
from dataclasses import dataclass
from typing import Generic, Iterable, Protocol, Sequence, TypeVar
from matplotlib.figure import Figure
from soundevent import data
from batdetect2.outputs.types import OutputTransformProtocol
from batdetect2.postprocess.types import (
ClipDetections,
ClipDetectionsTensor,
Detection,
)
from batdetect2.targets.types import TargetProtocol
__all__ = [
"AffinityFunction",
"ClipMatches",
"EvaluationTaskProtocol",
"EvaluatorProtocol",
"MatchEvaluation",
"MatcherProtocol",
"MetricsProtocol",
"PlotterProtocol",
]
@dataclass
class MatchEvaluation:
clip: data.Clip
sound_event_annotation: data.SoundEventAnnotation | None
gt_det: bool
gt_class: str | None
gt_geometry: data.Geometry | None
pred_score: float
pred_class_scores: dict[str, float]
pred_geometry: data.Geometry | None
affinity: float
@property
def top_class(self) -> str | None:
if not self.pred_class_scores:
return None
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
@property
def is_prediction(self) -> bool:
return self.pred_geometry is not None
@property
def is_generic(self) -> bool:
return self.gt_det and self.gt_class is None
@property
def top_class_score(self) -> float:
pred_class = self.top_class
if pred_class is None:
return 0
return self.pred_class_scores[pred_class]
@dataclass
class ClipMatches:
clip: data.Clip
matches: list[MatchEvaluation]
class MatcherProtocol(Protocol):
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
) -> Iterable[tuple[int | None, int | None, float]]: ...
class AffinityFunction(Protocol):
def __call__(
self,
detection: Detection,
ground_truth: data.SoundEventAnnotation,
) -> float: ...
class MetricsProtocol(Protocol):
def __call__(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[Detection]],
) -> dict[str, float]: ...
class PlotterProtocol(Protocol):
def __call__(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[Detection]],
) -> Iterable[tuple[str, Figure]]: ...
EvaluationOutput = TypeVar("EvaluationOutput")
class EvaluationTaskProtocol(Protocol, Generic[EvaluationOutput]):
targets: TargetProtocol
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[ClipDetections],
) -> EvaluationOutput: ...
def compute_metrics(
self,
eval_outputs: EvaluationOutput,
) -> dict[str, float]: ...
def generate_plots(
self,
eval_outputs: EvaluationOutput,
) -> Iterable[tuple[str, Figure]]: ...
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
targets: TargetProtocol
transform: OutputTransformProtocol
def to_clip_detections_batch(
self,
clip_detections: Sequence[ClipDetectionsTensor],
clips: Sequence[data.Clip],
) -> list[ClipDetections]: ...
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[ClipDetections],
) -> EvaluationOutput: ...
def compute_metrics(
self,
eval_outputs: EvaluationOutput,
) -> dict[str, float]: ...
def generate_plots(
self,
eval_outputs: EvaluationOutput,
) -> Iterable[tuple[str, Figure]]: ...

View File

@ -1,7 +1,7 @@
import argparse import argparse
import os import os
import warnings import warnings
from typing import List, Optional from typing import List
import torch import torch
import torch.utils.data import torch.utils.data
@ -88,7 +88,8 @@ def select_device(warn=True) -> str:
if warn: if warn:
warnings.warn( warnings.warn(
"No GPU available, using the CPU instead. Please consider using a GPU " "No GPU available, using the CPU instead. Please consider using a GPU "
"to speed up training." "to speed up training.",
stacklevel=2,
) )
return "cpu" return "cpu"
@ -98,8 +99,8 @@ def load_annotations(
dataset_name: str, dataset_name: str,
ann_path: str, ann_path: str,
audio_path: str, audio_path: str,
classes_to_ignore: Optional[List[str]] = None, classes_to_ignore: List[str] | None = None,
events_of_interest: Optional[List[str]] = None, events_of_interest: List[str] | None = None,
) -> List[types.FileAnnotation]: ) -> List[types.FileAnnotation]:
train_sets: List[types.DatasetDict] = [] train_sets: List[types.DatasetDict] = []
train_sets.append( train_sets.append(

View File

@ -2,7 +2,6 @@ import argparse
import json import json
import os import os
from collections import Counter from collections import Counter
from typing import List, Optional, Tuple
import numpy as np import numpy as np
from sklearn.model_selection import StratifiedGroupKFold from sklearn.model_selection import StratifiedGroupKFold
@ -12,8 +11,8 @@ from batdetect2 import types
def print_dataset_stats( def print_dataset_stats(
data: List[types.FileAnnotation], data: list[types.FileAnnotation],
classes_to_ignore: Optional[List[str]] = None, classes_to_ignore: list[str] | None = None,
) -> Counter[str]: ) -> Counter[str]:
print("Num files:", len(data)) print("Num files:", len(data))
counts, _ = tu.get_class_names(data, classes_to_ignore) counts, _ = tu.get_class_names(data, classes_to_ignore)
@ -22,7 +21,7 @@ def print_dataset_stats(
return counts return counts
def load_file_names(file_name: str) -> List[str]: def load_file_names(file_name: str) -> list[str]:
if not os.path.isfile(file_name): if not os.path.isfile(file_name):
raise FileNotFoundError(f"Input file not found - {file_name}") raise FileNotFoundError(f"Input file not found - {file_name}")
@ -100,12 +99,12 @@ def parse_args():
def split_data( def split_data(
data: List[types.FileAnnotation], data: list[types.FileAnnotation],
train_file: str, train_file: str,
test_file: str, test_file: str,
n_splits: int = 5, n_splits: int = 5,
random_state: int = 0, random_state: int = 0,
) -> Tuple[List[types.FileAnnotation], List[types.FileAnnotation]]: ) -> tuple[list[types.FileAnnotation], list[types.FileAnnotation]]:
if train_file != "" and test_file != "": if train_file != "" and test_file != "":
# user has specifed the train / test split # user has specifed the train / test split
mapping = { mapping = {
@ -162,7 +161,7 @@ def main():
# change the names of the classes # change the names of the classes
ip_names = args.input_class_names.split(";") ip_names = args.input_class_names.split(";")
op_names = args.output_class_names.split(";") op_names = args.output_class_names.split(";")
name_dict = dict(zip(ip_names, op_names)) name_dict = dict(zip(ip_names, op_names, strict=False))
# load annotations # load annotations
data_all = tu.load_set_of_anns( data_all = tu.load_set_of_anns(

View File

@ -1,58 +1,68 @@
from typing import TYPE_CHECKING, List, Optional, Sequence from typing import Sequence
from lightning import Trainer from lightning import Trainer
from soundevent import data from soundevent import data
from batdetect2.audio import AudioConfig
from batdetect2.audio.loader import build_audio_loader from batdetect2.audio.loader import build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.inference.clips import get_clips_from_files from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.config import InferenceConfig
from batdetect2.inference.dataset import build_inference_loader from batdetect2.inference.dataset import build_inference_loader
from batdetect2.inference.lightning import InferenceModule from batdetect2.inference.lightning import InferenceModule
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.preprocess.preprocessor import build_preprocessor from batdetect2.outputs import (
from batdetect2.targets.targets import build_targets OutputsConfig,
from batdetect2.typing.postprocess import BatDetect2Prediction OutputTransformProtocol,
build_output_transform,
if TYPE_CHECKING: )
from batdetect2.config import BatDetect2Config from batdetect2.postprocess.types import ClipDetections
from batdetect2.typing import ( from batdetect2.preprocess.types import PreprocessorProtocol
AudioLoader, from batdetect2.targets.types import TargetProtocol
PreprocessorProtocol,
TargetProtocol,
)
def run_batch_inference( def run_batch_inference(
model, model: Model,
clips: Sequence[data.Clip], clips: Sequence[data.Clip],
targets: Optional["TargetProtocol"] = None, targets: TargetProtocol | None = None,
audio_loader: Optional["AudioLoader"] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: PreprocessorProtocol | None = None,
config: Optional["BatDetect2Config"] = None, audio_config: AudioConfig | None = None,
num_workers: Optional[int] = None, output_transform: OutputTransformProtocol | None = None,
batch_size: Optional[int] = None, output_config: OutputsConfig | None = None,
) -> List[BatDetect2Prediction]: inference_config: InferenceConfig | None = None,
from batdetect2.config import BatDetect2Config num_workers: int = 1,
batch_size: int | None = None,
config = config or BatDetect2Config() ) -> list[ClipDetections]:
audio_config = audio_config or AudioConfig(
audio_loader = audio_loader or build_audio_loader() samplerate=model.preprocessor.input_samplerate,
preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate,
) )
output_config = output_config or OutputsConfig()
inference_config = inference_config or InferenceConfig()
targets = targets or build_targets() audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets
output_transform = output_transform or build_output_transform(
config=output_config.transform,
targets=targets,
)
loader = build_inference_loader( loader = build_inference_loader(
clips, clips,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config.inference.loader, config=inference_config.loader,
num_workers=num_workers, num_workers=num_workers,
batch_size=batch_size, batch_size=batch_size,
) )
module = InferenceModule(model) module = InferenceModule(
model,
output_transform=output_transform,
)
trainer = Trainer(enable_checkpointing=False, logger=False) trainer = Trainer(enable_checkpointing=False, logger=False)
outputs = trainer.predict(module, loader) outputs = trainer.predict(module, loader)
return [ return [
@ -65,13 +75,18 @@ def run_batch_inference(
def process_file_list( def process_file_list(
model: Model, model: Model,
paths: Sequence[data.PathLike], paths: Sequence[data.PathLike],
config: "BatDetect2Config", targets: TargetProtocol | None = None,
targets: Optional["TargetProtocol"] = None, audio_loader: AudioLoader | None = None,
audio_loader: Optional["AudioLoader"] = None, audio_config: AudioConfig | None = None,
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: PreprocessorProtocol | None = None,
num_workers: Optional[int] = None, inference_config: InferenceConfig | None = None,
) -> List[BatDetect2Prediction]: output_config: OutputsConfig | None = None,
clip_config = config.inference.clipping output_transform: OutputTransformProtocol | None = None,
batch_size: int | None = None,
num_workers: int = 0,
) -> list[ClipDetections]:
inference_config = inference_config or InferenceConfig()
clip_config = inference_config.clipping
clips = get_clips_from_files( clips = get_clips_from_files(
paths, paths,
duration=clip_config.duration, duration=clip_config.duration,
@ -85,6 +100,10 @@ def process_file_list(
targets=targets, targets=targets,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config, batch_size=batch_size,
num_workers=num_workers, num_workers=num_workers,
output_config=output_config,
audio_config=audio_config,
output_transform=output_transform,
inference_config=inference_config,
) )

View File

@ -38,10 +38,10 @@ def get_recording_clips(
discard_empty: bool = True, discard_empty: bool = True,
) -> Sequence[data.Clip]: ) -> Sequence[data.Clip]:
start_time = 0 start_time = 0
duration = recording.duration recording_duration = recording.duration
hop = duration * (1 - overlap) hop = duration * (1 - overlap)
num_clips = int(np.ceil(duration / hop)) num_clips = int(np.ceil(recording_duration / hop))
if num_clips == 0: if num_clips == 0:
# This should only happen if the clip's duration is zero, # This should only happen if the clip's duration is zero,
@ -53,8 +53,8 @@ def get_recording_clips(
start = start_time + i * hop start = start_time + i * hop
end = start + duration end = start + duration
if end > duration: if end > recording_duration:
empty_duration = end - duration empty_duration = end - recording_duration
if empty_duration > max_empty and discard_empty: if empty_duration > max_empty and discard_empty:
# Discard clips that contain too much empty space # Discard clips that contain too much empty space

View File

@ -1,4 +1,4 @@
from typing import List, NamedTuple, Optional, Sequence from typing import NamedTuple, Sequence
import torch import torch
from loguru import logger from loguru import logger
@ -6,10 +6,11 @@ from soundevent import data
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import build_audio_loader from batdetect2.audio import build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.core import BaseConfig from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = [ __all__ = [
"InferenceDataset", "InferenceDataset",
@ -29,14 +30,14 @@ class DatasetItem(NamedTuple):
class InferenceDataset(Dataset[DatasetItem]): class InferenceDataset(Dataset[DatasetItem]):
clips: List[data.Clip] clips: list[data.Clip]
def __init__( def __init__(
self, self,
clips: Sequence[data.Clip], clips: Sequence[data.Clip],
audio_loader: AudioLoader, audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
audio_dir: Optional[data.PathLike] = None, audio_dir: data.PathLike | None = None,
): ):
self.clips = list(clips) self.clips = list(clips)
self.preprocessor = preprocessor self.preprocessor = preprocessor
@ -46,31 +47,30 @@ class InferenceDataset(Dataset[DatasetItem]):
def __len__(self): def __len__(self):
return len(self.clips) return len(self.clips)
def __getitem__(self, idx: int) -> DatasetItem: def __getitem__(self, index: int) -> DatasetItem:
clip = self.clips[idx] clip = self.clips[index]
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir) wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
wav_tensor = torch.tensor(wav).unsqueeze(0) wav_tensor = torch.tensor(wav).unsqueeze(0)
spectrogram = self.preprocessor(wav_tensor) spectrogram = self.preprocessor(wav_tensor)
return DatasetItem( return DatasetItem(
spec=spectrogram, spec=spectrogram,
idx=torch.tensor(idx), idx=torch.tensor(index),
start_time=torch.tensor(clip.start_time), start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time), end_time=torch.tensor(clip.end_time),
) )
class InferenceLoaderConfig(BaseConfig): class InferenceLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8 batch_size: int = 8
def build_inference_loader( def build_inference_loader(
clips: Sequence[data.Clip], clips: Sequence[data.Clip],
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
config: Optional[InferenceLoaderConfig] = None, config: InferenceLoaderConfig | None = None,
num_workers: Optional[int] = None, num_workers: int = 0,
batch_size: Optional[int] = None, batch_size: int | None = None,
) -> DataLoader[DatasetItem]: ) -> DataLoader[DatasetItem]:
logger.info("Building inference data loader...") logger.info("Building inference data loader...")
config = config or InferenceLoaderConfig() config = config or InferenceLoaderConfig()
@ -83,20 +83,19 @@ def build_inference_loader(
batch_size = batch_size or config.batch_size batch_size = batch_size or config.batch_size
num_workers = num_workers or config.num_workers
return DataLoader( return DataLoader(
inference_dataset, inference_dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=False,
num_workers=config.num_workers, num_workers=num_workers,
collate_fn=_collate_fn, collate_fn=_collate_fn,
) )
def build_inference_dataset( def build_inference_dataset(
clips: Sequence[data.Clip], clips: Sequence[data.Clip],
audio_loader: Optional[AudioLoader] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
) -> InferenceDataset: ) -> InferenceDataset:
if audio_loader is None: if audio_loader is None:
audio_loader = build_audio_loader() audio_loader = build_audio_loader()
@ -111,7 +110,7 @@ def build_inference_dataset(
) )
def _collate_fn(batch: List[DatasetItem]) -> DatasetItem: def _collate_fn(batch: list[DatasetItem]) -> DatasetItem:
max_width = max(item.spec.shape[-1] for item in batch) max_width = max(item.spec.shape[-1] for item in batch)
return DatasetItem( return DatasetItem(
spec=torch.stack( spec=torch.stack(

View File

@ -5,45 +5,44 @@ from torch.utils.data import DataLoader
from batdetect2.inference.dataset import DatasetItem, InferenceDataset from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.typing.postprocess import BatDetect2Prediction from batdetect2.postprocess.types import ClipDetections
class InferenceModule(LightningModule): class InferenceModule(LightningModule):
def __init__(self, model: Model): def __init__(
self,
model: Model,
output_transform: OutputTransformProtocol | None = None,
):
super().__init__() super().__init__()
self.model = model self.model = model
self.output_transform = output_transform or build_output_transform(
targets=model.targets
)
def predict_step( def predict_step(
self, self,
batch: DatasetItem, batch: DatasetItem,
batch_idx: int, batch_idx: int,
dataloader_idx: int = 0, dataloader_idx: int = 0,
) -> Sequence[BatDetect2Prediction]: ) -> Sequence[ClipDetections]:
dataset = self.get_dataset() dataset = self.get_dataset()
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx] clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
outputs = self.model.detector(batch.spec) outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor( clip_detections = self.model.postprocessor(outputs)
outputs,
start_times=[clip.start_time for clip in clips],
)
predictions = [ return [
BatDetect2Prediction( self.output_transform.to_clip_detections(
detections=clip_dets,
clip=clip, clip=clip,
predictions=to_raw_predictions(
clip_dets.numpy(),
targets=self.model.targets,
),
) )
for clip, clip_dets in zip(clips, clip_detections) for clip, clip_dets in zip(clips, clip_detections, strict=True)
] ]
return predictions
def get_dataset(self) -> InferenceDataset: def get_dataset(self) -> InferenceDataset:
dataloaders = self.trainer.predict_dataloaders dataloaders = self.trainer.predict_dataloaders
assert isinstance(dataloaders, DataLoader) assert isinstance(dataloaders, DataLoader)

View File

@ -9,10 +9,8 @@ from typing import (
Dict, Dict,
Generic, Generic,
Literal, Literal,
Optional,
Protocol, Protocol,
TypeVar, TypeVar,
Union,
) )
import numpy as np import numpy as np
@ -32,6 +30,21 @@ from batdetect2.core.configs import BaseConfig
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs" DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
__all__ = [
"AppLoggingConfig",
"BaseLoggerConfig",
"CSVLoggerConfig",
"DEFAULT_LOGS_DIR",
"DVCLiveConfig",
"LoggerConfig",
"MLFlowLoggerConfig",
"TensorBoardLoggerConfig",
"build_logger",
"enable_logging",
"get_image_logger",
"get_table_logger",
]
def enable_logging(level: int): def enable_logging(level: int):
logger.remove() logger.remove()
@ -49,14 +62,14 @@ def enable_logging(level: int):
class BaseLoggerConfig(BaseConfig): class BaseLoggerConfig(BaseConfig):
log_dir: Path = DEFAULT_LOGS_DIR log_dir: Path = DEFAULT_LOGS_DIR
experiment_name: Optional[str] = None experiment_name: str | None = None
run_name: Optional[str] = None run_name: str | None = None
class DVCLiveConfig(BaseLoggerConfig): class DVCLiveConfig(BaseLoggerConfig):
name: Literal["dvclive"] = "dvclive" name: Literal["dvclive"] = "dvclive"
prefix: str = "" prefix: str = ""
log_model: Union[bool, Literal["all"]] = False log_model: bool | Literal["all"] = False
monitor_system: bool = False monitor_system: bool = False
@ -72,22 +85,26 @@ class TensorBoardLoggerConfig(BaseLoggerConfig):
class MLFlowLoggerConfig(BaseLoggerConfig): class MLFlowLoggerConfig(BaseLoggerConfig):
name: Literal["mlflow"] = "mlflow" name: Literal["mlflow"] = "mlflow"
tracking_uri: Optional[str] = "http://localhost:5000" tracking_uri: str | None = "http://localhost:5000"
tags: Optional[dict[str, Any]] = None tags: dict[str, Any] | None = None
log_model: bool = False log_model: bool = False
LoggerConfig = Annotated[ LoggerConfig = Annotated[
Union[ DVCLiveConfig
DVCLiveConfig, | CSVLoggerConfig
CSVLoggerConfig, | TensorBoardLoggerConfig
TensorBoardLoggerConfig, | MLFlowLoggerConfig,
MLFlowLoggerConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
class AppLoggingConfig(BaseConfig):
train: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
evaluation: LoggerConfig = Field(default_factory=CSVLoggerConfig)
inference: LoggerConfig = Field(default_factory=CSVLoggerConfig)
T = TypeVar("T", bound=LoggerConfig, contravariant=True) T = TypeVar("T", bound=LoggerConfig, contravariant=True)
@ -95,20 +112,20 @@ class LoggerBuilder(Protocol, Generic[T]):
def __call__( def __call__(
self, self,
config: T, config: T,
log_dir: Optional[Path] = None, log_dir: Path | None = None,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
) -> Logger: ... ) -> Logger: ...
def create_dvclive_logger( def create_dvclive_logger(
config: DVCLiveConfig, config: DVCLiveConfig,
log_dir: Optional[Path] = None, log_dir: Path | None = None,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
) -> Logger: ) -> Logger:
try: try:
from dvclive.lightning import DVCLiveLogger # type: ignore from dvclive.lightning import DVCLiveLogger
except ImportError as error: except ImportError as error:
raise ValueError( raise ValueError(
"DVCLive is not installed and cannot be used for logging" "DVCLive is not installed and cannot be used for logging"
@ -130,9 +147,9 @@ def create_dvclive_logger(
def create_csv_logger( def create_csv_logger(
config: CSVLoggerConfig, config: CSVLoggerConfig,
log_dir: Optional[Path] = None, log_dir: Path | None = None,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
) -> Logger: ) -> Logger:
from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers import CSVLogger
@ -159,9 +176,9 @@ def create_csv_logger(
def create_tensorboard_logger( def create_tensorboard_logger(
config: TensorBoardLoggerConfig, config: TensorBoardLoggerConfig,
log_dir: Optional[Path] = None, log_dir: Path | None = None,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
) -> Logger: ) -> Logger:
from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers import TensorBoardLogger
@ -191,9 +208,9 @@ def create_tensorboard_logger(
def create_mlflow_logger( def create_mlflow_logger(
config: MLFlowLoggerConfig, config: MLFlowLoggerConfig,
log_dir: Optional[data.PathLike] = None, log_dir: data.PathLike | None = None,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
) -> Logger: ) -> Logger:
try: try:
from lightning.pytorch.loggers import MLFlowLogger from lightning.pytorch.loggers import MLFlowLogger
@ -232,9 +249,9 @@ LOGGER_FACTORY: Dict[str, LoggerBuilder] = {
def build_logger( def build_logger(
config: LoggerConfig, config: LoggerConfig,
log_dir: Optional[Path] = None, log_dir: Path | None = None,
experiment_name: Optional[str] = None, experiment_name: str | None = None,
run_name: Optional[str] = None, run_name: str | None = None,
) -> Logger: ) -> Logger:
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building logger with config: \n{}", "Building logger with config: \n{}",
@ -257,7 +274,7 @@ def build_logger(
PlotLogger = Callable[[str, Figure, int], None] PlotLogger = Callable[[str, Figure, int], None]
def get_image_logger(logger: Logger) -> Optional[PlotLogger]: def get_image_logger(logger: Logger) -> PlotLogger | None:
if isinstance(logger, TensorBoardLogger): if isinstance(logger, TensorBoardLogger):
return logger.experiment.add_figure return logger.experiment.add_figure
@ -282,7 +299,7 @@ def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
TableLogger = Callable[[str, pd.DataFrame, int], None] TableLogger = Callable[[str, pd.DataFrame, int], None]
def get_table_logger(logger: Logger) -> Optional[TableLogger]: def get_table_logger(logger: Logger) -> TableLogger | None:
if isinstance(logger, TensorBoardLogger): if isinstance(logger, TensorBoardLogger):
return partial(save_table, dir=Path(logger.log_dir)) return partial(save_table, dir=Path(logger.log_dir))

View File

@ -1,36 +1,46 @@
"""Defines and builds the neural network models used in BatDetect2. """Neural network model definitions and builders for BatDetect2.
This package (`batdetect2.models`) contains the PyTorch implementations of the This package contains the PyTorch implementations of the deep neural network
deep neural network architectures used for detecting and classifying bat calls architectures used to detect and classify bat echolocation calls in
from spectrograms. It provides modular components and configuration-driven spectrograms. Components are designed to be combined through configuration
assembly, allowing for experimentation and use of different architectural objects, making it easy to experiment with different architectures.
variants.
Key Submodules: Key submodules
- `.types`: Defines core data structures (`ModelOutput`) and abstract base --------------
classes (`BackboneModel`, `DetectionModel`) establishing interfaces. - ``blocks``: Reusable convolutional building blocks (downsampling,
- `.blocks`: Provides reusable neural network building blocks. upsampling, attention, coord-conv variants).
- `.encoder`: Defines and builds the downsampling path (encoder) of the network. - ``encoder``: The downsampling path; reduces spatial resolution whilst
- `.bottleneck`: Defines and builds the central bottleneck component. extracting increasingly abstract features.
- `.decoder`: Defines and builds the upsampling path (decoder) of the network. - ``bottleneck``: The central component connecting encoder to decoder;
- `.backbone`: Assembles the encoder, bottleneck, and decoder into a complete optionally applies self-attention along the time axis.
feature extraction backbone (e.g., a U-Net like structure). - ``decoder``: The upsampling path; reconstructs high-resolution feature
- `.heads`: Defines simple prediction heads (detection, classification, size) maps using bottleneck output and skip connections from the encoder.
that attach to the backbone features. - ``backbones``: Assembles encoder, bottleneck, and decoder into a complete
- `.detectors`: Assembles the backbone and prediction heads into the final, U-Net-style feature extraction backbone.
end-to-end `Detector` model. - ``heads``: Lightweight 1×1 convolutional heads that produce detection,
classification, and bounding-box size predictions from backbone features.
- ``detectors``: Combines a backbone with prediction heads into the final
end-to-end ``Detector`` model.
This module re-exports the most important classes, configurations, and builder The primary entry point for building a full, ready-to-use BatDetect2 model
functions from these submodules for convenient access. The primary entry point is the ``build_model`` factory function exported from this module.
for creating a standard BatDetect2 model instance is the `build_model` function
provided here.
""" """
from typing import List, Optional from typing import Literal
import torch import torch
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.models.backbones import Backbone, build_backbone from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
from batdetect2.core.configs import BaseConfig
from batdetect2.models.backbones import (
BackboneConfig,
UNetBackbone,
UNetBackboneConfig,
build_backbone,
load_backbone_config,
)
from batdetect2.models.blocks import ( from batdetect2.models.blocks import (
ConvConfig, ConvConfig,
FreqCoordConvDownConfig, FreqCoordConvDownConfig,
@ -43,10 +53,6 @@ from batdetect2.models.bottleneck import (
BottleneckConfig, BottleneckConfig,
build_bottleneck, build_bottleneck,
) )
from batdetect2.models.config import (
BackboneConfig,
load_backbone_config,
)
from batdetect2.models.decoder import ( from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG, DEFAULT_DECODER_CONFIG,
DecoderConfig, DecoderConfig,
@ -59,17 +65,20 @@ from batdetect2.models.encoder import (
build_encoder, build_encoder,
) )
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.typing import ( from batdetect2.models.types import DetectionModel
from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.postprocess.types import (
ClipDetectionsTensor, ClipDetectionsTensor,
DetectionModel,
PostprocessorProtocol, PostprocessorProtocol,
PreprocessorProtocol,
TargetProtocol,
) )
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.config import TargetConfig
from batdetect2.targets.types import TargetProtocol
__all__ = [ __all__ = [
"BBoxHead", "BBoxHead",
"Backbone", "UNetBackbone",
"BackboneConfig", "BackboneConfig",
"Bottleneck", "Bottleneck",
"BottleneckConfig", "BottleneckConfig",
@ -92,11 +101,93 @@ __all__ = [
"build_detector", "build_detector",
"load_backbone_config", "load_backbone_config",
"Model", "Model",
"ModelConfig",
"build_model", "build_model",
"build_model_with_new_targets",
] ]
class ModelConfig(BaseConfig):
"""Complete configuration describing a BatDetect2 model.
Bundles every parameter that defines a model's behaviour: the input
sample rate, backbone architecture, preprocessing pipeline,
postprocessing pipeline, and detection targets.
Attributes
----------
samplerate : int
Expected input audio sample rate in Hz. Audio must be resampled
to this rate before being passed to the model. Defaults to
``TARGET_SAMPLERATE_HZ`` (256 000 Hz).
architecture : BackboneConfig
Configuration for the encoder-decoder backbone network. Defaults
to ``UNetBackboneConfig()``.
preprocess : PreprocessingConfig
Parameters for the audio-to-spectrogram preprocessing pipeline
(STFT, frequency crop, transforms, resize). Defaults to
``PreprocessingConfig()``.
postprocess : PostprocessConfig
Parameters for converting raw model outputs into detections (NMS
kernel, thresholds, top-k limit). Defaults to
``PostprocessConfig()``.
targets : TargetConfig
Detection and classification target definitions (class list,
detection target, bounding-box mapper). Defaults to
``TargetConfig()``.
"""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
architecture: BackboneConfig = Field(default_factory=UNetBackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
@classmethod
def load(
cls,
path: PathLike,
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
targets: TargetConfig | None = None,
) -> "ModelConfig":
config = super().load(path, field, extra, strict)
if targets is None:
return config
return config.model_copy(update={"targets": targets})
class Model(torch.nn.Module): class Model(torch.nn.Module):
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing.
Combines a preprocessor, a detection model, and a postprocessor into a
single PyTorch module. Calling ``forward`` on a raw waveform tensor
returns a list of detection tensors ready for downstream use.
This class is the top-level object produced by ``build_model``. Most
users will not need to construct it directly.
Attributes
----------
detector : DetectionModel
The neural network that processes spectrograms and produces raw
detection, classification, and bounding-box outputs.
preprocessor : PreprocessorProtocol
Converts a raw waveform tensor into a spectrogram tensor accepted by
``detector``.
postprocessor : PostprocessorProtocol
Converts the raw ``ModelOutput`` from ``detector`` into a list of
per-clip detection tensors.
targets : TargetProtocol
Describes the set of target classes; used when building heads and
during training target construction.
"""
detector: DetectionModel detector: DetectionModel
preprocessor: PreprocessorProtocol preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol postprocessor: PostprocessorProtocol
@ -115,31 +206,87 @@ class Model(torch.nn.Module):
self.postprocessor = postprocessor self.postprocessor = postprocessor
self.targets = targets self.targets = targets
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]: def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
"""Run the full detection pipeline on a waveform tensor.
Converts the waveform to a spectrogram, passes it through the
detector, and postprocesses the raw outputs into detection tensors.
Parameters
----------
wav : torch.Tensor
Raw audio waveform tensor. The exact expected shape depends on
the preprocessor, but is typically ``(batch, samples)`` or
``(batch, channels, samples)``.
Returns
-------
list[ClipDetectionsTensor]
One detection tensor per clip in the batch. Each tensor encodes
the detected events (locations, class scores, sizes) for that
clip.
"""
spec = self.preprocessor(wav) spec = self.preprocessor(wav)
outputs = self.detector(spec) outputs = self.detector(spec)
return self.postprocessor(outputs) return self.postprocessor(outputs)
def build_model( def build_model(
config: Optional[BackboneConfig] = None, config: ModelConfig | None = None,
targets: Optional[TargetProtocol] = None, targets: TargetProtocol | None = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: PreprocessorProtocol | None = None,
postprocessor: Optional[PostprocessorProtocol] = None, postprocessor: PostprocessorProtocol | None = None,
): ) -> Model:
"""Build a complete, ready-to-use BatDetect2 model.
Assembles a ``Model`` instance from a ``ModelConfig`` and optional
component overrides. Any component argument left as ``None`` is built
from the configuration. Passing a pre-built component overrides the
corresponding config fields for that component only.
Parameters
----------
config : ModelConfig, optional
Full model configuration (samplerate, architecture, preprocessing,
postprocessing, targets). Defaults to ``ModelConfig()`` if not
provided.
targets : TargetProtocol, optional
Pre-built targets object. If given, overrides
``config.targets``.
preprocessor : PreprocessorProtocol, optional
Pre-built preprocessor. If given, overrides
``config.preprocess`` and ``config.samplerate`` for the
preprocessing step.
postprocessor : PostprocessorProtocol, optional
Pre-built postprocessor. If given, overrides
``config.postprocess``. When omitted and a custom
``preprocessor`` is supplied, the default postprocessor is built
using that preprocessor so that frequency and time scaling remain
consistent.
Returns
-------
Model
A fully assembled ``Model`` instance ready for inference or
training.
"""
from batdetect2.postprocess import build_postprocessor from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
config = config or BackboneConfig() config = config or ModelConfig()
targets = targets or build_targets() targets = targets or build_targets(config=config.targets)
preprocessor = preprocessor or build_preprocessor() preprocessor = preprocessor or build_preprocessor(
config=config.preprocess,
input_samplerate=config.samplerate,
)
postprocessor = postprocessor or build_postprocessor( postprocessor = postprocessor or build_postprocessor(
preprocessor=preprocessor, preprocessor=preprocessor,
config=config.postprocess,
) )
detector = build_detector( detector = build_detector(
num_classes=len(targets.class_names), num_classes=len(targets.class_names),
config=config, config=config.architecture,
) )
return Model( return Model(
detector=detector, detector=detector,
@ -147,3 +294,21 @@ def build_model(
preprocessor=preprocessor, preprocessor=preprocessor,
targets=targets, targets=targets,
) )
def build_model_with_new_targets(
model: Model,
targets: TargetProtocol,
) -> Model:
"""Build a new model with a different target set."""
detector = build_detector(
num_classes=len(targets.class_names),
backbone=model.detector.backbone,
)
return Model(
detector=detector,
postprocessor=model.postprocessor,
preprocessor=model.preprocessor,
targets=targets,
)

View File

@ -1,99 +1,176 @@
"""Assembles a complete Encoder-Decoder Backbone network. """Assembles a complete encoder-decoder backbone network.
This module defines the configuration (`BackboneConfig`) and implementation This module defines ``UNetBackboneConfig`` and the ``UNetBackbone``
(`Backbone`) for a standard encoder-decoder style neural network backbone. ``nn.Module``, together with the ``build_backbone`` and
``load_backbone_config`` helpers.
It orchestrates the connection between three main components, built using their A backbone combines three components built from the sibling modules:
respective configurations and factory functions from sibling modules:
1. Encoder (`batdetect2.models.encoder`): Downsampling path, extracts features
at multiple resolutions and provides skip connections.
2. Bottleneck (`batdetect2.models.bottleneck`): Processes features at the
lowest resolution, optionally applying self-attention.
3. Decoder (`batdetect2.models.decoder`): Upsampling path, reconstructs high-
resolution features using bottleneck features and skip connections.
The resulting `Backbone` module takes a spectrogram as input and outputs a 1. **Encoder** (``batdetect2.models.encoder``) reduces spatial resolution
final feature map, typically used by subsequent prediction heads. It includes while extracting hierarchical features and storing skip-connection tensors.
automatic padding to handle input sizes not perfectly divisible by the 2. **Bottleneck** (``batdetect2.models.bottleneck``) processes the
network's total downsampling factor. lowest-resolution features, optionally applying self-attention.
3. **Decoder** (``batdetect2.models.decoder``) restores spatial resolution
using bottleneck features and skip connections from the encoder.
The resulting ``UNetBackbone`` takes a spectrogram tensor as input and returns
a high-resolution feature map consumed by the prediction heads in
``batdetect2.models.detectors``.
Input padding is handled automatically: the backbone pads the input to be
divisible by the total downsampling factor and strips the padding from the
output so that the output spatial dimensions always match the input spatial
dimensions.
""" """
from typing import Tuple from typing import Annotated, Literal
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from pydantic import Field, TypeAdapter
from soundevent import data
from batdetect2.models.bottleneck import build_bottleneck from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.models.config import BackboneConfig from batdetect2.core.registries import (
from batdetect2.models.decoder import Decoder, build_decoder ImportConfig,
from batdetect2.models.encoder import Encoder, build_encoder Registry,
from batdetect2.typing.models import BackboneModel add_import_config,
)
from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig,
build_bottleneck,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
build_decoder,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
EncoderConfig,
build_encoder,
)
from batdetect2.models.types import (
BackboneModel,
BottleneckProtocol,
DecoderProtocol,
EncoderProtocol,
)
__all__ = [ __all__ = [
"Backbone", "BackboneImportConfig",
"UNetBackbone",
"BackboneConfig",
"load_backbone_config",
"build_backbone", "build_backbone",
] ]
class Backbone(BackboneModel): class UNetBackboneConfig(BaseConfig):
"""Encoder-Decoder Backbone Network Implementation. """Configuration for a U-Net-style encoder-decoder backbone.
Combines an Encoder, Bottleneck, and Decoder module sequentially, using All fields have sensible defaults that reproduce the standard BatDetect2
skip connections between the Encoder and Decoder. Implements the standard architecture, so you can start with ``UNetBackboneConfig()`` and override
U-Net style forward pass. Includes automatic input padding to handle only the fields you want to change.
various input sizes and a final convolutional block to adjust the output
channels.
This class inherits from `BackboneModel` and implements its `forward` Attributes
method. Instances are typically created using the `build_backbone` factory ----------
function. name : str
Discriminator field used by the backbone registry; always
``"UNetBackbone"``.
input_height : int
Number of frequency bins in the input spectrogram. Defaults to
``128``.
in_channels : int
Number of channels in the input spectrogram (e.g. ``1`` for a
standard mel-spectrogram). Defaults to ``1``.
encoder : EncoderConfig
Configuration for the downsampling path. Defaults to
``DEFAULT_ENCODER_CONFIG``.
bottleneck : BottleneckConfig
Configuration for the bottleneck. Defaults to
``DEFAULT_BOTTLENECK_CONFIG``.
decoder : DecoderConfig
Configuration for the upsampling path. Defaults to
``DEFAULT_DECODER_CONFIG``.
"""
name: Literal["UNetBackbone"] = "UNetBackbone"
input_height: int = 128
in_channels: int = 1
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
@add_import_config(backbone_registry)
class BackboneImportConfig(ImportConfig):
"""Use any callable as a backbone model.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class UNetBackbone(BackboneModel):
"""U-Net-style encoder-decoder backbone network.
Combines an encoder, a bottleneck, and a decoder into a single module
that produces a high-resolution feature map from an input spectrogram.
Skip connections from each encoder stage are added element-wise to the
corresponding decoder stage input.
Input spectrograms of arbitrary width are handled automatically: the
backbone pads the input so that its dimensions are divisible by
``divide_factor`` and removes the padding from the output.
Instances are typically created via ``build_backbone``.
Attributes Attributes
---------- ----------
input_height : int input_height : int
Expected height of the input spectrogram. Expected height (frequency bins) of the input spectrogram.
out_channels : int out_channels : int
Number of channels in the final output feature map. Number of channels in the output feature map (taken from the
encoder : Encoder decoder's output channel count).
encoder : EncoderProtocol
The instantiated encoder module. The instantiated encoder module.
decoder : Decoder decoder : DecoderProtocol
The instantiated decoder module. The instantiated decoder module.
bottleneck : nn.Module bottleneck : BottleneckProtocol
The instantiated bottleneck module. The instantiated bottleneck module.
final_conv : ConvBlock
Final convolutional block applied after the decoder.
divide_factor : int divide_factor : int
The total downsampling factor (2^depth) applied by the encoder, The total spatial downsampling factor applied by the encoder
used for automatic input padding. (``input_height // encoder.output_height``). The input width is
padded to be a multiple of this value before processing.
""" """
def __init__( def __init__(
self, self,
input_height: int, input_height: int,
encoder: Encoder, encoder: EncoderProtocol,
decoder: Decoder, decoder: DecoderProtocol,
bottleneck: nn.Module, bottleneck: BottleneckProtocol,
): ):
"""Initialize the Backbone network. """Initialise the backbone network.
Parameters Parameters
---------- ----------
input_height : int input_height : int
Expected height of the input spectrogram. Expected height (frequency bins) of the input spectrogram.
out_channels : int encoder : EncoderProtocol
Desired number of output channels for the backbone's feature map. An initialised encoder module.
encoder : Encoder decoder : DecoderProtocol
An initialized Encoder module. An initialised decoder module. Its ``output_height`` must equal
decoder : Decoder ``input_height``; a ``ValueError`` is raised otherwise.
An initialized Decoder module. bottleneck : BottleneckProtocol
bottleneck : nn.Module An initialised bottleneck module.
An initialized Bottleneck module.
Raises
------
ValueError
If component output/input channels or heights are incompatible.
""" """
super().__init__() super().__init__()
self.input_height = input_height self.input_height = input_height
@ -110,22 +187,25 @@ class Backbone(BackboneModel):
self.divide_factor = input_height // self.encoder.output_height self.divide_factor = input_height // self.encoder.output_height
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Perform the forward pass through the encoder-decoder backbone. """Produce a feature map from an input spectrogram.
Applies padding, runs encoder, bottleneck, decoder (with skip Pads the input if necessary, runs it through the encoder, then
connections), removes padding, and applies a final convolution. the bottleneck, then the decoder (incorporating encoder skip
connections), and finally removes any padding added earlier.
Parameters Parameters
---------- ----------
spec : torch.Tensor spec : torch.Tensor
Input spectrogram tensor, shape `(B, C_in, H_in, W_in)`. Must match Input spectrogram tensor, shape
`self.encoder.input_channels` and `self.input_height`. ``(B, C_in, H_in, W_in)``. ``H_in`` must equal
``self.input_height``; ``W_in`` can be any positive integer.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Output feature map tensor, shape `(B, C_out, H_in, W_in)`, where Feature map tensor, shape ``(B, C_out, H_in, W_in)``, where
`C_out` is `self.out_channels`. ``C_out`` is ``self.out_channels``. The spatial dimensions
always match those of the input.
""" """
spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor) spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor)
@ -143,95 +223,97 @@ class Backbone(BackboneModel):
return x return x
@backbone_registry.register(UNetBackboneConfig)
@staticmethod
def from_config(config: UNetBackboneConfig) -> BackboneModel:
encoder = build_encoder(
in_channels=config.in_channels,
input_height=config.input_height,
config=config.encoder,
)
def build_backbone(config: BackboneConfig) -> BackboneModel: bottleneck = build_bottleneck(
"""Factory function to build a Backbone from configuration. input_height=encoder.output_height,
in_channels=encoder.out_channels,
config=config.bottleneck,
)
Constructs the `Encoder`, `Bottleneck`, and `Decoder` components based on decoder = build_decoder(
the provided `BackboneConfig`, validates their compatibility, and assembles in_channels=bottleneck.out_channels,
them into a `Backbone` instance. input_height=encoder.output_height,
config=config.decoder,
)
if decoder.output_height != config.input_height:
raise ValueError(
"Invalid configuration: Decoder output height "
f"({decoder.output_height}) must match the Backbone input height "
f"({config.input_height}). Check encoder/decoder layer "
"configurations and input/bottleneck heights."
)
return UNetBackbone(
input_height=config.input_height,
encoder=encoder,
decoder=decoder,
bottleneck=bottleneck,
)
BackboneConfig = Annotated[
UNetBackboneConfig | BackboneImportConfig,
Field(discriminator="name"),
]
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
"""Build a backbone network from configuration.
Looks up the backbone class corresponding to ``config.name`` in the
backbone registry and calls its ``from_config`` method. If no
configuration is provided, a default ``UNetBackbone`` is returned.
Parameters Parameters
---------- ----------
config : BackboneConfig config : BackboneConfig, optional
The configuration object detailing the backbone architecture, including A configuration object describing the desired backbone. Currently
input dimensions and configurations for encoder, bottleneck, and ``UNetBackboneConfig`` is the only supported type. Defaults to
decoder. ``UNetBackboneConfig()`` if not provided.
Returns Returns
------- -------
BackboneModel BackboneModel
An initialized `Backbone` module ready for use. An initialised backbone module.
Raises
------
ValueError
If sub-component configurations are incompatible
(e.g., channel mismatches, decoder output height doesn't match backbone
input height).
NotImplementedError
If an unknown block type is specified in sub-configs.
""" """
encoder = build_encoder( config = config or UNetBackboneConfig()
in_channels=config.in_channels, return backbone_registry.build(config)
input_height=config.input_height,
config=config.encoder,
)
bottleneck = build_bottleneck(
input_height=encoder.output_height,
in_channels=encoder.out_channels,
config=config.bottleneck,
)
decoder = build_decoder(
in_channels=bottleneck.out_channels,
input_height=encoder.output_height,
config=config.decoder,
)
if decoder.output_height != config.input_height:
raise ValueError(
"Invalid configuration: Decoder output height "
f"({decoder.output_height}) must match the Backbone input height "
f"({config.input_height}). Check encoder/decoder layer "
"configurations and input/bottleneck heights."
)
return Backbone(
input_height=config.input_height,
encoder=encoder,
decoder=decoder,
bottleneck=bottleneck,
)
def _pad_adjust( def _pad_adjust(
spec: torch.Tensor, spec: torch.Tensor,
factor: int = 32, factor: int = 32,
) -> Tuple[torch.Tensor, int, int]: ) -> tuple[torch.Tensor, int, int]:
"""Pad tensor height and width to be divisible by a factor. """Pad a tensor's height and width to be divisible by ``factor``.
Calculates the required padding for the last two dimensions (H, W) to make Adds zero-padding to the bottom and right edges of the tensor so that
them divisible by `factor` and applies right/bottom padding using both dimensions are exact multiples of ``factor``. If both dimensions
`torch.nn.functional.pad`. are already divisible, the tensor is returned unchanged.
Parameters Parameters
---------- ----------
spec : torch.Tensor spec : torch.Tensor
Input tensor, typically shape `(B, C, H, W)`. Input tensor, typically shape ``(B, C, H, W)``.
factor : int, default=32 factor : int, default=32
The factor to make height and width divisible by. The factor that both H and W should be divisible by after padding.
Returns Returns
------- -------
Tuple[torch.Tensor, int, int] tuple[torch.Tensor, int, int]
A tuple containing: - Padded tensor.
- The padded tensor. - Number of rows added to the height (``h_pad``).
- The amount of padding added to height (`h_pad`). - Number of columns added to the width (``w_pad``).
- The amount of padding added to width (`w_pad`).
""" """
h, w = spec.shape[2:] h, w = spec.shape[-2:]
h_pad = -h % factor h_pad = -h % factor
w_pad = -w % factor w_pad = -w % factor
@ -244,28 +326,71 @@ def _pad_adjust(
def _restore_pad( def _restore_pad(
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0 x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
) -> torch.Tensor: ) -> torch.Tensor:
"""Remove padding added by _pad_adjust. """Remove padding previously added by ``_pad_adjust``.
Removes padding from the bottom and right edges of the tensor. Trims ``h_pad`` rows from the bottom and ``w_pad`` columns from the
right of the tensor, restoring its original spatial dimensions.
Parameters Parameters
---------- ----------
x : torch.Tensor x : torch.Tensor
Padded tensor, typically shape `(B, C, H_padded, W_padded)`. Padded tensor, typically shape ``(B, C, H_padded, W_padded)``.
h_pad : int, default=0 h_pad : int, default=0
Amount of padding previously added to the height (bottom). Number of rows to remove from the bottom.
w_pad : int, default=0 w_pad : int, default=0
Amount of padding previously added to the width (right). Number of columns to remove from the right.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Tensor with padding removed, shape `(B, C, H_original, W_original)`. Tensor with padding removed, shape
``(B, C, H_padded - h_pad, W_padded - w_pad)``.
""" """
if h_pad > 0: if h_pad > 0:
x = x[:, :, :-h_pad, :] x = x[..., :-h_pad, :]
if w_pad > 0: if w_pad > 0:
x = x[:, :, :, :-w_pad] x = x[..., :-w_pad]
return x return x
def load_backbone_config(
path: data.PathLike,
field: str | None = None,
) -> BackboneConfig:
"""Load a backbone configuration from a YAML or JSON file.
Reads the file at ``path``, optionally descends into a named sub-field,
and validates the result against the ``BackboneConfig`` discriminated
union.
Parameters
----------
path : PathLike
Path to the configuration file. Both YAML and JSON formats are
supported.
field : str, optional
Dot-separated key path to the sub-field that contains the backbone
configuration (e.g. ``"model"``). If ``None``, the root of the
file is used.
Returns
-------
BackboneConfig
A validated backbone configuration object (currently always a
``UNetBackboneConfig`` instance).
Raises
------
FileNotFoundError
If ``path`` does not exist.
ValidationError
If the loaded data does not conform to a known ``BackboneConfig``
schema.
"""
return load_config(
path,
schema=TypeAdapter(BackboneConfig),
field=field,
)

View File

@ -1,42 +1,63 @@
"""Commonly used neural network building blocks for BatDetect2 models. """Reusable convolutional building blocks for BatDetect2 models.
This module provides various reusable `torch.nn.Module` subclasses that form This module provides a collection of ``torch.nn.Module`` subclasses that form
the fundamental building blocks for constructing convolutional neural network the fundamental building blocks for the encoder-decoder backbone used in
architectures, particularly encoder-decoder backbones used in BatDetect2. BatDetect2. All blocks follow a consistent interface: they store
``in_channels`` and ``out_channels`` as attributes and implement a
``get_output_height`` method that reports how a given input height maps to an
output height (e.g., halved by downsampling blocks, doubled by upsampling
blocks).
It includes standard components like basic convolutional blocks (`ConvBlock`), Available block families
blocks incorporating downsampling (`StandardConvDownBlock`), and blocks with ------------------------
upsampling (`StandardConvUpBlock`). Standard blocks
``ConvBlock`` convolution + batch normalisation + ReLU, no change in
spatial resolution.
Additionally, it features specialized layers investigated in BatDetect2 Downsampling blocks
research: ``StandardConvDownBlock`` convolution then 2×2 max-pooling, halves H
and W.
``FreqCoordConvDownBlock`` like ``StandardConvDownBlock`` but prepends
a normalised frequency-coordinate channel before the convolution
(CoordConv concept), helping filters learn frequency-position-dependent
patterns.
- `SelfAttention`: Applies self-attention along the time dimension, enabling Upsampling blocks
the model to weigh information across the entire temporal context, often ``StandardConvUpBlock`` bilinear interpolation then convolution,
used in the bottleneck of an encoder-decoder. doubles H and W.
- `FreqCoordConvDownBlock` / `FreqCoordConvUpBlock`: Implement the "CoordConv" ``FreqCoordConvUpBlock`` like ``StandardConvUpBlock`` but prepends a
concept by concatenating normalized frequency coordinate information as an frequency-coordinate channel after upsampling.
extra channel to the input of convolutional layers. This explicitly provides
spatial frequency information to filters, potentially enabling them to learn
frequency-dependent patterns more effectively.
These blocks can be utilized directly in custom PyTorch model definitions or Bottleneck blocks
assembled into larger architectures. ``VerticalConv`` 1-D convolution whose kernel spans the entire
frequency axis, collapsing H to 1 whilst preserving W.
``SelfAttention`` scaled dot-product self-attention along the time
axis; typically follows a ``VerticalConv``.
A unified factory function `build_layer_from_config` allows creating instances Group block
of these blocks based on configuration objects. ``LayerGroup`` chains several blocks sequentially into one unit,
useful when a single encoder or decoder "stage" requires more than one
operation.
Factory function
----------------
``build_layer`` creates any of the above blocks from the matching
configuration object (one of the ``*Config`` classes exported here), using
a discriminated-union ``name`` field to dispatch to the correct class.
""" """
from typing import Annotated, List, Literal, Tuple, Union from typing import Annotated, Literal
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from pydantic import Field from pydantic import Field
from torch import nn from torch import nn
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
__all__ = [ __all__ = [
"BlockImportConfig",
"ConvBlock", "ConvBlock",
"LayerGroupConfig", "LayerGroupConfig",
"VerticalConv", "VerticalConv",
@ -51,63 +72,125 @@ __all__ = [
"FreqCoordConvUpConfig", "FreqCoordConvUpConfig",
"StandardConvUpConfig", "StandardConvUpConfig",
"LayerConfig", "LayerConfig",
"build_layer_from_config", "build_layer",
] ]
class Block(nn.Module):
"""Abstract base class for all BatDetect2 building blocks.
Subclasses must set ``in_channels`` and ``out_channels`` as integer
attributes so that factory functions can wire blocks together without
inspecting configuration objects at runtime. They may also override
``get_output_height`` when the block changes the height dimension (e.g.
downsampling or upsampling blocks).
Attributes
----------
in_channels : int
Number of channels expected in the input tensor.
out_channels : int
Number of channels produced in the output tensor.
"""
in_channels: int
out_channels: int
def get_output_height(self, input_height: int) -> int:
"""Return the output height for a given input height.
The default implementation returns ``input_height`` unchanged,
which is correct for blocks that do not alter spatial resolution.
Override this in downsampling (returns ``input_height // 2``) or
upsampling (returns ``input_height * 2``) subclasses.
Parameters
----------
input_height : int
Height (number of frequency bins) of the input feature map.
Returns
-------
int
Height of the output feature map.
"""
return input_height
block_registry: Registry[Block, [int, int]] = Registry("block")
@add_import_config(block_registry)
class BlockImportConfig(ImportConfig):
"""Use any callable as a model block.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class SelfAttentionConfig(BaseConfig): class SelfAttentionConfig(BaseConfig):
"""Configuration for a ``SelfAttention`` block.
Attributes
----------
name : str
Discriminator field; always ``"SelfAttention"``.
attention_channels : int
Dimensionality of the query, key, and value projections.
temperature : float
Scaling factor applied to the weighted values before the final
linear projection. Defaults to ``1``.
"""
name: Literal["SelfAttention"] = "SelfAttention" name: Literal["SelfAttention"] = "SelfAttention"
attention_channels: int attention_channels: int
temperature: float = 1 temperature: float = 1
class SelfAttention(nn.Module): class SelfAttention(Block):
"""Self-Attention mechanism operating along the time dimension. """Self-attention block operating along the time axis.
This module implements a scaled dot-product self-attention mechanism, Applies a scaled dot-product self-attention mechanism across the time
specifically designed here to operate across the time steps of an input steps of an input feature map. Before attention is computed the height
feature map, typically after spatial dimensions (like frequency) have been dimension (frequency axis) is expected to have been reduced to 1, e.g.
condensed or squeezed. by a preceding ``VerticalConv`` layer.
By calculating attention weights between all pairs of time steps, it allows For each time step the block computes query, key, and value projections
the model to capture long-range temporal dependencies and focus on relevant with learned linear weights, then calculates attention weights from the
parts of the sequence. It's often employed in the bottleneck or querykey dot products divided by ``temperature × attention_channels``.
intermediate layers of an encoder-decoder architecture to integrate global The weighted sum of values is projected back to ``in_channels`` via a
temporal context. final linear layer, and the height dimension is restored so that the
output shape matches the input shape.
The implementation uses linear projections to create query, key, and value
representations, computes scaled dot-product attention scores, applies
softmax, and produces an output by weighting the values according to the
attention scores, followed by a final linear projection. Positional encoding
is not explicitly included in this block.
Parameters Parameters
---------- ----------
in_channels : int in_channels : int
Number of input channels (features per time step after spatial squeeze). Number of input channels (features per time step). The output will
also have ``in_channels`` channels.
attention_channels : int attention_channels : int
Number of channels for the query, key, and value projections. Also the Dimensionality of the query, key, and value projections.
dimension of the output projection's input.
temperature : float, default=1.0 temperature : float, default=1.0
Scaling factor applied *before* the final projection layer. Can be used Divisor applied together with ``attention_channels`` when scaling
to adjust the sharpness or focus of the attention mechanism, although the dot-product scores before softmax. Larger values produce softer
scaling within the softmax (dividing by sqrt(dim)) is more common for (more uniform) attention distributions.
standard transformers. Here it scales the weighted values.
Attributes Attributes
---------- ----------
key_fun : nn.Linear key_fun : nn.Linear
Linear layer for key projection. Linear projection for keys.
value_fun : nn.Linear value_fun : nn.Linear
Linear layer for value projection. Linear projection for values.
query_fun : nn.Linear query_fun : nn.Linear
Linear layer for query projection. Linear projection for queries.
pro_fun : nn.Linear pro_fun : nn.Linear
Final linear projection layer applied after attention weighting. Final linear projection applied to the attended values.
temperature : float temperature : float
Scaling factor applied before final projection. Scaling divisor used when computing attention scores.
att_dim : int att_dim : int
Dimensionality of the attention space (`attention_channels`). Dimensionality of the attention space (``attention_channels``).
""" """
def __init__( def __init__(
@ -117,10 +200,13 @@ class SelfAttention(nn.Module):
temperature: float = 1.0, temperature: float = 1.0,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
# Note, does not encode position information (absolute or relative) # Note, does not encode position information (absolute or relative)
self.temperature = temperature self.temperature = temperature
self.att_dim = attention_channels self.att_dim = attention_channels
self.output_channels = in_channels
self.key_fun = nn.Linear(in_channels, attention_channels) self.key_fun = nn.Linear(in_channels, attention_channels)
self.value_fun = nn.Linear(in_channels, attention_channels) self.value_fun = nn.Linear(in_channels, attention_channels)
@ -133,20 +219,16 @@ class SelfAttention(nn.Module):
Parameters Parameters
---------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor, expected shape `(B, C, H, W)`, where H is typically Input tensor with shape ``(B, C, 1, W)``. The height dimension
squeezed (e.g., H=1 after a `VerticalConv` or pooling) before must be 1 (i.e. the frequency axis should already have been
applying attention along the W (time) dimension. collapsed by a preceding ``VerticalConv`` layer).
Returns Returns
------- -------
torch.Tensor torch.Tensor
Output tensor of the same shape as the input `(B, C, H, W)`, where Output tensor with the same shape ``(B, C, 1, W)`` as the
attention has been applied across the W dimension. input, with each time step updated by attended context from all
other time steps.
Raises
------
RuntimeError
If input tensor dimensions are incompatible with operations.
""" """
x = x.squeeze(2).permute(0, 2, 1) x = x.squeeze(2).permute(0, 2, 1)
@ -175,6 +257,22 @@ class SelfAttention(nn.Module):
return op return op
def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor: def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
"""Return the softmax attention weight matrix.
Useful for visualising which time steps attend to which others.
Parameters
----------
x : torch.Tensor
Input tensor with shape ``(B, C, 1, W)``.
Returns
-------
torch.Tensor
Attention weight matrix with shape ``(B, W, W)``. Entry
``[b, i, j]`` is the attention weight that time step ``i``
assigns to time step ``j`` in batch item ``b``.
"""
x = x.squeeze(2).permute(0, 2, 1) x = x.squeeze(2).permute(0, 2, 1)
key = torch.matmul( key = torch.matmul(
@ -190,6 +288,19 @@ class SelfAttention(nn.Module):
att_weights = F.softmax(kk_qq, 1) att_weights = F.softmax(kk_qq, 1)
return att_weights return att_weights
@block_registry.register(SelfAttentionConfig)
@staticmethod
def from_config(
config: SelfAttentionConfig,
input_channels: int,
input_height: int,
) -> "SelfAttention":
return SelfAttention(
in_channels=input_channels,
attention_channels=config.attention_channels,
temperature=config.temperature,
)
class ConvConfig(BaseConfig): class ConvConfig(BaseConfig):
"""Configuration for a basic ConvBlock.""" """Configuration for a basic ConvBlock."""
@ -207,7 +318,7 @@ class ConvConfig(BaseConfig):
"""Padding size.""" """Padding size."""
class ConvBlock(nn.Module): class ConvBlock(Block):
"""Basic Convolutional Block. """Basic Convolutional Block.
A standard building block consisting of a 2D convolution, followed by A standard building block consisting of a 2D convolution, followed by
@ -235,6 +346,8 @@ class ConvBlock(nn.Module):
pad_size: int = 1, pad_size: int = 1,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_channels, in_channels,
out_channels, out_channels,
@ -258,8 +371,37 @@ class ConvBlock(nn.Module):
""" """
return F.relu_(self.batch_norm(self.conv(x))) return F.relu_(self.batch_norm(self.conv(x)))
@block_registry.register(ConvConfig)
@staticmethod
def from_config(
config: ConvConfig,
input_channels: int,
input_height: int,
):
return ConvBlock(
in_channels=input_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
)
class VerticalConv(nn.Module):
class VerticalConvConfig(BaseConfig):
"""Configuration for a ``VerticalConv`` block.
Attributes
----------
name : str
Discriminator field; always ``"VerticalConv"``.
channels : int
Number of output channels produced by the vertical convolution.
"""
name: Literal["VerticalConv"] = "VerticalConv"
channels: int
class VerticalConv(Block):
"""Convolutional layer that aggregates features across the entire height. """Convolutional layer that aggregates features across the entire height.
Applies a 2D convolution using a kernel with shape `(input_height, 1)`. Applies a 2D convolution using a kernel with shape `(input_height, 1)`.
@ -288,6 +430,8 @@ class VerticalConv(nn.Module):
input_height: int, input_height: int,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
@ -312,6 +456,19 @@ class VerticalConv(nn.Module):
""" """
return F.relu_(self.bn(self.conv(x))) return F.relu_(self.bn(self.conv(x)))
@block_registry.register(VerticalConvConfig)
@staticmethod
def from_config(
config: VerticalConvConfig,
input_channels: int,
input_height: int,
):
return VerticalConv(
in_channels=input_channels,
out_channels=config.channels,
input_height=input_height,
)
class FreqCoordConvDownConfig(BaseConfig): class FreqCoordConvDownConfig(BaseConfig):
"""Configuration for a FreqCoordConvDownBlock.""" """Configuration for a FreqCoordConvDownBlock."""
@ -329,7 +486,7 @@ class FreqCoordConvDownConfig(BaseConfig):
"""Padding size.""" """Padding size."""
class FreqCoordConvDownBlock(nn.Module): class FreqCoordConvDownBlock(Block):
"""Downsampling Conv Block incorporating Frequency Coordinate features. """Downsampling Conv Block incorporating Frequency Coordinate features.
This block implements a downsampling step (Conv2d + MaxPool2d) commonly This block implements a downsampling step (Conv2d + MaxPool2d) commonly
@ -368,6 +525,8 @@ class FreqCoordConvDownBlock(nn.Module):
pad_size: int = 1, pad_size: int = 1,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.coords = nn.Parameter( self.coords = nn.Parameter(
torch.linspace(-1, 1, input_height)[None, None, ..., None], torch.linspace(-1, 1, input_height)[None, None, ..., None],
@ -402,6 +561,24 @@ class FreqCoordConvDownBlock(nn.Module):
x = F.relu(self.batch_norm(x), inplace=True) x = F.relu(self.batch_norm(x), inplace=True)
return x return x
def get_output_height(self, input_height: int) -> int:
return input_height // 2
@block_registry.register(FreqCoordConvDownConfig)
@staticmethod
def from_config(
config: FreqCoordConvDownConfig,
input_channels: int,
input_height: int,
):
return FreqCoordConvDownBlock(
in_channels=input_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
)
class StandardConvDownConfig(BaseConfig): class StandardConvDownConfig(BaseConfig):
"""Configuration for a StandardConvDownBlock.""" """Configuration for a StandardConvDownBlock."""
@ -419,7 +596,7 @@ class StandardConvDownConfig(BaseConfig):
"""Padding size.""" """Padding size."""
class StandardConvDownBlock(nn.Module): class StandardConvDownBlock(Block):
"""Standard Downsampling Convolutional Block. """Standard Downsampling Convolutional Block.
A basic downsampling block consisting of a 2D convolution, followed by A basic downsampling block consisting of a 2D convolution, followed by
@ -447,6 +624,8 @@ class StandardConvDownBlock(nn.Module):
pad_size: int = 1, pad_size: int = 1,
): ):
super(StandardConvDownBlock, self).__init__() super(StandardConvDownBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_channels, in_channels,
out_channels, out_channels,
@ -472,6 +651,23 @@ class StandardConvDownBlock(nn.Module):
x = F.max_pool2d(self.conv(x), 2, 2) x = F.max_pool2d(self.conv(x), 2, 2)
return F.relu(self.batch_norm(x), inplace=True) return F.relu(self.batch_norm(x), inplace=True)
def get_output_height(self, input_height: int) -> int:
return input_height // 2
@block_registry.register(StandardConvDownConfig)
@staticmethod
def from_config(
config: StandardConvDownConfig,
input_channels: int,
input_height: int,
):
return StandardConvDownBlock(
in_channels=input_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
)
class FreqCoordConvUpConfig(BaseConfig): class FreqCoordConvUpConfig(BaseConfig):
"""Configuration for a FreqCoordConvUpBlock.""" """Configuration for a FreqCoordConvUpBlock."""
@ -488,8 +684,14 @@ class FreqCoordConvUpConfig(BaseConfig):
pad_size: int = 1 pad_size: int = 1
"""Padding size.""" """Padding size."""
up_mode: str = "bilinear"
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
class FreqCoordConvUpBlock(nn.Module): up_scale: tuple[int, int] = (2, 2)
"""Scaling factor for height and width during upsampling."""
class FreqCoordConvUpBlock(Block):
"""Upsampling Conv Block incorporating Frequency Coordinate features. """Upsampling Conv Block incorporating Frequency Coordinate features.
This block implements an upsampling step followed by a convolution, This block implements an upsampling step followed by a convolution,
@ -504,22 +706,22 @@ class FreqCoordConvUpBlock(nn.Module):
Parameters Parameters
---------- ----------
in_channels : int in_channels
Number of channels in the input tensor (before upsampling). Number of channels in the input tensor (before upsampling).
out_channels : int out_channels
Number of output channels after the convolution. Number of output channels after the convolution.
input_height : int input_height
Height (H dimension, frequency bins) of the tensor *before* upsampling. Height (H dimension, frequency bins) of the tensor *before* upsampling.
Used to calculate the height for coordinate feature generation after Used to calculate the height for coordinate feature generation after
upsampling. upsampling.
kernel_size : int, default=3 kernel_size
Size of the square convolutional kernel. Size of the square convolutional kernel.
pad_size : int, default=1 pad_size
Padding added before convolution. Padding added before convolution.
up_mode : str, default="bilinear" up_mode
Interpolation mode for upsampling (e.g., "nearest", "bilinear", Interpolation mode for upsampling (e.g., "nearest", "bilinear",
"bicubic"). "bicubic").
up_scale : Tuple[int, int], default=(2, 2) up_scale
Scaling factor for height and width during upsampling Scaling factor for height and width during upsampling
(typically (2, 2)). (typically (2, 2)).
""" """
@ -532,9 +734,11 @@ class FreqCoordConvUpBlock(nn.Module):
kernel_size: int = 3, kernel_size: int = 3,
pad_size: int = 1, pad_size: int = 1,
up_mode: str = "bilinear", up_mode: str = "bilinear",
up_scale: Tuple[int, int] = (2, 2), up_scale: tuple[int, int] = (2, 2),
): ):
super().__init__() super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.up_scale = up_scale self.up_scale = up_scale
self.up_mode = up_mode self.up_mode = up_mode
@ -581,6 +785,26 @@ class FreqCoordConvUpBlock(nn.Module):
op = F.relu(self.batch_norm(op), inplace=True) op = F.relu(self.batch_norm(op), inplace=True)
return op return op
def get_output_height(self, input_height: int) -> int:
return input_height * 2
@block_registry.register(FreqCoordConvUpConfig)
@staticmethod
def from_config(
config: FreqCoordConvUpConfig,
input_channels: int,
input_height: int,
):
return FreqCoordConvUpBlock(
in_channels=input_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
up_mode=config.up_mode,
up_scale=config.up_scale,
)
class StandardConvUpConfig(BaseConfig): class StandardConvUpConfig(BaseConfig):
"""Configuration for a StandardConvUpBlock.""" """Configuration for a StandardConvUpBlock."""
@ -597,8 +821,14 @@ class StandardConvUpConfig(BaseConfig):
pad_size: int = 1 pad_size: int = 1
"""Padding size.""" """Padding size."""
up_mode: str = "bilinear"
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
class StandardConvUpBlock(nn.Module): up_scale: tuple[int, int] = (2, 2)
"""Scaling factor for height and width during upsampling."""
class StandardConvUpBlock(Block):
"""Standard Upsampling Convolutional Block. """Standard Upsampling Convolutional Block.
A basic upsampling block used in CNN decoders. It first upsamples the input A basic upsampling block used in CNN decoders. It first upsamples the input
@ -609,17 +839,17 @@ class StandardConvUpBlock(nn.Module):
Parameters Parameters
---------- ----------
in_channels : int in_channels
Number of channels in the input tensor (before upsampling). Number of channels in the input tensor (before upsampling).
out_channels : int out_channels
Number of output channels after the convolution. Number of output channels after the convolution.
kernel_size : int, default=3 kernel_size
Size of the square convolutional kernel. Size of the square convolutional kernel.
pad_size : int, default=1 pad_size
Padding added before convolution. Padding added before convolution.
up_mode : str, default="bilinear" up_mode
Interpolation mode for upsampling (e.g., "nearest", "bilinear"). Interpolation mode for upsampling (e.g., "nearest", "bilinear").
up_scale : Tuple[int, int], default=(2, 2) up_scale
Scaling factor for height and width during upsampling. Scaling factor for height and width during upsampling.
""" """
@ -630,9 +860,11 @@ class StandardConvUpBlock(nn.Module):
kernel_size: int = 3, kernel_size: int = 3,
pad_size: int = 1, pad_size: int = 1,
up_mode: str = "bilinear", up_mode: str = "bilinear",
up_scale: Tuple[int, int] = (2, 2), up_scale: tuple[int, int] = (2, 2),
): ):
super(StandardConvUpBlock, self).__init__() super(StandardConvUpBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.up_scale = up_scale self.up_scale = up_scale
self.up_mode = up_mode self.up_mode = up_mode
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
@ -669,155 +901,195 @@ class StandardConvUpBlock(nn.Module):
op = F.relu(self.batch_norm(op), inplace=True) op = F.relu(self.batch_norm(op), inplace=True)
return op return op
def get_output_height(self, input_height: int) -> int:
return input_height * 2
@block_registry.register(StandardConvUpConfig)
@staticmethod
def from_config(
config: StandardConvUpConfig,
input_channels: int,
input_height: int,
):
return StandardConvUpBlock(
in_channels=input_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
up_mode=config.up_mode,
up_scale=config.up_scale,
)
class LayerGroupConfig(BaseConfig):
"""Configuration for a ``LayerGroup`` — a sequential chain of blocks.
Use this when a single encoder or decoder stage needs more than one
block. The blocks are executed in the order they appear in ``layers``,
with channel counts and heights propagated automatically.
Attributes
----------
name : str
Discriminator field; always ``"LayerGroup"``.
layers : List[LayerConfig]
Ordered list of block configurations to chain together.
"""
name: Literal["LayerGroup"] = "LayerGroup"
layers: list["LayerConfig"]
LayerConfig = Annotated[ LayerConfig = Annotated[
Union[ ConvConfig
ConvConfig, | FreqCoordConvDownConfig
FreqCoordConvDownConfig, | StandardConvDownConfig
StandardConvDownConfig, | FreqCoordConvUpConfig
FreqCoordConvUpConfig, | StandardConvUpConfig
StandardConvUpConfig, | SelfAttentionConfig
SelfAttentionConfig, | LayerGroupConfig,
"LayerGroupConfig",
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""Type alias for the discriminated union of block configuration models.""" """Type alias for the discriminated union of block configuration models."""
class LayerGroupConfig(BaseConfig): class LayerGroup(nn.Module):
name: Literal["LayerGroup"] = "LayerGroup" """Sequential chain of blocks that acts as a single composite block.
layers: List[LayerConfig]
Wraps multiple ``Block`` instances in an ``nn.Sequential`` container,
exposing the same ``in_channels``, ``out_channels``, and
``get_output_height`` interface as a regular ``Block`` so it can be
used transparently wherever a single block is expected.
Instances are typically constructed by ``build_layer`` when given a
``LayerGroupConfig``; you rarely need to create them directly.
Parameters
----------
layers : list[Block]
Pre-built block instances to chain, in execution order.
input_height : int
Height of the tensor entering the first block.
input_channels : int
Number of channels in the tensor entering the first block.
Attributes
----------
in_channels : int
Number of input channels (taken from the first block).
out_channels : int
Number of output channels (taken from the last block).
layers : nn.Sequential
The wrapped sequence of block modules.
"""
def __init__(
self,
layers: list[Block],
input_height: int,
input_channels: int,
):
super().__init__()
self.in_channels = input_channels
self.out_channels = (
layers[-1].out_channels if layers else input_channels
)
self.layers = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Pass input through all blocks in sequence.
Parameters
----------
x : torch.Tensor
Input feature map, shape ``(B, C_in, H, W)``.
Returns
-------
torch.Tensor
Output feature map after all blocks have been applied.
"""
return self.layers(x)
def get_output_height(self, input_height: int) -> int:
"""Compute the output height by propagating through all blocks.
Parameters
----------
input_height : int
Height of the input feature map.
Returns
-------
int
Height after all blocks in the group have been applied.
"""
for block in self.layers:
input_height = block.get_output_height(input_height) # type: ignore
return input_height
@block_registry.register(LayerGroupConfig)
@staticmethod
def from_config(
config: LayerGroupConfig,
input_channels: int,
input_height: int,
):
layers = []
for layer_config in config.layers:
layer = build_layer(
input_height=input_height,
in_channels=input_channels,
config=layer_config,
)
layers.append(layer)
input_height = layer.get_output_height(input_height)
input_channels = layer.out_channels
return LayerGroup(
layers=layers,
input_height=input_height,
input_channels=input_channels,
)
def build_layer_from_config( def build_layer(
input_height: int, input_height: int,
in_channels: int, in_channels: int,
config: LayerConfig, config: LayerConfig,
) -> Tuple[nn.Module, int, int]: ) -> Block:
"""Factory function to build a specific nn.Module block from its config. """Build a block from its configuration object.
Takes configuration object (one of the types included in the `LayerConfig` Looks up the block class corresponding to ``config.name`` in the
union) and instantiates the corresponding nn.Module block with the correct internal block registry and instantiates it with the given input
parameters derived from the config and the current pipeline state dimensions. This is the standard way to construct blocks when
(`input_height`, `in_channels`). assembling an encoder or decoder from a configuration file.
It uses the `name` field within the `config` object to determine
which block class to instantiate.
Parameters Parameters
---------- ----------
input_height : int input_height : int
Height (frequency bins) of the input tensor *to this layer*. Height (number of frequency bins) of the input tensor to this
block. Required for blocks whose kernel size depends on the input
height (e.g. ``VerticalConv``) and for coordinate-aware blocks.
in_channels : int in_channels : int
Number of channels in the input tensor *to this layer*. Number of channels in the input tensor to this block.
config : LayerConfig config : LayerConfig
A Pydantic configuration object for the desired block (e.g., an A configuration object for the desired block type. The ``name``
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified field selects the block class; remaining fields supply its
by its `name` field. parameters.
Returns Returns
------- -------
Tuple[nn.Module, int, int] Block
A tuple containing: An initialised block module ready to be added to an
- The instantiated `nn.Module` block. ``nn.Sequential`` or ``nn.ModuleList``.
- The number of output channels produced by the block.
- The calculated height of the output produced by the block.
Raises Raises
------ ------
NotImplementedError KeyError
If the `config.name` does not correspond to a known block type. If ``config.name`` does not correspond to a registered block type.
ValueError ValueError
If parameters derived from the config are invalid for the block. If the configuration parameters are invalid for the chosen block.
""" """
if config.name == "ConvBlock": return block_registry.build(config, in_channels, input_height)
return (
ConvBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height,
)
if config.name == "FreqCoordConvDown":
return (
FreqCoordConvDownBlock(
in_channels=in_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height // 2,
)
if config.name == "StandardConvDown":
return (
StandardConvDownBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height // 2,
)
if config.name == "FreqCoordConvUp":
return (
FreqCoordConvUpBlock(
in_channels=in_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height * 2,
)
if config.name == "StandardConvUp":
return (
StandardConvUpBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height * 2,
)
if config.name == "SelfAttention":
return (
SelfAttention(
in_channels=in_channels,
attention_channels=config.attention_channels,
temperature=config.temperature,
),
config.attention_channels,
input_height,
)
if config.name == "LayerGroup":
current_channels = in_channels
current_height = input_height
blocks = []
for block_config in config.layers:
block, current_channels, current_height = build_layer_from_config(
input_height=current_height,
in_channels=current_channels,
config=block_config,
)
blocks.append(block)
return nn.Sequential(*blocks), current_channels, current_height
raise NotImplementedError(f"Unknown block type {config.name}")

View File

@ -1,20 +1,24 @@
"""Defines the Bottleneck component of an Encoder-Decoder architecture. """Bottleneck component for encoder-decoder network architectures.
This module provides the configuration (`BottleneckConfig`) and The bottleneck sits between the encoder (downsampling path) and the decoder
`torch.nn.Module` implementations (`Bottleneck`, `BottleneckAttn`) for the (upsampling path) and processes the lowest-resolution, highest-channel feature
bottleneck layer(s) that typically connect the Encoder (downsampling path) and map produced by the encoder.
Decoder (upsampling path) in networks like U-Nets.
The bottleneck processes the lowest-resolution, highest-dimensionality feature This module provides:
map produced by the Encoder. This module offers a configurable option to include
a `SelfAttention` layer within the bottleneck, allowing the model to capture
global temporal context before features are passed to the Decoder.
A factory function `build_bottleneck` constructs the appropriate bottleneck - ``BottleneckConfig`` configuration dataclass describing the number of
module based on the provided configuration. internal channels and an optional sequence of additional layers (currently
only ``SelfAttention`` is supported).
- ``Bottleneck`` the ``torch.nn.Module`` implementation. It first applies a
``VerticalConv`` to collapse the frequency axis to a single bin, optionally
runs one or more additional layers (e.g. self-attention along the time axis),
then repeats the output along the height dimension to restore the original
frequency resolution before passing features to the decoder.
- ``build_bottleneck`` factory function that constructs a ``Bottleneck``
instance from a ``BottleneckConfig`` and the encoder's output dimensions.
""" """
from typing import Annotated, List, Optional, Union from typing import Annotated, List
import torch import torch
from pydantic import Field from pydantic import Field
@ -22,10 +26,12 @@ from torch import nn
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.models.blocks import ( from batdetect2.models.blocks import (
Block,
SelfAttentionConfig, SelfAttentionConfig,
VerticalConv, VerticalConv,
build_layer_from_config, build_layer,
) )
from batdetect2.models.types import BottleneckProtocol
__all__ = [ __all__ = [
"BottleneckConfig", "BottleneckConfig",
@ -34,43 +40,52 @@ __all__ = [
] ]
class Bottleneck(nn.Module): class Bottleneck(Block):
"""Base Bottleneck module for Encoder-Decoder architectures. """Bottleneck module for encoder-decoder architectures.
This implementation represents the simplest bottleneck structure Processes the lowest-resolution feature map that links the encoder and
considered, primarily consisting of a `VerticalConv` layer. This layer decoder. The sequence of operations is:
collapses the frequency dimension (height) to 1, summarizing information
across frequencies at each time step. The output is then repeated along the
height dimension to match the original bottleneck input height before being
passed to the decoder.
This base version does *not* include self-attention. 1. ``VerticalConv`` collapses the frequency axis (height) to a single
bin by applying a convolution whose kernel spans the full height.
2. Optional additional layers (e.g. ``SelfAttention``) applied while
the feature map has height 1, so they operate purely along the time
axis.
3. Height restoration the single-bin output is repeated along the
height axis to restore the original frequency resolution, producing
a tensor that the decoder can accept.
Parameters Parameters
---------- ----------
input_height : int input_height : int
Height (frequency bins) of the input tensor. Must be positive. Height (number of frequency bins) of the input tensor. Must be
positive.
in_channels : int in_channels : int
Number of channels in the input tensor from the encoder. Must be Number of channels in the input tensor from the encoder. Must be
positive. positive.
out_channels : int out_channels : int
Number of output channels. Must be positive. Number of output channels after the bottleneck. Must be positive.
bottleneck_channels : int, optional
Number of internal channels used by the ``VerticalConv`` layer.
Defaults to ``out_channels`` if not provided.
layers : List[torch.nn.Module], optional
Additional modules (e.g. ``SelfAttention``) to apply after the
``VerticalConv`` and before height restoration.
Attributes Attributes
---------- ----------
in_channels : int in_channels : int
Number of input channels accepted by the bottleneck. Number of input channels accepted by the bottleneck.
out_channels : int
Number of output channels produced by the bottleneck.
input_height : int input_height : int
Expected height of the input tensor. Expected height of the input tensor.
channels : int bottleneck_channels : int
Number of output channels. Number of channels used internally by the vertical convolution.
conv_vert : VerticalConv conv_vert : VerticalConv
The vertical convolution layer. The vertical convolution layer.
layers : nn.ModuleList
Raises Additional layers applied after the vertical convolution.
------
ValueError
If `input_height`, `in_channels`, or `out_channels` are not positive.
""" """
def __init__( def __init__(
@ -78,14 +93,31 @@ class Bottleneck(nn.Module):
input_height: int, input_height: int,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
bottleneck_channels: Optional[int] = None, bottleneck_channels: int | None = None,
layers: Optional[List[torch.nn.Module]] = None, layers: List[torch.nn.Module] | None = None,
) -> None: ) -> None:
"""Initialize the base Bottleneck layer.""" """Initialise the Bottleneck layer.
Parameters
----------
input_height : int
Height (number of frequency bins) of the input tensor.
in_channels : int
Number of channels in the input tensor.
out_channels : int
Number of channels in the output tensor.
bottleneck_channels : int, optional
Number of internal channels for the ``VerticalConv``. Defaults
to ``out_channels``.
layers : List[torch.nn.Module], optional
Additional modules applied after the ``VerticalConv``, such as
a ``SelfAttention`` block.
"""
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.input_height = input_height self.input_height = input_height
self.out_channels = out_channels self.out_channels = out_channels
self.bottleneck_channels = ( self.bottleneck_channels = (
bottleneck_channels bottleneck_channels
if bottleneck_channels is not None if bottleneck_channels is not None
@ -100,23 +132,24 @@ class Bottleneck(nn.Module):
) )
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input features through the bottleneck. """Process the encoder's bottleneck features.
Applies vertical convolution and repeats the output height. Applies vertical convolution, optional additional layers, then
restores the height dimension by repetition.
Parameters Parameters
---------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor from the encoder bottleneck, shape Input tensor from the encoder, shape
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`, ``(B, C_in, H_in, W)``. ``C_in`` must match
`H_in` must match `self.input_height`. ``self.in_channels`` and ``H_in`` must match
``self.input_height``.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Output tensor, shape `(B, C_out, H_in, W)`. Note that the height Output tensor with shape ``(B, C_out, H_in, W)``. The height
dimension `H_in` is restored via repetition after the vertical ``H_in`` is restored by repeating the single-bin result.
convolution.
""" """
x = self.conv_vert(x) x = self.conv_vert(x)
@ -127,37 +160,29 @@ class Bottleneck(nn.Module):
BottleneckLayerConfig = Annotated[ BottleneckLayerConfig = Annotated[
Union[SelfAttentionConfig,], SelfAttentionConfig,
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""Type alias for the discriminated union of block configs usable in Decoder.""" """Type alias for the discriminated union of block configs usable in the Bottleneck."""
class BottleneckConfig(BaseConfig): class BottleneckConfig(BaseConfig):
"""Configuration for the bottleneck layer(s). """Configuration for the bottleneck component.
Defines the number of channels within the bottleneck and whether to include
a self-attention mechanism.
Attributes Attributes
---------- ----------
channels : int channels : int
The number of output channels produced by the main convolutional layer Number of output channels produced by the bottleneck. This value
within the bottleneck. This often matches the number of channels coming is also used as the dimensionality of any optional layers (e.g.
from the last encoder stage, but can be different. Must be positive. self-attention). Must be positive.
This also defines the channel dimensions used within the optional layers : List[BottleneckLayerConfig]
`SelfAttention` layer. Ordered list of additional block configurations to apply after the
self_attention : bool initial ``VerticalConv``. Currently only ``SelfAttentionConfig`` is
If True, includes a `SelfAttention` layer operating on the time supported. Defaults to an empty list (no extra layers).
dimension after an initial `VerticalConv` layer within the bottleneck.
If False, only the initial `VerticalConv` (and height repetition) is
performed.
""" """
channels: int channels: int
layers: List[BottleneckLayerConfig] = Field( layers: List[BottleneckLayerConfig] = Field(default_factory=list)
default_factory=list,
)
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig( DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
@ -171,32 +196,39 @@ DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
def build_bottleneck( def build_bottleneck(
input_height: int, input_height: int,
in_channels: int, in_channels: int,
config: Optional[BottleneckConfig] = None, config: BottleneckConfig | None = None,
) -> nn.Module: ) -> BottleneckProtocol:
"""Factory function to build the Bottleneck module from configuration. """Build a ``Bottleneck`` module from configuration.
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based Constructs a ``Bottleneck`` instance whose internal channel count and
on the `config.self_attention` flag. optional extra layers (e.g. self-attention) are controlled by
``config``. If no configuration is provided, the default
``DEFAULT_BOTTLENECK_CONFIG`` is used, which includes a
``SelfAttention`` layer.
Parameters Parameters
---------- ----------
input_height : int input_height : int
Height (frequency bins) of the input tensor. Must be positive. Height (number of frequency bins) of the input tensor from the
encoder. Must be positive.
in_channels : int in_channels : int
Number of channels in the input tensor. Must be positive. Number of channels in the input tensor from the encoder. Must be
positive.
config : BottleneckConfig, optional config : BottleneckConfig, optional
Configuration object specifying the bottleneck channels and whether Configuration specifying the output channel count and any
to use self-attention. Uses `DEFAULT_BOTTLENECK_CONFIG` if None. additional layers. Uses ``DEFAULT_BOTTLENECK_CONFIG`` if ``None``.
Returns Returns
------- -------
nn.Module BottleneckProtocol
An initialized bottleneck module (`Bottleneck` or `BottleneckAttn`). An initialised ``Bottleneck`` module.
Raises Raises
------ ------
ValueError AssertionError
If `input_height` or `in_channels` are not positive. If any configured layer changes the height of the feature map
(bottleneck layers must preserve height so that it can be restored
by repetition).
""" """
config = config or DEFAULT_BOTTLENECK_CONFIG config = config or DEFAULT_BOTTLENECK_CONFIG
@ -206,11 +238,13 @@ def build_bottleneck(
layers = [] layers = []
for layer_config in config.layers: for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config( layer = build_layer(
input_height=current_height, input_height=current_height,
in_channels=current_channels, in_channels=current_channels,
config=layer_config, config=layer_config,
) )
current_height = layer.get_output_height(current_height)
current_channels = layer.out_channels
assert current_height == input_height, ( assert current_height == input_height, (
"Bottleneck layers should not change the spectrogram height" "Bottleneck layers should not change the spectrogram height"
) )

View File

@ -1,98 +0,0 @@
from typing import Optional
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
EncoderConfig,
)
__all__ = [
"BackboneConfig",
"load_backbone_config",
]
class BackboneConfig(BaseConfig):
"""Configuration for the Encoder-Decoder Backbone network.
Aggregates configurations for the encoder, bottleneck, and decoder
components, along with defining the input and final output dimensions
for the complete backbone.
Attributes
----------
input_height : int, default=128
Expected height (frequency bins) of the input spectrograms to the
backbone. Must be positive.
in_channels : int, default=1
Expected number of channels in the input spectrograms (e.g., 1 for
mono). Must be positive.
encoder : EncoderConfig, optional
Configuration for the encoder. If None or omitted,
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
encoder module) will be used.
bottleneck : BottleneckConfig, optional
Configuration for the bottleneck layer connecting encoder and decoder.
If None or omitted, the default bottleneck configuration will be used.
decoder : DecoderConfig, optional
Configuration for the decoder. If None or omitted,
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
decoder module) will be used.
out_channels : int, default=32
Desired number of channels in the final feature map output by the
backbone. Must be positive.
"""
input_height: int = 128
in_channels: int = 1
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
out_channels: int = 32
def load_backbone_config(
path: data.PathLike,
field: Optional[str] = None,
) -> BackboneConfig:
"""Load the backbone configuration from a file.
Reads a configuration file (YAML) and validates it against the
`BackboneConfig` schema, potentially extracting data from a nested field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
backbone configuration (e.g., "model.backbone"). If None, the entire
file content is used.
Returns
-------
BackboneConfig
The loaded and validated backbone configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded config data does not conform to `BackboneConfig`.
KeyError, TypeError
If `field` specifies an invalid path.
"""
return load_config(path, schema=BackboneConfig, field=field)

View File

@ -1,24 +1,24 @@
"""Constructs the Decoder part of an Encoder-Decoder neural network. """Decoder (upsampling path) for the BatDetect2 backbone.
This module defines the configuration structure (`DecoderConfig`) for the layer This module defines ``DecoderConfig`` and the ``Decoder`` ``nn.Module``,
sequence and provides the `Decoder` class (an `nn.Module`) along with a factory together with the ``build_decoder`` factory function.
function (`build_decoder`). Decoders typically form the upsampling path in
architectures like U-Nets, taking bottleneck features
(usually from an `Encoder`) and skip connections to reconstruct
higher-resolution feature maps.
The decoder is built dynamically by stacking neural network blocks based on a In a U-Net-style network the decoder progressively restores the spatial
list of configuration objects provided in `DecoderConfig.layers`. Each config resolution of the feature map back towards the input resolution. At each
object specifies the type of block (e.g., standard convolution, stage it combines the upsampled features with the corresponding skip-connection
coordinate-feature convolution with upsampling) and its parameters. This allows tensor from the encoder (the residual) by element-wise addition before passing
flexible definition of decoder architectures via configuration files. the result to the upsampling block.
The `Decoder`'s `forward` method is designed to accept skip connection tensors The decoder is fully configurable: the type, number, and parameters of the
(`residuals`) from the encoder, merging them with the upsampled feature maps upsampling blocks are described by a ``DecoderConfig`` object containing an
at each stage. ordered list of block configuration objects (see ``batdetect2.models.blocks``
for available block types).
A default configuration ``DEFAULT_DECODER_CONFIG`` is provided and used by
``build_decoder`` when no explicit configuration is supplied.
""" """
from typing import Annotated, List, Optional, Union from typing import Annotated, List
import torch import torch
from pydantic import Field from pydantic import Field
@ -30,7 +30,7 @@ from batdetect2.models.blocks import (
FreqCoordConvUpConfig, FreqCoordConvUpConfig,
LayerGroupConfig, LayerGroupConfig,
StandardConvUpConfig, StandardConvUpConfig,
build_layer_from_config, build_layer,
) )
__all__ = [ __all__ = [
@ -41,63 +41,57 @@ __all__ = [
] ]
DecoderLayerConfig = Annotated[ DecoderLayerConfig = Annotated[
Union[ ConvConfig
ConvConfig, | FreqCoordConvUpConfig
FreqCoordConvUpConfig, | StandardConvUpConfig
StandardConvUpConfig, | LayerGroupConfig,
LayerGroupConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""Type alias for the discriminated union of block configs usable in Decoder.""" """Type alias for the discriminated union of block configs usable in Decoder."""
class DecoderConfig(BaseConfig): class DecoderConfig(BaseConfig):
"""Configuration for the sequence of layers in the Decoder module. """Configuration for the sequential ``Decoder`` module.
Defines the types and parameters of the neural network blocks that
constitute the decoder's upsampling path.
Attributes Attributes
---------- ----------
layers : List[DecoderLayerConfig] layers : List[DecoderLayerConfig]
An ordered list of configuration objects, each defining one layer or Ordered list of block configuration objects defining the decoder's
block in the decoder sequence. Each item must be a valid block upsampling stages (from deepest to shallowest). Each entry
config including a `name` field and necessary parameters like specifies the block type (via its ``name`` field) and any
`out_channels`. Input channels for each layer are inferred sequentially. block-specific parameters such as ``out_channels``. Input channels
The list must contain at least one layer. for each block are inferred automatically from the output of the
previous block. Must contain at least one entry.
""" """
layers: List[DecoderLayerConfig] = Field(min_length=1) layers: List[DecoderLayerConfig] = Field(min_length=1)
class Decoder(nn.Module): class Decoder(nn.Module):
"""Sequential Decoder module composed of configurable upsampling layers. """Sequential decoder module composed of configurable upsampling layers.
Constructs the upsampling path of an encoder-decoder network by stacking Executes a series of upsampling blocks in order, adding the
multiple blocks (e.g., `StandardConvUpBlock`, `FreqCoordConvUpBlock`) corresponding encoder skip-connection tensor (residual) to the feature
based on a list of layer modules provided during initialization (typically map before each block. The residuals are consumed in reverse order (from
created by the `build_decoder` factory function). deepest encoder layer to shallowest) to match the spatial resolutions at
each decoder stage.
The `forward` method is designed to integrate skip connection tensors Instances are typically created by ``build_decoder``.
(`residuals`) from the corresponding encoder stages, by adding them
element-wise to the input of each decoder layer before processing.
Attributes Attributes
---------- ----------
in_channels : int in_channels : int
Number of channels expected in the input tensor. Number of channels expected in the input tensor (bottleneck output).
out_channels : int out_channels : int
Number of channels in the final output tensor produced by the last Number of channels in the final output feature map.
layer.
input_height : int input_height : int
Height (frequency bins) expected in the input tensor. Height (frequency bins) of the input tensor.
output_height : int output_height : int
Height (frequency bins) expected in the output tensor. Height (frequency bins) of the output tensor.
layers : nn.ModuleList layers : nn.ModuleList
The sequence of instantiated upscaling layer modules. Sequence of instantiated upsampling block modules.
depth : int depth : int
The number of upscaling layers (depth) in the decoder. Number of upsampling layers.
""" """
def __init__( def __init__(
@ -108,23 +102,24 @@ class Decoder(nn.Module):
output_height: int, output_height: int,
layers: List[nn.Module], layers: List[nn.Module],
): ):
"""Initialize the Decoder module. """Initialise the Decoder module.
Note: This constructor is typically called internally by the This constructor is typically called by the ``build_decoder``
`build_decoder` factory function. factory function.
Parameters Parameters
---------- ----------
in_channels : int
Number of channels in the input tensor (bottleneck output).
out_channels : int out_channels : int
Number of channels produced by the final layer. Number of channels produced by the final layer.
input_height : int input_height : int
Expected height of the input tensor (bottleneck). Height of the input tensor (bottleneck output height).
in_channels : int output_height : int
Expected number of channels in the input tensor (bottleneck). Height of the output tensor after all layers have been applied.
layers : List[nn.Module] layers : List[nn.Module]
A list of pre-instantiated upscaling layer modules (e.g., Pre-built upsampling block modules in execution order (deepest
`StandardConvUpBlock` or `FreqCoordConvUpBlock`) in the desired stage first).
sequence (from bottleneck towards output resolution).
""" """
super().__init__() super().__init__()
@ -142,43 +137,35 @@ class Decoder(nn.Module):
x: torch.Tensor, x: torch.Tensor,
residuals: List[torch.Tensor], residuals: List[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
"""Pass input through decoder layers, incorporating skip connections. """Pass input through all decoder layers, incorporating skip connections.
Processes the input tensor `x` sequentially through the upscaling At each stage the corresponding residual tensor is added
layers. At each stage, the corresponding skip connection tensor from element-wise to ``x`` before it is passed to the upsampling block.
the `residuals` list is added element-wise to the input before passing Residuals are consumed in reverse order the last element of
it to the upscaling block. ``residuals`` (the output of the shallowest encoder layer) is added
at the first decoder stage, and the first element (output of the
deepest encoder layer) is added at the last decoder stage.
Parameters Parameters
---------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor from the previous stage (e.g., encoder bottleneck). Bottleneck feature map, shape ``(B, C_in, H_in, W)``.
Shape `(B, C_in, H_in, W_in)`, where `C_in` matches
`self.in_channels`.
residuals : List[torch.Tensor] residuals : List[torch.Tensor]
List containing the skip connection tensors from the corresponding Skip-connection tensors from the encoder, ordered from shallowest
encoder stages. Should be ordered from the deepest encoder layer (index 0) to deepest (index -1). Must contain exactly
output (lowest resolution) to the shallowest (highest resolution ``self.depth`` tensors. Each tensor must have the same spatial
near input). The number of tensors in this list must match the dimensions and channel count as ``x`` at the corresponding
number of decoder layers (`self.depth`). Each residual tensor's decoder stage.
channel count must be compatible with the input tensor `x` for
element-wise addition (or concatenation if the blocks were designed
for it).
Returns Returns
------- -------
torch.Tensor torch.Tensor
The final decoded feature map tensor produced by the last layer. Decoded feature map, shape ``(B, C_out, H_out, W)``.
Shape `(B, C_out, H_out, W_out)`.
Raises Raises
------ ------
ValueError ValueError
If the number of `residuals` provided does not match the decoder If the number of ``residuals`` does not equal ``self.depth``.
depth.
RuntimeError
If shapes mismatch during skip connection addition or layer
processing.
""" """
if len(residuals) != len(self.layers): if len(residuals) != len(self.layers):
raise ValueError( raise ValueError(
@ -187,7 +174,7 @@ class Decoder(nn.Module):
f"but got {len(residuals)}." f"but got {len(residuals)}."
) )
for layer, res in zip(self.layers, residuals[::-1]): for layer, res in zip(self.layers, residuals[::-1], strict=False):
x = layer(x + res) x = layer(x + res)
return x return x
@ -205,53 +192,55 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
), ),
], ],
) )
"""A default configuration for the Decoder's *layer sequence*. """Default decoder configuration used in standard BatDetect2 models.
Specifies an architecture often used in BatDetect2, consisting of three Mirrors ``DEFAULT_ENCODER_CONFIG`` in reverse. Assumes the bottleneck
frequency coordinate-aware upsampling blocks followed by a standard output has 256 channels and height 16, and produces:
convolutional block.
- Stage 1 (``FreqCoordConvUp``): 64 channels, height 32.
- Stage 2 (``FreqCoordConvUp``): 32 channels, height 64.
- Stage 3 (``LayerGroup``):
- ``FreqCoordConvUp``: 32 channels, height 128.
- ``ConvBlock``: 32 channels, height 128 (final feature map).
""" """
def build_decoder( def build_decoder(
in_channels: int, in_channels: int,
input_height: int, input_height: int,
config: Optional[DecoderConfig] = None, config: DecoderConfig | None = None,
) -> Decoder: ) -> Decoder:
"""Factory function to build a Decoder instance from configuration. """Build a ``Decoder`` from configuration.
Constructs a sequential `Decoder` module based on the layer sequence Constructs a sequential ``Decoder`` by iterating over the block
defined in a `DecoderConfig` object and the provided input dimensions configurations in ``config.layers``, building each block with
(bottleneck channels and height). If no config is provided, uses the ``build_layer``, and tracking the channel count and feature-map height
default layer sequence from `DEFAULT_DECODER_CONFIG`. as they change through the sequence.
It iteratively builds the layers using the unified `build_layer_from_config`
factory (from `.blocks`), tracking the changing number of channels and
feature map height required for each subsequent layer.
Parameters Parameters
---------- ----------
in_channels : int in_channels : int
The number of channels in the input tensor to the decoder. Must be > 0. Number of channels in the input tensor (bottleneck output). Must
be positive.
input_height : int input_height : int
The height (frequency bins) of the input tensor to the decoder. Must be Height (number of frequency bins) of the input tensor. Must be
> 0. positive.
config : DecoderConfig, optional config : DecoderConfig, optional
The configuration object detailing the sequence of layers and their Configuration specifying the layer sequence. Defaults to
parameters. If None, `DEFAULT_DECODER_CONFIG` is used. ``DEFAULT_DECODER_CONFIG`` if not provided.
Returns Returns
------- -------
Decoder Decoder
An initialized `Decoder` module. An initialised ``Decoder`` module.
Raises Raises
------ ------
ValueError ValueError
If `in_channels` or `input_height` are not positive, or if the layer If ``in_channels`` or ``input_height`` are not positive.
configuration is invalid (e.g., empty list, unknown `name`). KeyError
NotImplementedError If a layer configuration specifies an unknown block type.
If `build_layer_from_config` encounters an unknown `name`.
""" """
config = config or DEFAULT_DECODER_CONFIG config = config or DEFAULT_DECODER_CONFIG
@ -261,11 +250,13 @@ def build_decoder(
layers = [] layers = []
for layer_config in config.layers: for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config( layer = build_layer(
in_channels=current_channels, in_channels=current_channels,
input_height=current_height, input_height=current_height,
config=layer_config, config=layer_config,
) )
current_height = layer.get_output_height(current_height)
current_channels = layer.out_channels
layers.append(layer) layers.append(layer)
return Decoder( return Decoder(

View File

@ -1,27 +1,32 @@
"""Assembles the complete BatDetect2 Detection Model. """Assembles the complete BatDetect2 detection model.
This module defines the concrete `Detector` class, which implements the This module defines the ``Detector`` class, which combines a backbone
`DetectionModel` interface defined in `.types`. It combines a feature feature extractor with prediction heads for detection, classification, and
extraction backbone with specific prediction heads to create the end-to-end bounding-box size regression.
neural network used for detecting bat calls, predicting their size, and
classifying them.
The primary components are: Components
- `Detector`: The `torch.nn.Module` subclass representing the complete model. ----------
- ``Detector`` the ``torch.nn.Module`` that wires together a backbone
(``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to
produce a ``ModelOutput`` tuple from an input spectrogram.
- ``build_detector`` factory function that builds a ready-to-use
``Detector`` from a backbone configuration and a target class count.
This module focuses purely on the neural network architecture definition. The Note that ``Detector`` operates purely on spectrogram tensors; raw audio
logic for preprocessing inputs and postprocessing/decoding outputs resides in preprocessing and output postprocessing are handled by
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively. ``batdetect2.preprocess`` and ``batdetect2.postprocess`` respectively.
""" """
from typing import Optional
import torch import torch
from loguru import logger from loguru import logger
from batdetect2.models.backbones import BackboneConfig, build_backbone from batdetect2.models.backbones import (
BackboneConfig,
UNetBackboneConfig,
build_backbone,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
__all__ = [ __all__ = [
"Detector", "Detector",
@ -30,25 +35,30 @@ __all__ = [
class Detector(DetectionModel): class Detector(DetectionModel):
"""Concrete implementation of the BatDetect2 Detection Model. """Complete BatDetect2 detection and classification model.
Assembles a complete detection and classification model by combining a Combines a backbone feature extractor with two prediction heads:
feature extraction backbone network with specific prediction heads for
detection probability, bounding box size regression, and class - ``ClassifierHead``: predicts per-class probabilities at each
probabilities. timefrequency location.
- ``BBoxHead``: predicts call duration and bandwidth at each location.
The detection probability map is derived from the class probabilities by
summing across the class dimension (i.e. the probability that *any* class
is present), rather than from a separate detection head.
Instances are typically created via ``build_detector``.
Attributes Attributes
---------- ----------
backbone : BackboneModel backbone : BackboneModel
The feature extraction backbone network module. The feature extraction backbone.
num_classes : int num_classes : int
The number of specific target classes the model predicts (derived from Number of target classes (inferred from the classifier head).
the `classifier_head`).
classifier_head : ClassifierHead classifier_head : ClassifierHead
The prediction head responsible for generating class probabilities. Produces per-class probability maps from backbone features.
bbox_head : BBoxHead bbox_head : BBoxHead
The prediction head responsible for generating bounding box size Produces duration and bandwidth predictions from backbone features.
predictions.
""" """
backbone: BackboneModel backbone: BackboneModel
@ -59,26 +69,21 @@ class Detector(DetectionModel):
classifier_head: ClassifierHead, classifier_head: ClassifierHead,
bbox_head: BBoxHead, bbox_head: BBoxHead,
): ):
"""Initialize the Detector model. """Initialise the Detector model.
Note: Instances are typically created using the `build_detector` This constructor is typically called by the ``build_detector``
factory function. factory function.
Parameters Parameters
---------- ----------
backbone : BackboneModel backbone : BackboneModel
An initialized feature extraction backbone module (e.g., built by An initialised backbone module (e.g. built by
`build_backbone` from the `.backbone` module). ``build_backbone``).
classifier_head : ClassifierHead classifier_head : ClassifierHead
An initialized classification head module. The number of classes An initialised classification head. The ``num_classes``
is inferred from this head. attribute is read from this head.
bbox_head : BBoxHead bbox_head : BBoxHead
An initialized bounding box size prediction head module. An initialised bounding-box size prediction head.
Raises
------
TypeError
If the provided modules are not of the expected types.
""" """
super().__init__() super().__init__()
@ -88,31 +93,34 @@ class Detector(DetectionModel):
self.bbox_head = bbox_head self.bbox_head = bbox_head
def forward(self, spec: torch.Tensor) -> ModelOutput: def forward(self, spec: torch.Tensor) -> ModelOutput:
"""Perform the forward pass of the complete detection model. """Run the complete detection model on an input spectrogram.
Processes the input spectrogram through the backbone to extract Passes the spectrogram through the backbone to produce a feature
features, then passes these features through the separate prediction map, then applies the classifier and bounding-box heads. The
heads to generate detection probabilities, class probabilities, and detection probability map is derived by summing the per-class
size predictions. probability maps across the class dimension; no separate detection
head is used.
Parameters Parameters
---------- ----------
spec : torch.Tensor spec : torch.Tensor
Input spectrogram tensor, typically with shape Input spectrogram tensor, shape
`(batch_size, input_channels, frequency_bins, time_bins)`. The ``(batch_size, channels, frequency_bins, time_bins)``.
shape must be compatible with the `self.backbone` input
requirements.
Returns Returns
------- -------
ModelOutput ModelOutput
A NamedTuple containing the four output tensors: A named tuple with four fields:
- `detection_probs`: Detection probability heatmap `(B, 1, H, W)`.
- `size_preds`: Predicted scaled size dimensions `(B, 2, H, W)`. - ``detection_probs`` ``(B, 1, H, W)`` probability that a
- `class_probs`: Class probabilities (excluding background) call of any class is present at each location. Derived by
`(B, num_classes, H, W)`. summing ``class_probs`` over the class dimension.
- `features`: Output feature map from the backbone - ``size_preds`` ``(B, 2, H, W)`` scaled duration (channel
`(B, C_out, H, W)`. 0) and bandwidth (channel 1) predictions at each location.
- ``class_probs`` ``(B, num_classes, H, W)`` per-class
probabilities at each location.
- ``features`` ``(B, C_out, H, W)`` raw backbone feature
map.
""" """
features = self.backbone(spec) features = self.backbone(spec)
classification = self.classifier_head(features) classification = self.classifier_head(features)
@ -127,40 +135,46 @@ class Detector(DetectionModel):
def build_detector( def build_detector(
num_classes: int, config: Optional[BackboneConfig] = None num_classes: int,
config: BackboneConfig | None = None,
backbone: BackboneModel | None = None,
) -> DetectionModel: ) -> DetectionModel:
"""Build the complete BatDetect2 detection model. """Build a complete BatDetect2 detection model.
Constructs a backbone from ``config``, attaches a ``ClassifierHead``
and a ``BBoxHead`` sized to the backbone's output channel count, and
returns them wrapped in a ``Detector``.
Parameters Parameters
---------- ----------
num_classes : int num_classes : int
The number of specific target classes the model should predict Number of target bat species or call types to predict. Must be
(required for the `ClassifierHead`). Must be positive. positive.
config : BackboneConfig, optional config : BackboneConfig, optional
Configuration object specifying the architecture of the backbone Backbone architecture configuration. Defaults to
(encoder, bottleneck, decoder). If None, default configurations defined ``UNetBackboneConfig()`` (the standard BatDetect2 architecture) if
within the respective builder functions (`build_encoder`, etc.) will be not provided.
used to construct a default backbone architecture.
Returns Returns
------- -------
DetectionModel DetectionModel
An initialized `Detector` model instance. An initialised ``Detector`` instance ready for training or
inference.
Raises Raises
------ ------
ValueError ValueError
If `num_classes` is not positive, or if errors occur during the If ``num_classes`` is not positive, or if the backbone
construction of the backbone or detector components (e.g., incompatible configuration is invalid.
configurations, invalid parameters).
""" """
config = config or BackboneConfig() if backbone is None:
config = config or UNetBackboneConfig()
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(), # type: ignore
)
backbone = build_backbone(config=config)
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(),
)
backbone = build_backbone(config=config)
classifier_head = ClassifierHead( classifier_head = ClassifierHead(
num_classes=num_classes, num_classes=num_classes,
in_channels=backbone.out_channels, in_channels=backbone.out_channels,

View File

@ -1,26 +1,27 @@
"""Constructs the Encoder part of a configurable neural network backbone. """Encoder (downsampling path) for the BatDetect2 backbone.
This module defines the configuration structure (`EncoderConfig`) and provides This module defines ``EncoderConfig`` and the ``Encoder`` ``nn.Module``,
the `Encoder` class (an `nn.Module`) along with a factory function together with the ``build_encoder`` factory function.
(`build_encoder`) to create sequential encoders. Encoders typically form the
downsampling path in architectures like U-Nets, processing input feature maps
(like spectrograms) to produce lower-resolution, higher-dimensionality feature
representations (bottleneck features).
The encoder is built dynamically by stacking neural network blocks based on a In a U-Net-style network the encoder progressively reduces the spatial
list of configuration objects provided in `EncoderConfig.layers`. Each resolution of the spectrogram whilst increasing the number of feature
configuration object specifies the type of block (e.g., standard convolution, channels. Each layer in the encoder produces a feature map that is stored
coordinate-feature convolution with downsampling) and its parameters for use as a skip connection in the corresponding decoder layer.
(e.g., output channels). This allows for flexible definition of encoder
architectures via configuration files.
The `Encoder`'s `forward` method returns outputs from all intermediate layers, The encoder is fully configurable: the type, number, and parameters of the
suitable for skip connections, while the `encode` method returns only the final downsampling blocks are described by an ``EncoderConfig`` object containing
bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also an ordered list of block configuration objects (see ``batdetect2.models.blocks``
provided. for available block types).
``Encoder.forward`` returns the outputs of *all* encoder layers as a list,
so that skip connections are available to the decoder.
``Encoder.encode`` returns only the final output (the input to the bottleneck).
A default configuration ``DEFAULT_ENCODER_CONFIG`` is provided and used by
``build_encoder`` when no explicit configuration is supplied.
""" """
from typing import Annotated, List, Optional, Union from typing import Annotated, List
import torch import torch
from pydantic import Field from pydantic import Field
@ -32,7 +33,7 @@ from batdetect2.models.blocks import (
FreqCoordConvDownConfig, FreqCoordConvDownConfig,
LayerGroupConfig, LayerGroupConfig,
StandardConvDownConfig, StandardConvDownConfig,
build_layer_from_config, build_layer,
) )
__all__ = [ __all__ = [
@ -43,47 +44,42 @@ __all__ = [
] ]
EncoderLayerConfig = Annotated[ EncoderLayerConfig = Annotated[
Union[ ConvConfig
ConvConfig, | FreqCoordConvDownConfig
FreqCoordConvDownConfig, | StandardConvDownConfig
StandardConvDownConfig, | LayerGroupConfig,
LayerGroupConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""Type alias for the discriminated union of block configs usable in Encoder.""" """Type alias for the discriminated union of block configs usable in Encoder."""
class EncoderConfig(BaseConfig): class EncoderConfig(BaseConfig):
"""Configuration for building the sequential Encoder module. """Configuration for the sequential ``Encoder`` module.
Defines the sequence of neural network blocks that constitute the encoder
(downsampling path).
Attributes Attributes
---------- ----------
layers : List[EncoderLayerConfig] layers : List[EncoderLayerConfig]
An ordered list of configuration objects, each defining one layer or Ordered list of block configuration objects defining the encoder's
block in the encoder sequence. Each item must be a valid block config downsampling stages. Each entry specifies the block type (via its
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`, ``name`` field) and any block-specific parameters such as
`StandardConvDownConfig`) including a `name` field and necessary ``out_channels``. Input channels for each block are inferred
parameters like `out_channels`. Input channels for each layer are automatically from the output of the previous block. Must contain
inferred sequentially. The list must contain at least one layer. at least one entry.
""" """
layers: List[EncoderLayerConfig] = Field(min_length=1) layers: List[EncoderLayerConfig] = Field(min_length=1)
class Encoder(nn.Module): class Encoder(nn.Module):
"""Sequential Encoder module composed of configurable downscaling layers. """Sequential encoder module composed of configurable downsampling layers.
Constructs the downsampling path of an encoder-decoder network by stacking Executes a series of downsampling blocks in order, storing the output of
multiple downscaling blocks. each block so that it can be passed as a skip connection to the
corresponding decoder layer.
The `forward` method executes the sequence and returns the output feature ``forward`` returns the outputs of *all* layers (useful when skip
map from *each* downscaling stage, facilitating the implementation of skip connections are needed). ``encode`` returns only the final output
connections in U-Net-like architectures. The `encode` method returns only (the input to the bottleneck).
the final output tensor (bottleneck features).
Attributes Attributes
---------- ----------
@ -91,14 +87,14 @@ class Encoder(nn.Module):
Number of channels expected in the input tensor. Number of channels expected in the input tensor.
input_height : int input_height : int
Height (frequency bins) expected in the input tensor. Height (frequency bins) expected in the input tensor.
output_channels : int out_channels : int
Number of channels in the final output tensor (bottleneck). Number of channels in the final output tensor (bottleneck input).
output_height : int output_height : int
Height (frequency bins) expected in the output tensor. Height (frequency bins) of the final output tensor.
layers : nn.ModuleList layers : nn.ModuleList
The sequence of instantiated downscaling layer modules. Sequence of instantiated downsampling block modules.
depth : int depth : int
The number of downscaling layers in the encoder. Number of downsampling layers.
""" """
def __init__( def __init__(
@ -109,23 +105,22 @@ class Encoder(nn.Module):
input_height: int = 128, input_height: int = 128,
in_channels: int = 1, in_channels: int = 1,
): ):
"""Initialize the Encoder module. """Initialise the Encoder module.
Note: This constructor is typically called internally by the This constructor is typically called by the ``build_encoder`` factory
`build_encoder` factory function, which prepares the `layers` list. function, which takes care of building the ``layers`` list from a
configuration object.
Parameters Parameters
---------- ----------
output_channels : int output_channels : int
Number of channels produced by the final layer. Number of channels produced by the final layer.
output_height : int output_height : int
The expected height of the output tensor. Height of the output tensor after all layers have been applied.
layers : List[nn.Module] layers : List[nn.Module]
A list of pre-instantiated downscaling layer modules (e.g., Pre-built downsampling block modules in execution order.
`StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired
sequence.
input_height : int, default=128 input_height : int, default=128
Expected height of the input tensor. Expected height of the input tensor (frequency bins).
in_channels : int, default=1 in_channels : int, default=1
Expected number of channels in the input tensor. Expected number of channels in the input tensor.
""" """
@ -140,29 +135,30 @@ class Encoder(nn.Module):
self.depth = len(self.layers) self.depth = len(self.layers)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]: def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Pass input through encoder layers, returns all intermediate outputs. """Pass input through all encoder layers and return every output.
This method is typically used when the Encoder is part of a U-Net or Used when skip connections are needed (e.g. in a U-Net decoder).
similar architecture requiring skip connections.
Parameters Parameters
---------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match Input spectrogram feature map, shape ``(B, C_in, H_in, W)``.
`self.in_channels` and `H_in` must match `self.input_height`. ``C_in`` must match ``self.in_channels`` and ``H_in`` must
match ``self.input_height``.
Returns Returns
------- -------
List[torch.Tensor] List[torch.Tensor]
A list containing the output tensors from *each* downscaling layer Output tensors from every layer in order.
in the sequence. `outputs[0]` is the output of the first layer, ``outputs[0]`` is the output of the first (shallowest) layer;
`outputs[-1]` is the final output (bottleneck) of the encoder. ``outputs[-1]`` is the output of the last (deepest) layer,
which serves as the input to the bottleneck.
Raises Raises
------ ------
ValueError ValueError
If input tensor channel count or height does not match expected If the input channel count or height does not match the
values. expected values.
""" """
if x.shape[1] != self.in_channels: if x.shape[1] != self.in_channels:
raise ValueError( raise ValueError(
@ -185,28 +181,29 @@ class Encoder(nn.Module):
return outputs return outputs
def encode(self, x: torch.Tensor) -> torch.Tensor: def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Pass input through encoder layers, returning only the final output. """Pass input through all encoder layers and return only the final output.
This method provides access to the bottleneck features produced after Use this when skip connections are not needed and you only require
the last downscaling layer. the bottleneck feature map.
Parameters Parameters
---------- ----------
x : torch.Tensor x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W)`. Must match expected Input spectrogram feature map, shape ``(B, C_in, H_in, W)``.
`in_channels` and `input_height`. Must satisfy the same shape requirements as ``forward``.
Returns Returns
------- -------
torch.Tensor torch.Tensor
The final output tensor (bottleneck features) from the last layer Output of the last encoder layer, shape
of the encoder. Shape `(B, C_out, H_out, W_out)`. ``(B, C_out, H_out, W)``, where ``C_out`` is
``self.out_channels`` and ``H_out`` is ``self.output_height``.
Raises Raises
------ ------
ValueError ValueError
If input tensor channel count or height does not match expected If the input channel count or height does not match the
values. expected values.
""" """
if x.shape[1] != self.in_channels: if x.shape[1] != self.in_channels:
raise ValueError( raise ValueError(
@ -238,58 +235,57 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
), ),
], ],
) )
"""Default configuration for the Encoder. """Default encoder configuration used in standard BatDetect2 models.
Specifies an architecture typically used in BatDetect2: Assumes a 1-channel input with 128 frequency bins and produces the
- Input: 1 channel, 128 frequency bins. following feature maps:
- Layer 1: FreqCoordConvDown -> 32 channels, H=64
- Layer 2: FreqCoordConvDown -> 64 channels, H=32 - Stage 1 (``FreqCoordConvDown``): 32 channels, height 64.
- Layer 3: FreqCoordConvDown -> 128 channels, H=16 - Stage 2 (``FreqCoordConvDown``): 64 channels, height 32.
- Layer 4: ConvBlock -> 256 channels, H=16 (Bottleneck) - Stage 3 (``LayerGroup``):
- ``FreqCoordConvDown``: 128 channels, height 16.
- ``ConvBlock``: 256 channels, height 16 (bottleneck input).
""" """
def build_encoder( def build_encoder(
in_channels: int, in_channels: int,
input_height: int, input_height: int,
config: Optional[EncoderConfig] = None, config: EncoderConfig | None = None,
) -> Encoder: ) -> Encoder:
"""Factory function to build an Encoder instance from configuration. """Build an ``Encoder`` from configuration.
Constructs a sequential `Encoder` module based on the layer sequence Constructs a sequential ``Encoder`` by iterating over the block
defined in an `EncoderConfig` object and the provided input dimensions. configurations in ``config.layers``, building each block with
If no config is provided, uses the default layer sequence from ``build_layer``, and tracking the channel count and feature-map height
`DEFAULT_ENCODER_CONFIG`. as they change through the sequence.
It iteratively builds the layers using the unified
`build_layer_from_config` factory (from `.blocks`), tracking the changing
number of channels and feature map height required for each subsequent
layer, especially for coordinate- aware blocks.
Parameters Parameters
---------- ----------
in_channels : int in_channels : int
The number of channels expected in the input tensor to the encoder. Number of channels in the input spectrogram tensor. Must be
Must be > 0. positive.
input_height : int input_height : int
The height (frequency bins) expected in the input tensor. Must be > 0. Height (number of frequency bins) of the input spectrogram.
Crucial for initializing coordinate-aware layers correctly. Must be positive and should be divisible by
``2 ** (number of downsampling stages)`` to avoid size mismatches
later in the network.
config : EncoderConfig, optional config : EncoderConfig, optional
The configuration object detailing the sequence of layers and their Configuration specifying the layer sequence. Defaults to
parameters. If None, `DEFAULT_ENCODER_CONFIG` is used. ``DEFAULT_ENCODER_CONFIG`` if not provided.
Returns Returns
------- -------
Encoder Encoder
An initialized `Encoder` module. An initialised ``Encoder`` module.
Raises Raises
------ ------
ValueError ValueError
If `in_channels` or `input_height` are not positive, or if the layer If ``in_channels`` or ``input_height`` are not positive.
configuration is invalid (e.g., empty list, unknown `name`). KeyError
NotImplementedError If a layer configuration specifies an unknown block type.
If `build_layer_from_config` encounters an unknown `name`.
""" """
if in_channels <= 0 or input_height <= 0: if in_channels <= 0 or input_height <= 0:
raise ValueError("in_channels and input_height must be positive.") raise ValueError("in_channels and input_height must be positive.")
@ -302,12 +298,14 @@ def build_encoder(
layers = [] layers = []
for layer_config in config.layers: for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config( layer = build_layer(
in_channels=current_channels, in_channels=current_channels,
input_height=current_height, input_height=current_height,
config=layer_config, config=layer_config,
) )
layers.append(layer) layers.append(layer)
current_height = layer.get_output_height(current_height)
current_channels = layer.out_channels
return Encoder( return Encoder(
input_height=input_height, input_height=input_height,

View File

@ -1,20 +1,19 @@
"""Prediction Head modules for BatDetect2 models. """Prediction heads attached to the backbone feature map.
This module defines simple `torch.nn.Module` subclasses that serve as Each head is a lightweight ``torch.nn.Module`` that applies a 1×1
prediction heads, typically attached to the output feature map of a backbone convolution to map backbone feature channels to one specific type of
network output required by BatDetect2:
Each head is responsible for generating one specific type of output required - ``DetectorHead``: single-channel detection probability heatmap (sigmoid
by the BatDetect2 task: activation).
- `DetectorHead`: Predicts the probability of sound event presence. - ``ClassifierHead``: multi-class probability map over the target bat
- `ClassifierHead`: Predicts the probability distribution over target classes. species / call types (softmax activation).
- `BBoxHead`: Predicts the size (width, height) of the sound event's bounding - ``BBoxHead``: two-channel map of predicted call duration (time axis) and
box. bandwidth (frequency axis) at each location (no activation; raw
regression output).
These heads use 1x1 convolutions to map the backbone feature channels All three heads share the same input feature map produced by the backbone,
to the desired number of output channels for each prediction task at each so they can be evaluated in parallel in a single forward pass.
spatial location, followed by an appropriate activation function (e.g., sigmoid
for detection, softmax for classification, none for size regression).
""" """
import torch import torch
@ -28,42 +27,35 @@ __all__ = [
class ClassifierHead(nn.Module): class ClassifierHead(nn.Module):
"""Prediction head for multi-class classification probabilities. """Prediction head for species / call-type classification probabilities.
Takes an input feature map and produces a probability map where each Takes a backbone feature map and produces a probability map where each
channel corresponds to a specific target class. It uses a 1x1 convolution channel corresponds to a target class. Internally the 1×1 convolution
to map input channels to `num_classes + 1` outputs (one for each target maps ``in_channels`` to ``num_classes + 1`` logits (the extra channel
class plus an assumed background/generic class), applies softmax across the represents a generic background / unknown category); a softmax is then
channels, and returns the probabilities for the specific target classes applied across the channel dimension and the background channel is
(excluding the last background/generic channel). discarded before returning.
Parameters Parameters
---------- ----------
num_classes : int num_classes : int
The number of specific target classes the model should predict Number of target classes (bat species or call types) to predict,
(excluding any background or generic category). Must be positive. excluding the background category. Must be positive.
in_channels : int in_channels : int
Number of channels in the input feature map tensor from the backbone. Number of channels in the backbone feature map. Must be positive.
Must be positive.
Attributes Attributes
---------- ----------
num_classes : int num_classes : int
Number of specific output classes. Number of specific output classes (background excluded).
in_channels : int in_channels : int
Number of input channels expected. Number of input channels expected.
classifier : nn.Conv2d classifier : nn.Conv2d
The 1x1 convolutional layer used for prediction. 1×1 convolution with ``num_classes + 1`` output channels.
Output channels = num_classes + 1.
Raises
------
ValueError
If `num_classes` or `in_channels` are not positive.
""" """
def __init__(self, num_classes: int, in_channels: int): def __init__(self, num_classes: int, in_channels: int):
"""Initialize the ClassifierHead.""" """Initialise the ClassifierHead."""
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
@ -76,20 +68,20 @@ class ClassifierHead(nn.Module):
) )
def forward(self, features: torch.Tensor) -> torch.Tensor: def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute class probabilities from input features. """Compute per-class probabilities from backbone features.
Parameters Parameters
---------- ----------
features : torch.Tensor features : torch.Tensor
Input feature map tensor from the backbone, typically with shape Backbone feature map, shape ``(B, C_in, H, W)``.
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Class probability map tensor with shape `(B, num_classes, H, W)`. Class probability map, shape ``(B, num_classes, H, W)``.
Contains probabilities for the specific target classes after Values are softmax probabilities in the range [0, 1] and
softmax, excluding the implicit background/generic class channel. sum to less than 1 per location (the background probability
is discarded).
""" """
logits = self.classifier(features) logits = self.classifier(features)
probs = torch.softmax(logits, dim=1) probs = torch.softmax(logits, dim=1)
@ -97,36 +89,30 @@ class ClassifierHead(nn.Module):
class DetectorHead(nn.Module): class DetectorHead(nn.Module):
"""Prediction head for sound event detection probability. """Prediction head for detection probability (is a call present here?).
Takes an input feature map and produces a single-channel heatmap where Produces a single-channel heatmap where each value indicates the
each value represents the probability ([0, 1]) of a relevant sound event probability ([0, 1]) that a bat call of *any* species is present at
(of any class) being present at that spatial location. that timefrequency location in the spectrogram.
Uses a 1x1 convolution to map input channels to 1 output channel, followed Applies a 1×1 convolution mapping ``in_channels`` 1, followed by
by a sigmoid activation function. sigmoid activation.
Parameters Parameters
---------- ----------
in_channels : int in_channels : int
Number of channels in the input feature map tensor from the backbone. Number of channels in the backbone feature map. Must be positive.
Must be positive.
Attributes Attributes
---------- ----------
in_channels : int in_channels : int
Number of input channels expected. Number of input channels expected.
detector : nn.Conv2d detector : nn.Conv2d
The 1x1 convolutional layer mapping to a single output channel. 1×1 convolution with a single output channel.
Raises
------
ValueError
If `in_channels` is not positive.
""" """
def __init__(self, in_channels: int): def __init__(self, in_channels: int):
"""Initialize the DetectorHead.""" """Initialise the DetectorHead."""
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -138,62 +124,49 @@ class DetectorHead(nn.Module):
) )
def forward(self, features: torch.Tensor) -> torch.Tensor: def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute detection probabilities from input features. """Compute detection probabilities from backbone features.
Parameters Parameters
---------- ----------
features : torch.Tensor features : torch.Tensor
Input feature map tensor from the backbone, typically with shape Backbone feature map, shape ``(B, C_in, H, W)``.
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Detection probability heatmap tensor with shape `(B, 1, H, W)`. Detection probability heatmap, shape ``(B, 1, H, W)``.
Values are in the range [0, 1] due to the sigmoid activation. Values are in the range [0, 1].
Raises
------
RuntimeError
If input channel count does not match `self.in_channels`.
""" """
return torch.sigmoid(self.detector(features)) return torch.sigmoid(self.detector(features))
class BBoxHead(nn.Module): class BBoxHead(nn.Module):
"""Prediction head for bounding box size dimensions. """Prediction head for bounding box size (duration and bandwidth).
Takes an input feature map and produces a two-channel map where each Produces a two-channel map where channel 0 predicts the scaled duration
channel represents a predicted size dimension (typically width/duration and (time-axis extent) and channel 1 predicts the scaled bandwidth
height/bandwidth) for a potential sound event at that spatial location. (frequency-axis extent) of the call at each spectrogram location.
Uses a 1x1 convolution to map input channels to 2 output channels. No Applies a 1×1 convolution mapping ``in_channels`` 2 with no
activation function is typically applied, as size prediction is often activation function (raw regression output). The predicted values are
treated as a direct regression task. The output values usually represent in a scaled space and must be converted to real units (seconds and Hz)
*scaled* dimensions that need to be un-scaled during postprocessing. during postprocessing.
Parameters Parameters
---------- ----------
in_channels : int in_channels : int
Number of channels in the input feature map tensor from the backbone. Number of channels in the backbone feature map. Must be positive.
Must be positive.
Attributes Attributes
---------- ----------
in_channels : int in_channels : int
Number of input channels expected. Number of input channels expected.
bbox : nn.Conv2d bbox : nn.Conv2d
The 1x1 convolutional layer mapping to 2 output channels 1×1 convolution with 2 output channels (duration, bandwidth).
(width, height).
Raises
------
ValueError
If `in_channels` is not positive.
""" """
def __init__(self, in_channels: int): def __init__(self, in_channels: int):
"""Initialize the BBoxHead.""" """Initialise the BBoxHead."""
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -205,19 +178,19 @@ class BBoxHead(nn.Module):
) )
def forward(self, features: torch.Tensor) -> torch.Tensor: def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute predicted bounding box dimensions from input features. """Predict call duration and bandwidth from backbone features.
Parameters Parameters
---------- ----------
features : torch.Tensor features : torch.Tensor
Input feature map tensor from the backbone, typically with shape Backbone feature map, shape ``(B, C_in, H, W)``.
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Predicted size tensor with shape `(B, 2, H, W)`. Channel 0 usually Size prediction tensor, shape ``(B, 2, H, W)``. Channel 0 is
represents scaled width, Channel 1 scaled height. These values the predicted scaled duration; channel 1 is the predicted
need to be un-scaled during postprocessing. scaled bandwidth. Values must be rescaled to real units during
postprocessing.
""" """
return self.bbox(features) return self.bbox(features)

View File

@ -0,0 +1,90 @@
from abc import ABC, abstractmethod
from typing import NamedTuple, Protocol
import torch
__all__ = [
"BackboneModel",
"BlockProtocol",
"BottleneckProtocol",
"DecoderProtocol",
"DetectionModel",
"EncoderDecoderModel",
"EncoderProtocol",
"ModelOutput",
]
class BlockProtocol(Protocol):
in_channels: int
out_channels: int
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
def get_output_height(self, input_height: int) -> int: ...
class EncoderProtocol(Protocol):
in_channels: int
out_channels: int
input_height: int
output_height: int
def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ...
class BottleneckProtocol(Protocol):
in_channels: int
out_channels: int
input_height: int
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
class DecoderProtocol(Protocol):
in_channels: int
out_channels: int
input_height: int
output_height: int
depth: int
def __call__(
self,
x: torch.Tensor,
residuals: list[torch.Tensor],
) -> torch.Tensor: ...
class ModelOutput(NamedTuple):
detection_probs: torch.Tensor
size_preds: torch.Tensor
class_probs: torch.Tensor
features: torch.Tensor
class BackboneModel(ABC, torch.nn.Module):
input_height: int
out_channels: int
@abstractmethod
def forward(self, spec: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class EncoderDecoderModel(BackboneModel):
bottleneck_channels: int
@abstractmethod
def encode(self, spec: torch.Tensor) -> torch.Tensor: ...
@abstractmethod
def decode(self, encoded: torch.Tensor) -> torch.Tensor: ...
class DetectionModel(ABC, torch.nn.Module):
backbone: BackboneModel
classifier_head: torch.nn.Module
bbox_head: torch.nn.Module
@abstractmethod
def forward(self, spec: torch.Tensor) -> ModelOutput: ...

View File

@ -0,0 +1,35 @@
from batdetect2.outputs.config import OutputsConfig
from batdetect2.outputs.formats import (
BatDetect2OutputConfig,
OutputFormatConfig,
ParquetOutputConfig,
RawOutputConfig,
SoundEventOutputConfig,
build_output_formatter,
get_output_formatter,
load_predictions,
)
from batdetect2.outputs.transforms import (
OutputTransformConfig,
build_output_transform,
)
from batdetect2.outputs.types import (
OutputFormatterProtocol,
OutputTransformProtocol,
)
__all__ = [
"BatDetect2OutputConfig",
"OutputFormatConfig",
"OutputFormatterProtocol",
"OutputTransformConfig",
"OutputTransformProtocol",
"OutputsConfig",
"ParquetOutputConfig",
"RawOutputConfig",
"SoundEventOutputConfig",
"build_output_formatter",
"build_output_transform",
"get_output_formatter",
"load_predictions",
]

View File

@ -0,0 +1,15 @@
from pydantic import Field
from batdetect2.core.configs import BaseConfig
from batdetect2.outputs.formats import OutputFormatConfig
from batdetect2.outputs.formats.raw import RawOutputConfig
from batdetect2.outputs.transforms import OutputTransformConfig
__all__ = ["OutputsConfig"]
class OutputsConfig(BaseConfig):
format: OutputFormatConfig = Field(default_factory=RawOutputConfig)
transform: OutputTransformConfig = Field(
default_factory=OutputTransformConfig
)

View File

@ -1,39 +1,42 @@
from typing import Annotated, Optional, Union from typing import Annotated
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike from soundevent.data import PathLike
from batdetect2.data.predictions.base import ( from batdetect2.outputs.formats.base import (
OutputFormatterProtocol, OutputFormatterProtocol,
prediction_formatters, output_formatters,
) )
from batdetect2.data.predictions.batdetect2 import BatDetect2OutputConfig from batdetect2.outputs.formats.batdetect2 import BatDetect2OutputConfig
from batdetect2.data.predictions.raw import RawOutputConfig from batdetect2.outputs.formats.parquet import ParquetOutputConfig
from batdetect2.data.predictions.soundevent import SoundEventOutputConfig from batdetect2.outputs.formats.raw import RawOutputConfig
from batdetect2.typing import TargetProtocol from batdetect2.outputs.formats.soundevent import SoundEventOutputConfig
from batdetect2.targets.types import TargetProtocol
__all__ = [ __all__ = [
"build_output_formatter",
"get_output_formatter",
"BatDetect2OutputConfig", "BatDetect2OutputConfig",
"OutputFormatConfig",
"ParquetOutputConfig",
"RawOutputConfig", "RawOutputConfig",
"SoundEventOutputConfig", "SoundEventOutputConfig",
"build_output_formatter",
"get_output_formatter",
"load_predictions",
] ]
OutputFormatConfig = Annotated[ OutputFormatConfig = Annotated[
Union[ BatDetect2OutputConfig
BatDetect2OutputConfig, | ParquetOutputConfig
SoundEventOutputConfig, | SoundEventOutputConfig
RawOutputConfig, | RawOutputConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
def build_output_formatter( def build_output_formatter(
targets: Optional[TargetProtocol] = None, targets: TargetProtocol | None = None,
config: Optional[OutputFormatConfig] = None, config: OutputFormatConfig | None = None,
) -> OutputFormatterProtocol: ) -> OutputFormatterProtocol:
"""Construct the final output formatter.""" """Construct the final output formatter."""
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
@ -41,13 +44,13 @@ def build_output_formatter(
config = config or RawOutputConfig() config = config or RawOutputConfig()
targets = targets or build_targets() targets = targets or build_targets()
return prediction_formatters.build(config, targets) return output_formatters.build(config, targets)
def get_output_formatter( def get_output_formatter(
name: Optional[str] = None, name: str | None = None,
targets: Optional[TargetProtocol] = None, targets: TargetProtocol | None = None,
config: Optional[OutputFormatConfig] = None, config: OutputFormatConfig | None = None,
) -> OutputFormatterProtocol: ) -> OutputFormatterProtocol:
"""Get the output formatter by name.""" """Get the output formatter by name."""
@ -55,7 +58,7 @@ def get_output_formatter(
if name is None: if name is None:
raise ValueError("Either config or name must be provided.") raise ValueError("Either config or name must be provided.")
config_class = prediction_formatters.get_config_type(name) config_class = output_formatters.get_config_type(name)
config = config_class() # type: ignore config = config_class() # type: ignore
if config.name != name: # type: ignore if config.name != name: # type: ignore
@ -68,9 +71,9 @@ def get_output_formatter(
def load_predictions( def load_predictions(
path: PathLike, path: PathLike,
format: Optional[str] = "raw", format: str | None = "raw",
config: Optional[OutputFormatConfig] = None, config: OutputFormatConfig | None = None,
targets: Optional[TargetProtocol] = None, targets: TargetProtocol | None = None,
): ):
"""Load predictions from a file.""" """Load predictions from a file."""
from batdetect2.targets import build_targets from batdetect2.targets import build_targets

View File

@ -0,0 +1,46 @@
from pathlib import Path
from typing import Literal
from soundevent.data import PathLike
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.targets.types import TargetProtocol
__all__ = [
"OutputFormatterProtocol",
"PredictionFormatterImportConfig",
"make_path_relative",
"output_formatters",
]
def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
path = Path(path)
audio_dir = Path(audio_dir)
if path.is_absolute():
if not path.is_relative_to(audio_dir):
raise ValueError(
f"Audio file {path} is not in audio_dir {audio_dir}"
)
return path.relative_to(audio_dir)
return path
output_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
Registry(name="output_formatter")
)
@add_import_config(output_formatters)
class PredictionFormatterImportConfig(ImportConfig):
"""Use any callable as a prediction formatter.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"

Some files were not shown because too many files have changed in this diff Show More