mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 23:30:21 +02:00
Compare commits
73 Commits
a4a5a10da1
...
5a14b29281
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5a14b29281 | ||
|
|
13a31d9de9 | ||
|
|
a1fad6d7d7 | ||
|
|
b8acd86c71 | ||
|
|
875751d340 | ||
|
|
32d8c4a9e5 | ||
|
|
fb3dc3eaf0 | ||
|
|
0d90cb5cc3 | ||
|
|
91806aa01e | ||
|
|
23ac619c50 | ||
|
|
99b9e55c0e | ||
|
|
652670b01d | ||
|
|
0163a572cb | ||
|
|
f0af5dd79e | ||
|
|
2f03abe8f6 | ||
|
|
22a3d18d45 | ||
|
|
bf5b88016a | ||
|
|
f9056eb19a | ||
|
|
ebe7e134e9 | ||
|
|
8e35956007 | ||
|
|
a332c5c3bd | ||
|
|
9fa703b34b | ||
|
|
0bf809e376 | ||
|
|
6276a8884e | ||
|
|
7b1cb402b4 | ||
|
|
45ae15eed5 | ||
|
|
b3af70761e | ||
|
|
daff74fdde | ||
|
|
31d4f92359 | ||
|
|
d56b9f02ae | ||
|
|
573d8e38d6 | ||
|
|
751be53edf | ||
|
|
c226dc3f2b | ||
|
|
3b47c688dd | ||
|
|
feee2bdfa3 | ||
|
|
615c7d78fb | ||
|
|
56f6affc72 | ||
|
|
65bd0dc6ae | ||
|
|
1a7c0b4b3a | ||
|
|
8ac4f4c44d | ||
|
|
038d58ed99 | ||
|
|
47f418a63c | ||
|
|
197cc38e3e | ||
|
|
e0503487ec | ||
|
|
4b7d23abde | ||
|
|
3c337a06cb | ||
|
|
0d590a26cc | ||
|
|
46c02962f3 | ||
|
|
bfc88a4a0f | ||
|
|
ce952e364b | ||
|
|
ef3348d651 | ||
|
|
4207661da4 | ||
|
|
45e3cf1434 | ||
|
|
f2d5088bec | ||
|
|
652d076b46 | ||
|
|
e393709258 | ||
|
|
54605ef269 | ||
|
|
b8b8a68f49 | ||
|
|
6812e1c515 | ||
|
|
0b344003a1 | ||
|
|
c9bcaebcde | ||
|
|
d52e988b8f | ||
|
|
cce1b49a8d | ||
|
|
8313fe1484 | ||
|
|
4509602e70 | ||
|
|
0adb58e039 | ||
|
|
531ff69974 | ||
|
|
750f9e43c4 | ||
|
|
f71fe0c2e2 | ||
|
|
113f438e74 | ||
|
|
2563f26ed3 | ||
|
|
9c72537ddd | ||
|
|
72278d75ec |
7
.gitignore
vendored
7
.gitignore
vendored
@ -102,7 +102,7 @@ experiments/*
|
||||
DvcLiveLogger/checkpoints
|
||||
logs/
|
||||
mlruns/
|
||||
outputs/
|
||||
/outputs/
|
||||
notebooks/lightning_logs
|
||||
|
||||
# Jupiter notebooks
|
||||
@ -123,3 +123,8 @@ example_data/preprocessed
|
||||
|
||||
# Dev notebooks
|
||||
notebooks/tmp
|
||||
/tmp
|
||||
/.agents/skills
|
||||
/notebooks
|
||||
/AGENTS.md
|
||||
/scripts
|
||||
|
||||
93
docs/source/architecture.md
Normal file
93
docs/source/architecture.md
Normal file
@ -0,0 +1,93 @@
|
||||
# BatDetect2 Architecture Overview
|
||||
|
||||
This document provides a comprehensive map of the `batdetect2` codebase architecture. It is intended to serve as a deep-dive reference for developers, agents, and contributors navigating the project.
|
||||
|
||||
`batdetect2` is designed as a modular deep learning pipeline for detecting and classifying bat echolocation calls in high-frequency audio recordings. It heavily utilizes **PyTorch**, **PyTorch Lightning** for training, and the **Soundevent** library for standardized audio and geometry data classes.
|
||||
|
||||
The repository follows a configuration-driven design pattern, heavily utilizing `pydantic`/`omegaconf` (via `BaseConfig`) and the Factory/Registry patterns for dependency injection and modularity. The entire pipeline can be orchestrated via the high-level API `BatDetect2API` (`src/batdetect2/api_v2.py`).
|
||||
|
||||
---
|
||||
|
||||
## 1. Data Flow Pipeline
|
||||
|
||||
The standard lifecycle of a prediction request follows these sequential stages, each handled by an isolated, replaceable module:
|
||||
|
||||
1. **Audio Loading (`batdetect2.audio`)**: Read raw `.wav` files into standard NumPy arrays or `soundevent.data.Clip` objects. Handles resampling.
|
||||
2. **Preprocessing (`batdetect2.preprocess`)**: Converts raw 1D waveforms into 2D Spectrogram tensors.
|
||||
3. **Forward Pass (`batdetect2.models`)**: A PyTorch neural network processes the spectrogram and outputs dense prediction tensors (e.g., detection heatmaps, bounding box sizes, class probabilities).
|
||||
4. **Postprocessing (`batdetect2.postprocess`)**: Decodes the raw output tensors back into explicit geometry bounding boxes and runs Non-Maximum Suppression (NMS) to filter redundant predictions.
|
||||
5. **Formatting (`batdetect2.data`)**: Transforms the predictions into standard formats (`.csv`, `.json`, `.parquet`) using `OutputFormatterProtocol`.
|
||||
|
||||
---
|
||||
|
||||
## 2. Core Modules Breakdown
|
||||
|
||||
### 2.1 Audio and Preprocessing
|
||||
- **`audio/`**:
|
||||
- Centralizes audio I/O using `AudioLoader`. It abstracts over the `soundevent` library, efficiently handling full `Recording` files or smaller `Clip` segments, standardizing the sample rate.
|
||||
- **`preprocess/`**:
|
||||
- Dictated by the `PreprocessorProtocol`.
|
||||
- Its primary responsibility is spectrogram generation via Short-Time Fourier Transform (STFT).
|
||||
- During training, it incorporates data augmentation layers (e.g., amplitude scaling, time masking, frequency masking, spectral mean subtraction) configured via `PreprocessingConfig`.
|
||||
|
||||
### 2.2 Deep Learning Models (`models/`)
|
||||
The `models` directory contains all PyTorch neural network architectures. The default architecture is an Encoder-Decoder (U-Net style) network.
|
||||
- **`blocks.py`**: Reusable neural network blocks, including standard Convolutions (`ConvBlock`) and specialized layers like `FreqCoordConvDownBlock`/`FreqCoordConvUpBlock` which append normalized spatial frequency coordinates to explicitly grant convolutional filters frequency-awareness.
|
||||
- **`encoder.py`**: The downsampling path (feature extraction). Builds a sequential list of blocks and captures skip connections.
|
||||
- **`bottleneck.py`**: The deepest, lowest-resolution segment connecting the Encoder and Decoder. Features an optional `SelfAttention` mechanism to weigh global temporal contexts.
|
||||
- **`decoder.py`**: The upsampling path (reconstruction), actively integrating skip connections (residuals) from the Encoder.
|
||||
- **`heads.py`**: Attach to the backbone's feature map to output specific predictions:
|
||||
- `BBoxHead`: Predicts bounding box sizes.
|
||||
- `ClassifierHead`: Predicts species classes.
|
||||
- `DetectorHead`: Predicts detection probability heatmaps.
|
||||
- **`backbones.py` & `detectors.py`**: Assemble the encoder, bottleneck, decoder, and heads into a cohesive `Detector` model.
|
||||
- **`__init__.py:Model`**: The overarching wrapper `torch.nn.Module` containing the `detector`, `preprocessor`, `postprocessor`, and `targets`.
|
||||
|
||||
### 2.3 Targets and Regions of Interest (`targets/`)
|
||||
Crucial for training, this module translates physical annotations (Regions of Interest / ROIs) into training targets (tensors).
|
||||
- **`rois.py`**: Implements `ROITargetMapper`. Maps a geometric bounding box into a 2D reference `Position` (time, freq) and a `Size` array. Includes strategies like:
|
||||
- `AnchorBBoxMapper`: Maps based on a fixed bounding box corner/center.
|
||||
- `PeakEnergyBBoxMapper`: Identifies the physical coordinate of peak acoustic energy inside the bounding box and calculates offsets to the box edges.
|
||||
- **`targets.py`**: Reconstructs complete multi-channel target heatmaps and coordinate tensors from the ROIs to compute losses during training.
|
||||
|
||||
### 2.4 Postprocessing (`postprocess/`)
|
||||
- Implements `PostprocessorProtocol`.
|
||||
- Reverses the logic from `targets`. It scans the model's output detection heatmaps for peaks, extracts the predicted sizes and class probabilities at those peaks, and decodes them back into physical `soundevent.data.Geometry` (Bounding Boxes).
|
||||
- Automatically applies Non-Maximum Suppression (NMS) configured via `PostprocessConfig` to remove highly overlapping predictions.
|
||||
|
||||
### 2.5 Data Management (`data/`)
|
||||
- **`annotations/`**: Utilities to load dataset annotations supporting multiple standardized schemas (`AOEF`, `BatDetect2` formats).
|
||||
- **`datasets.py`**: Aggregates recordings and annotations into memory.
|
||||
- **`predictions/`**: Handles the exporting of model results via `OutputFormatterProtocol`. Includes formatters for `RawOutput`, `.parquet`, `.json`, etc.
|
||||
|
||||
### 2.6 Evaluation (`evaluate/`)
|
||||
- Computes scientific metrics using `EvaluatorProtocol`.
|
||||
- Provides specific testing environments for tasks like `Clip Classification`, `Clip Detection`, and `Top Class` predictions.
|
||||
- Generates precision-recall curves and scatter plots.
|
||||
|
||||
### 2.7 Training (`train/`)
|
||||
- Implements the distributed PyTorch training loop via PyTorch Lightning.
|
||||
- **`lightning.py`**: Contains `TrainingModule`, the `LightningModule` that orchestrates the optimizer, learning rate scheduler, forward passes, and backpropagation using the generated `targets`.
|
||||
|
||||
---
|
||||
|
||||
## 3. Interfaces and Tooling
|
||||
|
||||
### 3.1 APIs
|
||||
- **`api_v2.py` (`BatDetect2API`)**: The modern API object. It is deeply integrated with dependency injection using `BatDetect2Config`. It instantiates the loader, targets, preprocessor, postprocessor, and model, exposing easy-to-use methods like `process_file`, `evaluate`, and `train`.
|
||||
- **`api.py`**: The legacy API. Kept for backwards compatibility. Uses hardcoded default instances rather than configuration objects.
|
||||
|
||||
### 3.2 Command Line Interface (`cli/`)
|
||||
- Implements terminal commands utilizing `click`. Commands include `batdetect2 detect`, `evaluate`, and `train`.
|
||||
|
||||
### 3.3 Core and Configuration (`core/`, `config.py`)
|
||||
- **`core/registries.py`**: A string-based Registry pattern (e.g., `block_registry`, `roi_mapper_registry`) that allows developers to dynamically swap components (like a custom neural network block) via configuration files without modifying python code.
|
||||
- **`config.py`**: Aggregates all modular `BaseConfig` objects (`AudioConfig`, `PreprocessingConfig`, `BackboneConfig`) into the monolithic `BatDetect2Config`.
|
||||
|
||||
---
|
||||
|
||||
## Summary
|
||||
To navigate this codebase effectively:
|
||||
1. Follow **`api_v2.py`** to see how high-level operations invoke individual components.
|
||||
2. Rely heavily on the typed **Protocols** located in each subsystem's `types.py` module (for example `src/batdetect2/preprocess/types.py` and `src/batdetect2/postprocess/types.py`) to understand inputs and outputs without needing to read each implementation.
|
||||
3. Understand that data flows structurally as `soundevent` primitives externally, and as pure `torch.Tensor` internally through the network.
|
||||
@ -6,6 +6,7 @@ Hi!
|
||||
:maxdepth: 1
|
||||
:caption: Contents:
|
||||
|
||||
architecture
|
||||
data/index
|
||||
preprocessing/index
|
||||
postprocessing
|
||||
|
||||
@ -1,70 +1,79 @@
|
||||
config_version: v1
|
||||
|
||||
audio:
|
||||
samplerate: 256000
|
||||
resample:
|
||||
enabled: True
|
||||
method: "poly"
|
||||
|
||||
preprocess:
|
||||
stft:
|
||||
window_duration: 0.002
|
||||
window_overlap: 0.75
|
||||
window_fn: hann
|
||||
frequencies:
|
||||
max_freq: 120000
|
||||
min_freq: 10000
|
||||
size:
|
||||
height: 128
|
||||
resize_factor: 0.5
|
||||
spectrogram_transforms:
|
||||
- name: pcen
|
||||
time_constant: 0.1
|
||||
gain: 0.98
|
||||
bias: 2
|
||||
power: 0.5
|
||||
- name: spectral_mean_substraction
|
||||
|
||||
postprocess:
|
||||
nms_kernel_size: 9
|
||||
detection_threshold: 0.01
|
||||
top_k_per_sec: 200
|
||||
enabled: true
|
||||
method: poly
|
||||
|
||||
model:
|
||||
input_height: 128
|
||||
in_channels: 1
|
||||
out_channels: 32
|
||||
encoder:
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 32
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 64
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 128
|
||||
- name: ConvBlock
|
||||
out_channels: 256
|
||||
bottleneck:
|
||||
channels: 256
|
||||
layers:
|
||||
- name: SelfAttention
|
||||
attention_channels: 256
|
||||
decoder:
|
||||
layers:
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 64
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- name: ConvBlock
|
||||
out_channels: 32
|
||||
samplerate: 256000
|
||||
|
||||
preprocess:
|
||||
stft:
|
||||
window_duration: 0.002
|
||||
window_overlap: 0.75
|
||||
window_fn: hann
|
||||
frequencies:
|
||||
max_freq: 120000
|
||||
min_freq: 10000
|
||||
size:
|
||||
height: 128
|
||||
resize_factor: 0.5
|
||||
spectrogram_transforms:
|
||||
- name: pcen
|
||||
time_constant: 0.1
|
||||
gain: 0.98
|
||||
bias: 2
|
||||
power: 0.5
|
||||
- name: spectral_mean_subtraction
|
||||
|
||||
architecture:
|
||||
name: UNetBackbone
|
||||
input_height: 128
|
||||
in_channels: 1
|
||||
encoder:
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 32
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 64
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 128
|
||||
- name: ConvBlock
|
||||
out_channels: 256
|
||||
bottleneck:
|
||||
channels: 256
|
||||
layers:
|
||||
- name: SelfAttention
|
||||
attention_channels: 256
|
||||
decoder:
|
||||
layers:
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 64
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- name: ConvBlock
|
||||
out_channels: 32
|
||||
|
||||
postprocess:
|
||||
nms_kernel_size: 9
|
||||
detection_threshold: 0.01
|
||||
top_k_per_sec: 200
|
||||
|
||||
train:
|
||||
optimizer:
|
||||
name: adam
|
||||
learning_rate: 0.001
|
||||
|
||||
scheduler:
|
||||
name: cosine_annealing
|
||||
t_max: 100
|
||||
|
||||
labels:
|
||||
@ -76,10 +85,7 @@ train:
|
||||
|
||||
train_loader:
|
||||
batch_size: 8
|
||||
|
||||
num_workers: 2
|
||||
|
||||
shuffle: True
|
||||
shuffle: true
|
||||
|
||||
clipping_strategy:
|
||||
name: random_subclip
|
||||
@ -115,7 +121,6 @@ train:
|
||||
max_masks: 3
|
||||
|
||||
val_loader:
|
||||
num_workers: 2
|
||||
clipping_strategy:
|
||||
name: whole_audio_padded
|
||||
chunk_size: 0.256
|
||||
@ -134,9 +139,6 @@ train:
|
||||
size:
|
||||
weight: 0.1
|
||||
|
||||
logger:
|
||||
name: csv
|
||||
|
||||
validation:
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
@ -146,6 +148,10 @@ train:
|
||||
metrics:
|
||||
- name: average_precision
|
||||
|
||||
logging:
|
||||
train:
|
||||
name: csv
|
||||
|
||||
evaluation:
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
|
||||
53
justfile
53
justfile
@ -14,60 +14,67 @@ HTML_COVERAGE_DIR := "htmlcov"
|
||||
help:
|
||||
@just --list
|
||||
|
||||
install:
|
||||
uv sync
|
||||
|
||||
# Testing & Coverage
|
||||
# Run tests using pytest.
|
||||
test:
|
||||
pytest {{TESTS_DIR}}
|
||||
uv run pytest {{TESTS_DIR}}
|
||||
|
||||
# Run tests and generate coverage data.
|
||||
coverage:
|
||||
pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
|
||||
uv run pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
|
||||
|
||||
# Generate an HTML coverage report.
|
||||
coverage-html: coverage
|
||||
@echo "Generating HTML coverage report..."
|
||||
coverage html -d {{HTML_COVERAGE_DIR}}
|
||||
uv run coverage html -d {{HTML_COVERAGE_DIR}}
|
||||
@echo "HTML coverage report generated in {{HTML_COVERAGE_DIR}}/"
|
||||
|
||||
# Serve the HTML coverage report locally.
|
||||
coverage-serve: coverage-html
|
||||
@echo "Serving report at http://localhost:8000/ ..."
|
||||
python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
|
||||
uv run python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
|
||||
|
||||
# Documentation
|
||||
# Build documentation using Sphinx.
|
||||
docs:
|
||||
sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
|
||||
uv run sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
|
||||
|
||||
# Serve documentation with live reload.
|
||||
docs-serve:
|
||||
sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
|
||||
uv run sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
|
||||
|
||||
# Formatting & Linting
|
||||
# Format code using ruff.
|
||||
format:
|
||||
ruff format {{PYTHON_DIRS}}
|
||||
|
||||
# Check code formatting using ruff.
|
||||
format-check:
|
||||
ruff format --check {{PYTHON_DIRS}}
|
||||
|
||||
# Lint code using ruff.
|
||||
lint:
|
||||
ruff check {{PYTHON_DIRS}}
|
||||
fix-format:
|
||||
uv run ruff format {{PYTHON_DIRS}}
|
||||
|
||||
# Lint code using ruff and apply automatic fixes.
|
||||
lint-fix:
|
||||
ruff check --fix {{PYTHON_DIRS}}
|
||||
fix-lint:
|
||||
uv run ruff check --fix {{PYTHON_DIRS}}
|
||||
|
||||
# Combined Formatting & Linting
|
||||
fix: fix-format fix-lint
|
||||
|
||||
# Checking tasks
|
||||
# Check code formatting using ruff.
|
||||
check-format:
|
||||
uv run ruff format --check {{PYTHON_DIRS}}
|
||||
|
||||
# Lint code using ruff.
|
||||
check-lint:
|
||||
uv run ruff check {{PYTHON_DIRS}}
|
||||
|
||||
# Type Checking
|
||||
# Type check code using pyright.
|
||||
typecheck:
|
||||
pyright {{PYTHON_DIRS}}
|
||||
# Type check code using ty.
|
||||
check-types:
|
||||
uv run ty check {{PYTHON_DIRS}}
|
||||
|
||||
# Combined Checks
|
||||
# Run all checks (format-check, lint, typecheck).
|
||||
check: format-check lint typecheck test
|
||||
check: check-format check-lint check-types
|
||||
|
||||
# Cleaning tasks
|
||||
# Remove Python bytecode and cache.
|
||||
@ -95,7 +102,7 @@ clean: clean-build clean-pyc clean-test clean-docs
|
||||
|
||||
# Train on example data.
|
||||
example-train OPTIONS="":
|
||||
batdetect2 train \
|
||||
uv run batdetect2 train \
|
||||
--val-dataset example_data/dataset.yaml \
|
||||
--config example_data/config.yaml \
|
||||
{{OPTIONS}} \
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -1,440 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "cfb0b360-a204-4c27-a18f-3902e8758879",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:02.699871Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:02.699590Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:02.710312Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:02.709798Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:02.699839Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%load_ext autoreload\n",
|
||||
"%autoreload 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "326c5432-94e6-4abf-a332-fe902559461b",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:02.711324Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:02.711067Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.092380Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.091830Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:02.711304Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/santiago/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from pathlib import Path\n",
|
||||
"from typing import List, Optional\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"import pytorch_lightning as pl\n",
|
||||
"from batdetect2.train.modules import DetectorModel\n",
|
||||
"from batdetect2.train.augmentations import (\n",
|
||||
" add_echo,\n",
|
||||
" select_random_subclip,\n",
|
||||
" warp_spectrogram,\n",
|
||||
")\n",
|
||||
"from batdetect2.train.dataset import LabeledDataset, get_files\n",
|
||||
"from batdetect2.train.preprocess import PreprocessingConfig\n",
|
||||
"from soundevent import data\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"from soundevent.types import ClassMapper\n",
|
||||
"from torch.utils.data import DataLoader"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "9402a473-0b25-4123-9fa8-ad1f71a4237a",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-18T22:39:12.395329Z",
|
||||
"iopub.status.busy": "2024-11-18T22:39:12.393444Z",
|
||||
"iopub.status.idle": "2024-11-18T22:39:12.405938Z",
|
||||
"shell.execute_reply": "2024-11-18T22:39:12.402980Z",
|
||||
"shell.execute_reply.started": "2024-11-18T22:39:12.395236Z"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"## Training Datasets"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "cfd97d83-8c2b-46c8-9eae-cea59f53bc61",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.093487Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.092990Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.121636Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.121143Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.093459Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_dir = Path.cwd().parent / \"example_data\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "d5131ae9-2efd-4758-b6e5-189a6d90789b",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.122685Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.122270Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.151386Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.150788Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.122661Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"files = get_files(data_dir / \"preprocessed\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "bc733d3d-7829-4e90-896d-a0dc76b33288",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.152327Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.152060Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.184041Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.183372Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.152305Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataset = LabeledDataset(files)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "dfbb94ab-7b12-4689-9c15-4dc34cd17cb2",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.186393Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.186117Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.220175Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.219322Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.186375Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_dataloader = DataLoader(\n",
|
||||
" train_dataset,\n",
|
||||
" shuffle=True,\n",
|
||||
" batch_size=32,\n",
|
||||
" num_workers=4,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "e2eedaa9-6be3-481a-8786-7618515d98f8",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.221653Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.221242Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.260977Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.260375Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.221616Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# List of all possible classes\n",
|
||||
"class Mapper(ClassMapper):\n",
|
||||
" class_labels = [\n",
|
||||
" \"Eptesicus serotinus\",\n",
|
||||
" \"Myotis mystacinus\",\n",
|
||||
" \"Pipistrellus pipistrellus\",\n",
|
||||
" \"Rhinolophus ferrumequinum\",\n",
|
||||
" ]\n",
|
||||
"\n",
|
||||
" def encode(self, x: data.SoundEventAnnotation) -> Optional[str]:\n",
|
||||
" event_tag = data.find_tag(x.tags, \"event\")\n",
|
||||
"\n",
|
||||
" if event_tag.value == \"Social\":\n",
|
||||
" return \"social\"\n",
|
||||
"\n",
|
||||
" if event_tag.value != \"Echolocation\":\n",
|
||||
" # Ignore all other types of calls\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
" species_tag = data.find_tag(x.tags, \"class\")\n",
|
||||
" return species_tag.value\n",
|
||||
"\n",
|
||||
" def decode(self, class_name: str) -> List[data.Tag]:\n",
|
||||
" if class_name == \"social\":\n",
|
||||
" return [data.Tag(key=\"event\", value=\"social\")]\n",
|
||||
"\n",
|
||||
" return [data.Tag(key=\"class\", value=class_name)]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "1ff6072c-511e-42fe-a74f-282f269b80f0",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.262337Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.261775Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.309793Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.309216Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.262307Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"detector = DetectorModel(class_mapper=Mapper())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "3a763ee6-15bc-4105-a409-f06e0ad21a06",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.310695Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.310438Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:09.366636Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:09.366059Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.310669Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"GPU available: False, used: False\n",
|
||||
"TPU available: False, using: 0 TPU cores\n",
|
||||
"HPU available: False, using: 0 HPUs\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"trainer = pl.Trainer(\n",
|
||||
" limit_train_batches=100,\n",
|
||||
" max_epochs=2,\n",
|
||||
" log_every_n_steps=1,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "0b86d49d-3314-4257-94f5-f964855be385",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.execute_input": "2024-11-19T17:33:09.367499Z",
|
||||
"iopub.status.busy": "2024-11-19T17:33:09.367242Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:10.811300Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:10.809823Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:09.367473Z"
|
||||
}
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
" | Name | Type | Params | Mode \n",
|
||||
"--------------------------------------------------------\n",
|
||||
"0 | feature_extractor | Net2DFast | 119 K | train\n",
|
||||
"1 | classifier | Conv2d | 54 | train\n",
|
||||
"2 | bbox | Conv2d | 18 | train\n",
|
||||
"--------------------------------------------------------\n",
|
||||
"119 K Trainable params\n",
|
||||
"448 Non-trainable params\n",
|
||||
"119 K Total params\n",
|
||||
"0.480 Total estimated model params size (MB)\n",
|
||||
"32 Modules in train mode\n",
|
||||
"0 Modules in eval mode\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Epoch 0: 0%| | 0/1 [00:00<?, ?it/s]class heatmap shape torch.Size([3, 4, 128, 512])\n",
|
||||
"class props shape torch.Size([3, 5, 128, 512])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "RuntimeError",
|
||||
"evalue": "The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdetector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:538\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 536\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m=\u001b[39m TrainerStatus\u001b[38;5;241m.\u001b[39mRUNNING\n\u001b[1;32m 537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 538\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 539\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 540\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:47\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 47\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 50\u001b[0m _call_teardown_hook(trainer)\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:574\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 568\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 569\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 570\u001b[0m ckpt_path,\n\u001b[1;32m 571\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 572\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 573\u001b[0m )\n\u001b[0;32m--> 574\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 576\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 577\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:981\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 976\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 978\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 981\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 983\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 984\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 985\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 986\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:1025\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1023\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 1024\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1025\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1026\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1027\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:205\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:363\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 363\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone:\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end(data_fetcher)\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:250\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.advance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_batch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mlightning_module\u001b[38;5;241m.\u001b[39mautomatic_optimization:\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# in automatic optimization, there can only be one optimizer\u001b[39;00m\n\u001b[0;32m--> 250\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautomatic_optimization\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 252\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_optimization\u001b[38;5;241m.\u001b[39mrun(kwargs)\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:190\u001b[0m, in \u001b[0;36m_AutomaticOptimization.run\u001b[0;34m(self, optimizer, batch_idx, kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m closure()\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# BACKWARD PASS\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# gradient update with accumulated gradients\u001b[39;00m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 190\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 192\u001b[0m result \u001b[38;5;241m=\u001b[39m closure\u001b[38;5;241m.\u001b[39mconsume_result()\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:268\u001b[0m, in \u001b[0;36m_AutomaticOptimization._optimizer_step\u001b[0;34m(self, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_ready()\n\u001b[1;32m 267\u001b[0m \u001b[38;5;66;03m# model hook\u001b[39;00m\n\u001b[0;32m--> 268\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_lightning_module_hook\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moptimizer_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_step_and_backward_closure\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m should_accumulate:\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_completed()\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:167\u001b[0m, in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(trainer, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m hook_name\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[LightningModule]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpl_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 167\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 170\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/module.py:1306\u001b[0m, in \u001b[0;36mLightningModule.optimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_closure)\u001b[0m\n\u001b[1;32m 1275\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21moptimizer_step\u001b[39m(\n\u001b[1;32m 1276\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1277\u001b[0m epoch: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1280\u001b[0m optimizer_closure: Optional[Callable[[], Any]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1281\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls\u001b[39;00m\n\u001b[1;32m 1283\u001b[0m \u001b[38;5;124;03m the optimizer.\u001b[39;00m\n\u001b[1;32m 1284\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1304\u001b[0m \n\u001b[1;32m 1305\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1306\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer_closure\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/optimizer.py:153\u001b[0m, in \u001b[0;36mLightningOptimizer.step\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MisconfigurationException(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWhen `optimizer.step(closure)` is called, the closure should be callable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_strategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 153\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_strategy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_after_step()\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m step_output\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:238\u001b[0m, in \u001b[0;36mStrategy.optimizer_step\u001b[0;34m(self, optimizer, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;66;03m# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed\u001b[39;00m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, pl\u001b[38;5;241m.\u001b[39mLightningModule)\n\u001b[0;32m--> 238\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_plugin\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:122\u001b[0m, in \u001b[0;36mPrecision.optimizer_step\u001b[0;34m(self, optimizer, model, closure, **kwargs)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Hook to run the optimizer step.\"\"\"\u001b[39;00m\n\u001b[1;32m 121\u001b[0m closure \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wrap_closure, model, optimizer, closure)\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:130\u001b[0m, in \u001b[0;36mLRScheduler.__init__.<locals>.patch_track_step_called.<locals>.wrap_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 128\u001b[0m opt \u001b[38;5;241m=\u001b[39m opt_ref()\n\u001b[1;32m 129\u001b[0m opt\u001b[38;5;241m.\u001b[39m_opt_called \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m \u001b[38;5;66;03m# type: ignore[union-attr]\u001b[39;00m\n\u001b[0;32m--> 130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__get__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mopt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__class__\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:484\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 481\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 482\u001b[0m )\n\u001b[0;32m--> 484\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 485\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 487\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:89\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.<locals>._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 87\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 88\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 89\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 91\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/adam.py:205\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m closure \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39menable_grad():\n\u001b[0;32m--> 205\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparam_groups:\n\u001b[1;32m 208\u001b[0m params_with_grad: List[Tensor] \u001b[38;5;241m=\u001b[39m []\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:108\u001b[0m, in \u001b[0;36mPrecision._wrap_closure\u001b[0;34m(self, model, optimizer, closure)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrap_closure\u001b[39m(\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 97\u001b[0m model: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpl.LightningModule\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 98\u001b[0m optimizer: Steppable,\n\u001b[1;32m 99\u001b[0m closure: Callable[[], Any],\n\u001b[1;32m 100\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 101\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;124;03m hook is called.\u001b[39;00m\n\u001b[1;32m 103\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 108\u001b[0m closure_result \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_after_closure(model, optimizer)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m closure_result\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:144\u001b[0m, in \u001b[0;36mClosure.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Optional[Tensor]:\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result\u001b[38;5;241m.\u001b[39mloss\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:129\u001b[0m, in \u001b[0;36mClosure.closure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39menable_grad()\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mclosure\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ClosureResult:\n\u001b[0;32m--> 129\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step_output\u001b[38;5;241m.\u001b[39mclosure_loss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwarning_cache\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`training_step` returned `None`. If this was on purpose, ignore this warning...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:317\u001b[0m, in \u001b[0;36m_AutomaticOptimization._training_step\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Performs the actual train step with the tied hooks.\u001b[39;00m\n\u001b[1;32m 307\u001b[0m \n\u001b[1;32m 308\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 313\u001b[0m \n\u001b[1;32m 314\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 315\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\n\u001b[0;32m--> 317\u001b[0m training_step_output \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtraining_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mpost_training_step() \u001b[38;5;66;03m# unused hook - call anyway for backward compatibility\u001b[39;00m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m training_step_output \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mworld_size \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:319\u001b[0m, in \u001b[0;36m_call_strategy_hook\u001b[0;34m(trainer, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 319\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 322\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.training_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module:\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_redirection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_step\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:167\u001b[0m, in \u001b[0;36mDetectorModel.training_step\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtraining_step\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch: TrainExample):\n\u001b[1;32m 166\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(batch\u001b[38;5;241m.\u001b[39mspec)\n\u001b[0;32m--> 167\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:150\u001b[0m, in \u001b[0;36mDetectorModel.compute_loss\u001b[0;34m(self, outputs, batch)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mclass props shape\u001b[39m\u001b[38;5;124m\"\u001b[39m, outputs\u001b[38;5;241m.\u001b[39mclass_probs\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 149\u001b[0m valid_mask \u001b[38;5;241m=\u001b[39m batch\u001b[38;5;241m.\u001b[39mclass_heatmap\u001b[38;5;241m.\u001b[39many(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m--> 150\u001b[0m classification_loss \u001b[38;5;241m=\u001b[39m \u001b[43mlosses\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal_loss\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[43m \u001b[49m\u001b[43moutputs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_probs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 152\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_heatmap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 153\u001b[0m \u001b[43m \u001b[49m\u001b[43mweights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalid_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalid_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbeta\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 156\u001b[0m \u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43malpha\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 157\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 160\u001b[0m detection_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mdetection\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 161\u001b[0m \u001b[38;5;241m+\u001b[39m size_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39msize\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 162\u001b[0m \u001b[38;5;241m+\u001b[39m classification_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mclassification\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 163\u001b[0m )\n",
|
||||
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/losses.py:38\u001b[0m, in \u001b[0;36mfocal_loss\u001b[0;34m(pred, gt, weights, valid_mask, eps, beta, alpha)\u001b[0m\n\u001b[1;32m 35\u001b[0m pos_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39meq(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[1;32m 36\u001b[0m neg_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39mlt(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m---> 38\u001b[0m pos_loss \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpred\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpow\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpred\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpos_inds\u001b[49m\n\u001b[1;32m 39\u001b[0m neg_loss \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 40\u001b[0m torch\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m pred \u001b[38;5;241m+\u001b[39m eps)\n\u001b[1;32m 41\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(pred, alpha)\n\u001b[1;32m 42\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m gt, beta)\n\u001b[1;32m 43\u001b[0m \u001b[38;5;241m*\u001b[39m neg_inds\n\u001b[1;32m 44\u001b[0m )\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m weights \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
||||
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"trainer.fit(detector, train_dataloaders=train_dataloader)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2f6924db-e520-49a1-bbe8-6c4956e46314",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.status.busy": "2024-11-19T17:33:10.811729Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:10.811955Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:10.811858Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:10.811849Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"clip_annotation = train_dataset.get_clip_annotation(0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "23943e13-6875-49b8-9f18-2ba6528aa673",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.status.busy": "2024-11-19T17:33:10.812924Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:10.813260Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:10.813104Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:10.813087Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"spec = detector.compute_spectrogram(clip_annotation.clip)\n",
|
||||
"outputs = detector(torch.tensor(spec.values).unsqueeze(0).unsqueeze(0))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "dd1fe346-0873-4b14-ae1b-92ef1f4f27a5",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.status.busy": "2024-11-19T17:33:10.814343Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:10.814806Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:10.814628Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:10.814611Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"_, ax= plt.subplots(figsize=(15, 5))\n",
|
||||
"spec.plot(ax=ax, add_colorbar=False)\n",
|
||||
"ax.pcolormesh(spec.time, spec.frequency, outputs.detection_probs.detach().squeeze())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "eadd36ef-a04a-4665-b703-cec84cf1673b",
|
||||
"metadata": {
|
||||
"execution": {
|
||||
"iopub.status.busy": "2024-11-19T17:33:10.815603Z",
|
||||
"iopub.status.idle": "2024-11-19T17:33:10.816065Z",
|
||||
"shell.execute_reply": "2024-11-19T17:33:10.815894Z",
|
||||
"shell.execute_reply.started": "2024-11-19T17:33:10.815877Z"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(f\"Num predicted soundevents: {len(predictions.sound_events)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e4e54f3e-6ddc-4fe5-8ce0-b527ff6f18ae",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "batdetect2-dev",
|
||||
"language": "python",
|
||||
"name": "batdetect2-dev"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@ -1,194 +0,0 @@
|
||||
import marimo
|
||||
|
||||
__generated_with = "0.14.16"
|
||||
app = marimo.App(width="medium")
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
import marimo as mo
|
||||
return (mo,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
from batdetect2.data import (
|
||||
load_dataset_config,
|
||||
load_dataset,
|
||||
extract_recordings_df,
|
||||
extract_sound_events_df,
|
||||
compute_class_summary,
|
||||
)
|
||||
return (
|
||||
compute_class_summary,
|
||||
extract_recordings_df,
|
||||
extract_sound_events_df,
|
||||
load_dataset,
|
||||
load_dataset_config,
|
||||
)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(mo):
|
||||
dataset_config_browser = mo.ui.file_browser(
|
||||
selection_mode="file",
|
||||
multiple=False,
|
||||
)
|
||||
dataset_config_browser
|
||||
return (dataset_config_browser,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(dataset_config_browser, load_dataset_config, mo):
|
||||
mo.stop(dataset_config_browser.path() is None)
|
||||
dataset_config = load_dataset_config(dataset_config_browser.path())
|
||||
return (dataset_config,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(dataset_config, load_dataset):
|
||||
dataset = load_dataset(dataset_config, base_dir="../paper/")
|
||||
return (dataset,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
from batdetect2.targets import load_target_config, build_targets
|
||||
return build_targets, load_target_config
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(mo):
|
||||
targets_config_browser = mo.ui.file_browser(
|
||||
selection_mode="file",
|
||||
multiple=False,
|
||||
)
|
||||
targets_config_browser
|
||||
return (targets_config_browser,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(load_target_config, mo, targets_config_browser):
|
||||
mo.stop(targets_config_browser.path() is None)
|
||||
targets_config = load_target_config(targets_config_browser.path())
|
||||
return (targets_config,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(build_targets, targets_config):
|
||||
targets = build_targets(targets_config)
|
||||
return (targets,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
import pandas as pd
|
||||
from soundevent.geometry import compute_bounds
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(dataset, extract_recordings_df):
|
||||
recordings = extract_recordings_df(dataset)
|
||||
recordings
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(dataset, extract_sound_events_df, targets):
|
||||
sound_events = extract_sound_events_df(dataset, targets)
|
||||
sound_events
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(compute_class_summary, dataset, targets):
|
||||
compute_class_summary(dataset, targets)
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
from batdetect2.data.split import split_dataset_by_recordings
|
||||
return (split_dataset_by_recordings,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(dataset, split_dataset_by_recordings, targets):
|
||||
train_dataset, val_dataset = split_dataset_by_recordings(dataset, targets, random_state=42)
|
||||
return train_dataset, val_dataset
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(compute_class_summary, targets, train_dataset):
|
||||
compute_class_summary(train_dataset, targets)
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(compute_class_summary, targets, val_dataset):
|
||||
compute_class_summary(val_dataset, targets)
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
from soundevent import io, data
|
||||
from pathlib import Path
|
||||
return Path, data, io
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(Path, data, io, train_dataset):
|
||||
io.save(
|
||||
data.AnnotationSet(
|
||||
name="batdetect2_tuning_train",
|
||||
description="Set of annotations used as the train dataset for the hyper-parameter tuning stage.",
|
||||
clip_annotations=train_dataset,
|
||||
),
|
||||
Path("../paper/data/datasets/annotation_sets/tuning_train.json"),
|
||||
audio_dir=Path("../paper/data/datasets/"),
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(Path, data, io, val_dataset):
|
||||
io.save(
|
||||
data.AnnotationSet(
|
||||
name="batdetect2_tuning_val",
|
||||
description="Set of annotations used as the validation dataset for the hyper-parameter tuning stage.",
|
||||
clip_annotations=val_dataset,
|
||||
),
|
||||
Path("../paper/data/datasets/annotation_sets/tuning_val.json"),
|
||||
audio_dir=Path("../paper/data/datasets/"),
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(load_dataset, load_dataset_config):
|
||||
config = load_dataset_config("../paper/conf/datasets/train/uk_tune.yaml")
|
||||
rec = load_dataset(config, base_dir="../paper/")
|
||||
return (rec,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(rec):
|
||||
dict(rec[0].sound_events[0].tags[0].term)
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(compute_class_summary, rec, targets):
|
||||
compute_class_summary(rec,targets)
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
@ -1,273 +0,0 @@
|
||||
import marimo
|
||||
|
||||
__generated_with = "0.14.16"
|
||||
app = marimo.App(width="medium")
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
import marimo as mo
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
from batdetect2.data import load_dataset_config, load_dataset
|
||||
from batdetect2.preprocess import load_preprocessing_config, build_preprocessor
|
||||
from batdetect2 import api
|
||||
from soundevent import data
|
||||
from batdetect2.evaluate.types import MatchEvaluation
|
||||
from batdetect2.types import Annotation
|
||||
from batdetect2.compat import annotation_to_sound_event_prediction
|
||||
from batdetect2.plotting import (
|
||||
plot_clip,
|
||||
plot_clip_annotation,
|
||||
plot_clip_prediction,
|
||||
plot_matches,
|
||||
plot_false_positive_match,
|
||||
plot_false_negative_match,
|
||||
plot_true_positive_match,
|
||||
plot_cross_trigger_match,
|
||||
)
|
||||
return (
|
||||
MatchEvaluation,
|
||||
annotation_to_sound_event_prediction,
|
||||
api,
|
||||
build_preprocessor,
|
||||
data,
|
||||
load_dataset,
|
||||
load_dataset_config,
|
||||
load_preprocessing_config,
|
||||
plot_clip_annotation,
|
||||
plot_clip_prediction,
|
||||
plot_cross_trigger_match,
|
||||
plot_false_negative_match,
|
||||
plot_false_positive_match,
|
||||
plot_matches,
|
||||
plot_true_positive_match,
|
||||
)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(build_preprocessor, load_dataset_config, load_preprocessing_config):
|
||||
dataset_config = load_dataset_config(
|
||||
path="example_data/config.yaml", field="datasets.train"
|
||||
)
|
||||
|
||||
preprocessor_config = load_preprocessing_config(
|
||||
path="example_data/config.yaml", field="preprocess"
|
||||
)
|
||||
|
||||
preprocessor = build_preprocessor(preprocessor_config)
|
||||
return dataset_config, preprocessor
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(dataset_config, load_dataset):
|
||||
dataset = load_dataset(dataset_config)
|
||||
return (dataset,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(dataset):
|
||||
clip_annotation = dataset[1]
|
||||
return (clip_annotation,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(clip_annotation, plot_clip_annotation, preprocessor):
|
||||
plot_clip_annotation(
|
||||
clip_annotation, preprocessor=preprocessor, figsize=(15, 5)
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(annotation_to_sound_event_prediction, api, clip_annotation, data):
|
||||
audio = api.load_audio(clip_annotation.clip.recording.path)
|
||||
detections, features, spec = api.process_audio(audio)
|
||||
clip_prediction = data.ClipPrediction(
|
||||
clip=clip_annotation.clip,
|
||||
sound_events=[
|
||||
annotation_to_sound_event_prediction(
|
||||
prediction, clip_annotation.clip.recording
|
||||
)
|
||||
for prediction in detections
|
||||
],
|
||||
)
|
||||
return (clip_prediction,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(clip_prediction, plot_clip_prediction):
|
||||
plot_clip_prediction(clip_prediction, figsize=(15, 5))
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
from batdetect2.evaluate import match_predictions_and_annotations
|
||||
import random
|
||||
return match_predictions_and_annotations, random
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(data, random):
|
||||
def add_noise(clip_annotation, time_buffer=0.003, freq_buffer=1000):
|
||||
def _add_bbox_noise(bbox):
|
||||
start_time, low_freq, end_time, high_freq = bbox.coordinates
|
||||
return data.BoundingBox(
|
||||
coordinates=[
|
||||
start_time + random.uniform(-time_buffer, time_buffer),
|
||||
low_freq + random.uniform(-freq_buffer, freq_buffer),
|
||||
end_time + random.uniform(-time_buffer, time_buffer),
|
||||
high_freq + random.uniform(-freq_buffer, freq_buffer),
|
||||
]
|
||||
)
|
||||
|
||||
def _add_noise(se):
|
||||
return se.model_copy(
|
||||
update=dict(
|
||||
sound_event=se.sound_event.model_copy(
|
||||
update=dict(
|
||||
geometry=_add_bbox_noise(se.sound_event.geometry)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return clip_annotation.model_copy(
|
||||
update=dict(
|
||||
sound_events=[
|
||||
_add_noise(se) for se in clip_annotation.sound_events
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def drop_random(obj, p=0.5):
|
||||
return obj.model_copy(
|
||||
update=dict(
|
||||
sound_events=[se for se in obj.sound_events if random.random() > p]
|
||||
)
|
||||
)
|
||||
return add_noise, drop_random
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(
|
||||
add_noise,
|
||||
clip_annotation,
|
||||
clip_prediction,
|
||||
drop_random,
|
||||
match_predictions_and_annotations,
|
||||
):
|
||||
|
||||
|
||||
matches = match_predictions_and_annotations(
|
||||
drop_random(add_noise(clip_annotation), p=0.2),
|
||||
drop_random(clip_prediction),
|
||||
)
|
||||
return (matches,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(clip_annotation, matches, plot_matches):
|
||||
plot_matches(matches, clip_annotation.clip, figsize=(15, 5))
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(matches):
|
||||
true_positives = []
|
||||
false_positives = []
|
||||
false_negatives = []
|
||||
|
||||
for match in matches:
|
||||
if match.source is None and match.target is not None:
|
||||
false_negatives.append(match)
|
||||
elif match.target is None and match.source is not None:
|
||||
false_positives.append(match)
|
||||
elif match.target is not None and match.source is not None:
|
||||
true_positives.append(match)
|
||||
else:
|
||||
continue
|
||||
|
||||
return false_negatives, false_positives, true_positives
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(MatchEvaluation, false_positives, plot_false_positive_match):
|
||||
false_positive = false_positives[0]
|
||||
false_positive_eval = MatchEvaluation(
|
||||
match=false_positive,
|
||||
gt_det=False,
|
||||
gt_class=None,
|
||||
pred_score=false_positive.source.score,
|
||||
pred_class_scores={
|
||||
"myomyo": 0.2
|
||||
}
|
||||
)
|
||||
|
||||
plot_false_positive_match(false_positive_eval)
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(MatchEvaluation, false_negatives, plot_false_negative_match):
|
||||
false_negative = false_negatives[0]
|
||||
false_negative_eval = MatchEvaluation(
|
||||
match=false_negative,
|
||||
gt_det=True,
|
||||
gt_class="myomyo",
|
||||
pred_score=None,
|
||||
pred_class_scores={}
|
||||
)
|
||||
|
||||
plot_false_negative_match(false_negative_eval)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(MatchEvaluation, plot_true_positive_match, true_positives):
|
||||
true_positive = true_positives[0]
|
||||
true_positive_eval = MatchEvaluation(
|
||||
match=true_positive,
|
||||
gt_det=True,
|
||||
gt_class="myomyo",
|
||||
pred_score=0.87,
|
||||
pred_class_scores={
|
||||
"pyomyo": 0.84,
|
||||
"pippip": 0.84,
|
||||
}
|
||||
)
|
||||
|
||||
plot_true_positive_match(true_positive_eval)
|
||||
return (true_positive,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(MatchEvaluation, plot_cross_trigger_match, true_positive):
|
||||
cross_trigger_eval = MatchEvaluation(
|
||||
match=true_positive,
|
||||
gt_det=True,
|
||||
gt_class="myomyo",
|
||||
pred_score=0.87,
|
||||
pred_class_scores={
|
||||
"pippip": 0.84,
|
||||
"myomyo": 0.84,
|
||||
}
|
||||
)
|
||||
|
||||
plot_cross_trigger_match(cross_trigger_eval)
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
@ -1,19 +0,0 @@
|
||||
import marimo
|
||||
|
||||
__generated_with = "0.13.15"
|
||||
app = marimo.App(width="medium")
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
@ -1,97 +0,0 @@
|
||||
import marimo
|
||||
|
||||
__generated_with = "0.14.5"
|
||||
app = marimo.App(width="medium")
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
import marimo as mo
|
||||
return (mo,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
from batdetect2.data import load_dataset, load_dataset_config
|
||||
return load_dataset, load_dataset_config
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(mo):
|
||||
dataset_input = mo.ui.file_browser(label="dataset config file")
|
||||
dataset_input
|
||||
return (dataset_input,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(mo):
|
||||
audio_dir_input = mo.ui.file_browser(
|
||||
label="audio directory", selection_mode="directory"
|
||||
)
|
||||
audio_dir_input
|
||||
return (audio_dir_input,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(dataset_input, load_dataset_config):
|
||||
dataset_config = load_dataset_config(
|
||||
path=dataset_input.path(), field="datasets.train",
|
||||
)
|
||||
return (dataset_config,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(audio_dir_input, dataset_config, load_dataset):
|
||||
dataset = load_dataset(dataset_config, base_dir=audio_dir_input.path())
|
||||
return (dataset,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(dataset):
|
||||
len(dataset)
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(dataset):
|
||||
tag_groups = [
|
||||
se.tags
|
||||
for clip in dataset
|
||||
for se in clip.sound_events
|
||||
if se.tags
|
||||
]
|
||||
all_tags = [
|
||||
tag for group in tag_groups for tag in group
|
||||
]
|
||||
return (all_tags,)
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(mo):
|
||||
key_search = mo.ui.text(label="key", debounce=0.1)
|
||||
value_search = mo.ui.text(label="value", debounce=0.1)
|
||||
mo.hstack([key_search, value_search]).left()
|
||||
return key_search, value_search
|
||||
|
||||
|
||||
@app.cell
|
||||
def _(all_tags, key_search, mo, value_search):
|
||||
filtered_tags = list(set(all_tags))
|
||||
|
||||
if key_search.value:
|
||||
filtered_tags = [tag for tag in filtered_tags if key_search.value.lower() in tag.key.lower()]
|
||||
|
||||
if value_search.value:
|
||||
filtered_tags = [tag for tag in filtered_tags if value_search.value.lower() in tag.value.lower()]
|
||||
|
||||
mo.vstack([mo.md(f"key={tag.key} value={tag.value}") for tag in filtered_tags[:5]])
|
||||
return
|
||||
|
||||
|
||||
@app.cell
|
||||
def _():
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
@ -25,14 +25,14 @@ dependencies = [
|
||||
"scikit-learn>=1.2.2",
|
||||
"scipy>=1.10.1",
|
||||
"seaborn>=0.13.2",
|
||||
"soundevent[audio,geometry,plot]>=2.9.1",
|
||||
"soundevent[audio,geometry,plot]>=2.10.0",
|
||||
"tensorboard>=2.16.2",
|
||||
"torch>=1.13.1",
|
||||
"torchaudio>=1.13.1",
|
||||
"torchvision>=0.14.0",
|
||||
"tqdm>=4.66.2",
|
||||
]
|
||||
requires-python = ">=3.9,<3.13"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
readme = "README.md"
|
||||
license = { text = "CC-by-nc-4" }
|
||||
classifiers = [
|
||||
@ -75,7 +75,6 @@ dev = [
|
||||
"ruff>=0.7.3",
|
||||
"ipykernel>=6.29.4",
|
||||
"setuptools>=69.5.1",
|
||||
"basedpyright>=1.31.0",
|
||||
"myst-parser>=3.0.1",
|
||||
"sphinx-autobuild>=2024.10.3",
|
||||
"numpydoc>=1.8.0",
|
||||
@ -87,13 +86,24 @@ dev = [
|
||||
"rust-just>=1.40.0",
|
||||
"pandas-stubs>=2.2.2.240807",
|
||||
"python-lsp-server>=1.13.0",
|
||||
"deepdiff>=8.6.1",
|
||||
]
|
||||
dvclive = ["dvclive>=3.48.2"]
|
||||
mlflow = ["mlflow>=3.1.1"]
|
||||
gradio = [
|
||||
"gradio>=6.9.0",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 79
|
||||
target-version = "py39"
|
||||
target-version = "py310"
|
||||
exclude = [
|
||||
"src/batdetect2/train/legacy",
|
||||
"src/batdetect2/plotting/legacy",
|
||||
"src/batdetect2/evaluate/legacy",
|
||||
"src/batdetect2/finetune",
|
||||
"src/batdetect2/utils",
|
||||
]
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
@ -105,15 +115,12 @@ select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "numpy"
|
||||
|
||||
[tool.pyright]
|
||||
[tool.ty.src]
|
||||
include = ["src", "tests"]
|
||||
pythonVersion = "3.9"
|
||||
pythonPlatform = "All"
|
||||
exclude = [
|
||||
"src/batdetect2/detector/",
|
||||
"src/batdetect2/train/legacy",
|
||||
"src/batdetect2/plotting/legacy",
|
||||
"src/batdetect2/evaluate/legacy",
|
||||
"src/batdetect2/finetune",
|
||||
"src/batdetect2/utils",
|
||||
"src/batdetect2/plot",
|
||||
"src/batdetect2/evaluate/legacy",
|
||||
"src/batdetect2/train/legacy",
|
||||
]
|
||||
|
||||
@ -98,7 +98,6 @@ consult the API documentation in the code.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -165,7 +164,7 @@ def load_audio(
|
||||
time_exp_fact: float = 1,
|
||||
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
scale: bool = False,
|
||||
max_duration: Optional[float] = None,
|
||||
max_duration: float | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Load audio from file.
|
||||
|
||||
@ -203,7 +202,7 @@ def load_audio(
|
||||
def generate_spectrogram(
|
||||
audio: np.ndarray,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: Optional[SpectrogramParameters] = None,
|
||||
config: SpectrogramParameters | None = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> torch.Tensor:
|
||||
"""Generate spectrogram from audio array.
|
||||
@ -240,7 +239,7 @@ def generate_spectrogram(
|
||||
def process_file(
|
||||
audio_file: str,
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> du.RunResults:
|
||||
"""Process audio file with model.
|
||||
@ -271,8 +270,8 @@ def process_spectrogram(
|
||||
spec: torch.Tensor,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
config: ProcessingConfiguration | None = None,
|
||||
) -> tuple[list[Annotation], np.ndarray]:
|
||||
"""Process spectrogram with model.
|
||||
|
||||
Parameters
|
||||
@ -312,9 +311,9 @@ def process_audio(
|
||||
audio: np.ndarray,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
model: DetectionModel = MODEL,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
config: ProcessingConfiguration | None = None,
|
||||
device: torch.device = DEVICE,
|
||||
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
|
||||
) -> tuple[list[Annotation], np.ndarray, torch.Tensor]:
|
||||
"""Process audio array with model.
|
||||
|
||||
Parameters
|
||||
@ -356,8 +355,8 @@ def process_audio(
|
||||
def postprocess(
|
||||
outputs: ModelOutput,
|
||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
) -> Tuple[List[Annotation], np.ndarray]:
|
||||
config: ProcessingConfiguration | None = None,
|
||||
) -> tuple[list[Annotation], np.ndarray]:
|
||||
"""Postprocess model outputs.
|
||||
|
||||
Convert model tensor outputs to predicted bounding boxes and
|
||||
|
||||
@ -1,59 +1,90 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
from typing import Literal, Sequence, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from soundevent import data
|
||||
from soundevent.audio.files import get_audio_files
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.core import merge_configs
|
||||
from batdetect2.data import (
|
||||
OutputFormatConfig,
|
||||
build_output_formatter,
|
||||
get_output_formatter,
|
||||
load_dataset_from_config,
|
||||
from batdetect2.data import Dataset, load_dataset_from_config
|
||||
from batdetect2.evaluate import (
|
||||
DEFAULT_EVAL_DIR,
|
||||
EvaluationConfig,
|
||||
EvaluatorProtocol,
|
||||
build_evaluator,
|
||||
run_evaluate,
|
||||
save_evaluation_results,
|
||||
)
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.data.predictions.base import OutputFormatterProtocol
|
||||
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
||||
from batdetect2.inference import process_file_list, run_batch_inference
|
||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.inference import (
|
||||
InferenceConfig,
|
||||
process_file_list,
|
||||
run_batch_inference,
|
||||
)
|
||||
from batdetect2.logging import (
|
||||
DEFAULT_LOGS_DIR,
|
||||
AppLoggingConfig,
|
||||
LoggerConfig,
|
||||
)
|
||||
from batdetect2.models import (
|
||||
Model,
|
||||
ModelConfig,
|
||||
build_model,
|
||||
build_model_with_new_targets,
|
||||
)
|
||||
from batdetect2.models.detectors import Detector
|
||||
from batdetect2.outputs import (
|
||||
OutputFormatConfig,
|
||||
OutputFormatterProtocol,
|
||||
OutputsConfig,
|
||||
OutputTransformProtocol,
|
||||
build_output_formatter,
|
||||
build_output_transform,
|
||||
get_output_formatter,
|
||||
)
|
||||
from batdetect2.postprocess import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
PostprocessorProtocol,
|
||||
build_postprocessor,
|
||||
)
|
||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
||||
from batdetect2.train import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
TrainingConfig,
|
||||
load_model_from_checkpoint,
|
||||
train,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
BatDetect2Prediction,
|
||||
EvaluatorProtocol,
|
||||
PostprocessorProtocol,
|
||||
PreprocessorProtocol,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
run_train,
|
||||
)
|
||||
|
||||
|
||||
class BatDetect2API:
|
||||
def __init__(
|
||||
self,
|
||||
config: BatDetect2Config,
|
||||
model_config: ModelConfig,
|
||||
audio_config: AudioConfig,
|
||||
train_config: TrainingConfig,
|
||||
evaluation_config: EvaluationConfig,
|
||||
inference_config: InferenceConfig,
|
||||
outputs_config: OutputsConfig,
|
||||
logging_config: AppLoggingConfig,
|
||||
targets: TargetProtocol,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
evaluator: EvaluatorProtocol,
|
||||
formatter: OutputFormatterProtocol,
|
||||
output_transform: OutputTransformProtocol,
|
||||
model: Model,
|
||||
):
|
||||
self.config = config
|
||||
self.model_config = model_config
|
||||
self.audio_config = audio_config
|
||||
self.train_config = train_config
|
||||
self.evaluation_config = evaluation_config
|
||||
self.inference_config = inference_config
|
||||
self.outputs_config = outputs_config
|
||||
self.logging_config = logging_config
|
||||
self.targets = targets
|
||||
self.audio_loader = audio_loader
|
||||
self.preprocessor = preprocessor
|
||||
@ -61,34 +92,40 @@ class BatDetect2API:
|
||||
self.evaluator = evaluator
|
||||
self.model = model
|
||||
self.formatter = formatter
|
||||
self.output_transform = output_transform
|
||||
|
||||
self.model.eval()
|
||||
|
||||
def load_annotations(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> Dataset:
|
||||
return load_dataset_from_config(path, base_dir=base_dir)
|
||||
|
||||
def train(
|
||||
self,
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
||||
train_workers: Optional[int] = None,
|
||||
val_workers: Optional[int] = None,
|
||||
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
|
||||
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
run_name: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
|
||||
log_dir: Path | None = DEFAULT_LOGS_DIR,
|
||||
experiment_name: str | None = None,
|
||||
num_epochs: int | None = None,
|
||||
run_name: str | None = None,
|
||||
seed: int | None = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
):
|
||||
train(
|
||||
run_train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
model=self.model,
|
||||
targets=self.targets,
|
||||
config=self.config,
|
||||
model_config=model_config or self.model_config,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
train_workers=train_workers,
|
||||
@ -99,25 +136,81 @@ class BatDetect2API:
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
train_config=train_config or self.train_config,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
logger_config=logger_config or self.logging_config.train,
|
||||
)
|
||||
return self
|
||||
|
||||
def finetune(
|
||||
self,
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||
trainable: Literal[
|
||||
"all", "heads", "classifier_head", "bbox_head"
|
||||
] = "heads",
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
|
||||
log_dir: Path | None = DEFAULT_LOGS_DIR,
|
||||
experiment_name: str | None = None,
|
||||
num_epochs: int | None = None,
|
||||
run_name: str | None = None,
|
||||
seed: int | None = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
) -> "BatDetect2API":
|
||||
"""Fine-tune the model with trainable-parameter selection."""
|
||||
|
||||
self._set_trainable_parameters(trainable)
|
||||
|
||||
run_train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
model=self.model,
|
||||
targets=self.targets,
|
||||
model_config=model_config or self.model_config,
|
||||
preprocessor=self.preprocessor,
|
||||
audio_loader=self.audio_loader,
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
num_epochs=num_epochs,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
train_config=train_config or self.train_config,
|
||||
logger_config=logger_config or self.logging_config.train,
|
||||
)
|
||||
return self
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
num_workers: Optional[int] = None,
|
||||
num_workers: int = 0,
|
||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
save_predictions: bool = True,
|
||||
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
||||
return evaluate(
|
||||
audio_config: AudioConfig | None = None,
|
||||
evaluation_config: EvaluationConfig | None = None,
|
||||
outputs_config: OutputsConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
) -> tuple[dict[str, float], list[ClipDetections]]:
|
||||
return run_evaluate(
|
||||
self.model,
|
||||
test_annotations,
|
||||
targets=self.targets,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
config=self.config,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
evaluation_config=evaluation_config or self.evaluation_config,
|
||||
output_config=outputs_config or self.outputs_config,
|
||||
logger_config=logger_config or self.logging_config.evaluation,
|
||||
num_workers=num_workers,
|
||||
output_dir=output_dir,
|
||||
experiment_name=experiment_name,
|
||||
@ -128,8 +221,8 @@ class BatDetect2API:
|
||||
def evaluate_predictions(
|
||||
self,
|
||||
annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
output_dir: Optional[data.PathLike] = None,
|
||||
predictions: Sequence[ClipDetections],
|
||||
output_dir: data.PathLike | None = None,
|
||||
):
|
||||
clip_evals = self.evaluator.evaluate(
|
||||
annotations,
|
||||
@ -139,30 +232,66 @@ class BatDetect2API:
|
||||
metrics = self.evaluator.compute_metrics(clip_evals)
|
||||
|
||||
if output_dir is not None:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
if not output_dir.is_dir():
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
metrics_path = output_dir / "metrics.json"
|
||||
metrics_path.write_text(json.dumps(metrics))
|
||||
|
||||
for figure_name, fig in self.evaluator.generate_plots(clip_evals):
|
||||
fig_path = output_dir / figure_name
|
||||
|
||||
if not fig_path.parent.is_dir():
|
||||
fig_path.parent.mkdir(parents=True)
|
||||
|
||||
fig.savefig(fig_path)
|
||||
save_evaluation_results(
|
||||
metrics=metrics,
|
||||
plots=self.evaluator.generate_plots(clip_evals),
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
def load_audio(self, path: data.PathLike) -> np.ndarray:
|
||||
return self.audio_loader.load_file(path)
|
||||
|
||||
def load_recording(self, recording: data.Recording) -> np.ndarray:
|
||||
return self.audio_loader.load_recording(recording)
|
||||
|
||||
def load_clip(self, clip: data.Clip) -> np.ndarray:
|
||||
return self.audio_loader.load_clip(clip)
|
||||
|
||||
def get_top_class_name(self, detection: Detection) -> str:
|
||||
"""Get highest-confidence class name for one detection."""
|
||||
|
||||
top_index = int(np.argmax(detection.class_scores))
|
||||
return self.targets.class_names[top_index]
|
||||
|
||||
def get_class_scores(
|
||||
self,
|
||||
detection: Detection,
|
||||
*,
|
||||
include_top_class: bool = True,
|
||||
sort_descending: bool = True,
|
||||
) -> list[tuple[str, float]]:
|
||||
"""Get class score list as ``(class_name, score)`` pairs."""
|
||||
|
||||
scores = [
|
||||
(class_name, float(score))
|
||||
for class_name, score in zip(
|
||||
self.targets.class_names,
|
||||
detection.class_scores,
|
||||
strict=True,
|
||||
)
|
||||
]
|
||||
|
||||
if sort_descending:
|
||||
scores.sort(key=lambda item: item[1], reverse=True)
|
||||
|
||||
if include_top_class:
|
||||
return scores
|
||||
|
||||
top_class_name = self.get_top_class_name(detection)
|
||||
return [
|
||||
(class_name, score)
|
||||
for class_name, score in scores
|
||||
if class_name != top_class_name
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_detection_features(detection: Detection) -> np.ndarray:
|
||||
"""Get extracted feature vector for one detection."""
|
||||
|
||||
return detection.features
|
||||
|
||||
def generate_spectrogram(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
@ -170,24 +299,41 @@ class BatDetect2API:
|
||||
tensor = torch.tensor(audio).unsqueeze(0)
|
||||
return self.preprocessor(tensor)
|
||||
|
||||
def process_file(self, audio_file: str) -> BatDetect2Prediction:
|
||||
def process_file(
|
||||
self,
|
||||
audio_file: data.PathLike,
|
||||
batch_size: int | None = None,
|
||||
) -> ClipDetections:
|
||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||
wav = self.audio_loader.load_recording(recording)
|
||||
detections = self.process_audio(wav)
|
||||
return BatDetect2Prediction(
|
||||
|
||||
predictions = self.process_files(
|
||||
[audio_file],
|
||||
batch_size=(
|
||||
batch_size
|
||||
if batch_size is not None
|
||||
else self.inference_config.loader.batch_size
|
||||
),
|
||||
)
|
||||
detections = [
|
||||
detection
|
||||
for prediction in predictions
|
||||
for detection in prediction.detections
|
||||
]
|
||||
|
||||
return ClipDetections(
|
||||
clip=data.Clip(
|
||||
uuid=recording.uuid,
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
),
|
||||
predictions=detections,
|
||||
detections=detections,
|
||||
)
|
||||
|
||||
def process_audio(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
) -> List[RawPrediction]:
|
||||
) -> list[Detection]:
|
||||
spec = self.generate_spectrogram(audio)
|
||||
return self.process_spectrogram(spec)
|
||||
|
||||
@ -195,7 +341,7 @@ class BatDetect2API:
|
||||
self,
|
||||
spec: torch.Tensor,
|
||||
start_time: float = 0,
|
||||
) -> List[RawPrediction]:
|
||||
) -> list[Detection]:
|
||||
if spec.ndim == 4 and spec.shape[0] > 1:
|
||||
raise ValueError("Batched spectrograms not supported.")
|
||||
|
||||
@ -204,59 +350,74 @@ class BatDetect2API:
|
||||
|
||||
outputs = self.model.detector(spec)
|
||||
|
||||
detections = self.model.postprocessor(
|
||||
detections = self.postprocessor(
|
||||
outputs,
|
||||
start_times=[start_time],
|
||||
)[0]
|
||||
|
||||
return to_raw_predictions(detections.numpy(), targets=self.targets)
|
||||
return self.output_transform.to_detections(
|
||||
detections=detections,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
def process_directory(
|
||||
self,
|
||||
audio_dir: data.PathLike,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
) -> list[ClipDetections]:
|
||||
files = list(get_audio_files(audio_dir))
|
||||
return self.process_files(files)
|
||||
|
||||
def process_files(
|
||||
self,
|
||||
audio_files: Sequence[data.PathLike],
|
||||
num_workers: Optional[int] = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
batch_size: int | None = None,
|
||||
num_workers: int = 0,
|
||||
audio_config: AudioConfig | None = None,
|
||||
inference_config: InferenceConfig | None = None,
|
||||
output_config: OutputsConfig | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
return process_file_list(
|
||||
self.model,
|
||||
audio_files,
|
||||
config=self.config,
|
||||
targets=self.targets,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
output_transform=self.output_transform,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
inference_config=inference_config or self.inference_config,
|
||||
output_config=output_config or self.outputs_config,
|
||||
)
|
||||
|
||||
def process_clips(
|
||||
self,
|
||||
clips: Sequence[data.Clip],
|
||||
batch_size: Optional[int] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
batch_size: int | None = None,
|
||||
num_workers: int = 0,
|
||||
audio_config: AudioConfig | None = None,
|
||||
inference_config: InferenceConfig | None = None,
|
||||
output_config: OutputsConfig | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
return run_batch_inference(
|
||||
self.model,
|
||||
clips,
|
||||
targets=self.targets,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
config=self.config,
|
||||
output_transform=self.output_transform,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
inference_config=inference_config or self.inference_config,
|
||||
output_config=output_config or self.outputs_config,
|
||||
)
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
predictions: Sequence[ClipDetections],
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
format: Optional[str] = None,
|
||||
config: Optional[OutputFormatConfig] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
format: str | None = None,
|
||||
config: OutputFormatConfig | None = None,
|
||||
):
|
||||
formatter = self.formatter
|
||||
|
||||
@ -274,50 +435,78 @@ class BatDetect2API:
|
||||
def load_predictions(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
return self.formatter.load(path)
|
||||
format: str | None = None,
|
||||
config: OutputFormatConfig | None = None,
|
||||
) -> list[object]:
|
||||
formatter = self.formatter
|
||||
|
||||
if format is not None or config is not None:
|
||||
format = format or config.name # type: ignore
|
||||
formatter = get_output_formatter(
|
||||
name=format,
|
||||
targets=self.targets,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return formatter.load(path)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: BatDetect2Config,
|
||||
):
|
||||
targets = build_targets(config=config.targets)
|
||||
) -> "BatDetect2API":
|
||||
targets = build_targets(config=config.model.targets)
|
||||
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
|
||||
preprocessor = build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
config=config.model.preprocess,
|
||||
)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.postprocess,
|
||||
config=config.model.postprocess,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
formatter = build_output_formatter(
|
||||
targets,
|
||||
config=config.outputs.format,
|
||||
)
|
||||
output_transform = build_output_transform(
|
||||
config=config.outputs.transform,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
# NOTE: Better to have a separate instance of
|
||||
# preprocessor and postprocessor as these may be moved
|
||||
# to another device.
|
||||
evaluator = build_evaluator(
|
||||
config=config.evaluation,
|
||||
targets=targets,
|
||||
transform=output_transform,
|
||||
)
|
||||
|
||||
# NOTE: Build separate instances of preprocessor and postprocessor
|
||||
# to avoid device mismatch errors
|
||||
model = build_model(
|
||||
config=config.model,
|
||||
targets=targets,
|
||||
targets=build_targets(config=config.model.targets),
|
||||
preprocessor=build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
config=config.model.preprocess,
|
||||
),
|
||||
postprocessor=build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.postprocess,
|
||||
config=config.model.postprocess,
|
||||
),
|
||||
)
|
||||
|
||||
formatter = build_output_formatter(targets, config=config.output)
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
model_config=config.model,
|
||||
audio_config=config.audio,
|
||||
train_config=config.train,
|
||||
evaluation_config=config.evaluation,
|
||||
inference_config=config.inference,
|
||||
outputs_config=config.outputs,
|
||||
logging_config=config.logging,
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
@ -325,40 +514,83 @@ class BatDetect2API:
|
||||
evaluator=evaluator,
|
||||
model=model,
|
||||
formatter=formatter,
|
||||
output_transform=output_transform,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
path: data.PathLike,
|
||||
config: Optional[BatDetect2Config] = None,
|
||||
):
|
||||
model, stored_config = load_model_from_checkpoint(path)
|
||||
targets_config: TargetConfig | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
evaluation_config: EvaluationConfig | None = None,
|
||||
inference_config: InferenceConfig | None = None,
|
||||
outputs_config: OutputsConfig | None = None,
|
||||
logging_config: AppLoggingConfig | None = None,
|
||||
) -> "BatDetect2API":
|
||||
model, model_config = load_model_from_checkpoint(path)
|
||||
|
||||
config = (
|
||||
merge_configs(stored_config, config) if config else stored_config
|
||||
audio_config = audio_config or AudioConfig(
|
||||
samplerate=model_config.samplerate,
|
||||
)
|
||||
train_config = train_config or TrainingConfig()
|
||||
evaluation_config = evaluation_config or EvaluationConfig()
|
||||
inference_config = inference_config or InferenceConfig()
|
||||
outputs_config = outputs_config or OutputsConfig()
|
||||
logging_config = logging_config or AppLoggingConfig()
|
||||
|
||||
targets = build_targets(config=config.targets)
|
||||
if (
|
||||
targets_config is not None
|
||||
and targets_config != model_config.targets
|
||||
):
|
||||
targets = build_targets(config=targets_config)
|
||||
model = build_model_with_new_targets(
|
||||
model=model,
|
||||
targets=targets,
|
||||
)
|
||||
model_config = model_config.model_copy(
|
||||
update={"targets": targets_config}
|
||||
)
|
||||
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
targets = build_targets(config=model_config.targets)
|
||||
|
||||
audio_loader = build_audio_loader(config=audio_config)
|
||||
|
||||
preprocessor = build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
config=model_config.preprocess,
|
||||
)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.postprocess,
|
||||
config=model_config.postprocess,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
formatter = build_output_formatter(
|
||||
targets,
|
||||
config=outputs_config.format,
|
||||
)
|
||||
|
||||
formatter = build_output_formatter(targets, config=config.output)
|
||||
output_transform = build_output_transform(
|
||||
config=outputs_config.transform,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
config=evaluation_config,
|
||||
targets=targets,
|
||||
transform=output_transform,
|
||||
)
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
model_config=model_config,
|
||||
audio_config=audio_config,
|
||||
train_config=train_config,
|
||||
evaluation_config=evaluation_config,
|
||||
inference_config=inference_config,
|
||||
outputs_config=outputs_config,
|
||||
logging_config=logging_config,
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
@ -366,4 +598,27 @@ class BatDetect2API:
|
||||
evaluator=evaluator,
|
||||
model=model,
|
||||
formatter=formatter,
|
||||
output_transform=output_transform,
|
||||
)
|
||||
|
||||
def _set_trainable_parameters(
|
||||
self,
|
||||
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
|
||||
) -> None:
|
||||
detector = cast(Detector, self.model.detector)
|
||||
|
||||
for parameter in detector.parameters():
|
||||
parameter.requires_grad = False
|
||||
|
||||
if trainable == "all":
|
||||
for parameter in detector.parameters():
|
||||
parameter.requires_grad = True
|
||||
return
|
||||
|
||||
if trainable in {"heads", "classifier_head"}:
|
||||
for parameter in detector.classifier_head.parameters():
|
||||
parameter.requires_grad = True
|
||||
|
||||
if trainable in {"heads", "bbox_head"}:
|
||||
for parameter in detector.bbox_head.parameters():
|
||||
parameter.requires_grad = True
|
||||
|
||||
@ -5,8 +5,11 @@ from batdetect2.audio.loader import (
|
||||
SoundEventAudioLoader,
|
||||
build_audio_loader,
|
||||
)
|
||||
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
||||
|
||||
__all__ = [
|
||||
"AudioLoader",
|
||||
"ClipperProtocol",
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"AudioConfig",
|
||||
"SoundEventAudioLoader",
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
from typing import Annotated, List, Literal
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
@ -6,8 +6,13 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds, intervals_overlap
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.typing import ClipperProtocol
|
||||
from batdetect2.audio.types import ClipperProtocol
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
|
||||
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||
@ -16,12 +21,24 @@ DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||
__all__ = [
|
||||
"build_clipper",
|
||||
"ClipConfig",
|
||||
"ClipperImportConfig",
|
||||
]
|
||||
|
||||
|
||||
clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")
|
||||
|
||||
|
||||
@add_import_config(clipper_registry)
|
||||
class ClipperImportConfig(ImportConfig):
|
||||
"""Use any callable as a clipper.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class RandomClipConfig(BaseConfig):
|
||||
name: Literal["random_subclip"] = "random_subclip"
|
||||
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
||||
@ -245,16 +262,12 @@ class FixedDurationClip:
|
||||
|
||||
|
||||
ClipConfig = Annotated[
|
||||
Union[
|
||||
RandomClipConfig,
|
||||
PaddedClipConfig,
|
||||
FixedDurationClipConfig,
|
||||
],
|
||||
RandomClipConfig | PaddedClipConfig | FixedDurationClipConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
|
||||
def build_clipper(config: ClipConfig | None = None) -> ClipperProtocol:
|
||||
config = config or RandomClipConfig()
|
||||
|
||||
logger.opt(lazy=True).debug(
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import Field
|
||||
@ -7,8 +5,8 @@ from scipy.signal import resample, resample_poly
|
||||
from soundevent import audio, data
|
||||
from soundfile import LibsndfileError
|
||||
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.typing import AudioLoader
|
||||
|
||||
__all__ = [
|
||||
"SoundEventAudioLoader",
|
||||
@ -28,15 +26,17 @@ class ResampleConfig(BaseConfig):
|
||||
|
||||
Attributes
|
||||
----------
|
||||
samplerate : int, default=256000
|
||||
The target sample rate in Hz to resample the audio to. Must be > 0.
|
||||
enabled : bool, default=True
|
||||
Whether to resample the audio to the target sample rate. If
|
||||
``False``, the audio is returned at its original sample rate.
|
||||
method : str, default="poly"
|
||||
The resampling algorithm to use. Options:
|
||||
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
|
||||
Generally fast.
|
||||
- "fourier": Resampling via Fourier method using
|
||||
`scipy.signal.resample`. May handle non-integer
|
||||
resampling factors differently.
|
||||
|
||||
- ``"poly"``: Polyphase resampling via
|
||||
``scipy.signal.resample_poly``. Generally fast and accurate.
|
||||
- ``"fourier"``: FFT-based resampling via
|
||||
``scipy.signal.resample``. May be preferred for non-integer
|
||||
resampling ratios.
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
@ -50,7 +50,7 @@ class AudioConfig(BaseConfig):
|
||||
resample: ResampleConfig = Field(default_factory=ResampleConfig)
|
||||
|
||||
|
||||
def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader:
|
||||
def build_audio_loader(config: AudioConfig | None = None) -> AudioLoader:
|
||||
"""Factory function to create an AudioLoader based on configuration."""
|
||||
config = config or AudioConfig()
|
||||
return SoundEventAudioLoader(
|
||||
@ -65,7 +65,7 @@ class SoundEventAudioLoader(AudioLoader):
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
config: ResampleConfig | None = None,
|
||||
):
|
||||
self.samplerate = samplerate
|
||||
self.config = config or ResampleConfig()
|
||||
@ -73,7 +73,7 @@ class SoundEventAudioLoader(AudioLoader):
|
||||
def load_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio directly from a file path."""
|
||||
return load_file_audio(
|
||||
@ -86,7 +86,7 @@ class SoundEventAudioLoader(AudioLoader):
|
||||
def load_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio for a Recording object."""
|
||||
return load_recording_audio(
|
||||
@ -99,7 +99,7 @@ class SoundEventAudioLoader(AudioLoader):
|
||||
def load_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the audio segment defined by a Clip object."""
|
||||
return load_clip_audio(
|
||||
@ -112,10 +112,10 @@ class SoundEventAudioLoader(AudioLoader):
|
||||
|
||||
def load_file_audio(
|
||||
path: data.PathLike,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
samplerate: int | None = None,
|
||||
config: ResampleConfig | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio from a file path using specified config."""
|
||||
try:
|
||||
@ -136,10 +136,10 @@ def load_file_audio(
|
||||
|
||||
def load_recording_audio(
|
||||
recording: data.Recording,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
samplerate: int | None = None,
|
||||
config: ResampleConfig | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio content of a recording using config."""
|
||||
clip = data.Clip(
|
||||
@ -158,10 +158,10 @@ def load_recording_audio(
|
||||
|
||||
def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
samplerate: int | None = None,
|
||||
config: ResampleConfig | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess a specific audio clip segment based on config."""
|
||||
try:
|
||||
@ -194,7 +194,31 @@ def resample_audio(
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
method: str = "poly",
|
||||
) -> np.ndarray:
|
||||
"""Resample an audio waveform DataArray to a target sample rate."""
|
||||
"""Resample an audio waveform to a target sample rate.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : np.ndarray
|
||||
Input waveform array. The last axis is assumed to be time.
|
||||
sr : int
|
||||
Original sample rate of ``wav`` in Hz.
|
||||
samplerate : int, default=256000
|
||||
Target sample rate in Hz.
|
||||
method : str, default="poly"
|
||||
Resampling algorithm: ``"poly"`` (polyphase) or
|
||||
``"fourier"`` (FFT-based).
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Resampled waveform. If ``sr == samplerate`` the input array is
|
||||
returned unchanged.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
If ``method`` is not ``"poly"`` or ``"fourier"``.
|
||||
"""
|
||||
if sr == samplerate:
|
||||
return wav
|
||||
|
||||
@ -264,7 +288,7 @@ def resample_audio_fourier(
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
"""Resample a numpy array using `scipy.signal.resample`.
|
||||
"""Resample a numpy array using ``scipy.signal.resample``.
|
||||
|
||||
This method uses FFTs to resample the signal.
|
||||
|
||||
@ -272,23 +296,20 @@ def resample_audio_fourier(
|
||||
----------
|
||||
array : np.ndarray
|
||||
The input array to resample.
|
||||
num : int
|
||||
The desired number of samples in the output array along `axis`.
|
||||
sr_orig : int
|
||||
The original sample rate in Hz.
|
||||
sr_new : int
|
||||
The target sample rate in Hz.
|
||||
axis : int, default=-1
|
||||
The axis of `array` along which to resample.
|
||||
The axis of ``array`` along which to resample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The array resampled to have `num` samples along `axis`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num` is negative.
|
||||
The array resampled to the target sample rate.
|
||||
"""
|
||||
ratio = sr_new / sr_orig
|
||||
return resample( # type: ignore
|
||||
return resample(
|
||||
array,
|
||||
int(array.shape[axis] * ratio),
|
||||
axis=axis,
|
||||
|
||||
40
src/batdetect2/audio/types.py
Normal file
40
src/batdetect2/audio/types.py
Normal file
@ -0,0 +1,40 @@
|
||||
from typing import Protocol
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
"AudioLoader",
|
||||
"ClipperProtocol",
|
||||
]
|
||||
|
||||
|
||||
class AudioLoader(Protocol):
|
||||
samplerate: int
|
||||
|
||||
def load_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray: ...
|
||||
|
||||
def load_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray: ...
|
||||
|
||||
def load_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
) -> np.ndarray: ...
|
||||
|
||||
|
||||
class ClipperProtocol(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
) -> data.ClipAnnotation: ...
|
||||
|
||||
def get_subclip(self, clip: data.Clip) -> data.Clip: ...
|
||||
@ -2,6 +2,7 @@ from batdetect2.cli.base import cli
|
||||
from batdetect2.cli.compat import detect
|
||||
from batdetect2.cli.data import data
|
||||
from batdetect2.cli.evaluate import evaluate_command
|
||||
from batdetect2.cli.inference import predict
|
||||
from batdetect2.cli.train import train_command
|
||||
|
||||
__all__ = [
|
||||
@ -10,6 +11,7 @@ __all__ = [
|
||||
"data",
|
||||
"train_command",
|
||||
"evaluate_command",
|
||||
"predict",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
|
||||
@ -34,9 +33,9 @@ def data(): ...
|
||||
)
|
||||
def summary(
|
||||
dataset_config: Path,
|
||||
field: Optional[str] = None,
|
||||
targets_path: Optional[Path] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
field: str | None = None,
|
||||
targets_path: Path | None = None,
|
||||
base_dir: Path | None = None,
|
||||
):
|
||||
from batdetect2.data import compute_class_summary, load_dataset_from_config
|
||||
from batdetect2.targets import load_targets
|
||||
@ -83,9 +82,9 @@ def summary(
|
||||
)
|
||||
def convert(
|
||||
dataset_config: Path,
|
||||
field: Optional[str] = None,
|
||||
field: str | None = None,
|
||||
output: Path = Path("annotations.json"),
|
||||
base_dir: Optional[Path] = None,
|
||||
base_dir: Path | None = None,
|
||||
):
|
||||
"""Convert a dataset config file to soundevent format."""
|
||||
from soundevent import data, io
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
@ -13,9 +12,14 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||
|
||||
|
||||
@cli.command(name="evaluate")
|
||||
@click.argument("model-path", type=click.Path(exists=True))
|
||||
@click.argument("model_path", type=click.Path(exists=True))
|
||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||
@click.option("--config", "config_path", type=click.Path())
|
||||
@click.option("--targets", "targets_config", type=click.Path(exists=True))
|
||||
@click.option("--audio-config", type=click.Path(exists=True))
|
||||
@click.option("--evaluation-config", type=click.Path(exists=True))
|
||||
@click.option("--inference-config", type=click.Path(exists=True))
|
||||
@click.option("--outputs-config", type=click.Path(exists=True))
|
||||
@click.option("--logging-config", type=click.Path(exists=True))
|
||||
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
|
||||
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
|
||||
@click.option("--experiment-name", type=str)
|
||||
@ -25,15 +29,25 @@ def evaluate_command(
|
||||
model_path: Path,
|
||||
test_dataset: Path,
|
||||
base_dir: Path,
|
||||
config_path: Optional[Path],
|
||||
targets_config: Path | None,
|
||||
audio_config: Path | None,
|
||||
evaluation_config: Path | None,
|
||||
inference_config: Path | None,
|
||||
outputs_config: Path | None,
|
||||
logging_config: Path | None,
|
||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||
num_workers: Optional[int] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
num_workers: int = 0,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
):
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import load_full_config
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.evaluate import EvaluationConfig
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.targets import TargetConfig
|
||||
|
||||
logger.info("Initiating evaluation process...")
|
||||
|
||||
@ -47,11 +61,44 @@ def evaluate_command(
|
||||
num_annotations=len(test_annotations),
|
||||
)
|
||||
|
||||
config = None
|
||||
if config_path is not None:
|
||||
config = load_full_config(config_path)
|
||||
target_conf = (
|
||||
TargetConfig.load(targets_config)
|
||||
if targets_config is not None
|
||||
else None
|
||||
)
|
||||
audio_conf = (
|
||||
AudioConfig.load(audio_config) if audio_config is not None else None
|
||||
)
|
||||
eval_conf = (
|
||||
EvaluationConfig.load(evaluation_config)
|
||||
if evaluation_config is not None
|
||||
else None
|
||||
)
|
||||
inference_conf = (
|
||||
InferenceConfig.load(inference_config)
|
||||
if inference_config is not None
|
||||
else None
|
||||
)
|
||||
outputs_conf = (
|
||||
OutputsConfig.load(outputs_config)
|
||||
if outputs_config is not None
|
||||
else None
|
||||
)
|
||||
logging_conf = (
|
||||
AppLoggingConfig.load(logging_config)
|
||||
if logging_config is not None
|
||||
else None
|
||||
)
|
||||
|
||||
api = BatDetect2API.from_checkpoint(model_path, config=config)
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
model_path,
|
||||
targets_config=target_conf,
|
||||
audio_config=audio_conf,
|
||||
evaluation_config=eval_conf,
|
||||
inference_config=inference_conf,
|
||||
outputs_config=outputs_conf,
|
||||
logging_config=logging_conf,
|
||||
)
|
||||
|
||||
api.evaluate(
|
||||
test_annotations,
|
||||
|
||||
231
src/batdetect2/cli/inference.py
Normal file
231
src/batdetect2/cli/inference.py
Normal file
@ -0,0 +1,231 @@
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
from soundevent import io
|
||||
from soundevent.audio.files import get_audio_files
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
|
||||
__all__ = ["predict"]
|
||||
|
||||
|
||||
@cli.group(name="predict")
|
||||
def predict() -> None:
|
||||
"""Run prediction with BatDetect2 API v2."""
|
||||
|
||||
|
||||
def _build_api(
|
||||
model_path: Path,
|
||||
audio_config: Path | None,
|
||||
inference_config: Path | None,
|
||||
outputs_config: Path | None,
|
||||
logging_config: Path | None,
|
||||
):
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
|
||||
audio_conf = (
|
||||
AudioConfig.load(audio_config) if audio_config is not None else None
|
||||
)
|
||||
inference_conf = (
|
||||
InferenceConfig.load(inference_config)
|
||||
if inference_config is not None
|
||||
else None
|
||||
)
|
||||
outputs_conf = (
|
||||
OutputsConfig.load(outputs_config)
|
||||
if outputs_config is not None
|
||||
else None
|
||||
)
|
||||
logging_conf = (
|
||||
AppLoggingConfig.load(logging_config)
|
||||
if logging_config is not None
|
||||
else None
|
||||
)
|
||||
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
model_path,
|
||||
audio_config=audio_conf,
|
||||
inference_config=inference_conf,
|
||||
outputs_config=outputs_conf,
|
||||
logging_config=logging_conf,
|
||||
)
|
||||
return api, audio_conf, inference_conf, outputs_conf
|
||||
|
||||
|
||||
def _run_prediction(
|
||||
model_path: Path,
|
||||
audio_files: list[Path],
|
||||
output_path: Path,
|
||||
audio_config: Path | None,
|
||||
inference_config: Path | None,
|
||||
outputs_config: Path | None,
|
||||
logging_config: Path | None,
|
||||
batch_size: int | None,
|
||||
num_workers: int,
|
||||
format_name: str | None,
|
||||
) -> None:
|
||||
logger.info("Initiating prediction process...")
|
||||
|
||||
api, audio_conf, inference_conf, outputs_conf = _build_api(
|
||||
model_path,
|
||||
audio_config,
|
||||
inference_config,
|
||||
outputs_config,
|
||||
logging_config,
|
||||
)
|
||||
|
||||
logger.info("Found {num_files} audio files", num_files=len(audio_files))
|
||||
|
||||
predictions = api.process_files(
|
||||
audio_files,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
audio_config=audio_conf,
|
||||
inference_config=inference_conf,
|
||||
output_config=outputs_conf,
|
||||
)
|
||||
|
||||
common_path = audio_files[0].parent if audio_files else None
|
||||
api.save_predictions(
|
||||
predictions,
|
||||
path=output_path,
|
||||
audio_dir=common_path,
|
||||
format=format_name,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Prediction complete. Results saved to {path}", path=output_path
|
||||
)
|
||||
|
||||
|
||||
@predict.command(name="directory")
|
||||
@click.argument("model_path", type=click.Path(exists=True))
|
||||
@click.argument("audio_dir", type=click.Path(exists=True))
|
||||
@click.argument("output_path", type=click.Path())
|
||||
@click.option("--audio-config", type=click.Path(exists=True))
|
||||
@click.option("--inference-config", type=click.Path(exists=True))
|
||||
@click.option("--outputs-config", type=click.Path(exists=True))
|
||||
@click.option("--logging-config", type=click.Path(exists=True))
|
||||
@click.option("--batch-size", type=int)
|
||||
@click.option("--workers", "num_workers", type=int, default=0)
|
||||
@click.option("--format", "format_name", type=str)
|
||||
def predict_directory_command(
|
||||
model_path: Path,
|
||||
audio_dir: Path,
|
||||
output_path: Path,
|
||||
audio_config: Path | None,
|
||||
inference_config: Path | None,
|
||||
outputs_config: Path | None,
|
||||
logging_config: Path | None,
|
||||
batch_size: int | None,
|
||||
num_workers: int,
|
||||
format_name: str | None,
|
||||
) -> None:
|
||||
audio_files = list(get_audio_files(audio_dir))
|
||||
_run_prediction(
|
||||
model_path=model_path,
|
||||
audio_files=audio_files,
|
||||
output_path=output_path,
|
||||
audio_config=audio_config,
|
||||
inference_config=inference_config,
|
||||
outputs_config=outputs_config,
|
||||
logging_config=logging_config,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
format_name=format_name,
|
||||
)
|
||||
|
||||
|
||||
@predict.command(name="file_list")
|
||||
@click.argument("model_path", type=click.Path(exists=True))
|
||||
@click.argument("file_list", type=click.Path(exists=True))
|
||||
@click.argument("output_path", type=click.Path())
|
||||
@click.option("--audio-config", type=click.Path(exists=True))
|
||||
@click.option("--inference-config", type=click.Path(exists=True))
|
||||
@click.option("--outputs-config", type=click.Path(exists=True))
|
||||
@click.option("--logging-config", type=click.Path(exists=True))
|
||||
@click.option("--batch-size", type=int)
|
||||
@click.option("--workers", "num_workers", type=int, default=0)
|
||||
@click.option("--format", "format_name", type=str)
|
||||
def predict_file_list_command(
|
||||
model_path: Path,
|
||||
file_list: Path,
|
||||
output_path: Path,
|
||||
audio_config: Path | None,
|
||||
inference_config: Path | None,
|
||||
outputs_config: Path | None,
|
||||
logging_config: Path | None,
|
||||
batch_size: int | None,
|
||||
num_workers: int,
|
||||
format_name: str | None,
|
||||
) -> None:
|
||||
file_list = Path(file_list)
|
||||
audio_files = [
|
||||
Path(line.strip())
|
||||
for line in file_list.read_text().splitlines()
|
||||
if line.strip()
|
||||
]
|
||||
|
||||
_run_prediction(
|
||||
model_path=model_path,
|
||||
audio_files=audio_files,
|
||||
output_path=output_path,
|
||||
audio_config=audio_config,
|
||||
inference_config=inference_config,
|
||||
outputs_config=outputs_config,
|
||||
logging_config=logging_config,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
format_name=format_name,
|
||||
)
|
||||
|
||||
|
||||
@predict.command(name="dataset")
|
||||
@click.argument("model_path", type=click.Path(exists=True))
|
||||
@click.argument("dataset_path", type=click.Path(exists=True))
|
||||
@click.argument("output_path", type=click.Path())
|
||||
@click.option("--audio-config", type=click.Path(exists=True))
|
||||
@click.option("--inference-config", type=click.Path(exists=True))
|
||||
@click.option("--outputs-config", type=click.Path(exists=True))
|
||||
@click.option("--logging-config", type=click.Path(exists=True))
|
||||
@click.option("--batch-size", type=int)
|
||||
@click.option("--workers", "num_workers", type=int, default=0)
|
||||
@click.option("--format", "format_name", type=str)
|
||||
def predict_dataset_command(
|
||||
model_path: Path,
|
||||
dataset_path: Path,
|
||||
output_path: Path,
|
||||
audio_config: Path | None,
|
||||
inference_config: Path | None,
|
||||
outputs_config: Path | None,
|
||||
logging_config: Path | None,
|
||||
batch_size: int | None,
|
||||
num_workers: int,
|
||||
format_name: str | None,
|
||||
) -> None:
|
||||
dataset_path = Path(dataset_path)
|
||||
dataset = io.load(dataset_path, type="annotation_set")
|
||||
audio_files = sorted(
|
||||
{
|
||||
Path(clip_annotation.clip.recording.path)
|
||||
for clip_annotation in dataset.clip_annotations
|
||||
}
|
||||
)
|
||||
|
||||
_run_prediction(
|
||||
model_path=model_path,
|
||||
audio_files=audio_files,
|
||||
output_path=output_path,
|
||||
audio_config=audio_config,
|
||||
inference_config=inference_config,
|
||||
outputs_config=outputs_config,
|
||||
logging_config=logging_config,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
format_name=format_name,
|
||||
)
|
||||
@ -1,5 +1,4 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
@ -14,10 +13,15 @@ __all__ = ["train_command"]
|
||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||
@click.option("--model", "model_path", type=click.Path(exists=True))
|
||||
@click.option("--targets", "targets_config", type=click.Path(exists=True))
|
||||
@click.option("--model-config", type=click.Path(exists=True))
|
||||
@click.option("--training-config", type=click.Path(exists=True))
|
||||
@click.option("--audio-config", type=click.Path(exists=True))
|
||||
@click.option("--evaluation-config", type=click.Path(exists=True))
|
||||
@click.option("--inference-config", type=click.Path(exists=True))
|
||||
@click.option("--outputs-config", type=click.Path(exists=True))
|
||||
@click.option("--logging-config", type=click.Path(exists=True))
|
||||
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
||||
@click.option("--log-dir", type=click.Path(exists=True))
|
||||
@click.option("--config", type=click.Path(exists=True))
|
||||
@click.option("--config-field", type=str)
|
||||
@click.option("--train-workers", type=int)
|
||||
@click.option("--val-workers", type=int)
|
||||
@click.option("--num-epochs", type=int)
|
||||
@ -26,42 +30,82 @@ __all__ = ["train_command"]
|
||||
@click.option("--seed", type=int)
|
||||
def train_command(
|
||||
train_dataset: Path,
|
||||
val_dataset: Optional[Path] = None,
|
||||
model_path: Optional[Path] = None,
|
||||
ckpt_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
config: Optional[Path] = None,
|
||||
targets_config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
val_dataset: Path | None = None,
|
||||
model_path: Path | None = None,
|
||||
ckpt_dir: Path | None = None,
|
||||
log_dir: Path | None = None,
|
||||
targets_config: Path | None = None,
|
||||
model_config: Path | None = None,
|
||||
training_config: Path | None = None,
|
||||
audio_config: Path | None = None,
|
||||
evaluation_config: Path | None = None,
|
||||
inference_config: Path | None = None,
|
||||
outputs_config: Path | None = None,
|
||||
logging_config: Path | None = None,
|
||||
seed: int | None = None,
|
||||
num_epochs: int | None = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
):
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import (
|
||||
BatDetect2Config,
|
||||
load_full_config,
|
||||
)
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.targets import load_target_config
|
||||
from batdetect2.evaluate import EvaluationConfig
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train import TrainingConfig
|
||||
|
||||
logger.info("Initiating training process...")
|
||||
|
||||
logger.info("Loading configuration...")
|
||||
conf = (
|
||||
load_full_config(config, field=config_field)
|
||||
if config is not None
|
||||
else BatDetect2Config()
|
||||
target_conf = (
|
||||
TargetConfig.load(targets_config)
|
||||
if targets_config is not None
|
||||
else None
|
||||
)
|
||||
model_conf = (
|
||||
ModelConfig.load(model_config) if model_config is not None else None
|
||||
)
|
||||
train_conf = (
|
||||
TrainingConfig.load(training_config)
|
||||
if training_config is not None
|
||||
else None
|
||||
)
|
||||
audio_conf = (
|
||||
AudioConfig.load(audio_config) if audio_config is not None else None
|
||||
)
|
||||
eval_conf = (
|
||||
EvaluationConfig.load(evaluation_config)
|
||||
if evaluation_config is not None
|
||||
else None
|
||||
)
|
||||
inference_conf = (
|
||||
InferenceConfig.load(inference_config)
|
||||
if inference_config is not None
|
||||
else None
|
||||
)
|
||||
outputs_conf = (
|
||||
OutputsConfig.load(outputs_config)
|
||||
if outputs_config is not None
|
||||
else None
|
||||
)
|
||||
logging_conf = (
|
||||
AppLoggingConfig.load(logging_config)
|
||||
if logging_config is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if targets_config is not None:
|
||||
logger.info("Loading targets configuration...")
|
||||
conf = conf.model_copy(
|
||||
update=dict(targets=load_target_config(targets_config))
|
||||
)
|
||||
if target_conf is not None:
|
||||
logger.info("Loaded targets configuration.")
|
||||
|
||||
if model_conf is not None and target_conf is not None:
|
||||
model_conf = model_conf.model_copy(update={"targets": target_conf})
|
||||
|
||||
logger.info("Loading training dataset...")
|
||||
train_annotations = load_dataset_from_config(train_dataset)
|
||||
@ -82,12 +126,43 @@ def train_command(
|
||||
|
||||
logger.info("Configuration and data loaded. Starting training...")
|
||||
|
||||
if model_path is not None and model_conf is not None:
|
||||
raise click.UsageError(
|
||||
"--model-config cannot be used with --model. "
|
||||
"Checkpoint model configuration is loaded from the checkpoint."
|
||||
)
|
||||
|
||||
if model_path is None:
|
||||
conf = BatDetect2Config()
|
||||
if model_conf is not None:
|
||||
conf.model = model_conf
|
||||
elif target_conf is not None:
|
||||
conf.model = conf.model.model_copy(update={"targets": target_conf})
|
||||
|
||||
if train_conf is not None:
|
||||
conf.train = train_conf
|
||||
if audio_conf is not None:
|
||||
conf.audio = audio_conf
|
||||
if eval_conf is not None:
|
||||
conf.evaluation = eval_conf
|
||||
if inference_conf is not None:
|
||||
conf.inference = inference_conf
|
||||
if outputs_conf is not None:
|
||||
conf.outputs = outputs_conf
|
||||
if logging_conf is not None:
|
||||
conf.logging = logging_conf
|
||||
|
||||
api = BatDetect2API.from_config(conf)
|
||||
else:
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
model_path,
|
||||
config=conf if config is not None else None,
|
||||
targets_config=target_conf,
|
||||
train_config=train_conf,
|
||||
audio_config=audio_conf,
|
||||
evaluation_config=eval_conf,
|
||||
inference_config=inference_conf,
|
||||
outputs_config=outputs_conf,
|
||||
logging_config=logging_conf,
|
||||
)
|
||||
|
||||
return api.train(
|
||||
|
||||
@ -4,7 +4,7 @@ import json
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
@ -17,7 +17,7 @@ from batdetect2.types import (
|
||||
FileAnnotation,
|
||||
)
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
PathLike = Path | str | os.PathLike
|
||||
|
||||
__all__ = [
|
||||
"convert_to_annotation_group",
|
||||
@ -33,7 +33,7 @@ UNKNOWN_CLASS = "__UNKNOWN__"
|
||||
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
|
||||
|
||||
|
||||
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||
EventFn = Callable[[data.SoundEventAnnotation], str | None]
|
||||
|
||||
ClassFn = Callable[[data.Recording], int]
|
||||
|
||||
@ -103,17 +103,17 @@ def convert_to_annotation_group(
|
||||
y_inds.append(0)
|
||||
|
||||
annotations.append(
|
||||
{
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"low_freq": low_freq,
|
||||
"high_freq": high_freq,
|
||||
"class_prob": 1.0,
|
||||
"det_prob": 1.0,
|
||||
"individual": "0",
|
||||
"event": event,
|
||||
"class_id": class_id, # type: ignore
|
||||
}
|
||||
Annotation(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
class_prob=1.0,
|
||||
det_prob=1.0,
|
||||
individual="0",
|
||||
event=event,
|
||||
class_id=class_id,
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
@ -221,7 +221,7 @@ def annotation_to_sound_event_prediction(
|
||||
|
||||
def file_annotation_to_clip(
|
||||
file_annotation: FileAnnotation,
|
||||
audio_dir: Optional[PathLike] = None,
|
||||
audio_dir: PathLike | None = None,
|
||||
label_key: str = "class",
|
||||
) -> data.Clip:
|
||||
"""Convert file annotation to recording."""
|
||||
|
||||
@ -1,28 +1,20 @@
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.data.predictions import OutputFormatConfig
|
||||
from batdetect2.data.predictions.raw import RawOutputConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.evaluate.config import (
|
||||
EvaluationConfig,
|
||||
get_default_eval_config,
|
||||
)
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
from batdetect2.models.config import BackboneConfig
|
||||
from batdetect2.postprocess.config import PostprocessConfig
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.targets.config import TargetConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
|
||||
__all__ = [
|
||||
"BatDetect2Config",
|
||||
"load_full_config",
|
||||
"validate_config",
|
||||
]
|
||||
__all__ = ["BatDetect2Config"]
|
||||
|
||||
|
||||
class BatDetect2Config(BaseConfig):
|
||||
@ -32,26 +24,8 @@ class BatDetect2Config(BaseConfig):
|
||||
evaluation: EvaluationConfig = Field(
|
||||
default_factory=get_default_eval_config
|
||||
)
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
model: ModelConfig = Field(default_factory=ModelConfig)
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
|
||||
|
||||
|
||||
def validate_config(config: Optional[dict]) -> BatDetect2Config:
|
||||
if config is None:
|
||||
return BatDetect2Config()
|
||||
|
||||
return BatDetect2Config.model_validate(config)
|
||||
|
||||
|
||||
def load_full_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> BatDetect2Config:
|
||||
return load_config(path, schema=BatDetect2Config, field=field)
|
||||
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
|
||||
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
|
||||
|
||||
@ -1,8 +1,14 @@
|
||||
from batdetect2.core.configs import BaseConfig, load_config, merge_configs
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"add_import_config",
|
||||
"BaseConfig",
|
||||
"ImportConfig",
|
||||
"load_config",
|
||||
"Registry",
|
||||
"merge_configs",
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
@ -86,8 +84,8 @@ def adjust_width(
|
||||
|
||||
def slice_tensor(
|
||||
tensor: torch.Tensor,
|
||||
start: Optional[int] = None,
|
||||
end: Optional[int] = None,
|
||||
start: int | None = None,
|
||||
end: int | None = None,
|
||||
dim: int = -1,
|
||||
) -> torch.Tensor:
|
||||
slices = [slice(None)] * tensor.ndim
|
||||
|
||||
@ -8,11 +8,11 @@ configuration data from files, with optional support for accessing nested
|
||||
configuration sections.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, Type, TypeVar
|
||||
from typing import Any, Literal, Type, TypeVar, overload
|
||||
|
||||
import yaml
|
||||
from deepmerge.merger import Merger
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
from soundevent.data import PathLike
|
||||
|
||||
__all__ = [
|
||||
@ -21,6 +21,8 @@ __all__ = [
|
||||
"merge_configs",
|
||||
]
|
||||
|
||||
C = TypeVar("C", bound="BaseConfig")
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
"""Base class for all configuration models in BatDetect2.
|
||||
@ -62,8 +64,30 @@ class BaseConfig(BaseModel):
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_str: str):
|
||||
return cls.model_validate(yaml.safe_load(yaml_str))
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
@classmethod
|
||||
def load(
|
||||
cls: Type[C],
|
||||
path: PathLike,
|
||||
field: str | None = None,
|
||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||
strict: bool | None = None,
|
||||
) -> C:
|
||||
return load_config(
|
||||
path,
|
||||
schema=cls,
|
||||
field=field,
|
||||
extra=extra,
|
||||
strict=strict,
|
||||
)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
||||
Schema = Type[T_Model] | TypeAdapter[T]
|
||||
|
||||
|
||||
def get_object_field(obj: dict, current_key: str) -> Any:
|
||||
@ -125,35 +149,69 @@ def get_object_field(obj: dict, current_key: str) -> Any:
|
||||
return get_object_field(subobj, rest)
|
||||
|
||||
|
||||
@overload
|
||||
def load_config(
|
||||
path: PathLike,
|
||||
schema: Type[T],
|
||||
field: Optional[str] = None,
|
||||
) -> T:
|
||||
schema: Type[T_Model],
|
||||
field: str | None = None,
|
||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||
strict: bool | None = None,
|
||||
) -> T_Model: ...
|
||||
|
||||
|
||||
@overload
|
||||
def load_config(
|
||||
path: PathLike,
|
||||
schema: TypeAdapter[T],
|
||||
field: str | None = None,
|
||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||
strict: bool | None = None,
|
||||
) -> T: ...
|
||||
|
||||
|
||||
def load_config(
|
||||
path: PathLike,
|
||||
schema: Type[T_Model] | TypeAdapter[T],
|
||||
field: str | None = None,
|
||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||
strict: bool | None = None,
|
||||
) -> T_Model | T:
|
||||
"""Load and validate configuration data from a file against a schema.
|
||||
|
||||
Reads a YAML file, optionally extracts a specific section using dot
|
||||
notation, and then validates the resulting data against the provided
|
||||
Pydantic `schema`.
|
||||
Pydantic schema.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
The path to the configuration file (typically `.yaml`).
|
||||
schema : Type[T]
|
||||
The Pydantic `BaseModel` subclass that defines the expected structure
|
||||
and types for the configuration data.
|
||||
schema : Type[T_Model] | TypeAdapter[T]
|
||||
Either a Pydantic `BaseModel` subclass or a `TypeAdapter` instance
|
||||
that defines the expected structure and types for the configuration
|
||||
data.
|
||||
field : str, optional
|
||||
A dot-separated string indicating a nested section within the YAML
|
||||
file to extract before validation. If None (default), the entire
|
||||
file content is validated against the schema.
|
||||
Example: `"training.optimizer"` would extract the `optimizer` section
|
||||
within the `training` section.
|
||||
extra : Literal["ignore", "allow", "forbid"], optional
|
||||
How to handle extra keys in the configuration data. If None (default),
|
||||
the default behaviour of the schema is used. If "ignore", extra keys
|
||||
are ignored. If "allow", extra keys are allowed and will be accessible
|
||||
as attributes on the resulting model instance. If "forbid", extra
|
||||
keys are forbidden and an exception is raised. See pydantic
|
||||
documentation for more details.
|
||||
strict : bool, optional
|
||||
Whether to enforce types strictly. If None (default), the default
|
||||
behaviour of the schema is used. See pydantic documentation for more
|
||||
details.
|
||||
|
||||
Returns
|
||||
-------
|
||||
T
|
||||
An instance of the provided `schema`, populated and validated with
|
||||
T_Model | T
|
||||
An instance of the schema type, populated and validated with
|
||||
data from the configuration file.
|
||||
|
||||
Raises
|
||||
@ -179,7 +237,10 @@ def load_config(
|
||||
if field:
|
||||
config = get_object_field(config, field)
|
||||
|
||||
return schema.model_validate(config or {})
|
||||
if isinstance(schema, TypeAdapter):
|
||||
return schema.validate_python(config or {}, extra=extra, strict=strict)
|
||||
|
||||
return schema.model_validate(config or {}, extra=extra, strict=strict)
|
||||
|
||||
|
||||
default_merger = Merger(
|
||||
@ -189,7 +250,7 @@ default_merger = Merger(
|
||||
)
|
||||
|
||||
|
||||
def merge_configs(config1: T, config2: T) -> T:
|
||||
def merge_configs(config1: T_Model, config2: T_Model) -> T_Model:
|
||||
"""Merge two configuration objects."""
|
||||
model = type(config1)
|
||||
dict1 = config1.model_dump()
|
||||
|
||||
@ -1,23 +1,28 @@
|
||||
import sys
|
||||
from typing import Callable, Dict, Generic, Tuple, Type, TypeVar
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Concatenate,
|
||||
Generic,
|
||||
ParamSpec,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import Concatenate, ParamSpec
|
||||
else:
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
from hydra.utils import instantiate
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
__all__ = [
|
||||
"add_import_config",
|
||||
"ImportConfig",
|
||||
"Registry",
|
||||
"SimpleRegistry",
|
||||
]
|
||||
|
||||
|
||||
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
|
||||
T_Type = TypeVar("T_Type", covariant=True)
|
||||
P_Type = ParamSpec("P_Type")
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@ -43,12 +48,13 @@ class SimpleRegistry(Generic[T]):
|
||||
class Registry(Generic[T_Type, P_Type]):
|
||||
"""A generic class to create and manage a registry of items."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
def __init__(self, name: str, discriminator: str = "name"):
|
||||
self._name = name
|
||||
self._registry: Dict[
|
||||
self._registry: dict[
|
||||
str, Callable[Concatenate[..., P_Type], T_Type]
|
||||
] = {}
|
||||
self._config_types: Dict[str, Type[BaseModel]] = {}
|
||||
self._discriminator = discriminator
|
||||
self._config_types: dict[str, Type[BaseModel]] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
@ -56,15 +62,20 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
):
|
||||
fields = config_cls.model_fields
|
||||
|
||||
if "name" not in fields:
|
||||
raise ValueError("Configuration object must have a 'name' field.")
|
||||
if self._discriminator not in fields:
|
||||
raise ValueError(
|
||||
"Configuration object must have "
|
||||
f"a '{self._discriminator}' field."
|
||||
)
|
||||
|
||||
name = fields["name"].default
|
||||
name = fields[self._discriminator].default
|
||||
|
||||
self._config_types[name] = config_cls
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise ValueError("'name' field must be a string literal.")
|
||||
raise ValueError(
|
||||
f"'{self._discriminator}' field must be a string literal."
|
||||
)
|
||||
|
||||
def decorator(
|
||||
func: Callable[Concatenate[T_Config, P_Type], T_Type],
|
||||
@ -74,7 +85,7 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
|
||||
return decorator
|
||||
|
||||
def get_config_types(self) -> Tuple[Type[BaseModel], ...]:
|
||||
def get_config_types(self) -> tuple[Type[BaseModel], ...]:
|
||||
return tuple(self._config_types.values())
|
||||
|
||||
def get_config_type(self, name: str) -> Type[BaseModel]:
|
||||
@ -94,10 +105,12 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
) -> T_Type:
|
||||
"""Builds a logic instance from a config object."""
|
||||
|
||||
name = getattr(config, "name") # noqa: B009
|
||||
name = getattr(config, self._discriminator) # noqa: B009
|
||||
|
||||
if name is None:
|
||||
raise ValueError("Config does not have a name field")
|
||||
raise ValueError(
|
||||
f"Config does not have a '{self._discriminator}' field"
|
||||
)
|
||||
|
||||
if name not in self._registry:
|
||||
raise NotImplementedError(
|
||||
@ -105,3 +118,92 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
)
|
||||
|
||||
return self._registry[name](config, *args, **kwargs)
|
||||
|
||||
|
||||
class ImportConfig(BaseModel):
|
||||
"""Base config for dynamic instantiation via Hydra.
|
||||
|
||||
Subclass this to create a registry-specific import escape hatch.
|
||||
The subclass must add a discriminator field whose name matches the
|
||||
registry's own discriminator key, with its value fixed to
|
||||
``Literal["import"]``.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
target : str
|
||||
Fully-qualified dotted path to the callable to instantiate,
|
||||
e.g. ``"mypackage.module.MyClass"``.
|
||||
arguments : dict[str, Any]
|
||||
Base keyword arguments forwarded to the callable. When the
|
||||
same key also appears in ``kwargs`` passed to ``build()``,
|
||||
the ``kwargs`` value takes priority.
|
||||
"""
|
||||
|
||||
target: str
|
||||
arguments: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
T_Import = TypeVar("T_Import", bound=ImportConfig)
|
||||
|
||||
|
||||
def add_import_config(
|
||||
registry: Registry[T_Type, P_Type],
|
||||
arg_names: Sequence[str] | None = None,
|
||||
) -> Callable[[Type[T_Import]], Type[T_Import]]:
|
||||
"""Decorator that registers an ImportConfig subclass as an escape hatch.
|
||||
|
||||
Wraps the decorated class in a builder that calls
|
||||
``hydra.utils.instantiate`` using ``config.target`` and
|
||||
``config.arguments``. The builder is registered on *registry*
|
||||
under the discriminator value ``"import"``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
registry : Registry
|
||||
The registry instance on which the config should be registered.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Callable[[type[ImportConfig]], type[ImportConfig]]
|
||||
A class decorator that registers the class and returns it
|
||||
unchanged.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Define a per-registry import escape hatch::
|
||||
|
||||
@add_import_config(my_registry)
|
||||
class MyRegistryImportConfig(ImportConfig):
|
||||
name: Literal["import"] = "import"
|
||||
"""
|
||||
|
||||
def decorator(config_cls: Type[T_Import]) -> Type[T_Import]:
|
||||
def builder(
|
||||
config: T_Import,
|
||||
*args: P_Type.args,
|
||||
**kwargs: P_Type.kwargs,
|
||||
) -> T_Type:
|
||||
_arg_names = arg_names or []
|
||||
|
||||
if len(args) != len(_arg_names):
|
||||
raise ValueError(
|
||||
"Positional arguments are not supported "
|
||||
"for import escape hatch unless you specify "
|
||||
"the argument names. Use `arg_names` to specify "
|
||||
"the names of the positional arguments."
|
||||
)
|
||||
|
||||
args_dict = {_arg_names[i]: args[i] for i in range(len(args))}
|
||||
|
||||
hydra_cfg = {
|
||||
"_target_": config.target,
|
||||
**config.arguments,
|
||||
**args_dict,
|
||||
**kwargs,
|
||||
}
|
||||
return instantiate(hydra_cfg)
|
||||
|
||||
registry.register(config_cls)(builder)
|
||||
return config_cls
|
||||
|
||||
return decorator
|
||||
|
||||
@ -7,20 +7,12 @@ from batdetect2.data.annotations import (
|
||||
load_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.datasets import (
|
||||
Dataset,
|
||||
DatasetConfig,
|
||||
load_dataset,
|
||||
load_dataset_config,
|
||||
load_dataset_from_config,
|
||||
)
|
||||
from batdetect2.data.predictions import (
|
||||
BatDetect2OutputConfig,
|
||||
OutputFormatConfig,
|
||||
RawOutputConfig,
|
||||
SoundEventOutputConfig,
|
||||
build_output_formatter,
|
||||
get_output_formatter,
|
||||
load_predictions,
|
||||
)
|
||||
from batdetect2.data.summary import (
|
||||
compute_class_summary,
|
||||
extract_recordings_df,
|
||||
@ -28,6 +20,7 @@ from batdetect2.data.summary import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Dataset",
|
||||
"AOEFAnnotations",
|
||||
"AnnotatedDataset",
|
||||
"AnnotationFormats",
|
||||
@ -36,6 +29,7 @@ __all__ = [
|
||||
"BatDetect2OutputConfig",
|
||||
"DatasetConfig",
|
||||
"OutputFormatConfig",
|
||||
"ParquetOutputConfig",
|
||||
"RawOutputConfig",
|
||||
"SoundEventOutputConfig",
|
||||
"build_output_formatter",
|
||||
|
||||
@ -13,22 +13,18 @@ format-specific loading function to retrieve the annotations as a standard
|
||||
`soundevent.data.AnnotationSet`.
|
||||
"""
|
||||
|
||||
from typing import Annotated, Optional, Union
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.annotations.aoef import (
|
||||
AOEFAnnotations,
|
||||
load_aoef_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.annotations.aoef import AOEFAnnotations
|
||||
from batdetect2.data.annotations.batdetect2 import (
|
||||
AnnotationFilter,
|
||||
BatDetect2FilesAnnotations,
|
||||
BatDetect2MergedAnnotations,
|
||||
load_batdetect2_files_annotated_dataset,
|
||||
load_batdetect2_merged_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.annotations.registry import annotation_format_registry
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
__all__ = [
|
||||
@ -43,11 +39,7 @@ __all__ = [
|
||||
|
||||
|
||||
AnnotationFormats = Annotated[
|
||||
Union[
|
||||
BatDetect2MergedAnnotations,
|
||||
BatDetect2FilesAnnotations,
|
||||
AOEFAnnotations,
|
||||
],
|
||||
BatDetect2MergedAnnotations | BatDetect2FilesAnnotations | AOEFAnnotations,
|
||||
Field(discriminator="format"),
|
||||
]
|
||||
"""Type Alias representing all supported data source configurations.
|
||||
@ -63,24 +55,24 @@ source configuration represents.
|
||||
|
||||
def load_annotated_dataset(
|
||||
dataset: AnnotatedDataset,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> data.AnnotationSet:
|
||||
"""Load annotations for a single data source based on its configuration.
|
||||
|
||||
This function acts as a dispatcher. It inspects the type of the input
|
||||
`source_config` object (which corresponds to a specific annotation format)
|
||||
and calls the appropriate loading function (e.g.,
|
||||
`load_aoef_annotated_dataset` for `AOEFAnnotations`).
|
||||
This function acts as a dispatcher. It inspects the format of the input
|
||||
`dataset` object and delegates to the appropriate format-specific loader
|
||||
registered in the `annotation_format_registry` (e.g.,
|
||||
`AOEFLoader` for `AOEFAnnotations`).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_config : AnnotationFormats
|
||||
dataset : AnnotatedDataset
|
||||
The configuration object for the data source, specifying its format
|
||||
and necessary details (like paths). Must be an instance of one of the
|
||||
types included in the `AnnotationFormats` union.
|
||||
base_dir : Path, optional
|
||||
An optional base directory path. If provided, relative paths within
|
||||
the `source_config` might be resolved relative to this directory by
|
||||
the `dataset` will be resolved relative to this directory by
|
||||
the underlying loading functions. Defaults to None.
|
||||
|
||||
Returns
|
||||
@ -92,23 +84,8 @@ def load_annotated_dataset(
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
If the type of the `source_config` object does not match any of the
|
||||
known format-specific loading functions implemented in the dispatch
|
||||
logic.
|
||||
If the `format` field of `dataset` does not match any registered
|
||||
annotation format loader.
|
||||
"""
|
||||
|
||||
if isinstance(dataset, AOEFAnnotations):
|
||||
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)
|
||||
|
||||
if isinstance(dataset, BatDetect2MergedAnnotations):
|
||||
return load_batdetect2_merged_annotated_dataset(
|
||||
dataset, base_dir=base_dir
|
||||
)
|
||||
|
||||
if isinstance(dataset, BatDetect2FilesAnnotations):
|
||||
return load_batdetect2_files_annotated_dataset(
|
||||
dataset,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
raise NotImplementedError(f"Unknown annotation format: {dataset.name}")
|
||||
loader = annotation_format_registry.build(dataset)
|
||||
return loader.load(base_dir=base_dir)
|
||||
|
||||
@ -12,17 +12,22 @@ that meet specific status criteria (e.g., completed, verified, without issues).
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
from uuid import uuid5
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
from batdetect2.data.annotations.registry import annotation_format_registry
|
||||
from batdetect2.data.annotations.types import (
|
||||
AnnotatedDataset,
|
||||
AnnotationLoader,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AOEFAnnotations",
|
||||
"AOEFLoader",
|
||||
"load_aoef_annotated_dataset",
|
||||
"AnnotationTaskFilter",
|
||||
]
|
||||
@ -77,14 +82,30 @@ class AOEFAnnotations(AnnotatedDataset):
|
||||
|
||||
annotations_path: Path
|
||||
|
||||
filter: Optional[AnnotationTaskFilter] = Field(
|
||||
filter: AnnotationTaskFilter | None = Field(
|
||||
default_factory=AnnotationTaskFilter
|
||||
)
|
||||
|
||||
|
||||
class AOEFLoader(AnnotationLoader):
|
||||
def __init__(self, config: AOEFAnnotations):
|
||||
self.config = config
|
||||
|
||||
def load(
|
||||
self,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> data.AnnotationSet:
|
||||
return load_aoef_annotated_dataset(self.config, base_dir=base_dir)
|
||||
|
||||
@annotation_format_registry.register(AOEFAnnotations)
|
||||
@staticmethod
|
||||
def from_config(config: AOEFAnnotations):
|
||||
return AOEFLoader(config)
|
||||
|
||||
|
||||
def load_aoef_annotated_dataset(
|
||||
dataset: AOEFAnnotations,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> data.AnnotationSet:
|
||||
"""Load annotations from an AnnotationSet or AnnotationProject file.
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ aggregated into a `soundevent.data.AnnotationSet`.
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field, ValidationError
|
||||
@ -41,9 +41,13 @@ from batdetect2.data.annotations.legacy import (
|
||||
list_file_annotations,
|
||||
load_file_annotation,
|
||||
)
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
from batdetect2.data.annotations.registry import annotation_format_registry
|
||||
from batdetect2.data.annotations.types import (
|
||||
AnnotatedDataset,
|
||||
AnnotationLoader,
|
||||
)
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
PathLike = Path | str | os.PathLike
|
||||
|
||||
|
||||
__all__ = [
|
||||
@ -102,7 +106,7 @@ class BatDetect2FilesAnnotations(AnnotatedDataset):
|
||||
format: Literal["batdetect2"] = "batdetect2"
|
||||
annotations_dir: Path
|
||||
|
||||
filter: Optional[AnnotationFilter] = Field(
|
||||
filter: AnnotationFilter | None = Field(
|
||||
default_factory=AnnotationFilter,
|
||||
)
|
||||
|
||||
@ -133,14 +137,14 @@ class BatDetect2MergedAnnotations(AnnotatedDataset):
|
||||
format: Literal["batdetect2_file"] = "batdetect2_file"
|
||||
annotations_path: Path
|
||||
|
||||
filter: Optional[AnnotationFilter] = Field(
|
||||
filter: AnnotationFilter | None = Field(
|
||||
default_factory=AnnotationFilter,
|
||||
)
|
||||
|
||||
|
||||
def load_batdetect2_files_annotated_dataset(
|
||||
dataset: BatDetect2FilesAnnotations,
|
||||
base_dir: Optional[PathLike] = None,
|
||||
base_dir: PathLike | None = None,
|
||||
) -> data.AnnotationSet:
|
||||
"""Load and convert 'batdetect2_file' annotations into an AnnotationSet.
|
||||
|
||||
@ -244,7 +248,7 @@ def load_batdetect2_files_annotated_dataset(
|
||||
|
||||
def load_batdetect2_merged_annotated_dataset(
|
||||
dataset: BatDetect2MergedAnnotations,
|
||||
base_dir: Optional[PathLike] = None,
|
||||
base_dir: PathLike | None = None,
|
||||
) -> data.AnnotationSet:
|
||||
"""Load and convert 'batdetect2_merged' annotations into an AnnotationSet.
|
||||
|
||||
@ -302,7 +306,7 @@ def load_batdetect2_merged_annotated_dataset(
|
||||
try:
|
||||
ann = FileAnnotation.model_validate(ann)
|
||||
except ValueError as err:
|
||||
logger.warning(f"Invalid annotation file: {err}")
|
||||
logger.warning("Invalid annotation file: {err}", err=err)
|
||||
continue
|
||||
|
||||
if (
|
||||
@ -310,17 +314,23 @@ def load_batdetect2_merged_annotated_dataset(
|
||||
and dataset.filter.only_annotated
|
||||
and not ann.annotated
|
||||
):
|
||||
logger.debug(f"Skipping incomplete annotation {ann.id}")
|
||||
logger.debug(
|
||||
"Skipping incomplete annotation {ann_id}",
|
||||
ann_id=ann.id,
|
||||
)
|
||||
continue
|
||||
|
||||
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
|
||||
logger.debug(f"Skipping annotation with issues {ann.id}")
|
||||
logger.debug(
|
||||
"Skipping annotation with issues {ann_id}",
|
||||
ann_id=ann.id,
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
||||
except FileNotFoundError as err:
|
||||
logger.warning(f"Error loading annotations: {err}")
|
||||
logger.warning("Error loading annotations: {err}", err=err)
|
||||
continue
|
||||
|
||||
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
||||
@ -330,3 +340,41 @@ def load_batdetect2_merged_annotated_dataset(
|
||||
description=dataset.description,
|
||||
clip_annotations=annotations,
|
||||
)
|
||||
|
||||
|
||||
class BatDetect2MergedLoader(AnnotationLoader):
|
||||
def __init__(self, config: BatDetect2MergedAnnotations):
|
||||
self.config = config
|
||||
|
||||
def load(
|
||||
self,
|
||||
base_dir: PathLike | None = None,
|
||||
) -> data.AnnotationSet:
|
||||
return load_batdetect2_merged_annotated_dataset(
|
||||
self.config,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
@annotation_format_registry.register(BatDetect2MergedAnnotations)
|
||||
@staticmethod
|
||||
def from_config(config: BatDetect2MergedAnnotations):
|
||||
return BatDetect2MergedLoader(config)
|
||||
|
||||
|
||||
class BatDetect2FilesLoader(AnnotationLoader):
|
||||
def __init__(self, config: BatDetect2FilesAnnotations):
|
||||
self.config = config
|
||||
|
||||
def load(
|
||||
self,
|
||||
base_dir: PathLike | None = None,
|
||||
) -> data.AnnotationSet:
|
||||
return load_batdetect2_files_annotated_dataset(
|
||||
self.config,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
@annotation_format_registry.register(BatDetect2FilesAnnotations)
|
||||
@staticmethod
|
||||
def from_config(config: BatDetect2FilesAnnotations):
|
||||
return BatDetect2FilesLoader(config)
|
||||
|
||||
@ -3,12 +3,12 @@
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Callable, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
PathLike = Path | str | os.PathLike
|
||||
|
||||
__all__ = []
|
||||
|
||||
@ -27,7 +27,7 @@ SOUND_EVENT_ANNOTATION_NAMESPACE = uuid.uuid5(
|
||||
)
|
||||
|
||||
|
||||
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||
EventFn = Callable[[data.SoundEventAnnotation], str | None]
|
||||
|
||||
ClassFn = Callable[[data.Recording], int]
|
||||
|
||||
@ -130,7 +130,7 @@ def get_sound_event_tags(
|
||||
|
||||
def file_annotation_to_clip(
|
||||
file_annotation: FileAnnotation,
|
||||
audio_dir: Optional[PathLike] = None,
|
||||
audio_dir: PathLike | None = None,
|
||||
label_key: str = "class",
|
||||
) -> data.Clip:
|
||||
"""Convert file annotation to recording."""
|
||||
|
||||
35
src/batdetect2/data/annotations/registry.py
Normal file
35
src/batdetect2/data/annotations/registry.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import Literal
|
||||
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.data.annotations.types import AnnotationLoader
|
||||
|
||||
__all__ = [
|
||||
"AnnotationFormatImportConfig",
|
||||
"annotation_format_registry",
|
||||
]
|
||||
|
||||
annotation_format_registry: Registry[AnnotationLoader, []] = Registry(
|
||||
"annotation_format",
|
||||
discriminator="format",
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(annotation_format_registry)
|
||||
class AnnotationFormatImportConfig(ImportConfig):
|
||||
"""Import escape hatch for the annotation format registry.
|
||||
|
||||
Use this config to dynamically instantiate any callable as an
|
||||
annotation loader without registering it in
|
||||
``annotation_format_registry`` ahead of time.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
format : Literal["import"]
|
||||
Discriminator value; must always be ``"import"``.
|
||||
target : str
|
||||
Fully-qualified dotted path to the callable to instantiate.
|
||||
arguments : dict[str, Any]
|
||||
Keyword arguments forwarded to the callable.
|
||||
"""
|
||||
|
||||
format: Literal["import"] = "import"
|
||||
@ -1,9 +1,13 @@
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"AnnotatedDataset",
|
||||
"AnnotationLoader",
|
||||
]
|
||||
|
||||
|
||||
@ -34,3 +38,10 @@ class AnnotatedDataset(BaseConfig):
|
||||
name: str
|
||||
audio_dir: Path
|
||||
description: str = ""
|
||||
|
||||
|
||||
class AnnotationLoader(Protocol):
|
||||
def load(
|
||||
self,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> data.AnnotationSet: ...
|
||||
|
||||
@ -1,18 +1,33 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Annotated, List, Literal, Sequence, Union
|
||||
from typing import Annotated, List, Literal, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
|
||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||
|
||||
conditions: Registry[SoundEventCondition, []] = Registry("condition")
|
||||
|
||||
|
||||
@add_import_config(conditions)
|
||||
class SoundEventConditionImportConfig(ImportConfig):
|
||||
"""Use any callable as a sound event condition.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class HasTagConfig(BaseConfig):
|
||||
name: Literal["has_tag"] = "has_tag"
|
||||
tag: data.Tag
|
||||
@ -264,16 +279,14 @@ class Not:
|
||||
|
||||
|
||||
SoundEventConditionConfig = Annotated[
|
||||
Union[
|
||||
HasTagConfig,
|
||||
HasAllTagsConfig,
|
||||
HasAnyTagConfig,
|
||||
DurationConfig,
|
||||
FrequencyConfig,
|
||||
AllOfConfig,
|
||||
AnyOfConfig,
|
||||
NotConfig,
|
||||
],
|
||||
HasTagConfig
|
||||
| HasAllTagsConfig
|
||||
| HasAnyTagConfig
|
||||
| DurationConfig
|
||||
| FrequencyConfig
|
||||
| AllOfConfig
|
||||
| AnyOfConfig
|
||||
| NotConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ The core components are:
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Sequence
|
||||
from typing import List, Sequence
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
@ -69,7 +69,7 @@ class DatasetConfig(BaseConfig):
|
||||
description: str
|
||||
sources: List[AnnotationFormats]
|
||||
|
||||
sound_event_filter: Optional[SoundEventConditionConfig] = None
|
||||
sound_event_filter: SoundEventConditionConfig | None = None
|
||||
sound_event_transforms: List[SoundEventTransformConfig] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
@ -77,7 +77,7 @@ class DatasetConfig(BaseConfig):
|
||||
|
||||
def load_dataset(
|
||||
config: DatasetConfig,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> Dataset:
|
||||
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
||||
clip_annotations = []
|
||||
@ -161,14 +161,14 @@ def insert_source_tag(
|
||||
)
|
||||
|
||||
|
||||
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
||||
def load_dataset_config(path: data.PathLike, field: str | None = None):
|
||||
return load_config(path=path, schema=DatasetConfig, field=field)
|
||||
|
||||
|
||||
def load_dataset_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
field: str | None = None,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> Dataset:
|
||||
"""Load dataset annotation metadata from a configuration file.
|
||||
|
||||
@ -215,9 +215,9 @@ def load_dataset_from_config(
|
||||
def save_dataset(
|
||||
dataset: Dataset,
|
||||
path: data.PathLike,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
audio_dir: Optional[Path] = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
audio_dir: Path | None = None,
|
||||
) -> None:
|
||||
"""Save a loaded dataset (list of ClipAnnotations) to a file.
|
||||
|
||||
|
||||
@ -1,16 +1,15 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def iterate_over_sound_events(
|
||||
dataset: Dataset,
|
||||
targets: TargetProtocol,
|
||||
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
|
||||
) -> Generator[tuple[str | None, data.SoundEventAnnotation], None, None]:
|
||||
"""Iterate over sound events in a dataset.
|
||||
|
||||
Parameters
|
||||
@ -24,7 +23,7 @@ def iterate_over_sound_events(
|
||||
|
||||
Yields
|
||||
------
|
||||
Tuple[Optional[str], data.SoundEventAnnotation]
|
||||
tuple[Optional[str], data.SoundEventAnnotation]
|
||||
A tuple containing:
|
||||
- The encoded class name (str) for the sound event, or None if it
|
||||
cannot be encoded to a specific class.
|
||||
|
||||
@ -1,29 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.typing import (
|
||||
OutputFormatterProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
|
||||
def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
|
||||
path = Path(path)
|
||||
audio_dir = Path(audio_dir)
|
||||
|
||||
if path.is_absolute():
|
||||
if not path.is_relative_to(audio_dir):
|
||||
raise ValueError(
|
||||
f"Audio file {path} is not in audio_dir {audio_dir}"
|
||||
)
|
||||
|
||||
return path.relative_to(audio_dir)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
|
||||
Registry(name="output_formatter")
|
||||
)
|
||||
@ -1,5 +1,3 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from batdetect2.data.datasets import Dataset
|
||||
@ -7,15 +5,15 @@ from batdetect2.data.summary import (
|
||||
extract_recordings_df,
|
||||
extract_sound_events_df,
|
||||
)
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def split_dataset_by_recordings(
|
||||
dataset: Dataset,
|
||||
targets: TargetProtocol,
|
||||
train_size: float = 0.75,
|
||||
random_state: Optional[int] = None,
|
||||
) -> Tuple[Dataset, Dataset]:
|
||||
random_state: int | None = None,
|
||||
) -> tuple[Dataset, Dataset]:
|
||||
recordings = extract_recordings_df(dataset)
|
||||
|
||||
sound_events = extract_sound_events_df(
|
||||
@ -26,13 +24,15 @@ def split_dataset_by_recordings(
|
||||
)
|
||||
|
||||
majority_class = (
|
||||
sound_events.groupby("recording_id")
|
||||
sound_events.groupby("recording_id") # type: ignore
|
||||
.apply(
|
||||
lambda group: group["class_name"] # type: ignore
|
||||
.value_counts()
|
||||
.sort_values(ascending=False)
|
||||
.index[0],
|
||||
include_groups=False, # type: ignore
|
||||
lambda group: (
|
||||
group["class_name"]
|
||||
.value_counts()
|
||||
.sort_values(ascending=False)
|
||||
.index[0]
|
||||
),
|
||||
include_groups=False,
|
||||
)
|
||||
.rename("class_name")
|
||||
.to_frame()
|
||||
@ -46,8 +46,8 @@ def split_dataset_by_recordings(
|
||||
random_state=random_state,
|
||||
)
|
||||
|
||||
train_ids_set = set(train.values) # type: ignore
|
||||
test_ids_set = set(test.values) # type: ignore
|
||||
train_ids_set = set(train.values)
|
||||
test_ids_set = set(test.values)
|
||||
|
||||
extra = set(recordings["recording_id"]) - train_ids_set - test_ids_set
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ import pandas as pd
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"extract_recordings_df",
|
||||
@ -175,14 +175,14 @@ def compute_class_summary(
|
||||
.rename("num recordings")
|
||||
)
|
||||
durations = (
|
||||
sound_events.groupby("class_name")
|
||||
sound_events.groupby("class_name") # ty: ignore[no-matching-overload]
|
||||
.apply(
|
||||
lambda group: recordings[
|
||||
recordings["clip_annotation_id"].isin(
|
||||
group["clip_annotation_id"] # type: ignore
|
||||
group["clip_annotation_id"]
|
||||
)
|
||||
]["duration"].sum(),
|
||||
include_groups=False, # type: ignore
|
||||
include_groups=False,
|
||||
)
|
||||
.sort_values(ascending=False)
|
||||
.rename("duration")
|
||||
|
||||
@ -1,11 +1,15 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Annotated, Dict, List, Literal, Optional, Union
|
||||
from typing import Annotated, Dict, List, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.data.conditions import (
|
||||
SoundEventCondition,
|
||||
SoundEventConditionConfig,
|
||||
@ -20,6 +24,17 @@ SoundEventTransform = Callable[
|
||||
transforms: Registry[SoundEventTransform, []] = Registry("transform")
|
||||
|
||||
|
||||
@add_import_config(transforms)
|
||||
class SoundEventTransformImportConfig(ImportConfig):
|
||||
"""Use any callable as a sound event transform.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class SetFrequencyBoundConfig(BaseConfig):
|
||||
name: Literal["set_frequency"] = "set_frequency"
|
||||
boundary: Literal["low", "high"] = "low"
|
||||
@ -142,7 +157,7 @@ class MapTagValueConfig(BaseConfig):
|
||||
name: Literal["map_tag_value"] = "map_tag_value"
|
||||
tag_key: str
|
||||
value_mapping: Dict[str, str]
|
||||
target_key: Optional[str] = None
|
||||
target_key: str | None = None
|
||||
|
||||
|
||||
class MapTagValue:
|
||||
@ -150,7 +165,7 @@ class MapTagValue:
|
||||
self,
|
||||
tag_key: str,
|
||||
value_mapping: Dict[str, str],
|
||||
target_key: Optional[str] = None,
|
||||
target_key: str | None = None,
|
||||
):
|
||||
self.tag_key = tag_key
|
||||
self.value_mapping = value_mapping
|
||||
@ -176,12 +191,7 @@ class MapTagValue:
|
||||
if self.target_key is None:
|
||||
tags.append(tag.model_copy(update=dict(value=value)))
|
||||
else:
|
||||
tags.append(
|
||||
data.Tag(
|
||||
key=self.target_key, # type: ignore
|
||||
value=value,
|
||||
)
|
||||
)
|
||||
tags.append(data.Tag(key=self.target_key, value=value))
|
||||
|
||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||
|
||||
@ -221,13 +231,11 @@ class ApplyAll:
|
||||
|
||||
|
||||
SoundEventTransformConfig = Annotated[
|
||||
Union[
|
||||
SetFrequencyBoundConfig,
|
||||
ReplaceTagConfig,
|
||||
MapTagValueConfig,
|
||||
ApplyIfConfig,
|
||||
ApplyAllConfig,
|
||||
],
|
||||
SetFrequencyBoundConfig
|
||||
| ReplaceTagConfig
|
||||
| MapTagValueConfig
|
||||
| ApplyIfConfig
|
||||
| ApplyAllConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""Functions to compute features from predictions."""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -86,7 +86,7 @@ def compute_bandwidth(
|
||||
|
||||
def compute_max_power_bb(
|
||||
prediction: types.Prediction,
|
||||
spec: Optional[np.ndarray] = None,
|
||||
spec: np.ndarray | None = None,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
**_,
|
||||
@ -131,7 +131,7 @@ def compute_max_power_bb(
|
||||
|
||||
def compute_max_power(
|
||||
prediction: types.Prediction,
|
||||
spec: Optional[np.ndarray] = None,
|
||||
spec: np.ndarray | None = None,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
**_,
|
||||
@ -157,7 +157,7 @@ def compute_max_power(
|
||||
|
||||
def compute_max_power_first(
|
||||
prediction: types.Prediction,
|
||||
spec: Optional[np.ndarray] = None,
|
||||
spec: np.ndarray | None = None,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
**_,
|
||||
@ -184,7 +184,7 @@ def compute_max_power_first(
|
||||
|
||||
def compute_max_power_second(
|
||||
prediction: types.Prediction,
|
||||
spec: Optional[np.ndarray] = None,
|
||||
spec: np.ndarray | None = None,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
**_,
|
||||
@ -211,7 +211,7 @@ def compute_max_power_second(
|
||||
|
||||
def compute_call_interval(
|
||||
prediction: types.Prediction,
|
||||
previous: Optional[types.Prediction] = None,
|
||||
previous: types.Prediction | None = None,
|
||||
**_,
|
||||
) -> float:
|
||||
"""Compute time between this call and the previous call in seconds."""
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field, computed_field
|
||||
|
||||
@ -198,8 +198,8 @@ class TrainingParameters(BaseModel):
|
||||
def get_params(
|
||||
make_dirs: bool = False,
|
||||
exps_dir: str = "../../experiments/",
|
||||
model_name: Optional[str] = None,
|
||||
experiment: Union[Path, str, None] = None,
|
||||
model_name: str | None = None,
|
||||
experiment: Path | str | None = None,
|
||||
**kwargs,
|
||||
) -> TrainingParameters:
|
||||
experiments_dir = Path(exps_dir)
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
"""Post-processing of the output of the model."""
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -45,7 +43,7 @@ def run_nms(
|
||||
outputs: ModelOutput,
|
||||
params: NonMaximumSuppressionConfig,
|
||||
sampling_rate: np.ndarray,
|
||||
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
|
||||
) -> tuple[list[PredictionResults], list[np.ndarray]]:
|
||||
"""Run non-maximum suppression on the output of the model.
|
||||
|
||||
Model outputs processed are expected to have a batch dimension.
|
||||
@ -73,8 +71,8 @@ def run_nms(
|
||||
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
|
||||
|
||||
# loop over batch to save outputs
|
||||
preds: List[PredictionResults] = []
|
||||
feats: List[np.ndarray] = []
|
||||
preds: list[PredictionResults] = []
|
||||
feats: list[np.ndarray] = []
|
||||
for num_detection in range(pred_det_nms.shape[0]):
|
||||
# get valid indices
|
||||
inds_ord = torch.argsort(x_pos[num_detection, :])
|
||||
@ -151,7 +149,7 @@ def run_nms(
|
||||
|
||||
def non_max_suppression(
|
||||
heat: torch.Tensor,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
kernel_size: int | tuple[int, int],
|
||||
):
|
||||
# kernel can be an int or list/tuple
|
||||
if isinstance(kernel_size, int):
|
||||
|
||||
@ -1,15 +1,32 @@
|
||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate
|
||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||
from batdetect2.evaluate.results import save_evaluation_results
|
||||
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
||||
from batdetect2.evaluate.types import (
|
||||
AffinityFunction,
|
||||
ClipMatches,
|
||||
EvaluationTaskProtocol,
|
||||
EvaluatorProtocol,
|
||||
MetricsProtocol,
|
||||
PlotterProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AffinityFunction",
|
||||
"ClipMatches",
|
||||
"DEFAULT_EVAL_DIR",
|
||||
"EvaluationConfig",
|
||||
"EvaluationTaskProtocol",
|
||||
"Evaluator",
|
||||
"EvaluatorProtocol",
|
||||
"MatchEvaluation",
|
||||
"MatcherProtocol",
|
||||
"MetricsProtocol",
|
||||
"PlotterProtocol",
|
||||
"TaskConfig",
|
||||
"build_evaluator",
|
||||
"build_task",
|
||||
"evaluate",
|
||||
"load_evaluation_config",
|
||||
"DEFAULT_EVAL_DIR",
|
||||
"run_evaluate",
|
||||
"save_evaluation_results",
|
||||
]
|
||||
|
||||
@ -1,76 +1,116 @@
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import compute_affinity
|
||||
from soundevent.geometry import compute_interval_overlap
|
||||
from soundevent.geometry import (
|
||||
buffer_geometry,
|
||||
compute_bbox_iou,
|
||||
compute_geometric_iou,
|
||||
compute_temporal_closeness,
|
||||
compute_temporal_iou,
|
||||
)
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.typing.evaluate import AffinityFunction
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.evaluate.types import AffinityFunction
|
||||
from batdetect2.postprocess.types import Detection
|
||||
|
||||
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
||||
"matching_strategy"
|
||||
"affinity_function"
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(affinity_functions)
|
||||
class AffinityFunctionImportConfig(ImportConfig):
|
||||
"""Use any callable as an affinity function.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class TimeAffinityConfig(BaseConfig):
|
||||
name: Literal["time_affinity"] = "time_affinity"
|
||||
time_buffer: float = 0.01
|
||||
position: Literal["start", "end", "center"] | float = "start"
|
||||
max_distance: float = 0.01
|
||||
|
||||
|
||||
class TimeAffinity(AffinityFunction):
|
||||
def __init__(self, time_buffer: float):
|
||||
self.time_buffer = time_buffer
|
||||
def __init__(
|
||||
self,
|
||||
max_distance: float = 0.01,
|
||||
position: Literal["start", "end", "center"] | float = "start",
|
||||
):
|
||||
if position == "start":
|
||||
position = 0
|
||||
elif position == "end":
|
||||
position = 1
|
||||
elif position == "center":
|
||||
position = 0.5
|
||||
|
||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||
return compute_timestamp_affinity(
|
||||
geometry1, geometry2, time_buffer=self.time_buffer
|
||||
self.position = position
|
||||
self.max_distance = max_distance
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
detection: Detection,
|
||||
ground_truth: data.SoundEventAnnotation,
|
||||
) -> float:
|
||||
target_geometry = ground_truth.sound_event.geometry
|
||||
source_geometry = detection.geometry
|
||||
return compute_temporal_closeness(
|
||||
target_geometry,
|
||||
source_geometry,
|
||||
ratio=self.position,
|
||||
max_distance=self.max_distance,
|
||||
)
|
||||
|
||||
@affinity_functions.register(TimeAffinityConfig)
|
||||
@staticmethod
|
||||
def from_config(config: TimeAffinityConfig):
|
||||
return TimeAffinity(time_buffer=config.time_buffer)
|
||||
|
||||
|
||||
def compute_timestamp_affinity(
|
||||
geometry1: data.Geometry,
|
||||
geometry2: data.Geometry,
|
||||
time_buffer: float = 0.01,
|
||||
) -> float:
|
||||
assert isinstance(geometry1, data.TimeStamp)
|
||||
assert isinstance(geometry2, data.TimeStamp)
|
||||
|
||||
start_time1 = geometry1.coordinates
|
||||
start_time2 = geometry2.coordinates
|
||||
|
||||
a = min(start_time1, start_time2)
|
||||
b = max(start_time1, start_time2)
|
||||
|
||||
if b - a >= 2 * time_buffer:
|
||||
return 0
|
||||
|
||||
intersection = a - b + 2 * time_buffer
|
||||
union = b - a + 2 * time_buffer
|
||||
return intersection / union
|
||||
return TimeAffinity(
|
||||
max_distance=config.max_distance,
|
||||
position=config.position,
|
||||
)
|
||||
|
||||
|
||||
class IntervalIOUConfig(BaseConfig):
|
||||
name: Literal["interval_iou"] = "interval_iou"
|
||||
time_buffer: float = 0.01
|
||||
time_buffer: float = 0.0
|
||||
|
||||
|
||||
class IntervalIOU(AffinityFunction):
|
||||
def __init__(self, time_buffer: float):
|
||||
if time_buffer < 0:
|
||||
raise ValueError("time_buffer must be non-negative")
|
||||
|
||||
self.time_buffer = time_buffer
|
||||
|
||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||
return compute_interval_iou(
|
||||
geometry1,
|
||||
geometry2,
|
||||
time_buffer=self.time_buffer,
|
||||
)
|
||||
def __call__(
|
||||
self,
|
||||
detection: Detection,
|
||||
ground_truth: data.SoundEventAnnotation,
|
||||
) -> float:
|
||||
target_geometry = ground_truth.sound_event.geometry
|
||||
source_geometry = detection.geometry
|
||||
|
||||
if self.time_buffer > 0:
|
||||
target_geometry = buffer_geometry(
|
||||
target_geometry,
|
||||
time=self.time_buffer,
|
||||
)
|
||||
source_geometry = buffer_geometry(
|
||||
source_geometry,
|
||||
time=self.time_buffer,
|
||||
)
|
||||
|
||||
return compute_temporal_iou(target_geometry, source_geometry)
|
||||
|
||||
@affinity_functions.register(IntervalIOUConfig)
|
||||
@staticmethod
|
||||
@ -78,64 +118,44 @@ class IntervalIOU(AffinityFunction):
|
||||
return IntervalIOU(time_buffer=config.time_buffer)
|
||||
|
||||
|
||||
def compute_interval_iou(
|
||||
geometry1: data.Geometry,
|
||||
geometry2: data.Geometry,
|
||||
time_buffer: float = 0.01,
|
||||
) -> float:
|
||||
assert isinstance(geometry1, data.TimeInterval)
|
||||
assert isinstance(geometry2, data.TimeInterval)
|
||||
|
||||
start_time1, end_time1 = geometry1.coordinates
|
||||
start_time2, end_time2 = geometry1.coordinates
|
||||
|
||||
start_time1 -= time_buffer
|
||||
start_time2 -= time_buffer
|
||||
end_time1 += time_buffer
|
||||
end_time2 += time_buffer
|
||||
|
||||
intersection = compute_interval_overlap(
|
||||
(start_time1, end_time1),
|
||||
(start_time2, end_time2),
|
||||
)
|
||||
|
||||
union = (
|
||||
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
|
||||
)
|
||||
|
||||
if union == 0:
|
||||
return 0
|
||||
|
||||
return intersection / union
|
||||
|
||||
|
||||
class BBoxIOUConfig(BaseConfig):
|
||||
name: Literal["bbox_iou"] = "bbox_iou"
|
||||
time_buffer: float = 0.01
|
||||
freq_buffer: float = 1000
|
||||
time_buffer: float = 0.0
|
||||
freq_buffer: float = 0.0
|
||||
|
||||
|
||||
class BBoxIOU(AffinityFunction):
|
||||
def __init__(self, time_buffer: float, freq_buffer: float):
|
||||
if time_buffer < 0:
|
||||
raise ValueError("time_buffer must be non-negative")
|
||||
|
||||
if freq_buffer < 0:
|
||||
raise ValueError("freq_buffer must be non-negative")
|
||||
|
||||
self.time_buffer = time_buffer
|
||||
self.freq_buffer = freq_buffer
|
||||
|
||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||
if not isinstance(geometry1, data.BoundingBox):
|
||||
raise TypeError(
|
||||
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}"
|
||||
def __call__(
|
||||
self,
|
||||
detection: Detection,
|
||||
ground_truth: data.SoundEventAnnotation,
|
||||
):
|
||||
target_geometry = ground_truth.sound_event.geometry
|
||||
source_geometry = detection.geometry
|
||||
|
||||
if self.time_buffer > 0 or self.freq_buffer > 0:
|
||||
target_geometry = buffer_geometry(
|
||||
target_geometry,
|
||||
time=self.time_buffer,
|
||||
freq=self.freq_buffer,
|
||||
)
|
||||
source_geometry = buffer_geometry(
|
||||
source_geometry,
|
||||
time=self.time_buffer,
|
||||
freq=self.freq_buffer,
|
||||
)
|
||||
|
||||
if not isinstance(geometry2, data.BoundingBox):
|
||||
raise TypeError(
|
||||
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
|
||||
)
|
||||
return bbox_iou(
|
||||
geometry1,
|
||||
geometry2,
|
||||
time_buffer=self.time_buffer,
|
||||
freq_buffer=self.freq_buffer,
|
||||
)
|
||||
return compute_bbox_iou(target_geometry, source_geometry)
|
||||
|
||||
@affinity_functions.register(BBoxIOUConfig)
|
||||
@staticmethod
|
||||
@ -146,65 +166,44 @@ class BBoxIOU(AffinityFunction):
|
||||
)
|
||||
|
||||
|
||||
def bbox_iou(
|
||||
geometry1: data.BoundingBox,
|
||||
geometry2: data.BoundingBox,
|
||||
time_buffer: float = 0.01,
|
||||
freq_buffer: float = 1000,
|
||||
) -> float:
|
||||
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
|
||||
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
|
||||
|
||||
start_time1 -= time_buffer
|
||||
start_time2 -= time_buffer
|
||||
end_time1 += time_buffer
|
||||
end_time2 += time_buffer
|
||||
|
||||
low_freq1 -= freq_buffer
|
||||
low_freq2 -= freq_buffer
|
||||
high_freq1 += freq_buffer
|
||||
high_freq2 += freq_buffer
|
||||
|
||||
time_intersection = compute_interval_overlap(
|
||||
(start_time1, end_time1),
|
||||
(start_time2, end_time2),
|
||||
)
|
||||
|
||||
freq_intersection = max(
|
||||
0,
|
||||
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
|
||||
)
|
||||
|
||||
intersection = time_intersection * freq_intersection
|
||||
|
||||
if intersection == 0:
|
||||
return 0
|
||||
|
||||
union = (
|
||||
(end_time1 - start_time1) * (high_freq1 - low_freq1)
|
||||
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
|
||||
- intersection
|
||||
)
|
||||
|
||||
return intersection / union
|
||||
|
||||
|
||||
class GeometricIOUConfig(BaseConfig):
|
||||
name: Literal["geometric_iou"] = "geometric_iou"
|
||||
time_buffer: float = 0.01
|
||||
freq_buffer: float = 1000
|
||||
time_buffer: float = 0.0
|
||||
freq_buffer: float = 0.0
|
||||
|
||||
|
||||
class GeometricIOU(AffinityFunction):
|
||||
def __init__(self, time_buffer: float):
|
||||
self.time_buffer = time_buffer
|
||||
def __init__(self, time_buffer: float = 0, freq_buffer: float = 0):
|
||||
if time_buffer < 0:
|
||||
raise ValueError("time_buffer must be non-negative")
|
||||
|
||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||
return compute_affinity(
|
||||
geometry1,
|
||||
geometry2,
|
||||
time_buffer=self.time_buffer,
|
||||
)
|
||||
if freq_buffer < 0:
|
||||
raise ValueError("freq_buffer must be non-negative")
|
||||
|
||||
self.time_buffer = time_buffer
|
||||
self.freq_buffer = freq_buffer
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
detection: Detection,
|
||||
ground_truth: data.SoundEventAnnotation,
|
||||
):
|
||||
target_geometry = ground_truth.sound_event.geometry
|
||||
source_geometry = detection.geometry
|
||||
|
||||
if self.time_buffer > 0 or self.freq_buffer > 0:
|
||||
target_geometry = buffer_geometry(
|
||||
target_geometry,
|
||||
time=self.time_buffer,
|
||||
freq=self.freq_buffer,
|
||||
)
|
||||
source_geometry = buffer_geometry(
|
||||
source_geometry,
|
||||
time=self.time_buffer,
|
||||
freq=self.freq_buffer,
|
||||
)
|
||||
|
||||
return compute_geometric_iou(target_geometry, source_geometry)
|
||||
|
||||
@affinity_functions.register(GeometricIOUConfig)
|
||||
@staticmethod
|
||||
@ -213,18 +212,16 @@ class GeometricIOU(AffinityFunction):
|
||||
|
||||
|
||||
AffinityConfig = Annotated[
|
||||
Union[
|
||||
TimeAffinityConfig,
|
||||
IntervalIOUConfig,
|
||||
BBoxIOUConfig,
|
||||
GeometricIOUConfig,
|
||||
],
|
||||
TimeAffinityConfig
|
||||
| IntervalIOUConfig
|
||||
| BBoxIOUConfig
|
||||
| GeometricIOUConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_affinity_function(
|
||||
config: Optional[AffinityConfig] = None,
|
||||
config: AffinityConfig | None = None,
|
||||
) -> AffinityFunction:
|
||||
config = config or GeometricIOUConfig()
|
||||
return affinity_functions.build(config)
|
||||
|
||||
@ -1,19 +1,14 @@
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.tasks import (
|
||||
TaskConfig,
|
||||
)
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.evaluate.tasks import TaskConfig
|
||||
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
|
||||
|
||||
__all__ = [
|
||||
"EvaluationConfig",
|
||||
"load_evaluation_config",
|
||||
]
|
||||
|
||||
|
||||
@ -24,7 +19,6 @@ class EvaluationConfig(BaseConfig):
|
||||
ClassificationTaskConfig(),
|
||||
]
|
||||
)
|
||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
|
||||
|
||||
def get_default_eval_config() -> EvaluationConfig:
|
||||
@ -47,10 +41,3 @@ def get_default_eval_config() -> EvaluationConfig:
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def load_evaluation_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> EvaluationConfig:
|
||||
return load_config(path, schema=EvaluationConfig, field=field)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, NamedTuple, Optional, Sequence
|
||||
from typing import List, NamedTuple, Sequence
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
@ -8,14 +8,11 @@ from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
||||
from batdetect2.audio.clips import PaddedClipConfig
|
||||
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipperProtocol,
|
||||
PreprocessorProtocol,
|
||||
)
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"TestDataset",
|
||||
@ -39,8 +36,8 @@ class TestDataset(Dataset[TestExample]):
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
clipper: Optional[ClipperProtocol] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
clipper: ClipperProtocol | None = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
):
|
||||
self.clip_annotations = list(clip_annotations)
|
||||
self.clipper = clipper
|
||||
@ -51,8 +48,8 @@ class TestDataset(Dataset[TestExample]):
|
||||
def __len__(self):
|
||||
return len(self.clip_annotations)
|
||||
|
||||
def __getitem__(self, idx: int) -> TestExample:
|
||||
clip_annotation = self.clip_annotations[idx]
|
||||
def __getitem__(self, index: int) -> TestExample:
|
||||
clip_annotation = self.clip_annotations[index]
|
||||
|
||||
if self.clipper is not None:
|
||||
clip_annotation = self.clipper(clip_annotation)
|
||||
@ -63,14 +60,13 @@ class TestDataset(Dataset[TestExample]):
|
||||
spectrogram = self.preprocessor(wav_tensor)
|
||||
return TestExample(
|
||||
spec=spectrogram,
|
||||
idx=torch.tensor(idx),
|
||||
idx=torch.tensor(index),
|
||||
start_time=torch.tensor(clip.start_time),
|
||||
end_time=torch.tensor(clip.end_time),
|
||||
)
|
||||
|
||||
|
||||
class TestLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: PaddedClipConfig()
|
||||
)
|
||||
@ -78,10 +74,10 @@ class TestLoaderConfig(BaseConfig):
|
||||
|
||||
def build_test_loader(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[TestLoaderConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
config: TestLoaderConfig | None = None,
|
||||
num_workers: int = 0,
|
||||
) -> DataLoader[TestExample]:
|
||||
logger.info("Building test data loader...")
|
||||
config = config or TestLoaderConfig()
|
||||
@ -97,7 +93,6 @@ def build_test_loader(
|
||||
config=config,
|
||||
)
|
||||
|
||||
num_workers = num_workers or config.num_workers
|
||||
return DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
@ -109,9 +104,9 @@ def build_test_loader(
|
||||
|
||||
def build_test_dataset(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[TestLoaderConfig] = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
config: TestLoaderConfig | None = None,
|
||||
) -> TestDataset:
|
||||
logger.info("Building training dataset...")
|
||||
config = config or TestLoaderConfig()
|
||||
|
||||
@ -1,56 +1,51 @@
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import Sequence
|
||||
|
||||
from lightning import Trainer
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.evaluate import EvaluationConfig
|
||||
from batdetect2.evaluate.dataset import build_test_loader
|
||||
from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.evaluate.lightning import EvaluationModule
|
||||
from batdetect2.logging import build_logger
|
||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
OutputFormatterProtocol,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.outputs import OutputsConfig, build_output_transform
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||
|
||||
|
||||
def evaluate(
|
||||
def run_evaluate(
|
||||
model: Model,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
formatter: Optional["OutputFormatterProtocol"] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
evaluation_config: EvaluationConfig | None = None,
|
||||
output_config: OutputsConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
formatter: OutputFormatterProtocol | None = None,
|
||||
num_workers: int = 0,
|
||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> tuple[dict[str, float], list[ClipDetections]]:
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
audio_config = audio_config or AudioConfig()
|
||||
evaluation_config = evaluation_config or EvaluationConfig()
|
||||
output_config = output_config or OutputsConfig()
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
config=config.preprocess,
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
|
||||
targets = targets or build_targets(config=config.targets)
|
||||
preprocessor = preprocessor or model.preprocessor
|
||||
targets = targets or model.targets
|
||||
|
||||
loader = build_test_loader(
|
||||
test_annotations,
|
||||
@ -59,15 +54,26 @@ def evaluate(
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
output_transform = build_output_transform(
|
||||
config=output_config.transform,
|
||||
targets=targets,
|
||||
)
|
||||
evaluator = build_evaluator(
|
||||
config=evaluation_config,
|
||||
targets=targets,
|
||||
transform=output_transform,
|
||||
)
|
||||
|
||||
logger = build_logger(
|
||||
config.evaluation.logger,
|
||||
logger_config or CSVLoggerConfig(),
|
||||
log_dir=Path(output_dir),
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
module = EvaluationModule(model, evaluator)
|
||||
module = EvaluationModule(
|
||||
model,
|
||||
evaluator,
|
||||
)
|
||||
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
||||
metrics = trainer.test(module, loader)
|
||||
|
||||
|
||||
@ -1,13 +1,15 @@
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Sequence, Tuple
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.evaluate.tasks import build_task
|
||||
from batdetect2.evaluate.types import EvaluationTaskProtocol, EvaluatorProtocol
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess.types import ClipDetections, ClipDetectionsTensor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"Evaluator",
|
||||
@ -19,15 +21,27 @@ class Evaluator:
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
tasks: Sequence[EvaluatorProtocol],
|
||||
transform: OutputTransformProtocol,
|
||||
tasks: Sequence[EvaluationTaskProtocol],
|
||||
):
|
||||
self.targets = targets
|
||||
self.transform = transform
|
||||
self.tasks = tasks
|
||||
|
||||
def to_clip_detections_batch(
|
||||
self,
|
||||
clip_detections: Sequence[ClipDetectionsTensor],
|
||||
clips: Sequence[data.Clip],
|
||||
) -> list[ClipDetections]:
|
||||
return [
|
||||
self.transform.to_clip_detections(detections=dets, clip=clip)
|
||||
for dets, clip in zip(clip_detections, clips, strict=False)
|
||||
]
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
predictions: Sequence[ClipDetections],
|
||||
) -> List[Any]:
|
||||
return [
|
||||
task.evaluate(clip_annotations, predictions) for task in self.tasks
|
||||
@ -36,7 +50,7 @@ class Evaluator:
|
||||
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
|
||||
results = {}
|
||||
|
||||
for task, outputs in zip(self.tasks, eval_outputs):
|
||||
for task, outputs in zip(self.tasks, eval_outputs, strict=False):
|
||||
results.update(task.compute_metrics(outputs))
|
||||
|
||||
return results
|
||||
@ -45,14 +59,15 @@ class Evaluator:
|
||||
self,
|
||||
eval_outputs: List[Any],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
for task, outputs in zip(self.tasks, eval_outputs):
|
||||
for task, outputs in zip(self.tasks, eval_outputs, strict=False):
|
||||
for name, fig in task.generate_plots(outputs):
|
||||
yield name, fig
|
||||
|
||||
|
||||
def build_evaluator(
|
||||
config: Optional[Union[EvaluationConfig, dict]] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
config: EvaluationConfig | dict | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
transform: OutputTransformProtocol | None = None,
|
||||
) -> EvaluatorProtocol:
|
||||
targets = targets or build_targets()
|
||||
|
||||
@ -62,7 +77,10 @@ def build_evaluator(
|
||||
if not isinstance(config, EvaluationConfig):
|
||||
config = EvaluationConfig.model_validate(config)
|
||||
|
||||
transform = transform or build_output_transform(targets=targets)
|
||||
|
||||
return Evaluator(
|
||||
targets=targets,
|
||||
transform=transform,
|
||||
tasks=[build_task(task, targets=targets) for task in config.tasks],
|
||||
)
|
||||
|
||||
@ -357,7 +357,7 @@ def train_rf_model(x_train, y_train, num_classes, seed=2001):
|
||||
clf = RandomForestClassifier(random_state=seed, n_jobs=-1)
|
||||
clf.fit(x_train, y_train)
|
||||
y_pred = clf.predict(x_train)
|
||||
tr_acc = (y_pred == y_train).mean()
|
||||
(y_pred == y_train).mean()
|
||||
# print('Train acc', round(tr_acc*100, 2))
|
||||
return clf, un_train_class
|
||||
|
||||
@ -450,7 +450,7 @@ def add_root_path_back(data_sets, ann_path, wav_path):
|
||||
|
||||
|
||||
def check_classes_in_train(gt_list, class_names):
|
||||
num_gt_total = np.sum([gg["start_times"].shape[0] for gg in gt_list])
|
||||
np.sum([gg["start_times"].shape[0] for gg in gt_list])
|
||||
num_with_no_class = 0
|
||||
for gt in gt_list:
|
||||
for cc in gt["class_names"]:
|
||||
@ -569,7 +569,7 @@ if __name__ == "__main__":
|
||||
num_with_no_class = check_classes_in_train(gt_test, class_names)
|
||||
if total_num_calls == num_with_no_class:
|
||||
print("Classes from the test set are not in the train set.")
|
||||
assert False
|
||||
raise AssertionError()
|
||||
|
||||
# only need the train data if evaluating Sonobat or Tadarida
|
||||
if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "":
|
||||
@ -743,7 +743,7 @@ if __name__ == "__main__":
|
||||
# check if the class names are the same
|
||||
if params_bd["class_names"] != class_names:
|
||||
print("Warning: Class names are not the same as the trained model")
|
||||
assert False
|
||||
raise AssertionError()
|
||||
|
||||
run_config = {
|
||||
**bd_args,
|
||||
@ -753,7 +753,7 @@ if __name__ == "__main__":
|
||||
|
||||
preds_bd = []
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
for ii, gg in enumerate(gt_test):
|
||||
for gg in gt_test:
|
||||
pred = du.process_file(
|
||||
gg["file_path"],
|
||||
model,
|
||||
|
||||
@ -5,11 +5,10 @@ from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.logging import get_image_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing import EvaluatorProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
|
||||
|
||||
class EvaluationModule(LightningModule):
|
||||
@ -24,7 +23,7 @@ class EvaluationModule(LightningModule):
|
||||
self.evaluator = evaluator
|
||||
|
||||
self.clip_annotations: List[data.ClipAnnotation] = []
|
||||
self.predictions: List[BatDetect2Prediction] = []
|
||||
self.predictions: List[ClipDetections] = []
|
||||
|
||||
def test_step(self, batch: TestExample, batch_idx: int):
|
||||
dataset = self.get_dataset()
|
||||
@ -34,22 +33,11 @@ class EvaluationModule(LightningModule):
|
||||
]
|
||||
|
||||
outputs = self.model.detector(batch.spec)
|
||||
clip_detections = self.model.postprocessor(
|
||||
outputs,
|
||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||
clip_detections = self.model.postprocessor(outputs)
|
||||
predictions = self.evaluator.to_clip_detections_batch(
|
||||
clip_detections,
|
||||
[clip_annotation.clip for clip_annotation in clip_annotations],
|
||||
)
|
||||
predictions = [
|
||||
BatDetect2Prediction(
|
||||
clip=clip_annotation.clip,
|
||||
predictions=to_raw_predictions(
|
||||
clip_dets.numpy(),
|
||||
targets=self.evaluator.targets,
|
||||
),
|
||||
)
|
||||
for clip_annotation, clip_dets in zip(
|
||||
clip_annotations, clip_detections
|
||||
)
|
||||
]
|
||||
|
||||
self.clip_annotations.extend(clip_annotations)
|
||||
self.predictions.extend(predictions)
|
||||
|
||||
@ -1,617 +0,0 @@
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import compute_affinity
|
||||
from soundevent.geometry import buffer_geometry, compute_bounds, scale_geometry
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.affinity import (
|
||||
AffinityConfig,
|
||||
BBoxIOUConfig,
|
||||
GeometricIOUConfig,
|
||||
build_affinity_function,
|
||||
)
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import (
|
||||
AffinityFunction,
|
||||
MatcherProtocol,
|
||||
MatchEvaluation,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.typing.evaluate import ClipMatches
|
||||
|
||||
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
||||
"""The geometry representation to use for matching."""
|
||||
|
||||
matching_strategies = Registry("matching_strategy")
|
||||
|
||||
|
||||
def match(
|
||||
sound_event_annotations: Sequence[data.SoundEventAnnotation],
|
||||
raw_predictions: Sequence[RawPrediction],
|
||||
clip: data.Clip,
|
||||
scores: Optional[Sequence[float]] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
matcher: Optional[MatcherProtocol] = None,
|
||||
) -> ClipMatches:
|
||||
if matcher is None:
|
||||
matcher = build_matcher()
|
||||
|
||||
if targets is None:
|
||||
targets = build_targets()
|
||||
|
||||
target_geometries: List[data.Geometry] = [ # type: ignore
|
||||
sound_event_annotation.sound_event.geometry
|
||||
for sound_event_annotation in sound_event_annotations
|
||||
]
|
||||
|
||||
predicted_geometries = [
|
||||
raw_prediction.geometry for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
if scores is None:
|
||||
scores = [
|
||||
raw_prediction.detection_score
|
||||
for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
matches = []
|
||||
|
||||
for source_idx, target_idx, affinity in matcher(
|
||||
ground_truth=target_geometries,
|
||||
predictions=predicted_geometries,
|
||||
scores=scores,
|
||||
):
|
||||
target = (
|
||||
sound_event_annotations[target_idx]
|
||||
if target_idx is not None
|
||||
else None
|
||||
)
|
||||
prediction = (
|
||||
raw_predictions[source_idx] if source_idx is not None else None
|
||||
)
|
||||
|
||||
gt_det = target_idx is not None
|
||||
gt_class = targets.encode_class(target) if target is not None else None
|
||||
gt_geometry = (
|
||||
target_geometries[target_idx] if target_idx is not None else None
|
||||
)
|
||||
|
||||
pred_score = float(prediction.detection_score) if prediction else 0
|
||||
pred_geometry = (
|
||||
predicted_geometries[source_idx]
|
||||
if source_idx is not None
|
||||
else None
|
||||
)
|
||||
|
||||
class_scores = (
|
||||
{
|
||||
class_name: score
|
||||
for class_name, score in zip(
|
||||
targets.class_names,
|
||||
prediction.class_scores,
|
||||
)
|
||||
}
|
||||
if prediction is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
matches.append(
|
||||
MatchEvaluation(
|
||||
clip=clip,
|
||||
sound_event_annotation=target,
|
||||
gt_det=gt_det,
|
||||
gt_class=gt_class,
|
||||
gt_geometry=gt_geometry,
|
||||
pred_score=pred_score,
|
||||
pred_class_scores=class_scores,
|
||||
pred_geometry=pred_geometry,
|
||||
affinity=affinity,
|
||||
)
|
||||
)
|
||||
|
||||
return ClipMatches(clip=clip, matches=matches)
|
||||
|
||||
|
||||
class StartTimeMatchConfig(BaseConfig):
|
||||
name: Literal["start_time_match"] = "start_time_match"
|
||||
distance_threshold: float = 0.01
|
||||
|
||||
|
||||
class StartTimeMatcher(MatcherProtocol):
|
||||
def __init__(self, distance_threshold: float):
|
||||
self.distance_threshold = distance_threshold
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
):
|
||||
return match_start_times(
|
||||
ground_truth,
|
||||
predictions,
|
||||
scores,
|
||||
distance_threshold=self.distance_threshold,
|
||||
)
|
||||
|
||||
@matching_strategies.register(StartTimeMatchConfig)
|
||||
@staticmethod
|
||||
def from_config(config: StartTimeMatchConfig):
|
||||
return StartTimeMatcher(distance_threshold=config.distance_threshold)
|
||||
|
||||
|
||||
def match_start_times(
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
distance_threshold: float = 0.01,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
if not ground_truth:
|
||||
for index in range(len(predictions)):
|
||||
yield index, None, 0
|
||||
|
||||
return
|
||||
|
||||
if not predictions:
|
||||
for index in range(len(ground_truth)):
|
||||
yield None, index, 0
|
||||
|
||||
return
|
||||
|
||||
gt_times = np.array([compute_bounds(geom)[0] for geom in ground_truth])
|
||||
pred_times = np.array([compute_bounds(geom)[0] for geom in predictions])
|
||||
|
||||
scores = np.array(scores)
|
||||
sort_args = np.argsort(scores)[::-1]
|
||||
|
||||
distances = np.abs(gt_times[None, :] - pred_times[:, None])
|
||||
closests = np.argmin(distances, axis=-1)
|
||||
|
||||
unmatched_gt = set(range(len(gt_times)))
|
||||
|
||||
for pred_index in sort_args:
|
||||
# Get the closest ground truth
|
||||
gt_closest_index = closests[pred_index]
|
||||
|
||||
if gt_closest_index not in unmatched_gt:
|
||||
# Does not match if closest has been assigned
|
||||
yield pred_index, None, 0
|
||||
continue
|
||||
|
||||
# Get the actual distance
|
||||
distance = distances[pred_index, gt_closest_index]
|
||||
|
||||
if distance > distance_threshold:
|
||||
# Does not match if too far from closest
|
||||
yield pred_index, None, 0
|
||||
continue
|
||||
|
||||
# Return affinity value: linear interpolation between 0 to 1, where a
|
||||
# distance at the threshold maps to 0 affinity and a zero distance maps
|
||||
# to 1.
|
||||
affinity = np.interp(
|
||||
distance,
|
||||
[0, distance_threshold],
|
||||
[1, 0],
|
||||
left=1,
|
||||
right=0,
|
||||
)
|
||||
unmatched_gt.remove(gt_closest_index)
|
||||
yield pred_index, gt_closest_index, affinity
|
||||
|
||||
for missing_index in unmatched_gt:
|
||||
yield None, missing_index, 0
|
||||
|
||||
|
||||
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||
return data.BoundingBox(
|
||||
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||
)
|
||||
|
||||
|
||||
def _to_interval(geometry: data.Geometry) -> data.TimeInterval:
|
||||
start_time, _, end_time, _ = compute_bounds(geometry)
|
||||
return data.TimeInterval(coordinates=[start_time, end_time])
|
||||
|
||||
|
||||
def _to_timestamp(geometry: data.Geometry) -> data.TimeStamp:
|
||||
start_time = compute_bounds(geometry)[0]
|
||||
return data.TimeStamp(coordinates=start_time)
|
||||
|
||||
|
||||
_geometry_cast_functions: Mapping[
|
||||
MatchingGeometry, Callable[[data.Geometry], data.Geometry]
|
||||
] = {
|
||||
"bbox": _to_bbox,
|
||||
"interval": _to_interval,
|
||||
"timestamp": _to_timestamp,
|
||||
}
|
||||
|
||||
|
||||
class GreedyMatchConfig(BaseConfig):
|
||||
name: Literal["greedy_match"] = "greedy_match"
|
||||
geometry: MatchingGeometry = "timestamp"
|
||||
affinity_threshold: float = 0.5
|
||||
affinity_function: AffinityConfig = Field(
|
||||
default_factory=GeometricIOUConfig
|
||||
)
|
||||
|
||||
|
||||
class GreedyMatcher(MatcherProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
geometry: MatchingGeometry,
|
||||
affinity_threshold: float,
|
||||
affinity_function: AffinityFunction,
|
||||
):
|
||||
self.geometry = geometry
|
||||
self.affinity_function = affinity_function
|
||||
self.affinity_threshold = affinity_threshold
|
||||
self.cast_geometry = _geometry_cast_functions[self.geometry]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
):
|
||||
return greedy_match(
|
||||
ground_truth=[self.cast_geometry(geom) for geom in ground_truth],
|
||||
predictions=[self.cast_geometry(geom) for geom in predictions],
|
||||
scores=scores,
|
||||
affinity_function=self.affinity_function,
|
||||
affinity_threshold=self.affinity_threshold,
|
||||
)
|
||||
|
||||
@matching_strategies.register(GreedyMatchConfig)
|
||||
@staticmethod
|
||||
def from_config(config: GreedyMatchConfig):
|
||||
affinity_function = build_affinity_function(config.affinity_function)
|
||||
return GreedyMatcher(
|
||||
geometry=config.geometry,
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
affinity_function=affinity_function,
|
||||
)
|
||||
|
||||
|
||||
def greedy_match(
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
affinity_threshold: float = 0.5,
|
||||
affinity_function: AffinityFunction = compute_affinity,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
"""Performs a greedy, one-to-one matching of source to target geometries.
|
||||
|
||||
Iterates through source geometries, prioritizing by score if provided. Each
|
||||
source is matched to the best available target, provided the affinity
|
||||
exceeds the threshold and the target has not already been assigned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source
|
||||
A list of source geometries (e.g., predictions).
|
||||
target
|
||||
A list of target geometries (e.g., ground truths).
|
||||
scores
|
||||
Confidence scores for each source geometry for prioritization.
|
||||
affinity_threshold
|
||||
The minimum affinity score required for a valid match.
|
||||
|
||||
Yields
|
||||
------
|
||||
Tuple[Optional[int], Optional[int], float]
|
||||
A 3-element tuple describing a match or a miss. There are three
|
||||
possible formats:
|
||||
- Successful Match: `(source_idx, target_idx, affinity)`
|
||||
- Unmatched Source (False Positive): `(source_idx, None, 0)`
|
||||
- Unmatched Target (False Negative): `(None, target_idx, 0)`
|
||||
"""
|
||||
unassigned_gt = set(range(len(ground_truth)))
|
||||
|
||||
if not predictions:
|
||||
for gt_idx in range(len(ground_truth)):
|
||||
yield None, gt_idx, 0
|
||||
|
||||
return
|
||||
|
||||
if not ground_truth:
|
||||
for pred_idx in range(len(predictions)):
|
||||
yield pred_idx, None, 0
|
||||
|
||||
return
|
||||
|
||||
indices = np.argsort(scores)[::-1]
|
||||
|
||||
for pred_idx in indices:
|
||||
source_geometry = predictions[pred_idx]
|
||||
|
||||
affinities = np.array(
|
||||
[
|
||||
affinity_function(source_geometry, target_geometry)
|
||||
for target_geometry in ground_truth
|
||||
]
|
||||
)
|
||||
|
||||
closest_target = int(np.argmax(affinities))
|
||||
affinity = affinities[closest_target]
|
||||
|
||||
if affinities[closest_target] <= affinity_threshold:
|
||||
yield pred_idx, None, 0
|
||||
continue
|
||||
|
||||
if closest_target not in unassigned_gt:
|
||||
yield pred_idx, None, 0
|
||||
continue
|
||||
|
||||
unassigned_gt.remove(closest_target)
|
||||
yield pred_idx, closest_target, affinity
|
||||
|
||||
for gt_idx in unassigned_gt:
|
||||
yield None, gt_idx, 0
|
||||
|
||||
|
||||
class GreedyAffinityMatchConfig(BaseConfig):
|
||||
name: Literal["greedy_affinity_match"] = "greedy_affinity_match"
|
||||
affinity_function: AffinityConfig = Field(default_factory=BBoxIOUConfig)
|
||||
affinity_threshold: float = 0.5
|
||||
time_buffer: float = 0
|
||||
frequency_buffer: float = 0
|
||||
time_scale: float = 1.0
|
||||
frequency_scale: float = 1.0
|
||||
|
||||
|
||||
class GreedyAffinityMatcher(MatcherProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
affinity_threshold: float,
|
||||
affinity_function: AffinityFunction,
|
||||
time_buffer: float = 0,
|
||||
frequency_buffer: float = 0,
|
||||
time_scale: float = 1.0,
|
||||
frequency_scale: float = 1.0,
|
||||
):
|
||||
self.affinity_threshold = affinity_threshold
|
||||
self.affinity_function = affinity_function
|
||||
self.time_buffer = time_buffer
|
||||
self.frequency_buffer = frequency_buffer
|
||||
self.time_scale = time_scale
|
||||
self.frequency_scale = frequency_scale
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
):
|
||||
if self.time_buffer != 0 or self.frequency_buffer != 0:
|
||||
ground_truth = [
|
||||
buffer_geometry(
|
||||
geometry,
|
||||
time_buffer=self.time_buffer,
|
||||
freq_buffer=self.frequency_buffer,
|
||||
)
|
||||
for geometry in ground_truth
|
||||
]
|
||||
|
||||
predictions = [
|
||||
buffer_geometry(
|
||||
geometry,
|
||||
time_buffer=self.time_buffer,
|
||||
freq_buffer=self.frequency_buffer,
|
||||
)
|
||||
for geometry in predictions
|
||||
]
|
||||
|
||||
affinity_matrix = compute_affinity_matrix(
|
||||
ground_truth,
|
||||
predictions,
|
||||
self.affinity_function,
|
||||
time_scale=self.time_scale,
|
||||
frequency_scale=self.frequency_scale,
|
||||
)
|
||||
|
||||
return select_greedy_matches(
|
||||
affinity_matrix,
|
||||
affinity_threshold=self.affinity_threshold,
|
||||
)
|
||||
|
||||
@matching_strategies.register(GreedyAffinityMatchConfig)
|
||||
@staticmethod
|
||||
def from_config(config: GreedyAffinityMatchConfig):
|
||||
affinity_function = build_affinity_function(config.affinity_function)
|
||||
return GreedyAffinityMatcher(
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
affinity_function=affinity_function,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
)
|
||||
|
||||
|
||||
class OptimalMatchConfig(BaseConfig):
|
||||
name: Literal["optimal_affinity_match"] = "optimal_affinity_match"
|
||||
affinity_function: AffinityConfig = Field(default_factory=BBoxIOUConfig)
|
||||
affinity_threshold: float = 0.5
|
||||
time_buffer: float = 0
|
||||
frequency_buffer: float = 0
|
||||
time_scale: float = 1.0
|
||||
frequency_scale: float = 1.0
|
||||
|
||||
|
||||
class OptimalMatcher(MatcherProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
affinity_threshold: float,
|
||||
affinity_function: AffinityFunction,
|
||||
time_buffer: float = 0,
|
||||
frequency_buffer: float = 0,
|
||||
time_scale: float = 1.0,
|
||||
frequency_scale: float = 1.0,
|
||||
):
|
||||
self.affinity_threshold = affinity_threshold
|
||||
self.affinity_function = affinity_function
|
||||
self.time_buffer = time_buffer
|
||||
self.frequency_buffer = frequency_buffer
|
||||
self.time_scale = time_scale
|
||||
self.frequency_scale = frequency_scale
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
):
|
||||
if self.time_buffer != 0 or self.frequency_buffer != 0:
|
||||
ground_truth = [
|
||||
buffer_geometry(
|
||||
geometry,
|
||||
time_buffer=self.time_buffer,
|
||||
freq_buffer=self.frequency_buffer,
|
||||
)
|
||||
for geometry in ground_truth
|
||||
]
|
||||
|
||||
predictions = [
|
||||
buffer_geometry(
|
||||
geometry,
|
||||
time_buffer=self.time_buffer,
|
||||
freq_buffer=self.frequency_buffer,
|
||||
)
|
||||
for geometry in predictions
|
||||
]
|
||||
|
||||
affinity_matrix = compute_affinity_matrix(
|
||||
ground_truth,
|
||||
predictions,
|
||||
self.affinity_function,
|
||||
time_scale=self.time_scale,
|
||||
frequency_scale=self.frequency_scale,
|
||||
)
|
||||
return select_optimal_matches(
|
||||
affinity_matrix,
|
||||
affinity_threshold=self.affinity_threshold,
|
||||
)
|
||||
|
||||
@matching_strategies.register(OptimalMatchConfig)
|
||||
@staticmethod
|
||||
def from_config(config: OptimalMatchConfig):
|
||||
affinity_function = build_affinity_function(config.affinity_function)
|
||||
return OptimalMatcher(
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
affinity_function=affinity_function,
|
||||
time_buffer=config.time_buffer,
|
||||
frequency_buffer=config.frequency_buffer,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
)
|
||||
|
||||
|
||||
MatchConfig = Annotated[
|
||||
Union[
|
||||
GreedyMatchConfig,
|
||||
StartTimeMatchConfig,
|
||||
OptimalMatchConfig,
|
||||
GreedyAffinityMatchConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def compute_affinity_matrix(
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
affinity_function: AffinityFunction,
|
||||
time_scale: float = 1,
|
||||
frequency_scale: float = 1,
|
||||
) -> np.ndarray:
|
||||
# Scale geometries if necessary
|
||||
if time_scale != 1 or frequency_scale != 1:
|
||||
ground_truth = [
|
||||
scale_geometry(geometry, time_scale, frequency_scale)
|
||||
for geometry in ground_truth
|
||||
]
|
||||
|
||||
predictions = [
|
||||
scale_geometry(geometry, time_scale, frequency_scale)
|
||||
for geometry in predictions
|
||||
]
|
||||
|
||||
affinity_matrix = np.zeros((len(ground_truth), len(predictions)))
|
||||
for gt_idx, gt_geometry in enumerate(ground_truth):
|
||||
for pred_idx, pred_geometry in enumerate(predictions):
|
||||
affinity = affinity_function(
|
||||
gt_geometry,
|
||||
pred_geometry,
|
||||
)
|
||||
affinity_matrix[gt_idx, pred_idx] = affinity
|
||||
|
||||
return affinity_matrix
|
||||
|
||||
|
||||
def select_optimal_matches(
|
||||
affinity_matrix: np.ndarray,
|
||||
affinity_threshold: float = 0.5,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
num_gt, num_pred = affinity_matrix.shape
|
||||
gts = set(range(num_gt))
|
||||
preds = set(range(num_pred))
|
||||
|
||||
assiged_rows, assigned_columns = linear_sum_assignment(
|
||||
affinity_matrix,
|
||||
maximize=True,
|
||||
)
|
||||
|
||||
for gt_idx, pred_idx in zip(assiged_rows, assigned_columns):
|
||||
affinity = float(affinity_matrix[gt_idx, pred_idx])
|
||||
|
||||
if affinity <= affinity_threshold:
|
||||
continue
|
||||
|
||||
yield gt_idx, pred_idx, affinity
|
||||
gts.remove(gt_idx)
|
||||
preds.remove(pred_idx)
|
||||
|
||||
for gt_idx in gts:
|
||||
yield gt_idx, None, 0
|
||||
|
||||
for pred_idx in preds:
|
||||
yield None, pred_idx, 0
|
||||
|
||||
|
||||
def select_greedy_matches(
|
||||
affinity_matrix: np.ndarray,
|
||||
affinity_threshold: float = 0.5,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
num_gt, num_pred = affinity_matrix.shape
|
||||
unmatched_pred = set(range(num_pred))
|
||||
|
||||
for gt_idx in range(num_gt):
|
||||
row = affinity_matrix[gt_idx]
|
||||
|
||||
top_pred = int(np.argmax(row))
|
||||
top_affinity = float(row[top_pred])
|
||||
|
||||
if (
|
||||
top_affinity <= affinity_threshold
|
||||
or top_pred not in unmatched_pred
|
||||
):
|
||||
yield None, gt_idx, 0
|
||||
continue
|
||||
|
||||
unmatched_pred.remove(top_pred)
|
||||
yield top_pred, gt_idx, top_affinity
|
||||
|
||||
for pred_idx in unmatched_pred:
|
||||
yield pred_idx, None, 0
|
||||
|
||||
|
||||
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
|
||||
config = config or StartTimeMatchConfig()
|
||||
return matching_strategies.build(config)
|
||||
@ -7,10 +7,8 @@ from typing import (
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
@ -18,16 +16,23 @@ from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import (
|
||||
average_precision,
|
||||
compute_precision_recall,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
from batdetect2.postprocess.types import Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClassificationMetric",
|
||||
"ClassificationMetricConfig",
|
||||
"ClassificationMetricImportConfig",
|
||||
"build_classification_metric",
|
||||
"compute_precision_recall_curves",
|
||||
]
|
||||
@ -36,13 +41,13 @@ __all__ = [
|
||||
@dataclass
|
||||
class MatchEval:
|
||||
clip: data.Clip
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
gt: data.SoundEventAnnotation | None
|
||||
pred: Detection | None
|
||||
|
||||
is_prediction: bool
|
||||
is_ground_truth: bool
|
||||
is_generic: bool
|
||||
true_class: Optional[str]
|
||||
true_class: str | None
|
||||
score: float
|
||||
|
||||
|
||||
@ -60,17 +65,28 @@ classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(classification_metrics)
|
||||
class ClassificationMetricImportConfig(ImportConfig):
|
||||
"""Use any callable as a classification metric.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class BaseClassificationConfig(BaseConfig):
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
include: List[str] | None = None
|
||||
exclude: List[str] | None = None
|
||||
|
||||
|
||||
class BaseClassificationMetric:
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
include: List[str] | None = None,
|
||||
exclude: List[str] | None = None,
|
||||
):
|
||||
self.targets = targets
|
||||
self.include = include
|
||||
@ -100,8 +116,8 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
label: str = "average_precision",
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
include: List[str] | None = None,
|
||||
exclude: List[str] | None = None,
|
||||
):
|
||||
super().__init__(include=include, exclude=exclude, targets=targets)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
@ -169,8 +185,8 @@ class ClassificationROCAUC(BaseClassificationMetric):
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
label: str = "roc_auc",
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
include: List[str] | None = None,
|
||||
exclude: List[str] | None = None,
|
||||
):
|
||||
self.targets = targets
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
@ -225,10 +241,7 @@ class ClassificationROCAUC(BaseClassificationMetric):
|
||||
|
||||
|
||||
ClassificationMetricConfig = Annotated[
|
||||
Union[
|
||||
ClassificationAveragePrecisionConfig,
|
||||
ClassificationROCAUCConfig,
|
||||
],
|
||||
ClassificationAveragePrecisionConfig | ClassificationROCAUCConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union
|
||||
from typing import Annotated, Callable, Dict, Literal, Sequence, Set
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
|
||||
|
||||
@ -24,6 +28,17 @@ clip_classification_metrics: Registry[ClipClassificationMetric, []] = Registry(
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(clip_classification_metrics)
|
||||
class ClipClassificationMetricImportConfig(ImportConfig):
|
||||
"""Use any callable as a clip classification metric.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class ClipClassificationAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
label: str = "average_precision"
|
||||
@ -123,10 +138,7 @@ class ClipClassificationROCAUC:
|
||||
|
||||
|
||||
ClipClassificationMetricConfig = Annotated[
|
||||
Union[
|
||||
ClipClassificationAveragePrecisionConfig,
|
||||
ClipClassificationROCAUCConfig,
|
||||
],
|
||||
ClipClassificationAveragePrecisionConfig | ClipClassificationROCAUCConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -1,12 +1,16 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Callable, Dict, Literal, Sequence, Union
|
||||
from typing import Annotated, Callable, Dict, Literal, Sequence
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
|
||||
|
||||
@ -23,6 +27,17 @@ clip_detection_metrics: Registry[ClipDetectionMetric, []] = Registry(
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(clip_detection_metrics)
|
||||
class ClipDetectionMetricImportConfig(ImportConfig):
|
||||
"""Use any callable as a clip detection metric.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class ClipDetectionAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
label: str = "average_precision"
|
||||
@ -159,12 +174,10 @@ class ClipDetectionPrecision:
|
||||
|
||||
|
||||
ClipDetectionMetricConfig = Annotated[
|
||||
Union[
|
||||
ClipDetectionAveragePrecisionConfig,
|
||||
ClipDetectionROCAUCConfig,
|
||||
ClipDetectionRecallConfig,
|
||||
ClipDetectionPrecisionConfig,
|
||||
],
|
||||
ClipDetectionAveragePrecisionConfig
|
||||
| ClipDetectionROCAUCConfig
|
||||
| ClipDetectionRecallConfig
|
||||
| ClipDetectionPrecisionConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -11,7 +11,7 @@ __all__ = [
|
||||
def compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives: Optional[int] = None,
|
||||
num_positives: int | None = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
y_true = np.array(y_true)
|
||||
y_score = np.array(y_score)
|
||||
@ -41,7 +41,7 @@ def compute_precision_recall(
|
||||
def average_precision(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives: Optional[int] = None,
|
||||
num_positives: int | None = None,
|
||||
) -> float:
|
||||
if num_positives == 0:
|
||||
return np.nan
|
||||
|
||||
@ -5,9 +5,7 @@ from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
@ -15,21 +13,27 @@ from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import RawPrediction
|
||||
from batdetect2.postprocess.types import Detection
|
||||
|
||||
__all__ = [
|
||||
"DetectionMetricConfig",
|
||||
"DetectionMetric",
|
||||
"DetectionMetricImportConfig",
|
||||
"build_detection_metric",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchEval:
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
gt: data.SoundEventAnnotation | None
|
||||
pred: Detection | None
|
||||
|
||||
is_prediction: bool
|
||||
is_ground_truth: bool
|
||||
@ -48,6 +52,17 @@ DetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||
detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric")
|
||||
|
||||
|
||||
@add_import_config(detection_metrics)
|
||||
class DetectionMetricImportConfig(ImportConfig):
|
||||
"""Use any callable as a detection metric.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class DetectionAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
label: str = "average_precision"
|
||||
@ -79,7 +94,7 @@ class DetectionAveragePrecision:
|
||||
y_score.append(m.score)
|
||||
|
||||
ap = average_precision(y_true, y_score, num_positives=num_positives)
|
||||
return {self.label: ap}
|
||||
return {self.label: float(ap)}
|
||||
|
||||
@detection_metrics.register(DetectionAveragePrecisionConfig)
|
||||
@staticmethod
|
||||
@ -212,12 +227,10 @@ class DetectionPrecision:
|
||||
|
||||
|
||||
DetectionMetricConfig = Annotated[
|
||||
Union[
|
||||
DetectionAveragePrecisionConfig,
|
||||
DetectionROCAUCConfig,
|
||||
DetectionRecallConfig,
|
||||
DetectionPrecisionConfig,
|
||||
],
|
||||
DetectionAveragePrecisionConfig
|
||||
| DetectionROCAUCConfig
|
||||
| DetectionRecallConfig
|
||||
| DetectionPrecisionConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Annotated,
|
||||
@ -6,9 +5,7 @@ from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
@ -16,14 +13,20 @@ from pydantic import Field
|
||||
from sklearn import metrics, preprocessing
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import RawPrediction
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
from batdetect2.postprocess.types import Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"TopClassMetricConfig",
|
||||
"TopClassMetric",
|
||||
"TopClassMetricImportConfig",
|
||||
"build_top_class_metric",
|
||||
]
|
||||
|
||||
@ -31,14 +34,14 @@ __all__ = [
|
||||
@dataclass
|
||||
class MatchEval:
|
||||
clip: data.Clip
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
gt: data.SoundEventAnnotation | None
|
||||
pred: Detection | None
|
||||
|
||||
is_ground_truth: bool
|
||||
is_generic: bool
|
||||
is_prediction: bool
|
||||
pred_class: Optional[str]
|
||||
true_class: Optional[str]
|
||||
pred_class: str | None
|
||||
true_class: str | None
|
||||
score: float
|
||||
|
||||
|
||||
@ -54,6 +57,17 @@ TopClassMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||
top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric")
|
||||
|
||||
|
||||
@add_import_config(top_class_metrics)
|
||||
class TopClassMetricImportConfig(ImportConfig):
|
||||
"""Use any callable as a top-class metric.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class TopClassAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
label: str = "average_precision"
|
||||
@ -301,13 +315,11 @@ class BalancedAccuracy:
|
||||
|
||||
|
||||
TopClassMetricConfig = Annotated[
|
||||
Union[
|
||||
TopClassAveragePrecisionConfig,
|
||||
TopClassROCAUCConfig,
|
||||
TopClassRecallConfig,
|
||||
TopClassPrecisionConfig,
|
||||
BalancedAccuracyConfig,
|
||||
],
|
||||
TopClassAveragePrecisionConfig
|
||||
| TopClassROCAUCConfig
|
||||
| TopClassRecallConfig
|
||||
| TopClassPrecisionConfig
|
||||
| BalancedAccuracyConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -1,16 +1,14 @@
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
class BasePlotConfig(BaseConfig):
|
||||
label: str = "plot"
|
||||
theme: str = "default"
|
||||
title: Optional[str] = None
|
||||
title: str | None = None
|
||||
figsize: tuple[int, int] = (10, 10)
|
||||
dpi: int = 100
|
||||
|
||||
@ -21,7 +19,7 @@ class BasePlot:
|
||||
targets: TargetProtocol,
|
||||
label: str = "plot",
|
||||
figsize: tuple[int, int] = (10, 10),
|
||||
title: Optional[str] = None,
|
||||
title: str | None = None,
|
||||
dpi: int = 100,
|
||||
theme: str = "default",
|
||||
):
|
||||
|
||||
@ -3,10 +3,8 @@ from typing import (
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@ -14,7 +12,7 @@ from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.evaluate.metrics.classification import (
|
||||
ClipEval,
|
||||
_extract_per_class_metric_data,
|
||||
@ -31,7 +29,7 @@ from batdetect2.plotting.metrics import (
|
||||
plot_threshold_recall_curve,
|
||||
plot_threshold_recall_curves,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
ClassificationPlotter = Callable[
|
||||
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
||||
@ -42,10 +40,21 @@ classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(classification_plots)
|
||||
class ClassificationPlotImportConfig(ImportConfig):
|
||||
"""Use any callable as a classification plot.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Classification Precision-Recall Curve"
|
||||
title: str | None = "Classification Precision-Recall Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
separate_figures: bool = False
|
||||
@ -88,9 +97,7 @@ class PRCurve(BasePlot):
|
||||
|
||||
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||
ax.set_title(class_name)
|
||||
|
||||
yield f"{self.label}/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@classification_plots.register(PRCurveConfig)
|
||||
@ -108,7 +115,7 @@ class PRCurve(BasePlot):
|
||||
class ThresholdPrecisionCurveConfig(BasePlotConfig):
|
||||
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
|
||||
label: str = "threshold_precision_curve"
|
||||
title: Optional[str] = "Classification Threshold-Precision Curve"
|
||||
title: str | None = "Classification Threshold-Precision Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
separate_figures: bool = False
|
||||
@ -181,7 +188,7 @@ class ThresholdPrecisionCurve(BasePlot):
|
||||
class ThresholdRecallCurveConfig(BasePlotConfig):
|
||||
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
|
||||
label: str = "threshold_recall_curve"
|
||||
title: Optional[str] = "Classification Threshold-Recall Curve"
|
||||
title: str | None = "Classification Threshold-Recall Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
separate_figures: bool = False
|
||||
@ -254,7 +261,7 @@ class ThresholdRecallCurve(BasePlot):
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Classification ROC Curve"
|
||||
title: str | None = "Classification ROC Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
separate_figures: bool = False
|
||||
@ -326,12 +333,10 @@ class ROCCurve(BasePlot):
|
||||
|
||||
|
||||
ClassificationPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ThresholdPrecisionCurveConfig,
|
||||
ThresholdRecallCurveConfig,
|
||||
],
|
||||
PRCurveConfig
|
||||
| ROCCurveConfig
|
||||
| ThresholdPrecisionCurveConfig
|
||||
| ThresholdRecallCurveConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -3,10 +3,8 @@ from typing import (
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@ -14,7 +12,7 @@ from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.evaluate.metrics.clip_classification import ClipEval
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
@ -24,10 +22,11 @@ from batdetect2.plotting.metrics import (
|
||||
plot_roc_curve,
|
||||
plot_roc_curves,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClipClassificationPlotConfig",
|
||||
"ClipClassificationPlotImportConfig",
|
||||
"ClipClassificationPlotter",
|
||||
"build_clip_classification_plotter",
|
||||
]
|
||||
@ -41,10 +40,21 @@ clip_classification_plots: Registry[
|
||||
] = Registry("clip_classification_plot")
|
||||
|
||||
|
||||
@add_import_config(clip_classification_plots)
|
||||
class ClipClassificationPlotImportConfig(ImportConfig):
|
||||
"""Use any callable as a clip classification plot.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Clip Classification Precision-Recall Curve"
|
||||
title: str | None = "Clip Classification Precision-Recall Curve"
|
||||
separate_figures: bool = False
|
||||
|
||||
|
||||
@ -111,7 +121,7 @@ class PRCurve(BasePlot):
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Clip Classification ROC Curve"
|
||||
title: str | None = "Clip Classification ROC Curve"
|
||||
separate_figures: bool = False
|
||||
|
||||
|
||||
@ -174,10 +184,7 @@ class ROCCurve(BasePlot):
|
||||
|
||||
|
||||
ClipClassificationPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
],
|
||||
PRCurveConfig | ROCCurveConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -3,10 +3,8 @@ from typing import (
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import pandas as pd
|
||||
@ -15,15 +13,16 @@ from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.evaluate.metrics.clip_detection import ClipEval
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClipDetectionPlotConfig",
|
||||
"ClipDetectionPlotImportConfig",
|
||||
"ClipDetectionPlotter",
|
||||
"build_clip_detection_plotter",
|
||||
]
|
||||
@ -38,10 +37,21 @@ clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(clip_detection_plots)
|
||||
class ClipDetectionPlotImportConfig(ImportConfig):
|
||||
"""Use any callable as a clip detection plot.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Clip Detection Precision-Recall Curve"
|
||||
title: str | None = "Clip Detection Precision-Recall Curve"
|
||||
|
||||
|
||||
class PRCurve(BasePlot):
|
||||
@ -74,7 +84,7 @@ class PRCurve(BasePlot):
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Clip Detection ROC Curve"
|
||||
title: str | None = "Clip Detection ROC Curve"
|
||||
|
||||
|
||||
class ROCCurve(BasePlot):
|
||||
@ -107,7 +117,7 @@ class ROCCurve(BasePlot):
|
||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||
name: Literal["score_distribution"] = "score_distribution"
|
||||
label: str = "score_distribution"
|
||||
title: Optional[str] = "Clip Detection Score Distribution"
|
||||
title: str | None = "Clip Detection Score Distribution"
|
||||
|
||||
|
||||
class ScoreDistributionPlot(BasePlot):
|
||||
@ -147,11 +157,7 @@ class ScoreDistributionPlot(BasePlot):
|
||||
|
||||
|
||||
ClipDetectionPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ScoreDistributionPlotConfig,
|
||||
],
|
||||
PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -4,10 +4,8 @@ from typing import (
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@ -18,14 +16,16 @@ from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.metrics.detection import ClipEval
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.detections import plot_clip_detections
|
||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
||||
|
||||
@ -34,10 +34,21 @@ detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(detection_plots)
|
||||
class DetectionPlotImportConfig(ImportConfig):
|
||||
"""Use any callable as a detection plot.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Detection Precision-Recall Curve"
|
||||
title: str | None = "Detection Precision-Recall Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
@ -100,7 +111,7 @@ class PRCurve(BasePlot):
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Detection ROC Curve"
|
||||
title: str | None = "Detection ROC Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
@ -159,7 +170,7 @@ class ROCCurve(BasePlot):
|
||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||
name: Literal["score_distribution"] = "score_distribution"
|
||||
label: str = "score_distribution"
|
||||
title: Optional[str] = "Detection Score Distribution"
|
||||
title: str | None = "Detection Score Distribution"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
@ -226,7 +237,7 @@ class ScoreDistributionPlot(BasePlot):
|
||||
class ExampleDetectionPlotConfig(BasePlotConfig):
|
||||
name: Literal["example_detection"] = "example_detection"
|
||||
label: str = "example_detection"
|
||||
title: Optional[str] = "Example Detection"
|
||||
title: str | None = "Example Detection"
|
||||
figsize: tuple[int, int] = (10, 4)
|
||||
num_examples: int = 5
|
||||
threshold: float = 0.2
|
||||
@ -292,12 +303,10 @@ class ExampleDetectionPlot(BasePlot):
|
||||
|
||||
|
||||
DetectionPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ScoreDistributionPlotConfig,
|
||||
ExampleDetectionPlotConfig,
|
||||
],
|
||||
PRCurveConfig
|
||||
| ROCCurveConfig
|
||||
| ScoreDistributionPlotConfig
|
||||
| ExampleDetectionPlotConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
@ -4,14 +4,9 @@ from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@ -21,7 +16,8 @@ from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.metrics.top_class import (
|
||||
ClipEval,
|
||||
@ -32,19 +28,31 @@ from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.gallery import plot_match_gallery
|
||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
||||
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[tuple[str, Figure]]]
|
||||
|
||||
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
||||
name="top_class_plot"
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(top_class_plots)
|
||||
class TopClassPlotImportConfig(ImportConfig):
|
||||
"""Use any callable as a top-class plot.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Top Class Precision-Recall Curve"
|
||||
title: str | None = "Top Class Precision-Recall Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
@ -64,7 +72,7 @@ class PRCurve(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
) -> Iterable[tuple[str, Figure]]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
@ -111,7 +119,7 @@ class PRCurve(BasePlot):
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Top Class ROC Curve"
|
||||
title: str | None = "Top Class ROC Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
@ -131,7 +139,7 @@ class ROCCurve(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
) -> Iterable[tuple[str, Figure]]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
@ -173,7 +181,7 @@ class ROCCurve(BasePlot):
|
||||
|
||||
class ConfusionMatrixConfig(BasePlotConfig):
|
||||
name: Literal["confusion_matrix"] = "confusion_matrix"
|
||||
title: Optional[str] = "Top Class Confusion Matrix"
|
||||
title: str | None = "Top Class Confusion Matrix"
|
||||
figsize: tuple[int, int] = (10, 10)
|
||||
label: str = "confusion_matrix"
|
||||
exclude_generic: bool = True
|
||||
@ -214,7 +222,7 @@ class ConfusionMatrix(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
) -> Iterable[tuple[str, Figure]]:
|
||||
cm, labels = compute_confusion_matrix(
|
||||
clip_evaluations,
|
||||
self.targets,
|
||||
@ -257,7 +265,7 @@ class ConfusionMatrix(BasePlot):
|
||||
class ExampleClassificationPlotConfig(BasePlotConfig):
|
||||
name: Literal["example_classification"] = "example_classification"
|
||||
label: str = "example_classification"
|
||||
title: Optional[str] = "Example Classification"
|
||||
title: str | None = "Example Classification"
|
||||
num_examples: int = 4
|
||||
threshold: float = 0.2
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
@ -286,26 +294,26 @@ class ExampleClassificationPlot(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
) -> Iterable[tuple[str, Figure]]:
|
||||
grouped = group_matches(clip_evaluations, threshold=self.threshold)
|
||||
|
||||
for class_name, matches in grouped.items():
|
||||
true_positives: List[MatchEval] = get_binned_sample(
|
||||
true_positives: list[MatchEval] = get_binned_sample(
|
||||
matches.true_positives,
|
||||
n_examples=self.num_examples,
|
||||
)
|
||||
|
||||
false_positives: List[MatchEval] = get_binned_sample(
|
||||
false_positives: list[MatchEval] = get_binned_sample(
|
||||
matches.false_positives,
|
||||
n_examples=self.num_examples,
|
||||
)
|
||||
|
||||
false_negatives: List[MatchEval] = random.sample(
|
||||
false_negatives: list[MatchEval] = random.sample(
|
||||
matches.false_negatives,
|
||||
k=min(self.num_examples, len(matches.false_negatives)),
|
||||
)
|
||||
|
||||
cross_triggers: List[MatchEval] = get_binned_sample(
|
||||
cross_triggers: list[MatchEval] = get_binned_sample(
|
||||
matches.cross_triggers, n_examples=self.num_examples
|
||||
)
|
||||
|
||||
@ -348,12 +356,10 @@ class ExampleClassificationPlot(BasePlot):
|
||||
|
||||
|
||||
TopClassPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ConfusionMatrixConfig,
|
||||
ExampleClassificationPlotConfig,
|
||||
],
|
||||
PRCurveConfig
|
||||
| ROCCurveConfig
|
||||
| ConfusionMatrixConfig
|
||||
| ExampleClassificationPlotConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
@ -367,16 +373,16 @@ def build_top_class_plotter(
|
||||
|
||||
@dataclass
|
||||
class ClassMatches:
|
||||
false_positives: List[MatchEval] = field(default_factory=list)
|
||||
false_negatives: List[MatchEval] = field(default_factory=list)
|
||||
true_positives: List[MatchEval] = field(default_factory=list)
|
||||
cross_triggers: List[MatchEval] = field(default_factory=list)
|
||||
false_positives: list[MatchEval] = field(default_factory=list)
|
||||
false_negatives: list[MatchEval] = field(default_factory=list)
|
||||
true_positives: list[MatchEval] = field(default_factory=list)
|
||||
cross_triggers: list[MatchEval] = field(default_factory=list)
|
||||
|
||||
|
||||
def group_matches(
|
||||
clip_evals: Sequence[ClipEval],
|
||||
threshold: float = 0.2,
|
||||
) -> Dict[str, ClassMatches]:
|
||||
) -> dict[str, ClassMatches]:
|
||||
class_examples = defaultdict(ClassMatches)
|
||||
|
||||
for clip_eval in clip_evals:
|
||||
@ -405,12 +411,13 @@ def group_matches(
|
||||
return class_examples
|
||||
|
||||
|
||||
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
|
||||
def get_binned_sample(matches: list[MatchEval], n_examples: int = 5):
|
||||
if len(matches) < n_examples:
|
||||
return matches
|
||||
|
||||
indices, pred_scores = zip(
|
||||
*[(index, match.score) for index, match in enumerate(matches)]
|
||||
*[(index, match.score) for index, match in enumerate(matches)],
|
||||
strict=False,
|
||||
)
|
||||
|
||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
||||
|
||||
27
src/batdetect2/evaluate/results.py
Normal file
27
src/batdetect2/evaluate/results.py
Normal file
@ -0,0 +1,27 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
|
||||
__all__ = ["save_evaluation_results"]
|
||||
|
||||
|
||||
def save_evaluation_results(
|
||||
metrics: dict[str, float],
|
||||
plots: Iterable[tuple[str, Figure]],
|
||||
output_dir: data.PathLike,
|
||||
) -> None:
|
||||
"""Save evaluation metrics and plots to disk."""
|
||||
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
metrics_path = output_path / "metrics.json"
|
||||
metrics_path.write_text(json.dumps(metrics))
|
||||
|
||||
for figure_name, figure in plots:
|
||||
figure_path = output_path / figure_name
|
||||
figure_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
figure.savefig(figure_path)
|
||||
@ -1,106 +0,0 @@
|
||||
from typing import Annotated, Callable, Literal, Sequence, Union
|
||||
|
||||
import pandas as pd
|
||||
from pydantic import Field
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.typing import ClipMatches
|
||||
|
||||
EvaluationTableGenerator = Callable[[Sequence[ClipMatches]], pd.DataFrame]
|
||||
|
||||
|
||||
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
|
||||
"evaluation_table"
|
||||
)
|
||||
|
||||
|
||||
class FullEvaluationTableConfig(BaseConfig):
|
||||
name: Literal["full_evaluation"] = "full_evaluation"
|
||||
|
||||
|
||||
class FullEvaluationTable:
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipMatches]
|
||||
) -> pd.DataFrame:
|
||||
return extract_matches_dataframe(clip_evaluations)
|
||||
|
||||
@tables_registry.register(FullEvaluationTableConfig)
|
||||
@staticmethod
|
||||
def from_config(config: FullEvaluationTableConfig):
|
||||
return FullEvaluationTable()
|
||||
|
||||
|
||||
def extract_matches_dataframe(
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
) -> pd.DataFrame:
|
||||
data = []
|
||||
|
||||
for clip_evaluation in clip_evaluations:
|
||||
for match in clip_evaluation.matches:
|
||||
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
||||
pred_start_time = pred_low_freq = pred_end_time = (
|
||||
pred_high_freq
|
||||
) = None
|
||||
|
||||
sound_event_annotation = match.sound_event_annotation
|
||||
|
||||
if sound_event_annotation is not None:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
assert geometry is not None
|
||||
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
|
||||
compute_bounds(geometry)
|
||||
)
|
||||
|
||||
if match.pred_geometry is not None:
|
||||
(
|
||||
pred_start_time,
|
||||
pred_low_freq,
|
||||
pred_end_time,
|
||||
pred_high_freq,
|
||||
) = compute_bounds(match.pred_geometry)
|
||||
|
||||
data.append(
|
||||
{
|
||||
("recording", "uuid"): match.clip.recording.uuid,
|
||||
("clip", "uuid"): match.clip.uuid,
|
||||
("clip", "start_time"): match.clip.start_time,
|
||||
("clip", "end_time"): match.clip.end_time,
|
||||
("gt", "uuid"): match.sound_event_annotation.uuid
|
||||
if match.sound_event_annotation is not None
|
||||
else None,
|
||||
("gt", "class"): match.gt_class,
|
||||
("gt", "det"): match.gt_det,
|
||||
("gt", "start_time"): gt_start_time,
|
||||
("gt", "end_time"): gt_end_time,
|
||||
("gt", "low_freq"): gt_low_freq,
|
||||
("gt", "high_freq"): gt_high_freq,
|
||||
("pred", "score"): match.pred_score,
|
||||
("pred", "class"): match.top_class,
|
||||
("pred", "class_score"): match.top_class_score,
|
||||
("pred", "start_time"): pred_start_time,
|
||||
("pred", "end_time"): pred_end_time,
|
||||
("pred", "low_freq"): pred_low_freq,
|
||||
("pred", "high_freq"): pred_high_freq,
|
||||
("match", "affinity"): match.affinity,
|
||||
**{
|
||||
("pred_class_score", key): value
|
||||
for key, value in match.pred_class_scores.items()
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
|
||||
return df
|
||||
|
||||
|
||||
EvaluationTableConfig = Annotated[
|
||||
Union[FullEvaluationTableConfig,], Field(discriminator="name")
|
||||
]
|
||||
|
||||
|
||||
def build_table_generator(
|
||||
config: EvaluationTableConfig,
|
||||
) -> EvaluationTableGenerator:
|
||||
return tables_registry.build(config)
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Optional, Sequence, Union
|
||||
from typing import Annotated, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
@ -11,12 +11,10 @@ from batdetect2.evaluate.tasks.clip_classification import (
|
||||
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
|
||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
||||
from batdetect2.evaluate.types import EvaluationTaskProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
EvaluatorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"TaskConfig",
|
||||
@ -26,31 +24,29 @@ __all__ = [
|
||||
|
||||
|
||||
TaskConfig = Annotated[
|
||||
Union[
|
||||
ClassificationTaskConfig,
|
||||
DetectionTaskConfig,
|
||||
ClipDetectionTaskConfig,
|
||||
ClipClassificationTaskConfig,
|
||||
TopClassDetectionTaskConfig,
|
||||
],
|
||||
ClassificationTaskConfig
|
||||
| DetectionTaskConfig
|
||||
| ClipDetectionTaskConfig
|
||||
| ClipClassificationTaskConfig
|
||||
| TopClassDetectionTaskConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_task(
|
||||
config: TaskConfig,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
) -> EvaluatorProtocol:
|
||||
targets: TargetProtocol | None = None,
|
||||
) -> EvaluationTaskProtocol:
|
||||
targets = targets or build_targets()
|
||||
return tasks_registry.build(config, targets)
|
||||
|
||||
|
||||
def evaluate_task(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
task: Optional["str"] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
config: Optional[Union[TaskConfig, dict]] = None,
|
||||
predictions: Sequence[ClipDetections],
|
||||
task: str | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
config: TaskConfig | dict | None = None,
|
||||
):
|
||||
if isinstance(config, BaseTaskConfig):
|
||||
task_obj = build_task(config, targets)
|
||||
|
||||
@ -4,78 +4,93 @@ from typing import (
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Literal,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.evaluate.match import (
|
||||
MatchConfig,
|
||||
StartTimeMatchConfig,
|
||||
build_matcher,
|
||||
from batdetect2.core import (
|
||||
BaseConfig,
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
from batdetect2.evaluate.affinity import (
|
||||
AffinityConfig,
|
||||
TimeAffinityConfig,
|
||||
build_affinity_function,
|
||||
)
|
||||
from batdetect2.evaluate.types import (
|
||||
AffinityFunction,
|
||||
EvaluationTaskProtocol,
|
||||
)
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"BaseTaskConfig",
|
||||
"BaseTask",
|
||||
"TaskImportConfig",
|
||||
]
|
||||
|
||||
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
|
||||
tasks_registry: Registry[EvaluationTaskProtocol, [TargetProtocol]] = Registry(
|
||||
"tasks"
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(tasks_registry)
|
||||
class TaskImportConfig(ImportConfig):
|
||||
"""Use any callable as an evaluation task.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
T_Output = TypeVar("T_Output")
|
||||
|
||||
|
||||
class BaseTaskConfig(BaseConfig):
|
||||
prefix: str
|
||||
|
||||
ignore_start_end: float = 0.01
|
||||
matching_strategy: MatchConfig = Field(
|
||||
default_factory=StartTimeMatchConfig
|
||||
)
|
||||
|
||||
|
||||
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
class BaseTask(EvaluationTaskProtocol, Generic[T_Output]):
|
||||
targets: TargetProtocol
|
||||
|
||||
matcher: MatcherProtocol
|
||||
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
||||
|
||||
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||
|
||||
ignore_start_end: float
|
||||
|
||||
prefix: str
|
||||
|
||||
ignore_start_end: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
matcher: MatcherProtocol,
|
||||
targets: TargetProtocol,
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||
prefix: str,
|
||||
plots: List[
|
||||
Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]
|
||||
]
|
||||
| None = None,
|
||||
ignore_start_end: float = 0.01,
|
||||
plots: Optional[
|
||||
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||
] = None,
|
||||
):
|
||||
self.matcher = matcher
|
||||
self.prefix = prefix
|
||||
self.targets = targets
|
||||
self.metrics = metrics
|
||||
self.plots = plots or []
|
||||
self.targets = targets
|
||||
self.prefix = prefix
|
||||
self.ignore_start_end = ignore_start_end
|
||||
|
||||
def compute_metrics(
|
||||
@ -93,24 +108,30 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
self, eval_outputs: List[T_Output]
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
for plot in self.plots:
|
||||
for name, fig in plot(eval_outputs):
|
||||
yield f"{self.prefix}/{name}", fig
|
||||
try:
|
||||
for name, fig in plot(eval_outputs):
|
||||
yield f"{self.prefix}/{name}", fig
|
||||
except Exception as e:
|
||||
logger.error(f"Error plotting {self.prefix}: {e}")
|
||||
continue
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
predictions: Sequence[ClipDetections],
|
||||
) -> List[T_Output]:
|
||||
return [
|
||||
self.evaluate_clip(clip_annotation, preds)
|
||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
||||
for clip_annotation, preds in zip(
|
||||
clip_annotations, predictions, strict=False
|
||||
)
|
||||
]
|
||||
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
prediction: BatDetect2Prediction,
|
||||
) -> T_Output: ...
|
||||
prediction: ClipDetections,
|
||||
) -> T_Output: ... # ty: ignore[empty-body]
|
||||
|
||||
def include_sound_event_annotation(
|
||||
self,
|
||||
@ -121,9 +142,6 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
return False
|
||||
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
return is_in_bounds(
|
||||
geometry,
|
||||
clip,
|
||||
@ -132,7 +150,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
|
||||
def include_prediction(
|
||||
self,
|
||||
prediction: RawPrediction,
|
||||
prediction: Detection,
|
||||
clip: data.Clip,
|
||||
) -> bool:
|
||||
return is_in_bounds(
|
||||
@ -141,25 +159,56 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
self.ignore_start_end,
|
||||
)
|
||||
|
||||
|
||||
class BaseSEDTaskConfig(BaseTaskConfig):
|
||||
affinity: AffinityConfig = Field(default_factory=TimeAffinityConfig)
|
||||
affinity_threshold: float = 0
|
||||
strict_match: bool = True
|
||||
|
||||
|
||||
class BaseSEDTask(BaseTask[T_Output]):
|
||||
affinity: AffinityFunction
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
targets: TargetProtocol,
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||
affinity: AffinityFunction,
|
||||
plots: List[
|
||||
Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]
|
||||
]
|
||||
| None = None,
|
||||
affinity_threshold: float = 0,
|
||||
ignore_start_end: float = 0.01,
|
||||
strict_match: bool = True,
|
||||
):
|
||||
super().__init__(
|
||||
prefix=prefix,
|
||||
metrics=metrics,
|
||||
plots=plots,
|
||||
targets=targets,
|
||||
ignore_start_end=ignore_start_end,
|
||||
)
|
||||
self.affinity = affinity
|
||||
self.affinity_threshold = affinity_threshold
|
||||
self.strict_match = strict_match
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
config: BaseTaskConfig,
|
||||
config: BaseSEDTaskConfig,
|
||||
targets: TargetProtocol,
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||
plots: Optional[
|
||||
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||
] = None,
|
||||
**kwargs,
|
||||
):
|
||||
matcher = build_matcher(config.matching_strategy)
|
||||
affinity = build_affinity_function(config.affinity)
|
||||
return cls(
|
||||
matcher=matcher,
|
||||
targets=targets,
|
||||
metrics=metrics,
|
||||
plots=plots,
|
||||
affinity=affinity,
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
prefix=config.prefix,
|
||||
ignore_start_end=config.ignore_start_end,
|
||||
strict_match=config.strict_match,
|
||||
targets=targets,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
from typing import (
|
||||
List,
|
||||
Literal,
|
||||
)
|
||||
from functools import partial
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_detections_and_gts
|
||||
|
||||
from batdetect2.evaluate.metrics.classification import (
|
||||
ClassificationAveragePrecisionConfig,
|
||||
@ -18,24 +17,25 @@ from batdetect2.evaluate.plots.classification import (
|
||||
build_classification_plotter,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
BaseSEDTask,
|
||||
BaseSEDTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
class ClassificationTaskConfig(BaseTaskConfig):
|
||||
class ClassificationTaskConfig(BaseSEDTaskConfig):
|
||||
name: Literal["sound_event_classification"] = "sound_event_classification"
|
||||
prefix: str = "classification"
|
||||
metrics: List[ClassificationMetricConfig] = Field(
|
||||
metrics: list[ClassificationMetricConfig] = Field(
|
||||
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
||||
)
|
||||
plots: List[ClassificationPlotConfig] = Field(default_factory=list)
|
||||
plots: list[ClassificationPlotConfig] = Field(default_factory=list)
|
||||
include_generics: bool = True
|
||||
|
||||
|
||||
class ClassificationTask(BaseTask[ClipEval]):
|
||||
class ClassificationTask(BaseSEDTask[ClipEval]):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
@ -48,13 +48,13 @@ class ClassificationTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
prediction: BatDetect2Prediction,
|
||||
prediction: ClipDetections,
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
preds = [
|
||||
pred
|
||||
for pred in prediction.predictions
|
||||
for pred in prediction.detections
|
||||
if self.include_prediction(pred, clip)
|
||||
]
|
||||
|
||||
@ -73,40 +73,40 @@ class ClassificationTask(BaseTask[ClipEval]):
|
||||
gts = [
|
||||
sound_event
|
||||
for sound_event in all_gts
|
||||
if self.is_class(sound_event, class_name)
|
||||
if is_target_class(
|
||||
sound_event,
|
||||
class_name,
|
||||
self.targets,
|
||||
include_generics=self.include_generics,
|
||||
)
|
||||
]
|
||||
scores = [float(pred.class_scores[class_idx]) for pred in preds]
|
||||
|
||||
matches = []
|
||||
|
||||
for pred_idx, gt_idx, _ in self.matcher(
|
||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||
predictions=[pred.geometry for pred in preds],
|
||||
scores=scores,
|
||||
for match in match_detections_and_gts(
|
||||
detections=preds,
|
||||
ground_truths=gts,
|
||||
affinity=self.affinity,
|
||||
score=partial(get_class_score, class_idx=class_idx),
|
||||
strict_match=self.strict_match,
|
||||
affinity_threshold=self.affinity_threshold,
|
||||
):
|
||||
gt = gts[gt_idx] if gt_idx is not None else None
|
||||
pred = preds[pred_idx] if pred_idx is not None else None
|
||||
|
||||
true_class = (
|
||||
self.targets.encode_class(gt) if gt is not None else None
|
||||
self.targets.encode_class(match.annotation)
|
||||
if match.annotation is not None
|
||||
else None
|
||||
)
|
||||
|
||||
score = (
|
||||
float(pred.class_scores[class_idx])
|
||||
if pred is not None
|
||||
else 0
|
||||
)
|
||||
|
||||
matches.append(
|
||||
MatchEval(
|
||||
clip=clip,
|
||||
gt=gt,
|
||||
pred=pred,
|
||||
is_prediction=pred is not None,
|
||||
is_ground_truth=gt is not None,
|
||||
is_generic=gt is not None and true_class is None,
|
||||
gt=match.annotation,
|
||||
pred=match.prediction,
|
||||
is_prediction=match.prediction is not None,
|
||||
is_ground_truth=match.annotation is not None,
|
||||
is_generic=match.annotation is not None
|
||||
and true_class is None,
|
||||
true_class=true_class,
|
||||
score=score,
|
||||
score=match.prediction_score,
|
||||
)
|
||||
)
|
||||
|
||||
@ -114,20 +114,6 @@ class ClassificationTask(BaseTask[ClipEval]):
|
||||
|
||||
return ClipEval(clip=clip, matches=per_class_matches)
|
||||
|
||||
def is_class(
|
||||
self,
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
class_name: str,
|
||||
) -> bool:
|
||||
sound_event_class = self.targets.encode_class(sound_event)
|
||||
|
||||
if sound_event_class is None and self.include_generics:
|
||||
# Sound events that are generic could be of the given
|
||||
# class
|
||||
return True
|
||||
|
||||
return sound_event_class == class_name
|
||||
|
||||
@tasks_registry.register(ClassificationTaskConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
@ -147,4 +133,25 @@ class ClassificationTask(BaseTask[ClipEval]):
|
||||
plots=plots,
|
||||
targets=targets,
|
||||
metrics=metrics,
|
||||
include_generics=config.include_generics,
|
||||
)
|
||||
|
||||
|
||||
def get_class_score(pred: Detection, class_idx: int) -> float:
|
||||
return pred.class_scores[class_idx]
|
||||
|
||||
|
||||
def is_target_class(
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
class_name: str,
|
||||
targets: TargetProtocol,
|
||||
include_generics: bool = True,
|
||||
) -> bool:
|
||||
sound_event_class = targets.encode_class(sound_event)
|
||||
|
||||
if sound_event_class is None and include_generics:
|
||||
# Sound events that are generic could be of the given
|
||||
# class
|
||||
return True
|
||||
|
||||
return sound_event_class == class_name
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import List, Literal
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
@ -19,26 +19,26 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||
name: Literal["clip_classification"] = "clip_classification"
|
||||
prefix: str = "clip_classification"
|
||||
metrics: List[ClipClassificationMetricConfig] = Field(
|
||||
metrics: list[ClipClassificationMetricConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ClipClassificationAveragePrecisionConfig(),
|
||||
]
|
||||
)
|
||||
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list)
|
||||
plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ClipClassificationTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
prediction: BatDetect2Prediction,
|
||||
prediction: ClipDetections,
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
@ -55,7 +55,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
|
||||
gt_classes.add(class_name)
|
||||
|
||||
pred_scores = defaultdict(float)
|
||||
for pred in prediction.predictions:
|
||||
for pred in prediction.detections:
|
||||
if not self.include_prediction(pred, clip):
|
||||
continue
|
||||
|
||||
@ -78,8 +78,8 @@ class ClipClassificationTask(BaseTask[ClipEval]):
|
||||
build_clip_classification_plotter(plot, targets)
|
||||
for plot in config.plots
|
||||
]
|
||||
return ClipClassificationTask.build(
|
||||
config=config,
|
||||
return ClipClassificationTask(
|
||||
prefix=config.prefix,
|
||||
plots=plots,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Literal
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
@ -18,26 +18,26 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||
name: Literal["clip_detection"] = "clip_detection"
|
||||
prefix: str = "clip_detection"
|
||||
metrics: List[ClipDetectionMetricConfig] = Field(
|
||||
metrics: list[ClipDetectionMetricConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ClipDetectionAveragePrecisionConfig(),
|
||||
]
|
||||
)
|
||||
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list)
|
||||
plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ClipDetectionTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
prediction: BatDetect2Prediction,
|
||||
prediction: ClipDetections,
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
@ -47,7 +47,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
|
||||
)
|
||||
|
||||
pred_score = 0
|
||||
for pred in prediction.predictions:
|
||||
for pred in prediction.detections:
|
||||
if not self.include_prediction(pred, clip):
|
||||
continue
|
||||
|
||||
@ -69,8 +69,8 @@ class ClipDetectionTask(BaseTask[ClipEval]):
|
||||
build_clip_detection_plotter(plot, targets)
|
||||
for plot in config.plots
|
||||
]
|
||||
return ClipDetectionTask.build(
|
||||
config=config,
|
||||
return ClipDetectionTask(
|
||||
prefix=config.prefix,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
plots=plots,
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from typing import List, Literal
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_detections_and_gts
|
||||
|
||||
from batdetect2.evaluate.metrics.detection import (
|
||||
ClipEval,
|
||||
@ -15,28 +16,28 @@ from batdetect2.evaluate.plots.detection import (
|
||||
build_detection_plotter,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
BaseSEDTask,
|
||||
BaseSEDTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
class DetectionTaskConfig(BaseTaskConfig):
|
||||
class DetectionTaskConfig(BaseSEDTaskConfig):
|
||||
name: Literal["sound_event_detection"] = "sound_event_detection"
|
||||
prefix: str = "detection"
|
||||
metrics: List[DetectionMetricConfig] = Field(
|
||||
metrics: list[DetectionMetricConfig] = Field(
|
||||
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
||||
)
|
||||
plots: List[DetectionPlotConfig] = Field(default_factory=list)
|
||||
plots: list[DetectionPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DetectionTask(BaseTask[ClipEval]):
|
||||
class DetectionTask(BaseSEDTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
prediction: BatDetect2Prediction,
|
||||
prediction: ClipDetections,
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
@ -47,27 +48,26 @@ class DetectionTask(BaseTask[ClipEval]):
|
||||
]
|
||||
preds = [
|
||||
pred
|
||||
for pred in prediction.predictions
|
||||
for pred in prediction.detections
|
||||
if self.include_prediction(pred, clip)
|
||||
]
|
||||
scores = [pred.detection_score for pred in preds]
|
||||
|
||||
matches = []
|
||||
for pred_idx, gt_idx, _ in self.matcher(
|
||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||
predictions=[pred.geometry for pred in preds],
|
||||
scores=scores,
|
||||
for match in match_detections_and_gts(
|
||||
detections=preds,
|
||||
ground_truths=gts,
|
||||
affinity=self.affinity,
|
||||
score=lambda pred: pred.detection_score,
|
||||
strict_match=self.strict_match,
|
||||
affinity_threshold=self.affinity_threshold,
|
||||
):
|
||||
gt = gts[gt_idx] if gt_idx is not None else None
|
||||
pred = preds[pred_idx] if pred_idx is not None else None
|
||||
|
||||
matches.append(
|
||||
MatchEval(
|
||||
gt=gt,
|
||||
pred=pred,
|
||||
is_prediction=pred is not None,
|
||||
is_ground_truth=gt is not None,
|
||||
score=pred.detection_score if pred is not None else 0,
|
||||
gt=match.annotation,
|
||||
pred=match.prediction,
|
||||
is_prediction=match.prediction is not None,
|
||||
is_ground_truth=match.annotation is not None,
|
||||
score=match.prediction_score,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from typing import List, Literal
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_detections_and_gts
|
||||
|
||||
from batdetect2.evaluate.metrics.top_class import (
|
||||
ClipEval,
|
||||
@ -15,28 +16,28 @@ from batdetect2.evaluate.plots.top_class import (
|
||||
build_top_class_plotter,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
BaseSEDTask,
|
||||
BaseSEDTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
class TopClassDetectionTaskConfig(BaseTaskConfig):
|
||||
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
|
||||
name: Literal["top_class_detection"] = "top_class_detection"
|
||||
prefix: str = "top_class"
|
||||
metrics: List[TopClassMetricConfig] = Field(
|
||||
metrics: list[TopClassMetricConfig] = Field(
|
||||
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
||||
)
|
||||
plots: List[TopClassPlotConfig] = Field(default_factory=list)
|
||||
plots: list[TopClassPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||
class TopClassDetectionTask(BaseSEDTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
prediction: BatDetect2Prediction,
|
||||
prediction: ClipDetections,
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
@ -47,21 +48,21 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||
]
|
||||
preds = [
|
||||
pred
|
||||
for pred in prediction.predictions
|
||||
for pred in prediction.detections
|
||||
if self.include_prediction(pred, clip)
|
||||
]
|
||||
# Take the highest score for each prediction
|
||||
scores = [pred.class_scores.max() for pred in preds]
|
||||
|
||||
matches = []
|
||||
for pred_idx, gt_idx, _ in self.matcher(
|
||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||
predictions=[pred.geometry for pred in preds],
|
||||
scores=scores,
|
||||
for match in match_detections_and_gts(
|
||||
ground_truths=gts,
|
||||
detections=preds,
|
||||
affinity=self.affinity,
|
||||
score=lambda pred: pred.class_scores.max(),
|
||||
strict_match=self.strict_match,
|
||||
affinity_threshold=self.affinity_threshold,
|
||||
):
|
||||
gt = gts[gt_idx] if gt_idx is not None else None
|
||||
pred = preds[pred_idx] if pred_idx is not None else None
|
||||
|
||||
gt = match.annotation
|
||||
pred = match.prediction
|
||||
true_class = (
|
||||
self.targets.encode_class(gt) if gt is not None else None
|
||||
)
|
||||
@ -69,11 +70,6 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||
class_idx = (
|
||||
pred.class_scores.argmax() if pred is not None else None
|
||||
)
|
||||
|
||||
score = (
|
||||
float(pred.class_scores[class_idx]) if pred is not None else 0
|
||||
)
|
||||
|
||||
pred_class = (
|
||||
self.targets.class_names[class_idx]
|
||||
if class_idx is not None
|
||||
@ -90,7 +86,7 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||
true_class=true_class,
|
||||
is_generic=gt is not None and true_class is None,
|
||||
pred_class=pred_class,
|
||||
score=score,
|
||||
score=match.prediction_score,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
147
src/batdetect2/evaluate/types.py
Normal file
147
src/batdetect2/evaluate/types.py
Normal file
@ -0,0 +1,147 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, Iterable, Protocol, Sequence, TypeVar
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.outputs.types import OutputTransformProtocol
|
||||
from batdetect2.postprocess.types import (
|
||||
ClipDetections,
|
||||
ClipDetectionsTensor,
|
||||
Detection,
|
||||
)
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"AffinityFunction",
|
||||
"ClipMatches",
|
||||
"EvaluationTaskProtocol",
|
||||
"EvaluatorProtocol",
|
||||
"MatchEvaluation",
|
||||
"MatcherProtocol",
|
||||
"MetricsProtocol",
|
||||
"PlotterProtocol",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchEvaluation:
|
||||
clip: data.Clip
|
||||
sound_event_annotation: data.SoundEventAnnotation | None
|
||||
gt_det: bool
|
||||
gt_class: str | None
|
||||
gt_geometry: data.Geometry | None
|
||||
pred_score: float
|
||||
pred_class_scores: dict[str, float]
|
||||
pred_geometry: data.Geometry | None
|
||||
affinity: float
|
||||
|
||||
@property
|
||||
def top_class(self) -> str | None:
|
||||
if not self.pred_class_scores:
|
||||
return None
|
||||
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
|
||||
|
||||
@property
|
||||
def is_prediction(self) -> bool:
|
||||
return self.pred_geometry is not None
|
||||
|
||||
@property
|
||||
def is_generic(self) -> bool:
|
||||
return self.gt_det and self.gt_class is None
|
||||
|
||||
@property
|
||||
def top_class_score(self) -> float:
|
||||
pred_class = self.top_class
|
||||
if pred_class is None:
|
||||
return 0
|
||||
return self.pred_class_scores[pred_class]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipMatches:
|
||||
clip: data.Clip
|
||||
matches: list[MatchEvaluation]
|
||||
|
||||
|
||||
class MatcherProtocol(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
) -> Iterable[tuple[int | None, int | None, float]]: ...
|
||||
|
||||
|
||||
class AffinityFunction(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
detection: Detection,
|
||||
ground_truth: data.SoundEventAnnotation,
|
||||
) -> float: ...
|
||||
|
||||
|
||||
class MetricsProtocol(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[Detection]],
|
||||
) -> dict[str, float]: ...
|
||||
|
||||
|
||||
class PlotterProtocol(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[Detection]],
|
||||
) -> Iterable[tuple[str, Figure]]: ...
|
||||
|
||||
|
||||
EvaluationOutput = TypeVar("EvaluationOutput")
|
||||
|
||||
|
||||
class EvaluationTaskProtocol(Protocol, Generic[EvaluationOutput]):
|
||||
targets: TargetProtocol
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[ClipDetections],
|
||||
) -> EvaluationOutput: ...
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
eval_outputs: EvaluationOutput,
|
||||
) -> dict[str, float]: ...
|
||||
|
||||
def generate_plots(
|
||||
self,
|
||||
eval_outputs: EvaluationOutput,
|
||||
) -> Iterable[tuple[str, Figure]]: ...
|
||||
|
||||
|
||||
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
||||
targets: TargetProtocol
|
||||
transform: OutputTransformProtocol
|
||||
|
||||
def to_clip_detections_batch(
|
||||
self,
|
||||
clip_detections: Sequence[ClipDetectionsTensor],
|
||||
clips: Sequence[data.Clip],
|
||||
) -> list[ClipDetections]: ...
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[ClipDetections],
|
||||
) -> EvaluationOutput: ...
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
eval_outputs: EvaluationOutput,
|
||||
) -> dict[str, float]: ...
|
||||
|
||||
def generate_plots(
|
||||
self,
|
||||
eval_outputs: EvaluationOutput,
|
||||
) -> Iterable[tuple[str, Figure]]: ...
|
||||
@ -1,7 +1,7 @@
|
||||
import argparse
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
@ -88,7 +88,8 @@ def select_device(warn=True) -> str:
|
||||
if warn:
|
||||
warnings.warn(
|
||||
"No GPU available, using the CPU instead. Please consider using a GPU "
|
||||
"to speed up training."
|
||||
"to speed up training.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
return "cpu"
|
||||
@ -98,8 +99,8 @@ def load_annotations(
|
||||
dataset_name: str,
|
||||
ann_path: str,
|
||||
audio_path: str,
|
||||
classes_to_ignore: Optional[List[str]] = None,
|
||||
events_of_interest: Optional[List[str]] = None,
|
||||
classes_to_ignore: List[str] | None = None,
|
||||
events_of_interest: List[str] | None = None,
|
||||
) -> List[types.FileAnnotation]:
|
||||
train_sets: List[types.DatasetDict] = []
|
||||
train_sets.append(
|
||||
|
||||
@ -2,7 +2,6 @@ import argparse
|
||||
import json
|
||||
import os
|
||||
from collections import Counter
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from sklearn.model_selection import StratifiedGroupKFold
|
||||
@ -12,8 +11,8 @@ from batdetect2 import types
|
||||
|
||||
|
||||
def print_dataset_stats(
|
||||
data: List[types.FileAnnotation],
|
||||
classes_to_ignore: Optional[List[str]] = None,
|
||||
data: list[types.FileAnnotation],
|
||||
classes_to_ignore: list[str] | None = None,
|
||||
) -> Counter[str]:
|
||||
print("Num files:", len(data))
|
||||
counts, _ = tu.get_class_names(data, classes_to_ignore)
|
||||
@ -22,7 +21,7 @@ def print_dataset_stats(
|
||||
return counts
|
||||
|
||||
|
||||
def load_file_names(file_name: str) -> List[str]:
|
||||
def load_file_names(file_name: str) -> list[str]:
|
||||
if not os.path.isfile(file_name):
|
||||
raise FileNotFoundError(f"Input file not found - {file_name}")
|
||||
|
||||
@ -100,12 +99,12 @@ def parse_args():
|
||||
|
||||
|
||||
def split_data(
|
||||
data: List[types.FileAnnotation],
|
||||
data: list[types.FileAnnotation],
|
||||
train_file: str,
|
||||
test_file: str,
|
||||
n_splits: int = 5,
|
||||
random_state: int = 0,
|
||||
) -> Tuple[List[types.FileAnnotation], List[types.FileAnnotation]]:
|
||||
) -> tuple[list[types.FileAnnotation], list[types.FileAnnotation]]:
|
||||
if train_file != "" and test_file != "":
|
||||
# user has specifed the train / test split
|
||||
mapping = {
|
||||
@ -162,7 +161,7 @@ def main():
|
||||
# change the names of the classes
|
||||
ip_names = args.input_class_names.split(";")
|
||||
op_names = args.output_class_names.split(";")
|
||||
name_dict = dict(zip(ip_names, op_names))
|
||||
name_dict = dict(zip(ip_names, op_names, strict=False))
|
||||
|
||||
# load annotations
|
||||
data_all = tu.load_set_of_anns(
|
||||
|
||||
@ -1,58 +1,68 @@
|
||||
from typing import TYPE_CHECKING, List, Optional, Sequence
|
||||
from typing import Sequence
|
||||
|
||||
from lightning import Trainer
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.audio.loader import build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.inference.clips import get_clips_from_files
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
from batdetect2.inference.dataset import build_inference_loader
|
||||
from batdetect2.inference.lightning import InferenceModule
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.preprocess.preprocessor import build_preprocessor
|
||||
from batdetect2.targets.targets import build_targets
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.outputs import (
|
||||
OutputsConfig,
|
||||
OutputTransformProtocol,
|
||||
build_output_transform,
|
||||
)
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def run_batch_inference(
|
||||
model,
|
||||
model: Model,
|
||||
clips: Sequence[data.Clip],
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader()
|
||||
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
targets: TargetProtocol | None = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
output_transform: OutputTransformProtocol | None = None,
|
||||
output_config: OutputsConfig | None = None,
|
||||
inference_config: InferenceConfig | None = None,
|
||||
num_workers: int = 1,
|
||||
batch_size: int | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
audio_config = audio_config or AudioConfig(
|
||||
samplerate=model.preprocessor.input_samplerate,
|
||||
)
|
||||
output_config = output_config or OutputsConfig()
|
||||
inference_config = inference_config or InferenceConfig()
|
||||
|
||||
targets = targets or build_targets()
|
||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||
|
||||
preprocessor = preprocessor or model.preprocessor
|
||||
targets = targets or model.targets
|
||||
|
||||
output_transform = output_transform or build_output_transform(
|
||||
config=output_config.transform,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
loader = build_inference_loader(
|
||||
clips,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
config=config.inference.loader,
|
||||
config=inference_config.loader,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
module = InferenceModule(model)
|
||||
module = InferenceModule(
|
||||
model,
|
||||
output_transform=output_transform,
|
||||
)
|
||||
trainer = Trainer(enable_checkpointing=False, logger=False)
|
||||
outputs = trainer.predict(module, loader)
|
||||
return [
|
||||
@ -65,13 +75,18 @@ def run_batch_inference(
|
||||
def process_file_list(
|
||||
model: Model,
|
||||
paths: Sequence[data.PathLike],
|
||||
config: "BatDetect2Config",
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
clip_config = config.inference.clipping
|
||||
targets: TargetProtocol | None = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
inference_config: InferenceConfig | None = None,
|
||||
output_config: OutputsConfig | None = None,
|
||||
output_transform: OutputTransformProtocol | None = None,
|
||||
batch_size: int | None = None,
|
||||
num_workers: int = 0,
|
||||
) -> list[ClipDetections]:
|
||||
inference_config = inference_config or InferenceConfig()
|
||||
clip_config = inference_config.clipping
|
||||
clips = get_clips_from_files(
|
||||
paths,
|
||||
duration=clip_config.duration,
|
||||
@ -85,6 +100,10 @@ def process_file_list(
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
config=config,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
output_config=output_config,
|
||||
audio_config=audio_config,
|
||||
output_transform=output_transform,
|
||||
inference_config=inference_config,
|
||||
)
|
||||
|
||||
@ -38,10 +38,10 @@ def get_recording_clips(
|
||||
discard_empty: bool = True,
|
||||
) -> Sequence[data.Clip]:
|
||||
start_time = 0
|
||||
duration = recording.duration
|
||||
recording_duration = recording.duration
|
||||
hop = duration * (1 - overlap)
|
||||
|
||||
num_clips = int(np.ceil(duration / hop))
|
||||
num_clips = int(np.ceil(recording_duration / hop))
|
||||
|
||||
if num_clips == 0:
|
||||
# This should only happen if the clip's duration is zero,
|
||||
@ -53,8 +53,8 @@ def get_recording_clips(
|
||||
start = start_time + i * hop
|
||||
end = start + duration
|
||||
|
||||
if end > duration:
|
||||
empty_duration = end - duration
|
||||
if end > recording_duration:
|
||||
empty_duration = end - recording_duration
|
||||
|
||||
if empty_duration > max_empty and discard_empty:
|
||||
# Discard clips that contain too much empty space
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, NamedTuple, Optional, Sequence
|
||||
from typing import NamedTuple, Sequence
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
@ -6,10 +6,11 @@ from soundevent import data
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"InferenceDataset",
|
||||
@ -29,14 +30,14 @@ class DatasetItem(NamedTuple):
|
||||
|
||||
|
||||
class InferenceDataset(Dataset[DatasetItem]):
|
||||
clips: List[data.Clip]
|
||||
clips: list[data.Clip]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clips: Sequence[data.Clip],
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
audio_dir: data.PathLike | None = None,
|
||||
):
|
||||
self.clips = list(clips)
|
||||
self.preprocessor = preprocessor
|
||||
@ -46,31 +47,30 @@ class InferenceDataset(Dataset[DatasetItem]):
|
||||
def __len__(self):
|
||||
return len(self.clips)
|
||||
|
||||
def __getitem__(self, idx: int) -> DatasetItem:
|
||||
clip = self.clips[idx]
|
||||
def __getitem__(self, index: int) -> DatasetItem:
|
||||
clip = self.clips[index]
|
||||
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
||||
spectrogram = self.preprocessor(wav_tensor)
|
||||
return DatasetItem(
|
||||
spec=spectrogram,
|
||||
idx=torch.tensor(idx),
|
||||
idx=torch.tensor(index),
|
||||
start_time=torch.tensor(clip.start_time),
|
||||
end_time=torch.tensor(clip.end_time),
|
||||
)
|
||||
|
||||
|
||||
class InferenceLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
batch_size: int = 8
|
||||
|
||||
|
||||
def build_inference_loader(
|
||||
clips: Sequence[data.Clip],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[InferenceLoaderConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
config: InferenceLoaderConfig | None = None,
|
||||
num_workers: int = 0,
|
||||
batch_size: int | None = None,
|
||||
) -> DataLoader[DatasetItem]:
|
||||
logger.info("Building inference data loader...")
|
||||
config = config or InferenceLoaderConfig()
|
||||
@ -83,20 +83,19 @@ def build_inference_loader(
|
||||
|
||||
batch_size = batch_size or config.batch_size
|
||||
|
||||
num_workers = num_workers or config.num_workers
|
||||
return DataLoader(
|
||||
inference_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=config.num_workers,
|
||||
num_workers=num_workers,
|
||||
collate_fn=_collate_fn,
|
||||
)
|
||||
|
||||
|
||||
def build_inference_dataset(
|
||||
clips: Sequence[data.Clip],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
) -> InferenceDataset:
|
||||
if audio_loader is None:
|
||||
audio_loader = build_audio_loader()
|
||||
@ -111,7 +110,7 @@ def build_inference_dataset(
|
||||
)
|
||||
|
||||
|
||||
def _collate_fn(batch: List[DatasetItem]) -> DatasetItem:
|
||||
def _collate_fn(batch: list[DatasetItem]) -> DatasetItem:
|
||||
max_width = max(item.spec.shape[-1] for item in batch)
|
||||
return DatasetItem(
|
||||
spec=torch.stack(
|
||||
|
||||
@ -5,45 +5,44 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
|
||||
|
||||
class InferenceModule(LightningModule):
|
||||
def __init__(self, model: Model):
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
output_transform: OutputTransformProtocol | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.output_transform = output_transform or build_output_transform(
|
||||
targets=model.targets
|
||||
)
|
||||
|
||||
def predict_step(
|
||||
self,
|
||||
batch: DatasetItem,
|
||||
batch_idx: int,
|
||||
dataloader_idx: int = 0,
|
||||
) -> Sequence[BatDetect2Prediction]:
|
||||
) -> Sequence[ClipDetections]:
|
||||
dataset = self.get_dataset()
|
||||
|
||||
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
|
||||
|
||||
outputs = self.model.detector(batch.spec)
|
||||
|
||||
clip_detections = self.model.postprocessor(
|
||||
outputs,
|
||||
start_times=[clip.start_time for clip in clips],
|
||||
)
|
||||
clip_detections = self.model.postprocessor(outputs)
|
||||
|
||||
predictions = [
|
||||
BatDetect2Prediction(
|
||||
return [
|
||||
self.output_transform.to_clip_detections(
|
||||
detections=clip_dets,
|
||||
clip=clip,
|
||||
predictions=to_raw_predictions(
|
||||
clip_dets.numpy(),
|
||||
targets=self.model.targets,
|
||||
),
|
||||
)
|
||||
for clip, clip_dets in zip(clips, clip_detections)
|
||||
for clip, clip_dets in zip(clips, clip_detections, strict=True)
|
||||
]
|
||||
|
||||
return predictions
|
||||
|
||||
def get_dataset(self) -> InferenceDataset:
|
||||
dataloaders = self.trainer.predict_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
|
||||
@ -9,10 +9,8 @@ from typing import (
|
||||
Dict,
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
@ -32,6 +30,21 @@ from batdetect2.core.configs import BaseConfig
|
||||
|
||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||
|
||||
__all__ = [
|
||||
"AppLoggingConfig",
|
||||
"BaseLoggerConfig",
|
||||
"CSVLoggerConfig",
|
||||
"DEFAULT_LOGS_DIR",
|
||||
"DVCLiveConfig",
|
||||
"LoggerConfig",
|
||||
"MLFlowLoggerConfig",
|
||||
"TensorBoardLoggerConfig",
|
||||
"build_logger",
|
||||
"enable_logging",
|
||||
"get_image_logger",
|
||||
"get_table_logger",
|
||||
]
|
||||
|
||||
|
||||
def enable_logging(level: int):
|
||||
logger.remove()
|
||||
@ -49,14 +62,14 @@ def enable_logging(level: int):
|
||||
|
||||
class BaseLoggerConfig(BaseConfig):
|
||||
log_dir: Path = DEFAULT_LOGS_DIR
|
||||
experiment_name: Optional[str] = None
|
||||
run_name: Optional[str] = None
|
||||
experiment_name: str | None = None
|
||||
run_name: str | None = None
|
||||
|
||||
|
||||
class DVCLiveConfig(BaseLoggerConfig):
|
||||
name: Literal["dvclive"] = "dvclive"
|
||||
prefix: str = ""
|
||||
log_model: Union[bool, Literal["all"]] = False
|
||||
log_model: bool | Literal["all"] = False
|
||||
monitor_system: bool = False
|
||||
|
||||
|
||||
@ -72,22 +85,26 @@ class TensorBoardLoggerConfig(BaseLoggerConfig):
|
||||
|
||||
class MLFlowLoggerConfig(BaseLoggerConfig):
|
||||
name: Literal["mlflow"] = "mlflow"
|
||||
tracking_uri: Optional[str] = "http://localhost:5000"
|
||||
tags: Optional[dict[str, Any]] = None
|
||||
tracking_uri: str | None = "http://localhost:5000"
|
||||
tags: dict[str, Any] | None = None
|
||||
log_model: bool = False
|
||||
|
||||
|
||||
LoggerConfig = Annotated[
|
||||
Union[
|
||||
DVCLiveConfig,
|
||||
CSVLoggerConfig,
|
||||
TensorBoardLoggerConfig,
|
||||
MLFlowLoggerConfig,
|
||||
],
|
||||
DVCLiveConfig
|
||||
| CSVLoggerConfig
|
||||
| TensorBoardLoggerConfig
|
||||
| MLFlowLoggerConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
class AppLoggingConfig(BaseConfig):
|
||||
train: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
||||
evaluation: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
inference: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
|
||||
|
||||
T = TypeVar("T", bound=LoggerConfig, contravariant=True)
|
||||
|
||||
|
||||
@ -95,20 +112,20 @@ class LoggerBuilder(Protocol, Generic[T]):
|
||||
def __call__(
|
||||
self,
|
||||
config: T,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
log_dir: Path | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> Logger: ...
|
||||
|
||||
|
||||
def create_dvclive_logger(
|
||||
config: DVCLiveConfig,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
log_dir: Path | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> Logger:
|
||||
try:
|
||||
from dvclive.lightning import DVCLiveLogger # type: ignore
|
||||
from dvclive.lightning import DVCLiveLogger
|
||||
except ImportError as error:
|
||||
raise ValueError(
|
||||
"DVCLive is not installed and cannot be used for logging"
|
||||
@ -130,9 +147,9 @@ def create_dvclive_logger(
|
||||
|
||||
def create_csv_logger(
|
||||
config: CSVLoggerConfig,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
log_dir: Path | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> Logger:
|
||||
from lightning.pytorch.loggers import CSVLogger
|
||||
|
||||
@ -159,9 +176,9 @@ def create_csv_logger(
|
||||
|
||||
def create_tensorboard_logger(
|
||||
config: TensorBoardLoggerConfig,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
log_dir: Path | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> Logger:
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
@ -191,9 +208,9 @@ def create_tensorboard_logger(
|
||||
|
||||
def create_mlflow_logger(
|
||||
config: MLFlowLoggerConfig,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
log_dir: data.PathLike | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> Logger:
|
||||
try:
|
||||
from lightning.pytorch.loggers import MLFlowLogger
|
||||
@ -232,9 +249,9 @@ LOGGER_FACTORY: Dict[str, LoggerBuilder] = {
|
||||
|
||||
def build_logger(
|
||||
config: LoggerConfig,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
log_dir: Path | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> Logger:
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building logger with config: \n{}",
|
||||
@ -257,7 +274,7 @@ def build_logger(
|
||||
PlotLogger = Callable[[str, Figure, int], None]
|
||||
|
||||
|
||||
def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
|
||||
def get_image_logger(logger: Logger) -> PlotLogger | None:
|
||||
if isinstance(logger, TensorBoardLogger):
|
||||
return logger.experiment.add_figure
|
||||
|
||||
@ -282,7 +299,7 @@ def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
|
||||
TableLogger = Callable[[str, pd.DataFrame, int], None]
|
||||
|
||||
|
||||
def get_table_logger(logger: Logger) -> Optional[TableLogger]:
|
||||
def get_table_logger(logger: Logger) -> TableLogger | None:
|
||||
if isinstance(logger, TensorBoardLogger):
|
||||
return partial(save_table, dir=Path(logger.log_dir))
|
||||
|
||||
|
||||
@ -1,36 +1,46 @@
|
||||
"""Defines and builds the neural network models used in BatDetect2.
|
||||
"""Neural network model definitions and builders for BatDetect2.
|
||||
|
||||
This package (`batdetect2.models`) contains the PyTorch implementations of the
|
||||
deep neural network architectures used for detecting and classifying bat calls
|
||||
from spectrograms. It provides modular components and configuration-driven
|
||||
assembly, allowing for experimentation and use of different architectural
|
||||
variants.
|
||||
This package contains the PyTorch implementations of the deep neural network
|
||||
architectures used to detect and classify bat echolocation calls in
|
||||
spectrograms. Components are designed to be combined through configuration
|
||||
objects, making it easy to experiment with different architectures.
|
||||
|
||||
Key Submodules:
|
||||
- `.types`: Defines core data structures (`ModelOutput`) and abstract base
|
||||
classes (`BackboneModel`, `DetectionModel`) establishing interfaces.
|
||||
- `.blocks`: Provides reusable neural network building blocks.
|
||||
- `.encoder`: Defines and builds the downsampling path (encoder) of the network.
|
||||
- `.bottleneck`: Defines and builds the central bottleneck component.
|
||||
- `.decoder`: Defines and builds the upsampling path (decoder) of the network.
|
||||
- `.backbone`: Assembles the encoder, bottleneck, and decoder into a complete
|
||||
feature extraction backbone (e.g., a U-Net like structure).
|
||||
- `.heads`: Defines simple prediction heads (detection, classification, size)
|
||||
that attach to the backbone features.
|
||||
- `.detectors`: Assembles the backbone and prediction heads into the final,
|
||||
end-to-end `Detector` model.
|
||||
Key submodules
|
||||
--------------
|
||||
- ``blocks``: Reusable convolutional building blocks (downsampling,
|
||||
upsampling, attention, coord-conv variants).
|
||||
- ``encoder``: The downsampling path; reduces spatial resolution whilst
|
||||
extracting increasingly abstract features.
|
||||
- ``bottleneck``: The central component connecting encoder to decoder;
|
||||
optionally applies self-attention along the time axis.
|
||||
- ``decoder``: The upsampling path; reconstructs high-resolution feature
|
||||
maps using bottleneck output and skip connections from the encoder.
|
||||
- ``backbones``: Assembles encoder, bottleneck, and decoder into a complete
|
||||
U-Net-style feature extraction backbone.
|
||||
- ``heads``: Lightweight 1×1 convolutional heads that produce detection,
|
||||
classification, and bounding-box size predictions from backbone features.
|
||||
- ``detectors``: Combines a backbone with prediction heads into the final
|
||||
end-to-end ``Detector`` model.
|
||||
|
||||
This module re-exports the most important classes, configurations, and builder
|
||||
functions from these submodules for convenient access. The primary entry point
|
||||
for creating a standard BatDetect2 model instance is the `build_model` function
|
||||
provided here.
|
||||
The primary entry point for building a full, ready-to-use BatDetect2 model
|
||||
is the ``build_model`` factory function exported from this module.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.models.backbones import Backbone, build_backbone
|
||||
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.models.backbones import (
|
||||
BackboneConfig,
|
||||
UNetBackbone,
|
||||
UNetBackboneConfig,
|
||||
build_backbone,
|
||||
load_backbone_config,
|
||||
)
|
||||
from batdetect2.models.blocks import (
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
@ -43,10 +53,6 @@ from batdetect2.models.bottleneck import (
|
||||
BottleneckConfig,
|
||||
build_bottleneck,
|
||||
)
|
||||
from batdetect2.models.config import (
|
||||
BackboneConfig,
|
||||
load_backbone_config,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
DecoderConfig,
|
||||
@ -59,17 +65,20 @@ from batdetect2.models.encoder import (
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||
from batdetect2.typing import (
|
||||
from batdetect2.models.types import DetectionModel
|
||||
from batdetect2.postprocess.config import PostprocessConfig
|
||||
from batdetect2.postprocess.types import (
|
||||
ClipDetectionsTensor,
|
||||
DetectionModel,
|
||||
PostprocessorProtocol,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.config import TargetConfig
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"BBoxHead",
|
||||
"Backbone",
|
||||
"UNetBackbone",
|
||||
"BackboneConfig",
|
||||
"Bottleneck",
|
||||
"BottleneckConfig",
|
||||
@ -92,11 +101,93 @@ __all__ = [
|
||||
"build_detector",
|
||||
"load_backbone_config",
|
||||
"Model",
|
||||
"ModelConfig",
|
||||
"build_model",
|
||||
"build_model_with_new_targets",
|
||||
]
|
||||
|
||||
|
||||
class ModelConfig(BaseConfig):
|
||||
"""Complete configuration describing a BatDetect2 model.
|
||||
|
||||
Bundles every parameter that defines a model's behaviour: the input
|
||||
sample rate, backbone architecture, preprocessing pipeline,
|
||||
postprocessing pipeline, and detection targets.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
samplerate : int
|
||||
Expected input audio sample rate in Hz. Audio must be resampled
|
||||
to this rate before being passed to the model. Defaults to
|
||||
``TARGET_SAMPLERATE_HZ`` (256 000 Hz).
|
||||
architecture : BackboneConfig
|
||||
Configuration for the encoder-decoder backbone network. Defaults
|
||||
to ``UNetBackboneConfig()``.
|
||||
preprocess : PreprocessingConfig
|
||||
Parameters for the audio-to-spectrogram preprocessing pipeline
|
||||
(STFT, frequency crop, transforms, resize). Defaults to
|
||||
``PreprocessingConfig()``.
|
||||
postprocess : PostprocessConfig
|
||||
Parameters for converting raw model outputs into detections (NMS
|
||||
kernel, thresholds, top-k limit). Defaults to
|
||||
``PostprocessConfig()``.
|
||||
targets : TargetConfig
|
||||
Detection and classification target definitions (class list,
|
||||
detection target, bounding-box mapper). Defaults to
|
||||
``TargetConfig()``.
|
||||
"""
|
||||
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
architecture: BackboneConfig = Field(default_factory=UNetBackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
path: PathLike,
|
||||
field: str | None = None,
|
||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
||||
strict: bool | None = None,
|
||||
targets: TargetConfig | None = None,
|
||||
) -> "ModelConfig":
|
||||
config = super().load(path, field, extra, strict)
|
||||
|
||||
if targets is None:
|
||||
return config
|
||||
|
||||
return config.model_copy(update={"targets": targets})
|
||||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing.
|
||||
|
||||
Combines a preprocessor, a detection model, and a postprocessor into a
|
||||
single PyTorch module. Calling ``forward`` on a raw waveform tensor
|
||||
returns a list of detection tensors ready for downstream use.
|
||||
|
||||
This class is the top-level object produced by ``build_model``. Most
|
||||
users will not need to construct it directly.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
detector : DetectionModel
|
||||
The neural network that processes spectrograms and produces raw
|
||||
detection, classification, and bounding-box outputs.
|
||||
preprocessor : PreprocessorProtocol
|
||||
Converts a raw waveform tensor into a spectrogram tensor accepted by
|
||||
``detector``.
|
||||
postprocessor : PostprocessorProtocol
|
||||
Converts the raw ``ModelOutput`` from ``detector`` into a list of
|
||||
per-clip detection tensors.
|
||||
targets : TargetProtocol
|
||||
Describes the set of target classes; used when building heads and
|
||||
during training target construction.
|
||||
"""
|
||||
|
||||
detector: DetectionModel
|
||||
preprocessor: PreprocessorProtocol
|
||||
postprocessor: PostprocessorProtocol
|
||||
@ -115,31 +206,87 @@ class Model(torch.nn.Module):
|
||||
self.postprocessor = postprocessor
|
||||
self.targets = targets
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
|
||||
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
||||
"""Run the full detection pipeline on a waveform tensor.
|
||||
|
||||
Converts the waveform to a spectrogram, passes it through the
|
||||
detector, and postprocesses the raw outputs into detection tensors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : torch.Tensor
|
||||
Raw audio waveform tensor. The exact expected shape depends on
|
||||
the preprocessor, but is typically ``(batch, samples)`` or
|
||||
``(batch, channels, samples)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ClipDetectionsTensor]
|
||||
One detection tensor per clip in the batch. Each tensor encodes
|
||||
the detected events (locations, class scores, sizes) for that
|
||||
clip.
|
||||
"""
|
||||
spec = self.preprocessor(wav)
|
||||
outputs = self.detector(spec)
|
||||
return self.postprocessor(outputs)
|
||||
|
||||
|
||||
def build_model(
|
||||
config: Optional[BackboneConfig] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
postprocessor: Optional[PostprocessorProtocol] = None,
|
||||
):
|
||||
config: ModelConfig | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
postprocessor: PostprocessorProtocol | None = None,
|
||||
) -> Model:
|
||||
"""Build a complete, ready-to-use BatDetect2 model.
|
||||
|
||||
Assembles a ``Model`` instance from a ``ModelConfig`` and optional
|
||||
component overrides. Any component argument left as ``None`` is built
|
||||
from the configuration. Passing a pre-built component overrides the
|
||||
corresponding config fields for that component only.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : ModelConfig, optional
|
||||
Full model configuration (samplerate, architecture, preprocessing,
|
||||
postprocessing, targets). Defaults to ``ModelConfig()`` if not
|
||||
provided.
|
||||
targets : TargetProtocol, optional
|
||||
Pre-built targets object. If given, overrides
|
||||
``config.targets``.
|
||||
preprocessor : PreprocessorProtocol, optional
|
||||
Pre-built preprocessor. If given, overrides
|
||||
``config.preprocess`` and ``config.samplerate`` for the
|
||||
preprocessing step.
|
||||
postprocessor : PostprocessorProtocol, optional
|
||||
Pre-built postprocessor. If given, overrides
|
||||
``config.postprocess``. When omitted and a custom
|
||||
``preprocessor`` is supplied, the default postprocessor is built
|
||||
using that preprocessor so that frequency and time scaling remain
|
||||
consistent.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Model
|
||||
A fully assembled ``Model`` instance ready for inference or
|
||||
training.
|
||||
"""
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
|
||||
config = config or BackboneConfig()
|
||||
targets = targets or build_targets()
|
||||
preprocessor = preprocessor or build_preprocessor()
|
||||
config = config or ModelConfig()
|
||||
targets = targets or build_targets(config=config.targets)
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
config=config.preprocess,
|
||||
input_samplerate=config.samplerate,
|
||||
)
|
||||
postprocessor = postprocessor or build_postprocessor(
|
||||
preprocessor=preprocessor,
|
||||
config=config.postprocess,
|
||||
)
|
||||
detector = build_detector(
|
||||
num_classes=len(targets.class_names),
|
||||
config=config,
|
||||
config=config.architecture,
|
||||
)
|
||||
return Model(
|
||||
detector=detector,
|
||||
@ -147,3 +294,21 @@ def build_model(
|
||||
preprocessor=preprocessor,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
def build_model_with_new_targets(
|
||||
model: Model,
|
||||
targets: TargetProtocol,
|
||||
) -> Model:
|
||||
"""Build a new model with a different target set."""
|
||||
detector = build_detector(
|
||||
num_classes=len(targets.class_names),
|
||||
backbone=model.detector.backbone,
|
||||
)
|
||||
|
||||
return Model(
|
||||
detector=detector,
|
||||
postprocessor=model.postprocessor,
|
||||
preprocessor=model.preprocessor,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
@ -1,99 +1,176 @@
|
||||
"""Assembles a complete Encoder-Decoder Backbone network.
|
||||
"""Assembles a complete encoder-decoder backbone network.
|
||||
|
||||
This module defines the configuration (`BackboneConfig`) and implementation
|
||||
(`Backbone`) for a standard encoder-decoder style neural network backbone.
|
||||
This module defines ``UNetBackboneConfig`` and the ``UNetBackbone``
|
||||
``nn.Module``, together with the ``build_backbone`` and
|
||||
``load_backbone_config`` helpers.
|
||||
|
||||
It orchestrates the connection between three main components, built using their
|
||||
respective configurations and factory functions from sibling modules:
|
||||
1. Encoder (`batdetect2.models.encoder`): Downsampling path, extracts features
|
||||
at multiple resolutions and provides skip connections.
|
||||
2. Bottleneck (`batdetect2.models.bottleneck`): Processes features at the
|
||||
lowest resolution, optionally applying self-attention.
|
||||
3. Decoder (`batdetect2.models.decoder`): Upsampling path, reconstructs high-
|
||||
resolution features using bottleneck features and skip connections.
|
||||
A backbone combines three components built from the sibling modules:
|
||||
|
||||
The resulting `Backbone` module takes a spectrogram as input and outputs a
|
||||
final feature map, typically used by subsequent prediction heads. It includes
|
||||
automatic padding to handle input sizes not perfectly divisible by the
|
||||
network's total downsampling factor.
|
||||
1. **Encoder** (``batdetect2.models.encoder``) – reduces spatial resolution
|
||||
while extracting hierarchical features and storing skip-connection tensors.
|
||||
2. **Bottleneck** (``batdetect2.models.bottleneck``) – processes the
|
||||
lowest-resolution features, optionally applying self-attention.
|
||||
3. **Decoder** (``batdetect2.models.decoder``) – restores spatial resolution
|
||||
using bottleneck features and skip connections from the encoder.
|
||||
|
||||
The resulting ``UNetBackbone`` takes a spectrogram tensor as input and returns
|
||||
a high-resolution feature map consumed by the prediction heads in
|
||||
``batdetect2.models.detectors``.
|
||||
|
||||
Input padding is handled automatically: the backbone pads the input to be
|
||||
divisible by the total downsampling factor and strips the padding from the
|
||||
output so that the output spatial dimensions always match the input spatial
|
||||
dimensions.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from pydantic import Field, TypeAdapter
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models.bottleneck import build_bottleneck
|
||||
from batdetect2.models.config import BackboneConfig
|
||||
from batdetect2.models.decoder import Decoder, build_decoder
|
||||
from batdetect2.models.encoder import Encoder, build_encoder
|
||||
from batdetect2.typing.models import BackboneModel
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.models.bottleneck import (
|
||||
DEFAULT_BOTTLENECK_CONFIG,
|
||||
BottleneckConfig,
|
||||
build_bottleneck,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
DecoderConfig,
|
||||
build_decoder,
|
||||
)
|
||||
from batdetect2.models.encoder import (
|
||||
DEFAULT_ENCODER_CONFIG,
|
||||
EncoderConfig,
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.types import (
|
||||
BackboneModel,
|
||||
BottleneckProtocol,
|
||||
DecoderProtocol,
|
||||
EncoderProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Backbone",
|
||||
"BackboneImportConfig",
|
||||
"UNetBackbone",
|
||||
"BackboneConfig",
|
||||
"load_backbone_config",
|
||||
"build_backbone",
|
||||
]
|
||||
|
||||
|
||||
class Backbone(BackboneModel):
|
||||
"""Encoder-Decoder Backbone Network Implementation.
|
||||
class UNetBackboneConfig(BaseConfig):
|
||||
"""Configuration for a U-Net-style encoder-decoder backbone.
|
||||
|
||||
Combines an Encoder, Bottleneck, and Decoder module sequentially, using
|
||||
skip connections between the Encoder and Decoder. Implements the standard
|
||||
U-Net style forward pass. Includes automatic input padding to handle
|
||||
various input sizes and a final convolutional block to adjust the output
|
||||
channels.
|
||||
All fields have sensible defaults that reproduce the standard BatDetect2
|
||||
architecture, so you can start with ``UNetBackboneConfig()`` and override
|
||||
only the fields you want to change.
|
||||
|
||||
This class inherits from `BackboneModel` and implements its `forward`
|
||||
method. Instances are typically created using the `build_backbone` factory
|
||||
function.
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
Discriminator field used by the backbone registry; always
|
||||
``"UNetBackbone"``.
|
||||
input_height : int
|
||||
Number of frequency bins in the input spectrogram. Defaults to
|
||||
``128``.
|
||||
in_channels : int
|
||||
Number of channels in the input spectrogram (e.g. ``1`` for a
|
||||
standard mel-spectrogram). Defaults to ``1``.
|
||||
encoder : EncoderConfig
|
||||
Configuration for the downsampling path. Defaults to
|
||||
``DEFAULT_ENCODER_CONFIG``.
|
||||
bottleneck : BottleneckConfig
|
||||
Configuration for the bottleneck. Defaults to
|
||||
``DEFAULT_BOTTLENECK_CONFIG``.
|
||||
decoder : DecoderConfig
|
||||
Configuration for the upsampling path. Defaults to
|
||||
``DEFAULT_DECODER_CONFIG``.
|
||||
"""
|
||||
|
||||
name: Literal["UNetBackbone"] = "UNetBackbone"
|
||||
input_height: int = 128
|
||||
in_channels: int = 1
|
||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||
|
||||
|
||||
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
|
||||
|
||||
|
||||
@add_import_config(backbone_registry)
|
||||
class BackboneImportConfig(ImportConfig):
|
||||
"""Use any callable as a backbone model.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class UNetBackbone(BackboneModel):
|
||||
"""U-Net-style encoder-decoder backbone network.
|
||||
|
||||
Combines an encoder, a bottleneck, and a decoder into a single module
|
||||
that produces a high-resolution feature map from an input spectrogram.
|
||||
Skip connections from each encoder stage are added element-wise to the
|
||||
corresponding decoder stage input.
|
||||
|
||||
Input spectrograms of arbitrary width are handled automatically: the
|
||||
backbone pads the input so that its dimensions are divisible by
|
||||
``divide_factor`` and removes the padding from the output.
|
||||
|
||||
Instances are typically created via ``build_backbone``.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int
|
||||
Expected height of the input spectrogram.
|
||||
Expected height (frequency bins) of the input spectrogram.
|
||||
out_channels : int
|
||||
Number of channels in the final output feature map.
|
||||
encoder : Encoder
|
||||
Number of channels in the output feature map (taken from the
|
||||
decoder's output channel count).
|
||||
encoder : EncoderProtocol
|
||||
The instantiated encoder module.
|
||||
decoder : Decoder
|
||||
decoder : DecoderProtocol
|
||||
The instantiated decoder module.
|
||||
bottleneck : nn.Module
|
||||
bottleneck : BottleneckProtocol
|
||||
The instantiated bottleneck module.
|
||||
final_conv : ConvBlock
|
||||
Final convolutional block applied after the decoder.
|
||||
divide_factor : int
|
||||
The total downsampling factor (2^depth) applied by the encoder,
|
||||
used for automatic input padding.
|
||||
The total spatial downsampling factor applied by the encoder
|
||||
(``input_height // encoder.output_height``). The input width is
|
||||
padded to be a multiple of this value before processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_height: int,
|
||||
encoder: Encoder,
|
||||
decoder: Decoder,
|
||||
bottleneck: nn.Module,
|
||||
encoder: EncoderProtocol,
|
||||
decoder: DecoderProtocol,
|
||||
bottleneck: BottleneckProtocol,
|
||||
):
|
||||
"""Initialize the Backbone network.
|
||||
"""Initialise the backbone network.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Expected height of the input spectrogram.
|
||||
out_channels : int
|
||||
Desired number of output channels for the backbone's feature map.
|
||||
encoder : Encoder
|
||||
An initialized Encoder module.
|
||||
decoder : Decoder
|
||||
An initialized Decoder module.
|
||||
bottleneck : nn.Module
|
||||
An initialized Bottleneck module.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If component output/input channels or heights are incompatible.
|
||||
Expected height (frequency bins) of the input spectrogram.
|
||||
encoder : EncoderProtocol
|
||||
An initialised encoder module.
|
||||
decoder : DecoderProtocol
|
||||
An initialised decoder module. Its ``output_height`` must equal
|
||||
``input_height``; a ``ValueError`` is raised otherwise.
|
||||
bottleneck : BottleneckProtocol
|
||||
An initialised bottleneck module.
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_height = input_height
|
||||
@ -110,22 +187,25 @@ class Backbone(BackboneModel):
|
||||
self.divide_factor = input_height // self.encoder.output_height
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Perform the forward pass through the encoder-decoder backbone.
|
||||
"""Produce a feature map from an input spectrogram.
|
||||
|
||||
Applies padding, runs encoder, bottleneck, decoder (with skip
|
||||
connections), removes padding, and applies a final convolution.
|
||||
Pads the input if necessary, runs it through the encoder, then
|
||||
the bottleneck, then the decoder (incorporating encoder skip
|
||||
connections), and finally removes any padding added earlier.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor, shape `(B, C_in, H_in, W_in)`. Must match
|
||||
`self.encoder.input_channels` and `self.input_height`.
|
||||
Input spectrogram tensor, shape
|
||||
``(B, C_in, H_in, W_in)``. ``H_in`` must equal
|
||||
``self.input_height``; ``W_in`` can be any positive integer.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output feature map tensor, shape `(B, C_out, H_in, W_in)`, where
|
||||
`C_out` is `self.out_channels`.
|
||||
Feature map tensor, shape ``(B, C_out, H_in, W_in)``, where
|
||||
``C_out`` is ``self.out_channels``. The spatial dimensions
|
||||
always match those of the input.
|
||||
"""
|
||||
spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor)
|
||||
|
||||
@ -143,95 +223,97 @@ class Backbone(BackboneModel):
|
||||
|
||||
return x
|
||||
|
||||
@backbone_registry.register(UNetBackboneConfig)
|
||||
@staticmethod
|
||||
def from_config(config: UNetBackboneConfig) -> BackboneModel:
|
||||
encoder = build_encoder(
|
||||
in_channels=config.in_channels,
|
||||
input_height=config.input_height,
|
||||
config=config.encoder,
|
||||
)
|
||||
|
||||
def build_backbone(config: BackboneConfig) -> BackboneModel:
|
||||
"""Factory function to build a Backbone from configuration.
|
||||
bottleneck = build_bottleneck(
|
||||
input_height=encoder.output_height,
|
||||
in_channels=encoder.out_channels,
|
||||
config=config.bottleneck,
|
||||
)
|
||||
|
||||
Constructs the `Encoder`, `Bottleneck`, and `Decoder` components based on
|
||||
the provided `BackboneConfig`, validates their compatibility, and assembles
|
||||
them into a `Backbone` instance.
|
||||
decoder = build_decoder(
|
||||
in_channels=bottleneck.out_channels,
|
||||
input_height=encoder.output_height,
|
||||
config=config.decoder,
|
||||
)
|
||||
|
||||
if decoder.output_height != config.input_height:
|
||||
raise ValueError(
|
||||
"Invalid configuration: Decoder output height "
|
||||
f"({decoder.output_height}) must match the Backbone input height "
|
||||
f"({config.input_height}). Check encoder/decoder layer "
|
||||
"configurations and input/bottleneck heights."
|
||||
)
|
||||
|
||||
return UNetBackbone(
|
||||
input_height=config.input_height,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
bottleneck=bottleneck,
|
||||
)
|
||||
|
||||
|
||||
BackboneConfig = Annotated[
|
||||
UNetBackboneConfig | BackboneImportConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
||||
"""Build a backbone network from configuration.
|
||||
|
||||
Looks up the backbone class corresponding to ``config.name`` in the
|
||||
backbone registry and calls its ``from_config`` method. If no
|
||||
configuration is provided, a default ``UNetBackbone`` is returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : BackboneConfig
|
||||
The configuration object detailing the backbone architecture, including
|
||||
input dimensions and configurations for encoder, bottleneck, and
|
||||
decoder.
|
||||
config : BackboneConfig, optional
|
||||
A configuration object describing the desired backbone. Currently
|
||||
``UNetBackboneConfig`` is the only supported type. Defaults to
|
||||
``UNetBackboneConfig()`` if not provided.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneModel
|
||||
An initialized `Backbone` module ready for use.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If sub-component configurations are incompatible
|
||||
(e.g., channel mismatches, decoder output height doesn't match backbone
|
||||
input height).
|
||||
NotImplementedError
|
||||
If an unknown block type is specified in sub-configs.
|
||||
An initialised backbone module.
|
||||
"""
|
||||
encoder = build_encoder(
|
||||
in_channels=config.in_channels,
|
||||
input_height=config.input_height,
|
||||
config=config.encoder,
|
||||
)
|
||||
|
||||
bottleneck = build_bottleneck(
|
||||
input_height=encoder.output_height,
|
||||
in_channels=encoder.out_channels,
|
||||
config=config.bottleneck,
|
||||
)
|
||||
|
||||
decoder = build_decoder(
|
||||
in_channels=bottleneck.out_channels,
|
||||
input_height=encoder.output_height,
|
||||
config=config.decoder,
|
||||
)
|
||||
|
||||
if decoder.output_height != config.input_height:
|
||||
raise ValueError(
|
||||
"Invalid configuration: Decoder output height "
|
||||
f"({decoder.output_height}) must match the Backbone input height "
|
||||
f"({config.input_height}). Check encoder/decoder layer "
|
||||
"configurations and input/bottleneck heights."
|
||||
)
|
||||
|
||||
return Backbone(
|
||||
input_height=config.input_height,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
bottleneck=bottleneck,
|
||||
)
|
||||
config = config or UNetBackboneConfig()
|
||||
return backbone_registry.build(config)
|
||||
|
||||
|
||||
def _pad_adjust(
|
||||
spec: torch.Tensor,
|
||||
factor: int = 32,
|
||||
) -> Tuple[torch.Tensor, int, int]:
|
||||
"""Pad tensor height and width to be divisible by a factor.
|
||||
) -> tuple[torch.Tensor, int, int]:
|
||||
"""Pad a tensor's height and width to be divisible by ``factor``.
|
||||
|
||||
Calculates the required padding for the last two dimensions (H, W) to make
|
||||
them divisible by `factor` and applies right/bottom padding using
|
||||
`torch.nn.functional.pad`.
|
||||
Adds zero-padding to the bottom and right edges of the tensor so that
|
||||
both dimensions are exact multiples of ``factor``. If both dimensions
|
||||
are already divisible, the tensor is returned unchanged.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input tensor, typically shape `(B, C, H, W)`.
|
||||
Input tensor, typically shape ``(B, C, H, W)``.
|
||||
factor : int, default=32
|
||||
The factor to make height and width divisible by.
|
||||
The factor that both H and W should be divisible by after padding.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[torch.Tensor, int, int]
|
||||
A tuple containing:
|
||||
- The padded tensor.
|
||||
- The amount of padding added to height (`h_pad`).
|
||||
- The amount of padding added to width (`w_pad`).
|
||||
tuple[torch.Tensor, int, int]
|
||||
- Padded tensor.
|
||||
- Number of rows added to the height (``h_pad``).
|
||||
- Number of columns added to the width (``w_pad``).
|
||||
"""
|
||||
h, w = spec.shape[2:]
|
||||
h, w = spec.shape[-2:]
|
||||
h_pad = -h % factor
|
||||
w_pad = -w % factor
|
||||
|
||||
@ -244,28 +326,71 @@ def _pad_adjust(
|
||||
def _restore_pad(
|
||||
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
|
||||
) -> torch.Tensor:
|
||||
"""Remove padding added by _pad_adjust.
|
||||
"""Remove padding previously added by ``_pad_adjust``.
|
||||
|
||||
Removes padding from the bottom and right edges of the tensor.
|
||||
Trims ``h_pad`` rows from the bottom and ``w_pad`` columns from the
|
||||
right of the tensor, restoring its original spatial dimensions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Padded tensor, typically shape `(B, C, H_padded, W_padded)`.
|
||||
Padded tensor, typically shape ``(B, C, H_padded, W_padded)``.
|
||||
h_pad : int, default=0
|
||||
Amount of padding previously added to the height (bottom).
|
||||
Number of rows to remove from the bottom.
|
||||
w_pad : int, default=0
|
||||
Amount of padding previously added to the width (right).
|
||||
Number of columns to remove from the right.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Tensor with padding removed, shape `(B, C, H_original, W_original)`.
|
||||
Tensor with padding removed, shape
|
||||
``(B, C, H_padded - h_pad, W_padded - w_pad)``.
|
||||
"""
|
||||
if h_pad > 0:
|
||||
x = x[:, :, :-h_pad, :]
|
||||
x = x[..., :-h_pad, :]
|
||||
|
||||
if w_pad > 0:
|
||||
x = x[:, :, :, :-w_pad]
|
||||
x = x[..., :-w_pad]
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: str | None = None,
|
||||
) -> BackboneConfig:
|
||||
"""Load a backbone configuration from a YAML or JSON file.
|
||||
|
||||
Reads the file at ``path``, optionally descends into a named sub-field,
|
||||
and validates the result against the ``BackboneConfig`` discriminated
|
||||
union.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file. Both YAML and JSON formats are
|
||||
supported.
|
||||
field : str, optional
|
||||
Dot-separated key path to the sub-field that contains the backbone
|
||||
configuration (e.g. ``"model"``). If ``None``, the root of the
|
||||
file is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneConfig
|
||||
A validated backbone configuration object (currently always a
|
||||
``UNetBackboneConfig`` instance).
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If ``path`` does not exist.
|
||||
ValidationError
|
||||
If the loaded data does not conform to a known ``BackboneConfig``
|
||||
schema.
|
||||
"""
|
||||
return load_config(
|
||||
path,
|
||||
schema=TypeAdapter(BackboneConfig),
|
||||
field=field,
|
||||
)
|
||||
|
||||
@ -1,42 +1,63 @@
|
||||
"""Commonly used neural network building blocks for BatDetect2 models.
|
||||
"""Reusable convolutional building blocks for BatDetect2 models.
|
||||
|
||||
This module provides various reusable `torch.nn.Module` subclasses that form
|
||||
the fundamental building blocks for constructing convolutional neural network
|
||||
architectures, particularly encoder-decoder backbones used in BatDetect2.
|
||||
This module provides a collection of ``torch.nn.Module`` subclasses that form
|
||||
the fundamental building blocks for the encoder-decoder backbone used in
|
||||
BatDetect2. All blocks follow a consistent interface: they store
|
||||
``in_channels`` and ``out_channels`` as attributes and implement a
|
||||
``get_output_height`` method that reports how a given input height maps to an
|
||||
output height (e.g., halved by downsampling blocks, doubled by upsampling
|
||||
blocks).
|
||||
|
||||
It includes standard components like basic convolutional blocks (`ConvBlock`),
|
||||
blocks incorporating downsampling (`StandardConvDownBlock`), and blocks with
|
||||
upsampling (`StandardConvUpBlock`).
|
||||
Available block families
|
||||
------------------------
|
||||
Standard blocks
|
||||
``ConvBlock`` – convolution + batch normalisation + ReLU, no change in
|
||||
spatial resolution.
|
||||
|
||||
Additionally, it features specialized layers investigated in BatDetect2
|
||||
research:
|
||||
Downsampling blocks
|
||||
``StandardConvDownBlock`` – convolution then 2×2 max-pooling, halves H
|
||||
and W.
|
||||
``FreqCoordConvDownBlock`` – like ``StandardConvDownBlock`` but prepends
|
||||
a normalised frequency-coordinate channel before the convolution
|
||||
(CoordConv concept), helping filters learn frequency-position-dependent
|
||||
patterns.
|
||||
|
||||
- `SelfAttention`: Applies self-attention along the time dimension, enabling
|
||||
the model to weigh information across the entire temporal context, often
|
||||
used in the bottleneck of an encoder-decoder.
|
||||
- `FreqCoordConvDownBlock` / `FreqCoordConvUpBlock`: Implement the "CoordConv"
|
||||
concept by concatenating normalized frequency coordinate information as an
|
||||
extra channel to the input of convolutional layers. This explicitly provides
|
||||
spatial frequency information to filters, potentially enabling them to learn
|
||||
frequency-dependent patterns more effectively.
|
||||
Upsampling blocks
|
||||
``StandardConvUpBlock`` – bilinear interpolation then convolution,
|
||||
doubles H and W.
|
||||
``FreqCoordConvUpBlock`` – like ``StandardConvUpBlock`` but prepends a
|
||||
frequency-coordinate channel after upsampling.
|
||||
|
||||
These blocks can be utilized directly in custom PyTorch model definitions or
|
||||
assembled into larger architectures.
|
||||
Bottleneck blocks
|
||||
``VerticalConv`` – 1-D convolution whose kernel spans the entire
|
||||
frequency axis, collapsing H to 1 whilst preserving W.
|
||||
``SelfAttention`` – scaled dot-product self-attention along the time
|
||||
axis; typically follows a ``VerticalConv``.
|
||||
|
||||
A unified factory function `build_layer_from_config` allows creating instances
|
||||
of these blocks based on configuration objects.
|
||||
Group block
|
||||
``LayerGroup`` – chains several blocks sequentially into one unit,
|
||||
useful when a single encoder or decoder "stage" requires more than one
|
||||
operation.
|
||||
|
||||
Factory function
|
||||
----------------
|
||||
``build_layer`` creates any of the above blocks from the matching
|
||||
configuration object (one of the ``*Config`` classes exported here), using
|
||||
a discriminated-union ``name`` field to dispatch to the correct class.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List, Literal, Tuple, Union
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"BlockImportConfig",
|
||||
"ConvBlock",
|
||||
"LayerGroupConfig",
|
||||
"VerticalConv",
|
||||
@ -51,63 +72,125 @@ __all__ = [
|
||||
"FreqCoordConvUpConfig",
|
||||
"StandardConvUpConfig",
|
||||
"LayerConfig",
|
||||
"build_layer_from_config",
|
||||
"build_layer",
|
||||
]
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
"""Abstract base class for all BatDetect2 building blocks.
|
||||
|
||||
Subclasses must set ``in_channels`` and ``out_channels`` as integer
|
||||
attributes so that factory functions can wire blocks together without
|
||||
inspecting configuration objects at runtime. They may also override
|
||||
``get_output_height`` when the block changes the height dimension (e.g.
|
||||
downsampling or upsampling blocks).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels expected in the input tensor.
|
||||
out_channels : int
|
||||
Number of channels produced in the output tensor.
|
||||
"""
|
||||
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
"""Return the output height for a given input height.
|
||||
|
||||
The default implementation returns ``input_height`` unchanged,
|
||||
which is correct for blocks that do not alter spatial resolution.
|
||||
Override this in downsampling (returns ``input_height // 2``) or
|
||||
upsampling (returns ``input_height * 2``) subclasses.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Height (number of frequency bins) of the input feature map.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Height of the output feature map.
|
||||
"""
|
||||
return input_height
|
||||
|
||||
|
||||
block_registry: Registry[Block, [int, int]] = Registry("block")
|
||||
|
||||
|
||||
@add_import_config(block_registry)
|
||||
class BlockImportConfig(ImportConfig):
|
||||
"""Use any callable as a model block.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class SelfAttentionConfig(BaseConfig):
|
||||
"""Configuration for a ``SelfAttention`` block.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
Discriminator field; always ``"SelfAttention"``.
|
||||
attention_channels : int
|
||||
Dimensionality of the query, key, and value projections.
|
||||
temperature : float
|
||||
Scaling factor applied to the weighted values before the final
|
||||
linear projection. Defaults to ``1``.
|
||||
"""
|
||||
|
||||
name: Literal["SelfAttention"] = "SelfAttention"
|
||||
attention_channels: int
|
||||
temperature: float = 1
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""Self-Attention mechanism operating along the time dimension.
|
||||
class SelfAttention(Block):
|
||||
"""Self-attention block operating along the time axis.
|
||||
|
||||
This module implements a scaled dot-product self-attention mechanism,
|
||||
specifically designed here to operate across the time steps of an input
|
||||
feature map, typically after spatial dimensions (like frequency) have been
|
||||
condensed or squeezed.
|
||||
Applies a scaled dot-product self-attention mechanism across the time
|
||||
steps of an input feature map. Before attention is computed the height
|
||||
dimension (frequency axis) is expected to have been reduced to 1, e.g.
|
||||
by a preceding ``VerticalConv`` layer.
|
||||
|
||||
By calculating attention weights between all pairs of time steps, it allows
|
||||
the model to capture long-range temporal dependencies and focus on relevant
|
||||
parts of the sequence. It's often employed in the bottleneck or
|
||||
intermediate layers of an encoder-decoder architecture to integrate global
|
||||
temporal context.
|
||||
|
||||
The implementation uses linear projections to create query, key, and value
|
||||
representations, computes scaled dot-product attention scores, applies
|
||||
softmax, and produces an output by weighting the values according to the
|
||||
attention scores, followed by a final linear projection. Positional encoding
|
||||
is not explicitly included in this block.
|
||||
For each time step the block computes query, key, and value projections
|
||||
with learned linear weights, then calculates attention weights from the
|
||||
query–key dot products divided by ``temperature × attention_channels``.
|
||||
The weighted sum of values is projected back to ``in_channels`` via a
|
||||
final linear layer, and the height dimension is restored so that the
|
||||
output shape matches the input shape.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels (features per time step after spatial squeeze).
|
||||
Number of input channels (features per time step). The output will
|
||||
also have ``in_channels`` channels.
|
||||
attention_channels : int
|
||||
Number of channels for the query, key, and value projections. Also the
|
||||
dimension of the output projection's input.
|
||||
Dimensionality of the query, key, and value projections.
|
||||
temperature : float, default=1.0
|
||||
Scaling factor applied *before* the final projection layer. Can be used
|
||||
to adjust the sharpness or focus of the attention mechanism, although
|
||||
scaling within the softmax (dividing by sqrt(dim)) is more common for
|
||||
standard transformers. Here it scales the weighted values.
|
||||
Divisor applied together with ``attention_channels`` when scaling
|
||||
the dot-product scores before softmax. Larger values produce softer
|
||||
(more uniform) attention distributions.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
key_fun : nn.Linear
|
||||
Linear layer for key projection.
|
||||
Linear projection for keys.
|
||||
value_fun : nn.Linear
|
||||
Linear layer for value projection.
|
||||
Linear projection for values.
|
||||
query_fun : nn.Linear
|
||||
Linear layer for query projection.
|
||||
Linear projection for queries.
|
||||
pro_fun : nn.Linear
|
||||
Final linear projection layer applied after attention weighting.
|
||||
Final linear projection applied to the attended values.
|
||||
temperature : float
|
||||
Scaling factor applied before final projection.
|
||||
Scaling divisor used when computing attention scores.
|
||||
att_dim : int
|
||||
Dimensionality of the attention space (`attention_channels`).
|
||||
Dimensionality of the attention space (``attention_channels``).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -117,10 +200,13 @@ class SelfAttention(nn.Module):
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
|
||||
# Note, does not encode position information (absolute or relative)
|
||||
self.temperature = temperature
|
||||
self.att_dim = attention_channels
|
||||
self.output_channels = in_channels
|
||||
|
||||
self.key_fun = nn.Linear(in_channels, attention_channels)
|
||||
self.value_fun = nn.Linear(in_channels, attention_channels)
|
||||
@ -133,20 +219,16 @@ class SelfAttention(nn.Module):
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, expected shape `(B, C, H, W)`, where H is typically
|
||||
squeezed (e.g., H=1 after a `VerticalConv` or pooling) before
|
||||
applying attention along the W (time) dimension.
|
||||
Input tensor with shape ``(B, C, 1, W)``. The height dimension
|
||||
must be 1 (i.e. the frequency axis should already have been
|
||||
collapsed by a preceding ``VerticalConv`` layer).
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor of the same shape as the input `(B, C, H, W)`, where
|
||||
attention has been applied across the W dimension.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If input tensor dimensions are incompatible with operations.
|
||||
Output tensor with the same shape ``(B, C, 1, W)`` as the
|
||||
input, with each time step updated by attended context from all
|
||||
other time steps.
|
||||
"""
|
||||
|
||||
x = x.squeeze(2).permute(0, 2, 1)
|
||||
@ -175,6 +257,22 @@ class SelfAttention(nn.Module):
|
||||
return op
|
||||
|
||||
def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Return the softmax attention weight matrix.
|
||||
|
||||
Useful for visualising which time steps attend to which others.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor with shape ``(B, C, 1, W)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Attention weight matrix with shape ``(B, W, W)``. Entry
|
||||
``[b, i, j]`` is the attention weight that time step ``i``
|
||||
assigns to time step ``j`` in batch item ``b``.
|
||||
"""
|
||||
x = x.squeeze(2).permute(0, 2, 1)
|
||||
|
||||
key = torch.matmul(
|
||||
@ -190,6 +288,19 @@ class SelfAttention(nn.Module):
|
||||
att_weights = F.softmax(kk_qq, 1)
|
||||
return att_weights
|
||||
|
||||
@block_registry.register(SelfAttentionConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: SelfAttentionConfig,
|
||||
input_channels: int,
|
||||
input_height: int,
|
||||
) -> "SelfAttention":
|
||||
return SelfAttention(
|
||||
in_channels=input_channels,
|
||||
attention_channels=config.attention_channels,
|
||||
temperature=config.temperature,
|
||||
)
|
||||
|
||||
|
||||
class ConvConfig(BaseConfig):
|
||||
"""Configuration for a basic ConvBlock."""
|
||||
@ -207,7 +318,7 @@ class ConvConfig(BaseConfig):
|
||||
"""Padding size."""
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
class ConvBlock(Block):
|
||||
"""Basic Convolutional Block.
|
||||
|
||||
A standard building block consisting of a 2D convolution, followed by
|
||||
@ -235,6 +346,8 @@ class ConvBlock(nn.Module):
|
||||
pad_size: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
@ -258,8 +371,37 @@ class ConvBlock(nn.Module):
|
||||
"""
|
||||
return F.relu_(self.batch_norm(self.conv(x)))
|
||||
|
||||
@block_registry.register(ConvConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ConvConfig,
|
||||
input_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
return ConvBlock(
|
||||
in_channels=input_channels,
|
||||
out_channels=config.out_channels,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
)
|
||||
|
||||
class VerticalConv(nn.Module):
|
||||
|
||||
class VerticalConvConfig(BaseConfig):
|
||||
"""Configuration for a ``VerticalConv`` block.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
Discriminator field; always ``"VerticalConv"``.
|
||||
channels : int
|
||||
Number of output channels produced by the vertical convolution.
|
||||
"""
|
||||
|
||||
name: Literal["VerticalConv"] = "VerticalConv"
|
||||
channels: int
|
||||
|
||||
|
||||
class VerticalConv(Block):
|
||||
"""Convolutional layer that aggregates features across the entire height.
|
||||
|
||||
Applies a 2D convolution using a kernel with shape `(input_height, 1)`.
|
||||
@ -288,6 +430,8 @@ class VerticalConv(nn.Module):
|
||||
input_height: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
@ -312,6 +456,19 @@ class VerticalConv(nn.Module):
|
||||
"""
|
||||
return F.relu_(self.bn(self.conv(x)))
|
||||
|
||||
@block_registry.register(VerticalConvConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: VerticalConvConfig,
|
||||
input_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
return VerticalConv(
|
||||
in_channels=input_channels,
|
||||
out_channels=config.channels,
|
||||
input_height=input_height,
|
||||
)
|
||||
|
||||
|
||||
class FreqCoordConvDownConfig(BaseConfig):
|
||||
"""Configuration for a FreqCoordConvDownBlock."""
|
||||
@ -329,7 +486,7 @@ class FreqCoordConvDownConfig(BaseConfig):
|
||||
"""Padding size."""
|
||||
|
||||
|
||||
class FreqCoordConvDownBlock(nn.Module):
|
||||
class FreqCoordConvDownBlock(Block):
|
||||
"""Downsampling Conv Block incorporating Frequency Coordinate features.
|
||||
|
||||
This block implements a downsampling step (Conv2d + MaxPool2d) commonly
|
||||
@ -368,6 +525,8 @@ class FreqCoordConvDownBlock(nn.Module):
|
||||
pad_size: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.coords = nn.Parameter(
|
||||
torch.linspace(-1, 1, input_height)[None, None, ..., None],
|
||||
@ -402,6 +561,24 @@ class FreqCoordConvDownBlock(nn.Module):
|
||||
x = F.relu(self.batch_norm(x), inplace=True)
|
||||
return x
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
return input_height // 2
|
||||
|
||||
@block_registry.register(FreqCoordConvDownConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: FreqCoordConvDownConfig,
|
||||
input_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
return FreqCoordConvDownBlock(
|
||||
in_channels=input_channels,
|
||||
out_channels=config.out_channels,
|
||||
input_height=input_height,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
)
|
||||
|
||||
|
||||
class StandardConvDownConfig(BaseConfig):
|
||||
"""Configuration for a StandardConvDownBlock."""
|
||||
@ -419,7 +596,7 @@ class StandardConvDownConfig(BaseConfig):
|
||||
"""Padding size."""
|
||||
|
||||
|
||||
class StandardConvDownBlock(nn.Module):
|
||||
class StandardConvDownBlock(Block):
|
||||
"""Standard Downsampling Convolutional Block.
|
||||
|
||||
A basic downsampling block consisting of a 2D convolution, followed by
|
||||
@ -447,6 +624,8 @@ class StandardConvDownBlock(nn.Module):
|
||||
pad_size: int = 1,
|
||||
):
|
||||
super(StandardConvDownBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
@ -472,6 +651,23 @@ class StandardConvDownBlock(nn.Module):
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
return F.relu(self.batch_norm(x), inplace=True)
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
return input_height // 2
|
||||
|
||||
@block_registry.register(StandardConvDownConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: StandardConvDownConfig,
|
||||
input_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
return StandardConvDownBlock(
|
||||
in_channels=input_channels,
|
||||
out_channels=config.out_channels,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
)
|
||||
|
||||
|
||||
class FreqCoordConvUpConfig(BaseConfig):
|
||||
"""Configuration for a FreqCoordConvUpBlock."""
|
||||
@ -488,8 +684,14 @@ class FreqCoordConvUpConfig(BaseConfig):
|
||||
pad_size: int = 1
|
||||
"""Padding size."""
|
||||
|
||||
up_mode: str = "bilinear"
|
||||
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
|
||||
|
||||
class FreqCoordConvUpBlock(nn.Module):
|
||||
up_scale: tuple[int, int] = (2, 2)
|
||||
"""Scaling factor for height and width during upsampling."""
|
||||
|
||||
|
||||
class FreqCoordConvUpBlock(Block):
|
||||
"""Upsampling Conv Block incorporating Frequency Coordinate features.
|
||||
|
||||
This block implements an upsampling step followed by a convolution,
|
||||
@ -504,22 +706,22 @@ class FreqCoordConvUpBlock(nn.Module):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
in_channels
|
||||
Number of channels in the input tensor (before upsampling).
|
||||
out_channels : int
|
||||
out_channels
|
||||
Number of output channels after the convolution.
|
||||
input_height : int
|
||||
input_height
|
||||
Height (H dimension, frequency bins) of the tensor *before* upsampling.
|
||||
Used to calculate the height for coordinate feature generation after
|
||||
upsampling.
|
||||
kernel_size : int, default=3
|
||||
kernel_size
|
||||
Size of the square convolutional kernel.
|
||||
pad_size : int, default=1
|
||||
pad_size
|
||||
Padding added before convolution.
|
||||
up_mode : str, default="bilinear"
|
||||
up_mode
|
||||
Interpolation mode for upsampling (e.g., "nearest", "bilinear",
|
||||
"bicubic").
|
||||
up_scale : Tuple[int, int], default=(2, 2)
|
||||
up_scale
|
||||
Scaling factor for height and width during upsampling
|
||||
(typically (2, 2)).
|
||||
"""
|
||||
@ -532,9 +734,11 @@ class FreqCoordConvUpBlock(nn.Module):
|
||||
kernel_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
up_mode: str = "bilinear",
|
||||
up_scale: Tuple[int, int] = (2, 2),
|
||||
up_scale: tuple[int, int] = (2, 2),
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.up_scale = up_scale
|
||||
self.up_mode = up_mode
|
||||
@ -581,6 +785,26 @@ class FreqCoordConvUpBlock(nn.Module):
|
||||
op = F.relu(self.batch_norm(op), inplace=True)
|
||||
return op
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
return input_height * 2
|
||||
|
||||
@block_registry.register(FreqCoordConvUpConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: FreqCoordConvUpConfig,
|
||||
input_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
return FreqCoordConvUpBlock(
|
||||
in_channels=input_channels,
|
||||
out_channels=config.out_channels,
|
||||
input_height=input_height,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
up_mode=config.up_mode,
|
||||
up_scale=config.up_scale,
|
||||
)
|
||||
|
||||
|
||||
class StandardConvUpConfig(BaseConfig):
|
||||
"""Configuration for a StandardConvUpBlock."""
|
||||
@ -597,8 +821,14 @@ class StandardConvUpConfig(BaseConfig):
|
||||
pad_size: int = 1
|
||||
"""Padding size."""
|
||||
|
||||
up_mode: str = "bilinear"
|
||||
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
|
||||
|
||||
class StandardConvUpBlock(nn.Module):
|
||||
up_scale: tuple[int, int] = (2, 2)
|
||||
"""Scaling factor for height and width during upsampling."""
|
||||
|
||||
|
||||
class StandardConvUpBlock(Block):
|
||||
"""Standard Upsampling Convolutional Block.
|
||||
|
||||
A basic upsampling block used in CNN decoders. It first upsamples the input
|
||||
@ -609,17 +839,17 @@ class StandardConvUpBlock(nn.Module):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
in_channels
|
||||
Number of channels in the input tensor (before upsampling).
|
||||
out_channels : int
|
||||
out_channels
|
||||
Number of output channels after the convolution.
|
||||
kernel_size : int, default=3
|
||||
kernel_size
|
||||
Size of the square convolutional kernel.
|
||||
pad_size : int, default=1
|
||||
pad_size
|
||||
Padding added before convolution.
|
||||
up_mode : str, default="bilinear"
|
||||
up_mode
|
||||
Interpolation mode for upsampling (e.g., "nearest", "bilinear").
|
||||
up_scale : Tuple[int, int], default=(2, 2)
|
||||
up_scale
|
||||
Scaling factor for height and width during upsampling.
|
||||
"""
|
||||
|
||||
@ -630,9 +860,11 @@ class StandardConvUpBlock(nn.Module):
|
||||
kernel_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
up_mode: str = "bilinear",
|
||||
up_scale: Tuple[int, int] = (2, 2),
|
||||
up_scale: tuple[int, int] = (2, 2),
|
||||
):
|
||||
super(StandardConvUpBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.up_scale = up_scale
|
||||
self.up_mode = up_mode
|
||||
self.conv = nn.Conv2d(
|
||||
@ -669,155 +901,195 @@ class StandardConvUpBlock(nn.Module):
|
||||
op = F.relu(self.batch_norm(op), inplace=True)
|
||||
return op
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
return input_height * 2
|
||||
|
||||
@block_registry.register(StandardConvUpConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: StandardConvUpConfig,
|
||||
input_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
return StandardConvUpBlock(
|
||||
in_channels=input_channels,
|
||||
out_channels=config.out_channels,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
up_mode=config.up_mode,
|
||||
up_scale=config.up_scale,
|
||||
)
|
||||
|
||||
|
||||
class LayerGroupConfig(BaseConfig):
|
||||
"""Configuration for a ``LayerGroup`` — a sequential chain of blocks.
|
||||
|
||||
Use this when a single encoder or decoder stage needs more than one
|
||||
block. The blocks are executed in the order they appear in ``layers``,
|
||||
with channel counts and heights propagated automatically.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
Discriminator field; always ``"LayerGroup"``.
|
||||
layers : List[LayerConfig]
|
||||
Ordered list of block configurations to chain together.
|
||||
"""
|
||||
|
||||
name: Literal["LayerGroup"] = "LayerGroup"
|
||||
layers: list["LayerConfig"]
|
||||
|
||||
|
||||
LayerConfig = Annotated[
|
||||
Union[
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
StandardConvDownConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
SelfAttentionConfig,
|
||||
"LayerGroupConfig",
|
||||
],
|
||||
ConvConfig
|
||||
| FreqCoordConvDownConfig
|
||||
| StandardConvDownConfig
|
||||
| FreqCoordConvUpConfig
|
||||
| StandardConvUpConfig
|
||||
| SelfAttentionConfig
|
||||
| LayerGroupConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configuration models."""
|
||||
|
||||
|
||||
class LayerGroupConfig(BaseConfig):
|
||||
name: Literal["LayerGroup"] = "LayerGroup"
|
||||
layers: List[LayerConfig]
|
||||
class LayerGroup(nn.Module):
|
||||
"""Sequential chain of blocks that acts as a single composite block.
|
||||
|
||||
Wraps multiple ``Block`` instances in an ``nn.Sequential`` container,
|
||||
exposing the same ``in_channels``, ``out_channels``, and
|
||||
``get_output_height`` interface as a regular ``Block`` so it can be
|
||||
used transparently wherever a single block is expected.
|
||||
|
||||
Instances are typically constructed by ``build_layer`` when given a
|
||||
``LayerGroupConfig``; you rarely need to create them directly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
layers : list[Block]
|
||||
Pre-built block instances to chain, in execution order.
|
||||
input_height : int
|
||||
Height of the tensor entering the first block.
|
||||
input_channels : int
|
||||
Number of channels in the tensor entering the first block.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels (taken from the first block).
|
||||
out_channels : int
|
||||
Number of output channels (taken from the last block).
|
||||
layers : nn.Sequential
|
||||
The wrapped sequence of block modules.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
layers: list[Block],
|
||||
input_height: int,
|
||||
input_channels: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = input_channels
|
||||
self.out_channels = (
|
||||
layers[-1].out_channels if layers else input_channels
|
||||
)
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pass input through all blocks in sequence.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input feature map, shape ``(B, C_in, H, W)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output feature map after all blocks have been applied.
|
||||
"""
|
||||
return self.layers(x)
|
||||
|
||||
def get_output_height(self, input_height: int) -> int:
|
||||
"""Compute the output height by propagating through all blocks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Height of the input feature map.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Height after all blocks in the group have been applied.
|
||||
"""
|
||||
for block in self.layers:
|
||||
input_height = block.get_output_height(input_height) # type: ignore
|
||||
return input_height
|
||||
|
||||
@block_registry.register(LayerGroupConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: LayerGroupConfig,
|
||||
input_channels: int,
|
||||
input_height: int,
|
||||
):
|
||||
layers = []
|
||||
|
||||
for layer_config in config.layers:
|
||||
layer = build_layer(
|
||||
input_height=input_height,
|
||||
in_channels=input_channels,
|
||||
config=layer_config,
|
||||
)
|
||||
layers.append(layer)
|
||||
input_height = layer.get_output_height(input_height)
|
||||
input_channels = layer.out_channels
|
||||
|
||||
return LayerGroup(
|
||||
layers=layers,
|
||||
input_height=input_height,
|
||||
input_channels=input_channels,
|
||||
)
|
||||
|
||||
|
||||
def build_layer_from_config(
|
||||
def build_layer(
|
||||
input_height: int,
|
||||
in_channels: int,
|
||||
config: LayerConfig,
|
||||
) -> Tuple[nn.Module, int, int]:
|
||||
"""Factory function to build a specific nn.Module block from its config.
|
||||
) -> Block:
|
||||
"""Build a block from its configuration object.
|
||||
|
||||
Takes configuration object (one of the types included in the `LayerConfig`
|
||||
union) and instantiates the corresponding nn.Module block with the correct
|
||||
parameters derived from the config and the current pipeline state
|
||||
(`input_height`, `in_channels`).
|
||||
|
||||
It uses the `name` field within the `config` object to determine
|
||||
which block class to instantiate.
|
||||
Looks up the block class corresponding to ``config.name`` in the
|
||||
internal block registry and instantiates it with the given input
|
||||
dimensions. This is the standard way to construct blocks when
|
||||
assembling an encoder or decoder from a configuration file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Height (frequency bins) of the input tensor *to this layer*.
|
||||
Height (number of frequency bins) of the input tensor to this
|
||||
block. Required for blocks whose kernel size depends on the input
|
||||
height (e.g. ``VerticalConv``) and for coordinate-aware blocks.
|
||||
in_channels : int
|
||||
Number of channels in the input tensor *to this layer*.
|
||||
Number of channels in the input tensor to this block.
|
||||
config : LayerConfig
|
||||
A Pydantic configuration object for the desired block (e.g., an
|
||||
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
|
||||
by its `name` field.
|
||||
A configuration object for the desired block type. The ``name``
|
||||
field selects the block class; remaining fields supply its
|
||||
parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[nn.Module, int, int]
|
||||
A tuple containing:
|
||||
- The instantiated `nn.Module` block.
|
||||
- The number of output channels produced by the block.
|
||||
- The calculated height of the output produced by the block.
|
||||
Block
|
||||
An initialised block module ready to be added to an
|
||||
``nn.Sequential`` or ``nn.ModuleList``.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
If the `config.name` does not correspond to a known block type.
|
||||
KeyError
|
||||
If ``config.name`` does not correspond to a registered block type.
|
||||
ValueError
|
||||
If parameters derived from the config are invalid for the block.
|
||||
If the configuration parameters are invalid for the chosen block.
|
||||
"""
|
||||
if config.name == "ConvBlock":
|
||||
return (
|
||||
ConvBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=config.out_channels,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
),
|
||||
config.out_channels,
|
||||
input_height,
|
||||
)
|
||||
|
||||
if config.name == "FreqCoordConvDown":
|
||||
return (
|
||||
FreqCoordConvDownBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=config.out_channels,
|
||||
input_height=input_height,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
),
|
||||
config.out_channels,
|
||||
input_height // 2,
|
||||
)
|
||||
|
||||
if config.name == "StandardConvDown":
|
||||
return (
|
||||
StandardConvDownBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=config.out_channels,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
),
|
||||
config.out_channels,
|
||||
input_height // 2,
|
||||
)
|
||||
|
||||
if config.name == "FreqCoordConvUp":
|
||||
return (
|
||||
FreqCoordConvUpBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=config.out_channels,
|
||||
input_height=input_height,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
),
|
||||
config.out_channels,
|
||||
input_height * 2,
|
||||
)
|
||||
|
||||
if config.name == "StandardConvUp":
|
||||
return (
|
||||
StandardConvUpBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=config.out_channels,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
),
|
||||
config.out_channels,
|
||||
input_height * 2,
|
||||
)
|
||||
|
||||
if config.name == "SelfAttention":
|
||||
return (
|
||||
SelfAttention(
|
||||
in_channels=in_channels,
|
||||
attention_channels=config.attention_channels,
|
||||
temperature=config.temperature,
|
||||
),
|
||||
config.attention_channels,
|
||||
input_height,
|
||||
)
|
||||
|
||||
if config.name == "LayerGroup":
|
||||
current_channels = in_channels
|
||||
current_height = input_height
|
||||
|
||||
blocks = []
|
||||
|
||||
for block_config in config.layers:
|
||||
block, current_channels, current_height = build_layer_from_config(
|
||||
input_height=current_height,
|
||||
in_channels=current_channels,
|
||||
config=block_config,
|
||||
)
|
||||
blocks.append(block)
|
||||
|
||||
return nn.Sequential(*blocks), current_channels, current_height
|
||||
|
||||
raise NotImplementedError(f"Unknown block type {config.name}")
|
||||
return block_registry.build(config, in_channels, input_height)
|
||||
|
||||
@ -1,20 +1,24 @@
|
||||
"""Defines the Bottleneck component of an Encoder-Decoder architecture.
|
||||
"""Bottleneck component for encoder-decoder network architectures.
|
||||
|
||||
This module provides the configuration (`BottleneckConfig`) and
|
||||
`torch.nn.Module` implementations (`Bottleneck`, `BottleneckAttn`) for the
|
||||
bottleneck layer(s) that typically connect the Encoder (downsampling path) and
|
||||
Decoder (upsampling path) in networks like U-Nets.
|
||||
The bottleneck sits between the encoder (downsampling path) and the decoder
|
||||
(upsampling path) and processes the lowest-resolution, highest-channel feature
|
||||
map produced by the encoder.
|
||||
|
||||
The bottleneck processes the lowest-resolution, highest-dimensionality feature
|
||||
map produced by the Encoder. This module offers a configurable option to include
|
||||
a `SelfAttention` layer within the bottleneck, allowing the model to capture
|
||||
global temporal context before features are passed to the Decoder.
|
||||
This module provides:
|
||||
|
||||
A factory function `build_bottleneck` constructs the appropriate bottleneck
|
||||
module based on the provided configuration.
|
||||
- ``BottleneckConfig`` – configuration dataclass describing the number of
|
||||
internal channels and an optional sequence of additional layers (currently
|
||||
only ``SelfAttention`` is supported).
|
||||
- ``Bottleneck`` – the ``torch.nn.Module`` implementation. It first applies a
|
||||
``VerticalConv`` to collapse the frequency axis to a single bin, optionally
|
||||
runs one or more additional layers (e.g. self-attention along the time axis),
|
||||
then repeats the output along the height dimension to restore the original
|
||||
frequency resolution before passing features to the decoder.
|
||||
- ``build_bottleneck`` – factory function that constructs a ``Bottleneck``
|
||||
instance from a ``BottleneckConfig`` and the encoder's output dimensions.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List, Optional, Union
|
||||
from typing import Annotated, List
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
@ -22,10 +26,12 @@ from torch import nn
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
Block,
|
||||
SelfAttentionConfig,
|
||||
VerticalConv,
|
||||
build_layer_from_config,
|
||||
build_layer,
|
||||
)
|
||||
from batdetect2.models.types import BottleneckProtocol
|
||||
|
||||
__all__ = [
|
||||
"BottleneckConfig",
|
||||
@ -34,43 +40,52 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""Base Bottleneck module for Encoder-Decoder architectures.
|
||||
class Bottleneck(Block):
|
||||
"""Bottleneck module for encoder-decoder architectures.
|
||||
|
||||
This implementation represents the simplest bottleneck structure
|
||||
considered, primarily consisting of a `VerticalConv` layer. This layer
|
||||
collapses the frequency dimension (height) to 1, summarizing information
|
||||
across frequencies at each time step. The output is then repeated along the
|
||||
height dimension to match the original bottleneck input height before being
|
||||
passed to the decoder.
|
||||
Processes the lowest-resolution feature map that links the encoder and
|
||||
decoder. The sequence of operations is:
|
||||
|
||||
This base version does *not* include self-attention.
|
||||
1. ``VerticalConv`` – collapses the frequency axis (height) to a single
|
||||
bin by applying a convolution whose kernel spans the full height.
|
||||
2. Optional additional layers (e.g. ``SelfAttention``) – applied while
|
||||
the feature map has height 1, so they operate purely along the time
|
||||
axis.
|
||||
3. Height restoration – the single-bin output is repeated along the
|
||||
height axis to restore the original frequency resolution, producing
|
||||
a tensor that the decoder can accept.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Height (frequency bins) of the input tensor. Must be positive.
|
||||
Height (number of frequency bins) of the input tensor. Must be
|
||||
positive.
|
||||
in_channels : int
|
||||
Number of channels in the input tensor from the encoder. Must be
|
||||
positive.
|
||||
out_channels : int
|
||||
Number of output channels. Must be positive.
|
||||
Number of output channels after the bottleneck. Must be positive.
|
||||
bottleneck_channels : int, optional
|
||||
Number of internal channels used by the ``VerticalConv`` layer.
|
||||
Defaults to ``out_channels`` if not provided.
|
||||
layers : List[torch.nn.Module], optional
|
||||
Additional modules (e.g. ``SelfAttention``) to apply after the
|
||||
``VerticalConv`` and before height restoration.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels accepted by the bottleneck.
|
||||
out_channels : int
|
||||
Number of output channels produced by the bottleneck.
|
||||
input_height : int
|
||||
Expected height of the input tensor.
|
||||
channels : int
|
||||
Number of output channels.
|
||||
bottleneck_channels : int
|
||||
Number of channels used internally by the vertical convolution.
|
||||
conv_vert : VerticalConv
|
||||
The vertical convolution layer.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `input_height`, `in_channels`, or `out_channels` are not positive.
|
||||
layers : nn.ModuleList
|
||||
Additional layers applied after the vertical convolution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -78,14 +93,31 @@ class Bottleneck(nn.Module):
|
||||
input_height: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
bottleneck_channels: Optional[int] = None,
|
||||
layers: Optional[List[torch.nn.Module]] = None,
|
||||
bottleneck_channels: int | None = None,
|
||||
layers: List[torch.nn.Module] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the base Bottleneck layer."""
|
||||
"""Initialise the Bottleneck layer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Height (number of frequency bins) of the input tensor.
|
||||
in_channels : int
|
||||
Number of channels in the input tensor.
|
||||
out_channels : int
|
||||
Number of channels in the output tensor.
|
||||
bottleneck_channels : int, optional
|
||||
Number of internal channels for the ``VerticalConv``. Defaults
|
||||
to ``out_channels``.
|
||||
layers : List[torch.nn.Module], optional
|
||||
Additional modules applied after the ``VerticalConv``, such as
|
||||
a ``SelfAttention`` block.
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.input_height = input_height
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.bottleneck_channels = (
|
||||
bottleneck_channels
|
||||
if bottleneck_channels is not None
|
||||
@ -100,23 +132,24 @@ class Bottleneck(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Process input features through the bottleneck.
|
||||
"""Process the encoder's bottleneck features.
|
||||
|
||||
Applies vertical convolution and repeats the output height.
|
||||
Applies vertical convolution, optional additional layers, then
|
||||
restores the height dimension by repetition.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor from the encoder bottleneck, shape
|
||||
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
|
||||
`H_in` must match `self.input_height`.
|
||||
Input tensor from the encoder, shape
|
||||
``(B, C_in, H_in, W)``. ``C_in`` must match
|
||||
``self.in_channels`` and ``H_in`` must match
|
||||
``self.input_height``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H_in, W)`. Note that the height
|
||||
dimension `H_in` is restored via repetition after the vertical
|
||||
convolution.
|
||||
Output tensor with shape ``(B, C_out, H_in, W)``. The height
|
||||
``H_in`` is restored by repeating the single-bin result.
|
||||
"""
|
||||
x = self.conv_vert(x)
|
||||
|
||||
@ -127,37 +160,29 @@ class Bottleneck(nn.Module):
|
||||
|
||||
|
||||
BottleneckLayerConfig = Annotated[
|
||||
Union[SelfAttentionConfig,],
|
||||
SelfAttentionConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||
"""Type alias for the discriminated union of block configs usable in the Bottleneck."""
|
||||
|
||||
|
||||
class BottleneckConfig(BaseConfig):
|
||||
"""Configuration for the bottleneck layer(s).
|
||||
|
||||
Defines the number of channels within the bottleneck and whether to include
|
||||
a self-attention mechanism.
|
||||
"""Configuration for the bottleneck component.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
channels : int
|
||||
The number of output channels produced by the main convolutional layer
|
||||
within the bottleneck. This often matches the number of channels coming
|
||||
from the last encoder stage, but can be different. Must be positive.
|
||||
This also defines the channel dimensions used within the optional
|
||||
`SelfAttention` layer.
|
||||
self_attention : bool
|
||||
If True, includes a `SelfAttention` layer operating on the time
|
||||
dimension after an initial `VerticalConv` layer within the bottleneck.
|
||||
If False, only the initial `VerticalConv` (and height repetition) is
|
||||
performed.
|
||||
Number of output channels produced by the bottleneck. This value
|
||||
is also used as the dimensionality of any optional layers (e.g.
|
||||
self-attention). Must be positive.
|
||||
layers : List[BottleneckLayerConfig]
|
||||
Ordered list of additional block configurations to apply after the
|
||||
initial ``VerticalConv``. Currently only ``SelfAttentionConfig`` is
|
||||
supported. Defaults to an empty list (no extra layers).
|
||||
"""
|
||||
|
||||
channels: int
|
||||
layers: List[BottleneckLayerConfig] = Field(
|
||||
default_factory=list,
|
||||
)
|
||||
layers: List[BottleneckLayerConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
||||
@ -171,32 +196,39 @@ DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
||||
def build_bottleneck(
|
||||
input_height: int,
|
||||
in_channels: int,
|
||||
config: Optional[BottleneckConfig] = None,
|
||||
) -> nn.Module:
|
||||
"""Factory function to build the Bottleneck module from configuration.
|
||||
config: BottleneckConfig | None = None,
|
||||
) -> BottleneckProtocol:
|
||||
"""Build a ``Bottleneck`` module from configuration.
|
||||
|
||||
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
|
||||
on the `config.self_attention` flag.
|
||||
Constructs a ``Bottleneck`` instance whose internal channel count and
|
||||
optional extra layers (e.g. self-attention) are controlled by
|
||||
``config``. If no configuration is provided, the default
|
||||
``DEFAULT_BOTTLENECK_CONFIG`` is used, which includes a
|
||||
``SelfAttention`` layer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Height (frequency bins) of the input tensor. Must be positive.
|
||||
Height (number of frequency bins) of the input tensor from the
|
||||
encoder. Must be positive.
|
||||
in_channels : int
|
||||
Number of channels in the input tensor. Must be positive.
|
||||
Number of channels in the input tensor from the encoder. Must be
|
||||
positive.
|
||||
config : BottleneckConfig, optional
|
||||
Configuration object specifying the bottleneck channels and whether
|
||||
to use self-attention. Uses `DEFAULT_BOTTLENECK_CONFIG` if None.
|
||||
Configuration specifying the output channel count and any
|
||||
additional layers. Uses ``DEFAULT_BOTTLENECK_CONFIG`` if ``None``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
nn.Module
|
||||
An initialized bottleneck module (`Bottleneck` or `BottleneckAttn`).
|
||||
BottleneckProtocol
|
||||
An initialised ``Bottleneck`` module.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `input_height` or `in_channels` are not positive.
|
||||
AssertionError
|
||||
If any configured layer changes the height of the feature map
|
||||
(bottleneck layers must preserve height so that it can be restored
|
||||
by repetition).
|
||||
"""
|
||||
config = config or DEFAULT_BOTTLENECK_CONFIG
|
||||
|
||||
@ -206,11 +238,13 @@ def build_bottleneck(
|
||||
layers = []
|
||||
|
||||
for layer_config in config.layers:
|
||||
layer, current_channels, current_height = build_layer_from_config(
|
||||
layer = build_layer(
|
||||
input_height=current_height,
|
||||
in_channels=current_channels,
|
||||
config=layer_config,
|
||||
)
|
||||
current_height = layer.get_output_height(current_height)
|
||||
current_channels = layer.out_channels
|
||||
assert current_height == input_height, (
|
||||
"Bottleneck layers should not change the spectrogram height"
|
||||
)
|
||||
|
||||
@ -1,98 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.models.bottleneck import (
|
||||
DEFAULT_BOTTLENECK_CONFIG,
|
||||
BottleneckConfig,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
DecoderConfig,
|
||||
)
|
||||
from batdetect2.models.encoder import (
|
||||
DEFAULT_ENCODER_CONFIG,
|
||||
EncoderConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BackboneConfig",
|
||||
"load_backbone_config",
|
||||
]
|
||||
|
||||
|
||||
class BackboneConfig(BaseConfig):
|
||||
"""Configuration for the Encoder-Decoder Backbone network.
|
||||
|
||||
Aggregates configurations for the encoder, bottleneck, and decoder
|
||||
components, along with defining the input and final output dimensions
|
||||
for the complete backbone.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int, default=128
|
||||
Expected height (frequency bins) of the input spectrograms to the
|
||||
backbone. Must be positive.
|
||||
in_channels : int, default=1
|
||||
Expected number of channels in the input spectrograms (e.g., 1 for
|
||||
mono). Must be positive.
|
||||
encoder : EncoderConfig, optional
|
||||
Configuration for the encoder. If None or omitted,
|
||||
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
|
||||
encoder module) will be used.
|
||||
bottleneck : BottleneckConfig, optional
|
||||
Configuration for the bottleneck layer connecting encoder and decoder.
|
||||
If None or omitted, the default bottleneck configuration will be used.
|
||||
decoder : DecoderConfig, optional
|
||||
Configuration for the decoder. If None or omitted,
|
||||
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
|
||||
decoder module) will be used.
|
||||
out_channels : int, default=32
|
||||
Desired number of channels in the final feature map output by the
|
||||
backbone. Must be positive.
|
||||
"""
|
||||
|
||||
input_height: int = 128
|
||||
in_channels: int = 1
|
||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> BackboneConfig:
|
||||
"""Load the backbone configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`BackboneConfig` schema, potentially extracting data from a nested field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
backbone configuration (e.g., "model.backbone"). If None, the entire
|
||||
file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneConfig
|
||||
The loaded and validated backbone configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded config data does not conform to `BackboneConfig`.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path.
|
||||
"""
|
||||
return load_config(path, schema=BackboneConfig, field=field)
|
||||
@ -1,24 +1,24 @@
|
||||
"""Constructs the Decoder part of an Encoder-Decoder neural network.
|
||||
"""Decoder (upsampling path) for the BatDetect2 backbone.
|
||||
|
||||
This module defines the configuration structure (`DecoderConfig`) for the layer
|
||||
sequence and provides the `Decoder` class (an `nn.Module`) along with a factory
|
||||
function (`build_decoder`). Decoders typically form the upsampling path in
|
||||
architectures like U-Nets, taking bottleneck features
|
||||
(usually from an `Encoder`) and skip connections to reconstruct
|
||||
higher-resolution feature maps.
|
||||
This module defines ``DecoderConfig`` and the ``Decoder`` ``nn.Module``,
|
||||
together with the ``build_decoder`` factory function.
|
||||
|
||||
The decoder is built dynamically by stacking neural network blocks based on a
|
||||
list of configuration objects provided in `DecoderConfig.layers`. Each config
|
||||
object specifies the type of block (e.g., standard convolution,
|
||||
coordinate-feature convolution with upsampling) and its parameters. This allows
|
||||
flexible definition of decoder architectures via configuration files.
|
||||
In a U-Net-style network the decoder progressively restores the spatial
|
||||
resolution of the feature map back towards the input resolution. At each
|
||||
stage it combines the upsampled features with the corresponding skip-connection
|
||||
tensor from the encoder (the residual) by element-wise addition before passing
|
||||
the result to the upsampling block.
|
||||
|
||||
The `Decoder`'s `forward` method is designed to accept skip connection tensors
|
||||
(`residuals`) from the encoder, merging them with the upsampled feature maps
|
||||
at each stage.
|
||||
The decoder is fully configurable: the type, number, and parameters of the
|
||||
upsampling blocks are described by a ``DecoderConfig`` object containing an
|
||||
ordered list of block configuration objects (see ``batdetect2.models.blocks``
|
||||
for available block types).
|
||||
|
||||
A default configuration ``DEFAULT_DECODER_CONFIG`` is provided and used by
|
||||
``build_decoder`` when no explicit configuration is supplied.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List, Optional, Union
|
||||
from typing import Annotated, List
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
@ -30,7 +30,7 @@ from batdetect2.models.blocks import (
|
||||
FreqCoordConvUpConfig,
|
||||
LayerGroupConfig,
|
||||
StandardConvUpConfig,
|
||||
build_layer_from_config,
|
||||
build_layer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -41,63 +41,57 @@ __all__ = [
|
||||
]
|
||||
|
||||
DecoderLayerConfig = Annotated[
|
||||
Union[
|
||||
ConvConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
LayerGroupConfig,
|
||||
],
|
||||
ConvConfig
|
||||
| FreqCoordConvUpConfig
|
||||
| StandardConvUpConfig
|
||||
| LayerGroupConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||
|
||||
|
||||
class DecoderConfig(BaseConfig):
|
||||
"""Configuration for the sequence of layers in the Decoder module.
|
||||
|
||||
Defines the types and parameters of the neural network blocks that
|
||||
constitute the decoder's upsampling path.
|
||||
"""Configuration for the sequential ``Decoder`` module.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
layers : List[DecoderLayerConfig]
|
||||
An ordered list of configuration objects, each defining one layer or
|
||||
block in the decoder sequence. Each item must be a valid block
|
||||
config including a `name` field and necessary parameters like
|
||||
`out_channels`. Input channels for each layer are inferred sequentially.
|
||||
The list must contain at least one layer.
|
||||
Ordered list of block configuration objects defining the decoder's
|
||||
upsampling stages (from deepest to shallowest). Each entry
|
||||
specifies the block type (via its ``name`` field) and any
|
||||
block-specific parameters such as ``out_channels``. Input channels
|
||||
for each block are inferred automatically from the output of the
|
||||
previous block. Must contain at least one entry.
|
||||
"""
|
||||
|
||||
layers: List[DecoderLayerConfig] = Field(min_length=1)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Sequential Decoder module composed of configurable upsampling layers.
|
||||
"""Sequential decoder module composed of configurable upsampling layers.
|
||||
|
||||
Constructs the upsampling path of an encoder-decoder network by stacking
|
||||
multiple blocks (e.g., `StandardConvUpBlock`, `FreqCoordConvUpBlock`)
|
||||
based on a list of layer modules provided during initialization (typically
|
||||
created by the `build_decoder` factory function).
|
||||
Executes a series of upsampling blocks in order, adding the
|
||||
corresponding encoder skip-connection tensor (residual) to the feature
|
||||
map before each block. The residuals are consumed in reverse order (from
|
||||
deepest encoder layer to shallowest) to match the spatial resolutions at
|
||||
each decoder stage.
|
||||
|
||||
The `forward` method is designed to integrate skip connection tensors
|
||||
(`residuals`) from the corresponding encoder stages, by adding them
|
||||
element-wise to the input of each decoder layer before processing.
|
||||
Instances are typically created by ``build_decoder``.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels expected in the input tensor.
|
||||
Number of channels expected in the input tensor (bottleneck output).
|
||||
out_channels : int
|
||||
Number of channels in the final output tensor produced by the last
|
||||
layer.
|
||||
Number of channels in the final output feature map.
|
||||
input_height : int
|
||||
Height (frequency bins) expected in the input tensor.
|
||||
Height (frequency bins) of the input tensor.
|
||||
output_height : int
|
||||
Height (frequency bins) expected in the output tensor.
|
||||
Height (frequency bins) of the output tensor.
|
||||
layers : nn.ModuleList
|
||||
The sequence of instantiated upscaling layer modules.
|
||||
Sequence of instantiated upsampling block modules.
|
||||
depth : int
|
||||
The number of upscaling layers (depth) in the decoder.
|
||||
Number of upsampling layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -108,23 +102,24 @@ class Decoder(nn.Module):
|
||||
output_height: int,
|
||||
layers: List[nn.Module],
|
||||
):
|
||||
"""Initialize the Decoder module.
|
||||
"""Initialise the Decoder module.
|
||||
|
||||
Note: This constructor is typically called internally by the
|
||||
`build_decoder` factory function.
|
||||
This constructor is typically called by the ``build_decoder``
|
||||
factory function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input tensor (bottleneck output).
|
||||
out_channels : int
|
||||
Number of channels produced by the final layer.
|
||||
input_height : int
|
||||
Expected height of the input tensor (bottleneck).
|
||||
in_channels : int
|
||||
Expected number of channels in the input tensor (bottleneck).
|
||||
Height of the input tensor (bottleneck output height).
|
||||
output_height : int
|
||||
Height of the output tensor after all layers have been applied.
|
||||
layers : List[nn.Module]
|
||||
A list of pre-instantiated upscaling layer modules (e.g.,
|
||||
`StandardConvUpBlock` or `FreqCoordConvUpBlock`) in the desired
|
||||
sequence (from bottleneck towards output resolution).
|
||||
Pre-built upsampling block modules in execution order (deepest
|
||||
stage first).
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -142,43 +137,35 @@ class Decoder(nn.Module):
|
||||
x: torch.Tensor,
|
||||
residuals: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Pass input through decoder layers, incorporating skip connections.
|
||||
"""Pass input through all decoder layers, incorporating skip connections.
|
||||
|
||||
Processes the input tensor `x` sequentially through the upscaling
|
||||
layers. At each stage, the corresponding skip connection tensor from
|
||||
the `residuals` list is added element-wise to the input before passing
|
||||
it to the upscaling block.
|
||||
At each stage the corresponding residual tensor is added
|
||||
element-wise to ``x`` before it is passed to the upsampling block.
|
||||
Residuals are consumed in reverse order — the last element of
|
||||
``residuals`` (the output of the shallowest encoder layer) is added
|
||||
at the first decoder stage, and the first element (output of the
|
||||
deepest encoder layer) is added at the last decoder stage.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor from the previous stage (e.g., encoder bottleneck).
|
||||
Shape `(B, C_in, H_in, W_in)`, where `C_in` matches
|
||||
`self.in_channels`.
|
||||
Bottleneck feature map, shape ``(B, C_in, H_in, W)``.
|
||||
residuals : List[torch.Tensor]
|
||||
List containing the skip connection tensors from the corresponding
|
||||
encoder stages. Should be ordered from the deepest encoder layer
|
||||
output (lowest resolution) to the shallowest (highest resolution
|
||||
near input). The number of tensors in this list must match the
|
||||
number of decoder layers (`self.depth`). Each residual tensor's
|
||||
channel count must be compatible with the input tensor `x` for
|
||||
element-wise addition (or concatenation if the blocks were designed
|
||||
for it).
|
||||
Skip-connection tensors from the encoder, ordered from shallowest
|
||||
(index 0) to deepest (index -1). Must contain exactly
|
||||
``self.depth`` tensors. Each tensor must have the same spatial
|
||||
dimensions and channel count as ``x`` at the corresponding
|
||||
decoder stage.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The final decoded feature map tensor produced by the last layer.
|
||||
Shape `(B, C_out, H_out, W_out)`.
|
||||
Decoded feature map, shape ``(B, C_out, H_out, W)``.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the number of `residuals` provided does not match the decoder
|
||||
depth.
|
||||
RuntimeError
|
||||
If shapes mismatch during skip connection addition or layer
|
||||
processing.
|
||||
If the number of ``residuals`` does not equal ``self.depth``.
|
||||
"""
|
||||
if len(residuals) != len(self.layers):
|
||||
raise ValueError(
|
||||
@ -187,7 +174,7 @@ class Decoder(nn.Module):
|
||||
f"but got {len(residuals)}."
|
||||
)
|
||||
|
||||
for layer, res in zip(self.layers, residuals[::-1]):
|
||||
for layer, res in zip(self.layers, residuals[::-1], strict=False):
|
||||
x = layer(x + res)
|
||||
|
||||
return x
|
||||
@ -205,53 +192,55 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
|
||||
),
|
||||
],
|
||||
)
|
||||
"""A default configuration for the Decoder's *layer sequence*.
|
||||
"""Default decoder configuration used in standard BatDetect2 models.
|
||||
|
||||
Specifies an architecture often used in BatDetect2, consisting of three
|
||||
frequency coordinate-aware upsampling blocks followed by a standard
|
||||
convolutional block.
|
||||
Mirrors ``DEFAULT_ENCODER_CONFIG`` in reverse. Assumes the bottleneck
|
||||
output has 256 channels and height 16, and produces:
|
||||
|
||||
- Stage 1 (``FreqCoordConvUp``): 64 channels, height 32.
|
||||
- Stage 2 (``FreqCoordConvUp``): 32 channels, height 64.
|
||||
- Stage 3 (``LayerGroup``):
|
||||
|
||||
- ``FreqCoordConvUp``: 32 channels, height 128.
|
||||
- ``ConvBlock``: 32 channels, height 128 (final feature map).
|
||||
"""
|
||||
|
||||
|
||||
def build_decoder(
|
||||
in_channels: int,
|
||||
input_height: int,
|
||||
config: Optional[DecoderConfig] = None,
|
||||
config: DecoderConfig | None = None,
|
||||
) -> Decoder:
|
||||
"""Factory function to build a Decoder instance from configuration.
|
||||
"""Build a ``Decoder`` from configuration.
|
||||
|
||||
Constructs a sequential `Decoder` module based on the layer sequence
|
||||
defined in a `DecoderConfig` object and the provided input dimensions
|
||||
(bottleneck channels and height). If no config is provided, uses the
|
||||
default layer sequence from `DEFAULT_DECODER_CONFIG`.
|
||||
|
||||
It iteratively builds the layers using the unified `build_layer_from_config`
|
||||
factory (from `.blocks`), tracking the changing number of channels and
|
||||
feature map height required for each subsequent layer.
|
||||
Constructs a sequential ``Decoder`` by iterating over the block
|
||||
configurations in ``config.layers``, building each block with
|
||||
``build_layer``, and tracking the channel count and feature-map height
|
||||
as they change through the sequence.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
The number of channels in the input tensor to the decoder. Must be > 0.
|
||||
Number of channels in the input tensor (bottleneck output). Must
|
||||
be positive.
|
||||
input_height : int
|
||||
The height (frequency bins) of the input tensor to the decoder. Must be
|
||||
> 0.
|
||||
Height (number of frequency bins) of the input tensor. Must be
|
||||
positive.
|
||||
config : DecoderConfig, optional
|
||||
The configuration object detailing the sequence of layers and their
|
||||
parameters. If None, `DEFAULT_DECODER_CONFIG` is used.
|
||||
Configuration specifying the layer sequence. Defaults to
|
||||
``DEFAULT_DECODER_CONFIG`` if not provided.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Decoder
|
||||
An initialized `Decoder` module.
|
||||
An initialised ``Decoder`` module.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `in_channels` or `input_height` are not positive, or if the layer
|
||||
configuration is invalid (e.g., empty list, unknown `name`).
|
||||
NotImplementedError
|
||||
If `build_layer_from_config` encounters an unknown `name`.
|
||||
If ``in_channels`` or ``input_height`` are not positive.
|
||||
KeyError
|
||||
If a layer configuration specifies an unknown block type.
|
||||
"""
|
||||
config = config or DEFAULT_DECODER_CONFIG
|
||||
|
||||
@ -261,11 +250,13 @@ def build_decoder(
|
||||
layers = []
|
||||
|
||||
for layer_config in config.layers:
|
||||
layer, current_channels, current_height = build_layer_from_config(
|
||||
layer = build_layer(
|
||||
in_channels=current_channels,
|
||||
input_height=current_height,
|
||||
config=layer_config,
|
||||
)
|
||||
current_height = layer.get_output_height(current_height)
|
||||
current_channels = layer.out_channels
|
||||
layers.append(layer)
|
||||
|
||||
return Decoder(
|
||||
|
||||
@ -1,27 +1,32 @@
|
||||
"""Assembles the complete BatDetect2 Detection Model.
|
||||
"""Assembles the complete BatDetect2 detection model.
|
||||
|
||||
This module defines the concrete `Detector` class, which implements the
|
||||
`DetectionModel` interface defined in `.types`. It combines a feature
|
||||
extraction backbone with specific prediction heads to create the end-to-end
|
||||
neural network used for detecting bat calls, predicting their size, and
|
||||
classifying them.
|
||||
This module defines the ``Detector`` class, which combines a backbone
|
||||
feature extractor with prediction heads for detection, classification, and
|
||||
bounding-box size regression.
|
||||
|
||||
The primary components are:
|
||||
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
|
||||
Components
|
||||
----------
|
||||
- ``Detector`` – the ``torch.nn.Module`` that wires together a backbone
|
||||
(``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to
|
||||
produce a ``ModelOutput`` tuple from an input spectrogram.
|
||||
- ``build_detector`` – factory function that builds a ready-to-use
|
||||
``Detector`` from a backbone configuration and a target class count.
|
||||
|
||||
This module focuses purely on the neural network architecture definition. The
|
||||
logic for preprocessing inputs and postprocessing/decoding outputs resides in
|
||||
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
||||
Note that ``Detector`` operates purely on spectrogram tensors; raw audio
|
||||
preprocessing and output postprocessing are handled by
|
||||
``batdetect2.preprocess`` and ``batdetect2.postprocess`` respectively.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.models.backbones import BackboneConfig, build_backbone
|
||||
from batdetect2.models.backbones import (
|
||||
BackboneConfig,
|
||||
UNetBackboneConfig,
|
||||
build_backbone,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"Detector",
|
||||
@ -30,25 +35,30 @@ __all__ = [
|
||||
|
||||
|
||||
class Detector(DetectionModel):
|
||||
"""Concrete implementation of the BatDetect2 Detection Model.
|
||||
"""Complete BatDetect2 detection and classification model.
|
||||
|
||||
Assembles a complete detection and classification model by combining a
|
||||
feature extraction backbone network with specific prediction heads for
|
||||
detection probability, bounding box size regression, and class
|
||||
probabilities.
|
||||
Combines a backbone feature extractor with two prediction heads:
|
||||
|
||||
- ``ClassifierHead``: predicts per-class probabilities at each
|
||||
time–frequency location.
|
||||
- ``BBoxHead``: predicts call duration and bandwidth at each location.
|
||||
|
||||
The detection probability map is derived from the class probabilities by
|
||||
summing across the class dimension (i.e. the probability that *any* class
|
||||
is present), rather than from a separate detection head.
|
||||
|
||||
Instances are typically created via ``build_detector``.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
backbone : BackboneModel
|
||||
The feature extraction backbone network module.
|
||||
The feature extraction backbone.
|
||||
num_classes : int
|
||||
The number of specific target classes the model predicts (derived from
|
||||
the `classifier_head`).
|
||||
Number of target classes (inferred from the classifier head).
|
||||
classifier_head : ClassifierHead
|
||||
The prediction head responsible for generating class probabilities.
|
||||
Produces per-class probability maps from backbone features.
|
||||
bbox_head : BBoxHead
|
||||
The prediction head responsible for generating bounding box size
|
||||
predictions.
|
||||
Produces duration and bandwidth predictions from backbone features.
|
||||
"""
|
||||
|
||||
backbone: BackboneModel
|
||||
@ -59,26 +69,21 @@ class Detector(DetectionModel):
|
||||
classifier_head: ClassifierHead,
|
||||
bbox_head: BBoxHead,
|
||||
):
|
||||
"""Initialize the Detector model.
|
||||
"""Initialise the Detector model.
|
||||
|
||||
Note: Instances are typically created using the `build_detector`
|
||||
This constructor is typically called by the ``build_detector``
|
||||
factory function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backbone : BackboneModel
|
||||
An initialized feature extraction backbone module (e.g., built by
|
||||
`build_backbone` from the `.backbone` module).
|
||||
An initialised backbone module (e.g. built by
|
||||
``build_backbone``).
|
||||
classifier_head : ClassifierHead
|
||||
An initialized classification head module. The number of classes
|
||||
is inferred from this head.
|
||||
An initialised classification head. The ``num_classes``
|
||||
attribute is read from this head.
|
||||
bbox_head : BBoxHead
|
||||
An initialized bounding box size prediction head module.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError
|
||||
If the provided modules are not of the expected types.
|
||||
An initialised bounding-box size prediction head.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@ -88,31 +93,34 @@ class Detector(DetectionModel):
|
||||
self.bbox_head = bbox_head
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
"""Perform the forward pass of the complete detection model.
|
||||
"""Run the complete detection model on an input spectrogram.
|
||||
|
||||
Processes the input spectrogram through the backbone to extract
|
||||
features, then passes these features through the separate prediction
|
||||
heads to generate detection probabilities, class probabilities, and
|
||||
size predictions.
|
||||
Passes the spectrogram through the backbone to produce a feature
|
||||
map, then applies the classifier and bounding-box heads. The
|
||||
detection probability map is derived by summing the per-class
|
||||
probability maps across the class dimension; no separate detection
|
||||
head is used.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor, typically with shape
|
||||
`(batch_size, input_channels, frequency_bins, time_bins)`. The
|
||||
shape must be compatible with the `self.backbone` input
|
||||
requirements.
|
||||
Input spectrogram tensor, shape
|
||||
``(batch_size, channels, frequency_bins, time_bins)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ModelOutput
|
||||
A NamedTuple containing the four output tensors:
|
||||
- `detection_probs`: Detection probability heatmap `(B, 1, H, W)`.
|
||||
- `size_preds`: Predicted scaled size dimensions `(B, 2, H, W)`.
|
||||
- `class_probs`: Class probabilities (excluding background)
|
||||
`(B, num_classes, H, W)`.
|
||||
- `features`: Output feature map from the backbone
|
||||
`(B, C_out, H, W)`.
|
||||
A named tuple with four fields:
|
||||
|
||||
- ``detection_probs`` – ``(B, 1, H, W)`` – probability that a
|
||||
call of any class is present at each location. Derived by
|
||||
summing ``class_probs`` over the class dimension.
|
||||
- ``size_preds`` – ``(B, 2, H, W)`` – scaled duration (channel
|
||||
0) and bandwidth (channel 1) predictions at each location.
|
||||
- ``class_probs`` – ``(B, num_classes, H, W)`` – per-class
|
||||
probabilities at each location.
|
||||
- ``features`` – ``(B, C_out, H, W)`` – raw backbone feature
|
||||
map.
|
||||
"""
|
||||
features = self.backbone(spec)
|
||||
classification = self.classifier_head(features)
|
||||
@ -127,40 +135,46 @@ class Detector(DetectionModel):
|
||||
|
||||
|
||||
def build_detector(
|
||||
num_classes: int, config: Optional[BackboneConfig] = None
|
||||
num_classes: int,
|
||||
config: BackboneConfig | None = None,
|
||||
backbone: BackboneModel | None = None,
|
||||
) -> DetectionModel:
|
||||
"""Build the complete BatDetect2 detection model.
|
||||
"""Build a complete BatDetect2 detection model.
|
||||
|
||||
Constructs a backbone from ``config``, attaches a ``ClassifierHead``
|
||||
and a ``BBoxHead`` sized to the backbone's output channel count, and
|
||||
returns them wrapped in a ``Detector``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_classes : int
|
||||
The number of specific target classes the model should predict
|
||||
(required for the `ClassifierHead`). Must be positive.
|
||||
Number of target bat species or call types to predict. Must be
|
||||
positive.
|
||||
config : BackboneConfig, optional
|
||||
Configuration object specifying the architecture of the backbone
|
||||
(encoder, bottleneck, decoder). If None, default configurations defined
|
||||
within the respective builder functions (`build_encoder`, etc.) will be
|
||||
used to construct a default backbone architecture.
|
||||
Backbone architecture configuration. Defaults to
|
||||
``UNetBackboneConfig()`` (the standard BatDetect2 architecture) if
|
||||
not provided.
|
||||
|
||||
Returns
|
||||
-------
|
||||
DetectionModel
|
||||
An initialized `Detector` model instance.
|
||||
An initialised ``Detector`` instance ready for training or
|
||||
inference.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num_classes` is not positive, or if errors occur during the
|
||||
construction of the backbone or detector components (e.g., incompatible
|
||||
configurations, invalid parameters).
|
||||
If ``num_classes`` is not positive, or if the backbone
|
||||
configuration is invalid.
|
||||
"""
|
||||
config = config or BackboneConfig()
|
||||
if backbone is None:
|
||||
config = config or UNetBackboneConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building model with config: \n{}",
|
||||
lambda: config.to_yaml_string(), # type: ignore
|
||||
)
|
||||
backbone = build_backbone(config=config)
|
||||
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building model with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
backbone = build_backbone(config=config)
|
||||
classifier_head = ClassifierHead(
|
||||
num_classes=num_classes,
|
||||
in_channels=backbone.out_channels,
|
||||
|
||||
@ -1,26 +1,27 @@
|
||||
"""Constructs the Encoder part of a configurable neural network backbone.
|
||||
"""Encoder (downsampling path) for the BatDetect2 backbone.
|
||||
|
||||
This module defines the configuration structure (`EncoderConfig`) and provides
|
||||
the `Encoder` class (an `nn.Module`) along with a factory function
|
||||
(`build_encoder`) to create sequential encoders. Encoders typically form the
|
||||
downsampling path in architectures like U-Nets, processing input feature maps
|
||||
(like spectrograms) to produce lower-resolution, higher-dimensionality feature
|
||||
representations (bottleneck features).
|
||||
This module defines ``EncoderConfig`` and the ``Encoder`` ``nn.Module``,
|
||||
together with the ``build_encoder`` factory function.
|
||||
|
||||
The encoder is built dynamically by stacking neural network blocks based on a
|
||||
list of configuration objects provided in `EncoderConfig.layers`. Each
|
||||
configuration object specifies the type of block (e.g., standard convolution,
|
||||
coordinate-feature convolution with downsampling) and its parameters
|
||||
(e.g., output channels). This allows for flexible definition of encoder
|
||||
architectures via configuration files.
|
||||
In a U-Net-style network the encoder progressively reduces the spatial
|
||||
resolution of the spectrogram whilst increasing the number of feature
|
||||
channels. Each layer in the encoder produces a feature map that is stored
|
||||
for use as a skip connection in the corresponding decoder layer.
|
||||
|
||||
The `Encoder`'s `forward` method returns outputs from all intermediate layers,
|
||||
suitable for skip connections, while the `encode` method returns only the final
|
||||
bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also
|
||||
provided.
|
||||
The encoder is fully configurable: the type, number, and parameters of the
|
||||
downsampling blocks are described by an ``EncoderConfig`` object containing
|
||||
an ordered list of block configuration objects (see ``batdetect2.models.blocks``
|
||||
for available block types).
|
||||
|
||||
``Encoder.forward`` returns the outputs of *all* encoder layers as a list,
|
||||
so that skip connections are available to the decoder.
|
||||
``Encoder.encode`` returns only the final output (the input to the bottleneck).
|
||||
|
||||
A default configuration ``DEFAULT_ENCODER_CONFIG`` is provided and used by
|
||||
``build_encoder`` when no explicit configuration is supplied.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List, Optional, Union
|
||||
from typing import Annotated, List
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
@ -32,7 +33,7 @@ from batdetect2.models.blocks import (
|
||||
FreqCoordConvDownConfig,
|
||||
LayerGroupConfig,
|
||||
StandardConvDownConfig,
|
||||
build_layer_from_config,
|
||||
build_layer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -43,47 +44,42 @@ __all__ = [
|
||||
]
|
||||
|
||||
EncoderLayerConfig = Annotated[
|
||||
Union[
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
StandardConvDownConfig,
|
||||
LayerGroupConfig,
|
||||
],
|
||||
ConvConfig
|
||||
| FreqCoordConvDownConfig
|
||||
| StandardConvDownConfig
|
||||
| LayerGroupConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
||||
|
||||
|
||||
class EncoderConfig(BaseConfig):
|
||||
"""Configuration for building the sequential Encoder module.
|
||||
|
||||
Defines the sequence of neural network blocks that constitute the encoder
|
||||
(downsampling path).
|
||||
"""Configuration for the sequential ``Encoder`` module.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
layers : List[EncoderLayerConfig]
|
||||
An ordered list of configuration objects, each defining one layer or
|
||||
block in the encoder sequence. Each item must be a valid block config
|
||||
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
|
||||
`StandardConvDownConfig`) including a `name` field and necessary
|
||||
parameters like `out_channels`. Input channels for each layer are
|
||||
inferred sequentially. The list must contain at least one layer.
|
||||
Ordered list of block configuration objects defining the encoder's
|
||||
downsampling stages. Each entry specifies the block type (via its
|
||||
``name`` field) and any block-specific parameters such as
|
||||
``out_channels``. Input channels for each block are inferred
|
||||
automatically from the output of the previous block. Must contain
|
||||
at least one entry.
|
||||
"""
|
||||
|
||||
layers: List[EncoderLayerConfig] = Field(min_length=1)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Sequential Encoder module composed of configurable downscaling layers.
|
||||
"""Sequential encoder module composed of configurable downsampling layers.
|
||||
|
||||
Constructs the downsampling path of an encoder-decoder network by stacking
|
||||
multiple downscaling blocks.
|
||||
Executes a series of downsampling blocks in order, storing the output of
|
||||
each block so that it can be passed as a skip connection to the
|
||||
corresponding decoder layer.
|
||||
|
||||
The `forward` method executes the sequence and returns the output feature
|
||||
map from *each* downscaling stage, facilitating the implementation of skip
|
||||
connections in U-Net-like architectures. The `encode` method returns only
|
||||
the final output tensor (bottleneck features).
|
||||
``forward`` returns the outputs of *all* layers (useful when skip
|
||||
connections are needed). ``encode`` returns only the final output
|
||||
(the input to the bottleneck).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
@ -91,14 +87,14 @@ class Encoder(nn.Module):
|
||||
Number of channels expected in the input tensor.
|
||||
input_height : int
|
||||
Height (frequency bins) expected in the input tensor.
|
||||
output_channels : int
|
||||
Number of channels in the final output tensor (bottleneck).
|
||||
out_channels : int
|
||||
Number of channels in the final output tensor (bottleneck input).
|
||||
output_height : int
|
||||
Height (frequency bins) expected in the output tensor.
|
||||
Height (frequency bins) of the final output tensor.
|
||||
layers : nn.ModuleList
|
||||
The sequence of instantiated downscaling layer modules.
|
||||
Sequence of instantiated downsampling block modules.
|
||||
depth : int
|
||||
The number of downscaling layers in the encoder.
|
||||
Number of downsampling layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -109,23 +105,22 @@ class Encoder(nn.Module):
|
||||
input_height: int = 128,
|
||||
in_channels: int = 1,
|
||||
):
|
||||
"""Initialize the Encoder module.
|
||||
"""Initialise the Encoder module.
|
||||
|
||||
Note: This constructor is typically called internally by the
|
||||
`build_encoder` factory function, which prepares the `layers` list.
|
||||
This constructor is typically called by the ``build_encoder`` factory
|
||||
function, which takes care of building the ``layers`` list from a
|
||||
configuration object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output_channels : int
|
||||
Number of channels produced by the final layer.
|
||||
output_height : int
|
||||
The expected height of the output tensor.
|
||||
Height of the output tensor after all layers have been applied.
|
||||
layers : List[nn.Module]
|
||||
A list of pre-instantiated downscaling layer modules (e.g.,
|
||||
`StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired
|
||||
sequence.
|
||||
Pre-built downsampling block modules in execution order.
|
||||
input_height : int, default=128
|
||||
Expected height of the input tensor.
|
||||
Expected height of the input tensor (frequency bins).
|
||||
in_channels : int, default=1
|
||||
Expected number of channels in the input tensor.
|
||||
"""
|
||||
@ -140,29 +135,30 @@ class Encoder(nn.Module):
|
||||
self.depth = len(self.layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""Pass input through encoder layers, returns all intermediate outputs.
|
||||
"""Pass input through all encoder layers and return every output.
|
||||
|
||||
This method is typically used when the Encoder is part of a U-Net or
|
||||
similar architecture requiring skip connections.
|
||||
Used when skip connections are needed (e.g. in a U-Net decoder).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match
|
||||
`self.in_channels` and `H_in` must match `self.input_height`.
|
||||
Input spectrogram feature map, shape ``(B, C_in, H_in, W)``.
|
||||
``C_in`` must match ``self.in_channels`` and ``H_in`` must
|
||||
match ``self.input_height``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[torch.Tensor]
|
||||
A list containing the output tensors from *each* downscaling layer
|
||||
in the sequence. `outputs[0]` is the output of the first layer,
|
||||
`outputs[-1]` is the final output (bottleneck) of the encoder.
|
||||
Output tensors from every layer in order.
|
||||
``outputs[0]`` is the output of the first (shallowest) layer;
|
||||
``outputs[-1]`` is the output of the last (deepest) layer,
|
||||
which serves as the input to the bottleneck.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If input tensor channel count or height does not match expected
|
||||
values.
|
||||
If the input channel count or height does not match the
|
||||
expected values.
|
||||
"""
|
||||
if x.shape[1] != self.in_channels:
|
||||
raise ValueError(
|
||||
@ -185,28 +181,29 @@ class Encoder(nn.Module):
|
||||
return outputs
|
||||
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pass input through encoder layers, returning only the final output.
|
||||
"""Pass input through all encoder layers and return only the final output.
|
||||
|
||||
This method provides access to the bottleneck features produced after
|
||||
the last downscaling layer.
|
||||
Use this when skip connections are not needed and you only require
|
||||
the bottleneck feature map.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H_in, W)`. Must match expected
|
||||
`in_channels` and `input_height`.
|
||||
Input spectrogram feature map, shape ``(B, C_in, H_in, W)``.
|
||||
Must satisfy the same shape requirements as ``forward``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The final output tensor (bottleneck features) from the last layer
|
||||
of the encoder. Shape `(B, C_out, H_out, W_out)`.
|
||||
Output of the last encoder layer, shape
|
||||
``(B, C_out, H_out, W)``, where ``C_out`` is
|
||||
``self.out_channels`` and ``H_out`` is ``self.output_height``.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If input tensor channel count or height does not match expected
|
||||
values.
|
||||
If the input channel count or height does not match the
|
||||
expected values.
|
||||
"""
|
||||
if x.shape[1] != self.in_channels:
|
||||
raise ValueError(
|
||||
@ -238,58 +235,57 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
|
||||
),
|
||||
],
|
||||
)
|
||||
"""Default configuration for the Encoder.
|
||||
"""Default encoder configuration used in standard BatDetect2 models.
|
||||
|
||||
Specifies an architecture typically used in BatDetect2:
|
||||
- Input: 1 channel, 128 frequency bins.
|
||||
- Layer 1: FreqCoordConvDown -> 32 channels, H=64
|
||||
- Layer 2: FreqCoordConvDown -> 64 channels, H=32
|
||||
- Layer 3: FreqCoordConvDown -> 128 channels, H=16
|
||||
- Layer 4: ConvBlock -> 256 channels, H=16 (Bottleneck)
|
||||
Assumes a 1-channel input with 128 frequency bins and produces the
|
||||
following feature maps:
|
||||
|
||||
- Stage 1 (``FreqCoordConvDown``): 32 channels, height 64.
|
||||
- Stage 2 (``FreqCoordConvDown``): 64 channels, height 32.
|
||||
- Stage 3 (``LayerGroup``):
|
||||
|
||||
- ``FreqCoordConvDown``: 128 channels, height 16.
|
||||
- ``ConvBlock``: 256 channels, height 16 (bottleneck input).
|
||||
"""
|
||||
|
||||
|
||||
def build_encoder(
|
||||
in_channels: int,
|
||||
input_height: int,
|
||||
config: Optional[EncoderConfig] = None,
|
||||
config: EncoderConfig | None = None,
|
||||
) -> Encoder:
|
||||
"""Factory function to build an Encoder instance from configuration.
|
||||
"""Build an ``Encoder`` from configuration.
|
||||
|
||||
Constructs a sequential `Encoder` module based on the layer sequence
|
||||
defined in an `EncoderConfig` object and the provided input dimensions.
|
||||
If no config is provided, uses the default layer sequence from
|
||||
`DEFAULT_ENCODER_CONFIG`.
|
||||
|
||||
It iteratively builds the layers using the unified
|
||||
`build_layer_from_config` factory (from `.blocks`), tracking the changing
|
||||
number of channels and feature map height required for each subsequent
|
||||
layer, especially for coordinate- aware blocks.
|
||||
Constructs a sequential ``Encoder`` by iterating over the block
|
||||
configurations in ``config.layers``, building each block with
|
||||
``build_layer``, and tracking the channel count and feature-map height
|
||||
as they change through the sequence.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
The number of channels expected in the input tensor to the encoder.
|
||||
Must be > 0.
|
||||
Number of channels in the input spectrogram tensor. Must be
|
||||
positive.
|
||||
input_height : int
|
||||
The height (frequency bins) expected in the input tensor. Must be > 0.
|
||||
Crucial for initializing coordinate-aware layers correctly.
|
||||
Height (number of frequency bins) of the input spectrogram.
|
||||
Must be positive and should be divisible by
|
||||
``2 ** (number of downsampling stages)`` to avoid size mismatches
|
||||
later in the network.
|
||||
config : EncoderConfig, optional
|
||||
The configuration object detailing the sequence of layers and their
|
||||
parameters. If None, `DEFAULT_ENCODER_CONFIG` is used.
|
||||
Configuration specifying the layer sequence. Defaults to
|
||||
``DEFAULT_ENCODER_CONFIG`` if not provided.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Encoder
|
||||
An initialized `Encoder` module.
|
||||
An initialised ``Encoder`` module.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `in_channels` or `input_height` are not positive, or if the layer
|
||||
configuration is invalid (e.g., empty list, unknown `name`).
|
||||
NotImplementedError
|
||||
If `build_layer_from_config` encounters an unknown `name`.
|
||||
If ``in_channels`` or ``input_height`` are not positive.
|
||||
KeyError
|
||||
If a layer configuration specifies an unknown block type.
|
||||
"""
|
||||
if in_channels <= 0 or input_height <= 0:
|
||||
raise ValueError("in_channels and input_height must be positive.")
|
||||
@ -302,12 +298,14 @@ def build_encoder(
|
||||
layers = []
|
||||
|
||||
for layer_config in config.layers:
|
||||
layer, current_channels, current_height = build_layer_from_config(
|
||||
layer = build_layer(
|
||||
in_channels=current_channels,
|
||||
input_height=current_height,
|
||||
config=layer_config,
|
||||
)
|
||||
layers.append(layer)
|
||||
current_height = layer.get_output_height(current_height)
|
||||
current_channels = layer.out_channels
|
||||
|
||||
return Encoder(
|
||||
input_height=input_height,
|
||||
|
||||
@ -1,20 +1,19 @@
|
||||
"""Prediction Head modules for BatDetect2 models.
|
||||
"""Prediction heads attached to the backbone feature map.
|
||||
|
||||
This module defines simple `torch.nn.Module` subclasses that serve as
|
||||
prediction heads, typically attached to the output feature map of a backbone
|
||||
network
|
||||
Each head is a lightweight ``torch.nn.Module`` that applies a 1×1
|
||||
convolution to map backbone feature channels to one specific type of
|
||||
output required by BatDetect2:
|
||||
|
||||
Each head is responsible for generating one specific type of output required
|
||||
by the BatDetect2 task:
|
||||
- `DetectorHead`: Predicts the probability of sound event presence.
|
||||
- `ClassifierHead`: Predicts the probability distribution over target classes.
|
||||
- `BBoxHead`: Predicts the size (width, height) of the sound event's bounding
|
||||
box.
|
||||
- ``DetectorHead``: single-channel detection probability heatmap (sigmoid
|
||||
activation).
|
||||
- ``ClassifierHead``: multi-class probability map over the target bat
|
||||
species / call types (softmax activation).
|
||||
- ``BBoxHead``: two-channel map of predicted call duration (time axis) and
|
||||
bandwidth (frequency axis) at each location (no activation; raw
|
||||
regression output).
|
||||
|
||||
These heads use 1x1 convolutions to map the backbone feature channels
|
||||
to the desired number of output channels for each prediction task at each
|
||||
spatial location, followed by an appropriate activation function (e.g., sigmoid
|
||||
for detection, softmax for classification, none for size regression).
|
||||
All three heads share the same input feature map produced by the backbone,
|
||||
so they can be evaluated in parallel in a single forward pass.
|
||||
"""
|
||||
|
||||
import torch
|
||||
@ -28,42 +27,35 @@ __all__ = [
|
||||
|
||||
|
||||
class ClassifierHead(nn.Module):
|
||||
"""Prediction head for multi-class classification probabilities.
|
||||
"""Prediction head for species / call-type classification probabilities.
|
||||
|
||||
Takes an input feature map and produces a probability map where each
|
||||
channel corresponds to a specific target class. It uses a 1x1 convolution
|
||||
to map input channels to `num_classes + 1` outputs (one for each target
|
||||
class plus an assumed background/generic class), applies softmax across the
|
||||
channels, and returns the probabilities for the specific target classes
|
||||
(excluding the last background/generic channel).
|
||||
Takes a backbone feature map and produces a probability map where each
|
||||
channel corresponds to a target class. Internally the 1×1 convolution
|
||||
maps ``in_channels`` to ``num_classes + 1`` logits (the extra channel
|
||||
represents a generic background / unknown category); a softmax is then
|
||||
applied across the channel dimension and the background channel is
|
||||
discarded before returning.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_classes : int
|
||||
The number of specific target classes the model should predict
|
||||
(excluding any background or generic category). Must be positive.
|
||||
Number of target classes (bat species or call types) to predict,
|
||||
excluding the background category. Must be positive.
|
||||
in_channels : int
|
||||
Number of channels in the input feature map tensor from the backbone.
|
||||
Must be positive.
|
||||
Number of channels in the backbone feature map. Must be positive.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
num_classes : int
|
||||
Number of specific output classes.
|
||||
Number of specific output classes (background excluded).
|
||||
in_channels : int
|
||||
Number of input channels expected.
|
||||
classifier : nn.Conv2d
|
||||
The 1x1 convolutional layer used for prediction.
|
||||
Output channels = num_classes + 1.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num_classes` or `in_channels` are not positive.
|
||||
1×1 convolution with ``num_classes + 1`` output channels.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes: int, in_channels: int):
|
||||
"""Initialize the ClassifierHead."""
|
||||
"""Initialise the ClassifierHead."""
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
@ -76,20 +68,20 @@ class ClassifierHead(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute class probabilities from input features.
|
||||
"""Compute per-class probabilities from backbone features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : torch.Tensor
|
||||
Input feature map tensor from the backbone, typically with shape
|
||||
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
||||
Backbone feature map, shape ``(B, C_in, H, W)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Class probability map tensor with shape `(B, num_classes, H, W)`.
|
||||
Contains probabilities for the specific target classes after
|
||||
softmax, excluding the implicit background/generic class channel.
|
||||
Class probability map, shape ``(B, num_classes, H, W)``.
|
||||
Values are softmax probabilities in the range [0, 1] and
|
||||
sum to less than 1 per location (the background probability
|
||||
is discarded).
|
||||
"""
|
||||
logits = self.classifier(features)
|
||||
probs = torch.softmax(logits, dim=1)
|
||||
@ -97,36 +89,30 @@ class ClassifierHead(nn.Module):
|
||||
|
||||
|
||||
class DetectorHead(nn.Module):
|
||||
"""Prediction head for sound event detection probability.
|
||||
"""Prediction head for detection probability (is a call present here?).
|
||||
|
||||
Takes an input feature map and produces a single-channel heatmap where
|
||||
each value represents the probability ([0, 1]) of a relevant sound event
|
||||
(of any class) being present at that spatial location.
|
||||
Produces a single-channel heatmap where each value indicates the
|
||||
probability ([0, 1]) that a bat call of *any* species is present at
|
||||
that time–frequency location in the spectrogram.
|
||||
|
||||
Uses a 1x1 convolution to map input channels to 1 output channel, followed
|
||||
by a sigmoid activation function.
|
||||
Applies a 1×1 convolution mapping ``in_channels`` → 1, followed by
|
||||
sigmoid activation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input feature map tensor from the backbone.
|
||||
Must be positive.
|
||||
Number of channels in the backbone feature map. Must be positive.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels expected.
|
||||
detector : nn.Conv2d
|
||||
The 1x1 convolutional layer mapping to a single output channel.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `in_channels` is not positive.
|
||||
1×1 convolution with a single output channel.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
"""Initialize the DetectorHead."""
|
||||
"""Initialise the DetectorHead."""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
@ -138,62 +124,49 @@ class DetectorHead(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute detection probabilities from input features.
|
||||
"""Compute detection probabilities from backbone features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : torch.Tensor
|
||||
Input feature map tensor from the backbone, typically with shape
|
||||
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
||||
Backbone feature map, shape ``(B, C_in, H, W)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Detection probability heatmap tensor with shape `(B, 1, H, W)`.
|
||||
Values are in the range [0, 1] due to the sigmoid activation.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If input channel count does not match `self.in_channels`.
|
||||
Detection probability heatmap, shape ``(B, 1, H, W)``.
|
||||
Values are in the range [0, 1].
|
||||
"""
|
||||
return torch.sigmoid(self.detector(features))
|
||||
|
||||
|
||||
class BBoxHead(nn.Module):
|
||||
"""Prediction head for bounding box size dimensions.
|
||||
"""Prediction head for bounding box size (duration and bandwidth).
|
||||
|
||||
Takes an input feature map and produces a two-channel map where each
|
||||
channel represents a predicted size dimension (typically width/duration and
|
||||
height/bandwidth) for a potential sound event at that spatial location.
|
||||
Produces a two-channel map where channel 0 predicts the scaled duration
|
||||
(time-axis extent) and channel 1 predicts the scaled bandwidth
|
||||
(frequency-axis extent) of the call at each spectrogram location.
|
||||
|
||||
Uses a 1x1 convolution to map input channels to 2 output channels. No
|
||||
activation function is typically applied, as size prediction is often
|
||||
treated as a direct regression task. The output values usually represent
|
||||
*scaled* dimensions that need to be un-scaled during postprocessing.
|
||||
Applies a 1×1 convolution mapping ``in_channels`` → 2 with no
|
||||
activation function (raw regression output). The predicted values are
|
||||
in a scaled space and must be converted to real units (seconds and Hz)
|
||||
during postprocessing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input feature map tensor from the backbone.
|
||||
Must be positive.
|
||||
Number of channels in the backbone feature map. Must be positive.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels expected.
|
||||
bbox : nn.Conv2d
|
||||
The 1x1 convolutional layer mapping to 2 output channels
|
||||
(width, height).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `in_channels` is not positive.
|
||||
1×1 convolution with 2 output channels (duration, bandwidth).
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
"""Initialize the BBoxHead."""
|
||||
"""Initialise the BBoxHead."""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
@ -205,19 +178,19 @@ class BBoxHead(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute predicted bounding box dimensions from input features.
|
||||
"""Predict call duration and bandwidth from backbone features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : torch.Tensor
|
||||
Input feature map tensor from the backbone, typically with shape
|
||||
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
||||
Backbone feature map, shape ``(B, C_in, H, W)``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Predicted size tensor with shape `(B, 2, H, W)`. Channel 0 usually
|
||||
represents scaled width, Channel 1 scaled height. These values
|
||||
need to be un-scaled during postprocessing.
|
||||
Size prediction tensor, shape ``(B, 2, H, W)``. Channel 0 is
|
||||
the predicted scaled duration; channel 1 is the predicted
|
||||
scaled bandwidth. Values must be rescaled to real units during
|
||||
postprocessing.
|
||||
"""
|
||||
return self.bbox(features)
|
||||
|
||||
90
src/batdetect2/models/types.py
Normal file
90
src/batdetect2/models/types.py
Normal file
@ -0,0 +1,90 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import NamedTuple, Protocol
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"BackboneModel",
|
||||
"BlockProtocol",
|
||||
"BottleneckProtocol",
|
||||
"DecoderProtocol",
|
||||
"DetectionModel",
|
||||
"EncoderDecoderModel",
|
||||
"EncoderProtocol",
|
||||
"ModelOutput",
|
||||
]
|
||||
|
||||
|
||||
class BlockProtocol(Protocol):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
|
||||
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def get_output_height(self, input_height: int) -> int: ...
|
||||
|
||||
|
||||
class EncoderProtocol(Protocol):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
output_height: int
|
||||
|
||||
def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ...
|
||||
|
||||
|
||||
class BottleneckProtocol(Protocol):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
|
||||
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class DecoderProtocol(Protocol):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
output_height: int
|
||||
depth: int
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residuals: list[torch.Tensor],
|
||||
) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
detection_probs: torch.Tensor
|
||||
size_preds: torch.Tensor
|
||||
class_probs: torch.Tensor
|
||||
features: torch.Tensor
|
||||
|
||||
|
||||
class BackboneModel(ABC, torch.nn.Module):
|
||||
input_height: int
|
||||
out_channels: int
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EncoderDecoderModel(BackboneModel):
|
||||
bottleneck_channels: int
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, encoded: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class DetectionModel(ABC, torch.nn.Module):
|
||||
backbone: BackboneModel
|
||||
classifier_head: torch.nn.Module
|
||||
bbox_head: torch.nn.Module
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput: ...
|
||||
35
src/batdetect2/outputs/__init__.py
Normal file
35
src/batdetect2/outputs/__init__.py
Normal file
@ -0,0 +1,35 @@
|
||||
from batdetect2.outputs.config import OutputsConfig
|
||||
from batdetect2.outputs.formats import (
|
||||
BatDetect2OutputConfig,
|
||||
OutputFormatConfig,
|
||||
ParquetOutputConfig,
|
||||
RawOutputConfig,
|
||||
SoundEventOutputConfig,
|
||||
build_output_formatter,
|
||||
get_output_formatter,
|
||||
load_predictions,
|
||||
)
|
||||
from batdetect2.outputs.transforms import (
|
||||
OutputTransformConfig,
|
||||
build_output_transform,
|
||||
)
|
||||
from batdetect2.outputs.types import (
|
||||
OutputFormatterProtocol,
|
||||
OutputTransformProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BatDetect2OutputConfig",
|
||||
"OutputFormatConfig",
|
||||
"OutputFormatterProtocol",
|
||||
"OutputTransformConfig",
|
||||
"OutputTransformProtocol",
|
||||
"OutputsConfig",
|
||||
"ParquetOutputConfig",
|
||||
"RawOutputConfig",
|
||||
"SoundEventOutputConfig",
|
||||
"build_output_formatter",
|
||||
"build_output_transform",
|
||||
"get_output_formatter",
|
||||
"load_predictions",
|
||||
]
|
||||
15
src/batdetect2/outputs/config.py
Normal file
15
src/batdetect2/outputs/config.py
Normal file
@ -0,0 +1,15 @@
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.outputs.formats import OutputFormatConfig
|
||||
from batdetect2.outputs.formats.raw import RawOutputConfig
|
||||
from batdetect2.outputs.transforms import OutputTransformConfig
|
||||
|
||||
__all__ = ["OutputsConfig"]
|
||||
|
||||
|
||||
class OutputsConfig(BaseConfig):
|
||||
format: OutputFormatConfig = Field(default_factory=RawOutputConfig)
|
||||
transform: OutputTransformConfig = Field(
|
||||
default_factory=OutputTransformConfig
|
||||
)
|
||||
@ -1,39 +1,42 @@
|
||||
from typing import Annotated, Optional, Union
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.data.predictions.base import (
|
||||
from batdetect2.outputs.formats.base import (
|
||||
OutputFormatterProtocol,
|
||||
prediction_formatters,
|
||||
output_formatters,
|
||||
)
|
||||
from batdetect2.data.predictions.batdetect2 import BatDetect2OutputConfig
|
||||
from batdetect2.data.predictions.raw import RawOutputConfig
|
||||
from batdetect2.data.predictions.soundevent import SoundEventOutputConfig
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.outputs.formats.batdetect2 import BatDetect2OutputConfig
|
||||
from batdetect2.outputs.formats.parquet import ParquetOutputConfig
|
||||
from batdetect2.outputs.formats.raw import RawOutputConfig
|
||||
from batdetect2.outputs.formats.soundevent import SoundEventOutputConfig
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"build_output_formatter",
|
||||
"get_output_formatter",
|
||||
"BatDetect2OutputConfig",
|
||||
"OutputFormatConfig",
|
||||
"ParquetOutputConfig",
|
||||
"RawOutputConfig",
|
||||
"SoundEventOutputConfig",
|
||||
"build_output_formatter",
|
||||
"get_output_formatter",
|
||||
"load_predictions",
|
||||
]
|
||||
|
||||
|
||||
OutputFormatConfig = Annotated[
|
||||
Union[
|
||||
BatDetect2OutputConfig,
|
||||
SoundEventOutputConfig,
|
||||
RawOutputConfig,
|
||||
],
|
||||
BatDetect2OutputConfig
|
||||
| ParquetOutputConfig
|
||||
| SoundEventOutputConfig
|
||||
| RawOutputConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_output_formatter(
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
config: Optional[OutputFormatConfig] = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
config: OutputFormatConfig | None = None,
|
||||
) -> OutputFormatterProtocol:
|
||||
"""Construct the final output formatter."""
|
||||
from batdetect2.targets import build_targets
|
||||
@ -41,13 +44,13 @@ def build_output_formatter(
|
||||
config = config or RawOutputConfig()
|
||||
|
||||
targets = targets or build_targets()
|
||||
return prediction_formatters.build(config, targets)
|
||||
return output_formatters.build(config, targets)
|
||||
|
||||
|
||||
def get_output_formatter(
|
||||
name: Optional[str] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
config: Optional[OutputFormatConfig] = None,
|
||||
name: str | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
config: OutputFormatConfig | None = None,
|
||||
) -> OutputFormatterProtocol:
|
||||
"""Get the output formatter by name."""
|
||||
|
||||
@ -55,7 +58,7 @@ def get_output_formatter(
|
||||
if name is None:
|
||||
raise ValueError("Either config or name must be provided.")
|
||||
|
||||
config_class = prediction_formatters.get_config_type(name)
|
||||
config_class = output_formatters.get_config_type(name)
|
||||
config = config_class() # type: ignore
|
||||
|
||||
if config.name != name: # type: ignore
|
||||
@ -68,9 +71,9 @@ def get_output_formatter(
|
||||
|
||||
def load_predictions(
|
||||
path: PathLike,
|
||||
format: Optional[str] = "raw",
|
||||
config: Optional[OutputFormatConfig] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
format: str | None = "raw",
|
||||
config: OutputFormatConfig | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
):
|
||||
"""Load predictions from a file."""
|
||||
from batdetect2.targets import build_targets
|
||||
46
src/batdetect2/outputs/formats/base.py
Normal file
46
src/batdetect2/outputs/formats/base.py
Normal file
@ -0,0 +1,46 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"OutputFormatterProtocol",
|
||||
"PredictionFormatterImportConfig",
|
||||
"make_path_relative",
|
||||
"output_formatters",
|
||||
]
|
||||
|
||||
|
||||
def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
|
||||
path = Path(path)
|
||||
audio_dir = Path(audio_dir)
|
||||
|
||||
if path.is_absolute():
|
||||
if not path.is_relative_to(audio_dir):
|
||||
raise ValueError(
|
||||
f"Audio file {path} is not in audio_dir {audio_dir}"
|
||||
)
|
||||
|
||||
return path.relative_to(audio_dir)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
output_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
|
||||
Registry(name="output_formatter")
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(output_formatters)
|
||||
class PredictionFormatterImportConfig(ImportConfig):
|
||||
"""Use any callable as a prediction formatter.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user