Compare commits

...

73 Commits

Author SHA1 Message Date
mbsantiago
5a14b29281 Do not crash on failure to plot 2026-03-19 01:31:21 +00:00
mbsantiago
13a31d9de9 Do not sync with model in loading from checkpoint otherwise device clash 2026-03-19 01:28:05 +00:00
mbsantiago
a1fad6d7d7 Add test for checkpointing 2026-03-19 01:26:37 +00:00
Santiago Martinez Balvanera
b8acd86c71 By default only save the last checkpoint 2026-03-19 01:26:11 +00:00
Santiago Martinez Balvanera
875751d340 Add save last config to checkpoint 2026-03-19 00:36:45 +00:00
Santiago Martinez Balvanera
32d8c4a9e5 Create separate preproc/postproc/target instances for model in api 2026-03-19 00:23:55 +00:00
mbsantiago
fb3dc3eaf0 Ensure preprocessor is in CPU 2026-03-19 00:09:30 +00:00
mbsantiago
0d90cb5cc3 Create duplicate preprocessor for data input pipeline 2026-03-19 00:03:29 +00:00
mbsantiago
91806aa01e Update default scheduler params 2026-03-18 23:54:16 +00:00
mbsantiago
23ac619c50 Add option to override targets when loading the model config 2026-03-18 23:36:24 +00:00
mbsantiago
99b9e55c0e Add extra and strict arguments to load config functions, migrate example config 2026-03-18 22:06:45 +00:00
mbsantiago
652670b01d Name changes 2026-03-18 20:48:35 +00:00
mbsantiago
0163a572cb Expanded cli tests 2026-03-18 20:35:08 +00:00
mbsantiago
f0af5dd79e Change inference command to predict 2026-03-18 20:07:53 +00:00
mbsantiago
2f03abe8f6 Remove stale load functions 2026-03-18 19:43:54 +00:00
mbsantiago
22a3d18d45 Move the logging config out of the train/eval configs 2026-03-18 19:32:19 +00:00
mbsantiago
bf5b88016a Polish evaluate and train CLI 2026-03-18 19:15:57 +00:00
mbsantiago
f9056eb19a Ensure config is source of truth 2026-03-18 18:33:51 +00:00
mbsantiago
ebe7e134e9 Allow building new head 2026-03-18 17:44:35 +00:00
mbsantiago
8e35956007 Create save evaluation results 2026-03-18 16:49:22 +00:00
mbsantiago
a332c5c3bd Allow loading predictions in different formats 2026-03-18 16:17:50 +00:00
mbsantiago
9fa703b34b Allow training on existing model 2026-03-18 13:58:52 +00:00
mbsantiago
0bf809e376 Exported types at module level 2026-03-18 13:08:28 +00:00
mbsantiago
6276a8884e Add roundtrip test for encoding decoding geometries 2026-03-18 12:09:03 +00:00
mbsantiago
7b1cb402b4 Add a few clip and detection transforms for outputs 2026-03-18 11:24:22 +00:00
mbsantiago
45ae15eed5 Cleanup postprocessing 2026-03-18 09:07:13 +00:00
mbsantiago
b3af70761e Create EvaluateTaskProtocol 2026-03-18 01:53:28 +00:00
mbsantiago
daff74fdde Moved decoding to outputs 2026-03-18 01:35:34 +00:00
mbsantiago
31d4f92359 reformat 2026-03-18 00:41:45 +00:00
mbsantiago
d56b9f02ae Remove num_worker from config 2026-03-18 00:22:59 +00:00
mbsantiago
573d8e38d6 Ran formatter 2026-03-18 00:03:26 +00:00
mbsantiago
751be53edf Moving types around to each submodule 2026-03-18 00:01:35 +00:00
mbsantiago
c226dc3f2b Add outputs module 2026-03-17 23:00:26 +00:00
mbsantiago
3b47c688dd Update api_v2 2026-03-17 22:25:41 +00:00
mbsantiago
feee2bdfa3 Add scheduler and optimizer module 2026-03-17 21:16:41 +00:00
mbsantiago
615c7d78fb Added a full test of training and saving 2026-03-17 15:38:07 +00:00
mbsantiago
56f6affc72 use train config in training module 2026-03-17 14:16:40 +00:00
mbsantiago
65bd0dc6ae Restructure model config 2026-03-17 13:33:13 +00:00
mbsantiago
1a7c0b4b3a Removing legacy types 2026-03-17 12:53:03 +00:00
mbsantiago
8ac4f4c44d Add dynamic imports to existing registries 2026-03-16 10:04:34 +00:00
mbsantiago
038d58ed99 Add config for dynamic imports, and tests 2026-03-16 09:30:23 +00:00
mbsantiago
47f418a63c Add test for annotated dataset load function 2026-03-15 21:33:14 +00:00
mbsantiago
197cc38e3e Add tests for registry 2026-03-15 21:17:25 +00:00
mbsantiago
e0503487ec Remove deprecated types 2026-03-15 21:17:20 +00:00
mbsantiago
4b7d23abde Create data annotation loader registry 2026-03-15 20:53:59 +00:00
mbsantiago
3c337a06cb Add architecture document 2026-03-11 19:08:18 +00:00
mbsantiago
0d590a26cc Update tests 2026-03-08 18:41:05 +00:00
mbsantiago
46c02962f3 Add test for preprocessing 2026-03-08 17:11:27 +00:00
mbsantiago
bfc88a4a0f Update audio loader docstrings 2026-03-08 17:02:35 +00:00
mbsantiago
ce952e364b Update audio module docstrings 2026-03-08 17:02:25 +00:00
mbsantiago
ef3348d651 Update model docstrings 2026-03-08 16:34:17 +00:00
mbsantiago
4207661da4 Add test for backbones 2026-03-08 16:04:54 +00:00
mbsantiago
45e3cf1434 Modify example config to add name 2026-03-08 15:23:14 +00:00
mbsantiago
f2d5088bec Run formatter 2026-03-08 15:18:21 +00:00
mbsantiago
652d076b46 Add backbone registry 2026-03-08 15:02:56 +00:00
mbsantiago
e393709258 Add interfaces for encoder/decoder/bottleneck 2026-03-08 14:43:16 +00:00
mbsantiago
54605ef269 Add blocks and detector tests 2026-03-08 14:17:47 +00:00
mbsantiago
b8b8a68f49 Add gradio as optional group 2026-03-08 13:17:32 +00:00
mbsantiago
6812e1c515 Update default detection class config 2026-03-08 13:17:23 +00:00
mbsantiago
0b344003a1 Ignore scripts for now 2026-03-08 13:17:12 +00:00
mbsantiago
c9bcaebcde Ignore notebooks for now 2026-03-08 13:17:03 +00:00
mbsantiago
d52e988b8f Fix type errors 2026-03-08 12:55:36 +00:00
mbsantiago
cce1b49a8d Run formatting 2026-03-08 08:59:28 +00:00
mbsantiago
8313fe1484 Minor formatting 2026-03-07 16:09:43 +00:00
mbsantiago
4509602e70 Run automated fixes 2025-12-12 21:29:25 +00:00
mbsantiago
0adb58e039 Run formatter 2025-12-12 21:28:28 +00:00
mbsantiago
531ff69974 Cleanup evaluate 2025-12-12 21:14:31 +00:00
mbsantiago
750f9e43c4 Make sure threshold is used 2025-12-12 19:53:15 +00:00
mbsantiago
f71fe0c2e2 Using matching and affinity functions from soundevent 2025-12-12 19:25:01 +00:00
mbsantiago
113f438e74 Run lint fixes 2025-12-08 17:19:33 +00:00
mbsantiago
2563f26ed3 Update type hints to python 3.10 2025-12-08 17:14:50 +00:00
mbsantiago
9c72537ddd Add parquet format for outputs 2025-12-08 17:11:35 +00:00
mbsantiago
72278d75ec Upgrade soundevent to 2.10 2025-12-08 17:11:22 +00:00
216 changed files with 11302 additions and 9484 deletions

7
.gitignore vendored
View File

@ -102,7 +102,7 @@ experiments/*
DvcLiveLogger/checkpoints
logs/
mlruns/
outputs/
/outputs/
notebooks/lightning_logs
# Jupiter notebooks
@ -123,3 +123,8 @@ example_data/preprocessed
# Dev notebooks
notebooks/tmp
/tmp
/.agents/skills
/notebooks
/AGENTS.md
/scripts

View File

@ -0,0 +1,93 @@
# BatDetect2 Architecture Overview
This document provides a comprehensive map of the `batdetect2` codebase architecture. It is intended to serve as a deep-dive reference for developers, agents, and contributors navigating the project.
`batdetect2` is designed as a modular deep learning pipeline for detecting and classifying bat echolocation calls in high-frequency audio recordings. It heavily utilizes **PyTorch**, **PyTorch Lightning** for training, and the **Soundevent** library for standardized audio and geometry data classes.
The repository follows a configuration-driven design pattern, heavily utilizing `pydantic`/`omegaconf` (via `BaseConfig`) and the Factory/Registry patterns for dependency injection and modularity. The entire pipeline can be orchestrated via the high-level API `BatDetect2API` (`src/batdetect2/api_v2.py`).
---
## 1. Data Flow Pipeline
The standard lifecycle of a prediction request follows these sequential stages, each handled by an isolated, replaceable module:
1. **Audio Loading (`batdetect2.audio`)**: Read raw `.wav` files into standard NumPy arrays or `soundevent.data.Clip` objects. Handles resampling.
2. **Preprocessing (`batdetect2.preprocess`)**: Converts raw 1D waveforms into 2D Spectrogram tensors.
3. **Forward Pass (`batdetect2.models`)**: A PyTorch neural network processes the spectrogram and outputs dense prediction tensors (e.g., detection heatmaps, bounding box sizes, class probabilities).
4. **Postprocessing (`batdetect2.postprocess`)**: Decodes the raw output tensors back into explicit geometry bounding boxes and runs Non-Maximum Suppression (NMS) to filter redundant predictions.
5. **Formatting (`batdetect2.data`)**: Transforms the predictions into standard formats (`.csv`, `.json`, `.parquet`) using `OutputFormatterProtocol`.
---
## 2. Core Modules Breakdown
### 2.1 Audio and Preprocessing
- **`audio/`**:
- Centralizes audio I/O using `AudioLoader`. It abstracts over the `soundevent` library, efficiently handling full `Recording` files or smaller `Clip` segments, standardizing the sample rate.
- **`preprocess/`**:
- Dictated by the `PreprocessorProtocol`.
- Its primary responsibility is spectrogram generation via Short-Time Fourier Transform (STFT).
- During training, it incorporates data augmentation layers (e.g., amplitude scaling, time masking, frequency masking, spectral mean subtraction) configured via `PreprocessingConfig`.
### 2.2 Deep Learning Models (`models/`)
The `models` directory contains all PyTorch neural network architectures. The default architecture is an Encoder-Decoder (U-Net style) network.
- **`blocks.py`**: Reusable neural network blocks, including standard Convolutions (`ConvBlock`) and specialized layers like `FreqCoordConvDownBlock`/`FreqCoordConvUpBlock` which append normalized spatial frequency coordinates to explicitly grant convolutional filters frequency-awareness.
- **`encoder.py`**: The downsampling path (feature extraction). Builds a sequential list of blocks and captures skip connections.
- **`bottleneck.py`**: The deepest, lowest-resolution segment connecting the Encoder and Decoder. Features an optional `SelfAttention` mechanism to weigh global temporal contexts.
- **`decoder.py`**: The upsampling path (reconstruction), actively integrating skip connections (residuals) from the Encoder.
- **`heads.py`**: Attach to the backbone's feature map to output specific predictions:
- `BBoxHead`: Predicts bounding box sizes.
- `ClassifierHead`: Predicts species classes.
- `DetectorHead`: Predicts detection probability heatmaps.
- **`backbones.py` & `detectors.py`**: Assemble the encoder, bottleneck, decoder, and heads into a cohesive `Detector` model.
- **`__init__.py:Model`**: The overarching wrapper `torch.nn.Module` containing the `detector`, `preprocessor`, `postprocessor`, and `targets`.
### 2.3 Targets and Regions of Interest (`targets/`)
Crucial for training, this module translates physical annotations (Regions of Interest / ROIs) into training targets (tensors).
- **`rois.py`**: Implements `ROITargetMapper`. Maps a geometric bounding box into a 2D reference `Position` (time, freq) and a `Size` array. Includes strategies like:
- `AnchorBBoxMapper`: Maps based on a fixed bounding box corner/center.
- `PeakEnergyBBoxMapper`: Identifies the physical coordinate of peak acoustic energy inside the bounding box and calculates offsets to the box edges.
- **`targets.py`**: Reconstructs complete multi-channel target heatmaps and coordinate tensors from the ROIs to compute losses during training.
### 2.4 Postprocessing (`postprocess/`)
- Implements `PostprocessorProtocol`.
- Reverses the logic from `targets`. It scans the model's output detection heatmaps for peaks, extracts the predicted sizes and class probabilities at those peaks, and decodes them back into physical `soundevent.data.Geometry` (Bounding Boxes).
- Automatically applies Non-Maximum Suppression (NMS) configured via `PostprocessConfig` to remove highly overlapping predictions.
### 2.5 Data Management (`data/`)
- **`annotations/`**: Utilities to load dataset annotations supporting multiple standardized schemas (`AOEF`, `BatDetect2` formats).
- **`datasets.py`**: Aggregates recordings and annotations into memory.
- **`predictions/`**: Handles the exporting of model results via `OutputFormatterProtocol`. Includes formatters for `RawOutput`, `.parquet`, `.json`, etc.
### 2.6 Evaluation (`evaluate/`)
- Computes scientific metrics using `EvaluatorProtocol`.
- Provides specific testing environments for tasks like `Clip Classification`, `Clip Detection`, and `Top Class` predictions.
- Generates precision-recall curves and scatter plots.
### 2.7 Training (`train/`)
- Implements the distributed PyTorch training loop via PyTorch Lightning.
- **`lightning.py`**: Contains `TrainingModule`, the `LightningModule` that orchestrates the optimizer, learning rate scheduler, forward passes, and backpropagation using the generated `targets`.
---
## 3. Interfaces and Tooling
### 3.1 APIs
- **`api_v2.py` (`BatDetect2API`)**: The modern API object. It is deeply integrated with dependency injection using `BatDetect2Config`. It instantiates the loader, targets, preprocessor, postprocessor, and model, exposing easy-to-use methods like `process_file`, `evaluate`, and `train`.
- **`api.py`**: The legacy API. Kept for backwards compatibility. Uses hardcoded default instances rather than configuration objects.
### 3.2 Command Line Interface (`cli/`)
- Implements terminal commands utilizing `click`. Commands include `batdetect2 detect`, `evaluate`, and `train`.
### 3.3 Core and Configuration (`core/`, `config.py`)
- **`core/registries.py`**: A string-based Registry pattern (e.g., `block_registry`, `roi_mapper_registry`) that allows developers to dynamically swap components (like a custom neural network block) via configuration files without modifying python code.
- **`config.py`**: Aggregates all modular `BaseConfig` objects (`AudioConfig`, `PreprocessingConfig`, `BackboneConfig`) into the monolithic `BatDetect2Config`.
---
## Summary
To navigate this codebase effectively:
1. Follow **`api_v2.py`** to see how high-level operations invoke individual components.
2. Rely heavily on the typed **Protocols** located in each subsystem's `types.py` module (for example `src/batdetect2/preprocess/types.py` and `src/batdetect2/postprocess/types.py`) to understand inputs and outputs without needing to read each implementation.
3. Understand that data flows structurally as `soundevent` primitives externally, and as pure `torch.Tensor` internally through the network.

View File

@ -6,6 +6,7 @@ Hi!
:maxdepth: 1
:caption: Contents:
architecture
data/index
preprocessing/index
postprocessing

View File

@ -1,70 +1,79 @@
config_version: v1
audio:
samplerate: 256000
resample:
enabled: True
method: "poly"
preprocess:
stft:
window_duration: 0.002
window_overlap: 0.75
window_fn: hann
frequencies:
max_freq: 120000
min_freq: 10000
size:
height: 128
resize_factor: 0.5
spectrogram_transforms:
- name: pcen
time_constant: 0.1
gain: 0.98
bias: 2
power: 0.5
- name: spectral_mean_substraction
postprocess:
nms_kernel_size: 9
detection_threshold: 0.01
top_k_per_sec: 200
enabled: true
method: poly
model:
input_height: 128
in_channels: 1
out_channels: 32
encoder:
layers:
- name: FreqCoordConvDown
out_channels: 32
- name: FreqCoordConvDown
out_channels: 64
- name: LayerGroup
layers:
- name: FreqCoordConvDown
out_channels: 128
- name: ConvBlock
out_channels: 256
bottleneck:
channels: 256
layers:
- name: SelfAttention
attention_channels: 256
decoder:
layers:
- name: FreqCoordConvUp
out_channels: 64
- name: FreqCoordConvUp
out_channels: 32
- name: LayerGroup
layers:
- name: FreqCoordConvUp
out_channels: 32
- name: ConvBlock
out_channels: 32
samplerate: 256000
preprocess:
stft:
window_duration: 0.002
window_overlap: 0.75
window_fn: hann
frequencies:
max_freq: 120000
min_freq: 10000
size:
height: 128
resize_factor: 0.5
spectrogram_transforms:
- name: pcen
time_constant: 0.1
gain: 0.98
bias: 2
power: 0.5
- name: spectral_mean_subtraction
architecture:
name: UNetBackbone
input_height: 128
in_channels: 1
encoder:
layers:
- name: FreqCoordConvDown
out_channels: 32
- name: FreqCoordConvDown
out_channels: 64
- name: LayerGroup
layers:
- name: FreqCoordConvDown
out_channels: 128
- name: ConvBlock
out_channels: 256
bottleneck:
channels: 256
layers:
- name: SelfAttention
attention_channels: 256
decoder:
layers:
- name: FreqCoordConvUp
out_channels: 64
- name: FreqCoordConvUp
out_channels: 32
- name: LayerGroup
layers:
- name: FreqCoordConvUp
out_channels: 32
- name: ConvBlock
out_channels: 32
postprocess:
nms_kernel_size: 9
detection_threshold: 0.01
top_k_per_sec: 200
train:
optimizer:
name: adam
learning_rate: 0.001
scheduler:
name: cosine_annealing
t_max: 100
labels:
@ -76,10 +85,7 @@ train:
train_loader:
batch_size: 8
num_workers: 2
shuffle: True
shuffle: true
clipping_strategy:
name: random_subclip
@ -115,7 +121,6 @@ train:
max_masks: 3
val_loader:
num_workers: 2
clipping_strategy:
name: whole_audio_padded
chunk_size: 0.256
@ -134,9 +139,6 @@ train:
size:
weight: 0.1
logger:
name: csv
validation:
tasks:
- name: sound_event_detection
@ -146,6 +148,10 @@ train:
metrics:
- name: average_precision
logging:
train:
name: csv
evaluation:
tasks:
- name: sound_event_detection

View File

@ -14,60 +14,67 @@ HTML_COVERAGE_DIR := "htmlcov"
help:
@just --list
install:
uv sync
# Testing & Coverage
# Run tests using pytest.
test:
pytest {{TESTS_DIR}}
uv run pytest {{TESTS_DIR}}
# Run tests and generate coverage data.
coverage:
pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
uv run pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
# Generate an HTML coverage report.
coverage-html: coverage
@echo "Generating HTML coverage report..."
coverage html -d {{HTML_COVERAGE_DIR}}
uv run coverage html -d {{HTML_COVERAGE_DIR}}
@echo "HTML coverage report generated in {{HTML_COVERAGE_DIR}}/"
# Serve the HTML coverage report locally.
coverage-serve: coverage-html
@echo "Serving report at http://localhost:8000/ ..."
python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
uv run python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
# Documentation
# Build documentation using Sphinx.
docs:
sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
uv run sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
# Serve documentation with live reload.
docs-serve:
sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
uv run sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
# Formatting & Linting
# Format code using ruff.
format:
ruff format {{PYTHON_DIRS}}
# Check code formatting using ruff.
format-check:
ruff format --check {{PYTHON_DIRS}}
# Lint code using ruff.
lint:
ruff check {{PYTHON_DIRS}}
fix-format:
uv run ruff format {{PYTHON_DIRS}}
# Lint code using ruff and apply automatic fixes.
lint-fix:
ruff check --fix {{PYTHON_DIRS}}
fix-lint:
uv run ruff check --fix {{PYTHON_DIRS}}
# Combined Formatting & Linting
fix: fix-format fix-lint
# Checking tasks
# Check code formatting using ruff.
check-format:
uv run ruff format --check {{PYTHON_DIRS}}
# Lint code using ruff.
check-lint:
uv run ruff check {{PYTHON_DIRS}}
# Type Checking
# Type check code using pyright.
typecheck:
pyright {{PYTHON_DIRS}}
# Type check code using ty.
check-types:
uv run ty check {{PYTHON_DIRS}}
# Combined Checks
# Run all checks (format-check, lint, typecheck).
check: format-check lint typecheck test
check: check-format check-lint check-types
# Cleaning tasks
# Remove Python bytecode and cache.
@ -95,7 +102,7 @@ clean: clean-build clean-pyc clean-test clean-docs
# Train on example data.
example-train OPTIONS="":
batdetect2 train \
uv run batdetect2 train \
--val-dataset example_data/dataset.yaml \
--config example_data/config.yaml \
{{OPTIONS}} \

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,440 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "cfb0b360-a204-4c27-a18f-3902e8758879",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:02.699871Z",
"iopub.status.busy": "2024-11-19T17:33:02.699590Z",
"iopub.status.idle": "2024-11-19T17:33:02.710312Z",
"shell.execute_reply": "2024-11-19T17:33:02.709798Z",
"shell.execute_reply.started": "2024-11-19T17:33:02.699839Z"
}
},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "326c5432-94e6-4abf-a332-fe902559461b",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:02.711324Z",
"iopub.status.busy": "2024-11-19T17:33:02.711067Z",
"iopub.status.idle": "2024-11-19T17:33:09.092380Z",
"shell.execute_reply": "2024-11-19T17:33:09.091830Z",
"shell.execute_reply.started": "2024-11-19T17:33:02.711304Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/santiago/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from pathlib import Path\n",
"from typing import List, Optional\n",
"import torch\n",
"\n",
"import pytorch_lightning as pl\n",
"from batdetect2.train.modules import DetectorModel\n",
"from batdetect2.train.augmentations import (\n",
" add_echo,\n",
" select_random_subclip,\n",
" warp_spectrogram,\n",
")\n",
"from batdetect2.train.dataset import LabeledDataset, get_files\n",
"from batdetect2.train.preprocess import PreprocessingConfig\n",
"from soundevent import data\n",
"import matplotlib.pyplot as plt\n",
"from soundevent.types import ClassMapper\n",
"from torch.utils.data import DataLoader"
]
},
{
"cell_type": "markdown",
"id": "9402a473-0b25-4123-9fa8-ad1f71a4237a",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-18T22:39:12.395329Z",
"iopub.status.busy": "2024-11-18T22:39:12.393444Z",
"iopub.status.idle": "2024-11-18T22:39:12.405938Z",
"shell.execute_reply": "2024-11-18T22:39:12.402980Z",
"shell.execute_reply.started": "2024-11-18T22:39:12.395236Z"
}
},
"source": [
"## Training Datasets"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cfd97d83-8c2b-46c8-9eae-cea59f53bc61",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:09.093487Z",
"iopub.status.busy": "2024-11-19T17:33:09.092990Z",
"iopub.status.idle": "2024-11-19T17:33:09.121636Z",
"shell.execute_reply": "2024-11-19T17:33:09.121143Z",
"shell.execute_reply.started": "2024-11-19T17:33:09.093459Z"
}
},
"outputs": [],
"source": [
"data_dir = Path.cwd().parent / \"example_data\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d5131ae9-2efd-4758-b6e5-189a6d90789b",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:09.122685Z",
"iopub.status.busy": "2024-11-19T17:33:09.122270Z",
"iopub.status.idle": "2024-11-19T17:33:09.151386Z",
"shell.execute_reply": "2024-11-19T17:33:09.150788Z",
"shell.execute_reply.started": "2024-11-19T17:33:09.122661Z"
}
},
"outputs": [],
"source": [
"files = get_files(data_dir / \"preprocessed\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "bc733d3d-7829-4e90-896d-a0dc76b33288",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:09.152327Z",
"iopub.status.busy": "2024-11-19T17:33:09.152060Z",
"iopub.status.idle": "2024-11-19T17:33:09.184041Z",
"shell.execute_reply": "2024-11-19T17:33:09.183372Z",
"shell.execute_reply.started": "2024-11-19T17:33:09.152305Z"
}
},
"outputs": [],
"source": [
"train_dataset = LabeledDataset(files)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "dfbb94ab-7b12-4689-9c15-4dc34cd17cb2",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:09.186393Z",
"iopub.status.busy": "2024-11-19T17:33:09.186117Z",
"iopub.status.idle": "2024-11-19T17:33:09.220175Z",
"shell.execute_reply": "2024-11-19T17:33:09.219322Z",
"shell.execute_reply.started": "2024-11-19T17:33:09.186375Z"
}
},
"outputs": [],
"source": [
"train_dataloader = DataLoader(\n",
" train_dataset,\n",
" shuffle=True,\n",
" batch_size=32,\n",
" num_workers=4,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "e2eedaa9-6be3-481a-8786-7618515d98f8",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:09.221653Z",
"iopub.status.busy": "2024-11-19T17:33:09.221242Z",
"iopub.status.idle": "2024-11-19T17:33:09.260977Z",
"shell.execute_reply": "2024-11-19T17:33:09.260375Z",
"shell.execute_reply.started": "2024-11-19T17:33:09.221616Z"
}
},
"outputs": [],
"source": [
"# List of all possible classes\n",
"class Mapper(ClassMapper):\n",
" class_labels = [\n",
" \"Eptesicus serotinus\",\n",
" \"Myotis mystacinus\",\n",
" \"Pipistrellus pipistrellus\",\n",
" \"Rhinolophus ferrumequinum\",\n",
" ]\n",
"\n",
" def encode(self, x: data.SoundEventAnnotation) -> Optional[str]:\n",
" event_tag = data.find_tag(x.tags, \"event\")\n",
"\n",
" if event_tag.value == \"Social\":\n",
" return \"social\"\n",
"\n",
" if event_tag.value != \"Echolocation\":\n",
" # Ignore all other types of calls\n",
" return None\n",
"\n",
" species_tag = data.find_tag(x.tags, \"class\")\n",
" return species_tag.value\n",
"\n",
" def decode(self, class_name: str) -> List[data.Tag]:\n",
" if class_name == \"social\":\n",
" return [data.Tag(key=\"event\", value=\"social\")]\n",
"\n",
" return [data.Tag(key=\"class\", value=class_name)]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1ff6072c-511e-42fe-a74f-282f269b80f0",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:09.262337Z",
"iopub.status.busy": "2024-11-19T17:33:09.261775Z",
"iopub.status.idle": "2024-11-19T17:33:09.309793Z",
"shell.execute_reply": "2024-11-19T17:33:09.309216Z",
"shell.execute_reply.started": "2024-11-19T17:33:09.262307Z"
}
},
"outputs": [],
"source": [
"detector = DetectorModel(class_mapper=Mapper())"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3a763ee6-15bc-4105-a409-f06e0ad21a06",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:09.310695Z",
"iopub.status.busy": "2024-11-19T17:33:09.310438Z",
"iopub.status.idle": "2024-11-19T17:33:09.366636Z",
"shell.execute_reply": "2024-11-19T17:33:09.366059Z",
"shell.execute_reply.started": "2024-11-19T17:33:09.310669Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: False, used: False\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n"
]
}
],
"source": [
"trainer = pl.Trainer(\n",
" limit_train_batches=100,\n",
" max_epochs=2,\n",
" log_every_n_steps=1,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "0b86d49d-3314-4257-94f5-f964855be385",
"metadata": {
"execution": {
"iopub.execute_input": "2024-11-19T17:33:09.367499Z",
"iopub.status.busy": "2024-11-19T17:33:09.367242Z",
"iopub.status.idle": "2024-11-19T17:33:10.811300Z",
"shell.execute_reply": "2024-11-19T17:33:10.809823Z",
"shell.execute_reply.started": "2024-11-19T17:33:09.367473Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
" | Name | Type | Params | Mode \n",
"--------------------------------------------------------\n",
"0 | feature_extractor | Net2DFast | 119 K | train\n",
"1 | classifier | Conv2d | 54 | train\n",
"2 | bbox | Conv2d | 18 | train\n",
"--------------------------------------------------------\n",
"119 K Trainable params\n",
"448 Non-trainable params\n",
"119 K Total params\n",
"0.480 Total estimated model params size (MB)\n",
"32 Modules in train mode\n",
"0 Modules in eval mode\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0: 0%| | 0/1 [00:00<?, ?it/s]class heatmap shape torch.Size([3, 4, 128, 512])\n",
"class props shape torch.Size([3, 5, 128, 512])\n"
]
},
{
"ename": "RuntimeError",
"evalue": "The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdetector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:538\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 536\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m=\u001b[39m TrainerStatus\u001b[38;5;241m.\u001b[39mRUNNING\n\u001b[1;32m 537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 538\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 539\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 540\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:47\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 47\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 50\u001b[0m _call_teardown_hook(trainer)\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:574\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 568\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 569\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 570\u001b[0m ckpt_path,\n\u001b[1;32m 571\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 572\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 573\u001b[0m )\n\u001b[0;32m--> 574\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 576\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 577\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:981\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 976\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 978\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 981\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 983\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 984\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 985\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 986\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:1025\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1023\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 1024\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1025\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1026\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1027\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:205\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:363\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 363\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone:\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end(data_fetcher)\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:250\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.advance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_batch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mlightning_module\u001b[38;5;241m.\u001b[39mautomatic_optimization:\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# in automatic optimization, there can only be one optimizer\u001b[39;00m\n\u001b[0;32m--> 250\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautomatic_optimization\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 252\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_optimization\u001b[38;5;241m.\u001b[39mrun(kwargs)\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:190\u001b[0m, in \u001b[0;36m_AutomaticOptimization.run\u001b[0;34m(self, optimizer, batch_idx, kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m closure()\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# BACKWARD PASS\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# gradient update with accumulated gradients\u001b[39;00m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 190\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 192\u001b[0m result \u001b[38;5;241m=\u001b[39m closure\u001b[38;5;241m.\u001b[39mconsume_result()\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:268\u001b[0m, in \u001b[0;36m_AutomaticOptimization._optimizer_step\u001b[0;34m(self, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_ready()\n\u001b[1;32m 267\u001b[0m \u001b[38;5;66;03m# model hook\u001b[39;00m\n\u001b[0;32m--> 268\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_lightning_module_hook\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moptimizer_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_step_and_backward_closure\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m should_accumulate:\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_completed()\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:167\u001b[0m, in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(trainer, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m hook_name\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[LightningModule]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpl_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 167\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 170\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/module.py:1306\u001b[0m, in \u001b[0;36mLightningModule.optimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_closure)\u001b[0m\n\u001b[1;32m 1275\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21moptimizer_step\u001b[39m(\n\u001b[1;32m 1276\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1277\u001b[0m epoch: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1280\u001b[0m optimizer_closure: Optional[Callable[[], Any]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1281\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls\u001b[39;00m\n\u001b[1;32m 1283\u001b[0m \u001b[38;5;124;03m the optimizer.\u001b[39;00m\n\u001b[1;32m 1284\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1304\u001b[0m \n\u001b[1;32m 1305\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1306\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer_closure\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/optimizer.py:153\u001b[0m, in \u001b[0;36mLightningOptimizer.step\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MisconfigurationException(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWhen `optimizer.step(closure)` is called, the closure should be callable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_strategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 153\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_strategy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_after_step()\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m step_output\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:238\u001b[0m, in \u001b[0;36mStrategy.optimizer_step\u001b[0;34m(self, optimizer, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;66;03m# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed\u001b[39;00m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, pl\u001b[38;5;241m.\u001b[39mLightningModule)\n\u001b[0;32m--> 238\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_plugin\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:122\u001b[0m, in \u001b[0;36mPrecision.optimizer_step\u001b[0;34m(self, optimizer, model, closure, **kwargs)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Hook to run the optimizer step.\"\"\"\u001b[39;00m\n\u001b[1;32m 121\u001b[0m closure \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wrap_closure, model, optimizer, closure)\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:130\u001b[0m, in \u001b[0;36mLRScheduler.__init__.<locals>.patch_track_step_called.<locals>.wrap_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 128\u001b[0m opt \u001b[38;5;241m=\u001b[39m opt_ref()\n\u001b[1;32m 129\u001b[0m opt\u001b[38;5;241m.\u001b[39m_opt_called \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m \u001b[38;5;66;03m# type: ignore[union-attr]\u001b[39;00m\n\u001b[0;32m--> 130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__get__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mopt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__class__\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:484\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 481\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 482\u001b[0m )\n\u001b[0;32m--> 484\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 485\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 487\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:89\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.<locals>._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 87\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 88\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 89\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 91\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/adam.py:205\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m closure \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39menable_grad():\n\u001b[0;32m--> 205\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparam_groups:\n\u001b[1;32m 208\u001b[0m params_with_grad: List[Tensor] \u001b[38;5;241m=\u001b[39m []\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:108\u001b[0m, in \u001b[0;36mPrecision._wrap_closure\u001b[0;34m(self, model, optimizer, closure)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrap_closure\u001b[39m(\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 97\u001b[0m model: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpl.LightningModule\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 98\u001b[0m optimizer: Steppable,\n\u001b[1;32m 99\u001b[0m closure: Callable[[], Any],\n\u001b[1;32m 100\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 101\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;124;03m hook is called.\u001b[39;00m\n\u001b[1;32m 103\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 108\u001b[0m closure_result \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_after_closure(model, optimizer)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m closure_result\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:144\u001b[0m, in \u001b[0;36mClosure.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Optional[Tensor]:\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result\u001b[38;5;241m.\u001b[39mloss\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:129\u001b[0m, in \u001b[0;36mClosure.closure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39menable_grad()\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mclosure\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ClosureResult:\n\u001b[0;32m--> 129\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step_output\u001b[38;5;241m.\u001b[39mclosure_loss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwarning_cache\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`training_step` returned `None`. If this was on purpose, ignore this warning...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:317\u001b[0m, in \u001b[0;36m_AutomaticOptimization._training_step\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Performs the actual train step with the tied hooks.\u001b[39;00m\n\u001b[1;32m 307\u001b[0m \n\u001b[1;32m 308\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 313\u001b[0m \n\u001b[1;32m 314\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 315\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\n\u001b[0;32m--> 317\u001b[0m training_step_output \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtraining_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mpost_training_step() \u001b[38;5;66;03m# unused hook - call anyway for backward compatibility\u001b[39;00m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m training_step_output \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mworld_size \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:319\u001b[0m, in \u001b[0;36m_call_strategy_hook\u001b[0;34m(trainer, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 319\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 322\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.training_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module:\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_redirection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_step\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:167\u001b[0m, in \u001b[0;36mDetectorModel.training_step\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtraining_step\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch: TrainExample):\n\u001b[1;32m 166\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(batch\u001b[38;5;241m.\u001b[39mspec)\n\u001b[0;32m--> 167\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:150\u001b[0m, in \u001b[0;36mDetectorModel.compute_loss\u001b[0;34m(self, outputs, batch)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mclass props shape\u001b[39m\u001b[38;5;124m\"\u001b[39m, outputs\u001b[38;5;241m.\u001b[39mclass_probs\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 149\u001b[0m valid_mask \u001b[38;5;241m=\u001b[39m batch\u001b[38;5;241m.\u001b[39mclass_heatmap\u001b[38;5;241m.\u001b[39many(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m--> 150\u001b[0m classification_loss \u001b[38;5;241m=\u001b[39m \u001b[43mlosses\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal_loss\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[43m \u001b[49m\u001b[43moutputs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_probs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 152\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_heatmap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 153\u001b[0m \u001b[43m \u001b[49m\u001b[43mweights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalid_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalid_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbeta\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 156\u001b[0m \u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43malpha\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 157\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 160\u001b[0m detection_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mdetection\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 161\u001b[0m \u001b[38;5;241m+\u001b[39m size_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39msize\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 162\u001b[0m \u001b[38;5;241m+\u001b[39m classification_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mclassification\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 163\u001b[0m )\n",
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/losses.py:38\u001b[0m, in \u001b[0;36mfocal_loss\u001b[0;34m(pred, gt, weights, valid_mask, eps, beta, alpha)\u001b[0m\n\u001b[1;32m 35\u001b[0m pos_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39meq(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[1;32m 36\u001b[0m neg_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39mlt(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m---> 38\u001b[0m pos_loss \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpred\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpow\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpred\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpos_inds\u001b[49m\n\u001b[1;32m 39\u001b[0m neg_loss \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 40\u001b[0m torch\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m pred \u001b[38;5;241m+\u001b[39m eps)\n\u001b[1;32m 41\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(pred, alpha)\n\u001b[1;32m 42\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m gt, beta)\n\u001b[1;32m 43\u001b[0m \u001b[38;5;241m*\u001b[39m neg_inds\n\u001b[1;32m 44\u001b[0m )\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m weights \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1"
]
}
],
"source": [
"trainer.fit(detector, train_dataloaders=train_dataloader)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f6924db-e520-49a1-bbe8-6c4956e46314",
"metadata": {
"execution": {
"iopub.status.busy": "2024-11-19T17:33:10.811729Z",
"iopub.status.idle": "2024-11-19T17:33:10.811955Z",
"shell.execute_reply": "2024-11-19T17:33:10.811858Z",
"shell.execute_reply.started": "2024-11-19T17:33:10.811849Z"
}
},
"outputs": [],
"source": [
"clip_annotation = train_dataset.get_clip_annotation(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "23943e13-6875-49b8-9f18-2ba6528aa673",
"metadata": {
"execution": {
"iopub.status.busy": "2024-11-19T17:33:10.812924Z",
"iopub.status.idle": "2024-11-19T17:33:10.813260Z",
"shell.execute_reply": "2024-11-19T17:33:10.813104Z",
"shell.execute_reply.started": "2024-11-19T17:33:10.813087Z"
}
},
"outputs": [],
"source": [
"spec = detector.compute_spectrogram(clip_annotation.clip)\n",
"outputs = detector(torch.tensor(spec.values).unsqueeze(0).unsqueeze(0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd1fe346-0873-4b14-ae1b-92ef1f4f27a5",
"metadata": {
"execution": {
"iopub.status.busy": "2024-11-19T17:33:10.814343Z",
"iopub.status.idle": "2024-11-19T17:33:10.814806Z",
"shell.execute_reply": "2024-11-19T17:33:10.814628Z",
"shell.execute_reply.started": "2024-11-19T17:33:10.814611Z"
}
},
"outputs": [],
"source": [
"_, ax= plt.subplots(figsize=(15, 5))\n",
"spec.plot(ax=ax, add_colorbar=False)\n",
"ax.pcolormesh(spec.time, spec.frequency, outputs.detection_probs.detach().squeeze())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eadd36ef-a04a-4665-b703-cec84cf1673b",
"metadata": {
"execution": {
"iopub.status.busy": "2024-11-19T17:33:10.815603Z",
"iopub.status.idle": "2024-11-19T17:33:10.816065Z",
"shell.execute_reply": "2024-11-19T17:33:10.815894Z",
"shell.execute_reply.started": "2024-11-19T17:33:10.815877Z"
}
},
"outputs": [],
"source": [
"print(f\"Num predicted soundevents: {len(predictions.sound_events)}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e4e54f3e-6ddc-4fe5-8ce0-b527ff6f18ae",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "batdetect2-dev",
"language": "python",
"name": "batdetect2-dev"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -1,194 +0,0 @@
import marimo
__generated_with = "0.14.16"
app = marimo.App(width="medium")
@app.cell
def _():
import marimo as mo
return (mo,)
@app.cell
def _():
from batdetect2.data import (
load_dataset_config,
load_dataset,
extract_recordings_df,
extract_sound_events_df,
compute_class_summary,
)
return (
compute_class_summary,
extract_recordings_df,
extract_sound_events_df,
load_dataset,
load_dataset_config,
)
@app.cell
def _(mo):
dataset_config_browser = mo.ui.file_browser(
selection_mode="file",
multiple=False,
)
dataset_config_browser
return (dataset_config_browser,)
@app.cell
def _(dataset_config_browser, load_dataset_config, mo):
mo.stop(dataset_config_browser.path() is None)
dataset_config = load_dataset_config(dataset_config_browser.path())
return (dataset_config,)
@app.cell
def _(dataset_config, load_dataset):
dataset = load_dataset(dataset_config, base_dir="../paper/")
return (dataset,)
@app.cell
def _():
from batdetect2.targets import load_target_config, build_targets
return build_targets, load_target_config
@app.cell
def _(mo):
targets_config_browser = mo.ui.file_browser(
selection_mode="file",
multiple=False,
)
targets_config_browser
return (targets_config_browser,)
@app.cell
def _(load_target_config, mo, targets_config_browser):
mo.stop(targets_config_browser.path() is None)
targets_config = load_target_config(targets_config_browser.path())
return (targets_config,)
@app.cell
def _(build_targets, targets_config):
targets = build_targets(targets_config)
return (targets,)
@app.cell
def _():
import pandas as pd
from soundevent.geometry import compute_bounds
return
@app.cell
def _(dataset, extract_recordings_df):
recordings = extract_recordings_df(dataset)
recordings
return
@app.cell
def _(dataset, extract_sound_events_df, targets):
sound_events = extract_sound_events_df(dataset, targets)
sound_events
return
@app.cell
def _(compute_class_summary, dataset, targets):
compute_class_summary(dataset, targets)
return
@app.cell
def _():
from batdetect2.data.split import split_dataset_by_recordings
return (split_dataset_by_recordings,)
@app.cell
def _(dataset, split_dataset_by_recordings, targets):
train_dataset, val_dataset = split_dataset_by_recordings(dataset, targets, random_state=42)
return train_dataset, val_dataset
@app.cell
def _(compute_class_summary, targets, train_dataset):
compute_class_summary(train_dataset, targets)
return
@app.cell
def _(compute_class_summary, targets, val_dataset):
compute_class_summary(val_dataset, targets)
return
@app.cell
def _():
from soundevent import io, data
from pathlib import Path
return Path, data, io
@app.cell
def _(Path, data, io, train_dataset):
io.save(
data.AnnotationSet(
name="batdetect2_tuning_train",
description="Set of annotations used as the train dataset for the hyper-parameter tuning stage.",
clip_annotations=train_dataset,
),
Path("../paper/data/datasets/annotation_sets/tuning_train.json"),
audio_dir=Path("../paper/data/datasets/"),
)
return
@app.cell
def _(Path, data, io, val_dataset):
io.save(
data.AnnotationSet(
name="batdetect2_tuning_val",
description="Set of annotations used as the validation dataset for the hyper-parameter tuning stage.",
clip_annotations=val_dataset,
),
Path("../paper/data/datasets/annotation_sets/tuning_val.json"),
audio_dir=Path("../paper/data/datasets/"),
)
return
@app.cell
def _(load_dataset, load_dataset_config):
config = load_dataset_config("../paper/conf/datasets/train/uk_tune.yaml")
rec = load_dataset(config, base_dir="../paper/")
return (rec,)
@app.cell
def _(rec):
dict(rec[0].sound_events[0].tags[0].term)
return
@app.cell
def _(compute_class_summary, rec, targets):
compute_class_summary(rec,targets)
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

View File

@ -1,273 +0,0 @@
import marimo
__generated_with = "0.14.16"
app = marimo.App(width="medium")
@app.cell
def _():
import marimo as mo
return
@app.cell
def _():
from batdetect2.data import load_dataset_config, load_dataset
from batdetect2.preprocess import load_preprocessing_config, build_preprocessor
from batdetect2 import api
from soundevent import data
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.types import Annotation
from batdetect2.compat import annotation_to_sound_event_prediction
from batdetect2.plotting import (
plot_clip,
plot_clip_annotation,
plot_clip_prediction,
plot_matches,
plot_false_positive_match,
plot_false_negative_match,
plot_true_positive_match,
plot_cross_trigger_match,
)
return (
MatchEvaluation,
annotation_to_sound_event_prediction,
api,
build_preprocessor,
data,
load_dataset,
load_dataset_config,
load_preprocessing_config,
plot_clip_annotation,
plot_clip_prediction,
plot_cross_trigger_match,
plot_false_negative_match,
plot_false_positive_match,
plot_matches,
plot_true_positive_match,
)
@app.cell
def _(build_preprocessor, load_dataset_config, load_preprocessing_config):
dataset_config = load_dataset_config(
path="example_data/config.yaml", field="datasets.train"
)
preprocessor_config = load_preprocessing_config(
path="example_data/config.yaml", field="preprocess"
)
preprocessor = build_preprocessor(preprocessor_config)
return dataset_config, preprocessor
@app.cell
def _(dataset_config, load_dataset):
dataset = load_dataset(dataset_config)
return (dataset,)
@app.cell
def _(dataset):
clip_annotation = dataset[1]
return (clip_annotation,)
@app.cell
def _(clip_annotation, plot_clip_annotation, preprocessor):
plot_clip_annotation(
clip_annotation, preprocessor=preprocessor, figsize=(15, 5)
)
return
@app.cell
def _(annotation_to_sound_event_prediction, api, clip_annotation, data):
audio = api.load_audio(clip_annotation.clip.recording.path)
detections, features, spec = api.process_audio(audio)
clip_prediction = data.ClipPrediction(
clip=clip_annotation.clip,
sound_events=[
annotation_to_sound_event_prediction(
prediction, clip_annotation.clip.recording
)
for prediction in detections
],
)
return (clip_prediction,)
@app.cell
def _(clip_prediction, plot_clip_prediction):
plot_clip_prediction(clip_prediction, figsize=(15, 5))
return
@app.cell
def _():
from batdetect2.evaluate import match_predictions_and_annotations
import random
return match_predictions_and_annotations, random
@app.cell
def _(data, random):
def add_noise(clip_annotation, time_buffer=0.003, freq_buffer=1000):
def _add_bbox_noise(bbox):
start_time, low_freq, end_time, high_freq = bbox.coordinates
return data.BoundingBox(
coordinates=[
start_time + random.uniform(-time_buffer, time_buffer),
low_freq + random.uniform(-freq_buffer, freq_buffer),
end_time + random.uniform(-time_buffer, time_buffer),
high_freq + random.uniform(-freq_buffer, freq_buffer),
]
)
def _add_noise(se):
return se.model_copy(
update=dict(
sound_event=se.sound_event.model_copy(
update=dict(
geometry=_add_bbox_noise(se.sound_event.geometry)
)
)
)
)
return clip_annotation.model_copy(
update=dict(
sound_events=[
_add_noise(se) for se in clip_annotation.sound_events
]
)
)
def drop_random(obj, p=0.5):
return obj.model_copy(
update=dict(
sound_events=[se for se in obj.sound_events if random.random() > p]
)
)
return add_noise, drop_random
@app.cell
def _(
add_noise,
clip_annotation,
clip_prediction,
drop_random,
match_predictions_and_annotations,
):
matches = match_predictions_and_annotations(
drop_random(add_noise(clip_annotation), p=0.2),
drop_random(clip_prediction),
)
return (matches,)
@app.cell
def _(clip_annotation, matches, plot_matches):
plot_matches(matches, clip_annotation.clip, figsize=(15, 5))
return
@app.cell
def _(matches):
true_positives = []
false_positives = []
false_negatives = []
for match in matches:
if match.source is None and match.target is not None:
false_negatives.append(match)
elif match.target is None and match.source is not None:
false_positives.append(match)
elif match.target is not None and match.source is not None:
true_positives.append(match)
else:
continue
return false_negatives, false_positives, true_positives
@app.cell
def _(MatchEvaluation, false_positives, plot_false_positive_match):
false_positive = false_positives[0]
false_positive_eval = MatchEvaluation(
match=false_positive,
gt_det=False,
gt_class=None,
pred_score=false_positive.source.score,
pred_class_scores={
"myomyo": 0.2
}
)
plot_false_positive_match(false_positive_eval)
return
@app.cell
def _(MatchEvaluation, false_negatives, plot_false_negative_match):
false_negative = false_negatives[0]
false_negative_eval = MatchEvaluation(
match=false_negative,
gt_det=True,
gt_class="myomyo",
pred_score=None,
pred_class_scores={}
)
plot_false_negative_match(false_negative_eval)
return
@app.cell
def _(MatchEvaluation, plot_true_positive_match, true_positives):
true_positive = true_positives[0]
true_positive_eval = MatchEvaluation(
match=true_positive,
gt_det=True,
gt_class="myomyo",
pred_score=0.87,
pred_class_scores={
"pyomyo": 0.84,
"pippip": 0.84,
}
)
plot_true_positive_match(true_positive_eval)
return (true_positive,)
@app.cell
def _(MatchEvaluation, plot_cross_trigger_match, true_positive):
cross_trigger_eval = MatchEvaluation(
match=true_positive,
gt_det=True,
gt_class="myomyo",
pred_score=0.87,
pred_class_scores={
"pippip": 0.84,
"myomyo": 0.84,
}
)
plot_cross_trigger_match(cross_trigger_eval)
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

View File

@ -1,19 +0,0 @@
import marimo
__generated_with = "0.13.15"
app = marimo.App(width="medium")
@app.cell
def _():
from batdetect2.preprocess import build_preprocessor
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

View File

@ -1,97 +0,0 @@
import marimo
__generated_with = "0.14.5"
app = marimo.App(width="medium")
@app.cell
def _():
import marimo as mo
return (mo,)
@app.cell
def _():
from batdetect2.data import load_dataset, load_dataset_config
return load_dataset, load_dataset_config
@app.cell
def _(mo):
dataset_input = mo.ui.file_browser(label="dataset config file")
dataset_input
return (dataset_input,)
@app.cell
def _(mo):
audio_dir_input = mo.ui.file_browser(
label="audio directory", selection_mode="directory"
)
audio_dir_input
return (audio_dir_input,)
@app.cell
def _(dataset_input, load_dataset_config):
dataset_config = load_dataset_config(
path=dataset_input.path(), field="datasets.train",
)
return (dataset_config,)
@app.cell
def _(audio_dir_input, dataset_config, load_dataset):
dataset = load_dataset(dataset_config, base_dir=audio_dir_input.path())
return (dataset,)
@app.cell
def _(dataset):
len(dataset)
return
@app.cell
def _(dataset):
tag_groups = [
se.tags
for clip in dataset
for se in clip.sound_events
if se.tags
]
all_tags = [
tag for group in tag_groups for tag in group
]
return (all_tags,)
@app.cell
def _(mo):
key_search = mo.ui.text(label="key", debounce=0.1)
value_search = mo.ui.text(label="value", debounce=0.1)
mo.hstack([key_search, value_search]).left()
return key_search, value_search
@app.cell
def _(all_tags, key_search, mo, value_search):
filtered_tags = list(set(all_tags))
if key_search.value:
filtered_tags = [tag for tag in filtered_tags if key_search.value.lower() in tag.key.lower()]
if value_search.value:
filtered_tags = [tag for tag in filtered_tags if value_search.value.lower() in tag.value.lower()]
mo.vstack([mo.md(f"key={tag.key} value={tag.value}") for tag in filtered_tags[:5]])
return
@app.cell
def _():
return
if __name__ == "__main__":
app.run()

View File

@ -25,14 +25,14 @@ dependencies = [
"scikit-learn>=1.2.2",
"scipy>=1.10.1",
"seaborn>=0.13.2",
"soundevent[audio,geometry,plot]>=2.9.1",
"soundevent[audio,geometry,plot]>=2.10.0",
"tensorboard>=2.16.2",
"torch>=1.13.1",
"torchaudio>=1.13.1",
"torchvision>=0.14.0",
"tqdm>=4.66.2",
]
requires-python = ">=3.9,<3.13"
requires-python = ">=3.10,<3.14"
readme = "README.md"
license = { text = "CC-by-nc-4" }
classifiers = [
@ -75,7 +75,6 @@ dev = [
"ruff>=0.7.3",
"ipykernel>=6.29.4",
"setuptools>=69.5.1",
"basedpyright>=1.31.0",
"myst-parser>=3.0.1",
"sphinx-autobuild>=2024.10.3",
"numpydoc>=1.8.0",
@ -87,13 +86,24 @@ dev = [
"rust-just>=1.40.0",
"pandas-stubs>=2.2.2.240807",
"python-lsp-server>=1.13.0",
"deepdiff>=8.6.1",
]
dvclive = ["dvclive>=3.48.2"]
mlflow = ["mlflow>=3.1.1"]
gradio = [
"gradio>=6.9.0",
]
[tool.ruff]
line-length = 79
target-version = "py39"
target-version = "py310"
exclude = [
"src/batdetect2/train/legacy",
"src/batdetect2/plotting/legacy",
"src/batdetect2/evaluate/legacy",
"src/batdetect2/finetune",
"src/batdetect2/utils",
]
[tool.ruff.format]
docstring-code-format = true
@ -105,15 +115,12 @@ select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
[tool.ruff.lint.pydocstyle]
convention = "numpy"
[tool.pyright]
[tool.ty.src]
include = ["src", "tests"]
pythonVersion = "3.9"
pythonPlatform = "All"
exclude = [
"src/batdetect2/detector/",
"src/batdetect2/train/legacy",
"src/batdetect2/plotting/legacy",
"src/batdetect2/evaluate/legacy",
"src/batdetect2/finetune",
"src/batdetect2/utils",
"src/batdetect2/plot",
"src/batdetect2/evaluate/legacy",
"src/batdetect2/train/legacy",
]

View File

@ -98,7 +98,6 @@ consult the API documentation in the code.
"""
import warnings
from typing import List, Optional, Tuple
import numpy as np
import torch
@ -165,7 +164,7 @@ def load_audio(
time_exp_fact: float = 1,
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
scale: bool = False,
max_duration: Optional[float] = None,
max_duration: float | None = None,
) -> np.ndarray:
"""Load audio from file.
@ -203,7 +202,7 @@ def load_audio(
def generate_spectrogram(
audio: np.ndarray,
samp_rate: int = TARGET_SAMPLERATE_HZ,
config: Optional[SpectrogramParameters] = None,
config: SpectrogramParameters | None = None,
device: torch.device = DEVICE,
) -> torch.Tensor:
"""Generate spectrogram from audio array.
@ -240,7 +239,7 @@ def generate_spectrogram(
def process_file(
audio_file: str,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
config: ProcessingConfiguration | None = None,
device: torch.device = DEVICE,
) -> du.RunResults:
"""Process audio file with model.
@ -271,8 +270,8 @@ def process_spectrogram(
spec: torch.Tensor,
samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
) -> Tuple[List[Annotation], np.ndarray]:
config: ProcessingConfiguration | None = None,
) -> tuple[list[Annotation], np.ndarray]:
"""Process spectrogram with model.
Parameters
@ -312,9 +311,9 @@ def process_audio(
audio: np.ndarray,
samp_rate: int = TARGET_SAMPLERATE_HZ,
model: DetectionModel = MODEL,
config: Optional[ProcessingConfiguration] = None,
config: ProcessingConfiguration | None = None,
device: torch.device = DEVICE,
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
) -> tuple[list[Annotation], np.ndarray, torch.Tensor]:
"""Process audio array with model.
Parameters
@ -356,8 +355,8 @@ def process_audio(
def postprocess(
outputs: ModelOutput,
samp_rate: int = TARGET_SAMPLERATE_HZ,
config: Optional[ProcessingConfiguration] = None,
) -> Tuple[List[Annotation], np.ndarray]:
config: ProcessingConfiguration | None = None,
) -> tuple[list[Annotation], np.ndarray]:
"""Postprocess model outputs.
Convert model tensor outputs to predicted bounding boxes and

View File

@ -1,59 +1,90 @@
import json
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
from typing import Literal, Sequence, cast
import numpy as np
import torch
from soundevent import data
from soundevent.audio.files import get_audio_files
from batdetect2.audio import build_audio_loader
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
from batdetect2.config import BatDetect2Config
from batdetect2.core import merge_configs
from batdetect2.data import (
OutputFormatConfig,
build_output_formatter,
get_output_formatter,
load_dataset_from_config,
from batdetect2.data import Dataset, load_dataset_from_config
from batdetect2.evaluate import (
DEFAULT_EVAL_DIR,
EvaluationConfig,
EvaluatorProtocol,
build_evaluator,
run_evaluate,
save_evaluation_results,
)
from batdetect2.data.datasets import Dataset
from batdetect2.data.predictions.base import OutputFormatterProtocol
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR
from batdetect2.models import Model, build_model
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.inference import (
InferenceConfig,
process_file_list,
run_batch_inference,
)
from batdetect2.logging import (
DEFAULT_LOGS_DIR,
AppLoggingConfig,
LoggerConfig,
)
from batdetect2.models import (
Model,
ModelConfig,
build_model,
build_model_with_new_targets,
)
from batdetect2.models.detectors import Detector
from batdetect2.outputs import (
OutputFormatConfig,
OutputFormatterProtocol,
OutputsConfig,
OutputTransformProtocol,
build_output_formatter,
build_output_transform,
get_output_formatter,
)
from batdetect2.postprocess import (
ClipDetections,
Detection,
PostprocessorProtocol,
build_postprocessor,
)
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
from batdetect2.train import (
DEFAULT_CHECKPOINT_DIR,
TrainingConfig,
load_model_from_checkpoint,
train,
)
from batdetect2.typing import (
AudioLoader,
BatDetect2Prediction,
EvaluatorProtocol,
PostprocessorProtocol,
PreprocessorProtocol,
RawPrediction,
TargetProtocol,
run_train,
)
class BatDetect2API:
def __init__(
self,
config: BatDetect2Config,
model_config: ModelConfig,
audio_config: AudioConfig,
train_config: TrainingConfig,
evaluation_config: EvaluationConfig,
inference_config: InferenceConfig,
outputs_config: OutputsConfig,
logging_config: AppLoggingConfig,
targets: TargetProtocol,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
evaluator: EvaluatorProtocol,
formatter: OutputFormatterProtocol,
output_transform: OutputTransformProtocol,
model: Model,
):
self.config = config
self.model_config = model_config
self.audio_config = audio_config
self.train_config = train_config
self.evaluation_config = evaluation_config
self.inference_config = inference_config
self.outputs_config = outputs_config
self.logging_config = logging_config
self.targets = targets
self.audio_loader = audio_loader
self.preprocessor = preprocessor
@ -61,34 +92,40 @@ class BatDetect2API:
self.evaluator = evaluator
self.model = model
self.formatter = formatter
self.output_transform = output_transform
self.model.eval()
def load_annotations(
self,
path: data.PathLike,
base_dir: Optional[data.PathLike] = None,
base_dir: data.PathLike | None = None,
) -> Dataset:
return load_dataset_from_config(path, base_dir=base_dir)
def train(
self,
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
train_workers: Optional[int] = None,
val_workers: Optional[int] = None,
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
experiment_name: Optional[str] = None,
num_epochs: Optional[int] = None,
run_name: Optional[str] = None,
seed: Optional[int] = None,
val_annotations: Sequence[data.ClipAnnotation] | None = None,
train_workers: int = 0,
val_workers: int = 0,
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
log_dir: Path | None = DEFAULT_LOGS_DIR,
experiment_name: str | None = None,
num_epochs: int | None = None,
run_name: str | None = None,
seed: int | None = None,
model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None,
):
train(
run_train(
train_annotations=train_annotations,
val_annotations=val_annotations,
model=self.model,
targets=self.targets,
config=self.config,
model_config=model_config or self.model_config,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
train_workers=train_workers,
@ -99,25 +136,81 @@ class BatDetect2API:
experiment_name=experiment_name,
run_name=run_name,
seed=seed,
train_config=train_config or self.train_config,
audio_config=audio_config or self.audio_config,
logger_config=logger_config or self.logging_config.train,
)
return self
def finetune(
self,
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Sequence[data.ClipAnnotation] | None = None,
trainable: Literal[
"all", "heads", "classifier_head", "bbox_head"
] = "heads",
train_workers: int = 0,
val_workers: int = 0,
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
log_dir: Path | None = DEFAULT_LOGS_DIR,
experiment_name: str | None = None,
num_epochs: int | None = None,
run_name: str | None = None,
seed: int | None = None,
model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None,
) -> "BatDetect2API":
"""Fine-tune the model with trainable-parameter selection."""
self._set_trainable_parameters(trainable)
run_train(
train_annotations=train_annotations,
val_annotations=val_annotations,
model=self.model,
targets=self.targets,
model_config=model_config or self.model_config,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
train_workers=train_workers,
val_workers=val_workers,
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
experiment_name=experiment_name,
num_epochs=num_epochs,
run_name=run_name,
seed=seed,
audio_config=audio_config or self.audio_config,
train_config=train_config or self.train_config,
logger_config=logger_config or self.logging_config.train,
)
return self
def evaluate(
self,
test_annotations: Sequence[data.ClipAnnotation],
num_workers: Optional[int] = None,
num_workers: int = 0,
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
experiment_name: str | None = None,
run_name: str | None = None,
save_predictions: bool = True,
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
return evaluate(
audio_config: AudioConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
outputs_config: OutputsConfig | None = None,
logger_config: LoggerConfig | None = None,
) -> tuple[dict[str, float], list[ClipDetections]]:
return run_evaluate(
self.model,
test_annotations,
targets=self.targets,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
config=self.config,
audio_config=audio_config or self.audio_config,
evaluation_config=evaluation_config or self.evaluation_config,
output_config=outputs_config or self.outputs_config,
logger_config=logger_config or self.logging_config.evaluation,
num_workers=num_workers,
output_dir=output_dir,
experiment_name=experiment_name,
@ -128,8 +221,8 @@ class BatDetect2API:
def evaluate_predictions(
self,
annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction],
output_dir: Optional[data.PathLike] = None,
predictions: Sequence[ClipDetections],
output_dir: data.PathLike | None = None,
):
clip_evals = self.evaluator.evaluate(
annotations,
@ -139,30 +232,66 @@ class BatDetect2API:
metrics = self.evaluator.compute_metrics(clip_evals)
if output_dir is not None:
output_dir = Path(output_dir)
if not output_dir.is_dir():
output_dir.mkdir(parents=True)
metrics_path = output_dir / "metrics.json"
metrics_path.write_text(json.dumps(metrics))
for figure_name, fig in self.evaluator.generate_plots(clip_evals):
fig_path = output_dir / figure_name
if not fig_path.parent.is_dir():
fig_path.parent.mkdir(parents=True)
fig.savefig(fig_path)
save_evaluation_results(
metrics=metrics,
plots=self.evaluator.generate_plots(clip_evals),
output_dir=output_dir,
)
return metrics
def load_audio(self, path: data.PathLike) -> np.ndarray:
return self.audio_loader.load_file(path)
def load_recording(self, recording: data.Recording) -> np.ndarray:
return self.audio_loader.load_recording(recording)
def load_clip(self, clip: data.Clip) -> np.ndarray:
return self.audio_loader.load_clip(clip)
def get_top_class_name(self, detection: Detection) -> str:
"""Get highest-confidence class name for one detection."""
top_index = int(np.argmax(detection.class_scores))
return self.targets.class_names[top_index]
def get_class_scores(
self,
detection: Detection,
*,
include_top_class: bool = True,
sort_descending: bool = True,
) -> list[tuple[str, float]]:
"""Get class score list as ``(class_name, score)`` pairs."""
scores = [
(class_name, float(score))
for class_name, score in zip(
self.targets.class_names,
detection.class_scores,
strict=True,
)
]
if sort_descending:
scores.sort(key=lambda item: item[1], reverse=True)
if include_top_class:
return scores
top_class_name = self.get_top_class_name(detection)
return [
(class_name, score)
for class_name, score in scores
if class_name != top_class_name
]
@staticmethod
def get_detection_features(detection: Detection) -> np.ndarray:
"""Get extracted feature vector for one detection."""
return detection.features
def generate_spectrogram(
self,
audio: np.ndarray,
@ -170,24 +299,41 @@ class BatDetect2API:
tensor = torch.tensor(audio).unsqueeze(0)
return self.preprocessor(tensor)
def process_file(self, audio_file: str) -> BatDetect2Prediction:
def process_file(
self,
audio_file: data.PathLike,
batch_size: int | None = None,
) -> ClipDetections:
recording = data.Recording.from_file(audio_file, compute_hash=False)
wav = self.audio_loader.load_recording(recording)
detections = self.process_audio(wav)
return BatDetect2Prediction(
predictions = self.process_files(
[audio_file],
batch_size=(
batch_size
if batch_size is not None
else self.inference_config.loader.batch_size
),
)
detections = [
detection
for prediction in predictions
for detection in prediction.detections
]
return ClipDetections(
clip=data.Clip(
uuid=recording.uuid,
recording=recording,
start_time=0,
end_time=recording.duration,
),
predictions=detections,
detections=detections,
)
def process_audio(
self,
audio: np.ndarray,
) -> List[RawPrediction]:
) -> list[Detection]:
spec = self.generate_spectrogram(audio)
return self.process_spectrogram(spec)
@ -195,7 +341,7 @@ class BatDetect2API:
self,
spec: torch.Tensor,
start_time: float = 0,
) -> List[RawPrediction]:
) -> list[Detection]:
if spec.ndim == 4 and spec.shape[0] > 1:
raise ValueError("Batched spectrograms not supported.")
@ -204,59 +350,74 @@ class BatDetect2API:
outputs = self.model.detector(spec)
detections = self.model.postprocessor(
detections = self.postprocessor(
outputs,
start_times=[start_time],
)[0]
return to_raw_predictions(detections.numpy(), targets=self.targets)
return self.output_transform.to_detections(
detections=detections,
start_time=start_time,
)
def process_directory(
self,
audio_dir: data.PathLike,
) -> List[BatDetect2Prediction]:
) -> list[ClipDetections]:
files = list(get_audio_files(audio_dir))
return self.process_files(files)
def process_files(
self,
audio_files: Sequence[data.PathLike],
num_workers: Optional[int] = None,
) -> List[BatDetect2Prediction]:
batch_size: int | None = None,
num_workers: int = 0,
audio_config: AudioConfig | None = None,
inference_config: InferenceConfig | None = None,
output_config: OutputsConfig | None = None,
) -> list[ClipDetections]:
return process_file_list(
self.model,
audio_files,
config=self.config,
targets=self.targets,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
output_transform=self.output_transform,
batch_size=batch_size,
num_workers=num_workers,
audio_config=audio_config or self.audio_config,
inference_config=inference_config or self.inference_config,
output_config=output_config or self.outputs_config,
)
def process_clips(
self,
clips: Sequence[data.Clip],
batch_size: Optional[int] = None,
num_workers: Optional[int] = None,
) -> List[BatDetect2Prediction]:
batch_size: int | None = None,
num_workers: int = 0,
audio_config: AudioConfig | None = None,
inference_config: InferenceConfig | None = None,
output_config: OutputsConfig | None = None,
) -> list[ClipDetections]:
return run_batch_inference(
self.model,
clips,
targets=self.targets,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
config=self.config,
output_transform=self.output_transform,
batch_size=batch_size,
num_workers=num_workers,
audio_config=audio_config or self.audio_config,
inference_config=inference_config or self.inference_config,
output_config=output_config or self.outputs_config,
)
def save_predictions(
self,
predictions: Sequence[BatDetect2Prediction],
predictions: Sequence[ClipDetections],
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
format: Optional[str] = None,
config: Optional[OutputFormatConfig] = None,
audio_dir: data.PathLike | None = None,
format: str | None = None,
config: OutputFormatConfig | None = None,
):
formatter = self.formatter
@ -274,50 +435,78 @@ class BatDetect2API:
def load_predictions(
self,
path: data.PathLike,
) -> List[BatDetect2Prediction]:
return self.formatter.load(path)
format: str | None = None,
config: OutputFormatConfig | None = None,
) -> list[object]:
formatter = self.formatter
if format is not None or config is not None:
format = format or config.name # type: ignore
formatter = get_output_formatter(
name=format,
targets=self.targets,
config=config,
)
return formatter.load(path)
@classmethod
def from_config(
cls,
config: BatDetect2Config,
):
targets = build_targets(config=config.targets)
) -> "BatDetect2API":
targets = build_targets(config=config.model.targets)
audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
config=config.model.preprocess,
)
postprocessor = build_postprocessor(
preprocessor,
config=config.postprocess,
config=config.model.postprocess,
)
evaluator = build_evaluator(config=config.evaluation, targets=targets)
formatter = build_output_formatter(
targets,
config=config.outputs.format,
)
output_transform = build_output_transform(
config=config.outputs.transform,
targets=targets,
)
# NOTE: Better to have a separate instance of
# preprocessor and postprocessor as these may be moved
# to another device.
evaluator = build_evaluator(
config=config.evaluation,
targets=targets,
transform=output_transform,
)
# NOTE: Build separate instances of preprocessor and postprocessor
# to avoid device mismatch errors
model = build_model(
config=config.model,
targets=targets,
targets=build_targets(config=config.model.targets),
preprocessor=build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
config=config.model.preprocess,
),
postprocessor=build_postprocessor(
preprocessor,
config=config.postprocess,
config=config.model.postprocess,
),
)
formatter = build_output_formatter(targets, config=config.output)
return cls(
config=config,
model_config=config.model,
audio_config=config.audio,
train_config=config.train,
evaluation_config=config.evaluation,
inference_config=config.inference,
outputs_config=config.outputs,
logging_config=config.logging,
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,
@ -325,40 +514,83 @@ class BatDetect2API:
evaluator=evaluator,
model=model,
formatter=formatter,
output_transform=output_transform,
)
@classmethod
def from_checkpoint(
cls,
path: data.PathLike,
config: Optional[BatDetect2Config] = None,
):
model, stored_config = load_model_from_checkpoint(path)
targets_config: TargetConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
inference_config: InferenceConfig | None = None,
outputs_config: OutputsConfig | None = None,
logging_config: AppLoggingConfig | None = None,
) -> "BatDetect2API":
model, model_config = load_model_from_checkpoint(path)
config = (
merge_configs(stored_config, config) if config else stored_config
audio_config = audio_config or AudioConfig(
samplerate=model_config.samplerate,
)
train_config = train_config or TrainingConfig()
evaluation_config = evaluation_config or EvaluationConfig()
inference_config = inference_config or InferenceConfig()
outputs_config = outputs_config or OutputsConfig()
logging_config = logging_config or AppLoggingConfig()
targets = build_targets(config=config.targets)
if (
targets_config is not None
and targets_config != model_config.targets
):
targets = build_targets(config=targets_config)
model = build_model_with_new_targets(
model=model,
targets=targets,
)
model_config = model_config.model_copy(
update={"targets": targets_config}
)
audio_loader = build_audio_loader(config=config.audio)
targets = build_targets(config=model_config.targets)
audio_loader = build_audio_loader(config=audio_config)
preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
config=model_config.preprocess,
)
postprocessor = build_postprocessor(
preprocessor,
config=config.postprocess,
config=model_config.postprocess,
)
evaluator = build_evaluator(config=config.evaluation, targets=targets)
formatter = build_output_formatter(
targets,
config=outputs_config.format,
)
formatter = build_output_formatter(targets, config=config.output)
output_transform = build_output_transform(
config=outputs_config.transform,
targets=targets,
)
evaluator = build_evaluator(
config=evaluation_config,
targets=targets,
transform=output_transform,
)
return cls(
config=config,
model_config=model_config,
audio_config=audio_config,
train_config=train_config,
evaluation_config=evaluation_config,
inference_config=inference_config,
outputs_config=outputs_config,
logging_config=logging_config,
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,
@ -366,4 +598,27 @@ class BatDetect2API:
evaluator=evaluator,
model=model,
formatter=formatter,
output_transform=output_transform,
)
def _set_trainable_parameters(
self,
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
) -> None:
detector = cast(Detector, self.model.detector)
for parameter in detector.parameters():
parameter.requires_grad = False
if trainable == "all":
for parameter in detector.parameters():
parameter.requires_grad = True
return
if trainable in {"heads", "classifier_head"}:
for parameter in detector.classifier_head.parameters():
parameter.requires_grad = True
if trainable in {"heads", "bbox_head"}:
for parameter in detector.bbox_head.parameters():
parameter.requires_grad = True

View File

@ -5,8 +5,11 @@ from batdetect2.audio.loader import (
SoundEventAudioLoader,
build_audio_loader,
)
from batdetect2.audio.types import AudioLoader, ClipperProtocol
__all__ = [
"AudioLoader",
"ClipperProtocol",
"TARGET_SAMPLERATE_HZ",
"AudioConfig",
"SoundEventAudioLoader",

View File

@ -1,4 +1,4 @@
from typing import Annotated, List, Literal, Optional, Union
from typing import Annotated, List, Literal
import numpy as np
from loguru import logger
@ -6,8 +6,13 @@ from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds, intervals_overlap
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipperProtocol
from batdetect2.audio.types import ClipperProtocol
from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
DEFAULT_TRAIN_CLIP_DURATION = 0.256
DEFAULT_MAX_EMPTY_CLIP = 0.1
@ -16,12 +21,24 @@ DEFAULT_MAX_EMPTY_CLIP = 0.1
__all__ = [
"build_clipper",
"ClipConfig",
"ClipperImportConfig",
]
clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")
@add_import_config(clipper_registry)
class ClipperImportConfig(ImportConfig):
"""Use any callable as a clipper.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class RandomClipConfig(BaseConfig):
name: Literal["random_subclip"] = "random_subclip"
duration: float = DEFAULT_TRAIN_CLIP_DURATION
@ -245,16 +262,12 @@ class FixedDurationClip:
ClipConfig = Annotated[
Union[
RandomClipConfig,
PaddedClipConfig,
FixedDurationClipConfig,
],
RandomClipConfig | PaddedClipConfig | FixedDurationClipConfig,
Field(discriminator="name"),
]
def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
def build_clipper(config: ClipConfig | None = None) -> ClipperProtocol:
config = config or RandomClipConfig()
logger.opt(lazy=True).debug(

View File

@ -1,5 +1,3 @@
from typing import Optional
import numpy as np
from numpy.typing import DTypeLike
from pydantic import Field
@ -7,8 +5,8 @@ from scipy.signal import resample, resample_poly
from soundevent import audio, data
from soundfile import LibsndfileError
from batdetect2.audio.types import AudioLoader
from batdetect2.core import BaseConfig
from batdetect2.typing import AudioLoader
__all__ = [
"SoundEventAudioLoader",
@ -28,15 +26,17 @@ class ResampleConfig(BaseConfig):
Attributes
----------
samplerate : int, default=256000
The target sample rate in Hz to resample the audio to. Must be > 0.
enabled : bool, default=True
Whether to resample the audio to the target sample rate. If
``False``, the audio is returned at its original sample rate.
method : str, default="poly"
The resampling algorithm to use. Options:
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
Generally fast.
- "fourier": Resampling via Fourier method using
`scipy.signal.resample`. May handle non-integer
resampling factors differently.
- ``"poly"``: Polyphase resampling via
``scipy.signal.resample_poly``. Generally fast and accurate.
- ``"fourier"``: FFT-based resampling via
``scipy.signal.resample``. May be preferred for non-integer
resampling ratios.
"""
enabled: bool = True
@ -50,7 +50,7 @@ class AudioConfig(BaseConfig):
resample: ResampleConfig = Field(default_factory=ResampleConfig)
def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader:
def build_audio_loader(config: AudioConfig | None = None) -> AudioLoader:
"""Factory function to create an AudioLoader based on configuration."""
config = config or AudioConfig()
return SoundEventAudioLoader(
@ -65,7 +65,7 @@ class SoundEventAudioLoader(AudioLoader):
def __init__(
self,
samplerate: int = TARGET_SAMPLERATE_HZ,
config: Optional[ResampleConfig] = None,
config: ResampleConfig | None = None,
):
self.samplerate = samplerate
self.config = config or ResampleConfig()
@ -73,7 +73,7 @@ class SoundEventAudioLoader(AudioLoader):
def load_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
audio_dir: data.PathLike | None = None,
) -> np.ndarray:
"""Load and preprocess audio directly from a file path."""
return load_file_audio(
@ -86,7 +86,7 @@ class SoundEventAudioLoader(AudioLoader):
def load_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
audio_dir: data.PathLike | None = None,
) -> np.ndarray:
"""Load and preprocess the entire audio for a Recording object."""
return load_recording_audio(
@ -99,7 +99,7 @@ class SoundEventAudioLoader(AudioLoader):
def load_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
audio_dir: data.PathLike | None = None,
) -> np.ndarray:
"""Load and preprocess the audio segment defined by a Clip object."""
return load_clip_audio(
@ -112,10 +112,10 @@ class SoundEventAudioLoader(AudioLoader):
def load_file_audio(
path: data.PathLike,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
samplerate: int | None = None,
config: ResampleConfig | None = None,
audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32,
) -> np.ndarray:
"""Load and preprocess audio from a file path using specified config."""
try:
@ -136,10 +136,10 @@ def load_file_audio(
def load_recording_audio(
recording: data.Recording,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
samplerate: int | None = None,
config: ResampleConfig | None = None,
audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32,
) -> np.ndarray:
"""Load and preprocess the entire audio content of a recording using config."""
clip = data.Clip(
@ -158,10 +158,10 @@ def load_recording_audio(
def load_clip_audio(
clip: data.Clip,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
samplerate: int | None = None,
config: ResampleConfig | None = None,
audio_dir: data.PathLike | None = None,
dtype: DTypeLike = np.float32,
) -> np.ndarray:
"""Load and preprocess a specific audio clip segment based on config."""
try:
@ -194,7 +194,31 @@ def resample_audio(
samplerate: int = TARGET_SAMPLERATE_HZ,
method: str = "poly",
) -> np.ndarray:
"""Resample an audio waveform DataArray to a target sample rate."""
"""Resample an audio waveform to a target sample rate.
Parameters
----------
wav : np.ndarray
Input waveform array. The last axis is assumed to be time.
sr : int
Original sample rate of ``wav`` in Hz.
samplerate : int, default=256000
Target sample rate in Hz.
method : str, default="poly"
Resampling algorithm: ``"poly"`` (polyphase) or
``"fourier"`` (FFT-based).
Returns
-------
np.ndarray
Resampled waveform. If ``sr == samplerate`` the input array is
returned unchanged.
Raises
------
NotImplementedError
If ``method`` is not ``"poly"`` or ``"fourier"``.
"""
if sr == samplerate:
return wav
@ -264,7 +288,7 @@ def resample_audio_fourier(
sr_new: int,
axis: int = -1,
) -> np.ndarray:
"""Resample a numpy array using `scipy.signal.resample`.
"""Resample a numpy array using ``scipy.signal.resample``.
This method uses FFTs to resample the signal.
@ -272,23 +296,20 @@ def resample_audio_fourier(
----------
array : np.ndarray
The input array to resample.
num : int
The desired number of samples in the output array along `axis`.
sr_orig : int
The original sample rate in Hz.
sr_new : int
The target sample rate in Hz.
axis : int, default=-1
The axis of `array` along which to resample.
The axis of ``array`` along which to resample.
Returns
-------
np.ndarray
The array resampled to have `num` samples along `axis`.
Raises
------
ValueError
If `num` is negative.
The array resampled to the target sample rate.
"""
ratio = sr_new / sr_orig
return resample( # type: ignore
return resample(
array,
int(array.shape[axis] * ratio),
axis=axis,

View File

@ -0,0 +1,40 @@
from typing import Protocol
import numpy as np
from soundevent import data
__all__ = [
"AudioLoader",
"ClipperProtocol",
]
class AudioLoader(Protocol):
samplerate: int
def load_file(
self,
path: data.PathLike,
audio_dir: data.PathLike | None = None,
) -> np.ndarray: ...
def load_recording(
self,
recording: data.Recording,
audio_dir: data.PathLike | None = None,
) -> np.ndarray: ...
def load_clip(
self,
clip: data.Clip,
audio_dir: data.PathLike | None = None,
) -> np.ndarray: ...
class ClipperProtocol(Protocol):
def __call__(
self,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation: ...
def get_subclip(self, clip: data.Clip) -> data.Clip: ...

View File

@ -2,6 +2,7 @@ from batdetect2.cli.base import cli
from batdetect2.cli.compat import detect
from batdetect2.cli.data import data
from batdetect2.cli.evaluate import evaluate_command
from batdetect2.cli.inference import predict
from batdetect2.cli.train import train_command
__all__ = [
@ -10,6 +11,7 @@ __all__ = [
"data",
"train_command",
"evaluate_command",
"predict",
]

View File

@ -1,5 +1,4 @@
from pathlib import Path
from typing import Optional
import click
@ -34,9 +33,9 @@ def data(): ...
)
def summary(
dataset_config: Path,
field: Optional[str] = None,
targets_path: Optional[Path] = None,
base_dir: Optional[Path] = None,
field: str | None = None,
targets_path: Path | None = None,
base_dir: Path | None = None,
):
from batdetect2.data import compute_class_summary, load_dataset_from_config
from batdetect2.targets import load_targets
@ -83,9 +82,9 @@ def summary(
)
def convert(
dataset_config: Path,
field: Optional[str] = None,
field: str | None = None,
output: Path = Path("annotations.json"),
base_dir: Optional[Path] = None,
base_dir: Path | None = None,
):
"""Convert a dataset config file to soundevent format."""
from soundevent import data, io

View File

@ -1,5 +1,4 @@
from pathlib import Path
from typing import Optional
import click
from loguru import logger
@ -13,9 +12,14 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@cli.command(name="evaluate")
@click.argument("model-path", type=click.Path(exists=True))
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("test_dataset", type=click.Path(exists=True))
@click.option("--config", "config_path", type=click.Path())
@click.option("--targets", "targets_config", type=click.Path(exists=True))
@click.option("--audio-config", type=click.Path(exists=True))
@click.option("--evaluation-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--logging-config", type=click.Path(exists=True))
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
@click.option("--experiment-name", type=str)
@ -25,15 +29,25 @@ def evaluate_command(
model_path: Path,
test_dataset: Path,
base_dir: Path,
config_path: Optional[Path],
targets_config: Path | None,
audio_config: Path | None,
evaluation_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
logging_config: Path | None,
output_dir: Path = DEFAULT_OUTPUT_DIR,
num_workers: Optional[int] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
num_workers: int = 0,
experiment_name: str | None = None,
run_name: str | None = None,
):
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import load_full_config
from batdetect2.audio import AudioConfig
from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate import EvaluationConfig
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.targets import TargetConfig
logger.info("Initiating evaluation process...")
@ -47,11 +61,44 @@ def evaluate_command(
num_annotations=len(test_annotations),
)
config = None
if config_path is not None:
config = load_full_config(config_path)
target_conf = (
TargetConfig.load(targets_config)
if targets_config is not None
else None
)
audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None
)
eval_conf = (
EvaluationConfig.load(evaluation_config)
if evaluation_config is not None
else None
)
inference_conf = (
InferenceConfig.load(inference_config)
if inference_config is not None
else None
)
outputs_conf = (
OutputsConfig.load(outputs_config)
if outputs_config is not None
else None
)
logging_conf = (
AppLoggingConfig.load(logging_config)
if logging_config is not None
else None
)
api = BatDetect2API.from_checkpoint(model_path, config=config)
api = BatDetect2API.from_checkpoint(
model_path,
targets_config=target_conf,
audio_config=audio_conf,
evaluation_config=eval_conf,
inference_config=inference_conf,
outputs_config=outputs_conf,
logging_config=logging_conf,
)
api.evaluate(
test_annotations,

View File

@ -0,0 +1,231 @@
from pathlib import Path
import click
from loguru import logger
from soundevent import io
from soundevent.audio.files import get_audio_files
from batdetect2.cli.base import cli
__all__ = ["predict"]
@cli.group(name="predict")
def predict() -> None:
"""Run prediction with BatDetect2 API v2."""
def _build_api(
model_path: Path,
audio_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
logging_config: Path | None,
):
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.outputs import OutputsConfig
audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None
)
inference_conf = (
InferenceConfig.load(inference_config)
if inference_config is not None
else None
)
outputs_conf = (
OutputsConfig.load(outputs_config)
if outputs_config is not None
else None
)
logging_conf = (
AppLoggingConfig.load(logging_config)
if logging_config is not None
else None
)
api = BatDetect2API.from_checkpoint(
model_path,
audio_config=audio_conf,
inference_config=inference_conf,
outputs_config=outputs_conf,
logging_config=logging_conf,
)
return api, audio_conf, inference_conf, outputs_conf
def _run_prediction(
model_path: Path,
audio_files: list[Path],
output_path: Path,
audio_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
logging_config: Path | None,
batch_size: int | None,
num_workers: int,
format_name: str | None,
) -> None:
logger.info("Initiating prediction process...")
api, audio_conf, inference_conf, outputs_conf = _build_api(
model_path,
audio_config,
inference_config,
outputs_config,
logging_config,
)
logger.info("Found {num_files} audio files", num_files=len(audio_files))
predictions = api.process_files(
audio_files,
batch_size=batch_size,
num_workers=num_workers,
audio_config=audio_conf,
inference_config=inference_conf,
output_config=outputs_conf,
)
common_path = audio_files[0].parent if audio_files else None
api.save_predictions(
predictions,
path=output_path,
audio_dir=common_path,
format=format_name,
)
logger.info(
"Prediction complete. Results saved to {path}", path=output_path
)
@predict.command(name="directory")
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("audio_dir", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path())
@click.option("--audio-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--logging-config", type=click.Path(exists=True))
@click.option("--batch-size", type=int)
@click.option("--workers", "num_workers", type=int, default=0)
@click.option("--format", "format_name", type=str)
def predict_directory_command(
model_path: Path,
audio_dir: Path,
output_path: Path,
audio_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
logging_config: Path | None,
batch_size: int | None,
num_workers: int,
format_name: str | None,
) -> None:
audio_files = list(get_audio_files(audio_dir))
_run_prediction(
model_path=model_path,
audio_files=audio_files,
output_path=output_path,
audio_config=audio_config,
inference_config=inference_config,
outputs_config=outputs_config,
logging_config=logging_config,
batch_size=batch_size,
num_workers=num_workers,
format_name=format_name,
)
@predict.command(name="file_list")
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("file_list", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path())
@click.option("--audio-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--logging-config", type=click.Path(exists=True))
@click.option("--batch-size", type=int)
@click.option("--workers", "num_workers", type=int, default=0)
@click.option("--format", "format_name", type=str)
def predict_file_list_command(
model_path: Path,
file_list: Path,
output_path: Path,
audio_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
logging_config: Path | None,
batch_size: int | None,
num_workers: int,
format_name: str | None,
) -> None:
file_list = Path(file_list)
audio_files = [
Path(line.strip())
for line in file_list.read_text().splitlines()
if line.strip()
]
_run_prediction(
model_path=model_path,
audio_files=audio_files,
output_path=output_path,
audio_config=audio_config,
inference_config=inference_config,
outputs_config=outputs_config,
logging_config=logging_config,
batch_size=batch_size,
num_workers=num_workers,
format_name=format_name,
)
@predict.command(name="dataset")
@click.argument("model_path", type=click.Path(exists=True))
@click.argument("dataset_path", type=click.Path(exists=True))
@click.argument("output_path", type=click.Path())
@click.option("--audio-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--logging-config", type=click.Path(exists=True))
@click.option("--batch-size", type=int)
@click.option("--workers", "num_workers", type=int, default=0)
@click.option("--format", "format_name", type=str)
def predict_dataset_command(
model_path: Path,
dataset_path: Path,
output_path: Path,
audio_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
logging_config: Path | None,
batch_size: int | None,
num_workers: int,
format_name: str | None,
) -> None:
dataset_path = Path(dataset_path)
dataset = io.load(dataset_path, type="annotation_set")
audio_files = sorted(
{
Path(clip_annotation.clip.recording.path)
for clip_annotation in dataset.clip_annotations
}
)
_run_prediction(
model_path=model_path,
audio_files=audio_files,
output_path=output_path,
audio_config=audio_config,
inference_config=inference_config,
outputs_config=outputs_config,
logging_config=logging_config,
batch_size=batch_size,
num_workers=num_workers,
format_name=format_name,
)

View File

@ -1,5 +1,4 @@
from pathlib import Path
from typing import Optional
import click
from loguru import logger
@ -14,10 +13,15 @@ __all__ = ["train_command"]
@click.option("--val-dataset", type=click.Path(exists=True))
@click.option("--model", "model_path", type=click.Path(exists=True))
@click.option("--targets", "targets_config", type=click.Path(exists=True))
@click.option("--model-config", type=click.Path(exists=True))
@click.option("--training-config", type=click.Path(exists=True))
@click.option("--audio-config", type=click.Path(exists=True))
@click.option("--evaluation-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--logging-config", type=click.Path(exists=True))
@click.option("--ckpt-dir", type=click.Path(exists=True))
@click.option("--log-dir", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True))
@click.option("--config-field", type=str)
@click.option("--train-workers", type=int)
@click.option("--val-workers", type=int)
@click.option("--num-epochs", type=int)
@ -26,42 +30,82 @@ __all__ = ["train_command"]
@click.option("--seed", type=int)
def train_command(
train_dataset: Path,
val_dataset: Optional[Path] = None,
model_path: Optional[Path] = None,
ckpt_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
config: Optional[Path] = None,
targets_config: Optional[Path] = None,
config_field: Optional[str] = None,
seed: Optional[int] = None,
num_epochs: Optional[int] = None,
val_dataset: Path | None = None,
model_path: Path | None = None,
ckpt_dir: Path | None = None,
log_dir: Path | None = None,
targets_config: Path | None = None,
model_config: Path | None = None,
training_config: Path | None = None,
audio_config: Path | None = None,
evaluation_config: Path | None = None,
inference_config: Path | None = None,
outputs_config: Path | None = None,
logging_config: Path | None = None,
seed: int | None = None,
num_epochs: int | None = None,
train_workers: int = 0,
val_workers: int = 0,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
experiment_name: str | None = None,
run_name: str | None = None,
):
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import (
BatDetect2Config,
load_full_config,
)
from batdetect2.audio import AudioConfig
from batdetect2.config import BatDetect2Config
from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config
from batdetect2.evaluate import EvaluationConfig
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.targets import TargetConfig
from batdetect2.train import TrainingConfig
logger.info("Initiating training process...")
logger.info("Loading configuration...")
conf = (
load_full_config(config, field=config_field)
if config is not None
else BatDetect2Config()
target_conf = (
TargetConfig.load(targets_config)
if targets_config is not None
else None
)
model_conf = (
ModelConfig.load(model_config) if model_config is not None else None
)
train_conf = (
TrainingConfig.load(training_config)
if training_config is not None
else None
)
audio_conf = (
AudioConfig.load(audio_config) if audio_config is not None else None
)
eval_conf = (
EvaluationConfig.load(evaluation_config)
if evaluation_config is not None
else None
)
inference_conf = (
InferenceConfig.load(inference_config)
if inference_config is not None
else None
)
outputs_conf = (
OutputsConfig.load(outputs_config)
if outputs_config is not None
else None
)
logging_conf = (
AppLoggingConfig.load(logging_config)
if logging_config is not None
else None
)
if targets_config is not None:
logger.info("Loading targets configuration...")
conf = conf.model_copy(
update=dict(targets=load_target_config(targets_config))
)
if target_conf is not None:
logger.info("Loaded targets configuration.")
if model_conf is not None and target_conf is not None:
model_conf = model_conf.model_copy(update={"targets": target_conf})
logger.info("Loading training dataset...")
train_annotations = load_dataset_from_config(train_dataset)
@ -82,12 +126,43 @@ def train_command(
logger.info("Configuration and data loaded. Starting training...")
if model_path is not None and model_conf is not None:
raise click.UsageError(
"--model-config cannot be used with --model. "
"Checkpoint model configuration is loaded from the checkpoint."
)
if model_path is None:
conf = BatDetect2Config()
if model_conf is not None:
conf.model = model_conf
elif target_conf is not None:
conf.model = conf.model.model_copy(update={"targets": target_conf})
if train_conf is not None:
conf.train = train_conf
if audio_conf is not None:
conf.audio = audio_conf
if eval_conf is not None:
conf.evaluation = eval_conf
if inference_conf is not None:
conf.inference = inference_conf
if outputs_conf is not None:
conf.outputs = outputs_conf
if logging_conf is not None:
conf.logging = logging_conf
api = BatDetect2API.from_config(conf)
else:
api = BatDetect2API.from_checkpoint(
model_path,
config=conf if config is not None else None,
targets_config=target_conf,
train_config=train_conf,
audio_config=audio_conf,
evaluation_config=eval_conf,
inference_config=inference_conf,
outputs_config=outputs_conf,
logging_config=logging_conf,
)
return api.train(

View File

@ -4,7 +4,7 @@ import json
import os
import uuid
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Callable, List
import numpy as np
from soundevent import data
@ -17,7 +17,7 @@ from batdetect2.types import (
FileAnnotation,
)
PathLike = Union[Path, str, os.PathLike]
PathLike = Path | str | os.PathLike
__all__ = [
"convert_to_annotation_group",
@ -33,7 +33,7 @@ UNKNOWN_CLASS = "__UNKNOWN__"
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
EventFn = Callable[[data.SoundEventAnnotation], str | None]
ClassFn = Callable[[data.Recording], int]
@ -103,17 +103,17 @@ def convert_to_annotation_group(
y_inds.append(0)
annotations.append(
{
"start_time": start_time,
"end_time": end_time,
"low_freq": low_freq,
"high_freq": high_freq,
"class_prob": 1.0,
"det_prob": 1.0,
"individual": "0",
"event": event,
"class_id": class_id, # type: ignore
}
Annotation(
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
class_prob=1.0,
det_prob=1.0,
individual="0",
event=event,
class_id=class_id,
)
)
return {
@ -221,7 +221,7 @@ def annotation_to_sound_event_prediction(
def file_annotation_to_clip(
file_annotation: FileAnnotation,
audio_dir: Optional[PathLike] = None,
audio_dir: PathLike | None = None,
label_key: str = "class",
) -> data.Clip:
"""Convert file annotation to recording."""

View File

@ -1,28 +1,20 @@
from typing import Literal, Optional
from typing import Literal
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.audio import AudioConfig
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.data.predictions import OutputFormatConfig
from batdetect2.data.predictions.raw import RawOutputConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.evaluate.config import (
EvaluationConfig,
get_default_eval_config,
)
from batdetect2.inference.config import InferenceConfig
from batdetect2.models.config import BackboneConfig
from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.targets.config import TargetConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.train.config import TrainingConfig
__all__ = [
"BatDetect2Config",
"load_full_config",
"validate_config",
]
__all__ = ["BatDetect2Config"]
class BatDetect2Config(BaseConfig):
@ -32,26 +24,8 @@ class BatDetect2Config(BaseConfig):
evaluation: EvaluationConfig = Field(
default_factory=get_default_eval_config
)
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
model: ModelConfig = Field(default_factory=ModelConfig)
audio: AudioConfig = Field(default_factory=AudioConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
inference: InferenceConfig = Field(default_factory=InferenceConfig)
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
def validate_config(config: Optional[dict]) -> BatDetect2Config:
if config is None:
return BatDetect2Config()
return BatDetect2Config.model_validate(config)
def load_full_config(
path: PathLike,
field: Optional[str] = None,
) -> BatDetect2Config:
return load_config(path, schema=BatDetect2Config, field=field)
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)

View File

@ -1,8 +1,14 @@
from batdetect2.core.configs import BaseConfig, load_config, merge_configs
from batdetect2.core.registries import Registry
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
__all__ = [
"add_import_config",
"BaseConfig",
"ImportConfig",
"load_config",
"Registry",
"merge_configs",

View File

@ -1,5 +1,3 @@
from typing import Optional
import numpy as np
import torch
import xarray as xr
@ -86,8 +84,8 @@ def adjust_width(
def slice_tensor(
tensor: torch.Tensor,
start: Optional[int] = None,
end: Optional[int] = None,
start: int | None = None,
end: int | None = None,
dim: int = -1,
) -> torch.Tensor:
slices = [slice(None)] * tensor.ndim

View File

@ -8,11 +8,11 @@ configuration data from files, with optional support for accessing nested
configuration sections.
"""
from typing import Any, Optional, Type, TypeVar
from typing import Any, Literal, Type, TypeVar, overload
import yaml
from deepmerge.merger import Merger
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, TypeAdapter
from soundevent.data import PathLike
__all__ = [
@ -21,6 +21,8 @@ __all__ = [
"merge_configs",
]
C = TypeVar("C", bound="BaseConfig")
class BaseConfig(BaseModel):
"""Base class for all configuration models in BatDetect2.
@ -62,8 +64,30 @@ class BaseConfig(BaseModel):
)
)
@classmethod
def from_yaml(cls, yaml_str: str):
return cls.model_validate(yaml.safe_load(yaml_str))
T = TypeVar("T", bound=BaseModel)
@classmethod
def load(
cls: Type[C],
path: PathLike,
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
) -> C:
return load_config(
path,
schema=cls,
field=field,
extra=extra,
strict=strict,
)
T = TypeVar("T")
T_Model = TypeVar("T_Model", bound=BaseModel)
Schema = Type[T_Model] | TypeAdapter[T]
def get_object_field(obj: dict, current_key: str) -> Any:
@ -125,35 +149,69 @@ def get_object_field(obj: dict, current_key: str) -> Any:
return get_object_field(subobj, rest)
@overload
def load_config(
path: PathLike,
schema: Type[T],
field: Optional[str] = None,
) -> T:
schema: Type[T_Model],
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
) -> T_Model: ...
@overload
def load_config(
path: PathLike,
schema: TypeAdapter[T],
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
) -> T: ...
def load_config(
path: PathLike,
schema: Type[T_Model] | TypeAdapter[T],
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
) -> T_Model | T:
"""Load and validate configuration data from a file against a schema.
Reads a YAML file, optionally extracts a specific section using dot
notation, and then validates the resulting data against the provided
Pydantic `schema`.
Pydantic schema.
Parameters
----------
path : PathLike
The path to the configuration file (typically `.yaml`).
schema : Type[T]
The Pydantic `BaseModel` subclass that defines the expected structure
and types for the configuration data.
schema : Type[T_Model] | TypeAdapter[T]
Either a Pydantic `BaseModel` subclass or a `TypeAdapter` instance
that defines the expected structure and types for the configuration
data.
field : str, optional
A dot-separated string indicating a nested section within the YAML
file to extract before validation. If None (default), the entire
file content is validated against the schema.
Example: `"training.optimizer"` would extract the `optimizer` section
within the `training` section.
extra : Literal["ignore", "allow", "forbid"], optional
How to handle extra keys in the configuration data. If None (default),
the default behaviour of the schema is used. If "ignore", extra keys
are ignored. If "allow", extra keys are allowed and will be accessible
as attributes on the resulting model instance. If "forbid", extra
keys are forbidden and an exception is raised. See pydantic
documentation for more details.
strict : bool, optional
Whether to enforce types strictly. If None (default), the default
behaviour of the schema is used. See pydantic documentation for more
details.
Returns
-------
T
An instance of the provided `schema`, populated and validated with
T_Model | T
An instance of the schema type, populated and validated with
data from the configuration file.
Raises
@ -179,7 +237,10 @@ def load_config(
if field:
config = get_object_field(config, field)
return schema.model_validate(config or {})
if isinstance(schema, TypeAdapter):
return schema.validate_python(config or {}, extra=extra, strict=strict)
return schema.model_validate(config or {}, extra=extra, strict=strict)
default_merger = Merger(
@ -189,7 +250,7 @@ default_merger = Merger(
)
def merge_configs(config1: T, config2: T) -> T:
def merge_configs(config1: T_Model, config2: T_Model) -> T_Model:
"""Merge two configuration objects."""
model = type(config1)
dict1 = config1.model_dump()

View File

@ -1,23 +1,28 @@
import sys
from typing import Callable, Dict, Generic, Tuple, Type, TypeVar
from typing import (
Any,
Callable,
Concatenate,
Generic,
ParamSpec,
Sequence,
Type,
TypeVar,
)
from pydantic import BaseModel
if sys.version_info >= (3, 10):
from typing import Concatenate, ParamSpec
else:
from typing_extensions import Concatenate, ParamSpec
from hydra.utils import instantiate
from pydantic import BaseModel, Field
__all__ = [
"add_import_config",
"ImportConfig",
"Registry",
"SimpleRegistry",
]
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
T_Type = TypeVar("T_Type", covariant=True)
P_Type = ParamSpec("P_Type")
T = TypeVar("T")
@ -43,12 +48,13 @@ class SimpleRegistry(Generic[T]):
class Registry(Generic[T_Type, P_Type]):
"""A generic class to create and manage a registry of items."""
def __init__(self, name: str):
def __init__(self, name: str, discriminator: str = "name"):
self._name = name
self._registry: Dict[
self._registry: dict[
str, Callable[Concatenate[..., P_Type], T_Type]
] = {}
self._config_types: Dict[str, Type[BaseModel]] = {}
self._discriminator = discriminator
self._config_types: dict[str, Type[BaseModel]] = {}
def register(
self,
@ -56,15 +62,20 @@ class Registry(Generic[T_Type, P_Type]):
):
fields = config_cls.model_fields
if "name" not in fields:
raise ValueError("Configuration object must have a 'name' field.")
if self._discriminator not in fields:
raise ValueError(
"Configuration object must have "
f"a '{self._discriminator}' field."
)
name = fields["name"].default
name = fields[self._discriminator].default
self._config_types[name] = config_cls
if not isinstance(name, str):
raise ValueError("'name' field must be a string literal.")
raise ValueError(
f"'{self._discriminator}' field must be a string literal."
)
def decorator(
func: Callable[Concatenate[T_Config, P_Type], T_Type],
@ -74,7 +85,7 @@ class Registry(Generic[T_Type, P_Type]):
return decorator
def get_config_types(self) -> Tuple[Type[BaseModel], ...]:
def get_config_types(self) -> tuple[Type[BaseModel], ...]:
return tuple(self._config_types.values())
def get_config_type(self, name: str) -> Type[BaseModel]:
@ -94,10 +105,12 @@ class Registry(Generic[T_Type, P_Type]):
) -> T_Type:
"""Builds a logic instance from a config object."""
name = getattr(config, "name") # noqa: B009
name = getattr(config, self._discriminator) # noqa: B009
if name is None:
raise ValueError("Config does not have a name field")
raise ValueError(
f"Config does not have a '{self._discriminator}' field"
)
if name not in self._registry:
raise NotImplementedError(
@ -105,3 +118,92 @@ class Registry(Generic[T_Type, P_Type]):
)
return self._registry[name](config, *args, **kwargs)
class ImportConfig(BaseModel):
"""Base config for dynamic instantiation via Hydra.
Subclass this to create a registry-specific import escape hatch.
The subclass must add a discriminator field whose name matches the
registry's own discriminator key, with its value fixed to
``Literal["import"]``.
Attributes
----------
target : str
Fully-qualified dotted path to the callable to instantiate,
e.g. ``"mypackage.module.MyClass"``.
arguments : dict[str, Any]
Base keyword arguments forwarded to the callable. When the
same key also appears in ``kwargs`` passed to ``build()``,
the ``kwargs`` value takes priority.
"""
target: str
arguments: dict[str, Any] = Field(default_factory=dict)
T_Import = TypeVar("T_Import", bound=ImportConfig)
def add_import_config(
registry: Registry[T_Type, P_Type],
arg_names: Sequence[str] | None = None,
) -> Callable[[Type[T_Import]], Type[T_Import]]:
"""Decorator that registers an ImportConfig subclass as an escape hatch.
Wraps the decorated class in a builder that calls
``hydra.utils.instantiate`` using ``config.target`` and
``config.arguments``. The builder is registered on *registry*
under the discriminator value ``"import"``.
Parameters
----------
registry : Registry
The registry instance on which the config should be registered.
Returns
-------
Callable[[type[ImportConfig]], type[ImportConfig]]
A class decorator that registers the class and returns it
unchanged.
Examples
--------
Define a per-registry import escape hatch::
@add_import_config(my_registry)
class MyRegistryImportConfig(ImportConfig):
name: Literal["import"] = "import"
"""
def decorator(config_cls: Type[T_Import]) -> Type[T_Import]:
def builder(
config: T_Import,
*args: P_Type.args,
**kwargs: P_Type.kwargs,
) -> T_Type:
_arg_names = arg_names or []
if len(args) != len(_arg_names):
raise ValueError(
"Positional arguments are not supported "
"for import escape hatch unless you specify "
"the argument names. Use `arg_names` to specify "
"the names of the positional arguments."
)
args_dict = {_arg_names[i]: args[i] for i in range(len(args))}
hydra_cfg = {
"_target_": config.target,
**config.arguments,
**args_dict,
**kwargs,
}
return instantiate(hydra_cfg)
registry.register(config_cls)(builder)
return config_cls
return decorator

View File

@ -7,20 +7,12 @@ from batdetect2.data.annotations import (
load_annotated_dataset,
)
from batdetect2.data.datasets import (
Dataset,
DatasetConfig,
load_dataset,
load_dataset_config,
load_dataset_from_config,
)
from batdetect2.data.predictions import (
BatDetect2OutputConfig,
OutputFormatConfig,
RawOutputConfig,
SoundEventOutputConfig,
build_output_formatter,
get_output_formatter,
load_predictions,
)
from batdetect2.data.summary import (
compute_class_summary,
extract_recordings_df,
@ -28,6 +20,7 @@ from batdetect2.data.summary import (
)
__all__ = [
"Dataset",
"AOEFAnnotations",
"AnnotatedDataset",
"AnnotationFormats",
@ -36,6 +29,7 @@ __all__ = [
"BatDetect2OutputConfig",
"DatasetConfig",
"OutputFormatConfig",
"ParquetOutputConfig",
"RawOutputConfig",
"SoundEventOutputConfig",
"build_output_formatter",

View File

@ -13,22 +13,18 @@ format-specific loading function to retrieve the annotations as a standard
`soundevent.data.AnnotationSet`.
"""
from typing import Annotated, Optional, Union
from typing import Annotated
from pydantic import Field
from soundevent import data
from batdetect2.data.annotations.aoef import (
AOEFAnnotations,
load_aoef_annotated_dataset,
)
from batdetect2.data.annotations.aoef import AOEFAnnotations
from batdetect2.data.annotations.batdetect2 import (
AnnotationFilter,
BatDetect2FilesAnnotations,
BatDetect2MergedAnnotations,
load_batdetect2_files_annotated_dataset,
load_batdetect2_merged_annotated_dataset,
)
from batdetect2.data.annotations.registry import annotation_format_registry
from batdetect2.data.annotations.types import AnnotatedDataset
__all__ = [
@ -43,11 +39,7 @@ __all__ = [
AnnotationFormats = Annotated[
Union[
BatDetect2MergedAnnotations,
BatDetect2FilesAnnotations,
AOEFAnnotations,
],
BatDetect2MergedAnnotations | BatDetect2FilesAnnotations | AOEFAnnotations,
Field(discriminator="format"),
]
"""Type Alias representing all supported data source configurations.
@ -63,24 +55,24 @@ source configuration represents.
def load_annotated_dataset(
dataset: AnnotatedDataset,
base_dir: Optional[data.PathLike] = None,
base_dir: data.PathLike | None = None,
) -> data.AnnotationSet:
"""Load annotations for a single data source based on its configuration.
This function acts as a dispatcher. It inspects the type of the input
`source_config` object (which corresponds to a specific annotation format)
and calls the appropriate loading function (e.g.,
`load_aoef_annotated_dataset` for `AOEFAnnotations`).
This function acts as a dispatcher. It inspects the format of the input
`dataset` object and delegates to the appropriate format-specific loader
registered in the `annotation_format_registry` (e.g.,
`AOEFLoader` for `AOEFAnnotations`).
Parameters
----------
source_config : AnnotationFormats
dataset : AnnotatedDataset
The configuration object for the data source, specifying its format
and necessary details (like paths). Must be an instance of one of the
types included in the `AnnotationFormats` union.
base_dir : Path, optional
An optional base directory path. If provided, relative paths within
the `source_config` might be resolved relative to this directory by
the `dataset` will be resolved relative to this directory by
the underlying loading functions. Defaults to None.
Returns
@ -92,23 +84,8 @@ def load_annotated_dataset(
Raises
------
NotImplementedError
If the type of the `source_config` object does not match any of the
known format-specific loading functions implemented in the dispatch
logic.
If the `format` field of `dataset` does not match any registered
annotation format loader.
"""
if isinstance(dataset, AOEFAnnotations):
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)
if isinstance(dataset, BatDetect2MergedAnnotations):
return load_batdetect2_merged_annotated_dataset(
dataset, base_dir=base_dir
)
if isinstance(dataset, BatDetect2FilesAnnotations):
return load_batdetect2_files_annotated_dataset(
dataset,
base_dir=base_dir,
)
raise NotImplementedError(f"Unknown annotation format: {dataset.name}")
loader = annotation_format_registry.build(dataset)
return loader.load(base_dir=base_dir)

View File

@ -12,17 +12,22 @@ that meet specific status criteria (e.g., completed, verified, without issues).
"""
from pathlib import Path
from typing import Literal, Optional
from typing import Literal
from uuid import uuid5
from pydantic import Field
from soundevent import data, io
from batdetect2.core.configs import BaseConfig
from batdetect2.data.annotations.types import AnnotatedDataset
from batdetect2.data.annotations.registry import annotation_format_registry
from batdetect2.data.annotations.types import (
AnnotatedDataset,
AnnotationLoader,
)
__all__ = [
"AOEFAnnotations",
"AOEFLoader",
"load_aoef_annotated_dataset",
"AnnotationTaskFilter",
]
@ -77,14 +82,30 @@ class AOEFAnnotations(AnnotatedDataset):
annotations_path: Path
filter: Optional[AnnotationTaskFilter] = Field(
filter: AnnotationTaskFilter | None = Field(
default_factory=AnnotationTaskFilter
)
class AOEFLoader(AnnotationLoader):
def __init__(self, config: AOEFAnnotations):
self.config = config
def load(
self,
base_dir: data.PathLike | None = None,
) -> data.AnnotationSet:
return load_aoef_annotated_dataset(self.config, base_dir=base_dir)
@annotation_format_registry.register(AOEFAnnotations)
@staticmethod
def from_config(config: AOEFAnnotations):
return AOEFLoader(config)
def load_aoef_annotated_dataset(
dataset: AOEFAnnotations,
base_dir: Optional[data.PathLike] = None,
base_dir: data.PathLike | None = None,
) -> data.AnnotationSet:
"""Load annotations from an AnnotationSet or AnnotationProject file.

View File

@ -27,7 +27,7 @@ aggregated into a `soundevent.data.AnnotationSet`.
import json
import os
from pathlib import Path
from typing import Literal, Optional, Union
from typing import Literal
from loguru import logger
from pydantic import Field, ValidationError
@ -41,9 +41,13 @@ from batdetect2.data.annotations.legacy import (
list_file_annotations,
load_file_annotation,
)
from batdetect2.data.annotations.types import AnnotatedDataset
from batdetect2.data.annotations.registry import annotation_format_registry
from batdetect2.data.annotations.types import (
AnnotatedDataset,
AnnotationLoader,
)
PathLike = Union[Path, str, os.PathLike]
PathLike = Path | str | os.PathLike
__all__ = [
@ -102,7 +106,7 @@ class BatDetect2FilesAnnotations(AnnotatedDataset):
format: Literal["batdetect2"] = "batdetect2"
annotations_dir: Path
filter: Optional[AnnotationFilter] = Field(
filter: AnnotationFilter | None = Field(
default_factory=AnnotationFilter,
)
@ -133,14 +137,14 @@ class BatDetect2MergedAnnotations(AnnotatedDataset):
format: Literal["batdetect2_file"] = "batdetect2_file"
annotations_path: Path
filter: Optional[AnnotationFilter] = Field(
filter: AnnotationFilter | None = Field(
default_factory=AnnotationFilter,
)
def load_batdetect2_files_annotated_dataset(
dataset: BatDetect2FilesAnnotations,
base_dir: Optional[PathLike] = None,
base_dir: PathLike | None = None,
) -> data.AnnotationSet:
"""Load and convert 'batdetect2_file' annotations into an AnnotationSet.
@ -244,7 +248,7 @@ def load_batdetect2_files_annotated_dataset(
def load_batdetect2_merged_annotated_dataset(
dataset: BatDetect2MergedAnnotations,
base_dir: Optional[PathLike] = None,
base_dir: PathLike | None = None,
) -> data.AnnotationSet:
"""Load and convert 'batdetect2_merged' annotations into an AnnotationSet.
@ -302,7 +306,7 @@ def load_batdetect2_merged_annotated_dataset(
try:
ann = FileAnnotation.model_validate(ann)
except ValueError as err:
logger.warning(f"Invalid annotation file: {err}")
logger.warning("Invalid annotation file: {err}", err=err)
continue
if (
@ -310,17 +314,23 @@ def load_batdetect2_merged_annotated_dataset(
and dataset.filter.only_annotated
and not ann.annotated
):
logger.debug(f"Skipping incomplete annotation {ann.id}")
logger.debug(
"Skipping incomplete annotation {ann_id}",
ann_id=ann.id,
)
continue
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
logger.debug(f"Skipping annotation with issues {ann.id}")
logger.debug(
"Skipping annotation with issues {ann_id}",
ann_id=ann.id,
)
continue
try:
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
except FileNotFoundError as err:
logger.warning(f"Error loading annotations: {err}")
logger.warning("Error loading annotations: {err}", err=err)
continue
annotations.append(file_annotation_to_clip_annotation(ann, clip))
@ -330,3 +340,41 @@ def load_batdetect2_merged_annotated_dataset(
description=dataset.description,
clip_annotations=annotations,
)
class BatDetect2MergedLoader(AnnotationLoader):
def __init__(self, config: BatDetect2MergedAnnotations):
self.config = config
def load(
self,
base_dir: PathLike | None = None,
) -> data.AnnotationSet:
return load_batdetect2_merged_annotated_dataset(
self.config,
base_dir=base_dir,
)
@annotation_format_registry.register(BatDetect2MergedAnnotations)
@staticmethod
def from_config(config: BatDetect2MergedAnnotations):
return BatDetect2MergedLoader(config)
class BatDetect2FilesLoader(AnnotationLoader):
def __init__(self, config: BatDetect2FilesAnnotations):
self.config = config
def load(
self,
base_dir: PathLike | None = None,
) -> data.AnnotationSet:
return load_batdetect2_files_annotated_dataset(
self.config,
base_dir=base_dir,
)
@annotation_format_registry.register(BatDetect2FilesAnnotations)
@staticmethod
def from_config(config: BatDetect2FilesAnnotations):
return BatDetect2FilesLoader(config)

View File

@ -3,12 +3,12 @@
import os
import uuid
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Callable, List
from pydantic import BaseModel, Field
from soundevent import data
PathLike = Union[Path, str, os.PathLike]
PathLike = Path | str | os.PathLike
__all__ = []
@ -27,7 +27,7 @@ SOUND_EVENT_ANNOTATION_NAMESPACE = uuid.uuid5(
)
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
EventFn = Callable[[data.SoundEventAnnotation], str | None]
ClassFn = Callable[[data.Recording], int]
@ -130,7 +130,7 @@ def get_sound_event_tags(
def file_annotation_to_clip(
file_annotation: FileAnnotation,
audio_dir: Optional[PathLike] = None,
audio_dir: PathLike | None = None,
label_key: str = "class",
) -> data.Clip:
"""Convert file annotation to recording."""

View File

@ -0,0 +1,35 @@
from typing import Literal
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.data.annotations.types import AnnotationLoader
__all__ = [
"AnnotationFormatImportConfig",
"annotation_format_registry",
]
annotation_format_registry: Registry[AnnotationLoader, []] = Registry(
"annotation_format",
discriminator="format",
)
@add_import_config(annotation_format_registry)
class AnnotationFormatImportConfig(ImportConfig):
"""Import escape hatch for the annotation format registry.
Use this config to dynamically instantiate any callable as an
annotation loader without registering it in
``annotation_format_registry`` ahead of time.
Parameters
----------
format : Literal["import"]
Discriminator value; must always be ``"import"``.
target : str
Fully-qualified dotted path to the callable to instantiate.
arguments : dict[str, Any]
Keyword arguments forwarded to the callable.
"""
format: Literal["import"] = "import"

View File

@ -1,9 +1,13 @@
from pathlib import Path
from typing import Protocol
from soundevent import data
from batdetect2.core.configs import BaseConfig
__all__ = [
"AnnotatedDataset",
"AnnotationLoader",
]
@ -34,3 +38,10 @@ class AnnotatedDataset(BaseConfig):
name: str
audio_dir: Path
description: str = ""
class AnnotationLoader(Protocol):
def load(
self,
base_dir: data.PathLike | None = None,
) -> data.AnnotationSet: ...

View File

@ -1,18 +1,33 @@
from collections.abc import Callable
from typing import Annotated, List, Literal, Sequence, Union
from typing import Annotated, List, Literal, Sequence
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
conditions: Registry[SoundEventCondition, []] = Registry("condition")
@add_import_config(conditions)
class SoundEventConditionImportConfig(ImportConfig):
"""Use any callable as a sound event condition.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class HasTagConfig(BaseConfig):
name: Literal["has_tag"] = "has_tag"
tag: data.Tag
@ -264,16 +279,14 @@ class Not:
SoundEventConditionConfig = Annotated[
Union[
HasTagConfig,
HasAllTagsConfig,
HasAnyTagConfig,
DurationConfig,
FrequencyConfig,
AllOfConfig,
AnyOfConfig,
NotConfig,
],
HasTagConfig
| HasAllTagsConfig
| HasAnyTagConfig
| DurationConfig
| FrequencyConfig
| AllOfConfig
| AnyOfConfig
| NotConfig,
Field(discriminator="name"),
]

View File

@ -19,7 +19,7 @@ The core components are:
"""
from pathlib import Path
from typing import List, Optional, Sequence
from typing import List, Sequence
from loguru import logger
from pydantic import Field
@ -69,7 +69,7 @@ class DatasetConfig(BaseConfig):
description: str
sources: List[AnnotationFormats]
sound_event_filter: Optional[SoundEventConditionConfig] = None
sound_event_filter: SoundEventConditionConfig | None = None
sound_event_transforms: List[SoundEventTransformConfig] = Field(
default_factory=list
)
@ -77,7 +77,7 @@ class DatasetConfig(BaseConfig):
def load_dataset(
config: DatasetConfig,
base_dir: Optional[data.PathLike] = None,
base_dir: data.PathLike | None = None,
) -> Dataset:
"""Load all clip annotations from the sources defined in a DatasetConfig."""
clip_annotations = []
@ -161,14 +161,14 @@ def insert_source_tag(
)
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
def load_dataset_config(path: data.PathLike, field: str | None = None):
return load_config(path=path, schema=DatasetConfig, field=field)
def load_dataset_from_config(
path: data.PathLike,
field: Optional[str] = None,
base_dir: Optional[data.PathLike] = None,
field: str | None = None,
base_dir: data.PathLike | None = None,
) -> Dataset:
"""Load dataset annotation metadata from a configuration file.
@ -215,9 +215,9 @@ def load_dataset_from_config(
def save_dataset(
dataset: Dataset,
path: data.PathLike,
name: Optional[str] = None,
description: Optional[str] = None,
audio_dir: Optional[Path] = None,
name: str | None = None,
description: str | None = None,
audio_dir: Path | None = None,
) -> None:
"""Save a loaded dataset (list of ClipAnnotations) to a file.

View File

@ -1,16 +1,15 @@
from collections.abc import Generator
from typing import Optional, Tuple
from soundevent import data
from batdetect2.data.datasets import Dataset
from batdetect2.typing.targets import TargetProtocol
from batdetect2.targets.types import TargetProtocol
def iterate_over_sound_events(
dataset: Dataset,
targets: TargetProtocol,
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
) -> Generator[tuple[str | None, data.SoundEventAnnotation], None, None]:
"""Iterate over sound events in a dataset.
Parameters
@ -24,7 +23,7 @@ def iterate_over_sound_events(
Yields
------
Tuple[Optional[str], data.SoundEventAnnotation]
tuple[Optional[str], data.SoundEventAnnotation]
A tuple containing:
- The encoded class name (str) for the sound event, or None if it
cannot be encoded to a specific class.

View File

@ -1,29 +0,0 @@
from pathlib import Path
from soundevent.data import PathLike
from batdetect2.core import Registry
from batdetect2.typing import (
OutputFormatterProtocol,
TargetProtocol,
)
def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
path = Path(path)
audio_dir = Path(audio_dir)
if path.is_absolute():
if not path.is_relative_to(audio_dir):
raise ValueError(
f"Audio file {path} is not in audio_dir {audio_dir}"
)
return path.relative_to(audio_dir)
return path
prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
Registry(name="output_formatter")
)

View File

@ -1,5 +1,3 @@
from typing import Optional, Tuple
from sklearn.model_selection import train_test_split
from batdetect2.data.datasets import Dataset
@ -7,15 +5,15 @@ from batdetect2.data.summary import (
extract_recordings_df,
extract_sound_events_df,
)
from batdetect2.typing.targets import TargetProtocol
from batdetect2.targets.types import TargetProtocol
def split_dataset_by_recordings(
dataset: Dataset,
targets: TargetProtocol,
train_size: float = 0.75,
random_state: Optional[int] = None,
) -> Tuple[Dataset, Dataset]:
random_state: int | None = None,
) -> tuple[Dataset, Dataset]:
recordings = extract_recordings_df(dataset)
sound_events = extract_sound_events_df(
@ -26,13 +24,15 @@ def split_dataset_by_recordings(
)
majority_class = (
sound_events.groupby("recording_id")
sound_events.groupby("recording_id") # type: ignore
.apply(
lambda group: group["class_name"] # type: ignore
.value_counts()
.sort_values(ascending=False)
.index[0],
include_groups=False, # type: ignore
lambda group: (
group["class_name"]
.value_counts()
.sort_values(ascending=False)
.index[0]
),
include_groups=False,
)
.rename("class_name")
.to_frame()
@ -46,8 +46,8 @@ def split_dataset_by_recordings(
random_state=random_state,
)
train_ids_set = set(train.values) # type: ignore
test_ids_set = set(test.values) # type: ignore
train_ids_set = set(train.values)
test_ids_set = set(test.values)
extra = set(recordings["recording_id"]) - train_ids_set - test_ids_set

View File

@ -2,7 +2,7 @@ import pandas as pd
from soundevent.geometry import compute_bounds
from batdetect2.data.datasets import Dataset
from batdetect2.typing.targets import TargetProtocol
from batdetect2.targets.types import TargetProtocol
__all__ = [
"extract_recordings_df",
@ -175,14 +175,14 @@ def compute_class_summary(
.rename("num recordings")
)
durations = (
sound_events.groupby("class_name")
sound_events.groupby("class_name") # ty: ignore[no-matching-overload]
.apply(
lambda group: recordings[
recordings["clip_annotation_id"].isin(
group["clip_annotation_id"] # type: ignore
group["clip_annotation_id"]
)
]["duration"].sum(),
include_groups=False, # type: ignore
include_groups=False,
)
.sort_values(ascending=False)
.rename("duration")

View File

@ -1,11 +1,15 @@
from collections.abc import Callable
from typing import Annotated, Dict, List, Literal, Optional, Union
from typing import Annotated, Dict, List, Literal
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.data.conditions import (
SoundEventCondition,
SoundEventConditionConfig,
@ -20,6 +24,17 @@ SoundEventTransform = Callable[
transforms: Registry[SoundEventTransform, []] = Registry("transform")
@add_import_config(transforms)
class SoundEventTransformImportConfig(ImportConfig):
"""Use any callable as a sound event transform.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class SetFrequencyBoundConfig(BaseConfig):
name: Literal["set_frequency"] = "set_frequency"
boundary: Literal["low", "high"] = "low"
@ -142,7 +157,7 @@ class MapTagValueConfig(BaseConfig):
name: Literal["map_tag_value"] = "map_tag_value"
tag_key: str
value_mapping: Dict[str, str]
target_key: Optional[str] = None
target_key: str | None = None
class MapTagValue:
@ -150,7 +165,7 @@ class MapTagValue:
self,
tag_key: str,
value_mapping: Dict[str, str],
target_key: Optional[str] = None,
target_key: str | None = None,
):
self.tag_key = tag_key
self.value_mapping = value_mapping
@ -176,12 +191,7 @@ class MapTagValue:
if self.target_key is None:
tags.append(tag.model_copy(update=dict(value=value)))
else:
tags.append(
data.Tag(
key=self.target_key, # type: ignore
value=value,
)
)
tags.append(data.Tag(key=self.target_key, value=value))
return sound_event_annotation.model_copy(update=dict(tags=tags))
@ -221,13 +231,11 @@ class ApplyAll:
SoundEventTransformConfig = Annotated[
Union[
SetFrequencyBoundConfig,
ReplaceTagConfig,
MapTagValueConfig,
ApplyIfConfig,
ApplyAllConfig,
],
SetFrequencyBoundConfig
| ReplaceTagConfig
| MapTagValueConfig
| ApplyIfConfig
| ApplyAllConfig,
Field(discriminator="name"),
]

View File

@ -1,6 +1,6 @@
"""Functions to compute features from predictions."""
from typing import Dict, List, Optional
from typing import Dict, List
import numpy as np
@ -86,7 +86,7 @@ def compute_bandwidth(
def compute_max_power_bb(
prediction: types.Prediction,
spec: Optional[np.ndarray] = None,
spec: np.ndarray | None = None,
min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ,
**_,
@ -131,7 +131,7 @@ def compute_max_power_bb(
def compute_max_power(
prediction: types.Prediction,
spec: Optional[np.ndarray] = None,
spec: np.ndarray | None = None,
min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ,
**_,
@ -157,7 +157,7 @@ def compute_max_power(
def compute_max_power_first(
prediction: types.Prediction,
spec: Optional[np.ndarray] = None,
spec: np.ndarray | None = None,
min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ,
**_,
@ -184,7 +184,7 @@ def compute_max_power_first(
def compute_max_power_second(
prediction: types.Prediction,
spec: Optional[np.ndarray] = None,
spec: np.ndarray | None = None,
min_freq: int = MIN_FREQ_HZ,
max_freq: int = MAX_FREQ_HZ,
**_,
@ -211,7 +211,7 @@ def compute_max_power_second(
def compute_call_interval(
prediction: types.Prediction,
previous: Optional[types.Prediction] = None,
previous: types.Prediction | None = None,
**_,
) -> float:
"""Compute time between this call and the previous call in seconds."""

View File

@ -1,7 +1,7 @@
import datetime
import os
from pathlib import Path
from typing import List, Optional, Union
from typing import List
from pydantic import BaseModel, Field, computed_field
@ -198,8 +198,8 @@ class TrainingParameters(BaseModel):
def get_params(
make_dirs: bool = False,
exps_dir: str = "../../experiments/",
model_name: Optional[str] = None,
experiment: Union[Path, str, None] = None,
model_name: str | None = None,
experiment: Path | str | None = None,
**kwargs,
) -> TrainingParameters:
experiments_dir = Path(exps_dir)

View File

@ -1,7 +1,5 @@
"""Post-processing of the output of the model."""
from typing import List, Tuple, Union
import numpy as np
import torch
from torch import nn
@ -45,7 +43,7 @@ def run_nms(
outputs: ModelOutput,
params: NonMaximumSuppressionConfig,
sampling_rate: np.ndarray,
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
) -> tuple[list[PredictionResults], list[np.ndarray]]:
"""Run non-maximum suppression on the output of the model.
Model outputs processed are expected to have a batch dimension.
@ -73,8 +71,8 @@ def run_nms(
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
# loop over batch to save outputs
preds: List[PredictionResults] = []
feats: List[np.ndarray] = []
preds: list[PredictionResults] = []
feats: list[np.ndarray] = []
for num_detection in range(pred_det_nms.shape[0]):
# get valid indices
inds_ord = torch.argsort(x_pos[num_detection, :])
@ -151,7 +149,7 @@ def run_nms(
def non_max_suppression(
heat: torch.Tensor,
kernel_size: Union[int, Tuple[int, int]],
kernel_size: int | tuple[int, int],
):
# kernel can be an int or list/tuple
if isinstance(kernel_size, int):

View File

@ -1,15 +1,32 @@
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
from batdetect2.evaluate.results import save_evaluation_results
from batdetect2.evaluate.tasks import TaskConfig, build_task
from batdetect2.evaluate.types import (
AffinityFunction,
ClipMatches,
EvaluationTaskProtocol,
EvaluatorProtocol,
MetricsProtocol,
PlotterProtocol,
)
__all__ = [
"AffinityFunction",
"ClipMatches",
"DEFAULT_EVAL_DIR",
"EvaluationConfig",
"EvaluationTaskProtocol",
"Evaluator",
"EvaluatorProtocol",
"MatchEvaluation",
"MatcherProtocol",
"MetricsProtocol",
"PlotterProtocol",
"TaskConfig",
"build_evaluator",
"build_task",
"evaluate",
"load_evaluation_config",
"DEFAULT_EVAL_DIR",
"run_evaluate",
"save_evaluation_results",
]

View File

@ -1,76 +1,116 @@
from typing import Annotated, Literal, Optional, Union
from typing import Annotated, Literal
from pydantic import Field
from soundevent import data
from soundevent.evaluation import compute_affinity
from soundevent.geometry import compute_interval_overlap
from soundevent.geometry import (
buffer_geometry,
compute_bbox_iou,
compute_geometric_iou,
compute_temporal_closeness,
compute_temporal_iou,
)
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.typing.evaluate import AffinityFunction
from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.types import AffinityFunction
from batdetect2.postprocess.types import Detection
affinity_functions: Registry[AffinityFunction, []] = Registry(
"matching_strategy"
"affinity_function"
)
@add_import_config(affinity_functions)
class AffinityFunctionImportConfig(ImportConfig):
"""Use any callable as an affinity function.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class TimeAffinityConfig(BaseConfig):
name: Literal["time_affinity"] = "time_affinity"
time_buffer: float = 0.01
position: Literal["start", "end", "center"] | float = "start"
max_distance: float = 0.01
class TimeAffinity(AffinityFunction):
def __init__(self, time_buffer: float):
self.time_buffer = time_buffer
def __init__(
self,
max_distance: float = 0.01,
position: Literal["start", "end", "center"] | float = "start",
):
if position == "start":
position = 0
elif position == "end":
position = 1
elif position == "center":
position = 0.5
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
return compute_timestamp_affinity(
geometry1, geometry2, time_buffer=self.time_buffer
self.position = position
self.max_distance = max_distance
def __call__(
self,
detection: Detection,
ground_truth: data.SoundEventAnnotation,
) -> float:
target_geometry = ground_truth.sound_event.geometry
source_geometry = detection.geometry
return compute_temporal_closeness(
target_geometry,
source_geometry,
ratio=self.position,
max_distance=self.max_distance,
)
@affinity_functions.register(TimeAffinityConfig)
@staticmethod
def from_config(config: TimeAffinityConfig):
return TimeAffinity(time_buffer=config.time_buffer)
def compute_timestamp_affinity(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
) -> float:
assert isinstance(geometry1, data.TimeStamp)
assert isinstance(geometry2, data.TimeStamp)
start_time1 = geometry1.coordinates
start_time2 = geometry2.coordinates
a = min(start_time1, start_time2)
b = max(start_time1, start_time2)
if b - a >= 2 * time_buffer:
return 0
intersection = a - b + 2 * time_buffer
union = b - a + 2 * time_buffer
return intersection / union
return TimeAffinity(
max_distance=config.max_distance,
position=config.position,
)
class IntervalIOUConfig(BaseConfig):
name: Literal["interval_iou"] = "interval_iou"
time_buffer: float = 0.01
time_buffer: float = 0.0
class IntervalIOU(AffinityFunction):
def __init__(self, time_buffer: float):
if time_buffer < 0:
raise ValueError("time_buffer must be non-negative")
self.time_buffer = time_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
return compute_interval_iou(
geometry1,
geometry2,
time_buffer=self.time_buffer,
)
def __call__(
self,
detection: Detection,
ground_truth: data.SoundEventAnnotation,
) -> float:
target_geometry = ground_truth.sound_event.geometry
source_geometry = detection.geometry
if self.time_buffer > 0:
target_geometry = buffer_geometry(
target_geometry,
time=self.time_buffer,
)
source_geometry = buffer_geometry(
source_geometry,
time=self.time_buffer,
)
return compute_temporal_iou(target_geometry, source_geometry)
@affinity_functions.register(IntervalIOUConfig)
@staticmethod
@ -78,64 +118,44 @@ class IntervalIOU(AffinityFunction):
return IntervalIOU(time_buffer=config.time_buffer)
def compute_interval_iou(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
) -> float:
assert isinstance(geometry1, data.TimeInterval)
assert isinstance(geometry2, data.TimeInterval)
start_time1, end_time1 = geometry1.coordinates
start_time2, end_time2 = geometry1.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
intersection = compute_interval_overlap(
(start_time1, end_time1),
(start_time2, end_time2),
)
union = (
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
)
if union == 0:
return 0
return intersection / union
class BBoxIOUConfig(BaseConfig):
name: Literal["bbox_iou"] = "bbox_iou"
time_buffer: float = 0.01
freq_buffer: float = 1000
time_buffer: float = 0.0
freq_buffer: float = 0.0
class BBoxIOU(AffinityFunction):
def __init__(self, time_buffer: float, freq_buffer: float):
if time_buffer < 0:
raise ValueError("time_buffer must be non-negative")
if freq_buffer < 0:
raise ValueError("freq_buffer must be non-negative")
self.time_buffer = time_buffer
self.freq_buffer = freq_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
if not isinstance(geometry1, data.BoundingBox):
raise TypeError(
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}"
def __call__(
self,
detection: Detection,
ground_truth: data.SoundEventAnnotation,
):
target_geometry = ground_truth.sound_event.geometry
source_geometry = detection.geometry
if self.time_buffer > 0 or self.freq_buffer > 0:
target_geometry = buffer_geometry(
target_geometry,
time=self.time_buffer,
freq=self.freq_buffer,
)
source_geometry = buffer_geometry(
source_geometry,
time=self.time_buffer,
freq=self.freq_buffer,
)
if not isinstance(geometry2, data.BoundingBox):
raise TypeError(
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
)
return bbox_iou(
geometry1,
geometry2,
time_buffer=self.time_buffer,
freq_buffer=self.freq_buffer,
)
return compute_bbox_iou(target_geometry, source_geometry)
@affinity_functions.register(BBoxIOUConfig)
@staticmethod
@ -146,65 +166,44 @@ class BBoxIOU(AffinityFunction):
)
def bbox_iou(
geometry1: data.BoundingBox,
geometry2: data.BoundingBox,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
low_freq1 -= freq_buffer
low_freq2 -= freq_buffer
high_freq1 += freq_buffer
high_freq2 += freq_buffer
time_intersection = compute_interval_overlap(
(start_time1, end_time1),
(start_time2, end_time2),
)
freq_intersection = max(
0,
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
)
intersection = time_intersection * freq_intersection
if intersection == 0:
return 0
union = (
(end_time1 - start_time1) * (high_freq1 - low_freq1)
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
- intersection
)
return intersection / union
class GeometricIOUConfig(BaseConfig):
name: Literal["geometric_iou"] = "geometric_iou"
time_buffer: float = 0.01
freq_buffer: float = 1000
time_buffer: float = 0.0
freq_buffer: float = 0.0
class GeometricIOU(AffinityFunction):
def __init__(self, time_buffer: float):
self.time_buffer = time_buffer
def __init__(self, time_buffer: float = 0, freq_buffer: float = 0):
if time_buffer < 0:
raise ValueError("time_buffer must be non-negative")
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
return compute_affinity(
geometry1,
geometry2,
time_buffer=self.time_buffer,
)
if freq_buffer < 0:
raise ValueError("freq_buffer must be non-negative")
self.time_buffer = time_buffer
self.freq_buffer = freq_buffer
def __call__(
self,
detection: Detection,
ground_truth: data.SoundEventAnnotation,
):
target_geometry = ground_truth.sound_event.geometry
source_geometry = detection.geometry
if self.time_buffer > 0 or self.freq_buffer > 0:
target_geometry = buffer_geometry(
target_geometry,
time=self.time_buffer,
freq=self.freq_buffer,
)
source_geometry = buffer_geometry(
source_geometry,
time=self.time_buffer,
freq=self.freq_buffer,
)
return compute_geometric_iou(target_geometry, source_geometry)
@affinity_functions.register(GeometricIOUConfig)
@staticmethod
@ -213,18 +212,16 @@ class GeometricIOU(AffinityFunction):
AffinityConfig = Annotated[
Union[
TimeAffinityConfig,
IntervalIOUConfig,
BBoxIOUConfig,
GeometricIOUConfig,
],
TimeAffinityConfig
| IntervalIOUConfig
| BBoxIOUConfig
| GeometricIOUConfig,
Field(discriminator="name"),
]
def build_affinity_function(
config: Optional[AffinityConfig] = None,
config: AffinityConfig | None = None,
) -> AffinityFunction:
config = config or GeometricIOUConfig()
return affinity_functions.build(config)

View File

@ -1,19 +1,14 @@
from typing import List, Optional
from typing import List
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.tasks import (
TaskConfig,
)
from batdetect2.core.configs import BaseConfig
from batdetect2.evaluate.tasks import TaskConfig
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
__all__ = [
"EvaluationConfig",
"load_evaluation_config",
]
@ -24,7 +19,6 @@ class EvaluationConfig(BaseConfig):
ClassificationTaskConfig(),
]
)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
def get_default_eval_config() -> EvaluationConfig:
@ -47,10 +41,3 @@ def get_default_eval_config() -> EvaluationConfig:
]
}
)
def load_evaluation_config(
path: data.PathLike,
field: Optional[str] = None,
) -> EvaluationConfig:
return load_config(path, schema=EvaluationConfig, field=field)

View File

@ -1,4 +1,4 @@
from typing import List, NamedTuple, Optional, Sequence
from typing import List, NamedTuple, Sequence
import torch
from loguru import logger
@ -8,14 +8,11 @@ from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
from batdetect2.audio.clips import PaddedClipConfig
from batdetect2.audio.types import AudioLoader, ClipperProtocol
from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import (
AudioLoader,
ClipperProtocol,
PreprocessorProtocol,
)
from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = [
"TestDataset",
@ -39,8 +36,8 @@ class TestDataset(Dataset[TestExample]):
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
clipper: Optional[ClipperProtocol] = None,
audio_dir: Optional[data.PathLike] = None,
clipper: ClipperProtocol | None = None,
audio_dir: data.PathLike | None = None,
):
self.clip_annotations = list(clip_annotations)
self.clipper = clipper
@ -51,8 +48,8 @@ class TestDataset(Dataset[TestExample]):
def __len__(self):
return len(self.clip_annotations)
def __getitem__(self, idx: int) -> TestExample:
clip_annotation = self.clip_annotations[idx]
def __getitem__(self, index: int) -> TestExample:
clip_annotation = self.clip_annotations[index]
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
@ -63,14 +60,13 @@ class TestDataset(Dataset[TestExample]):
spectrogram = self.preprocessor(wav_tensor)
return TestExample(
spec=spectrogram,
idx=torch.tensor(idx),
idx=torch.tensor(index),
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)
class TestLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
@ -78,10 +74,10 @@ class TestLoaderConfig(BaseConfig):
def build_test_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TestLoaderConfig] = None,
num_workers: Optional[int] = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
config: TestLoaderConfig | None = None,
num_workers: int = 0,
) -> DataLoader[TestExample]:
logger.info("Building test data loader...")
config = config or TestLoaderConfig()
@ -97,7 +93,6 @@ def build_test_loader(
config=config,
)
num_workers = num_workers or config.num_workers
return DataLoader(
test_dataset,
batch_size=1,
@ -109,9 +104,9 @@ def build_test_loader(
def build_test_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TestLoaderConfig] = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
config: TestLoaderConfig | None = None,
) -> TestDataset:
logger.info("Building training dataset...")
config = config or TestLoaderConfig()

View File

@ -1,56 +1,51 @@
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
from typing import Sequence
from lightning import Trainer
from soundevent import data
from batdetect2.audio import build_audio_loader
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.evaluate import EvaluationConfig
from batdetect2.evaluate.dataset import build_test_loader
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import build_logger
from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger
from batdetect2.models import Model
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.typing.postprocess import RawPrediction
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
from batdetect2.typing import (
AudioLoader,
OutputFormatterProtocol,
PreprocessorProtocol,
TargetProtocol,
)
from batdetect2.outputs import OutputsConfig, build_output_transform
from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess.types import ClipDetections
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
def evaluate(
def run_evaluate(
model: Model,
test_annotations: Sequence[data.ClipAnnotation],
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
formatter: Optional["OutputFormatterProtocol"] = None,
num_workers: Optional[int] = None,
targets: TargetProtocol | None = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
audio_config: AudioConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
output_config: OutputsConfig | None = None,
logger_config: LoggerConfig | None = None,
formatter: OutputFormatterProtocol | None = None,
num_workers: int = 0,
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
from batdetect2.config import BatDetect2Config
experiment_name: str | None = None,
run_name: str | None = None,
) -> tuple[dict[str, float], list[ClipDetections]]:
config = config or BatDetect2Config()
audio_config = audio_config or AudioConfig()
evaluation_config = evaluation_config or EvaluationConfig()
output_config = output_config or OutputsConfig()
audio_loader = audio_loader or build_audio_loader(config=config.audio)
audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or build_preprocessor(
config=config.preprocess,
input_samplerate=audio_loader.samplerate,
)
targets = targets or build_targets(config=config.targets)
preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets
loader = build_test_loader(
test_annotations,
@ -59,15 +54,26 @@ def evaluate(
num_workers=num_workers,
)
evaluator = build_evaluator(config=config.evaluation, targets=targets)
output_transform = build_output_transform(
config=output_config.transform,
targets=targets,
)
evaluator = build_evaluator(
config=evaluation_config,
targets=targets,
transform=output_transform,
)
logger = build_logger(
config.evaluation.logger,
logger_config or CSVLoggerConfig(),
log_dir=Path(output_dir),
experiment_name=experiment_name,
run_name=run_name,
)
module = EvaluationModule(model, evaluator)
module = EvaluationModule(
model,
evaluator,
)
trainer = Trainer(logger=logger, enable_checkpointing=False)
metrics = trainer.test(module, loader)

View File

@ -1,13 +1,15 @@
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Iterable, List, Sequence, Tuple
from matplotlib.figure import Figure
from soundevent import data
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.tasks import build_task
from batdetect2.evaluate.types import EvaluationTaskProtocol, EvaluatorProtocol
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections, ClipDetectionsTensor
from batdetect2.targets import build_targets
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol
__all__ = [
"Evaluator",
@ -19,15 +21,27 @@ class Evaluator:
def __init__(
self,
targets: TargetProtocol,
tasks: Sequence[EvaluatorProtocol],
transform: OutputTransformProtocol,
tasks: Sequence[EvaluationTaskProtocol],
):
self.targets = targets
self.transform = transform
self.tasks = tasks
def to_clip_detections_batch(
self,
clip_detections: Sequence[ClipDetectionsTensor],
clips: Sequence[data.Clip],
) -> list[ClipDetections]:
return [
self.transform.to_clip_detections(detections=dets, clip=clip)
for dets, clip in zip(clip_detections, clips, strict=False)
]
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction],
predictions: Sequence[ClipDetections],
) -> List[Any]:
return [
task.evaluate(clip_annotations, predictions) for task in self.tasks
@ -36,7 +50,7 @@ class Evaluator:
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
results = {}
for task, outputs in zip(self.tasks, eval_outputs):
for task, outputs in zip(self.tasks, eval_outputs, strict=False):
results.update(task.compute_metrics(outputs))
return results
@ -45,14 +59,15 @@ class Evaluator:
self,
eval_outputs: List[Any],
) -> Iterable[Tuple[str, Figure]]:
for task, outputs in zip(self.tasks, eval_outputs):
for task, outputs in zip(self.tasks, eval_outputs, strict=False):
for name, fig in task.generate_plots(outputs):
yield name, fig
def build_evaluator(
config: Optional[Union[EvaluationConfig, dict]] = None,
targets: Optional[TargetProtocol] = None,
config: EvaluationConfig | dict | None = None,
targets: TargetProtocol | None = None,
transform: OutputTransformProtocol | None = None,
) -> EvaluatorProtocol:
targets = targets or build_targets()
@ -62,7 +77,10 @@ def build_evaluator(
if not isinstance(config, EvaluationConfig):
config = EvaluationConfig.model_validate(config)
transform = transform or build_output_transform(targets=targets)
return Evaluator(
targets=targets,
transform=transform,
tasks=[build_task(task, targets=targets) for task in config.tasks],
)

View File

@ -357,7 +357,7 @@ def train_rf_model(x_train, y_train, num_classes, seed=2001):
clf = RandomForestClassifier(random_state=seed, n_jobs=-1)
clf.fit(x_train, y_train)
y_pred = clf.predict(x_train)
tr_acc = (y_pred == y_train).mean()
(y_pred == y_train).mean()
# print('Train acc', round(tr_acc*100, 2))
return clf, un_train_class
@ -450,7 +450,7 @@ def add_root_path_back(data_sets, ann_path, wav_path):
def check_classes_in_train(gt_list, class_names):
num_gt_total = np.sum([gg["start_times"].shape[0] for gg in gt_list])
np.sum([gg["start_times"].shape[0] for gg in gt_list])
num_with_no_class = 0
for gt in gt_list:
for cc in gt["class_names"]:
@ -569,7 +569,7 @@ if __name__ == "__main__":
num_with_no_class = check_classes_in_train(gt_test, class_names)
if total_num_calls == num_with_no_class:
print("Classes from the test set are not in the train set.")
assert False
raise AssertionError()
# only need the train data if evaluating Sonobat or Tadarida
if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "":
@ -743,7 +743,7 @@ if __name__ == "__main__":
# check if the class names are the same
if params_bd["class_names"] != class_names:
print("Warning: Class names are not the same as the trained model")
assert False
raise AssertionError()
run_config = {
**bd_args,
@ -753,7 +753,7 @@ if __name__ == "__main__":
preds_bd = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for ii, gg in enumerate(gt_test):
for gg in gt_test:
pred = du.process_file(
gg["file_path"],
model,

View File

@ -5,11 +5,10 @@ from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import get_image_logger
from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import EvaluatorProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
from batdetect2.postprocess.types import ClipDetections
class EvaluationModule(LightningModule):
@ -24,7 +23,7 @@ class EvaluationModule(LightningModule):
self.evaluator = evaluator
self.clip_annotations: List[data.ClipAnnotation] = []
self.predictions: List[BatDetect2Prediction] = []
self.predictions: List[ClipDetections] = []
def test_step(self, batch: TestExample, batch_idx: int):
dataset = self.get_dataset()
@ -34,22 +33,11 @@ class EvaluationModule(LightningModule):
]
outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor(
outputs,
start_times=[ca.clip.start_time for ca in clip_annotations],
clip_detections = self.model.postprocessor(outputs)
predictions = self.evaluator.to_clip_detections_batch(
clip_detections,
[clip_annotation.clip for clip_annotation in clip_annotations],
)
predictions = [
BatDetect2Prediction(
clip=clip_annotation.clip,
predictions=to_raw_predictions(
clip_dets.numpy(),
targets=self.evaluator.targets,
),
)
for clip_annotation, clip_dets in zip(
clip_annotations, clip_detections
)
]
self.clip_annotations.extend(clip_annotations)
self.predictions.extend(predictions)

View File

@ -1,617 +0,0 @@
from collections.abc import Callable, Iterable, Mapping
from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np
from pydantic import Field
from scipy.optimize import linear_sum_assignment
from soundevent import data
from soundevent.evaluation import compute_affinity
from soundevent.geometry import buffer_geometry, compute_bounds, scale_geometry
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.affinity import (
AffinityConfig,
BBoxIOUConfig,
GeometricIOUConfig,
build_affinity_function,
)
from batdetect2.targets import build_targets
from batdetect2.typing import (
AffinityFunction,
MatcherProtocol,
MatchEvaluation,
RawPrediction,
TargetProtocol,
)
from batdetect2.typing.evaluate import ClipMatches
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
"""The geometry representation to use for matching."""
matching_strategies = Registry("matching_strategy")
def match(
sound_event_annotations: Sequence[data.SoundEventAnnotation],
raw_predictions: Sequence[RawPrediction],
clip: data.Clip,
scores: Optional[Sequence[float]] = None,
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
) -> ClipMatches:
if matcher is None:
matcher = build_matcher()
if targets is None:
targets = build_targets()
target_geometries: List[data.Geometry] = [ # type: ignore
sound_event_annotation.sound_event.geometry
for sound_event_annotation in sound_event_annotations
]
predicted_geometries = [
raw_prediction.geometry for raw_prediction in raw_predictions
]
if scores is None:
scores = [
raw_prediction.detection_score
for raw_prediction in raw_predictions
]
matches = []
for source_idx, target_idx, affinity in matcher(
ground_truth=target_geometries,
predictions=predicted_geometries,
scores=scores,
):
target = (
sound_event_annotations[target_idx]
if target_idx is not None
else None
)
prediction = (
raw_predictions[source_idx] if source_idx is not None else None
)
gt_det = target_idx is not None
gt_class = targets.encode_class(target) if target is not None else None
gt_geometry = (
target_geometries[target_idx] if target_idx is not None else None
)
pred_score = float(prediction.detection_score) if prediction else 0
pred_geometry = (
predicted_geometries[source_idx]
if source_idx is not None
else None
)
class_scores = (
{
class_name: score
for class_name, score in zip(
targets.class_names,
prediction.class_scores,
)
}
if prediction is not None
else {}
)
matches.append(
MatchEvaluation(
clip=clip,
sound_event_annotation=target,
gt_det=gt_det,
gt_class=gt_class,
gt_geometry=gt_geometry,
pred_score=pred_score,
pred_class_scores=class_scores,
pred_geometry=pred_geometry,
affinity=affinity,
)
)
return ClipMatches(clip=clip, matches=matches)
class StartTimeMatchConfig(BaseConfig):
name: Literal["start_time_match"] = "start_time_match"
distance_threshold: float = 0.01
class StartTimeMatcher(MatcherProtocol):
def __init__(self, distance_threshold: float):
self.distance_threshold = distance_threshold
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
):
return match_start_times(
ground_truth,
predictions,
scores,
distance_threshold=self.distance_threshold,
)
@matching_strategies.register(StartTimeMatchConfig)
@staticmethod
def from_config(config: StartTimeMatchConfig):
return StartTimeMatcher(distance_threshold=config.distance_threshold)
def match_start_times(
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
distance_threshold: float = 0.01,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
if not ground_truth:
for index in range(len(predictions)):
yield index, None, 0
return
if not predictions:
for index in range(len(ground_truth)):
yield None, index, 0
return
gt_times = np.array([compute_bounds(geom)[0] for geom in ground_truth])
pred_times = np.array([compute_bounds(geom)[0] for geom in predictions])
scores = np.array(scores)
sort_args = np.argsort(scores)[::-1]
distances = np.abs(gt_times[None, :] - pred_times[:, None])
closests = np.argmin(distances, axis=-1)
unmatched_gt = set(range(len(gt_times)))
for pred_index in sort_args:
# Get the closest ground truth
gt_closest_index = closests[pred_index]
if gt_closest_index not in unmatched_gt:
# Does not match if closest has been assigned
yield pred_index, None, 0
continue
# Get the actual distance
distance = distances[pred_index, gt_closest_index]
if distance > distance_threshold:
# Does not match if too far from closest
yield pred_index, None, 0
continue
# Return affinity value: linear interpolation between 0 to 1, where a
# distance at the threshold maps to 0 affinity and a zero distance maps
# to 1.
affinity = np.interp(
distance,
[0, distance_threshold],
[1, 0],
left=1,
right=0,
)
unmatched_gt.remove(gt_closest_index)
yield pred_index, gt_closest_index, affinity
for missing_index in unmatched_gt:
yield None, missing_index, 0
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
return data.BoundingBox(
coordinates=[start_time, low_freq, end_time, high_freq]
)
def _to_interval(geometry: data.Geometry) -> data.TimeInterval:
start_time, _, end_time, _ = compute_bounds(geometry)
return data.TimeInterval(coordinates=[start_time, end_time])
def _to_timestamp(geometry: data.Geometry) -> data.TimeStamp:
start_time = compute_bounds(geometry)[0]
return data.TimeStamp(coordinates=start_time)
_geometry_cast_functions: Mapping[
MatchingGeometry, Callable[[data.Geometry], data.Geometry]
] = {
"bbox": _to_bbox,
"interval": _to_interval,
"timestamp": _to_timestamp,
}
class GreedyMatchConfig(BaseConfig):
name: Literal["greedy_match"] = "greedy_match"
geometry: MatchingGeometry = "timestamp"
affinity_threshold: float = 0.5
affinity_function: AffinityConfig = Field(
default_factory=GeometricIOUConfig
)
class GreedyMatcher(MatcherProtocol):
def __init__(
self,
geometry: MatchingGeometry,
affinity_threshold: float,
affinity_function: AffinityFunction,
):
self.geometry = geometry
self.affinity_function = affinity_function
self.affinity_threshold = affinity_threshold
self.cast_geometry = _geometry_cast_functions[self.geometry]
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
):
return greedy_match(
ground_truth=[self.cast_geometry(geom) for geom in ground_truth],
predictions=[self.cast_geometry(geom) for geom in predictions],
scores=scores,
affinity_function=self.affinity_function,
affinity_threshold=self.affinity_threshold,
)
@matching_strategies.register(GreedyMatchConfig)
@staticmethod
def from_config(config: GreedyMatchConfig):
affinity_function = build_affinity_function(config.affinity_function)
return GreedyMatcher(
geometry=config.geometry,
affinity_threshold=config.affinity_threshold,
affinity_function=affinity_function,
)
def greedy_match(
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
affinity_threshold: float = 0.5,
affinity_function: AffinityFunction = compute_affinity,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
"""Performs a greedy, one-to-one matching of source to target geometries.
Iterates through source geometries, prioritizing by score if provided. Each
source is matched to the best available target, provided the affinity
exceeds the threshold and the target has not already been assigned.
Parameters
----------
source
A list of source geometries (e.g., predictions).
target
A list of target geometries (e.g., ground truths).
scores
Confidence scores for each source geometry for prioritization.
affinity_threshold
The minimum affinity score required for a valid match.
Yields
------
Tuple[Optional[int], Optional[int], float]
A 3-element tuple describing a match or a miss. There are three
possible formats:
- Successful Match: `(source_idx, target_idx, affinity)`
- Unmatched Source (False Positive): `(source_idx, None, 0)`
- Unmatched Target (False Negative): `(None, target_idx, 0)`
"""
unassigned_gt = set(range(len(ground_truth)))
if not predictions:
for gt_idx in range(len(ground_truth)):
yield None, gt_idx, 0
return
if not ground_truth:
for pred_idx in range(len(predictions)):
yield pred_idx, None, 0
return
indices = np.argsort(scores)[::-1]
for pred_idx in indices:
source_geometry = predictions[pred_idx]
affinities = np.array(
[
affinity_function(source_geometry, target_geometry)
for target_geometry in ground_truth
]
)
closest_target = int(np.argmax(affinities))
affinity = affinities[closest_target]
if affinities[closest_target] <= affinity_threshold:
yield pred_idx, None, 0
continue
if closest_target not in unassigned_gt:
yield pred_idx, None, 0
continue
unassigned_gt.remove(closest_target)
yield pred_idx, closest_target, affinity
for gt_idx in unassigned_gt:
yield None, gt_idx, 0
class GreedyAffinityMatchConfig(BaseConfig):
name: Literal["greedy_affinity_match"] = "greedy_affinity_match"
affinity_function: AffinityConfig = Field(default_factory=BBoxIOUConfig)
affinity_threshold: float = 0.5
time_buffer: float = 0
frequency_buffer: float = 0
time_scale: float = 1.0
frequency_scale: float = 1.0
class GreedyAffinityMatcher(MatcherProtocol):
def __init__(
self,
affinity_threshold: float,
affinity_function: AffinityFunction,
time_buffer: float = 0,
frequency_buffer: float = 0,
time_scale: float = 1.0,
frequency_scale: float = 1.0,
):
self.affinity_threshold = affinity_threshold
self.affinity_function = affinity_function
self.time_buffer = time_buffer
self.frequency_buffer = frequency_buffer
self.time_scale = time_scale
self.frequency_scale = frequency_scale
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
):
if self.time_buffer != 0 or self.frequency_buffer != 0:
ground_truth = [
buffer_geometry(
geometry,
time_buffer=self.time_buffer,
freq_buffer=self.frequency_buffer,
)
for geometry in ground_truth
]
predictions = [
buffer_geometry(
geometry,
time_buffer=self.time_buffer,
freq_buffer=self.frequency_buffer,
)
for geometry in predictions
]
affinity_matrix = compute_affinity_matrix(
ground_truth,
predictions,
self.affinity_function,
time_scale=self.time_scale,
frequency_scale=self.frequency_scale,
)
return select_greedy_matches(
affinity_matrix,
affinity_threshold=self.affinity_threshold,
)
@matching_strategies.register(GreedyAffinityMatchConfig)
@staticmethod
def from_config(config: GreedyAffinityMatchConfig):
affinity_function = build_affinity_function(config.affinity_function)
return GreedyAffinityMatcher(
affinity_threshold=config.affinity_threshold,
affinity_function=affinity_function,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
)
class OptimalMatchConfig(BaseConfig):
name: Literal["optimal_affinity_match"] = "optimal_affinity_match"
affinity_function: AffinityConfig = Field(default_factory=BBoxIOUConfig)
affinity_threshold: float = 0.5
time_buffer: float = 0
frequency_buffer: float = 0
time_scale: float = 1.0
frequency_scale: float = 1.0
class OptimalMatcher(MatcherProtocol):
def __init__(
self,
affinity_threshold: float,
affinity_function: AffinityFunction,
time_buffer: float = 0,
frequency_buffer: float = 0,
time_scale: float = 1.0,
frequency_scale: float = 1.0,
):
self.affinity_threshold = affinity_threshold
self.affinity_function = affinity_function
self.time_buffer = time_buffer
self.frequency_buffer = frequency_buffer
self.time_scale = time_scale
self.frequency_scale = frequency_scale
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
):
if self.time_buffer != 0 or self.frequency_buffer != 0:
ground_truth = [
buffer_geometry(
geometry,
time_buffer=self.time_buffer,
freq_buffer=self.frequency_buffer,
)
for geometry in ground_truth
]
predictions = [
buffer_geometry(
geometry,
time_buffer=self.time_buffer,
freq_buffer=self.frequency_buffer,
)
for geometry in predictions
]
affinity_matrix = compute_affinity_matrix(
ground_truth,
predictions,
self.affinity_function,
time_scale=self.time_scale,
frequency_scale=self.frequency_scale,
)
return select_optimal_matches(
affinity_matrix,
affinity_threshold=self.affinity_threshold,
)
@matching_strategies.register(OptimalMatchConfig)
@staticmethod
def from_config(config: OptimalMatchConfig):
affinity_function = build_affinity_function(config.affinity_function)
return OptimalMatcher(
affinity_threshold=config.affinity_threshold,
affinity_function=affinity_function,
time_buffer=config.time_buffer,
frequency_buffer=config.frequency_buffer,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
)
MatchConfig = Annotated[
Union[
GreedyMatchConfig,
StartTimeMatchConfig,
OptimalMatchConfig,
GreedyAffinityMatchConfig,
],
Field(discriminator="name"),
]
def compute_affinity_matrix(
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
affinity_function: AffinityFunction,
time_scale: float = 1,
frequency_scale: float = 1,
) -> np.ndarray:
# Scale geometries if necessary
if time_scale != 1 or frequency_scale != 1:
ground_truth = [
scale_geometry(geometry, time_scale, frequency_scale)
for geometry in ground_truth
]
predictions = [
scale_geometry(geometry, time_scale, frequency_scale)
for geometry in predictions
]
affinity_matrix = np.zeros((len(ground_truth), len(predictions)))
for gt_idx, gt_geometry in enumerate(ground_truth):
for pred_idx, pred_geometry in enumerate(predictions):
affinity = affinity_function(
gt_geometry,
pred_geometry,
)
affinity_matrix[gt_idx, pred_idx] = affinity
return affinity_matrix
def select_optimal_matches(
affinity_matrix: np.ndarray,
affinity_threshold: float = 0.5,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
num_gt, num_pred = affinity_matrix.shape
gts = set(range(num_gt))
preds = set(range(num_pred))
assiged_rows, assigned_columns = linear_sum_assignment(
affinity_matrix,
maximize=True,
)
for gt_idx, pred_idx in zip(assiged_rows, assigned_columns):
affinity = float(affinity_matrix[gt_idx, pred_idx])
if affinity <= affinity_threshold:
continue
yield gt_idx, pred_idx, affinity
gts.remove(gt_idx)
preds.remove(pred_idx)
for gt_idx in gts:
yield gt_idx, None, 0
for pred_idx in preds:
yield None, pred_idx, 0
def select_greedy_matches(
affinity_matrix: np.ndarray,
affinity_threshold: float = 0.5,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
num_gt, num_pred = affinity_matrix.shape
unmatched_pred = set(range(num_pred))
for gt_idx in range(num_gt):
row = affinity_matrix[gt_idx]
top_pred = int(np.argmax(row))
top_affinity = float(row[top_pred])
if (
top_affinity <= affinity_threshold
or top_pred not in unmatched_pred
):
yield None, gt_idx, 0
continue
unmatched_pred.remove(top_pred)
yield top_pred, gt_idx, top_affinity
for pred_idx in unmatched_pred:
yield pred_idx, None, 0
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
config = config or StartTimeMatchConfig()
return matching_strategies.build(config)

View File

@ -7,10 +7,8 @@ from typing import (
List,
Literal,
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
import numpy as np
@ -18,16 +16,23 @@ from pydantic import Field
from sklearn import metrics
from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.metrics.common import (
average_precision,
compute_precision_recall,
)
from batdetect2.typing import RawPrediction, TargetProtocol
from batdetect2.postprocess.types import Detection
from batdetect2.targets.types import TargetProtocol
__all__ = [
"ClassificationMetric",
"ClassificationMetricConfig",
"ClassificationMetricImportConfig",
"build_classification_metric",
"compute_precision_recall_curves",
]
@ -36,13 +41,13 @@ __all__ = [
@dataclass
class MatchEval:
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
gt: data.SoundEventAnnotation | None
pred: Detection | None
is_prediction: bool
is_ground_truth: bool
is_generic: bool
true_class: Optional[str]
true_class: str | None
score: float
@ -60,17 +65,28 @@ classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
)
@add_import_config(classification_metrics)
class ClassificationMetricImportConfig(ImportConfig):
"""Use any callable as a classification metric.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class BaseClassificationConfig(BaseConfig):
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
include: List[str] | None = None
exclude: List[str] | None = None
class BaseClassificationMetric:
def __init__(
self,
targets: TargetProtocol,
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
include: List[str] | None = None,
exclude: List[str] | None = None,
):
self.targets = targets
self.include = include
@ -100,8 +116,8 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
label: str = "average_precision",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
include: List[str] | None = None,
exclude: List[str] | None = None,
):
super().__init__(include=include, exclude=exclude, targets=targets)
self.ignore_non_predictions = ignore_non_predictions
@ -169,8 +185,8 @@ class ClassificationROCAUC(BaseClassificationMetric):
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
label: str = "roc_auc",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
include: List[str] | None = None,
exclude: List[str] | None = None,
):
self.targets = targets
self.ignore_non_predictions = ignore_non_predictions
@ -225,10 +241,7 @@ class ClassificationROCAUC(BaseClassificationMetric):
ClassificationMetricConfig = Annotated[
Union[
ClassificationAveragePrecisionConfig,
ClassificationROCAUCConfig,
],
ClassificationAveragePrecisionConfig | ClassificationROCAUCConfig,
Field(discriminator="name"),
]

View File

@ -1,13 +1,17 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union
from typing import Annotated, Callable, Dict, Literal, Sequence, Set
import numpy as np
from pydantic import Field
from sklearn import metrics
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.metrics.common import average_precision
@ -24,6 +28,17 @@ clip_classification_metrics: Registry[ClipClassificationMetric, []] = Registry(
)
@add_import_config(clip_classification_metrics)
class ClipClassificationMetricImportConfig(ImportConfig):
"""Use any callable as a clip classification metric.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class ClipClassificationAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision"
label: str = "average_precision"
@ -123,10 +138,7 @@ class ClipClassificationROCAUC:
ClipClassificationMetricConfig = Annotated[
Union[
ClipClassificationAveragePrecisionConfig,
ClipClassificationROCAUCConfig,
],
ClipClassificationAveragePrecisionConfig | ClipClassificationROCAUCConfig,
Field(discriminator="name"),
]

View File

@ -1,12 +1,16 @@
from dataclasses import dataclass
from typing import Annotated, Callable, Dict, Literal, Sequence, Union
from typing import Annotated, Callable, Dict, Literal, Sequence
import numpy as np
from pydantic import Field
from sklearn import metrics
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.metrics.common import average_precision
@ -23,6 +27,17 @@ clip_detection_metrics: Registry[ClipDetectionMetric, []] = Registry(
)
@add_import_config(clip_detection_metrics)
class ClipDetectionMetricImportConfig(ImportConfig):
"""Use any callable as a clip detection metric.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class ClipDetectionAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision"
label: str = "average_precision"
@ -159,12 +174,10 @@ class ClipDetectionPrecision:
ClipDetectionMetricConfig = Annotated[
Union[
ClipDetectionAveragePrecisionConfig,
ClipDetectionROCAUCConfig,
ClipDetectionRecallConfig,
ClipDetectionPrecisionConfig,
],
ClipDetectionAveragePrecisionConfig
| ClipDetectionROCAUCConfig
| ClipDetectionRecallConfig
| ClipDetectionPrecisionConfig,
Field(discriminator="name"),
]

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Tuple
import numpy as np
@ -11,7 +11,7 @@ __all__ = [
def compute_precision_recall(
y_true,
y_score,
num_positives: Optional[int] = None,
num_positives: int | None = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
y_true = np.array(y_true)
y_score = np.array(y_score)
@ -41,7 +41,7 @@ def compute_precision_recall(
def average_precision(
y_true,
y_score,
num_positives: Optional[int] = None,
num_positives: int | None = None,
) -> float:
if num_positives == 0:
return np.nan

View File

@ -5,9 +5,7 @@ from typing import (
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np
@ -15,21 +13,27 @@ from pydantic import Field
from sklearn import metrics
from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction
from batdetect2.postprocess.types import Detection
__all__ = [
"DetectionMetricConfig",
"DetectionMetric",
"DetectionMetricImportConfig",
"build_detection_metric",
]
@dataclass
class MatchEval:
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
gt: data.SoundEventAnnotation | None
pred: Detection | None
is_prediction: bool
is_ground_truth: bool
@ -48,6 +52,17 @@ DetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric")
@add_import_config(detection_metrics)
class DetectionMetricImportConfig(ImportConfig):
"""Use any callable as a detection metric.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class DetectionAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision"
label: str = "average_precision"
@ -79,7 +94,7 @@ class DetectionAveragePrecision:
y_score.append(m.score)
ap = average_precision(y_true, y_score, num_positives=num_positives)
return {self.label: ap}
return {self.label: float(ap)}
@detection_metrics.register(DetectionAveragePrecisionConfig)
@staticmethod
@ -212,12 +227,10 @@ class DetectionPrecision:
DetectionMetricConfig = Annotated[
Union[
DetectionAveragePrecisionConfig,
DetectionROCAUCConfig,
DetectionRecallConfig,
DetectionPrecisionConfig,
],
DetectionAveragePrecisionConfig
| DetectionROCAUCConfig
| DetectionRecallConfig
| DetectionPrecisionConfig,
Field(discriminator="name"),
]

View File

@ -1,4 +1,3 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import (
Annotated,
@ -6,9 +5,7 @@ from typing import (
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np
@ -16,14 +13,20 @@ from pydantic import Field
from sklearn import metrics, preprocessing
from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction
from batdetect2.typing.targets import TargetProtocol
from batdetect2.postprocess.types import Detection
from batdetect2.targets.types import TargetProtocol
__all__ = [
"TopClassMetricConfig",
"TopClassMetric",
"TopClassMetricImportConfig",
"build_top_class_metric",
]
@ -31,14 +34,14 @@ __all__ = [
@dataclass
class MatchEval:
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
gt: data.SoundEventAnnotation | None
pred: Detection | None
is_ground_truth: bool
is_generic: bool
is_prediction: bool
pred_class: Optional[str]
true_class: Optional[str]
pred_class: str | None
true_class: str | None
score: float
@ -54,6 +57,17 @@ TopClassMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric")
@add_import_config(top_class_metrics)
class TopClassMetricImportConfig(ImportConfig):
"""Use any callable as a top-class metric.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class TopClassAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision"
label: str = "average_precision"
@ -301,13 +315,11 @@ class BalancedAccuracy:
TopClassMetricConfig = Annotated[
Union[
TopClassAveragePrecisionConfig,
TopClassROCAUCConfig,
TopClassRecallConfig,
TopClassPrecisionConfig,
BalancedAccuracyConfig,
],
TopClassAveragePrecisionConfig
| TopClassROCAUCConfig
| TopClassRecallConfig
| TopClassPrecisionConfig
| BalancedAccuracyConfig,
Field(discriminator="name"),
]

View File

@ -1,16 +1,14 @@
from typing import Optional
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from batdetect2.core import BaseConfig
from batdetect2.typing import TargetProtocol
from batdetect2.targets.types import TargetProtocol
class BasePlotConfig(BaseConfig):
label: str = "plot"
theme: str = "default"
title: Optional[str] = None
title: str | None = None
figsize: tuple[int, int] = (10, 10)
dpi: int = 100
@ -21,7 +19,7 @@ class BasePlot:
targets: TargetProtocol,
label: str = "plot",
figsize: tuple[int, int] = (10, 10),
title: Optional[str] = None,
title: str | None = None,
dpi: int = 100,
theme: str = "default",
):

View File

@ -3,10 +3,8 @@ from typing import (
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
@ -14,7 +12,7 @@ from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.classification import (
ClipEval,
_extract_per_class_metric_data,
@ -31,7 +29,7 @@ from batdetect2.plotting.metrics import (
plot_threshold_recall_curve,
plot_threshold_recall_curves,
)
from batdetect2.typing import TargetProtocol
from batdetect2.targets.types import TargetProtocol
ClassificationPlotter = Callable[
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
@ -42,10 +40,21 @@ classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
)
@add_import_config(classification_plots)
class ClassificationPlotImportConfig(ImportConfig):
"""Use any callable as a classification plot.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Classification Precision-Recall Curve"
title: str | None = "Classification Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
@ -88,9 +97,7 @@ class PRCurve(BasePlot):
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(PRCurveConfig)
@ -108,7 +115,7 @@ class PRCurve(BasePlot):
class ThresholdPrecisionCurveConfig(BasePlotConfig):
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
label: str = "threshold_precision_curve"
title: Optional[str] = "Classification Threshold-Precision Curve"
title: str | None = "Classification Threshold-Precision Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
@ -181,7 +188,7 @@ class ThresholdPrecisionCurve(BasePlot):
class ThresholdRecallCurveConfig(BasePlotConfig):
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
label: str = "threshold_recall_curve"
title: Optional[str] = "Classification Threshold-Recall Curve"
title: str | None = "Classification Threshold-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
@ -254,7 +261,7 @@ class ThresholdRecallCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Classification ROC Curve"
title: str | None = "Classification ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
@ -326,12 +333,10 @@ class ROCCurve(BasePlot):
ClassificationPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ThresholdPrecisionCurveConfig,
ThresholdRecallCurveConfig,
],
PRCurveConfig
| ROCCurveConfig
| ThresholdPrecisionCurveConfig
| ThresholdRecallCurveConfig,
Field(discriminator="name"),
]

View File

@ -3,10 +3,8 @@ from typing import (
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
@ -14,7 +12,7 @@ from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.clip_classification import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
@ -24,10 +22,11 @@ from batdetect2.plotting.metrics import (
plot_roc_curve,
plot_roc_curves,
)
from batdetect2.typing import TargetProtocol
from batdetect2.targets.types import TargetProtocol
__all__ = [
"ClipClassificationPlotConfig",
"ClipClassificationPlotImportConfig",
"ClipClassificationPlotter",
"build_clip_classification_plotter",
]
@ -41,10 +40,21 @@ clip_classification_plots: Registry[
] = Registry("clip_classification_plot")
@add_import_config(clip_classification_plots)
class ClipClassificationPlotImportConfig(ImportConfig):
"""Use any callable as a clip classification plot.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Clip Classification Precision-Recall Curve"
title: str | None = "Clip Classification Precision-Recall Curve"
separate_figures: bool = False
@ -111,7 +121,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Clip Classification ROC Curve"
title: str | None = "Clip Classification ROC Curve"
separate_figures: bool = False
@ -174,10 +184,7 @@ class ROCCurve(BasePlot):
ClipClassificationPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
],
PRCurveConfig | ROCCurveConfig,
Field(discriminator="name"),
]

View File

@ -3,10 +3,8 @@ from typing import (
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import pandas as pd
@ -15,15 +13,16 @@ from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.clip_detection import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.typing import TargetProtocol
from batdetect2.targets.types import TargetProtocol
__all__ = [
"ClipDetectionPlotConfig",
"ClipDetectionPlotImportConfig",
"ClipDetectionPlotter",
"build_clip_detection_plotter",
]
@ -38,10 +37,21 @@ clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
)
@add_import_config(clip_detection_plots)
class ClipDetectionPlotImportConfig(ImportConfig):
"""Use any callable as a clip detection plot.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Clip Detection Precision-Recall Curve"
title: str | None = "Clip Detection Precision-Recall Curve"
class PRCurve(BasePlot):
@ -74,7 +84,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Clip Detection ROC Curve"
title: str | None = "Clip Detection ROC Curve"
class ROCCurve(BasePlot):
@ -107,7 +117,7 @@ class ROCCurve(BasePlot):
class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution"
title: Optional[str] = "Clip Detection Score Distribution"
title: str | None = "Clip Detection Score Distribution"
class ScoreDistributionPlot(BasePlot):
@ -147,11 +157,7 @@ class ScoreDistributionPlot(BasePlot):
ClipDetectionPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
],
PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig,
Field(discriminator="name"),
]

View File

@ -4,10 +4,8 @@ from typing import (
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
@ -18,14 +16,16 @@ from pydantic import Field
from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry
from batdetect2.audio.types import AudioLoader
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.detection import ClipEval
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.detections import plot_clip_detections
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
@ -34,10 +34,21 @@ detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
)
@add_import_config(detection_plots)
class DetectionPlotImportConfig(ImportConfig):
"""Use any callable as a detection plot.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Detection Precision-Recall Curve"
title: str | None = "Detection Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
@ -100,7 +111,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Detection ROC Curve"
title: str | None = "Detection ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
@ -159,7 +170,7 @@ class ROCCurve(BasePlot):
class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution"
title: Optional[str] = "Detection Score Distribution"
title: str | None = "Detection Score Distribution"
ignore_non_predictions: bool = True
ignore_generic: bool = True
@ -226,7 +237,7 @@ class ScoreDistributionPlot(BasePlot):
class ExampleDetectionPlotConfig(BasePlotConfig):
name: Literal["example_detection"] = "example_detection"
label: str = "example_detection"
title: Optional[str] = "Example Detection"
title: str | None = "Example Detection"
figsize: tuple[int, int] = (10, 4)
num_examples: int = 5
threshold: float = 0.2
@ -292,12 +303,10 @@ class ExampleDetectionPlot(BasePlot):
DetectionPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
ExampleDetectionPlotConfig,
],
PRCurveConfig
| ROCCurveConfig
| ScoreDistributionPlotConfig
| ExampleDetectionPlotConfig,
Field(discriminator="name"),
]

View File

@ -4,14 +4,9 @@ from dataclasses import dataclass, field
from typing import (
Annotated,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
@ -21,7 +16,8 @@ from pydantic import Field
from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry
from batdetect2.audio.types import AudioLoader
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.top_class import (
ClipEval,
@ -32,19 +28,31 @@ from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[tuple[str, Figure]]]
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
name="top_class_plot"
)
@add_import_config(top_class_plots)
class TopClassPlotImportConfig(ImportConfig):
"""Use any callable as a top-class plot.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Top Class Precision-Recall Curve"
title: str | None = "Top Class Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
@ -64,7 +72,7 @@ class PRCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
) -> Iterable[tuple[str, Figure]]:
y_true = []
y_score = []
num_positives = 0
@ -111,7 +119,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Top Class ROC Curve"
title: str | None = "Top Class ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
@ -131,7 +139,7 @@ class ROCCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
) -> Iterable[tuple[str, Figure]]:
y_true = []
y_score = []
@ -173,7 +181,7 @@ class ROCCurve(BasePlot):
class ConfusionMatrixConfig(BasePlotConfig):
name: Literal["confusion_matrix"] = "confusion_matrix"
title: Optional[str] = "Top Class Confusion Matrix"
title: str | None = "Top Class Confusion Matrix"
figsize: tuple[int, int] = (10, 10)
label: str = "confusion_matrix"
exclude_generic: bool = True
@ -214,7 +222,7 @@ class ConfusionMatrix(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
) -> Iterable[tuple[str, Figure]]:
cm, labels = compute_confusion_matrix(
clip_evaluations,
self.targets,
@ -257,7 +265,7 @@ class ConfusionMatrix(BasePlot):
class ExampleClassificationPlotConfig(BasePlotConfig):
name: Literal["example_classification"] = "example_classification"
label: str = "example_classification"
title: Optional[str] = "Example Classification"
title: str | None = "Example Classification"
num_examples: int = 4
threshold: float = 0.2
audio: AudioConfig = Field(default_factory=AudioConfig)
@ -286,26 +294,26 @@ class ExampleClassificationPlot(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
) -> Iterable[tuple[str, Figure]]:
grouped = group_matches(clip_evaluations, threshold=self.threshold)
for class_name, matches in grouped.items():
true_positives: List[MatchEval] = get_binned_sample(
true_positives: list[MatchEval] = get_binned_sample(
matches.true_positives,
n_examples=self.num_examples,
)
false_positives: List[MatchEval] = get_binned_sample(
false_positives: list[MatchEval] = get_binned_sample(
matches.false_positives,
n_examples=self.num_examples,
)
false_negatives: List[MatchEval] = random.sample(
false_negatives: list[MatchEval] = random.sample(
matches.false_negatives,
k=min(self.num_examples, len(matches.false_negatives)),
)
cross_triggers: List[MatchEval] = get_binned_sample(
cross_triggers: list[MatchEval] = get_binned_sample(
matches.cross_triggers, n_examples=self.num_examples
)
@ -348,12 +356,10 @@ class ExampleClassificationPlot(BasePlot):
TopClassPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ConfusionMatrixConfig,
ExampleClassificationPlotConfig,
],
PRCurveConfig
| ROCCurveConfig
| ConfusionMatrixConfig
| ExampleClassificationPlotConfig,
Field(discriminator="name"),
]
@ -367,16 +373,16 @@ def build_top_class_plotter(
@dataclass
class ClassMatches:
false_positives: List[MatchEval] = field(default_factory=list)
false_negatives: List[MatchEval] = field(default_factory=list)
true_positives: List[MatchEval] = field(default_factory=list)
cross_triggers: List[MatchEval] = field(default_factory=list)
false_positives: list[MatchEval] = field(default_factory=list)
false_negatives: list[MatchEval] = field(default_factory=list)
true_positives: list[MatchEval] = field(default_factory=list)
cross_triggers: list[MatchEval] = field(default_factory=list)
def group_matches(
clip_evals: Sequence[ClipEval],
threshold: float = 0.2,
) -> Dict[str, ClassMatches]:
) -> dict[str, ClassMatches]:
class_examples = defaultdict(ClassMatches)
for clip_eval in clip_evals:
@ -405,12 +411,13 @@ def group_matches(
return class_examples
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
def get_binned_sample(matches: list[MatchEval], n_examples: int = 5):
if len(matches) < n_examples:
return matches
indices, pred_scores = zip(
*[(index, match.score) for index, match in enumerate(matches)]
*[(index, match.score) for index, match in enumerate(matches)],
strict=False,
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")

View File

@ -0,0 +1,27 @@
import json
from pathlib import Path
from typing import Iterable
from matplotlib.figure import Figure
from soundevent import data
__all__ = ["save_evaluation_results"]
def save_evaluation_results(
metrics: dict[str, float],
plots: Iterable[tuple[str, Figure]],
output_dir: data.PathLike,
) -> None:
"""Save evaluation metrics and plots to disk."""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
metrics_path = output_path / "metrics.json"
metrics_path.write_text(json.dumps(metrics))
for figure_name, figure in plots:
figure_path = output_path / figure_name
figure_path.parent.mkdir(parents=True, exist_ok=True)
figure.savefig(figure_path)

View File

@ -1,106 +0,0 @@
from typing import Annotated, Callable, Literal, Sequence, Union
import pandas as pd
from pydantic import Field
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipMatches
EvaluationTableGenerator = Callable[[Sequence[ClipMatches]], pd.DataFrame]
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
"evaluation_table"
)
class FullEvaluationTableConfig(BaseConfig):
name: Literal["full_evaluation"] = "full_evaluation"
class FullEvaluationTable:
def __call__(
self, clip_evaluations: Sequence[ClipMatches]
) -> pd.DataFrame:
return extract_matches_dataframe(clip_evaluations)
@tables_registry.register(FullEvaluationTableConfig)
@staticmethod
def from_config(config: FullEvaluationTableConfig):
return FullEvaluationTable()
def extract_matches_dataframe(
clip_evaluations: Sequence[ClipMatches],
) -> pd.DataFrame:
data = []
for clip_evaluation in clip_evaluations:
for match in clip_evaluation.matches:
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
pred_start_time = pred_low_freq = pred_end_time = (
pred_high_freq
) = None
sound_event_annotation = match.sound_event_annotation
if sound_event_annotation is not None:
geometry = sound_event_annotation.sound_event.geometry
assert geometry is not None
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
compute_bounds(geometry)
)
if match.pred_geometry is not None:
(
pred_start_time,
pred_low_freq,
pred_end_time,
pred_high_freq,
) = compute_bounds(match.pred_geometry)
data.append(
{
("recording", "uuid"): match.clip.recording.uuid,
("clip", "uuid"): match.clip.uuid,
("clip", "start_time"): match.clip.start_time,
("clip", "end_time"): match.clip.end_time,
("gt", "uuid"): match.sound_event_annotation.uuid
if match.sound_event_annotation is not None
else None,
("gt", "class"): match.gt_class,
("gt", "det"): match.gt_det,
("gt", "start_time"): gt_start_time,
("gt", "end_time"): gt_end_time,
("gt", "low_freq"): gt_low_freq,
("gt", "high_freq"): gt_high_freq,
("pred", "score"): match.pred_score,
("pred", "class"): match.top_class,
("pred", "class_score"): match.top_class_score,
("pred", "start_time"): pred_start_time,
("pred", "end_time"): pred_end_time,
("pred", "low_freq"): pred_low_freq,
("pred", "high_freq"): pred_high_freq,
("match", "affinity"): match.affinity,
**{
("pred_class_score", key): value
for key, value in match.pred_class_scores.items()
},
}
)
df = pd.DataFrame(data)
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
return df
EvaluationTableConfig = Annotated[
Union[FullEvaluationTableConfig,], Field(discriminator="name")
]
def build_table_generator(
config: EvaluationTableConfig,
) -> EvaluationTableGenerator:
return tables_registry.build(config)

View File

@ -1,4 +1,4 @@
from typing import Annotated, Optional, Sequence, Union
from typing import Annotated, Sequence
from pydantic import Field
from soundevent import data
@ -11,12 +11,10 @@ from batdetect2.evaluate.tasks.clip_classification import (
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
from batdetect2.evaluate.types import EvaluationTaskProtocol
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets import build_targets
from batdetect2.typing import (
BatDetect2Prediction,
EvaluatorProtocol,
TargetProtocol,
)
from batdetect2.targets.types import TargetProtocol
__all__ = [
"TaskConfig",
@ -26,31 +24,29 @@ __all__ = [
TaskConfig = Annotated[
Union[
ClassificationTaskConfig,
DetectionTaskConfig,
ClipDetectionTaskConfig,
ClipClassificationTaskConfig,
TopClassDetectionTaskConfig,
],
ClassificationTaskConfig
| DetectionTaskConfig
| ClipDetectionTaskConfig
| ClipClassificationTaskConfig
| TopClassDetectionTaskConfig,
Field(discriminator="name"),
]
def build_task(
config: TaskConfig,
targets: Optional[TargetProtocol] = None,
) -> EvaluatorProtocol:
targets: TargetProtocol | None = None,
) -> EvaluationTaskProtocol:
targets = targets or build_targets()
return tasks_registry.build(config, targets)
def evaluate_task(
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction],
task: Optional["str"] = None,
targets: Optional[TargetProtocol] = None,
config: Optional[Union[TaskConfig, dict]] = None,
predictions: Sequence[ClipDetections],
task: str | None = None,
targets: TargetProtocol | None = None,
config: TaskConfig | dict | None = None,
):
if isinstance(config, BaseTaskConfig):
task_obj = build_task(config, targets)

View File

@ -4,78 +4,93 @@ from typing import (
Generic,
Iterable,
List,
Optional,
Literal,
Sequence,
Tuple,
TypeVar,
)
from loguru import logger
from matplotlib.figure import Figure
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.evaluate.match import (
MatchConfig,
StartTimeMatchConfig,
build_matcher,
from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction
from batdetect2.typing.targets import TargetProtocol
from batdetect2.evaluate.affinity import (
AffinityConfig,
TimeAffinityConfig,
build_affinity_function,
)
from batdetect2.evaluate.types import (
AffinityFunction,
EvaluationTaskProtocol,
)
from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol
__all__ = [
"BaseTaskConfig",
"BaseTask",
"TaskImportConfig",
]
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
tasks_registry: Registry[EvaluationTaskProtocol, [TargetProtocol]] = Registry(
"tasks"
)
@add_import_config(tasks_registry)
class TaskImportConfig(ImportConfig):
"""Use any callable as an evaluation task.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
T_Output = TypeVar("T_Output")
class BaseTaskConfig(BaseConfig):
prefix: str
ignore_start_end: float = 0.01
matching_strategy: MatchConfig = Field(
default_factory=StartTimeMatchConfig
)
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
class BaseTask(EvaluationTaskProtocol, Generic[T_Output]):
targets: TargetProtocol
matcher: MatcherProtocol
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
ignore_start_end: float
prefix: str
ignore_start_end: float
def __init__(
self,
matcher: MatcherProtocol,
targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
prefix: str,
plots: List[
Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]
]
| None = None,
ignore_start_end: float = 0.01,
plots: Optional[
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
] = None,
):
self.matcher = matcher
self.prefix = prefix
self.targets = targets
self.metrics = metrics
self.plots = plots or []
self.targets = targets
self.prefix = prefix
self.ignore_start_end = ignore_start_end
def compute_metrics(
@ -93,24 +108,30 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
self, eval_outputs: List[T_Output]
) -> Iterable[Tuple[str, Figure]]:
for plot in self.plots:
for name, fig in plot(eval_outputs):
yield f"{self.prefix}/{name}", fig
try:
for name, fig in plot(eval_outputs):
yield f"{self.prefix}/{name}", fig
except Exception as e:
logger.error(f"Error plotting {self.prefix}: {e}")
continue
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction],
predictions: Sequence[ClipDetections],
) -> List[T_Output]:
return [
self.evaluate_clip(clip_annotation, preds)
for clip_annotation, preds in zip(clip_annotations, predictions)
for clip_annotation, preds in zip(
clip_annotations, predictions, strict=False
)
]
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction,
) -> T_Output: ...
prediction: ClipDetections,
) -> T_Output: ... # ty: ignore[empty-body]
def include_sound_event_annotation(
self,
@ -121,9 +142,6 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
return False
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
return is_in_bounds(
geometry,
clip,
@ -132,7 +150,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
def include_prediction(
self,
prediction: RawPrediction,
prediction: Detection,
clip: data.Clip,
) -> bool:
return is_in_bounds(
@ -141,25 +159,56 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
self.ignore_start_end,
)
class BaseSEDTaskConfig(BaseTaskConfig):
affinity: AffinityConfig = Field(default_factory=TimeAffinityConfig)
affinity_threshold: float = 0
strict_match: bool = True
class BaseSEDTask(BaseTask[T_Output]):
affinity: AffinityFunction
def __init__(
self,
prefix: str,
targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
affinity: AffinityFunction,
plots: List[
Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]
]
| None = None,
affinity_threshold: float = 0,
ignore_start_end: float = 0.01,
strict_match: bool = True,
):
super().__init__(
prefix=prefix,
metrics=metrics,
plots=plots,
targets=targets,
ignore_start_end=ignore_start_end,
)
self.affinity = affinity
self.affinity_threshold = affinity_threshold
self.strict_match = strict_match
@classmethod
def build(
cls,
config: BaseTaskConfig,
config: BaseSEDTaskConfig,
targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
plots: Optional[
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
] = None,
**kwargs,
):
matcher = build_matcher(config.matching_strategy)
affinity = build_affinity_function(config.affinity)
return cls(
matcher=matcher,
targets=targets,
metrics=metrics,
plots=plots,
affinity=affinity,
affinity_threshold=config.affinity_threshold,
prefix=config.prefix,
ignore_start_end=config.ignore_start_end,
strict_match=config.strict_match,
targets=targets,
**kwargs,
)

View File

@ -1,10 +1,9 @@
from typing import (
List,
Literal,
)
from functools import partial
from typing import Literal
from pydantic import Field
from soundevent import data
from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.metrics.classification import (
ClassificationAveragePrecisionConfig,
@ -18,24 +17,25 @@ from batdetect2.evaluate.plots.classification import (
build_classification_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
BaseSEDTask,
BaseSEDTaskConfig,
tasks_registry,
)
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.targets.types import TargetProtocol
class ClassificationTaskConfig(BaseTaskConfig):
class ClassificationTaskConfig(BaseSEDTaskConfig):
name: Literal["sound_event_classification"] = "sound_event_classification"
prefix: str = "classification"
metrics: List[ClassificationMetricConfig] = Field(
metrics: list[ClassificationMetricConfig] = Field(
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
)
plots: List[ClassificationPlotConfig] = Field(default_factory=list)
plots: list[ClassificationPlotConfig] = Field(default_factory=list)
include_generics: bool = True
class ClassificationTask(BaseTask[ClipEval]):
class ClassificationTask(BaseSEDTask[ClipEval]):
def __init__(
self,
*args,
@ -48,13 +48,13 @@ class ClassificationTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction,
prediction: ClipDetections,
) -> ClipEval:
clip = clip_annotation.clip
preds = [
pred
for pred in prediction.predictions
for pred in prediction.detections
if self.include_prediction(pred, clip)
]
@ -73,40 +73,40 @@ class ClassificationTask(BaseTask[ClipEval]):
gts = [
sound_event
for sound_event in all_gts
if self.is_class(sound_event, class_name)
if is_target_class(
sound_event,
class_name,
self.targets,
include_generics=self.include_generics,
)
]
scores = [float(pred.class_scores[class_idx]) for pred in preds]
matches = []
for pred_idx, gt_idx, _ in self.matcher(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
predictions=[pred.geometry for pred in preds],
scores=scores,
for match in match_detections_and_gts(
detections=preds,
ground_truths=gts,
affinity=self.affinity,
score=partial(get_class_score, class_idx=class_idx),
strict_match=self.strict_match,
affinity_threshold=self.affinity_threshold,
):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
true_class = (
self.targets.encode_class(gt) if gt is not None else None
self.targets.encode_class(match.annotation)
if match.annotation is not None
else None
)
score = (
float(pred.class_scores[class_idx])
if pred is not None
else 0
)
matches.append(
MatchEval(
clip=clip,
gt=gt,
pred=pred,
is_prediction=pred is not None,
is_ground_truth=gt is not None,
is_generic=gt is not None and true_class is None,
gt=match.annotation,
pred=match.prediction,
is_prediction=match.prediction is not None,
is_ground_truth=match.annotation is not None,
is_generic=match.annotation is not None
and true_class is None,
true_class=true_class,
score=score,
score=match.prediction_score,
)
)
@ -114,20 +114,6 @@ class ClassificationTask(BaseTask[ClipEval]):
return ClipEval(clip=clip, matches=per_class_matches)
def is_class(
self,
sound_event: data.SoundEventAnnotation,
class_name: str,
) -> bool:
sound_event_class = self.targets.encode_class(sound_event)
if sound_event_class is None and self.include_generics:
# Sound events that are generic could be of the given
# class
return True
return sound_event_class == class_name
@tasks_registry.register(ClassificationTaskConfig)
@staticmethod
def from_config(
@ -147,4 +133,25 @@ class ClassificationTask(BaseTask[ClipEval]):
plots=plots,
targets=targets,
metrics=metrics,
include_generics=config.include_generics,
)
def get_class_score(pred: Detection, class_idx: int) -> float:
return pred.class_scores[class_idx]
def is_target_class(
sound_event: data.SoundEventAnnotation,
class_name: str,
targets: TargetProtocol,
include_generics: bool = True,
) -> bool:
sound_event_class = targets.encode_class(sound_event)
if sound_event_class is None and include_generics:
# Sound events that are generic could be of the given
# class
return True
return sound_event_class == class_name

View File

@ -1,5 +1,5 @@
from collections import defaultdict
from typing import List, Literal
from typing import Literal
from pydantic import Field
from soundevent import data
@ -19,26 +19,26 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
class ClipClassificationTaskConfig(BaseTaskConfig):
name: Literal["clip_classification"] = "clip_classification"
prefix: str = "clip_classification"
metrics: List[ClipClassificationMetricConfig] = Field(
metrics: list[ClipClassificationMetricConfig] = Field(
default_factory=lambda: [
ClipClassificationAveragePrecisionConfig(),
]
)
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list)
plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)
class ClipClassificationTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction,
prediction: ClipDetections,
) -> ClipEval:
clip = clip_annotation.clip
@ -55,7 +55,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
gt_classes.add(class_name)
pred_scores = defaultdict(float)
for pred in prediction.predictions:
for pred in prediction.detections:
if not self.include_prediction(pred, clip):
continue
@ -78,8 +78,8 @@ class ClipClassificationTask(BaseTask[ClipEval]):
build_clip_classification_plotter(plot, targets)
for plot in config.plots
]
return ClipClassificationTask.build(
config=config,
return ClipClassificationTask(
prefix=config.prefix,
plots=plots,
metrics=metrics,
targets=targets,

View File

@ -1,4 +1,4 @@
from typing import List, Literal
from typing import Literal
from pydantic import Field
from soundevent import data
@ -18,26 +18,26 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
class ClipDetectionTaskConfig(BaseTaskConfig):
name: Literal["clip_detection"] = "clip_detection"
prefix: str = "clip_detection"
metrics: List[ClipDetectionMetricConfig] = Field(
metrics: list[ClipDetectionMetricConfig] = Field(
default_factory=lambda: [
ClipDetectionAveragePrecisionConfig(),
]
)
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list)
plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)
class ClipDetectionTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction,
prediction: ClipDetections,
) -> ClipEval:
clip = clip_annotation.clip
@ -47,7 +47,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
)
pred_score = 0
for pred in prediction.predictions:
for pred in prediction.detections:
if not self.include_prediction(pred, clip):
continue
@ -69,8 +69,8 @@ class ClipDetectionTask(BaseTask[ClipEval]):
build_clip_detection_plotter(plot, targets)
for plot in config.plots
]
return ClipDetectionTask.build(
config=config,
return ClipDetectionTask(
prefix=config.prefix,
metrics=metrics,
targets=targets,
plots=plots,

View File

@ -1,7 +1,8 @@
from typing import List, Literal
from typing import Literal
from pydantic import Field
from soundevent import data
from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.metrics.detection import (
ClipEval,
@ -15,28 +16,28 @@ from batdetect2.evaluate.plots.detection import (
build_detection_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
BaseSEDTask,
BaseSEDTaskConfig,
tasks_registry,
)
from batdetect2.typing import TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
class DetectionTaskConfig(BaseTaskConfig):
class DetectionTaskConfig(BaseSEDTaskConfig):
name: Literal["sound_event_detection"] = "sound_event_detection"
prefix: str = "detection"
metrics: List[DetectionMetricConfig] = Field(
metrics: list[DetectionMetricConfig] = Field(
default_factory=lambda: [DetectionAveragePrecisionConfig()]
)
plots: List[DetectionPlotConfig] = Field(default_factory=list)
plots: list[DetectionPlotConfig] = Field(default_factory=list)
class DetectionTask(BaseTask[ClipEval]):
class DetectionTask(BaseSEDTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction,
prediction: ClipDetections,
) -> ClipEval:
clip = clip_annotation.clip
@ -47,27 +48,26 @@ class DetectionTask(BaseTask[ClipEval]):
]
preds = [
pred
for pred in prediction.predictions
for pred in prediction.detections
if self.include_prediction(pred, clip)
]
scores = [pred.detection_score for pred in preds]
matches = []
for pred_idx, gt_idx, _ in self.matcher(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
predictions=[pred.geometry for pred in preds],
scores=scores,
for match in match_detections_and_gts(
detections=preds,
ground_truths=gts,
affinity=self.affinity,
score=lambda pred: pred.detection_score,
strict_match=self.strict_match,
affinity_threshold=self.affinity_threshold,
):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
matches.append(
MatchEval(
gt=gt,
pred=pred,
is_prediction=pred is not None,
is_ground_truth=gt is not None,
score=pred.detection_score if pred is not None else 0,
gt=match.annotation,
pred=match.prediction,
is_prediction=match.prediction is not None,
is_ground_truth=match.annotation is not None,
score=match.prediction_score,
)
)

View File

@ -1,7 +1,8 @@
from typing import List, Literal
from typing import Literal
from pydantic import Field
from soundevent import data
from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.metrics.top_class import (
ClipEval,
@ -15,28 +16,28 @@ from batdetect2.evaluate.plots.top_class import (
build_top_class_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
BaseSEDTask,
BaseSEDTaskConfig,
tasks_registry,
)
from batdetect2.typing import TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import TargetProtocol
class TopClassDetectionTaskConfig(BaseTaskConfig):
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
name: Literal["top_class_detection"] = "top_class_detection"
prefix: str = "top_class"
metrics: List[TopClassMetricConfig] = Field(
metrics: list[TopClassMetricConfig] = Field(
default_factory=lambda: [TopClassAveragePrecisionConfig()]
)
plots: List[TopClassPlotConfig] = Field(default_factory=list)
plots: list[TopClassPlotConfig] = Field(default_factory=list)
class TopClassDetectionTask(BaseTask[ClipEval]):
class TopClassDetectionTask(BaseSEDTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
prediction: BatDetect2Prediction,
prediction: ClipDetections,
) -> ClipEval:
clip = clip_annotation.clip
@ -47,21 +48,21 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
]
preds = [
pred
for pred in prediction.predictions
for pred in prediction.detections
if self.include_prediction(pred, clip)
]
# Take the highest score for each prediction
scores = [pred.class_scores.max() for pred in preds]
matches = []
for pred_idx, gt_idx, _ in self.matcher(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
predictions=[pred.geometry for pred in preds],
scores=scores,
for match in match_detections_and_gts(
ground_truths=gts,
detections=preds,
affinity=self.affinity,
score=lambda pred: pred.class_scores.max(),
strict_match=self.strict_match,
affinity_threshold=self.affinity_threshold,
):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
gt = match.annotation
pred = match.prediction
true_class = (
self.targets.encode_class(gt) if gt is not None else None
)
@ -69,11 +70,6 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
class_idx = (
pred.class_scores.argmax() if pred is not None else None
)
score = (
float(pred.class_scores[class_idx]) if pred is not None else 0
)
pred_class = (
self.targets.class_names[class_idx]
if class_idx is not None
@ -90,7 +86,7 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
true_class=true_class,
is_generic=gt is not None and true_class is None,
pred_class=pred_class,
score=score,
score=match.prediction_score,
)
)

View File

@ -0,0 +1,147 @@
from dataclasses import dataclass
from typing import Generic, Iterable, Protocol, Sequence, TypeVar
from matplotlib.figure import Figure
from soundevent import data
from batdetect2.outputs.types import OutputTransformProtocol
from batdetect2.postprocess.types import (
ClipDetections,
ClipDetectionsTensor,
Detection,
)
from batdetect2.targets.types import TargetProtocol
__all__ = [
"AffinityFunction",
"ClipMatches",
"EvaluationTaskProtocol",
"EvaluatorProtocol",
"MatchEvaluation",
"MatcherProtocol",
"MetricsProtocol",
"PlotterProtocol",
]
@dataclass
class MatchEvaluation:
clip: data.Clip
sound_event_annotation: data.SoundEventAnnotation | None
gt_det: bool
gt_class: str | None
gt_geometry: data.Geometry | None
pred_score: float
pred_class_scores: dict[str, float]
pred_geometry: data.Geometry | None
affinity: float
@property
def top_class(self) -> str | None:
if not self.pred_class_scores:
return None
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
@property
def is_prediction(self) -> bool:
return self.pred_geometry is not None
@property
def is_generic(self) -> bool:
return self.gt_det and self.gt_class is None
@property
def top_class_score(self) -> float:
pred_class = self.top_class
if pred_class is None:
return 0
return self.pred_class_scores[pred_class]
@dataclass
class ClipMatches:
clip: data.Clip
matches: list[MatchEvaluation]
class MatcherProtocol(Protocol):
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
) -> Iterable[tuple[int | None, int | None, float]]: ...
class AffinityFunction(Protocol):
def __call__(
self,
detection: Detection,
ground_truth: data.SoundEventAnnotation,
) -> float: ...
class MetricsProtocol(Protocol):
def __call__(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[Detection]],
) -> dict[str, float]: ...
class PlotterProtocol(Protocol):
def __call__(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[Detection]],
) -> Iterable[tuple[str, Figure]]: ...
EvaluationOutput = TypeVar("EvaluationOutput")
class EvaluationTaskProtocol(Protocol, Generic[EvaluationOutput]):
targets: TargetProtocol
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[ClipDetections],
) -> EvaluationOutput: ...
def compute_metrics(
self,
eval_outputs: EvaluationOutput,
) -> dict[str, float]: ...
def generate_plots(
self,
eval_outputs: EvaluationOutput,
) -> Iterable[tuple[str, Figure]]: ...
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
targets: TargetProtocol
transform: OutputTransformProtocol
def to_clip_detections_batch(
self,
clip_detections: Sequence[ClipDetectionsTensor],
clips: Sequence[data.Clip],
) -> list[ClipDetections]: ...
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[ClipDetections],
) -> EvaluationOutput: ...
def compute_metrics(
self,
eval_outputs: EvaluationOutput,
) -> dict[str, float]: ...
def generate_plots(
self,
eval_outputs: EvaluationOutput,
) -> Iterable[tuple[str, Figure]]: ...

View File

@ -1,7 +1,7 @@
import argparse
import os
import warnings
from typing import List, Optional
from typing import List
import torch
import torch.utils.data
@ -88,7 +88,8 @@ def select_device(warn=True) -> str:
if warn:
warnings.warn(
"No GPU available, using the CPU instead. Please consider using a GPU "
"to speed up training."
"to speed up training.",
stacklevel=2,
)
return "cpu"
@ -98,8 +99,8 @@ def load_annotations(
dataset_name: str,
ann_path: str,
audio_path: str,
classes_to_ignore: Optional[List[str]] = None,
events_of_interest: Optional[List[str]] = None,
classes_to_ignore: List[str] | None = None,
events_of_interest: List[str] | None = None,
) -> List[types.FileAnnotation]:
train_sets: List[types.DatasetDict] = []
train_sets.append(

View File

@ -2,7 +2,6 @@ import argparse
import json
import os
from collections import Counter
from typing import List, Optional, Tuple
import numpy as np
from sklearn.model_selection import StratifiedGroupKFold
@ -12,8 +11,8 @@ from batdetect2 import types
def print_dataset_stats(
data: List[types.FileAnnotation],
classes_to_ignore: Optional[List[str]] = None,
data: list[types.FileAnnotation],
classes_to_ignore: list[str] | None = None,
) -> Counter[str]:
print("Num files:", len(data))
counts, _ = tu.get_class_names(data, classes_to_ignore)
@ -22,7 +21,7 @@ def print_dataset_stats(
return counts
def load_file_names(file_name: str) -> List[str]:
def load_file_names(file_name: str) -> list[str]:
if not os.path.isfile(file_name):
raise FileNotFoundError(f"Input file not found - {file_name}")
@ -100,12 +99,12 @@ def parse_args():
def split_data(
data: List[types.FileAnnotation],
data: list[types.FileAnnotation],
train_file: str,
test_file: str,
n_splits: int = 5,
random_state: int = 0,
) -> Tuple[List[types.FileAnnotation], List[types.FileAnnotation]]:
) -> tuple[list[types.FileAnnotation], list[types.FileAnnotation]]:
if train_file != "" and test_file != "":
# user has specifed the train / test split
mapping = {
@ -162,7 +161,7 @@ def main():
# change the names of the classes
ip_names = args.input_class_names.split(";")
op_names = args.output_class_names.split(";")
name_dict = dict(zip(ip_names, op_names))
name_dict = dict(zip(ip_names, op_names, strict=False))
# load annotations
data_all = tu.load_set_of_anns(

View File

@ -1,58 +1,68 @@
from typing import TYPE_CHECKING, List, Optional, Sequence
from typing import Sequence
from lightning import Trainer
from soundevent import data
from batdetect2.audio import AudioConfig
from batdetect2.audio.loader import build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.config import InferenceConfig
from batdetect2.inference.dataset import build_inference_loader
from batdetect2.inference.lightning import InferenceModule
from batdetect2.models import Model
from batdetect2.preprocess.preprocessor import build_preprocessor
from batdetect2.targets.targets import build_targets
from batdetect2.typing.postprocess import BatDetect2Prediction
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
from batdetect2.typing import (
AudioLoader,
PreprocessorProtocol,
TargetProtocol,
)
from batdetect2.outputs import (
OutputsConfig,
OutputTransformProtocol,
build_output_transform,
)
from batdetect2.postprocess.types import ClipDetections
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
def run_batch_inference(
model,
model: Model,
clips: Sequence[data.Clip],
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
num_workers: Optional[int] = None,
batch_size: Optional[int] = None,
) -> List[BatDetect2Prediction]:
from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config()
audio_loader = audio_loader or build_audio_loader()
preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate,
targets: TargetProtocol | None = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
audio_config: AudioConfig | None = None,
output_transform: OutputTransformProtocol | None = None,
output_config: OutputsConfig | None = None,
inference_config: InferenceConfig | None = None,
num_workers: int = 1,
batch_size: int | None = None,
) -> list[ClipDetections]:
audio_config = audio_config or AudioConfig(
samplerate=model.preprocessor.input_samplerate,
)
output_config = output_config or OutputsConfig()
inference_config = inference_config or InferenceConfig()
targets = targets or build_targets()
audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets
output_transform = output_transform or build_output_transform(
config=output_config.transform,
targets=targets,
)
loader = build_inference_loader(
clips,
audio_loader=audio_loader,
preprocessor=preprocessor,
config=config.inference.loader,
config=inference_config.loader,
num_workers=num_workers,
batch_size=batch_size,
)
module = InferenceModule(model)
module = InferenceModule(
model,
output_transform=output_transform,
)
trainer = Trainer(enable_checkpointing=False, logger=False)
outputs = trainer.predict(module, loader)
return [
@ -65,13 +75,18 @@ def run_batch_inference(
def process_file_list(
model: Model,
paths: Sequence[data.PathLike],
config: "BatDetect2Config",
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
num_workers: Optional[int] = None,
) -> List[BatDetect2Prediction]:
clip_config = config.inference.clipping
targets: TargetProtocol | None = None,
audio_loader: AudioLoader | None = None,
audio_config: AudioConfig | None = None,
preprocessor: PreprocessorProtocol | None = None,
inference_config: InferenceConfig | None = None,
output_config: OutputsConfig | None = None,
output_transform: OutputTransformProtocol | None = None,
batch_size: int | None = None,
num_workers: int = 0,
) -> list[ClipDetections]:
inference_config = inference_config or InferenceConfig()
clip_config = inference_config.clipping
clips = get_clips_from_files(
paths,
duration=clip_config.duration,
@ -85,6 +100,10 @@ def process_file_list(
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,
config=config,
batch_size=batch_size,
num_workers=num_workers,
output_config=output_config,
audio_config=audio_config,
output_transform=output_transform,
inference_config=inference_config,
)

View File

@ -38,10 +38,10 @@ def get_recording_clips(
discard_empty: bool = True,
) -> Sequence[data.Clip]:
start_time = 0
duration = recording.duration
recording_duration = recording.duration
hop = duration * (1 - overlap)
num_clips = int(np.ceil(duration / hop))
num_clips = int(np.ceil(recording_duration / hop))
if num_clips == 0:
# This should only happen if the clip's duration is zero,
@ -53,8 +53,8 @@ def get_recording_clips(
start = start_time + i * hop
end = start + duration
if end > duration:
empty_duration = end - duration
if end > recording_duration:
empty_duration = end - recording_duration
if empty_duration > max_empty and discard_empty:
# Discard clips that contain too much empty space

View File

@ -1,4 +1,4 @@
from typing import List, NamedTuple, Optional, Sequence
from typing import NamedTuple, Sequence
import torch
from loguru import logger
@ -6,10 +6,11 @@ from soundevent import data
from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = [
"InferenceDataset",
@ -29,14 +30,14 @@ class DatasetItem(NamedTuple):
class InferenceDataset(Dataset[DatasetItem]):
clips: List[data.Clip]
clips: list[data.Clip]
def __init__(
self,
clips: Sequence[data.Clip],
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
audio_dir: Optional[data.PathLike] = None,
audio_dir: data.PathLike | None = None,
):
self.clips = list(clips)
self.preprocessor = preprocessor
@ -46,31 +47,30 @@ class InferenceDataset(Dataset[DatasetItem]):
def __len__(self):
return len(self.clips)
def __getitem__(self, idx: int) -> DatasetItem:
clip = self.clips[idx]
def __getitem__(self, index: int) -> DatasetItem:
clip = self.clips[index]
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
wav_tensor = torch.tensor(wav).unsqueeze(0)
spectrogram = self.preprocessor(wav_tensor)
return DatasetItem(
spec=spectrogram,
idx=torch.tensor(idx),
idx=torch.tensor(index),
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)
class InferenceLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8
def build_inference_loader(
clips: Sequence[data.Clip],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[InferenceLoaderConfig] = None,
num_workers: Optional[int] = None,
batch_size: Optional[int] = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
config: InferenceLoaderConfig | None = None,
num_workers: int = 0,
batch_size: int | None = None,
) -> DataLoader[DatasetItem]:
logger.info("Building inference data loader...")
config = config or InferenceLoaderConfig()
@ -83,20 +83,19 @@ def build_inference_loader(
batch_size = batch_size or config.batch_size
num_workers = num_workers or config.num_workers
return DataLoader(
inference_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=config.num_workers,
num_workers=num_workers,
collate_fn=_collate_fn,
)
def build_inference_dataset(
clips: Sequence[data.Clip],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
) -> InferenceDataset:
if audio_loader is None:
audio_loader = build_audio_loader()
@ -111,7 +110,7 @@ def build_inference_dataset(
)
def _collate_fn(batch: List[DatasetItem]) -> DatasetItem:
def _collate_fn(batch: list[DatasetItem]) -> DatasetItem:
max_width = max(item.spec.shape[-1] for item in batch)
return DatasetItem(
spec=torch.stack(

View File

@ -5,45 +5,44 @@ from torch.utils.data import DataLoader
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing.postprocess import BatDetect2Prediction
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections
class InferenceModule(LightningModule):
def __init__(self, model: Model):
def __init__(
self,
model: Model,
output_transform: OutputTransformProtocol | None = None,
):
super().__init__()
self.model = model
self.output_transform = output_transform or build_output_transform(
targets=model.targets
)
def predict_step(
self,
batch: DatasetItem,
batch_idx: int,
dataloader_idx: int = 0,
) -> Sequence[BatDetect2Prediction]:
) -> Sequence[ClipDetections]:
dataset = self.get_dataset()
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor(
outputs,
start_times=[clip.start_time for clip in clips],
)
clip_detections = self.model.postprocessor(outputs)
predictions = [
BatDetect2Prediction(
return [
self.output_transform.to_clip_detections(
detections=clip_dets,
clip=clip,
predictions=to_raw_predictions(
clip_dets.numpy(),
targets=self.model.targets,
),
)
for clip, clip_dets in zip(clips, clip_detections)
for clip, clip_dets in zip(clips, clip_detections, strict=True)
]
return predictions
def get_dataset(self) -> InferenceDataset:
dataloaders = self.trainer.predict_dataloaders
assert isinstance(dataloaders, DataLoader)

View File

@ -9,10 +9,8 @@ from typing import (
Dict,
Generic,
Literal,
Optional,
Protocol,
TypeVar,
Union,
)
import numpy as np
@ -32,6 +30,21 @@ from batdetect2.core.configs import BaseConfig
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
__all__ = [
"AppLoggingConfig",
"BaseLoggerConfig",
"CSVLoggerConfig",
"DEFAULT_LOGS_DIR",
"DVCLiveConfig",
"LoggerConfig",
"MLFlowLoggerConfig",
"TensorBoardLoggerConfig",
"build_logger",
"enable_logging",
"get_image_logger",
"get_table_logger",
]
def enable_logging(level: int):
logger.remove()
@ -49,14 +62,14 @@ def enable_logging(level: int):
class BaseLoggerConfig(BaseConfig):
log_dir: Path = DEFAULT_LOGS_DIR
experiment_name: Optional[str] = None
run_name: Optional[str] = None
experiment_name: str | None = None
run_name: str | None = None
class DVCLiveConfig(BaseLoggerConfig):
name: Literal["dvclive"] = "dvclive"
prefix: str = ""
log_model: Union[bool, Literal["all"]] = False
log_model: bool | Literal["all"] = False
monitor_system: bool = False
@ -72,22 +85,26 @@ class TensorBoardLoggerConfig(BaseLoggerConfig):
class MLFlowLoggerConfig(BaseLoggerConfig):
name: Literal["mlflow"] = "mlflow"
tracking_uri: Optional[str] = "http://localhost:5000"
tags: Optional[dict[str, Any]] = None
tracking_uri: str | None = "http://localhost:5000"
tags: dict[str, Any] | None = None
log_model: bool = False
LoggerConfig = Annotated[
Union[
DVCLiveConfig,
CSVLoggerConfig,
TensorBoardLoggerConfig,
MLFlowLoggerConfig,
],
DVCLiveConfig
| CSVLoggerConfig
| TensorBoardLoggerConfig
| MLFlowLoggerConfig,
Field(discriminator="name"),
]
class AppLoggingConfig(BaseConfig):
train: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
evaluation: LoggerConfig = Field(default_factory=CSVLoggerConfig)
inference: LoggerConfig = Field(default_factory=CSVLoggerConfig)
T = TypeVar("T", bound=LoggerConfig, contravariant=True)
@ -95,20 +112,20 @@ class LoggerBuilder(Protocol, Generic[T]):
def __call__(
self,
config: T,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
log_dir: Path | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
) -> Logger: ...
def create_dvclive_logger(
config: DVCLiveConfig,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
log_dir: Path | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
) -> Logger:
try:
from dvclive.lightning import DVCLiveLogger # type: ignore
from dvclive.lightning import DVCLiveLogger
except ImportError as error:
raise ValueError(
"DVCLive is not installed and cannot be used for logging"
@ -130,9 +147,9 @@ def create_dvclive_logger(
def create_csv_logger(
config: CSVLoggerConfig,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
log_dir: Path | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
) -> Logger:
from lightning.pytorch.loggers import CSVLogger
@ -159,9 +176,9 @@ def create_csv_logger(
def create_tensorboard_logger(
config: TensorBoardLoggerConfig,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
log_dir: Path | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
) -> Logger:
from lightning.pytorch.loggers import TensorBoardLogger
@ -191,9 +208,9 @@ def create_tensorboard_logger(
def create_mlflow_logger(
config: MLFlowLoggerConfig,
log_dir: Optional[data.PathLike] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
log_dir: data.PathLike | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
) -> Logger:
try:
from lightning.pytorch.loggers import MLFlowLogger
@ -232,9 +249,9 @@ LOGGER_FACTORY: Dict[str, LoggerBuilder] = {
def build_logger(
config: LoggerConfig,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
log_dir: Path | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
) -> Logger:
logger.opt(lazy=True).debug(
"Building logger with config: \n{}",
@ -257,7 +274,7 @@ def build_logger(
PlotLogger = Callable[[str, Figure, int], None]
def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
def get_image_logger(logger: Logger) -> PlotLogger | None:
if isinstance(logger, TensorBoardLogger):
return logger.experiment.add_figure
@ -282,7 +299,7 @@ def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
TableLogger = Callable[[str, pd.DataFrame, int], None]
def get_table_logger(logger: Logger) -> Optional[TableLogger]:
def get_table_logger(logger: Logger) -> TableLogger | None:
if isinstance(logger, TensorBoardLogger):
return partial(save_table, dir=Path(logger.log_dir))

View File

@ -1,36 +1,46 @@
"""Defines and builds the neural network models used in BatDetect2.
"""Neural network model definitions and builders for BatDetect2.
This package (`batdetect2.models`) contains the PyTorch implementations of the
deep neural network architectures used for detecting and classifying bat calls
from spectrograms. It provides modular components and configuration-driven
assembly, allowing for experimentation and use of different architectural
variants.
This package contains the PyTorch implementations of the deep neural network
architectures used to detect and classify bat echolocation calls in
spectrograms. Components are designed to be combined through configuration
objects, making it easy to experiment with different architectures.
Key Submodules:
- `.types`: Defines core data structures (`ModelOutput`) and abstract base
classes (`BackboneModel`, `DetectionModel`) establishing interfaces.
- `.blocks`: Provides reusable neural network building blocks.
- `.encoder`: Defines and builds the downsampling path (encoder) of the network.
- `.bottleneck`: Defines and builds the central bottleneck component.
- `.decoder`: Defines and builds the upsampling path (decoder) of the network.
- `.backbone`: Assembles the encoder, bottleneck, and decoder into a complete
feature extraction backbone (e.g., a U-Net like structure).
- `.heads`: Defines simple prediction heads (detection, classification, size)
that attach to the backbone features.
- `.detectors`: Assembles the backbone and prediction heads into the final,
end-to-end `Detector` model.
Key submodules
--------------
- ``blocks``: Reusable convolutional building blocks (downsampling,
upsampling, attention, coord-conv variants).
- ``encoder``: The downsampling path; reduces spatial resolution whilst
extracting increasingly abstract features.
- ``bottleneck``: The central component connecting encoder to decoder;
optionally applies self-attention along the time axis.
- ``decoder``: The upsampling path; reconstructs high-resolution feature
maps using bottleneck output and skip connections from the encoder.
- ``backbones``: Assembles encoder, bottleneck, and decoder into a complete
U-Net-style feature extraction backbone.
- ``heads``: Lightweight 1×1 convolutional heads that produce detection,
classification, and bounding-box size predictions from backbone features.
- ``detectors``: Combines a backbone with prediction heads into the final
end-to-end ``Detector`` model.
This module re-exports the most important classes, configurations, and builder
functions from these submodules for convenient access. The primary entry point
for creating a standard BatDetect2 model instance is the `build_model` function
provided here.
The primary entry point for building a full, ready-to-use BatDetect2 model
is the ``build_model`` factory function exported from this module.
"""
from typing import List, Optional
from typing import Literal
import torch
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.models.backbones import Backbone, build_backbone
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
from batdetect2.core.configs import BaseConfig
from batdetect2.models.backbones import (
BackboneConfig,
UNetBackbone,
UNetBackboneConfig,
build_backbone,
load_backbone_config,
)
from batdetect2.models.blocks import (
ConvConfig,
FreqCoordConvDownConfig,
@ -43,10 +53,6 @@ from batdetect2.models.bottleneck import (
BottleneckConfig,
build_bottleneck,
)
from batdetect2.models.config import (
BackboneConfig,
load_backbone_config,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
@ -59,17 +65,20 @@ from batdetect2.models.encoder import (
build_encoder,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.typing import (
from batdetect2.models.types import DetectionModel
from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.postprocess.types import (
ClipDetectionsTensor,
DetectionModel,
PostprocessorProtocol,
PreprocessorProtocol,
TargetProtocol,
)
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.config import TargetConfig
from batdetect2.targets.types import TargetProtocol
__all__ = [
"BBoxHead",
"Backbone",
"UNetBackbone",
"BackboneConfig",
"Bottleneck",
"BottleneckConfig",
@ -92,11 +101,93 @@ __all__ = [
"build_detector",
"load_backbone_config",
"Model",
"ModelConfig",
"build_model",
"build_model_with_new_targets",
]
class ModelConfig(BaseConfig):
"""Complete configuration describing a BatDetect2 model.
Bundles every parameter that defines a model's behaviour: the input
sample rate, backbone architecture, preprocessing pipeline,
postprocessing pipeline, and detection targets.
Attributes
----------
samplerate : int
Expected input audio sample rate in Hz. Audio must be resampled
to this rate before being passed to the model. Defaults to
``TARGET_SAMPLERATE_HZ`` (256 000 Hz).
architecture : BackboneConfig
Configuration for the encoder-decoder backbone network. Defaults
to ``UNetBackboneConfig()``.
preprocess : PreprocessingConfig
Parameters for the audio-to-spectrogram preprocessing pipeline
(STFT, frequency crop, transforms, resize). Defaults to
``PreprocessingConfig()``.
postprocess : PostprocessConfig
Parameters for converting raw model outputs into detections (NMS
kernel, thresholds, top-k limit). Defaults to
``PostprocessConfig()``.
targets : TargetConfig
Detection and classification target definitions (class list,
detection target, bounding-box mapper). Defaults to
``TargetConfig()``.
"""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
architecture: BackboneConfig = Field(default_factory=UNetBackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
@classmethod
def load(
cls,
path: PathLike,
field: str | None = None,
extra: Literal["ignore", "allow", "forbid"] | None = None,
strict: bool | None = None,
targets: TargetConfig | None = None,
) -> "ModelConfig":
config = super().load(path, field, extra, strict)
if targets is None:
return config
return config.model_copy(update={"targets": targets})
class Model(torch.nn.Module):
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing.
Combines a preprocessor, a detection model, and a postprocessor into a
single PyTorch module. Calling ``forward`` on a raw waveform tensor
returns a list of detection tensors ready for downstream use.
This class is the top-level object produced by ``build_model``. Most
users will not need to construct it directly.
Attributes
----------
detector : DetectionModel
The neural network that processes spectrograms and produces raw
detection, classification, and bounding-box outputs.
preprocessor : PreprocessorProtocol
Converts a raw waveform tensor into a spectrogram tensor accepted by
``detector``.
postprocessor : PostprocessorProtocol
Converts the raw ``ModelOutput`` from ``detector`` into a list of
per-clip detection tensors.
targets : TargetProtocol
Describes the set of target classes; used when building heads and
during training target construction.
"""
detector: DetectionModel
preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol
@ -115,31 +206,87 @@ class Model(torch.nn.Module):
self.postprocessor = postprocessor
self.targets = targets
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
"""Run the full detection pipeline on a waveform tensor.
Converts the waveform to a spectrogram, passes it through the
detector, and postprocesses the raw outputs into detection tensors.
Parameters
----------
wav : torch.Tensor
Raw audio waveform tensor. The exact expected shape depends on
the preprocessor, but is typically ``(batch, samples)`` or
``(batch, channels, samples)``.
Returns
-------
list[ClipDetectionsTensor]
One detection tensor per clip in the batch. Each tensor encodes
the detected events (locations, class scores, sizes) for that
clip.
"""
spec = self.preprocessor(wav)
outputs = self.detector(spec)
return self.postprocessor(outputs)
def build_model(
config: Optional[BackboneConfig] = None,
targets: Optional[TargetProtocol] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
postprocessor: Optional[PostprocessorProtocol] = None,
):
config: ModelConfig | None = None,
targets: TargetProtocol | None = None,
preprocessor: PreprocessorProtocol | None = None,
postprocessor: PostprocessorProtocol | None = None,
) -> Model:
"""Build a complete, ready-to-use BatDetect2 model.
Assembles a ``Model`` instance from a ``ModelConfig`` and optional
component overrides. Any component argument left as ``None`` is built
from the configuration. Passing a pre-built component overrides the
corresponding config fields for that component only.
Parameters
----------
config : ModelConfig, optional
Full model configuration (samplerate, architecture, preprocessing,
postprocessing, targets). Defaults to ``ModelConfig()`` if not
provided.
targets : TargetProtocol, optional
Pre-built targets object. If given, overrides
``config.targets``.
preprocessor : PreprocessorProtocol, optional
Pre-built preprocessor. If given, overrides
``config.preprocess`` and ``config.samplerate`` for the
preprocessing step.
postprocessor : PostprocessorProtocol, optional
Pre-built postprocessor. If given, overrides
``config.postprocess``. When omitted and a custom
``preprocessor`` is supplied, the default postprocessor is built
using that preprocessor so that frequency and time scaling remain
consistent.
Returns
-------
Model
A fully assembled ``Model`` instance ready for inference or
training.
"""
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
config = config or BackboneConfig()
targets = targets or build_targets()
preprocessor = preprocessor or build_preprocessor()
config = config or ModelConfig()
targets = targets or build_targets(config=config.targets)
preprocessor = preprocessor or build_preprocessor(
config=config.preprocess,
input_samplerate=config.samplerate,
)
postprocessor = postprocessor or build_postprocessor(
preprocessor=preprocessor,
config=config.postprocess,
)
detector = build_detector(
num_classes=len(targets.class_names),
config=config,
config=config.architecture,
)
return Model(
detector=detector,
@ -147,3 +294,21 @@ def build_model(
preprocessor=preprocessor,
targets=targets,
)
def build_model_with_new_targets(
model: Model,
targets: TargetProtocol,
) -> Model:
"""Build a new model with a different target set."""
detector = build_detector(
num_classes=len(targets.class_names),
backbone=model.detector.backbone,
)
return Model(
detector=detector,
postprocessor=model.postprocessor,
preprocessor=model.preprocessor,
targets=targets,
)

View File

@ -1,99 +1,176 @@
"""Assembles a complete Encoder-Decoder Backbone network.
"""Assembles a complete encoder-decoder backbone network.
This module defines the configuration (`BackboneConfig`) and implementation
(`Backbone`) for a standard encoder-decoder style neural network backbone.
This module defines ``UNetBackboneConfig`` and the ``UNetBackbone``
``nn.Module``, together with the ``build_backbone`` and
``load_backbone_config`` helpers.
It orchestrates the connection between three main components, built using their
respective configurations and factory functions from sibling modules:
1. Encoder (`batdetect2.models.encoder`): Downsampling path, extracts features
at multiple resolutions and provides skip connections.
2. Bottleneck (`batdetect2.models.bottleneck`): Processes features at the
lowest resolution, optionally applying self-attention.
3. Decoder (`batdetect2.models.decoder`): Upsampling path, reconstructs high-
resolution features using bottleneck features and skip connections.
A backbone combines three components built from the sibling modules:
The resulting `Backbone` module takes a spectrogram as input and outputs a
final feature map, typically used by subsequent prediction heads. It includes
automatic padding to handle input sizes not perfectly divisible by the
network's total downsampling factor.
1. **Encoder** (``batdetect2.models.encoder``) reduces spatial resolution
while extracting hierarchical features and storing skip-connection tensors.
2. **Bottleneck** (``batdetect2.models.bottleneck``) processes the
lowest-resolution features, optionally applying self-attention.
3. **Decoder** (``batdetect2.models.decoder``) restores spatial resolution
using bottleneck features and skip connections from the encoder.
The resulting ``UNetBackbone`` takes a spectrogram tensor as input and returns
a high-resolution feature map consumed by the prediction heads in
``batdetect2.models.detectors``.
Input padding is handled automatically: the backbone pads the input to be
divisible by the total downsampling factor and strips the padding from the
output so that the output spatial dimensions always match the input spatial
dimensions.
"""
from typing import Tuple
from typing import Annotated, Literal
import torch
import torch.nn.functional as F
from torch import nn
from pydantic import Field, TypeAdapter
from soundevent import data
from batdetect2.models.bottleneck import build_bottleneck
from batdetect2.models.config import BackboneConfig
from batdetect2.models.decoder import Decoder, build_decoder
from batdetect2.models.encoder import Encoder, build_encoder
from batdetect2.typing.models import BackboneModel
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig,
build_bottleneck,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
build_decoder,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
EncoderConfig,
build_encoder,
)
from batdetect2.models.types import (
BackboneModel,
BottleneckProtocol,
DecoderProtocol,
EncoderProtocol,
)
__all__ = [
"Backbone",
"BackboneImportConfig",
"UNetBackbone",
"BackboneConfig",
"load_backbone_config",
"build_backbone",
]
class Backbone(BackboneModel):
"""Encoder-Decoder Backbone Network Implementation.
class UNetBackboneConfig(BaseConfig):
"""Configuration for a U-Net-style encoder-decoder backbone.
Combines an Encoder, Bottleneck, and Decoder module sequentially, using
skip connections between the Encoder and Decoder. Implements the standard
U-Net style forward pass. Includes automatic input padding to handle
various input sizes and a final convolutional block to adjust the output
channels.
All fields have sensible defaults that reproduce the standard BatDetect2
architecture, so you can start with ``UNetBackboneConfig()`` and override
only the fields you want to change.
This class inherits from `BackboneModel` and implements its `forward`
method. Instances are typically created using the `build_backbone` factory
function.
Attributes
----------
name : str
Discriminator field used by the backbone registry; always
``"UNetBackbone"``.
input_height : int
Number of frequency bins in the input spectrogram. Defaults to
``128``.
in_channels : int
Number of channels in the input spectrogram (e.g. ``1`` for a
standard mel-spectrogram). Defaults to ``1``.
encoder : EncoderConfig
Configuration for the downsampling path. Defaults to
``DEFAULT_ENCODER_CONFIG``.
bottleneck : BottleneckConfig
Configuration for the bottleneck. Defaults to
``DEFAULT_BOTTLENECK_CONFIG``.
decoder : DecoderConfig
Configuration for the upsampling path. Defaults to
``DEFAULT_DECODER_CONFIG``.
"""
name: Literal["UNetBackbone"] = "UNetBackbone"
input_height: int = 128
in_channels: int = 1
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
@add_import_config(backbone_registry)
class BackboneImportConfig(ImportConfig):
"""Use any callable as a backbone model.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class UNetBackbone(BackboneModel):
"""U-Net-style encoder-decoder backbone network.
Combines an encoder, a bottleneck, and a decoder into a single module
that produces a high-resolution feature map from an input spectrogram.
Skip connections from each encoder stage are added element-wise to the
corresponding decoder stage input.
Input spectrograms of arbitrary width are handled automatically: the
backbone pads the input so that its dimensions are divisible by
``divide_factor`` and removes the padding from the output.
Instances are typically created via ``build_backbone``.
Attributes
----------
input_height : int
Expected height of the input spectrogram.
Expected height (frequency bins) of the input spectrogram.
out_channels : int
Number of channels in the final output feature map.
encoder : Encoder
Number of channels in the output feature map (taken from the
decoder's output channel count).
encoder : EncoderProtocol
The instantiated encoder module.
decoder : Decoder
decoder : DecoderProtocol
The instantiated decoder module.
bottleneck : nn.Module
bottleneck : BottleneckProtocol
The instantiated bottleneck module.
final_conv : ConvBlock
Final convolutional block applied after the decoder.
divide_factor : int
The total downsampling factor (2^depth) applied by the encoder,
used for automatic input padding.
The total spatial downsampling factor applied by the encoder
(``input_height // encoder.output_height``). The input width is
padded to be a multiple of this value before processing.
"""
def __init__(
self,
input_height: int,
encoder: Encoder,
decoder: Decoder,
bottleneck: nn.Module,
encoder: EncoderProtocol,
decoder: DecoderProtocol,
bottleneck: BottleneckProtocol,
):
"""Initialize the Backbone network.
"""Initialise the backbone network.
Parameters
----------
input_height : int
Expected height of the input spectrogram.
out_channels : int
Desired number of output channels for the backbone's feature map.
encoder : Encoder
An initialized Encoder module.
decoder : Decoder
An initialized Decoder module.
bottleneck : nn.Module
An initialized Bottleneck module.
Raises
------
ValueError
If component output/input channels or heights are incompatible.
Expected height (frequency bins) of the input spectrogram.
encoder : EncoderProtocol
An initialised encoder module.
decoder : DecoderProtocol
An initialised decoder module. Its ``output_height`` must equal
``input_height``; a ``ValueError`` is raised otherwise.
bottleneck : BottleneckProtocol
An initialised bottleneck module.
"""
super().__init__()
self.input_height = input_height
@ -110,22 +187,25 @@ class Backbone(BackboneModel):
self.divide_factor = input_height // self.encoder.output_height
def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Perform the forward pass through the encoder-decoder backbone.
"""Produce a feature map from an input spectrogram.
Applies padding, runs encoder, bottleneck, decoder (with skip
connections), removes padding, and applies a final convolution.
Pads the input if necessary, runs it through the encoder, then
the bottleneck, then the decoder (incorporating encoder skip
connections), and finally removes any padding added earlier.
Parameters
----------
spec : torch.Tensor
Input spectrogram tensor, shape `(B, C_in, H_in, W_in)`. Must match
`self.encoder.input_channels` and `self.input_height`.
Input spectrogram tensor, shape
``(B, C_in, H_in, W_in)``. ``H_in`` must equal
``self.input_height``; ``W_in`` can be any positive integer.
Returns
-------
torch.Tensor
Output feature map tensor, shape `(B, C_out, H_in, W_in)`, where
`C_out` is `self.out_channels`.
Feature map tensor, shape ``(B, C_out, H_in, W_in)``, where
``C_out`` is ``self.out_channels``. The spatial dimensions
always match those of the input.
"""
spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor)
@ -143,95 +223,97 @@ class Backbone(BackboneModel):
return x
@backbone_registry.register(UNetBackboneConfig)
@staticmethod
def from_config(config: UNetBackboneConfig) -> BackboneModel:
encoder = build_encoder(
in_channels=config.in_channels,
input_height=config.input_height,
config=config.encoder,
)
def build_backbone(config: BackboneConfig) -> BackboneModel:
"""Factory function to build a Backbone from configuration.
bottleneck = build_bottleneck(
input_height=encoder.output_height,
in_channels=encoder.out_channels,
config=config.bottleneck,
)
Constructs the `Encoder`, `Bottleneck`, and `Decoder` components based on
the provided `BackboneConfig`, validates their compatibility, and assembles
them into a `Backbone` instance.
decoder = build_decoder(
in_channels=bottleneck.out_channels,
input_height=encoder.output_height,
config=config.decoder,
)
if decoder.output_height != config.input_height:
raise ValueError(
"Invalid configuration: Decoder output height "
f"({decoder.output_height}) must match the Backbone input height "
f"({config.input_height}). Check encoder/decoder layer "
"configurations and input/bottleneck heights."
)
return UNetBackbone(
input_height=config.input_height,
encoder=encoder,
decoder=decoder,
bottleneck=bottleneck,
)
BackboneConfig = Annotated[
UNetBackboneConfig | BackboneImportConfig,
Field(discriminator="name"),
]
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
"""Build a backbone network from configuration.
Looks up the backbone class corresponding to ``config.name`` in the
backbone registry and calls its ``from_config`` method. If no
configuration is provided, a default ``UNetBackbone`` is returned.
Parameters
----------
config : BackboneConfig
The configuration object detailing the backbone architecture, including
input dimensions and configurations for encoder, bottleneck, and
decoder.
config : BackboneConfig, optional
A configuration object describing the desired backbone. Currently
``UNetBackboneConfig`` is the only supported type. Defaults to
``UNetBackboneConfig()`` if not provided.
Returns
-------
BackboneModel
An initialized `Backbone` module ready for use.
Raises
------
ValueError
If sub-component configurations are incompatible
(e.g., channel mismatches, decoder output height doesn't match backbone
input height).
NotImplementedError
If an unknown block type is specified in sub-configs.
An initialised backbone module.
"""
encoder = build_encoder(
in_channels=config.in_channels,
input_height=config.input_height,
config=config.encoder,
)
bottleneck = build_bottleneck(
input_height=encoder.output_height,
in_channels=encoder.out_channels,
config=config.bottleneck,
)
decoder = build_decoder(
in_channels=bottleneck.out_channels,
input_height=encoder.output_height,
config=config.decoder,
)
if decoder.output_height != config.input_height:
raise ValueError(
"Invalid configuration: Decoder output height "
f"({decoder.output_height}) must match the Backbone input height "
f"({config.input_height}). Check encoder/decoder layer "
"configurations and input/bottleneck heights."
)
return Backbone(
input_height=config.input_height,
encoder=encoder,
decoder=decoder,
bottleneck=bottleneck,
)
config = config or UNetBackboneConfig()
return backbone_registry.build(config)
def _pad_adjust(
spec: torch.Tensor,
factor: int = 32,
) -> Tuple[torch.Tensor, int, int]:
"""Pad tensor height and width to be divisible by a factor.
) -> tuple[torch.Tensor, int, int]:
"""Pad a tensor's height and width to be divisible by ``factor``.
Calculates the required padding for the last two dimensions (H, W) to make
them divisible by `factor` and applies right/bottom padding using
`torch.nn.functional.pad`.
Adds zero-padding to the bottom and right edges of the tensor so that
both dimensions are exact multiples of ``factor``. If both dimensions
are already divisible, the tensor is returned unchanged.
Parameters
----------
spec : torch.Tensor
Input tensor, typically shape `(B, C, H, W)`.
Input tensor, typically shape ``(B, C, H, W)``.
factor : int, default=32
The factor to make height and width divisible by.
The factor that both H and W should be divisible by after padding.
Returns
-------
Tuple[torch.Tensor, int, int]
A tuple containing:
- The padded tensor.
- The amount of padding added to height (`h_pad`).
- The amount of padding added to width (`w_pad`).
tuple[torch.Tensor, int, int]
- Padded tensor.
- Number of rows added to the height (``h_pad``).
- Number of columns added to the width (``w_pad``).
"""
h, w = spec.shape[2:]
h, w = spec.shape[-2:]
h_pad = -h % factor
w_pad = -w % factor
@ -244,28 +326,71 @@ def _pad_adjust(
def _restore_pad(
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
) -> torch.Tensor:
"""Remove padding added by _pad_adjust.
"""Remove padding previously added by ``_pad_adjust``.
Removes padding from the bottom and right edges of the tensor.
Trims ``h_pad`` rows from the bottom and ``w_pad`` columns from the
right of the tensor, restoring its original spatial dimensions.
Parameters
----------
x : torch.Tensor
Padded tensor, typically shape `(B, C, H_padded, W_padded)`.
Padded tensor, typically shape ``(B, C, H_padded, W_padded)``.
h_pad : int, default=0
Amount of padding previously added to the height (bottom).
Number of rows to remove from the bottom.
w_pad : int, default=0
Amount of padding previously added to the width (right).
Number of columns to remove from the right.
Returns
-------
torch.Tensor
Tensor with padding removed, shape `(B, C, H_original, W_original)`.
Tensor with padding removed, shape
``(B, C, H_padded - h_pad, W_padded - w_pad)``.
"""
if h_pad > 0:
x = x[:, :, :-h_pad, :]
x = x[..., :-h_pad, :]
if w_pad > 0:
x = x[:, :, :, :-w_pad]
x = x[..., :-w_pad]
return x
def load_backbone_config(
path: data.PathLike,
field: str | None = None,
) -> BackboneConfig:
"""Load a backbone configuration from a YAML or JSON file.
Reads the file at ``path``, optionally descends into a named sub-field,
and validates the result against the ``BackboneConfig`` discriminated
union.
Parameters
----------
path : PathLike
Path to the configuration file. Both YAML and JSON formats are
supported.
field : str, optional
Dot-separated key path to the sub-field that contains the backbone
configuration (e.g. ``"model"``). If ``None``, the root of the
file is used.
Returns
-------
BackboneConfig
A validated backbone configuration object (currently always a
``UNetBackboneConfig`` instance).
Raises
------
FileNotFoundError
If ``path`` does not exist.
ValidationError
If the loaded data does not conform to a known ``BackboneConfig``
schema.
"""
return load_config(
path,
schema=TypeAdapter(BackboneConfig),
field=field,
)

View File

@ -1,42 +1,63 @@
"""Commonly used neural network building blocks for BatDetect2 models.
"""Reusable convolutional building blocks for BatDetect2 models.
This module provides various reusable `torch.nn.Module` subclasses that form
the fundamental building blocks for constructing convolutional neural network
architectures, particularly encoder-decoder backbones used in BatDetect2.
This module provides a collection of ``torch.nn.Module`` subclasses that form
the fundamental building blocks for the encoder-decoder backbone used in
BatDetect2. All blocks follow a consistent interface: they store
``in_channels`` and ``out_channels`` as attributes and implement a
``get_output_height`` method that reports how a given input height maps to an
output height (e.g., halved by downsampling blocks, doubled by upsampling
blocks).
It includes standard components like basic convolutional blocks (`ConvBlock`),
blocks incorporating downsampling (`StandardConvDownBlock`), and blocks with
upsampling (`StandardConvUpBlock`).
Available block families
------------------------
Standard blocks
``ConvBlock`` convolution + batch normalisation + ReLU, no change in
spatial resolution.
Additionally, it features specialized layers investigated in BatDetect2
research:
Downsampling blocks
``StandardConvDownBlock`` convolution then 2×2 max-pooling, halves H
and W.
``FreqCoordConvDownBlock`` like ``StandardConvDownBlock`` but prepends
a normalised frequency-coordinate channel before the convolution
(CoordConv concept), helping filters learn frequency-position-dependent
patterns.
- `SelfAttention`: Applies self-attention along the time dimension, enabling
the model to weigh information across the entire temporal context, often
used in the bottleneck of an encoder-decoder.
- `FreqCoordConvDownBlock` / `FreqCoordConvUpBlock`: Implement the "CoordConv"
concept by concatenating normalized frequency coordinate information as an
extra channel to the input of convolutional layers. This explicitly provides
spatial frequency information to filters, potentially enabling them to learn
frequency-dependent patterns more effectively.
Upsampling blocks
``StandardConvUpBlock`` bilinear interpolation then convolution,
doubles H and W.
``FreqCoordConvUpBlock`` like ``StandardConvUpBlock`` but prepends a
frequency-coordinate channel after upsampling.
These blocks can be utilized directly in custom PyTorch model definitions or
assembled into larger architectures.
Bottleneck blocks
``VerticalConv`` 1-D convolution whose kernel spans the entire
frequency axis, collapsing H to 1 whilst preserving W.
``SelfAttention`` scaled dot-product self-attention along the time
axis; typically follows a ``VerticalConv``.
A unified factory function `build_layer_from_config` allows creating instances
of these blocks based on configuration objects.
Group block
``LayerGroup`` chains several blocks sequentially into one unit,
useful when a single encoder or decoder "stage" requires more than one
operation.
Factory function
----------------
``build_layer`` creates any of the above blocks from the matching
configuration object (one of the ``*Config`` classes exported here), using
a discriminated-union ``name`` field to dispatch to the correct class.
"""
from typing import Annotated, List, Literal, Tuple, Union
from typing import Annotated, Literal
import torch
import torch.nn.functional as F
from pydantic import Field
from torch import nn
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.core.configs import BaseConfig
__all__ = [
"BlockImportConfig",
"ConvBlock",
"LayerGroupConfig",
"VerticalConv",
@ -51,63 +72,125 @@ __all__ = [
"FreqCoordConvUpConfig",
"StandardConvUpConfig",
"LayerConfig",
"build_layer_from_config",
"build_layer",
]
class Block(nn.Module):
"""Abstract base class for all BatDetect2 building blocks.
Subclasses must set ``in_channels`` and ``out_channels`` as integer
attributes so that factory functions can wire blocks together without
inspecting configuration objects at runtime. They may also override
``get_output_height`` when the block changes the height dimension (e.g.
downsampling or upsampling blocks).
Attributes
----------
in_channels : int
Number of channels expected in the input tensor.
out_channels : int
Number of channels produced in the output tensor.
"""
in_channels: int
out_channels: int
def get_output_height(self, input_height: int) -> int:
"""Return the output height for a given input height.
The default implementation returns ``input_height`` unchanged,
which is correct for blocks that do not alter spatial resolution.
Override this in downsampling (returns ``input_height // 2``) or
upsampling (returns ``input_height * 2``) subclasses.
Parameters
----------
input_height : int
Height (number of frequency bins) of the input feature map.
Returns
-------
int
Height of the output feature map.
"""
return input_height
block_registry: Registry[Block, [int, int]] = Registry("block")
@add_import_config(block_registry)
class BlockImportConfig(ImportConfig):
"""Use any callable as a model block.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class SelfAttentionConfig(BaseConfig):
"""Configuration for a ``SelfAttention`` block.
Attributes
----------
name : str
Discriminator field; always ``"SelfAttention"``.
attention_channels : int
Dimensionality of the query, key, and value projections.
temperature : float
Scaling factor applied to the weighted values before the final
linear projection. Defaults to ``1``.
"""
name: Literal["SelfAttention"] = "SelfAttention"
attention_channels: int
temperature: float = 1
class SelfAttention(nn.Module):
"""Self-Attention mechanism operating along the time dimension.
class SelfAttention(Block):
"""Self-attention block operating along the time axis.
This module implements a scaled dot-product self-attention mechanism,
specifically designed here to operate across the time steps of an input
feature map, typically after spatial dimensions (like frequency) have been
condensed or squeezed.
Applies a scaled dot-product self-attention mechanism across the time
steps of an input feature map. Before attention is computed the height
dimension (frequency axis) is expected to have been reduced to 1, e.g.
by a preceding ``VerticalConv`` layer.
By calculating attention weights between all pairs of time steps, it allows
the model to capture long-range temporal dependencies and focus on relevant
parts of the sequence. It's often employed in the bottleneck or
intermediate layers of an encoder-decoder architecture to integrate global
temporal context.
The implementation uses linear projections to create query, key, and value
representations, computes scaled dot-product attention scores, applies
softmax, and produces an output by weighting the values according to the
attention scores, followed by a final linear projection. Positional encoding
is not explicitly included in this block.
For each time step the block computes query, key, and value projections
with learned linear weights, then calculates attention weights from the
querykey dot products divided by ``temperature × attention_channels``.
The weighted sum of values is projected back to ``in_channels`` via a
final linear layer, and the height dimension is restored so that the
output shape matches the input shape.
Parameters
----------
in_channels : int
Number of input channels (features per time step after spatial squeeze).
Number of input channels (features per time step). The output will
also have ``in_channels`` channels.
attention_channels : int
Number of channels for the query, key, and value projections. Also the
dimension of the output projection's input.
Dimensionality of the query, key, and value projections.
temperature : float, default=1.0
Scaling factor applied *before* the final projection layer. Can be used
to adjust the sharpness or focus of the attention mechanism, although
scaling within the softmax (dividing by sqrt(dim)) is more common for
standard transformers. Here it scales the weighted values.
Divisor applied together with ``attention_channels`` when scaling
the dot-product scores before softmax. Larger values produce softer
(more uniform) attention distributions.
Attributes
----------
key_fun : nn.Linear
Linear layer for key projection.
Linear projection for keys.
value_fun : nn.Linear
Linear layer for value projection.
Linear projection for values.
query_fun : nn.Linear
Linear layer for query projection.
Linear projection for queries.
pro_fun : nn.Linear
Final linear projection layer applied after attention weighting.
Final linear projection applied to the attended values.
temperature : float
Scaling factor applied before final projection.
Scaling divisor used when computing attention scores.
att_dim : int
Dimensionality of the attention space (`attention_channels`).
Dimensionality of the attention space (``attention_channels``).
"""
def __init__(
@ -117,10 +200,13 @@ class SelfAttention(nn.Module):
temperature: float = 1.0,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels
# Note, does not encode position information (absolute or relative)
self.temperature = temperature
self.att_dim = attention_channels
self.output_channels = in_channels
self.key_fun = nn.Linear(in_channels, attention_channels)
self.value_fun = nn.Linear(in_channels, attention_channels)
@ -133,20 +219,16 @@ class SelfAttention(nn.Module):
Parameters
----------
x : torch.Tensor
Input tensor, expected shape `(B, C, H, W)`, where H is typically
squeezed (e.g., H=1 after a `VerticalConv` or pooling) before
applying attention along the W (time) dimension.
Input tensor with shape ``(B, C, 1, W)``. The height dimension
must be 1 (i.e. the frequency axis should already have been
collapsed by a preceding ``VerticalConv`` layer).
Returns
-------
torch.Tensor
Output tensor of the same shape as the input `(B, C, H, W)`, where
attention has been applied across the W dimension.
Raises
------
RuntimeError
If input tensor dimensions are incompatible with operations.
Output tensor with the same shape ``(B, C, 1, W)`` as the
input, with each time step updated by attended context from all
other time steps.
"""
x = x.squeeze(2).permute(0, 2, 1)
@ -175,6 +257,22 @@ class SelfAttention(nn.Module):
return op
def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
"""Return the softmax attention weight matrix.
Useful for visualising which time steps attend to which others.
Parameters
----------
x : torch.Tensor
Input tensor with shape ``(B, C, 1, W)``.
Returns
-------
torch.Tensor
Attention weight matrix with shape ``(B, W, W)``. Entry
``[b, i, j]`` is the attention weight that time step ``i``
assigns to time step ``j`` in batch item ``b``.
"""
x = x.squeeze(2).permute(0, 2, 1)
key = torch.matmul(
@ -190,6 +288,19 @@ class SelfAttention(nn.Module):
att_weights = F.softmax(kk_qq, 1)
return att_weights
@block_registry.register(SelfAttentionConfig)
@staticmethod
def from_config(
config: SelfAttentionConfig,
input_channels: int,
input_height: int,
) -> "SelfAttention":
return SelfAttention(
in_channels=input_channels,
attention_channels=config.attention_channels,
temperature=config.temperature,
)
class ConvConfig(BaseConfig):
"""Configuration for a basic ConvBlock."""
@ -207,7 +318,7 @@ class ConvConfig(BaseConfig):
"""Padding size."""
class ConvBlock(nn.Module):
class ConvBlock(Block):
"""Basic Convolutional Block.
A standard building block consisting of a 2D convolution, followed by
@ -235,6 +346,8 @@ class ConvBlock(nn.Module):
pad_size: int = 1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d(
in_channels,
out_channels,
@ -258,8 +371,37 @@ class ConvBlock(nn.Module):
"""
return F.relu_(self.batch_norm(self.conv(x)))
@block_registry.register(ConvConfig)
@staticmethod
def from_config(
config: ConvConfig,
input_channels: int,
input_height: int,
):
return ConvBlock(
in_channels=input_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
)
class VerticalConv(nn.Module):
class VerticalConvConfig(BaseConfig):
"""Configuration for a ``VerticalConv`` block.
Attributes
----------
name : str
Discriminator field; always ``"VerticalConv"``.
channels : int
Number of output channels produced by the vertical convolution.
"""
name: Literal["VerticalConv"] = "VerticalConv"
channels: int
class VerticalConv(Block):
"""Convolutional layer that aggregates features across the entire height.
Applies a 2D convolution using a kernel with shape `(input_height, 1)`.
@ -288,6 +430,8 @@ class VerticalConv(nn.Module):
input_height: int,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
@ -312,6 +456,19 @@ class VerticalConv(nn.Module):
"""
return F.relu_(self.bn(self.conv(x)))
@block_registry.register(VerticalConvConfig)
@staticmethod
def from_config(
config: VerticalConvConfig,
input_channels: int,
input_height: int,
):
return VerticalConv(
in_channels=input_channels,
out_channels=config.channels,
input_height=input_height,
)
class FreqCoordConvDownConfig(BaseConfig):
"""Configuration for a FreqCoordConvDownBlock."""
@ -329,7 +486,7 @@ class FreqCoordConvDownConfig(BaseConfig):
"""Padding size."""
class FreqCoordConvDownBlock(nn.Module):
class FreqCoordConvDownBlock(Block):
"""Downsampling Conv Block incorporating Frequency Coordinate features.
This block implements a downsampling step (Conv2d + MaxPool2d) commonly
@ -368,6 +525,8 @@ class FreqCoordConvDownBlock(nn.Module):
pad_size: int = 1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.coords = nn.Parameter(
torch.linspace(-1, 1, input_height)[None, None, ..., None],
@ -402,6 +561,24 @@ class FreqCoordConvDownBlock(nn.Module):
x = F.relu(self.batch_norm(x), inplace=True)
return x
def get_output_height(self, input_height: int) -> int:
return input_height // 2
@block_registry.register(FreqCoordConvDownConfig)
@staticmethod
def from_config(
config: FreqCoordConvDownConfig,
input_channels: int,
input_height: int,
):
return FreqCoordConvDownBlock(
in_channels=input_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
)
class StandardConvDownConfig(BaseConfig):
"""Configuration for a StandardConvDownBlock."""
@ -419,7 +596,7 @@ class StandardConvDownConfig(BaseConfig):
"""Padding size."""
class StandardConvDownBlock(nn.Module):
class StandardConvDownBlock(Block):
"""Standard Downsampling Convolutional Block.
A basic downsampling block consisting of a 2D convolution, followed by
@ -447,6 +624,8 @@ class StandardConvDownBlock(nn.Module):
pad_size: int = 1,
):
super(StandardConvDownBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = nn.Conv2d(
in_channels,
out_channels,
@ -472,6 +651,23 @@ class StandardConvDownBlock(nn.Module):
x = F.max_pool2d(self.conv(x), 2, 2)
return F.relu(self.batch_norm(x), inplace=True)
def get_output_height(self, input_height: int) -> int:
return input_height // 2
@block_registry.register(StandardConvDownConfig)
@staticmethod
def from_config(
config: StandardConvDownConfig,
input_channels: int,
input_height: int,
):
return StandardConvDownBlock(
in_channels=input_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
)
class FreqCoordConvUpConfig(BaseConfig):
"""Configuration for a FreqCoordConvUpBlock."""
@ -488,8 +684,14 @@ class FreqCoordConvUpConfig(BaseConfig):
pad_size: int = 1
"""Padding size."""
up_mode: str = "bilinear"
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
class FreqCoordConvUpBlock(nn.Module):
up_scale: tuple[int, int] = (2, 2)
"""Scaling factor for height and width during upsampling."""
class FreqCoordConvUpBlock(Block):
"""Upsampling Conv Block incorporating Frequency Coordinate features.
This block implements an upsampling step followed by a convolution,
@ -504,22 +706,22 @@ class FreqCoordConvUpBlock(nn.Module):
Parameters
----------
in_channels : int
in_channels
Number of channels in the input tensor (before upsampling).
out_channels : int
out_channels
Number of output channels after the convolution.
input_height : int
input_height
Height (H dimension, frequency bins) of the tensor *before* upsampling.
Used to calculate the height for coordinate feature generation after
upsampling.
kernel_size : int, default=3
kernel_size
Size of the square convolutional kernel.
pad_size : int, default=1
pad_size
Padding added before convolution.
up_mode : str, default="bilinear"
up_mode
Interpolation mode for upsampling (e.g., "nearest", "bilinear",
"bicubic").
up_scale : Tuple[int, int], default=(2, 2)
up_scale
Scaling factor for height and width during upsampling
(typically (2, 2)).
"""
@ -532,9 +734,11 @@ class FreqCoordConvUpBlock(nn.Module):
kernel_size: int = 3,
pad_size: int = 1,
up_mode: str = "bilinear",
up_scale: Tuple[int, int] = (2, 2),
up_scale: tuple[int, int] = (2, 2),
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.up_scale = up_scale
self.up_mode = up_mode
@ -581,6 +785,26 @@ class FreqCoordConvUpBlock(nn.Module):
op = F.relu(self.batch_norm(op), inplace=True)
return op
def get_output_height(self, input_height: int) -> int:
return input_height * 2
@block_registry.register(FreqCoordConvUpConfig)
@staticmethod
def from_config(
config: FreqCoordConvUpConfig,
input_channels: int,
input_height: int,
):
return FreqCoordConvUpBlock(
in_channels=input_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
up_mode=config.up_mode,
up_scale=config.up_scale,
)
class StandardConvUpConfig(BaseConfig):
"""Configuration for a StandardConvUpBlock."""
@ -597,8 +821,14 @@ class StandardConvUpConfig(BaseConfig):
pad_size: int = 1
"""Padding size."""
up_mode: str = "bilinear"
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
class StandardConvUpBlock(nn.Module):
up_scale: tuple[int, int] = (2, 2)
"""Scaling factor for height and width during upsampling."""
class StandardConvUpBlock(Block):
"""Standard Upsampling Convolutional Block.
A basic upsampling block used in CNN decoders. It first upsamples the input
@ -609,17 +839,17 @@ class StandardConvUpBlock(nn.Module):
Parameters
----------
in_channels : int
in_channels
Number of channels in the input tensor (before upsampling).
out_channels : int
out_channels
Number of output channels after the convolution.
kernel_size : int, default=3
kernel_size
Size of the square convolutional kernel.
pad_size : int, default=1
pad_size
Padding added before convolution.
up_mode : str, default="bilinear"
up_mode
Interpolation mode for upsampling (e.g., "nearest", "bilinear").
up_scale : Tuple[int, int], default=(2, 2)
up_scale
Scaling factor for height and width during upsampling.
"""
@ -630,9 +860,11 @@ class StandardConvUpBlock(nn.Module):
kernel_size: int = 3,
pad_size: int = 1,
up_mode: str = "bilinear",
up_scale: Tuple[int, int] = (2, 2),
up_scale: tuple[int, int] = (2, 2),
):
super(StandardConvUpBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.up_scale = up_scale
self.up_mode = up_mode
self.conv = nn.Conv2d(
@ -669,155 +901,195 @@ class StandardConvUpBlock(nn.Module):
op = F.relu(self.batch_norm(op), inplace=True)
return op
def get_output_height(self, input_height: int) -> int:
return input_height * 2
@block_registry.register(StandardConvUpConfig)
@staticmethod
def from_config(
config: StandardConvUpConfig,
input_channels: int,
input_height: int,
):
return StandardConvUpBlock(
in_channels=input_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
up_mode=config.up_mode,
up_scale=config.up_scale,
)
class LayerGroupConfig(BaseConfig):
"""Configuration for a ``LayerGroup`` — a sequential chain of blocks.
Use this when a single encoder or decoder stage needs more than one
block. The blocks are executed in the order they appear in ``layers``,
with channel counts and heights propagated automatically.
Attributes
----------
name : str
Discriminator field; always ``"LayerGroup"``.
layers : List[LayerConfig]
Ordered list of block configurations to chain together.
"""
name: Literal["LayerGroup"] = "LayerGroup"
layers: list["LayerConfig"]
LayerConfig = Annotated[
Union[
ConvConfig,
FreqCoordConvDownConfig,
StandardConvDownConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
SelfAttentionConfig,
"LayerGroupConfig",
],
ConvConfig
| FreqCoordConvDownConfig
| StandardConvDownConfig
| FreqCoordConvUpConfig
| StandardConvUpConfig
| SelfAttentionConfig
| LayerGroupConfig,
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configuration models."""
class LayerGroupConfig(BaseConfig):
name: Literal["LayerGroup"] = "LayerGroup"
layers: List[LayerConfig]
class LayerGroup(nn.Module):
"""Sequential chain of blocks that acts as a single composite block.
Wraps multiple ``Block`` instances in an ``nn.Sequential`` container,
exposing the same ``in_channels``, ``out_channels``, and
``get_output_height`` interface as a regular ``Block`` so it can be
used transparently wherever a single block is expected.
Instances are typically constructed by ``build_layer`` when given a
``LayerGroupConfig``; you rarely need to create them directly.
Parameters
----------
layers : list[Block]
Pre-built block instances to chain, in execution order.
input_height : int
Height of the tensor entering the first block.
input_channels : int
Number of channels in the tensor entering the first block.
Attributes
----------
in_channels : int
Number of input channels (taken from the first block).
out_channels : int
Number of output channels (taken from the last block).
layers : nn.Sequential
The wrapped sequence of block modules.
"""
def __init__(
self,
layers: list[Block],
input_height: int,
input_channels: int,
):
super().__init__()
self.in_channels = input_channels
self.out_channels = (
layers[-1].out_channels if layers else input_channels
)
self.layers = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Pass input through all blocks in sequence.
Parameters
----------
x : torch.Tensor
Input feature map, shape ``(B, C_in, H, W)``.
Returns
-------
torch.Tensor
Output feature map after all blocks have been applied.
"""
return self.layers(x)
def get_output_height(self, input_height: int) -> int:
"""Compute the output height by propagating through all blocks.
Parameters
----------
input_height : int
Height of the input feature map.
Returns
-------
int
Height after all blocks in the group have been applied.
"""
for block in self.layers:
input_height = block.get_output_height(input_height) # type: ignore
return input_height
@block_registry.register(LayerGroupConfig)
@staticmethod
def from_config(
config: LayerGroupConfig,
input_channels: int,
input_height: int,
):
layers = []
for layer_config in config.layers:
layer = build_layer(
input_height=input_height,
in_channels=input_channels,
config=layer_config,
)
layers.append(layer)
input_height = layer.get_output_height(input_height)
input_channels = layer.out_channels
return LayerGroup(
layers=layers,
input_height=input_height,
input_channels=input_channels,
)
def build_layer_from_config(
def build_layer(
input_height: int,
in_channels: int,
config: LayerConfig,
) -> Tuple[nn.Module, int, int]:
"""Factory function to build a specific nn.Module block from its config.
) -> Block:
"""Build a block from its configuration object.
Takes configuration object (one of the types included in the `LayerConfig`
union) and instantiates the corresponding nn.Module block with the correct
parameters derived from the config and the current pipeline state
(`input_height`, `in_channels`).
It uses the `name` field within the `config` object to determine
which block class to instantiate.
Looks up the block class corresponding to ``config.name`` in the
internal block registry and instantiates it with the given input
dimensions. This is the standard way to construct blocks when
assembling an encoder or decoder from a configuration file.
Parameters
----------
input_height : int
Height (frequency bins) of the input tensor *to this layer*.
Height (number of frequency bins) of the input tensor to this
block. Required for blocks whose kernel size depends on the input
height (e.g. ``VerticalConv``) and for coordinate-aware blocks.
in_channels : int
Number of channels in the input tensor *to this layer*.
Number of channels in the input tensor to this block.
config : LayerConfig
A Pydantic configuration object for the desired block (e.g., an
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
by its `name` field.
A configuration object for the desired block type. The ``name``
field selects the block class; remaining fields supply its
parameters.
Returns
-------
Tuple[nn.Module, int, int]
A tuple containing:
- The instantiated `nn.Module` block.
- The number of output channels produced by the block.
- The calculated height of the output produced by the block.
Block
An initialised block module ready to be added to an
``nn.Sequential`` or ``nn.ModuleList``.
Raises
------
NotImplementedError
If the `config.name` does not correspond to a known block type.
KeyError
If ``config.name`` does not correspond to a registered block type.
ValueError
If parameters derived from the config are invalid for the block.
If the configuration parameters are invalid for the chosen block.
"""
if config.name == "ConvBlock":
return (
ConvBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height,
)
if config.name == "FreqCoordConvDown":
return (
FreqCoordConvDownBlock(
in_channels=in_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height // 2,
)
if config.name == "StandardConvDown":
return (
StandardConvDownBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height // 2,
)
if config.name == "FreqCoordConvUp":
return (
FreqCoordConvUpBlock(
in_channels=in_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height * 2,
)
if config.name == "StandardConvUp":
return (
StandardConvUpBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height * 2,
)
if config.name == "SelfAttention":
return (
SelfAttention(
in_channels=in_channels,
attention_channels=config.attention_channels,
temperature=config.temperature,
),
config.attention_channels,
input_height,
)
if config.name == "LayerGroup":
current_channels = in_channels
current_height = input_height
blocks = []
for block_config in config.layers:
block, current_channels, current_height = build_layer_from_config(
input_height=current_height,
in_channels=current_channels,
config=block_config,
)
blocks.append(block)
return nn.Sequential(*blocks), current_channels, current_height
raise NotImplementedError(f"Unknown block type {config.name}")
return block_registry.build(config, in_channels, input_height)

View File

@ -1,20 +1,24 @@
"""Defines the Bottleneck component of an Encoder-Decoder architecture.
"""Bottleneck component for encoder-decoder network architectures.
This module provides the configuration (`BottleneckConfig`) and
`torch.nn.Module` implementations (`Bottleneck`, `BottleneckAttn`) for the
bottleneck layer(s) that typically connect the Encoder (downsampling path) and
Decoder (upsampling path) in networks like U-Nets.
The bottleneck sits between the encoder (downsampling path) and the decoder
(upsampling path) and processes the lowest-resolution, highest-channel feature
map produced by the encoder.
The bottleneck processes the lowest-resolution, highest-dimensionality feature
map produced by the Encoder. This module offers a configurable option to include
a `SelfAttention` layer within the bottleneck, allowing the model to capture
global temporal context before features are passed to the Decoder.
This module provides:
A factory function `build_bottleneck` constructs the appropriate bottleneck
module based on the provided configuration.
- ``BottleneckConfig`` configuration dataclass describing the number of
internal channels and an optional sequence of additional layers (currently
only ``SelfAttention`` is supported).
- ``Bottleneck`` the ``torch.nn.Module`` implementation. It first applies a
``VerticalConv`` to collapse the frequency axis to a single bin, optionally
runs one or more additional layers (e.g. self-attention along the time axis),
then repeats the output along the height dimension to restore the original
frequency resolution before passing features to the decoder.
- ``build_bottleneck`` factory function that constructs a ``Bottleneck``
instance from a ``BottleneckConfig`` and the encoder's output dimensions.
"""
from typing import Annotated, List, Optional, Union
from typing import Annotated, List
import torch
from pydantic import Field
@ -22,10 +26,12 @@ from torch import nn
from batdetect2.core.configs import BaseConfig
from batdetect2.models.blocks import (
Block,
SelfAttentionConfig,
VerticalConv,
build_layer_from_config,
build_layer,
)
from batdetect2.models.types import BottleneckProtocol
__all__ = [
"BottleneckConfig",
@ -34,43 +40,52 @@ __all__ = [
]
class Bottleneck(nn.Module):
"""Base Bottleneck module for Encoder-Decoder architectures.
class Bottleneck(Block):
"""Bottleneck module for encoder-decoder architectures.
This implementation represents the simplest bottleneck structure
considered, primarily consisting of a `VerticalConv` layer. This layer
collapses the frequency dimension (height) to 1, summarizing information
across frequencies at each time step. The output is then repeated along the
height dimension to match the original bottleneck input height before being
passed to the decoder.
Processes the lowest-resolution feature map that links the encoder and
decoder. The sequence of operations is:
This base version does *not* include self-attention.
1. ``VerticalConv`` collapses the frequency axis (height) to a single
bin by applying a convolution whose kernel spans the full height.
2. Optional additional layers (e.g. ``SelfAttention``) applied while
the feature map has height 1, so they operate purely along the time
axis.
3. Height restoration the single-bin output is repeated along the
height axis to restore the original frequency resolution, producing
a tensor that the decoder can accept.
Parameters
----------
input_height : int
Height (frequency bins) of the input tensor. Must be positive.
Height (number of frequency bins) of the input tensor. Must be
positive.
in_channels : int
Number of channels in the input tensor from the encoder. Must be
positive.
out_channels : int
Number of output channels. Must be positive.
Number of output channels after the bottleneck. Must be positive.
bottleneck_channels : int, optional
Number of internal channels used by the ``VerticalConv`` layer.
Defaults to ``out_channels`` if not provided.
layers : List[torch.nn.Module], optional
Additional modules (e.g. ``SelfAttention``) to apply after the
``VerticalConv`` and before height restoration.
Attributes
----------
in_channels : int
Number of input channels accepted by the bottleneck.
out_channels : int
Number of output channels produced by the bottleneck.
input_height : int
Expected height of the input tensor.
channels : int
Number of output channels.
bottleneck_channels : int
Number of channels used internally by the vertical convolution.
conv_vert : VerticalConv
The vertical convolution layer.
Raises
------
ValueError
If `input_height`, `in_channels`, or `out_channels` are not positive.
layers : nn.ModuleList
Additional layers applied after the vertical convolution.
"""
def __init__(
@ -78,14 +93,31 @@ class Bottleneck(nn.Module):
input_height: int,
in_channels: int,
out_channels: int,
bottleneck_channels: Optional[int] = None,
layers: Optional[List[torch.nn.Module]] = None,
bottleneck_channels: int | None = None,
layers: List[torch.nn.Module] | None = None,
) -> None:
"""Initialize the base Bottleneck layer."""
"""Initialise the Bottleneck layer.
Parameters
----------
input_height : int
Height (number of frequency bins) of the input tensor.
in_channels : int
Number of channels in the input tensor.
out_channels : int
Number of channels in the output tensor.
bottleneck_channels : int, optional
Number of internal channels for the ``VerticalConv``. Defaults
to ``out_channels``.
layers : List[torch.nn.Module], optional
Additional modules applied after the ``VerticalConv``, such as
a ``SelfAttention`` block.
"""
super().__init__()
self.in_channels = in_channels
self.input_height = input_height
self.out_channels = out_channels
self.bottleneck_channels = (
bottleneck_channels
if bottleneck_channels is not None
@ -100,23 +132,24 @@ class Bottleneck(nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input features through the bottleneck.
"""Process the encoder's bottleneck features.
Applies vertical convolution and repeats the output height.
Applies vertical convolution, optional additional layers, then
restores the height dimension by repetition.
Parameters
----------
x : torch.Tensor
Input tensor from the encoder bottleneck, shape
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
`H_in` must match `self.input_height`.
Input tensor from the encoder, shape
``(B, C_in, H_in, W)``. ``C_in`` must match
``self.in_channels`` and ``H_in`` must match
``self.input_height``.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H_in, W)`. Note that the height
dimension `H_in` is restored via repetition after the vertical
convolution.
Output tensor with shape ``(B, C_out, H_in, W)``. The height
``H_in`` is restored by repeating the single-bin result.
"""
x = self.conv_vert(x)
@ -127,37 +160,29 @@ class Bottleneck(nn.Module):
BottleneckLayerConfig = Annotated[
Union[SelfAttentionConfig,],
SelfAttentionConfig,
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""
"""Type alias for the discriminated union of block configs usable in the Bottleneck."""
class BottleneckConfig(BaseConfig):
"""Configuration for the bottleneck layer(s).
Defines the number of channels within the bottleneck and whether to include
a self-attention mechanism.
"""Configuration for the bottleneck component.
Attributes
----------
channels : int
The number of output channels produced by the main convolutional layer
within the bottleneck. This often matches the number of channels coming
from the last encoder stage, but can be different. Must be positive.
This also defines the channel dimensions used within the optional
`SelfAttention` layer.
self_attention : bool
If True, includes a `SelfAttention` layer operating on the time
dimension after an initial `VerticalConv` layer within the bottleneck.
If False, only the initial `VerticalConv` (and height repetition) is
performed.
Number of output channels produced by the bottleneck. This value
is also used as the dimensionality of any optional layers (e.g.
self-attention). Must be positive.
layers : List[BottleneckLayerConfig]
Ordered list of additional block configurations to apply after the
initial ``VerticalConv``. Currently only ``SelfAttentionConfig`` is
supported. Defaults to an empty list (no extra layers).
"""
channels: int
layers: List[BottleneckLayerConfig] = Field(
default_factory=list,
)
layers: List[BottleneckLayerConfig] = Field(default_factory=list)
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
@ -171,32 +196,39 @@ DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
def build_bottleneck(
input_height: int,
in_channels: int,
config: Optional[BottleneckConfig] = None,
) -> nn.Module:
"""Factory function to build the Bottleneck module from configuration.
config: BottleneckConfig | None = None,
) -> BottleneckProtocol:
"""Build a ``Bottleneck`` module from configuration.
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
on the `config.self_attention` flag.
Constructs a ``Bottleneck`` instance whose internal channel count and
optional extra layers (e.g. self-attention) are controlled by
``config``. If no configuration is provided, the default
``DEFAULT_BOTTLENECK_CONFIG`` is used, which includes a
``SelfAttention`` layer.
Parameters
----------
input_height : int
Height (frequency bins) of the input tensor. Must be positive.
Height (number of frequency bins) of the input tensor from the
encoder. Must be positive.
in_channels : int
Number of channels in the input tensor. Must be positive.
Number of channels in the input tensor from the encoder. Must be
positive.
config : BottleneckConfig, optional
Configuration object specifying the bottleneck channels and whether
to use self-attention. Uses `DEFAULT_BOTTLENECK_CONFIG` if None.
Configuration specifying the output channel count and any
additional layers. Uses ``DEFAULT_BOTTLENECK_CONFIG`` if ``None``.
Returns
-------
nn.Module
An initialized bottleneck module (`Bottleneck` or `BottleneckAttn`).
BottleneckProtocol
An initialised ``Bottleneck`` module.
Raises
------
ValueError
If `input_height` or `in_channels` are not positive.
AssertionError
If any configured layer changes the height of the feature map
(bottleneck layers must preserve height so that it can be restored
by repetition).
"""
config = config or DEFAULT_BOTTLENECK_CONFIG
@ -206,11 +238,13 @@ def build_bottleneck(
layers = []
for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config(
layer = build_layer(
input_height=current_height,
in_channels=current_channels,
config=layer_config,
)
current_height = layer.get_output_height(current_height)
current_channels = layer.out_channels
assert current_height == input_height, (
"Bottleneck layers should not change the spectrogram height"
)

View File

@ -1,98 +0,0 @@
from typing import Optional
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
EncoderConfig,
)
__all__ = [
"BackboneConfig",
"load_backbone_config",
]
class BackboneConfig(BaseConfig):
"""Configuration for the Encoder-Decoder Backbone network.
Aggregates configurations for the encoder, bottleneck, and decoder
components, along with defining the input and final output dimensions
for the complete backbone.
Attributes
----------
input_height : int, default=128
Expected height (frequency bins) of the input spectrograms to the
backbone. Must be positive.
in_channels : int, default=1
Expected number of channels in the input spectrograms (e.g., 1 for
mono). Must be positive.
encoder : EncoderConfig, optional
Configuration for the encoder. If None or omitted,
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
encoder module) will be used.
bottleneck : BottleneckConfig, optional
Configuration for the bottleneck layer connecting encoder and decoder.
If None or omitted, the default bottleneck configuration will be used.
decoder : DecoderConfig, optional
Configuration for the decoder. If None or omitted,
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
decoder module) will be used.
out_channels : int, default=32
Desired number of channels in the final feature map output by the
backbone. Must be positive.
"""
input_height: int = 128
in_channels: int = 1
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
out_channels: int = 32
def load_backbone_config(
path: data.PathLike,
field: Optional[str] = None,
) -> BackboneConfig:
"""Load the backbone configuration from a file.
Reads a configuration file (YAML) and validates it against the
`BackboneConfig` schema, potentially extracting data from a nested field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
backbone configuration (e.g., "model.backbone"). If None, the entire
file content is used.
Returns
-------
BackboneConfig
The loaded and validated backbone configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded config data does not conform to `BackboneConfig`.
KeyError, TypeError
If `field` specifies an invalid path.
"""
return load_config(path, schema=BackboneConfig, field=field)

View File

@ -1,24 +1,24 @@
"""Constructs the Decoder part of an Encoder-Decoder neural network.
"""Decoder (upsampling path) for the BatDetect2 backbone.
This module defines the configuration structure (`DecoderConfig`) for the layer
sequence and provides the `Decoder` class (an `nn.Module`) along with a factory
function (`build_decoder`). Decoders typically form the upsampling path in
architectures like U-Nets, taking bottleneck features
(usually from an `Encoder`) and skip connections to reconstruct
higher-resolution feature maps.
This module defines ``DecoderConfig`` and the ``Decoder`` ``nn.Module``,
together with the ``build_decoder`` factory function.
The decoder is built dynamically by stacking neural network blocks based on a
list of configuration objects provided in `DecoderConfig.layers`. Each config
object specifies the type of block (e.g., standard convolution,
coordinate-feature convolution with upsampling) and its parameters. This allows
flexible definition of decoder architectures via configuration files.
In a U-Net-style network the decoder progressively restores the spatial
resolution of the feature map back towards the input resolution. At each
stage it combines the upsampled features with the corresponding skip-connection
tensor from the encoder (the residual) by element-wise addition before passing
the result to the upsampling block.
The `Decoder`'s `forward` method is designed to accept skip connection tensors
(`residuals`) from the encoder, merging them with the upsampled feature maps
at each stage.
The decoder is fully configurable: the type, number, and parameters of the
upsampling blocks are described by a ``DecoderConfig`` object containing an
ordered list of block configuration objects (see ``batdetect2.models.blocks``
for available block types).
A default configuration ``DEFAULT_DECODER_CONFIG`` is provided and used by
``build_decoder`` when no explicit configuration is supplied.
"""
from typing import Annotated, List, Optional, Union
from typing import Annotated, List
import torch
from pydantic import Field
@ -30,7 +30,7 @@ from batdetect2.models.blocks import (
FreqCoordConvUpConfig,
LayerGroupConfig,
StandardConvUpConfig,
build_layer_from_config,
build_layer,
)
__all__ = [
@ -41,63 +41,57 @@ __all__ = [
]
DecoderLayerConfig = Annotated[
Union[
ConvConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
LayerGroupConfig,
],
ConvConfig
| FreqCoordConvUpConfig
| StandardConvUpConfig
| LayerGroupConfig,
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""
class DecoderConfig(BaseConfig):
"""Configuration for the sequence of layers in the Decoder module.
Defines the types and parameters of the neural network blocks that
constitute the decoder's upsampling path.
"""Configuration for the sequential ``Decoder`` module.
Attributes
----------
layers : List[DecoderLayerConfig]
An ordered list of configuration objects, each defining one layer or
block in the decoder sequence. Each item must be a valid block
config including a `name` field and necessary parameters like
`out_channels`. Input channels for each layer are inferred sequentially.
The list must contain at least one layer.
Ordered list of block configuration objects defining the decoder's
upsampling stages (from deepest to shallowest). Each entry
specifies the block type (via its ``name`` field) and any
block-specific parameters such as ``out_channels``. Input channels
for each block are inferred automatically from the output of the
previous block. Must contain at least one entry.
"""
layers: List[DecoderLayerConfig] = Field(min_length=1)
class Decoder(nn.Module):
"""Sequential Decoder module composed of configurable upsampling layers.
"""Sequential decoder module composed of configurable upsampling layers.
Constructs the upsampling path of an encoder-decoder network by stacking
multiple blocks (e.g., `StandardConvUpBlock`, `FreqCoordConvUpBlock`)
based on a list of layer modules provided during initialization (typically
created by the `build_decoder` factory function).
Executes a series of upsampling blocks in order, adding the
corresponding encoder skip-connection tensor (residual) to the feature
map before each block. The residuals are consumed in reverse order (from
deepest encoder layer to shallowest) to match the spatial resolutions at
each decoder stage.
The `forward` method is designed to integrate skip connection tensors
(`residuals`) from the corresponding encoder stages, by adding them
element-wise to the input of each decoder layer before processing.
Instances are typically created by ``build_decoder``.
Attributes
----------
in_channels : int
Number of channels expected in the input tensor.
Number of channels expected in the input tensor (bottleneck output).
out_channels : int
Number of channels in the final output tensor produced by the last
layer.
Number of channels in the final output feature map.
input_height : int
Height (frequency bins) expected in the input tensor.
Height (frequency bins) of the input tensor.
output_height : int
Height (frequency bins) expected in the output tensor.
Height (frequency bins) of the output tensor.
layers : nn.ModuleList
The sequence of instantiated upscaling layer modules.
Sequence of instantiated upsampling block modules.
depth : int
The number of upscaling layers (depth) in the decoder.
Number of upsampling layers.
"""
def __init__(
@ -108,23 +102,24 @@ class Decoder(nn.Module):
output_height: int,
layers: List[nn.Module],
):
"""Initialize the Decoder module.
"""Initialise the Decoder module.
Note: This constructor is typically called internally by the
`build_decoder` factory function.
This constructor is typically called by the ``build_decoder``
factory function.
Parameters
----------
in_channels : int
Number of channels in the input tensor (bottleneck output).
out_channels : int
Number of channels produced by the final layer.
input_height : int
Expected height of the input tensor (bottleneck).
in_channels : int
Expected number of channels in the input tensor (bottleneck).
Height of the input tensor (bottleneck output height).
output_height : int
Height of the output tensor after all layers have been applied.
layers : List[nn.Module]
A list of pre-instantiated upscaling layer modules (e.g.,
`StandardConvUpBlock` or `FreqCoordConvUpBlock`) in the desired
sequence (from bottleneck towards output resolution).
Pre-built upsampling block modules in execution order (deepest
stage first).
"""
super().__init__()
@ -142,43 +137,35 @@ class Decoder(nn.Module):
x: torch.Tensor,
residuals: List[torch.Tensor],
) -> torch.Tensor:
"""Pass input through decoder layers, incorporating skip connections.
"""Pass input through all decoder layers, incorporating skip connections.
Processes the input tensor `x` sequentially through the upscaling
layers. At each stage, the corresponding skip connection tensor from
the `residuals` list is added element-wise to the input before passing
it to the upscaling block.
At each stage the corresponding residual tensor is added
element-wise to ``x`` before it is passed to the upsampling block.
Residuals are consumed in reverse order the last element of
``residuals`` (the output of the shallowest encoder layer) is added
at the first decoder stage, and the first element (output of the
deepest encoder layer) is added at the last decoder stage.
Parameters
----------
x : torch.Tensor
Input tensor from the previous stage (e.g., encoder bottleneck).
Shape `(B, C_in, H_in, W_in)`, where `C_in` matches
`self.in_channels`.
Bottleneck feature map, shape ``(B, C_in, H_in, W)``.
residuals : List[torch.Tensor]
List containing the skip connection tensors from the corresponding
encoder stages. Should be ordered from the deepest encoder layer
output (lowest resolution) to the shallowest (highest resolution
near input). The number of tensors in this list must match the
number of decoder layers (`self.depth`). Each residual tensor's
channel count must be compatible with the input tensor `x` for
element-wise addition (or concatenation if the blocks were designed
for it).
Skip-connection tensors from the encoder, ordered from shallowest
(index 0) to deepest (index -1). Must contain exactly
``self.depth`` tensors. Each tensor must have the same spatial
dimensions and channel count as ``x`` at the corresponding
decoder stage.
Returns
-------
torch.Tensor
The final decoded feature map tensor produced by the last layer.
Shape `(B, C_out, H_out, W_out)`.
Decoded feature map, shape ``(B, C_out, H_out, W)``.
Raises
------
ValueError
If the number of `residuals` provided does not match the decoder
depth.
RuntimeError
If shapes mismatch during skip connection addition or layer
processing.
If the number of ``residuals`` does not equal ``self.depth``.
"""
if len(residuals) != len(self.layers):
raise ValueError(
@ -187,7 +174,7 @@ class Decoder(nn.Module):
f"but got {len(residuals)}."
)
for layer, res in zip(self.layers, residuals[::-1]):
for layer, res in zip(self.layers, residuals[::-1], strict=False):
x = layer(x + res)
return x
@ -205,53 +192,55 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
),
],
)
"""A default configuration for the Decoder's *layer sequence*.
"""Default decoder configuration used in standard BatDetect2 models.
Specifies an architecture often used in BatDetect2, consisting of three
frequency coordinate-aware upsampling blocks followed by a standard
convolutional block.
Mirrors ``DEFAULT_ENCODER_CONFIG`` in reverse. Assumes the bottleneck
output has 256 channels and height 16, and produces:
- Stage 1 (``FreqCoordConvUp``): 64 channels, height 32.
- Stage 2 (``FreqCoordConvUp``): 32 channels, height 64.
- Stage 3 (``LayerGroup``):
- ``FreqCoordConvUp``: 32 channels, height 128.
- ``ConvBlock``: 32 channels, height 128 (final feature map).
"""
def build_decoder(
in_channels: int,
input_height: int,
config: Optional[DecoderConfig] = None,
config: DecoderConfig | None = None,
) -> Decoder:
"""Factory function to build a Decoder instance from configuration.
"""Build a ``Decoder`` from configuration.
Constructs a sequential `Decoder` module based on the layer sequence
defined in a `DecoderConfig` object and the provided input dimensions
(bottleneck channels and height). If no config is provided, uses the
default layer sequence from `DEFAULT_DECODER_CONFIG`.
It iteratively builds the layers using the unified `build_layer_from_config`
factory (from `.blocks`), tracking the changing number of channels and
feature map height required for each subsequent layer.
Constructs a sequential ``Decoder`` by iterating over the block
configurations in ``config.layers``, building each block with
``build_layer``, and tracking the channel count and feature-map height
as they change through the sequence.
Parameters
----------
in_channels : int
The number of channels in the input tensor to the decoder. Must be > 0.
Number of channels in the input tensor (bottleneck output). Must
be positive.
input_height : int
The height (frequency bins) of the input tensor to the decoder. Must be
> 0.
Height (number of frequency bins) of the input tensor. Must be
positive.
config : DecoderConfig, optional
The configuration object detailing the sequence of layers and their
parameters. If None, `DEFAULT_DECODER_CONFIG` is used.
Configuration specifying the layer sequence. Defaults to
``DEFAULT_DECODER_CONFIG`` if not provided.
Returns
-------
Decoder
An initialized `Decoder` module.
An initialised ``Decoder`` module.
Raises
------
ValueError
If `in_channels` or `input_height` are not positive, or if the layer
configuration is invalid (e.g., empty list, unknown `name`).
NotImplementedError
If `build_layer_from_config` encounters an unknown `name`.
If ``in_channels`` or ``input_height`` are not positive.
KeyError
If a layer configuration specifies an unknown block type.
"""
config = config or DEFAULT_DECODER_CONFIG
@ -261,11 +250,13 @@ def build_decoder(
layers = []
for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config(
layer = build_layer(
in_channels=current_channels,
input_height=current_height,
config=layer_config,
)
current_height = layer.get_output_height(current_height)
current_channels = layer.out_channels
layers.append(layer)
return Decoder(

View File

@ -1,27 +1,32 @@
"""Assembles the complete BatDetect2 Detection Model.
"""Assembles the complete BatDetect2 detection model.
This module defines the concrete `Detector` class, which implements the
`DetectionModel` interface defined in `.types`. It combines a feature
extraction backbone with specific prediction heads to create the end-to-end
neural network used for detecting bat calls, predicting their size, and
classifying them.
This module defines the ``Detector`` class, which combines a backbone
feature extractor with prediction heads for detection, classification, and
bounding-box size regression.
The primary components are:
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
Components
----------
- ``Detector`` the ``torch.nn.Module`` that wires together a backbone
(``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to
produce a ``ModelOutput`` tuple from an input spectrogram.
- ``build_detector`` factory function that builds a ready-to-use
``Detector`` from a backbone configuration and a target class count.
This module focuses purely on the neural network architecture definition. The
logic for preprocessing inputs and postprocessing/decoding outputs resides in
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
Note that ``Detector`` operates purely on spectrogram tensors; raw audio
preprocessing and output postprocessing are handled by
``batdetect2.preprocess`` and ``batdetect2.postprocess`` respectively.
"""
from typing import Optional
import torch
from loguru import logger
from batdetect2.models.backbones import BackboneConfig, build_backbone
from batdetect2.models.backbones import (
BackboneConfig,
UNetBackboneConfig,
build_backbone,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
__all__ = [
"Detector",
@ -30,25 +35,30 @@ __all__ = [
class Detector(DetectionModel):
"""Concrete implementation of the BatDetect2 Detection Model.
"""Complete BatDetect2 detection and classification model.
Assembles a complete detection and classification model by combining a
feature extraction backbone network with specific prediction heads for
detection probability, bounding box size regression, and class
probabilities.
Combines a backbone feature extractor with two prediction heads:
- ``ClassifierHead``: predicts per-class probabilities at each
timefrequency location.
- ``BBoxHead``: predicts call duration and bandwidth at each location.
The detection probability map is derived from the class probabilities by
summing across the class dimension (i.e. the probability that *any* class
is present), rather than from a separate detection head.
Instances are typically created via ``build_detector``.
Attributes
----------
backbone : BackboneModel
The feature extraction backbone network module.
The feature extraction backbone.
num_classes : int
The number of specific target classes the model predicts (derived from
the `classifier_head`).
Number of target classes (inferred from the classifier head).
classifier_head : ClassifierHead
The prediction head responsible for generating class probabilities.
Produces per-class probability maps from backbone features.
bbox_head : BBoxHead
The prediction head responsible for generating bounding box size
predictions.
Produces duration and bandwidth predictions from backbone features.
"""
backbone: BackboneModel
@ -59,26 +69,21 @@ class Detector(DetectionModel):
classifier_head: ClassifierHead,
bbox_head: BBoxHead,
):
"""Initialize the Detector model.
"""Initialise the Detector model.
Note: Instances are typically created using the `build_detector`
This constructor is typically called by the ``build_detector``
factory function.
Parameters
----------
backbone : BackboneModel
An initialized feature extraction backbone module (e.g., built by
`build_backbone` from the `.backbone` module).
An initialised backbone module (e.g. built by
``build_backbone``).
classifier_head : ClassifierHead
An initialized classification head module. The number of classes
is inferred from this head.
An initialised classification head. The ``num_classes``
attribute is read from this head.
bbox_head : BBoxHead
An initialized bounding box size prediction head module.
Raises
------
TypeError
If the provided modules are not of the expected types.
An initialised bounding-box size prediction head.
"""
super().__init__()
@ -88,31 +93,34 @@ class Detector(DetectionModel):
self.bbox_head = bbox_head
def forward(self, spec: torch.Tensor) -> ModelOutput:
"""Perform the forward pass of the complete detection model.
"""Run the complete detection model on an input spectrogram.
Processes the input spectrogram through the backbone to extract
features, then passes these features through the separate prediction
heads to generate detection probabilities, class probabilities, and
size predictions.
Passes the spectrogram through the backbone to produce a feature
map, then applies the classifier and bounding-box heads. The
detection probability map is derived by summing the per-class
probability maps across the class dimension; no separate detection
head is used.
Parameters
----------
spec : torch.Tensor
Input spectrogram tensor, typically with shape
`(batch_size, input_channels, frequency_bins, time_bins)`. The
shape must be compatible with the `self.backbone` input
requirements.
Input spectrogram tensor, shape
``(batch_size, channels, frequency_bins, time_bins)``.
Returns
-------
ModelOutput
A NamedTuple containing the four output tensors:
- `detection_probs`: Detection probability heatmap `(B, 1, H, W)`.
- `size_preds`: Predicted scaled size dimensions `(B, 2, H, W)`.
- `class_probs`: Class probabilities (excluding background)
`(B, num_classes, H, W)`.
- `features`: Output feature map from the backbone
`(B, C_out, H, W)`.
A named tuple with four fields:
- ``detection_probs`` ``(B, 1, H, W)`` probability that a
call of any class is present at each location. Derived by
summing ``class_probs`` over the class dimension.
- ``size_preds`` ``(B, 2, H, W)`` scaled duration (channel
0) and bandwidth (channel 1) predictions at each location.
- ``class_probs`` ``(B, num_classes, H, W)`` per-class
probabilities at each location.
- ``features`` ``(B, C_out, H, W)`` raw backbone feature
map.
"""
features = self.backbone(spec)
classification = self.classifier_head(features)
@ -127,40 +135,46 @@ class Detector(DetectionModel):
def build_detector(
num_classes: int, config: Optional[BackboneConfig] = None
num_classes: int,
config: BackboneConfig | None = None,
backbone: BackboneModel | None = None,
) -> DetectionModel:
"""Build the complete BatDetect2 detection model.
"""Build a complete BatDetect2 detection model.
Constructs a backbone from ``config``, attaches a ``ClassifierHead``
and a ``BBoxHead`` sized to the backbone's output channel count, and
returns them wrapped in a ``Detector``.
Parameters
----------
num_classes : int
The number of specific target classes the model should predict
(required for the `ClassifierHead`). Must be positive.
Number of target bat species or call types to predict. Must be
positive.
config : BackboneConfig, optional
Configuration object specifying the architecture of the backbone
(encoder, bottleneck, decoder). If None, default configurations defined
within the respective builder functions (`build_encoder`, etc.) will be
used to construct a default backbone architecture.
Backbone architecture configuration. Defaults to
``UNetBackboneConfig()`` (the standard BatDetect2 architecture) if
not provided.
Returns
-------
DetectionModel
An initialized `Detector` model instance.
An initialised ``Detector`` instance ready for training or
inference.
Raises
------
ValueError
If `num_classes` is not positive, or if errors occur during the
construction of the backbone or detector components (e.g., incompatible
configurations, invalid parameters).
If ``num_classes`` is not positive, or if the backbone
configuration is invalid.
"""
config = config or BackboneConfig()
if backbone is None:
config = config or UNetBackboneConfig()
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(), # type: ignore
)
backbone = build_backbone(config=config)
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(),
)
backbone = build_backbone(config=config)
classifier_head = ClassifierHead(
num_classes=num_classes,
in_channels=backbone.out_channels,

View File

@ -1,26 +1,27 @@
"""Constructs the Encoder part of a configurable neural network backbone.
"""Encoder (downsampling path) for the BatDetect2 backbone.
This module defines the configuration structure (`EncoderConfig`) and provides
the `Encoder` class (an `nn.Module`) along with a factory function
(`build_encoder`) to create sequential encoders. Encoders typically form the
downsampling path in architectures like U-Nets, processing input feature maps
(like spectrograms) to produce lower-resolution, higher-dimensionality feature
representations (bottleneck features).
This module defines ``EncoderConfig`` and the ``Encoder`` ``nn.Module``,
together with the ``build_encoder`` factory function.
The encoder is built dynamically by stacking neural network blocks based on a
list of configuration objects provided in `EncoderConfig.layers`. Each
configuration object specifies the type of block (e.g., standard convolution,
coordinate-feature convolution with downsampling) and its parameters
(e.g., output channels). This allows for flexible definition of encoder
architectures via configuration files.
In a U-Net-style network the encoder progressively reduces the spatial
resolution of the spectrogram whilst increasing the number of feature
channels. Each layer in the encoder produces a feature map that is stored
for use as a skip connection in the corresponding decoder layer.
The `Encoder`'s `forward` method returns outputs from all intermediate layers,
suitable for skip connections, while the `encode` method returns only the final
bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also
provided.
The encoder is fully configurable: the type, number, and parameters of the
downsampling blocks are described by an ``EncoderConfig`` object containing
an ordered list of block configuration objects (see ``batdetect2.models.blocks``
for available block types).
``Encoder.forward`` returns the outputs of *all* encoder layers as a list,
so that skip connections are available to the decoder.
``Encoder.encode`` returns only the final output (the input to the bottleneck).
A default configuration ``DEFAULT_ENCODER_CONFIG`` is provided and used by
``build_encoder`` when no explicit configuration is supplied.
"""
from typing import Annotated, List, Optional, Union
from typing import Annotated, List
import torch
from pydantic import Field
@ -32,7 +33,7 @@ from batdetect2.models.blocks import (
FreqCoordConvDownConfig,
LayerGroupConfig,
StandardConvDownConfig,
build_layer_from_config,
build_layer,
)
__all__ = [
@ -43,47 +44,42 @@ __all__ = [
]
EncoderLayerConfig = Annotated[
Union[
ConvConfig,
FreqCoordConvDownConfig,
StandardConvDownConfig,
LayerGroupConfig,
],
ConvConfig
| FreqCoordConvDownConfig
| StandardConvDownConfig
| LayerGroupConfig,
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Encoder."""
class EncoderConfig(BaseConfig):
"""Configuration for building the sequential Encoder module.
Defines the sequence of neural network blocks that constitute the encoder
(downsampling path).
"""Configuration for the sequential ``Encoder`` module.
Attributes
----------
layers : List[EncoderLayerConfig]
An ordered list of configuration objects, each defining one layer or
block in the encoder sequence. Each item must be a valid block config
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
`StandardConvDownConfig`) including a `name` field and necessary
parameters like `out_channels`. Input channels for each layer are
inferred sequentially. The list must contain at least one layer.
Ordered list of block configuration objects defining the encoder's
downsampling stages. Each entry specifies the block type (via its
``name`` field) and any block-specific parameters such as
``out_channels``. Input channels for each block are inferred
automatically from the output of the previous block. Must contain
at least one entry.
"""
layers: List[EncoderLayerConfig] = Field(min_length=1)
class Encoder(nn.Module):
"""Sequential Encoder module composed of configurable downscaling layers.
"""Sequential encoder module composed of configurable downsampling layers.
Constructs the downsampling path of an encoder-decoder network by stacking
multiple downscaling blocks.
Executes a series of downsampling blocks in order, storing the output of
each block so that it can be passed as a skip connection to the
corresponding decoder layer.
The `forward` method executes the sequence and returns the output feature
map from *each* downscaling stage, facilitating the implementation of skip
connections in U-Net-like architectures. The `encode` method returns only
the final output tensor (bottleneck features).
``forward`` returns the outputs of *all* layers (useful when skip
connections are needed). ``encode`` returns only the final output
(the input to the bottleneck).
Attributes
----------
@ -91,14 +87,14 @@ class Encoder(nn.Module):
Number of channels expected in the input tensor.
input_height : int
Height (frequency bins) expected in the input tensor.
output_channels : int
Number of channels in the final output tensor (bottleneck).
out_channels : int
Number of channels in the final output tensor (bottleneck input).
output_height : int
Height (frequency bins) expected in the output tensor.
Height (frequency bins) of the final output tensor.
layers : nn.ModuleList
The sequence of instantiated downscaling layer modules.
Sequence of instantiated downsampling block modules.
depth : int
The number of downscaling layers in the encoder.
Number of downsampling layers.
"""
def __init__(
@ -109,23 +105,22 @@ class Encoder(nn.Module):
input_height: int = 128,
in_channels: int = 1,
):
"""Initialize the Encoder module.
"""Initialise the Encoder module.
Note: This constructor is typically called internally by the
`build_encoder` factory function, which prepares the `layers` list.
This constructor is typically called by the ``build_encoder`` factory
function, which takes care of building the ``layers`` list from a
configuration object.
Parameters
----------
output_channels : int
Number of channels produced by the final layer.
output_height : int
The expected height of the output tensor.
Height of the output tensor after all layers have been applied.
layers : List[nn.Module]
A list of pre-instantiated downscaling layer modules (e.g.,
`StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired
sequence.
Pre-built downsampling block modules in execution order.
input_height : int, default=128
Expected height of the input tensor.
Expected height of the input tensor (frequency bins).
in_channels : int, default=1
Expected number of channels in the input tensor.
"""
@ -140,29 +135,30 @@ class Encoder(nn.Module):
self.depth = len(self.layers)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Pass input through encoder layers, returns all intermediate outputs.
"""Pass input through all encoder layers and return every output.
This method is typically used when the Encoder is part of a U-Net or
similar architecture requiring skip connections.
Used when skip connections are needed (e.g. in a U-Net decoder).
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match
`self.in_channels` and `H_in` must match `self.input_height`.
Input spectrogram feature map, shape ``(B, C_in, H_in, W)``.
``C_in`` must match ``self.in_channels`` and ``H_in`` must
match ``self.input_height``.
Returns
-------
List[torch.Tensor]
A list containing the output tensors from *each* downscaling layer
in the sequence. `outputs[0]` is the output of the first layer,
`outputs[-1]` is the final output (bottleneck) of the encoder.
Output tensors from every layer in order.
``outputs[0]`` is the output of the first (shallowest) layer;
``outputs[-1]`` is the output of the last (deepest) layer,
which serves as the input to the bottleneck.
Raises
------
ValueError
If input tensor channel count or height does not match expected
values.
If the input channel count or height does not match the
expected values.
"""
if x.shape[1] != self.in_channels:
raise ValueError(
@ -185,28 +181,29 @@ class Encoder(nn.Module):
return outputs
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Pass input through encoder layers, returning only the final output.
"""Pass input through all encoder layers and return only the final output.
This method provides access to the bottleneck features produced after
the last downscaling layer.
Use this when skip connections are not needed and you only require
the bottleneck feature map.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W)`. Must match expected
`in_channels` and `input_height`.
Input spectrogram feature map, shape ``(B, C_in, H_in, W)``.
Must satisfy the same shape requirements as ``forward``.
Returns
-------
torch.Tensor
The final output tensor (bottleneck features) from the last layer
of the encoder. Shape `(B, C_out, H_out, W_out)`.
Output of the last encoder layer, shape
``(B, C_out, H_out, W)``, where ``C_out`` is
``self.out_channels`` and ``H_out`` is ``self.output_height``.
Raises
------
ValueError
If input tensor channel count or height does not match expected
values.
If the input channel count or height does not match the
expected values.
"""
if x.shape[1] != self.in_channels:
raise ValueError(
@ -238,58 +235,57 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
),
],
)
"""Default configuration for the Encoder.
"""Default encoder configuration used in standard BatDetect2 models.
Specifies an architecture typically used in BatDetect2:
- Input: 1 channel, 128 frequency bins.
- Layer 1: FreqCoordConvDown -> 32 channels, H=64
- Layer 2: FreqCoordConvDown -> 64 channels, H=32
- Layer 3: FreqCoordConvDown -> 128 channels, H=16
- Layer 4: ConvBlock -> 256 channels, H=16 (Bottleneck)
Assumes a 1-channel input with 128 frequency bins and produces the
following feature maps:
- Stage 1 (``FreqCoordConvDown``): 32 channels, height 64.
- Stage 2 (``FreqCoordConvDown``): 64 channels, height 32.
- Stage 3 (``LayerGroup``):
- ``FreqCoordConvDown``: 128 channels, height 16.
- ``ConvBlock``: 256 channels, height 16 (bottleneck input).
"""
def build_encoder(
in_channels: int,
input_height: int,
config: Optional[EncoderConfig] = None,
config: EncoderConfig | None = None,
) -> Encoder:
"""Factory function to build an Encoder instance from configuration.
"""Build an ``Encoder`` from configuration.
Constructs a sequential `Encoder` module based on the layer sequence
defined in an `EncoderConfig` object and the provided input dimensions.
If no config is provided, uses the default layer sequence from
`DEFAULT_ENCODER_CONFIG`.
It iteratively builds the layers using the unified
`build_layer_from_config` factory (from `.blocks`), tracking the changing
number of channels and feature map height required for each subsequent
layer, especially for coordinate- aware blocks.
Constructs a sequential ``Encoder`` by iterating over the block
configurations in ``config.layers``, building each block with
``build_layer``, and tracking the channel count and feature-map height
as they change through the sequence.
Parameters
----------
in_channels : int
The number of channels expected in the input tensor to the encoder.
Must be > 0.
Number of channels in the input spectrogram tensor. Must be
positive.
input_height : int
The height (frequency bins) expected in the input tensor. Must be > 0.
Crucial for initializing coordinate-aware layers correctly.
Height (number of frequency bins) of the input spectrogram.
Must be positive and should be divisible by
``2 ** (number of downsampling stages)`` to avoid size mismatches
later in the network.
config : EncoderConfig, optional
The configuration object detailing the sequence of layers and their
parameters. If None, `DEFAULT_ENCODER_CONFIG` is used.
Configuration specifying the layer sequence. Defaults to
``DEFAULT_ENCODER_CONFIG`` if not provided.
Returns
-------
Encoder
An initialized `Encoder` module.
An initialised ``Encoder`` module.
Raises
------
ValueError
If `in_channels` or `input_height` are not positive, or if the layer
configuration is invalid (e.g., empty list, unknown `name`).
NotImplementedError
If `build_layer_from_config` encounters an unknown `name`.
If ``in_channels`` or ``input_height`` are not positive.
KeyError
If a layer configuration specifies an unknown block type.
"""
if in_channels <= 0 or input_height <= 0:
raise ValueError("in_channels and input_height must be positive.")
@ -302,12 +298,14 @@ def build_encoder(
layers = []
for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config(
layer = build_layer(
in_channels=current_channels,
input_height=current_height,
config=layer_config,
)
layers.append(layer)
current_height = layer.get_output_height(current_height)
current_channels = layer.out_channels
return Encoder(
input_height=input_height,

View File

@ -1,20 +1,19 @@
"""Prediction Head modules for BatDetect2 models.
"""Prediction heads attached to the backbone feature map.
This module defines simple `torch.nn.Module` subclasses that serve as
prediction heads, typically attached to the output feature map of a backbone
network
Each head is a lightweight ``torch.nn.Module`` that applies a 1×1
convolution to map backbone feature channels to one specific type of
output required by BatDetect2:
Each head is responsible for generating one specific type of output required
by the BatDetect2 task:
- `DetectorHead`: Predicts the probability of sound event presence.
- `ClassifierHead`: Predicts the probability distribution over target classes.
- `BBoxHead`: Predicts the size (width, height) of the sound event's bounding
box.
- ``DetectorHead``: single-channel detection probability heatmap (sigmoid
activation).
- ``ClassifierHead``: multi-class probability map over the target bat
species / call types (softmax activation).
- ``BBoxHead``: two-channel map of predicted call duration (time axis) and
bandwidth (frequency axis) at each location (no activation; raw
regression output).
These heads use 1x1 convolutions to map the backbone feature channels
to the desired number of output channels for each prediction task at each
spatial location, followed by an appropriate activation function (e.g., sigmoid
for detection, softmax for classification, none for size regression).
All three heads share the same input feature map produced by the backbone,
so they can be evaluated in parallel in a single forward pass.
"""
import torch
@ -28,42 +27,35 @@ __all__ = [
class ClassifierHead(nn.Module):
"""Prediction head for multi-class classification probabilities.
"""Prediction head for species / call-type classification probabilities.
Takes an input feature map and produces a probability map where each
channel corresponds to a specific target class. It uses a 1x1 convolution
to map input channels to `num_classes + 1` outputs (one for each target
class plus an assumed background/generic class), applies softmax across the
channels, and returns the probabilities for the specific target classes
(excluding the last background/generic channel).
Takes a backbone feature map and produces a probability map where each
channel corresponds to a target class. Internally the 1×1 convolution
maps ``in_channels`` to ``num_classes + 1`` logits (the extra channel
represents a generic background / unknown category); a softmax is then
applied across the channel dimension and the background channel is
discarded before returning.
Parameters
----------
num_classes : int
The number of specific target classes the model should predict
(excluding any background or generic category). Must be positive.
Number of target classes (bat species or call types) to predict,
excluding the background category. Must be positive.
in_channels : int
Number of channels in the input feature map tensor from the backbone.
Must be positive.
Number of channels in the backbone feature map. Must be positive.
Attributes
----------
num_classes : int
Number of specific output classes.
Number of specific output classes (background excluded).
in_channels : int
Number of input channels expected.
classifier : nn.Conv2d
The 1x1 convolutional layer used for prediction.
Output channels = num_classes + 1.
Raises
------
ValueError
If `num_classes` or `in_channels` are not positive.
1×1 convolution with ``num_classes + 1`` output channels.
"""
def __init__(self, num_classes: int, in_channels: int):
"""Initialize the ClassifierHead."""
"""Initialise the ClassifierHead."""
super().__init__()
self.num_classes = num_classes
@ -76,20 +68,20 @@ class ClassifierHead(nn.Module):
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute class probabilities from input features.
"""Compute per-class probabilities from backbone features.
Parameters
----------
features : torch.Tensor
Input feature map tensor from the backbone, typically with shape
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Backbone feature map, shape ``(B, C_in, H, W)``.
Returns
-------
torch.Tensor
Class probability map tensor with shape `(B, num_classes, H, W)`.
Contains probabilities for the specific target classes after
softmax, excluding the implicit background/generic class channel.
Class probability map, shape ``(B, num_classes, H, W)``.
Values are softmax probabilities in the range [0, 1] and
sum to less than 1 per location (the background probability
is discarded).
"""
logits = self.classifier(features)
probs = torch.softmax(logits, dim=1)
@ -97,36 +89,30 @@ class ClassifierHead(nn.Module):
class DetectorHead(nn.Module):
"""Prediction head for sound event detection probability.
"""Prediction head for detection probability (is a call present here?).
Takes an input feature map and produces a single-channel heatmap where
each value represents the probability ([0, 1]) of a relevant sound event
(of any class) being present at that spatial location.
Produces a single-channel heatmap where each value indicates the
probability ([0, 1]) that a bat call of *any* species is present at
that timefrequency location in the spectrogram.
Uses a 1x1 convolution to map input channels to 1 output channel, followed
by a sigmoid activation function.
Applies a 1×1 convolution mapping ``in_channels`` 1, followed by
sigmoid activation.
Parameters
----------
in_channels : int
Number of channels in the input feature map tensor from the backbone.
Must be positive.
Number of channels in the backbone feature map. Must be positive.
Attributes
----------
in_channels : int
Number of input channels expected.
detector : nn.Conv2d
The 1x1 convolutional layer mapping to a single output channel.
Raises
------
ValueError
If `in_channels` is not positive.
1×1 convolution with a single output channel.
"""
def __init__(self, in_channels: int):
"""Initialize the DetectorHead."""
"""Initialise the DetectorHead."""
super().__init__()
self.in_channels = in_channels
@ -138,62 +124,49 @@ class DetectorHead(nn.Module):
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute detection probabilities from input features.
"""Compute detection probabilities from backbone features.
Parameters
----------
features : torch.Tensor
Input feature map tensor from the backbone, typically with shape
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Backbone feature map, shape ``(B, C_in, H, W)``.
Returns
-------
torch.Tensor
Detection probability heatmap tensor with shape `(B, 1, H, W)`.
Values are in the range [0, 1] due to the sigmoid activation.
Raises
------
RuntimeError
If input channel count does not match `self.in_channels`.
Detection probability heatmap, shape ``(B, 1, H, W)``.
Values are in the range [0, 1].
"""
return torch.sigmoid(self.detector(features))
class BBoxHead(nn.Module):
"""Prediction head for bounding box size dimensions.
"""Prediction head for bounding box size (duration and bandwidth).
Takes an input feature map and produces a two-channel map where each
channel represents a predicted size dimension (typically width/duration and
height/bandwidth) for a potential sound event at that spatial location.
Produces a two-channel map where channel 0 predicts the scaled duration
(time-axis extent) and channel 1 predicts the scaled bandwidth
(frequency-axis extent) of the call at each spectrogram location.
Uses a 1x1 convolution to map input channels to 2 output channels. No
activation function is typically applied, as size prediction is often
treated as a direct regression task. The output values usually represent
*scaled* dimensions that need to be un-scaled during postprocessing.
Applies a 1×1 convolution mapping ``in_channels`` 2 with no
activation function (raw regression output). The predicted values are
in a scaled space and must be converted to real units (seconds and Hz)
during postprocessing.
Parameters
----------
in_channels : int
Number of channels in the input feature map tensor from the backbone.
Must be positive.
Number of channels in the backbone feature map. Must be positive.
Attributes
----------
in_channels : int
Number of input channels expected.
bbox : nn.Conv2d
The 1x1 convolutional layer mapping to 2 output channels
(width, height).
Raises
------
ValueError
If `in_channels` is not positive.
1×1 convolution with 2 output channels (duration, bandwidth).
"""
def __init__(self, in_channels: int):
"""Initialize the BBoxHead."""
"""Initialise the BBoxHead."""
super().__init__()
self.in_channels = in_channels
@ -205,19 +178,19 @@ class BBoxHead(nn.Module):
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute predicted bounding box dimensions from input features.
"""Predict call duration and bandwidth from backbone features.
Parameters
----------
features : torch.Tensor
Input feature map tensor from the backbone, typically with shape
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Backbone feature map, shape ``(B, C_in, H, W)``.
Returns
-------
torch.Tensor
Predicted size tensor with shape `(B, 2, H, W)`. Channel 0 usually
represents scaled width, Channel 1 scaled height. These values
need to be un-scaled during postprocessing.
Size prediction tensor, shape ``(B, 2, H, W)``. Channel 0 is
the predicted scaled duration; channel 1 is the predicted
scaled bandwidth. Values must be rescaled to real units during
postprocessing.
"""
return self.bbox(features)

View File

@ -0,0 +1,90 @@
from abc import ABC, abstractmethod
from typing import NamedTuple, Protocol
import torch
__all__ = [
"BackboneModel",
"BlockProtocol",
"BottleneckProtocol",
"DecoderProtocol",
"DetectionModel",
"EncoderDecoderModel",
"EncoderProtocol",
"ModelOutput",
]
class BlockProtocol(Protocol):
in_channels: int
out_channels: int
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
def get_output_height(self, input_height: int) -> int: ...
class EncoderProtocol(Protocol):
in_channels: int
out_channels: int
input_height: int
output_height: int
def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ...
class BottleneckProtocol(Protocol):
in_channels: int
out_channels: int
input_height: int
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
class DecoderProtocol(Protocol):
in_channels: int
out_channels: int
input_height: int
output_height: int
depth: int
def __call__(
self,
x: torch.Tensor,
residuals: list[torch.Tensor],
) -> torch.Tensor: ...
class ModelOutput(NamedTuple):
detection_probs: torch.Tensor
size_preds: torch.Tensor
class_probs: torch.Tensor
features: torch.Tensor
class BackboneModel(ABC, torch.nn.Module):
input_height: int
out_channels: int
@abstractmethod
def forward(self, spec: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class EncoderDecoderModel(BackboneModel):
bottleneck_channels: int
@abstractmethod
def encode(self, spec: torch.Tensor) -> torch.Tensor: ...
@abstractmethod
def decode(self, encoded: torch.Tensor) -> torch.Tensor: ...
class DetectionModel(ABC, torch.nn.Module):
backbone: BackboneModel
classifier_head: torch.nn.Module
bbox_head: torch.nn.Module
@abstractmethod
def forward(self, spec: torch.Tensor) -> ModelOutput: ...

View File

@ -0,0 +1,35 @@
from batdetect2.outputs.config import OutputsConfig
from batdetect2.outputs.formats import (
BatDetect2OutputConfig,
OutputFormatConfig,
ParquetOutputConfig,
RawOutputConfig,
SoundEventOutputConfig,
build_output_formatter,
get_output_formatter,
load_predictions,
)
from batdetect2.outputs.transforms import (
OutputTransformConfig,
build_output_transform,
)
from batdetect2.outputs.types import (
OutputFormatterProtocol,
OutputTransformProtocol,
)
__all__ = [
"BatDetect2OutputConfig",
"OutputFormatConfig",
"OutputFormatterProtocol",
"OutputTransformConfig",
"OutputTransformProtocol",
"OutputsConfig",
"ParquetOutputConfig",
"RawOutputConfig",
"SoundEventOutputConfig",
"build_output_formatter",
"build_output_transform",
"get_output_formatter",
"load_predictions",
]

View File

@ -0,0 +1,15 @@
from pydantic import Field
from batdetect2.core.configs import BaseConfig
from batdetect2.outputs.formats import OutputFormatConfig
from batdetect2.outputs.formats.raw import RawOutputConfig
from batdetect2.outputs.transforms import OutputTransformConfig
__all__ = ["OutputsConfig"]
class OutputsConfig(BaseConfig):
format: OutputFormatConfig = Field(default_factory=RawOutputConfig)
transform: OutputTransformConfig = Field(
default_factory=OutputTransformConfig
)

View File

@ -1,39 +1,42 @@
from typing import Annotated, Optional, Union
from typing import Annotated
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.data.predictions.base import (
from batdetect2.outputs.formats.base import (
OutputFormatterProtocol,
prediction_formatters,
output_formatters,
)
from batdetect2.data.predictions.batdetect2 import BatDetect2OutputConfig
from batdetect2.data.predictions.raw import RawOutputConfig
from batdetect2.data.predictions.soundevent import SoundEventOutputConfig
from batdetect2.typing import TargetProtocol
from batdetect2.outputs.formats.batdetect2 import BatDetect2OutputConfig
from batdetect2.outputs.formats.parquet import ParquetOutputConfig
from batdetect2.outputs.formats.raw import RawOutputConfig
from batdetect2.outputs.formats.soundevent import SoundEventOutputConfig
from batdetect2.targets.types import TargetProtocol
__all__ = [
"build_output_formatter",
"get_output_formatter",
"BatDetect2OutputConfig",
"OutputFormatConfig",
"ParquetOutputConfig",
"RawOutputConfig",
"SoundEventOutputConfig",
"build_output_formatter",
"get_output_formatter",
"load_predictions",
]
OutputFormatConfig = Annotated[
Union[
BatDetect2OutputConfig,
SoundEventOutputConfig,
RawOutputConfig,
],
BatDetect2OutputConfig
| ParquetOutputConfig
| SoundEventOutputConfig
| RawOutputConfig,
Field(discriminator="name"),
]
def build_output_formatter(
targets: Optional[TargetProtocol] = None,
config: Optional[OutputFormatConfig] = None,
targets: TargetProtocol | None = None,
config: OutputFormatConfig | None = None,
) -> OutputFormatterProtocol:
"""Construct the final output formatter."""
from batdetect2.targets import build_targets
@ -41,13 +44,13 @@ def build_output_formatter(
config = config or RawOutputConfig()
targets = targets or build_targets()
return prediction_formatters.build(config, targets)
return output_formatters.build(config, targets)
def get_output_formatter(
name: Optional[str] = None,
targets: Optional[TargetProtocol] = None,
config: Optional[OutputFormatConfig] = None,
name: str | None = None,
targets: TargetProtocol | None = None,
config: OutputFormatConfig | None = None,
) -> OutputFormatterProtocol:
"""Get the output formatter by name."""
@ -55,7 +58,7 @@ def get_output_formatter(
if name is None:
raise ValueError("Either config or name must be provided.")
config_class = prediction_formatters.get_config_type(name)
config_class = output_formatters.get_config_type(name)
config = config_class() # type: ignore
if config.name != name: # type: ignore
@ -68,9 +71,9 @@ def get_output_formatter(
def load_predictions(
path: PathLike,
format: Optional[str] = "raw",
config: Optional[OutputFormatConfig] = None,
targets: Optional[TargetProtocol] = None,
format: str | None = "raw",
config: OutputFormatConfig | None = None,
targets: TargetProtocol | None = None,
):
"""Load predictions from a file."""
from batdetect2.targets import build_targets

View File

@ -0,0 +1,46 @@
from pathlib import Path
from typing import Literal
from soundevent.data import PathLike
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.targets.types import TargetProtocol
__all__ = [
"OutputFormatterProtocol",
"PredictionFormatterImportConfig",
"make_path_relative",
"output_formatters",
]
def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
path = Path(path)
audio_dir = Path(audio_dir)
if path.is_absolute():
if not path.is_relative_to(audio_dir):
raise ValueError(
f"Audio file {path} is not in audio_dir {audio_dir}"
)
return path.relative_to(audio_dir)
return path
output_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
Registry(name="output_formatter")
)
@add_import_config(output_formatters)
class PredictionFormatterImportConfig(ImportConfig):
"""Use any callable as a prediction formatter.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"

Some files were not shown because too many files have changed in this diff Show More