mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-09 09:40:19 +02:00
Compare commits
No commits in common. "5a14b29281f9bb8df90d80cca0e89235b8fcd4ae" and "a4a5a10da184695a8a8f81004f8ce84dc131ece8" have entirely different histories.
5a14b29281
...
a4a5a10da1
7
.gitignore
vendored
7
.gitignore
vendored
@ -102,7 +102,7 @@ experiments/*
|
|||||||
DvcLiveLogger/checkpoints
|
DvcLiveLogger/checkpoints
|
||||||
logs/
|
logs/
|
||||||
mlruns/
|
mlruns/
|
||||||
/outputs/
|
outputs/
|
||||||
notebooks/lightning_logs
|
notebooks/lightning_logs
|
||||||
|
|
||||||
# Jupiter notebooks
|
# Jupiter notebooks
|
||||||
@ -123,8 +123,3 @@ example_data/preprocessed
|
|||||||
|
|
||||||
# Dev notebooks
|
# Dev notebooks
|
||||||
notebooks/tmp
|
notebooks/tmp
|
||||||
/tmp
|
|
||||||
/.agents/skills
|
|
||||||
/notebooks
|
|
||||||
/AGENTS.md
|
|
||||||
/scripts
|
|
||||||
|
|||||||
@ -1,93 +0,0 @@
|
|||||||
# BatDetect2 Architecture Overview
|
|
||||||
|
|
||||||
This document provides a comprehensive map of the `batdetect2` codebase architecture. It is intended to serve as a deep-dive reference for developers, agents, and contributors navigating the project.
|
|
||||||
|
|
||||||
`batdetect2` is designed as a modular deep learning pipeline for detecting and classifying bat echolocation calls in high-frequency audio recordings. It heavily utilizes **PyTorch**, **PyTorch Lightning** for training, and the **Soundevent** library for standardized audio and geometry data classes.
|
|
||||||
|
|
||||||
The repository follows a configuration-driven design pattern, heavily utilizing `pydantic`/`omegaconf` (via `BaseConfig`) and the Factory/Registry patterns for dependency injection and modularity. The entire pipeline can be orchestrated via the high-level API `BatDetect2API` (`src/batdetect2/api_v2.py`).
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 1. Data Flow Pipeline
|
|
||||||
|
|
||||||
The standard lifecycle of a prediction request follows these sequential stages, each handled by an isolated, replaceable module:
|
|
||||||
|
|
||||||
1. **Audio Loading (`batdetect2.audio`)**: Read raw `.wav` files into standard NumPy arrays or `soundevent.data.Clip` objects. Handles resampling.
|
|
||||||
2. **Preprocessing (`batdetect2.preprocess`)**: Converts raw 1D waveforms into 2D Spectrogram tensors.
|
|
||||||
3. **Forward Pass (`batdetect2.models`)**: A PyTorch neural network processes the spectrogram and outputs dense prediction tensors (e.g., detection heatmaps, bounding box sizes, class probabilities).
|
|
||||||
4. **Postprocessing (`batdetect2.postprocess`)**: Decodes the raw output tensors back into explicit geometry bounding boxes and runs Non-Maximum Suppression (NMS) to filter redundant predictions.
|
|
||||||
5. **Formatting (`batdetect2.data`)**: Transforms the predictions into standard formats (`.csv`, `.json`, `.parquet`) using `OutputFormatterProtocol`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 2. Core Modules Breakdown
|
|
||||||
|
|
||||||
### 2.1 Audio and Preprocessing
|
|
||||||
- **`audio/`**:
|
|
||||||
- Centralizes audio I/O using `AudioLoader`. It abstracts over the `soundevent` library, efficiently handling full `Recording` files or smaller `Clip` segments, standardizing the sample rate.
|
|
||||||
- **`preprocess/`**:
|
|
||||||
- Dictated by the `PreprocessorProtocol`.
|
|
||||||
- Its primary responsibility is spectrogram generation via Short-Time Fourier Transform (STFT).
|
|
||||||
- During training, it incorporates data augmentation layers (e.g., amplitude scaling, time masking, frequency masking, spectral mean subtraction) configured via `PreprocessingConfig`.
|
|
||||||
|
|
||||||
### 2.2 Deep Learning Models (`models/`)
|
|
||||||
The `models` directory contains all PyTorch neural network architectures. The default architecture is an Encoder-Decoder (U-Net style) network.
|
|
||||||
- **`blocks.py`**: Reusable neural network blocks, including standard Convolutions (`ConvBlock`) and specialized layers like `FreqCoordConvDownBlock`/`FreqCoordConvUpBlock` which append normalized spatial frequency coordinates to explicitly grant convolutional filters frequency-awareness.
|
|
||||||
- **`encoder.py`**: The downsampling path (feature extraction). Builds a sequential list of blocks and captures skip connections.
|
|
||||||
- **`bottleneck.py`**: The deepest, lowest-resolution segment connecting the Encoder and Decoder. Features an optional `SelfAttention` mechanism to weigh global temporal contexts.
|
|
||||||
- **`decoder.py`**: The upsampling path (reconstruction), actively integrating skip connections (residuals) from the Encoder.
|
|
||||||
- **`heads.py`**: Attach to the backbone's feature map to output specific predictions:
|
|
||||||
- `BBoxHead`: Predicts bounding box sizes.
|
|
||||||
- `ClassifierHead`: Predicts species classes.
|
|
||||||
- `DetectorHead`: Predicts detection probability heatmaps.
|
|
||||||
- **`backbones.py` & `detectors.py`**: Assemble the encoder, bottleneck, decoder, and heads into a cohesive `Detector` model.
|
|
||||||
- **`__init__.py:Model`**: The overarching wrapper `torch.nn.Module` containing the `detector`, `preprocessor`, `postprocessor`, and `targets`.
|
|
||||||
|
|
||||||
### 2.3 Targets and Regions of Interest (`targets/`)
|
|
||||||
Crucial for training, this module translates physical annotations (Regions of Interest / ROIs) into training targets (tensors).
|
|
||||||
- **`rois.py`**: Implements `ROITargetMapper`. Maps a geometric bounding box into a 2D reference `Position` (time, freq) and a `Size` array. Includes strategies like:
|
|
||||||
- `AnchorBBoxMapper`: Maps based on a fixed bounding box corner/center.
|
|
||||||
- `PeakEnergyBBoxMapper`: Identifies the physical coordinate of peak acoustic energy inside the bounding box and calculates offsets to the box edges.
|
|
||||||
- **`targets.py`**: Reconstructs complete multi-channel target heatmaps and coordinate tensors from the ROIs to compute losses during training.
|
|
||||||
|
|
||||||
### 2.4 Postprocessing (`postprocess/`)
|
|
||||||
- Implements `PostprocessorProtocol`.
|
|
||||||
- Reverses the logic from `targets`. It scans the model's output detection heatmaps for peaks, extracts the predicted sizes and class probabilities at those peaks, and decodes them back into physical `soundevent.data.Geometry` (Bounding Boxes).
|
|
||||||
- Automatically applies Non-Maximum Suppression (NMS) configured via `PostprocessConfig` to remove highly overlapping predictions.
|
|
||||||
|
|
||||||
### 2.5 Data Management (`data/`)
|
|
||||||
- **`annotations/`**: Utilities to load dataset annotations supporting multiple standardized schemas (`AOEF`, `BatDetect2` formats).
|
|
||||||
- **`datasets.py`**: Aggregates recordings and annotations into memory.
|
|
||||||
- **`predictions/`**: Handles the exporting of model results via `OutputFormatterProtocol`. Includes formatters for `RawOutput`, `.parquet`, `.json`, etc.
|
|
||||||
|
|
||||||
### 2.6 Evaluation (`evaluate/`)
|
|
||||||
- Computes scientific metrics using `EvaluatorProtocol`.
|
|
||||||
- Provides specific testing environments for tasks like `Clip Classification`, `Clip Detection`, and `Top Class` predictions.
|
|
||||||
- Generates precision-recall curves and scatter plots.
|
|
||||||
|
|
||||||
### 2.7 Training (`train/`)
|
|
||||||
- Implements the distributed PyTorch training loop via PyTorch Lightning.
|
|
||||||
- **`lightning.py`**: Contains `TrainingModule`, the `LightningModule` that orchestrates the optimizer, learning rate scheduler, forward passes, and backpropagation using the generated `targets`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 3. Interfaces and Tooling
|
|
||||||
|
|
||||||
### 3.1 APIs
|
|
||||||
- **`api_v2.py` (`BatDetect2API`)**: The modern API object. It is deeply integrated with dependency injection using `BatDetect2Config`. It instantiates the loader, targets, preprocessor, postprocessor, and model, exposing easy-to-use methods like `process_file`, `evaluate`, and `train`.
|
|
||||||
- **`api.py`**: The legacy API. Kept for backwards compatibility. Uses hardcoded default instances rather than configuration objects.
|
|
||||||
|
|
||||||
### 3.2 Command Line Interface (`cli/`)
|
|
||||||
- Implements terminal commands utilizing `click`. Commands include `batdetect2 detect`, `evaluate`, and `train`.
|
|
||||||
|
|
||||||
### 3.3 Core and Configuration (`core/`, `config.py`)
|
|
||||||
- **`core/registries.py`**: A string-based Registry pattern (e.g., `block_registry`, `roi_mapper_registry`) that allows developers to dynamically swap components (like a custom neural network block) via configuration files without modifying python code.
|
|
||||||
- **`config.py`**: Aggregates all modular `BaseConfig` objects (`AudioConfig`, `PreprocessingConfig`, `BackboneConfig`) into the monolithic `BatDetect2Config`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Summary
|
|
||||||
To navigate this codebase effectively:
|
|
||||||
1. Follow **`api_v2.py`** to see how high-level operations invoke individual components.
|
|
||||||
2. Rely heavily on the typed **Protocols** located in each subsystem's `types.py` module (for example `src/batdetect2/preprocess/types.py` and `src/batdetect2/postprocess/types.py`) to understand inputs and outputs without needing to read each implementation.
|
|
||||||
3. Understand that data flows structurally as `soundevent` primitives externally, and as pure `torch.Tensor` internally through the network.
|
|
||||||
@ -6,7 +6,6 @@ Hi!
|
|||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
:caption: Contents:
|
:caption: Contents:
|
||||||
|
|
||||||
architecture
|
|
||||||
data/index
|
data/index
|
||||||
preprocessing/index
|
preprocessing/index
|
||||||
postprocessing
|
postprocessing
|
||||||
|
|||||||
@ -1,13 +1,8 @@
|
|||||||
config_version: v1
|
|
||||||
|
|
||||||
audio:
|
audio:
|
||||||
samplerate: 256000
|
samplerate: 256000
|
||||||
resample:
|
resample:
|
||||||
enabled: true
|
enabled: True
|
||||||
method: poly
|
method: "poly"
|
||||||
|
|
||||||
model:
|
|
||||||
samplerate: 256000
|
|
||||||
|
|
||||||
preprocess:
|
preprocess:
|
||||||
stft:
|
stft:
|
||||||
@ -26,12 +21,17 @@ model:
|
|||||||
gain: 0.98
|
gain: 0.98
|
||||||
bias: 2
|
bias: 2
|
||||||
power: 0.5
|
power: 0.5
|
||||||
- name: spectral_mean_subtraction
|
- name: spectral_mean_substraction
|
||||||
|
|
||||||
architecture:
|
postprocess:
|
||||||
name: UNetBackbone
|
nms_kernel_size: 9
|
||||||
|
detection_threshold: 0.01
|
||||||
|
top_k_per_sec: 200
|
||||||
|
|
||||||
|
model:
|
||||||
input_height: 128
|
input_height: 128
|
||||||
in_channels: 1
|
in_channels: 1
|
||||||
|
out_channels: 32
|
||||||
encoder:
|
encoder:
|
||||||
layers:
|
layers:
|
||||||
- name: FreqCoordConvDown
|
- name: FreqCoordConvDown
|
||||||
@ -62,18 +62,9 @@ model:
|
|||||||
- name: ConvBlock
|
- name: ConvBlock
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
|
|
||||||
postprocess:
|
|
||||||
nms_kernel_size: 9
|
|
||||||
detection_threshold: 0.01
|
|
||||||
top_k_per_sec: 200
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
optimizer:
|
optimizer:
|
||||||
name: adam
|
|
||||||
learning_rate: 0.001
|
learning_rate: 0.001
|
||||||
|
|
||||||
scheduler:
|
|
||||||
name: cosine_annealing
|
|
||||||
t_max: 100
|
t_max: 100
|
||||||
|
|
||||||
labels:
|
labels:
|
||||||
@ -85,7 +76,10 @@ train:
|
|||||||
|
|
||||||
train_loader:
|
train_loader:
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
shuffle: true
|
|
||||||
|
num_workers: 2
|
||||||
|
|
||||||
|
shuffle: True
|
||||||
|
|
||||||
clipping_strategy:
|
clipping_strategy:
|
||||||
name: random_subclip
|
name: random_subclip
|
||||||
@ -121,6 +115,7 @@ train:
|
|||||||
max_masks: 3
|
max_masks: 3
|
||||||
|
|
||||||
val_loader:
|
val_loader:
|
||||||
|
num_workers: 2
|
||||||
clipping_strategy:
|
clipping_strategy:
|
||||||
name: whole_audio_padded
|
name: whole_audio_padded
|
||||||
chunk_size: 0.256
|
chunk_size: 0.256
|
||||||
@ -139,6 +134,9 @@ train:
|
|||||||
size:
|
size:
|
||||||
weight: 0.1
|
weight: 0.1
|
||||||
|
|
||||||
|
logger:
|
||||||
|
name: csv
|
||||||
|
|
||||||
validation:
|
validation:
|
||||||
tasks:
|
tasks:
|
||||||
- name: sound_event_detection
|
- name: sound_event_detection
|
||||||
@ -148,10 +146,6 @@ train:
|
|||||||
metrics:
|
metrics:
|
||||||
- name: average_precision
|
- name: average_precision
|
||||||
|
|
||||||
logging:
|
|
||||||
train:
|
|
||||||
name: csv
|
|
||||||
|
|
||||||
evaluation:
|
evaluation:
|
||||||
tasks:
|
tasks:
|
||||||
- name: sound_event_detection
|
- name: sound_event_detection
|
||||||
|
|||||||
49
justfile
49
justfile
@ -14,67 +14,60 @@ HTML_COVERAGE_DIR := "htmlcov"
|
|||||||
help:
|
help:
|
||||||
@just --list
|
@just --list
|
||||||
|
|
||||||
install:
|
|
||||||
uv sync
|
|
||||||
|
|
||||||
# Testing & Coverage
|
# Testing & Coverage
|
||||||
# Run tests using pytest.
|
# Run tests using pytest.
|
||||||
test:
|
test:
|
||||||
uv run pytest {{TESTS_DIR}}
|
pytest {{TESTS_DIR}}
|
||||||
|
|
||||||
# Run tests and generate coverage data.
|
# Run tests and generate coverage data.
|
||||||
coverage:
|
coverage:
|
||||||
uv run pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
|
pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml {{TESTS_DIR}}
|
||||||
|
|
||||||
# Generate an HTML coverage report.
|
# Generate an HTML coverage report.
|
||||||
coverage-html: coverage
|
coverage-html: coverage
|
||||||
@echo "Generating HTML coverage report..."
|
@echo "Generating HTML coverage report..."
|
||||||
uv run coverage html -d {{HTML_COVERAGE_DIR}}
|
coverage html -d {{HTML_COVERAGE_DIR}}
|
||||||
@echo "HTML coverage report generated in {{HTML_COVERAGE_DIR}}/"
|
@echo "HTML coverage report generated in {{HTML_COVERAGE_DIR}}/"
|
||||||
|
|
||||||
# Serve the HTML coverage report locally.
|
# Serve the HTML coverage report locally.
|
||||||
coverage-serve: coverage-html
|
coverage-serve: coverage-html
|
||||||
@echo "Serving report at http://localhost:8000/ ..."
|
@echo "Serving report at http://localhost:8000/ ..."
|
||||||
uv run python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
|
python -m http.server --directory {{HTML_COVERAGE_DIR}} 8000
|
||||||
|
|
||||||
# Documentation
|
# Documentation
|
||||||
# Build documentation using Sphinx.
|
# Build documentation using Sphinx.
|
||||||
docs:
|
docs:
|
||||||
uv run sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
|
sphinx-build -b html {{DOCS_SOURCE}} {{DOCS_BUILD}}
|
||||||
|
|
||||||
# Serve documentation with live reload.
|
# Serve documentation with live reload.
|
||||||
docs-serve:
|
docs-serve:
|
||||||
uv run sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
|
sphinx-autobuild {{DOCS_SOURCE}} {{DOCS_BUILD}} --watch {{SOURCE_DIR}} --open-browser
|
||||||
|
|
||||||
# Formatting & Linting
|
# Formatting & Linting
|
||||||
# Format code using ruff.
|
# Format code using ruff.
|
||||||
fix-format:
|
format:
|
||||||
uv run ruff format {{PYTHON_DIRS}}
|
ruff format {{PYTHON_DIRS}}
|
||||||
|
|
||||||
# Lint code using ruff and apply automatic fixes.
|
|
||||||
fix-lint:
|
|
||||||
uv run ruff check --fix {{PYTHON_DIRS}}
|
|
||||||
|
|
||||||
# Combined Formatting & Linting
|
|
||||||
fix: fix-format fix-lint
|
|
||||||
|
|
||||||
# Checking tasks
|
|
||||||
# Check code formatting using ruff.
|
# Check code formatting using ruff.
|
||||||
check-format:
|
format-check:
|
||||||
uv run ruff format --check {{PYTHON_DIRS}}
|
ruff format --check {{PYTHON_DIRS}}
|
||||||
|
|
||||||
# Lint code using ruff.
|
# Lint code using ruff.
|
||||||
check-lint:
|
lint:
|
||||||
uv run ruff check {{PYTHON_DIRS}}
|
ruff check {{PYTHON_DIRS}}
|
||||||
|
|
||||||
|
# Lint code using ruff and apply automatic fixes.
|
||||||
|
lint-fix:
|
||||||
|
ruff check --fix {{PYTHON_DIRS}}
|
||||||
|
|
||||||
# Type Checking
|
# Type Checking
|
||||||
# Type check code using ty.
|
# Type check code using pyright.
|
||||||
check-types:
|
typecheck:
|
||||||
uv run ty check {{PYTHON_DIRS}}
|
pyright {{PYTHON_DIRS}}
|
||||||
|
|
||||||
# Combined Checks
|
# Combined Checks
|
||||||
# Run all checks (format-check, lint, typecheck).
|
# Run all checks (format-check, lint, typecheck).
|
||||||
check: check-format check-lint check-types
|
check: format-check lint typecheck test
|
||||||
|
|
||||||
# Cleaning tasks
|
# Cleaning tasks
|
||||||
# Remove Python bytecode and cache.
|
# Remove Python bytecode and cache.
|
||||||
@ -102,7 +95,7 @@ clean: clean-build clean-pyc clean-test clean-docs
|
|||||||
|
|
||||||
# Train on example data.
|
# Train on example data.
|
||||||
example-train OPTIONS="":
|
example-train OPTIONS="":
|
||||||
uv run batdetect2 train \
|
batdetect2 train \
|
||||||
--val-dataset example_data/dataset.yaml \
|
--val-dataset example_data/dataset.yaml \
|
||||||
--config example_data/config.yaml \
|
--config example_data/config.yaml \
|
||||||
{{OPTIONS}} \
|
{{OPTIONS}} \
|
||||||
|
|||||||
743
notebooks/Augmentations.ipynb
Normal file
743
notebooks/Augmentations.ipynb
Normal file
File diff suppressed because one or more lines are too long
1166
notebooks/Migrations.ipynb
Normal file
1166
notebooks/Migrations.ipynb
Normal file
File diff suppressed because one or more lines are too long
801
notebooks/Training Preprocess.ipynb
Normal file
801
notebooks/Training Preprocess.ipynb
Normal file
File diff suppressed because one or more lines are too long
440
notebooks/Training.ipynb
Normal file
440
notebooks/Training.ipynb
Normal file
@ -0,0 +1,440 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "cfb0b360-a204-4c27-a18f-3902e8758879",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2024-11-19T17:33:02.699871Z",
|
||||||
|
"iopub.status.busy": "2024-11-19T17:33:02.699590Z",
|
||||||
|
"iopub.status.idle": "2024-11-19T17:33:02.710312Z",
|
||||||
|
"shell.execute_reply": "2024-11-19T17:33:02.709798Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-19T17:33:02.699839Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"%load_ext autoreload\n",
|
||||||
|
"%autoreload 2"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "326c5432-94e6-4abf-a332-fe902559461b",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2024-11-19T17:33:02.711324Z",
|
||||||
|
"iopub.status.busy": "2024-11-19T17:33:02.711067Z",
|
||||||
|
"iopub.status.idle": "2024-11-19T17:33:09.092380Z",
|
||||||
|
"shell.execute_reply": "2024-11-19T17:33:09.091830Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-19T17:33:02.711304Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"/home/santiago/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||||
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"from typing import List, Optional\n",
|
||||||
|
"import torch\n",
|
||||||
|
"\n",
|
||||||
|
"import pytorch_lightning as pl\n",
|
||||||
|
"from batdetect2.train.modules import DetectorModel\n",
|
||||||
|
"from batdetect2.train.augmentations import (\n",
|
||||||
|
" add_echo,\n",
|
||||||
|
" select_random_subclip,\n",
|
||||||
|
" warp_spectrogram,\n",
|
||||||
|
")\n",
|
||||||
|
"from batdetect2.train.dataset import LabeledDataset, get_files\n",
|
||||||
|
"from batdetect2.train.preprocess import PreprocessingConfig\n",
|
||||||
|
"from soundevent import data\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"from soundevent.types import ClassMapper\n",
|
||||||
|
"from torch.utils.data import DataLoader"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "9402a473-0b25-4123-9fa8-ad1f71a4237a",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2024-11-18T22:39:12.395329Z",
|
||||||
|
"iopub.status.busy": "2024-11-18T22:39:12.393444Z",
|
||||||
|
"iopub.status.idle": "2024-11-18T22:39:12.405938Z",
|
||||||
|
"shell.execute_reply": "2024-11-18T22:39:12.402980Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-18T22:39:12.395236Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"## Training Datasets"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "cfd97d83-8c2b-46c8-9eae-cea59f53bc61",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2024-11-19T17:33:09.093487Z",
|
||||||
|
"iopub.status.busy": "2024-11-19T17:33:09.092990Z",
|
||||||
|
"iopub.status.idle": "2024-11-19T17:33:09.121636Z",
|
||||||
|
"shell.execute_reply": "2024-11-19T17:33:09.121143Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-19T17:33:09.093459Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"data_dir = Path.cwd().parent / \"example_data\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "d5131ae9-2efd-4758-b6e5-189a6d90789b",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2024-11-19T17:33:09.122685Z",
|
||||||
|
"iopub.status.busy": "2024-11-19T17:33:09.122270Z",
|
||||||
|
"iopub.status.idle": "2024-11-19T17:33:09.151386Z",
|
||||||
|
"shell.execute_reply": "2024-11-19T17:33:09.150788Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-19T17:33:09.122661Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"files = get_files(data_dir / \"preprocessed\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "bc733d3d-7829-4e90-896d-a0dc76b33288",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2024-11-19T17:33:09.152327Z",
|
||||||
|
"iopub.status.busy": "2024-11-19T17:33:09.152060Z",
|
||||||
|
"iopub.status.idle": "2024-11-19T17:33:09.184041Z",
|
||||||
|
"shell.execute_reply": "2024-11-19T17:33:09.183372Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-19T17:33:09.152305Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"train_dataset = LabeledDataset(files)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "dfbb94ab-7b12-4689-9c15-4dc34cd17cb2",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2024-11-19T17:33:09.186393Z",
|
||||||
|
"iopub.status.busy": "2024-11-19T17:33:09.186117Z",
|
||||||
|
"iopub.status.idle": "2024-11-19T17:33:09.220175Z",
|
||||||
|
"shell.execute_reply": "2024-11-19T17:33:09.219322Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-19T17:33:09.186375Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"train_dataloader = DataLoader(\n",
|
||||||
|
" train_dataset,\n",
|
||||||
|
" shuffle=True,\n",
|
||||||
|
" batch_size=32,\n",
|
||||||
|
" num_workers=4,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "e2eedaa9-6be3-481a-8786-7618515d98f8",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2024-11-19T17:33:09.221653Z",
|
||||||
|
"iopub.status.busy": "2024-11-19T17:33:09.221242Z",
|
||||||
|
"iopub.status.idle": "2024-11-19T17:33:09.260977Z",
|
||||||
|
"shell.execute_reply": "2024-11-19T17:33:09.260375Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-19T17:33:09.221616Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# List of all possible classes\n",
|
||||||
|
"class Mapper(ClassMapper):\n",
|
||||||
|
" class_labels = [\n",
|
||||||
|
" \"Eptesicus serotinus\",\n",
|
||||||
|
" \"Myotis mystacinus\",\n",
|
||||||
|
" \"Pipistrellus pipistrellus\",\n",
|
||||||
|
" \"Rhinolophus ferrumequinum\",\n",
|
||||||
|
" ]\n",
|
||||||
|
"\n",
|
||||||
|
" def encode(self, x: data.SoundEventAnnotation) -> Optional[str]:\n",
|
||||||
|
" event_tag = data.find_tag(x.tags, \"event\")\n",
|
||||||
|
"\n",
|
||||||
|
" if event_tag.value == \"Social\":\n",
|
||||||
|
" return \"social\"\n",
|
||||||
|
"\n",
|
||||||
|
" if event_tag.value != \"Echolocation\":\n",
|
||||||
|
" # Ignore all other types of calls\n",
|
||||||
|
" return None\n",
|
||||||
|
"\n",
|
||||||
|
" species_tag = data.find_tag(x.tags, \"class\")\n",
|
||||||
|
" return species_tag.value\n",
|
||||||
|
"\n",
|
||||||
|
" def decode(self, class_name: str) -> List[data.Tag]:\n",
|
||||||
|
" if class_name == \"social\":\n",
|
||||||
|
" return [data.Tag(key=\"event\", value=\"social\")]\n",
|
||||||
|
"\n",
|
||||||
|
" return [data.Tag(key=\"class\", value=class_name)]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"id": "1ff6072c-511e-42fe-a74f-282f269b80f0",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2024-11-19T17:33:09.262337Z",
|
||||||
|
"iopub.status.busy": "2024-11-19T17:33:09.261775Z",
|
||||||
|
"iopub.status.idle": "2024-11-19T17:33:09.309793Z",
|
||||||
|
"shell.execute_reply": "2024-11-19T17:33:09.309216Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-19T17:33:09.262307Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"detector = DetectorModel(class_mapper=Mapper())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"id": "3a763ee6-15bc-4105-a409-f06e0ad21a06",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2024-11-19T17:33:09.310695Z",
|
||||||
|
"iopub.status.busy": "2024-11-19T17:33:09.310438Z",
|
||||||
|
"iopub.status.idle": "2024-11-19T17:33:09.366636Z",
|
||||||
|
"shell.execute_reply": "2024-11-19T17:33:09.366059Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-19T17:33:09.310669Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"GPU available: False, used: False\n",
|
||||||
|
"TPU available: False, using: 0 TPU cores\n",
|
||||||
|
"HPU available: False, using: 0 HPUs\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"trainer = pl.Trainer(\n",
|
||||||
|
" limit_train_batches=100,\n",
|
||||||
|
" max_epochs=2,\n",
|
||||||
|
" log_every_n_steps=1,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "0b86d49d-3314-4257-94f5-f964855be385",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.execute_input": "2024-11-19T17:33:09.367499Z",
|
||||||
|
"iopub.status.busy": "2024-11-19T17:33:09.367242Z",
|
||||||
|
"iopub.status.idle": "2024-11-19T17:33:10.811300Z",
|
||||||
|
"shell.execute_reply": "2024-11-19T17:33:10.809823Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-19T17:33:09.367473Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"\n",
|
||||||
|
" | Name | Type | Params | Mode \n",
|
||||||
|
"--------------------------------------------------------\n",
|
||||||
|
"0 | feature_extractor | Net2DFast | 119 K | train\n",
|
||||||
|
"1 | classifier | Conv2d | 54 | train\n",
|
||||||
|
"2 | bbox | Conv2d | 18 | train\n",
|
||||||
|
"--------------------------------------------------------\n",
|
||||||
|
"119 K Trainable params\n",
|
||||||
|
"448 Non-trainable params\n",
|
||||||
|
"119 K Total params\n",
|
||||||
|
"0.480 Total estimated model params size (MB)\n",
|
||||||
|
"32 Modules in train mode\n",
|
||||||
|
"0 Modules in eval mode\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch 0: 0%| | 0/1 [00:00<?, ?it/s]class heatmap shape torch.Size([3, 4, 128, 512])\n",
|
||||||
|
"class props shape torch.Size([3, 5, 128, 512])\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"ename": "RuntimeError",
|
||||||
|
"evalue": "The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||||
|
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
||||||
|
"Cell \u001b[0;32mIn[10], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdetector\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain_dataloader\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:538\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 536\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m=\u001b[39m TrainerStatus\u001b[38;5;241m.\u001b[39mRUNNING\n\u001b[1;32m 537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 538\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 539\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m 540\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:47\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 47\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 49\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n\u001b[1;32m 50\u001b[0m _call_teardown_hook(trainer)\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:574\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m 567\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 568\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m 569\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m 570\u001b[0m ckpt_path,\n\u001b[1;32m 571\u001b[0m model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m 572\u001b[0m model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 573\u001b[0m )\n\u001b[0;32m--> 574\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 576\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n\u001b[1;32m 577\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:981\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m 976\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_signal_connector\u001b[38;5;241m.\u001b[39mregister_signal_handlers()\n\u001b[1;32m 978\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 979\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m 980\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 981\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 983\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 984\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m 985\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m 986\u001b[0m log\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: trainer tearing down\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/trainer.py:1025\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1023\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_run_sanity_check()\n\u001b[1;32m 1024\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1025\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1026\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1027\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected state \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:205\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 205\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 206\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:363\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_epoch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 362\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 363\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdone:\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end(data_fetcher)\n\u001b[1;32m 142\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_restarting \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/training_epoch_loop.py:250\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.advance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrun_training_batch\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[1;32m 248\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mlightning_module\u001b[38;5;241m.\u001b[39mautomatic_optimization:\n\u001b[1;32m 249\u001b[0m \u001b[38;5;66;03m# in automatic optimization, there can only be one optimizer\u001b[39;00m\n\u001b[0;32m--> 250\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautomatic_optimization\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 251\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 252\u001b[0m batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmanual_optimization\u001b[38;5;241m.\u001b[39mrun(kwargs)\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:190\u001b[0m, in \u001b[0;36m_AutomaticOptimization.run\u001b[0;34m(self, optimizer, batch_idx, kwargs)\u001b[0m\n\u001b[1;32m 183\u001b[0m closure()\n\u001b[1;32m 185\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# BACKWARD PASS\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m 188\u001b[0m \u001b[38;5;66;03m# gradient update with accumulated gradients\u001b[39;00m\n\u001b[1;32m 189\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 190\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 192\u001b[0m result \u001b[38;5;241m=\u001b[39m closure\u001b[38;5;241m.\u001b[39mconsume_result()\n\u001b[1;32m 193\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m result\u001b[38;5;241m.\u001b[39mloss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:268\u001b[0m, in \u001b[0;36m_AutomaticOptimization._optimizer_step\u001b[0;34m(self, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_ready()\n\u001b[1;32m 267\u001b[0m \u001b[38;5;66;03m# model hook\u001b[39;00m\n\u001b[0;32m--> 268\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_lightning_module_hook\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 269\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 270\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moptimizer_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 271\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 272\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 273\u001b[0m \u001b[43m \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 274\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_step_and_backward_closure\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 275\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 277\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m should_accumulate:\n\u001b[1;32m 278\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moptim_progress\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mstep\u001b[38;5;241m.\u001b[39mincrement_completed()\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:167\u001b[0m, in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(trainer, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m 164\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m hook_name\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[LightningModule]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpl_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 167\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 170\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/module.py:1306\u001b[0m, in \u001b[0;36mLightningModule.optimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_closure)\u001b[0m\n\u001b[1;32m 1275\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21moptimizer_step\u001b[39m(\n\u001b[1;32m 1276\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 1277\u001b[0m epoch: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1280\u001b[0m optimizer_closure: Optional[Callable[[], Any]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1281\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 1282\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls\u001b[39;00m\n\u001b[1;32m 1283\u001b[0m \u001b[38;5;124;03m the optimizer.\u001b[39;00m\n\u001b[1;32m 1284\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1304\u001b[0m \n\u001b[1;32m 1305\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1306\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer_closure\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/core/optimizer.py:153\u001b[0m, in \u001b[0;36mLightningOptimizer.step\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m MisconfigurationException(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWhen `optimizer.step(closure)` is called, the closure should be callable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 152\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_strategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 153\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_strategy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_after_step()\n\u001b[1;32m 157\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m step_output\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:238\u001b[0m, in \u001b[0;36mStrategy.optimizer_step\u001b[0;34m(self, optimizer, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;66;03m# TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed\u001b[39;00m\n\u001b[1;32m 237\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, pl\u001b[38;5;241m.\u001b[39mLightningModule)\n\u001b[0;32m--> 238\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_plugin\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:122\u001b[0m, in \u001b[0;36mPrecision.optimizer_step\u001b[0;34m(self, optimizer, model, closure, **kwargs)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Hook to run the optimizer step.\"\"\"\u001b[39;00m\n\u001b[1;32m 121\u001b[0m closure \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wrap_closure, model, optimizer, closure)\n\u001b[0;32m--> 122\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/lr_scheduler.py:130\u001b[0m, in \u001b[0;36mLRScheduler.__init__.<locals>.patch_track_step_called.<locals>.wrap_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 128\u001b[0m opt \u001b[38;5;241m=\u001b[39m opt_ref()\n\u001b[1;32m 129\u001b[0m opt\u001b[38;5;241m.\u001b[39m_opt_called \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m \u001b[38;5;66;03m# type: ignore[union-attr]\u001b[39;00m\n\u001b[0;32m--> 130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__get__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mopt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mopt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;18;43m__class__\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:484\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 480\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m 481\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 482\u001b[0m )\n\u001b[0;32m--> 484\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 485\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n\u001b[1;32m 487\u001b[0m \u001b[38;5;66;03m# call optimizer step post hooks\u001b[39;00m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/optimizer.py:89\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.<locals>._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 87\u001b[0m torch\u001b[38;5;241m.\u001b[39mset_grad_enabled(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefaults[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdifferentiable\u001b[39m\u001b[38;5;124m\"\u001b[39m])\n\u001b[1;32m 88\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 89\u001b[0m ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 90\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 91\u001b[0m torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/optim/adam.py:205\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m closure \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39menable_grad():\n\u001b[0;32m--> 205\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 207\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparam_groups:\n\u001b[1;32m 208\u001b[0m params_with_grad: List[Tensor] \u001b[38;5;241m=\u001b[39m []\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/plugins/precision/precision.py:108\u001b[0m, in \u001b[0;36mPrecision._wrap_closure\u001b[0;34m(self, model, optimizer, closure)\u001b[0m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_wrap_closure\u001b[39m(\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 97\u001b[0m model: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpl.LightningModule\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 98\u001b[0m optimizer: Steppable,\n\u001b[1;32m 99\u001b[0m closure: Callable[[], Any],\n\u001b[1;32m 100\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m 101\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``\u001b[39;00m\n\u001b[1;32m 102\u001b[0m \u001b[38;5;124;03m hook is called.\u001b[39;00m\n\u001b[1;32m 103\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 106\u001b[0m \n\u001b[1;32m 107\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 108\u001b[0m closure_result \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 109\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_after_closure(model, optimizer)\n\u001b[1;32m 110\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m closure_result\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:144\u001b[0m, in \u001b[0;36mClosure.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Optional[Tensor]:\n\u001b[0;32m--> 144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result\u001b[38;5;241m.\u001b[39mloss\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m 114\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecorate_context\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:129\u001b[0m, in \u001b[0;36mClosure.closure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 126\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m 127\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39menable_grad()\n\u001b[1;32m 128\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mclosure\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ClosureResult:\n\u001b[0;32m--> 129\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 131\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step_output\u001b[38;5;241m.\u001b[39mclosure_loss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 132\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mwarning_cache\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`training_step` returned `None`. If this was on purpose, ignore this warning...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/loops/optimization/automatic.py:317\u001b[0m, in \u001b[0;36m_AutomaticOptimization._training_step\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m 306\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Performs the actual train step with the tied hooks.\u001b[39;00m\n\u001b[1;32m 307\u001b[0m \n\u001b[1;32m 308\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 313\u001b[0m \n\u001b[1;32m 314\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 315\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\n\u001b[0;32m--> 317\u001b[0m training_step_output \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtraining_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mpost_training_step() \u001b[38;5;66;03m# unused hook - call anyway for backward compatibility\u001b[39;00m\n\u001b[1;32m 320\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m training_step_output \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mworld_size \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/call.py:319\u001b[0m, in \u001b[0;36m_call_strategy_hook\u001b[0;34m(trainer, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 318\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 319\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 321\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n\u001b[1;32m 322\u001b[0m pl_module\u001b[38;5;241m.\u001b[39m_current_fx_name \u001b[38;5;241m=\u001b[39m prev_fx_name\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/.venv/lib/python3.12/site-packages/pytorch_lightning/strategies/strategy.py:390\u001b[0m, in \u001b[0;36mStrategy.training_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module:\n\u001b[1;32m 389\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_redirection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_step\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 390\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:167\u001b[0m, in \u001b[0;36mDetectorModel.training_step\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtraining_step\u001b[39m(\u001b[38;5;28mself\u001b[39m, batch: TrainExample):\n\u001b[1;32m 166\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mforward(batch\u001b[38;5;241m.\u001b[39mspec)\n\u001b[0;32m--> 167\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43moutputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/modules.py:150\u001b[0m, in \u001b[0;36mDetectorModel.compute_loss\u001b[0;34m(self, outputs, batch)\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mclass props shape\u001b[39m\u001b[38;5;124m\"\u001b[39m, outputs\u001b[38;5;241m.\u001b[39mclass_probs\u001b[38;5;241m.\u001b[39mshape)\n\u001b[1;32m 149\u001b[0m valid_mask \u001b[38;5;241m=\u001b[39m batch\u001b[38;5;241m.\u001b[39mclass_heatmap\u001b[38;5;241m.\u001b[39many(dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, keepdim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m--> 150\u001b[0m classification_loss \u001b[38;5;241m=\u001b[39m \u001b[43mlosses\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal_loss\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 151\u001b[0m \u001b[43m \u001b[49m\u001b[43moutputs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_probs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 152\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_heatmap\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 153\u001b[0m \u001b[43m \u001b[49m\u001b[43mweights\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclass_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 154\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalid_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalid_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 155\u001b[0m \u001b[43m \u001b[49m\u001b[43mbeta\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbeta\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 156\u001b[0m \u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mconf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclassification\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfocal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43malpha\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 157\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 159\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 160\u001b[0m detection_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mdetection\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 161\u001b[0m \u001b[38;5;241m+\u001b[39m size_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39msize\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 162\u001b[0m \u001b[38;5;241m+\u001b[39m classification_loss \u001b[38;5;241m*\u001b[39m conf\u001b[38;5;241m.\u001b[39mclassification\u001b[38;5;241m.\u001b[39mweight\n\u001b[1;32m 163\u001b[0m )\n",
|
||||||
|
"File \u001b[0;32m~/Software/bat_detectors/batdetect2/batdetect2/train/losses.py:38\u001b[0m, in \u001b[0;36mfocal_loss\u001b[0;34m(pred, gt, weights, valid_mask, eps, beta, alpha)\u001b[0m\n\u001b[1;32m 35\u001b[0m pos_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39meq(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[1;32m 36\u001b[0m neg_inds \u001b[38;5;241m=\u001b[39m gt\u001b[38;5;241m.\u001b[39mlt(\u001b[38;5;241m1\u001b[39m)\u001b[38;5;241m.\u001b[39mfloat()\n\u001b[0;32m---> 38\u001b[0m pos_loss \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpred\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43meps\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpow\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpred\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43malpha\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mpos_inds\u001b[49m\n\u001b[1;32m 39\u001b[0m neg_loss \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 40\u001b[0m torch\u001b[38;5;241m.\u001b[39mlog(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m pred \u001b[38;5;241m+\u001b[39m eps)\n\u001b[1;32m 41\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(pred, alpha)\n\u001b[1;32m 42\u001b[0m \u001b[38;5;241m*\u001b[39m torch\u001b[38;5;241m.\u001b[39mpow(\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m-\u001b[39m gt, beta)\n\u001b[1;32m 43\u001b[0m \u001b[38;5;241m*\u001b[39m neg_inds\n\u001b[1;32m 44\u001b[0m )\n\u001b[1;32m 46\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m weights \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
|
||||||
|
"\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"trainer.fit(detector, train_dataloaders=train_dataloader)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "2f6924db-e520-49a1-bbe8-6c4956e46314",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.status.busy": "2024-11-19T17:33:10.811729Z",
|
||||||
|
"iopub.status.idle": "2024-11-19T17:33:10.811955Z",
|
||||||
|
"shell.execute_reply": "2024-11-19T17:33:10.811858Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-19T17:33:10.811849Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"clip_annotation = train_dataset.get_clip_annotation(0)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "23943e13-6875-49b8-9f18-2ba6528aa673",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.status.busy": "2024-11-19T17:33:10.812924Z",
|
||||||
|
"iopub.status.idle": "2024-11-19T17:33:10.813260Z",
|
||||||
|
"shell.execute_reply": "2024-11-19T17:33:10.813104Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-19T17:33:10.813087Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"spec = detector.compute_spectrogram(clip_annotation.clip)\n",
|
||||||
|
"outputs = detector(torch.tensor(spec.values).unsqueeze(0).unsqueeze(0))"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "dd1fe346-0873-4b14-ae1b-92ef1f4f27a5",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.status.busy": "2024-11-19T17:33:10.814343Z",
|
||||||
|
"iopub.status.idle": "2024-11-19T17:33:10.814806Z",
|
||||||
|
"shell.execute_reply": "2024-11-19T17:33:10.814628Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-19T17:33:10.814611Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"_, ax= plt.subplots(figsize=(15, 5))\n",
|
||||||
|
"spec.plot(ax=ax, add_colorbar=False)\n",
|
||||||
|
"ax.pcolormesh(spec.time, spec.frequency, outputs.detection_probs.detach().squeeze())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "eadd36ef-a04a-4665-b703-cec84cf1673b",
|
||||||
|
"metadata": {
|
||||||
|
"execution": {
|
||||||
|
"iopub.status.busy": "2024-11-19T17:33:10.815603Z",
|
||||||
|
"iopub.status.idle": "2024-11-19T17:33:10.816065Z",
|
||||||
|
"shell.execute_reply": "2024-11-19T17:33:10.815894Z",
|
||||||
|
"shell.execute_reply.started": "2024-11-19T17:33:10.815877Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print(f\"Num predicted soundevents: {len(predictions.sound_events)}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "e4e54f3e-6ddc-4fe5-8ce0-b527ff6f18ae",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "batdetect2-dev",
|
||||||
|
"language": "python",
|
||||||
|
"name": "batdetect2-dev"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.12.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
194
notebooks/data.py
Normal file
194
notebooks/data.py
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
import marimo
|
||||||
|
|
||||||
|
__generated_with = "0.14.16"
|
||||||
|
app = marimo.App(width="medium")
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
import marimo as mo
|
||||||
|
return (mo,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
from batdetect2.data import (
|
||||||
|
load_dataset_config,
|
||||||
|
load_dataset,
|
||||||
|
extract_recordings_df,
|
||||||
|
extract_sound_events_df,
|
||||||
|
compute_class_summary,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
compute_class_summary,
|
||||||
|
extract_recordings_df,
|
||||||
|
extract_sound_events_df,
|
||||||
|
load_dataset,
|
||||||
|
load_dataset_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(mo):
|
||||||
|
dataset_config_browser = mo.ui.file_browser(
|
||||||
|
selection_mode="file",
|
||||||
|
multiple=False,
|
||||||
|
)
|
||||||
|
dataset_config_browser
|
||||||
|
return (dataset_config_browser,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(dataset_config_browser, load_dataset_config, mo):
|
||||||
|
mo.stop(dataset_config_browser.path() is None)
|
||||||
|
dataset_config = load_dataset_config(dataset_config_browser.path())
|
||||||
|
return (dataset_config,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(dataset_config, load_dataset):
|
||||||
|
dataset = load_dataset(dataset_config, base_dir="../paper/")
|
||||||
|
return (dataset,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
from batdetect2.targets import load_target_config, build_targets
|
||||||
|
return build_targets, load_target_config
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(mo):
|
||||||
|
targets_config_browser = mo.ui.file_browser(
|
||||||
|
selection_mode="file",
|
||||||
|
multiple=False,
|
||||||
|
)
|
||||||
|
targets_config_browser
|
||||||
|
return (targets_config_browser,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(load_target_config, mo, targets_config_browser):
|
||||||
|
mo.stop(targets_config_browser.path() is None)
|
||||||
|
targets_config = load_target_config(targets_config_browser.path())
|
||||||
|
return (targets_config,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(build_targets, targets_config):
|
||||||
|
targets = build_targets(targets_config)
|
||||||
|
return (targets,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
import pandas as pd
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(dataset, extract_recordings_df):
|
||||||
|
recordings = extract_recordings_df(dataset)
|
||||||
|
recordings
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(dataset, extract_sound_events_df, targets):
|
||||||
|
sound_events = extract_sound_events_df(dataset, targets)
|
||||||
|
sound_events
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(compute_class_summary, dataset, targets):
|
||||||
|
compute_class_summary(dataset, targets)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
from batdetect2.data.split import split_dataset_by_recordings
|
||||||
|
return (split_dataset_by_recordings,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(dataset, split_dataset_by_recordings, targets):
|
||||||
|
train_dataset, val_dataset = split_dataset_by_recordings(dataset, targets, random_state=42)
|
||||||
|
return train_dataset, val_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(compute_class_summary, targets, train_dataset):
|
||||||
|
compute_class_summary(train_dataset, targets)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(compute_class_summary, targets, val_dataset):
|
||||||
|
compute_class_summary(val_dataset, targets)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
from soundevent import io, data
|
||||||
|
from pathlib import Path
|
||||||
|
return Path, data, io
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(Path, data, io, train_dataset):
|
||||||
|
io.save(
|
||||||
|
data.AnnotationSet(
|
||||||
|
name="batdetect2_tuning_train",
|
||||||
|
description="Set of annotations used as the train dataset for the hyper-parameter tuning stage.",
|
||||||
|
clip_annotations=train_dataset,
|
||||||
|
),
|
||||||
|
Path("../paper/data/datasets/annotation_sets/tuning_train.json"),
|
||||||
|
audio_dir=Path("../paper/data/datasets/"),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(Path, data, io, val_dataset):
|
||||||
|
io.save(
|
||||||
|
data.AnnotationSet(
|
||||||
|
name="batdetect2_tuning_val",
|
||||||
|
description="Set of annotations used as the validation dataset for the hyper-parameter tuning stage.",
|
||||||
|
clip_annotations=val_dataset,
|
||||||
|
),
|
||||||
|
Path("../paper/data/datasets/annotation_sets/tuning_val.json"),
|
||||||
|
audio_dir=Path("../paper/data/datasets/"),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(load_dataset, load_dataset_config):
|
||||||
|
config = load_dataset_config("../paper/conf/datasets/train/uk_tune.yaml")
|
||||||
|
rec = load_dataset(config, base_dir="../paper/")
|
||||||
|
return (rec,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(rec):
|
||||||
|
dict(rec[0].sound_events[0].tags[0].term)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(compute_class_summary, rec, targets):
|
||||||
|
compute_class_summary(rec,targets)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run()
|
||||||
273
notebooks/plotting.py
Normal file
273
notebooks/plotting.py
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
import marimo
|
||||||
|
|
||||||
|
__generated_with = "0.14.16"
|
||||||
|
app = marimo.App(width="medium")
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
import marimo as mo
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
from batdetect2.data import load_dataset_config, load_dataset
|
||||||
|
from batdetect2.preprocess import load_preprocessing_config, build_preprocessor
|
||||||
|
from batdetect2 import api
|
||||||
|
from soundevent import data
|
||||||
|
from batdetect2.evaluate.types import MatchEvaluation
|
||||||
|
from batdetect2.types import Annotation
|
||||||
|
from batdetect2.compat import annotation_to_sound_event_prediction
|
||||||
|
from batdetect2.plotting import (
|
||||||
|
plot_clip,
|
||||||
|
plot_clip_annotation,
|
||||||
|
plot_clip_prediction,
|
||||||
|
plot_matches,
|
||||||
|
plot_false_positive_match,
|
||||||
|
plot_false_negative_match,
|
||||||
|
plot_true_positive_match,
|
||||||
|
plot_cross_trigger_match,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
MatchEvaluation,
|
||||||
|
annotation_to_sound_event_prediction,
|
||||||
|
api,
|
||||||
|
build_preprocessor,
|
||||||
|
data,
|
||||||
|
load_dataset,
|
||||||
|
load_dataset_config,
|
||||||
|
load_preprocessing_config,
|
||||||
|
plot_clip_annotation,
|
||||||
|
plot_clip_prediction,
|
||||||
|
plot_cross_trigger_match,
|
||||||
|
plot_false_negative_match,
|
||||||
|
plot_false_positive_match,
|
||||||
|
plot_matches,
|
||||||
|
plot_true_positive_match,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(build_preprocessor, load_dataset_config, load_preprocessing_config):
|
||||||
|
dataset_config = load_dataset_config(
|
||||||
|
path="example_data/config.yaml", field="datasets.train"
|
||||||
|
)
|
||||||
|
|
||||||
|
preprocessor_config = load_preprocessing_config(
|
||||||
|
path="example_data/config.yaml", field="preprocess"
|
||||||
|
)
|
||||||
|
|
||||||
|
preprocessor = build_preprocessor(preprocessor_config)
|
||||||
|
return dataset_config, preprocessor
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(dataset_config, load_dataset):
|
||||||
|
dataset = load_dataset(dataset_config)
|
||||||
|
return (dataset,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(dataset):
|
||||||
|
clip_annotation = dataset[1]
|
||||||
|
return (clip_annotation,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(clip_annotation, plot_clip_annotation, preprocessor):
|
||||||
|
plot_clip_annotation(
|
||||||
|
clip_annotation, preprocessor=preprocessor, figsize=(15, 5)
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(annotation_to_sound_event_prediction, api, clip_annotation, data):
|
||||||
|
audio = api.load_audio(clip_annotation.clip.recording.path)
|
||||||
|
detections, features, spec = api.process_audio(audio)
|
||||||
|
clip_prediction = data.ClipPrediction(
|
||||||
|
clip=clip_annotation.clip,
|
||||||
|
sound_events=[
|
||||||
|
annotation_to_sound_event_prediction(
|
||||||
|
prediction, clip_annotation.clip.recording
|
||||||
|
)
|
||||||
|
for prediction in detections
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return (clip_prediction,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(clip_prediction, plot_clip_prediction):
|
||||||
|
plot_clip_prediction(clip_prediction, figsize=(15, 5))
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
from batdetect2.evaluate import match_predictions_and_annotations
|
||||||
|
import random
|
||||||
|
return match_predictions_and_annotations, random
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(data, random):
|
||||||
|
def add_noise(clip_annotation, time_buffer=0.003, freq_buffer=1000):
|
||||||
|
def _add_bbox_noise(bbox):
|
||||||
|
start_time, low_freq, end_time, high_freq = bbox.coordinates
|
||||||
|
return data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
start_time + random.uniform(-time_buffer, time_buffer),
|
||||||
|
low_freq + random.uniform(-freq_buffer, freq_buffer),
|
||||||
|
end_time + random.uniform(-time_buffer, time_buffer),
|
||||||
|
high_freq + random.uniform(-freq_buffer, freq_buffer),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _add_noise(se):
|
||||||
|
return se.model_copy(
|
||||||
|
update=dict(
|
||||||
|
sound_event=se.sound_event.model_copy(
|
||||||
|
update=dict(
|
||||||
|
geometry=_add_bbox_noise(se.sound_event.geometry)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return clip_annotation.model_copy(
|
||||||
|
update=dict(
|
||||||
|
sound_events=[
|
||||||
|
_add_noise(se) for se in clip_annotation.sound_events
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def drop_random(obj, p=0.5):
|
||||||
|
return obj.model_copy(
|
||||||
|
update=dict(
|
||||||
|
sound_events=[se for se in obj.sound_events if random.random() > p]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return add_noise, drop_random
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(
|
||||||
|
add_noise,
|
||||||
|
clip_annotation,
|
||||||
|
clip_prediction,
|
||||||
|
drop_random,
|
||||||
|
match_predictions_and_annotations,
|
||||||
|
):
|
||||||
|
|
||||||
|
|
||||||
|
matches = match_predictions_and_annotations(
|
||||||
|
drop_random(add_noise(clip_annotation), p=0.2),
|
||||||
|
drop_random(clip_prediction),
|
||||||
|
)
|
||||||
|
return (matches,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(clip_annotation, matches, plot_matches):
|
||||||
|
plot_matches(matches, clip_annotation.clip, figsize=(15, 5))
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(matches):
|
||||||
|
true_positives = []
|
||||||
|
false_positives = []
|
||||||
|
false_negatives = []
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
if match.source is None and match.target is not None:
|
||||||
|
false_negatives.append(match)
|
||||||
|
elif match.target is None and match.source is not None:
|
||||||
|
false_positives.append(match)
|
||||||
|
elif match.target is not None and match.source is not None:
|
||||||
|
true_positives.append(match)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return false_negatives, false_positives, true_positives
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(MatchEvaluation, false_positives, plot_false_positive_match):
|
||||||
|
false_positive = false_positives[0]
|
||||||
|
false_positive_eval = MatchEvaluation(
|
||||||
|
match=false_positive,
|
||||||
|
gt_det=False,
|
||||||
|
gt_class=None,
|
||||||
|
pred_score=false_positive.source.score,
|
||||||
|
pred_class_scores={
|
||||||
|
"myomyo": 0.2
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
plot_false_positive_match(false_positive_eval)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(MatchEvaluation, false_negatives, plot_false_negative_match):
|
||||||
|
false_negative = false_negatives[0]
|
||||||
|
false_negative_eval = MatchEvaluation(
|
||||||
|
match=false_negative,
|
||||||
|
gt_det=True,
|
||||||
|
gt_class="myomyo",
|
||||||
|
pred_score=None,
|
||||||
|
pred_class_scores={}
|
||||||
|
)
|
||||||
|
|
||||||
|
plot_false_negative_match(false_negative_eval)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(MatchEvaluation, plot_true_positive_match, true_positives):
|
||||||
|
true_positive = true_positives[0]
|
||||||
|
true_positive_eval = MatchEvaluation(
|
||||||
|
match=true_positive,
|
||||||
|
gt_det=True,
|
||||||
|
gt_class="myomyo",
|
||||||
|
pred_score=0.87,
|
||||||
|
pred_class_scores={
|
||||||
|
"pyomyo": 0.84,
|
||||||
|
"pippip": 0.84,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
plot_true_positive_match(true_positive_eval)
|
||||||
|
return (true_positive,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(MatchEvaluation, plot_cross_trigger_match, true_positive):
|
||||||
|
cross_trigger_eval = MatchEvaluation(
|
||||||
|
match=true_positive,
|
||||||
|
gt_det=True,
|
||||||
|
gt_class="myomyo",
|
||||||
|
pred_score=0.87,
|
||||||
|
pred_class_scores={
|
||||||
|
"pippip": 0.84,
|
||||||
|
"myomyo": 0.84,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
plot_cross_trigger_match(cross_trigger_eval)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run()
|
||||||
19
notebooks/signal_generation.py
Normal file
19
notebooks/signal_generation.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import marimo
|
||||||
|
|
||||||
|
__generated_with = "0.13.15"
|
||||||
|
app = marimo.App(width="medium")
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run()
|
||||||
97
notebooks/targets.py
Normal file
97
notebooks/targets.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
import marimo
|
||||||
|
|
||||||
|
__generated_with = "0.14.5"
|
||||||
|
app = marimo.App(width="medium")
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
import marimo as mo
|
||||||
|
return (mo,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
from batdetect2.data import load_dataset, load_dataset_config
|
||||||
|
return load_dataset, load_dataset_config
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(mo):
|
||||||
|
dataset_input = mo.ui.file_browser(label="dataset config file")
|
||||||
|
dataset_input
|
||||||
|
return (dataset_input,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(mo):
|
||||||
|
audio_dir_input = mo.ui.file_browser(
|
||||||
|
label="audio directory", selection_mode="directory"
|
||||||
|
)
|
||||||
|
audio_dir_input
|
||||||
|
return (audio_dir_input,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(dataset_input, load_dataset_config):
|
||||||
|
dataset_config = load_dataset_config(
|
||||||
|
path=dataset_input.path(), field="datasets.train",
|
||||||
|
)
|
||||||
|
return (dataset_config,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(audio_dir_input, dataset_config, load_dataset):
|
||||||
|
dataset = load_dataset(dataset_config, base_dir=audio_dir_input.path())
|
||||||
|
return (dataset,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(dataset):
|
||||||
|
len(dataset)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(dataset):
|
||||||
|
tag_groups = [
|
||||||
|
se.tags
|
||||||
|
for clip in dataset
|
||||||
|
for se in clip.sound_events
|
||||||
|
if se.tags
|
||||||
|
]
|
||||||
|
all_tags = [
|
||||||
|
tag for group in tag_groups for tag in group
|
||||||
|
]
|
||||||
|
return (all_tags,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(mo):
|
||||||
|
key_search = mo.ui.text(label="key", debounce=0.1)
|
||||||
|
value_search = mo.ui.text(label="value", debounce=0.1)
|
||||||
|
mo.hstack([key_search, value_search]).left()
|
||||||
|
return key_search, value_search
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(all_tags, key_search, mo, value_search):
|
||||||
|
filtered_tags = list(set(all_tags))
|
||||||
|
|
||||||
|
if key_search.value:
|
||||||
|
filtered_tags = [tag for tag in filtered_tags if key_search.value.lower() in tag.key.lower()]
|
||||||
|
|
||||||
|
if value_search.value:
|
||||||
|
filtered_tags = [tag for tag in filtered_tags if value_search.value.lower() in tag.value.lower()]
|
||||||
|
|
||||||
|
mo.vstack([mo.md(f"key={tag.key} value={tag.value}") for tag in filtered_tags[:5]])
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run()
|
||||||
@ -25,14 +25,14 @@ dependencies = [
|
|||||||
"scikit-learn>=1.2.2",
|
"scikit-learn>=1.2.2",
|
||||||
"scipy>=1.10.1",
|
"scipy>=1.10.1",
|
||||||
"seaborn>=0.13.2",
|
"seaborn>=0.13.2",
|
||||||
"soundevent[audio,geometry,plot]>=2.10.0",
|
"soundevent[audio,geometry,plot]>=2.9.1",
|
||||||
"tensorboard>=2.16.2",
|
"tensorboard>=2.16.2",
|
||||||
"torch>=1.13.1",
|
"torch>=1.13.1",
|
||||||
"torchaudio>=1.13.1",
|
"torchaudio>=1.13.1",
|
||||||
"torchvision>=0.14.0",
|
"torchvision>=0.14.0",
|
||||||
"tqdm>=4.66.2",
|
"tqdm>=4.66.2",
|
||||||
]
|
]
|
||||||
requires-python = ">=3.10,<3.14"
|
requires-python = ">=3.9,<3.13"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { text = "CC-by-nc-4" }
|
license = { text = "CC-by-nc-4" }
|
||||||
classifiers = [
|
classifiers = [
|
||||||
@ -75,6 +75,7 @@ dev = [
|
|||||||
"ruff>=0.7.3",
|
"ruff>=0.7.3",
|
||||||
"ipykernel>=6.29.4",
|
"ipykernel>=6.29.4",
|
||||||
"setuptools>=69.5.1",
|
"setuptools>=69.5.1",
|
||||||
|
"basedpyright>=1.31.0",
|
||||||
"myst-parser>=3.0.1",
|
"myst-parser>=3.0.1",
|
||||||
"sphinx-autobuild>=2024.10.3",
|
"sphinx-autobuild>=2024.10.3",
|
||||||
"numpydoc>=1.8.0",
|
"numpydoc>=1.8.0",
|
||||||
@ -86,24 +87,13 @@ dev = [
|
|||||||
"rust-just>=1.40.0",
|
"rust-just>=1.40.0",
|
||||||
"pandas-stubs>=2.2.2.240807",
|
"pandas-stubs>=2.2.2.240807",
|
||||||
"python-lsp-server>=1.13.0",
|
"python-lsp-server>=1.13.0",
|
||||||
"deepdiff>=8.6.1",
|
|
||||||
]
|
]
|
||||||
dvclive = ["dvclive>=3.48.2"]
|
dvclive = ["dvclive>=3.48.2"]
|
||||||
mlflow = ["mlflow>=3.1.1"]
|
mlflow = ["mlflow>=3.1.1"]
|
||||||
gradio = [
|
|
||||||
"gradio>=6.9.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 79
|
line-length = 79
|
||||||
target-version = "py310"
|
target-version = "py39"
|
||||||
exclude = [
|
|
||||||
"src/batdetect2/train/legacy",
|
|
||||||
"src/batdetect2/plotting/legacy",
|
|
||||||
"src/batdetect2/evaluate/legacy",
|
|
||||||
"src/batdetect2/finetune",
|
|
||||||
"src/batdetect2/utils",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
docstring-code-format = true
|
docstring-code-format = true
|
||||||
@ -115,12 +105,15 @@ select = ["E4", "E7", "E9", "F", "B", "Q", "I", "NPY201"]
|
|||||||
[tool.ruff.lint.pydocstyle]
|
[tool.ruff.lint.pydocstyle]
|
||||||
convention = "numpy"
|
convention = "numpy"
|
||||||
|
|
||||||
[tool.ty.src]
|
[tool.pyright]
|
||||||
include = ["src", "tests"]
|
include = ["src", "tests"]
|
||||||
|
pythonVersion = "3.9"
|
||||||
|
pythonPlatform = "All"
|
||||||
exclude = [
|
exclude = [
|
||||||
"src/batdetect2/train/legacy",
|
"src/batdetect2/detector/",
|
||||||
"src/batdetect2/plotting/legacy",
|
|
||||||
"src/batdetect2/evaluate/legacy",
|
|
||||||
"src/batdetect2/finetune",
|
"src/batdetect2/finetune",
|
||||||
"src/batdetect2/utils",
|
"src/batdetect2/utils",
|
||||||
|
"src/batdetect2/plot",
|
||||||
|
"src/batdetect2/evaluate/legacy",
|
||||||
|
"src/batdetect2/train/legacy",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -98,6 +98,7 @@ consult the API documentation in the code.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -164,7 +165,7 @@ def load_audio(
|
|||||||
time_exp_fact: float = 1,
|
time_exp_fact: float = 1,
|
||||||
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
|
target_samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
scale: bool = False,
|
scale: bool = False,
|
||||||
max_duration: float | None = None,
|
max_duration: Optional[float] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load audio from file.
|
"""Load audio from file.
|
||||||
|
|
||||||
@ -202,7 +203,7 @@ def load_audio(
|
|||||||
def generate_spectrogram(
|
def generate_spectrogram(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
config: SpectrogramParameters | None = None,
|
config: Optional[SpectrogramParameters] = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Generate spectrogram from audio array.
|
"""Generate spectrogram from audio array.
|
||||||
@ -239,7 +240,7 @@ def generate_spectrogram(
|
|||||||
def process_file(
|
def process_file(
|
||||||
audio_file: str,
|
audio_file: str,
|
||||||
model: DetectionModel = MODEL,
|
model: DetectionModel = MODEL,
|
||||||
config: ProcessingConfiguration | None = None,
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
) -> du.RunResults:
|
) -> du.RunResults:
|
||||||
"""Process audio file with model.
|
"""Process audio file with model.
|
||||||
@ -270,8 +271,8 @@ def process_spectrogram(
|
|||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
model: DetectionModel = MODEL,
|
model: DetectionModel = MODEL,
|
||||||
config: ProcessingConfiguration | None = None,
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
) -> tuple[list[Annotation], np.ndarray]:
|
) -> Tuple[List[Annotation], np.ndarray]:
|
||||||
"""Process spectrogram with model.
|
"""Process spectrogram with model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -311,9 +312,9 @@ def process_audio(
|
|||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
model: DetectionModel = MODEL,
|
model: DetectionModel = MODEL,
|
||||||
config: ProcessingConfiguration | None = None,
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
device: torch.device = DEVICE,
|
device: torch.device = DEVICE,
|
||||||
) -> tuple[list[Annotation], np.ndarray, torch.Tensor]:
|
) -> Tuple[List[Annotation], np.ndarray, torch.Tensor]:
|
||||||
"""Process audio array with model.
|
"""Process audio array with model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -355,8 +356,8 @@ def process_audio(
|
|||||||
def postprocess(
|
def postprocess(
|
||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
samp_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
config: ProcessingConfiguration | None = None,
|
config: Optional[ProcessingConfiguration] = None,
|
||||||
) -> tuple[list[Annotation], np.ndarray]:
|
) -> Tuple[List[Annotation], np.ndarray]:
|
||||||
"""Postprocess model outputs.
|
"""Postprocess model outputs.
|
||||||
|
|
||||||
Convert model tensor outputs to predicted bounding boxes and
|
Convert model tensor outputs to predicted bounding boxes and
|
||||||
|
|||||||
@ -1,90 +1,59 @@
|
|||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Sequence, cast
|
from typing import Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.audio.files import get_audio_files
|
from soundevent.audio.files import get_audio_files
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.data import Dataset, load_dataset_from_config
|
from batdetect2.core import merge_configs
|
||||||
from batdetect2.evaluate import (
|
from batdetect2.data import (
|
||||||
DEFAULT_EVAL_DIR,
|
|
||||||
EvaluationConfig,
|
|
||||||
EvaluatorProtocol,
|
|
||||||
build_evaluator,
|
|
||||||
run_evaluate,
|
|
||||||
save_evaluation_results,
|
|
||||||
)
|
|
||||||
from batdetect2.inference import (
|
|
||||||
InferenceConfig,
|
|
||||||
process_file_list,
|
|
||||||
run_batch_inference,
|
|
||||||
)
|
|
||||||
from batdetect2.logging import (
|
|
||||||
DEFAULT_LOGS_DIR,
|
|
||||||
AppLoggingConfig,
|
|
||||||
LoggerConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.models import (
|
|
||||||
Model,
|
|
||||||
ModelConfig,
|
|
||||||
build_model,
|
|
||||||
build_model_with_new_targets,
|
|
||||||
)
|
|
||||||
from batdetect2.models.detectors import Detector
|
|
||||||
from batdetect2.outputs import (
|
|
||||||
OutputFormatConfig,
|
OutputFormatConfig,
|
||||||
OutputFormatterProtocol,
|
|
||||||
OutputsConfig,
|
|
||||||
OutputTransformProtocol,
|
|
||||||
build_output_formatter,
|
build_output_formatter,
|
||||||
build_output_transform,
|
|
||||||
get_output_formatter,
|
get_output_formatter,
|
||||||
|
load_dataset_from_config,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess import (
|
from batdetect2.data.datasets import Dataset
|
||||||
ClipDetections,
|
from batdetect2.data.predictions.base import OutputFormatterProtocol
|
||||||
Detection,
|
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
||||||
PostprocessorProtocol,
|
from batdetect2.inference import process_file_list, run_batch_inference
|
||||||
build_postprocessor,
|
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||||
)
|
from batdetect2.models import Model, build_model
|
||||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
||||||
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.train import (
|
from batdetect2.train import (
|
||||||
DEFAULT_CHECKPOINT_DIR,
|
DEFAULT_CHECKPOINT_DIR,
|
||||||
TrainingConfig,
|
|
||||||
load_model_from_checkpoint,
|
load_model_from_checkpoint,
|
||||||
run_train,
|
train,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import (
|
||||||
|
AudioLoader,
|
||||||
|
BatDetect2Prediction,
|
||||||
|
EvaluatorProtocol,
|
||||||
|
PostprocessorProtocol,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
RawPrediction,
|
||||||
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2API:
|
class BatDetect2API:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: ModelConfig,
|
config: BatDetect2Config,
|
||||||
audio_config: AudioConfig,
|
|
||||||
train_config: TrainingConfig,
|
|
||||||
evaluation_config: EvaluationConfig,
|
|
||||||
inference_config: InferenceConfig,
|
|
||||||
outputs_config: OutputsConfig,
|
|
||||||
logging_config: AppLoggingConfig,
|
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
audio_loader: AudioLoader,
|
audio_loader: AudioLoader,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
evaluator: EvaluatorProtocol,
|
evaluator: EvaluatorProtocol,
|
||||||
formatter: OutputFormatterProtocol,
|
formatter: OutputFormatterProtocol,
|
||||||
output_transform: OutputTransformProtocol,
|
|
||||||
model: Model,
|
model: Model,
|
||||||
):
|
):
|
||||||
self.model_config = model_config
|
self.config = config
|
||||||
self.audio_config = audio_config
|
|
||||||
self.train_config = train_config
|
|
||||||
self.evaluation_config = evaluation_config
|
|
||||||
self.inference_config = inference_config
|
|
||||||
self.outputs_config = outputs_config
|
|
||||||
self.logging_config = logging_config
|
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.audio_loader = audio_loader
|
self.audio_loader = audio_loader
|
||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
@ -92,40 +61,34 @@ class BatDetect2API:
|
|||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
self.model = model
|
self.model = model
|
||||||
self.formatter = formatter
|
self.formatter = formatter
|
||||||
self.output_transform = output_transform
|
|
||||||
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
def load_annotations(
|
def load_annotations(
|
||||||
self,
|
self,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
base_dir: data.PathLike | None = None,
|
base_dir: Optional[data.PathLike] = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
return load_dataset_from_config(path, base_dir=base_dir)
|
return load_dataset_from_config(path, base_dir=base_dir)
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
||||||
train_workers: int = 0,
|
train_workers: Optional[int] = None,
|
||||||
val_workers: int = 0,
|
val_workers: Optional[int] = None,
|
||||||
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
|
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
|
||||||
log_dir: Path | None = DEFAULT_LOGS_DIR,
|
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
|
||||||
experiment_name: str | None = None,
|
experiment_name: Optional[str] = None,
|
||||||
num_epochs: int | None = None,
|
num_epochs: Optional[int] = None,
|
||||||
run_name: str | None = None,
|
run_name: Optional[str] = None,
|
||||||
seed: int | None = None,
|
seed: Optional[int] = None,
|
||||||
model_config: ModelConfig | None = None,
|
|
||||||
audio_config: AudioConfig | None = None,
|
|
||||||
train_config: TrainingConfig | None = None,
|
|
||||||
logger_config: LoggerConfig | None = None,
|
|
||||||
):
|
):
|
||||||
run_train(
|
train(
|
||||||
train_annotations=train_annotations,
|
train_annotations=train_annotations,
|
||||||
val_annotations=val_annotations,
|
val_annotations=val_annotations,
|
||||||
model=self.model,
|
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
model_config=model_config or self.model_config,
|
config=self.config,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
@ -136,81 +99,25 @@ class BatDetect2API:
|
|||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
train_config=train_config or self.train_config,
|
|
||||||
audio_config=audio_config or self.audio_config,
|
|
||||||
logger_config=logger_config or self.logging_config.train,
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def finetune(
|
|
||||||
self,
|
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
|
||||||
trainable: Literal[
|
|
||||||
"all", "heads", "classifier_head", "bbox_head"
|
|
||||||
] = "heads",
|
|
||||||
train_workers: int = 0,
|
|
||||||
val_workers: int = 0,
|
|
||||||
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
|
|
||||||
log_dir: Path | None = DEFAULT_LOGS_DIR,
|
|
||||||
experiment_name: str | None = None,
|
|
||||||
num_epochs: int | None = None,
|
|
||||||
run_name: str | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
model_config: ModelConfig | None = None,
|
|
||||||
audio_config: AudioConfig | None = None,
|
|
||||||
train_config: TrainingConfig | None = None,
|
|
||||||
logger_config: LoggerConfig | None = None,
|
|
||||||
) -> "BatDetect2API":
|
|
||||||
"""Fine-tune the model with trainable-parameter selection."""
|
|
||||||
|
|
||||||
self._set_trainable_parameters(trainable)
|
|
||||||
|
|
||||||
run_train(
|
|
||||||
train_annotations=train_annotations,
|
|
||||||
val_annotations=val_annotations,
|
|
||||||
model=self.model,
|
|
||||||
targets=self.targets,
|
|
||||||
model_config=model_config or self.model_config,
|
|
||||||
preprocessor=self.preprocessor,
|
|
||||||
audio_loader=self.audio_loader,
|
|
||||||
train_workers=train_workers,
|
|
||||||
val_workers=val_workers,
|
|
||||||
checkpoint_dir=checkpoint_dir,
|
|
||||||
log_dir=log_dir,
|
|
||||||
experiment_name=experiment_name,
|
|
||||||
num_epochs=num_epochs,
|
|
||||||
run_name=run_name,
|
|
||||||
seed=seed,
|
|
||||||
audio_config=audio_config or self.audio_config,
|
|
||||||
train_config=train_config or self.train_config,
|
|
||||||
logger_config=logger_config or self.logging_config.train,
|
|
||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
test_annotations: Sequence[data.ClipAnnotation],
|
test_annotations: Sequence[data.ClipAnnotation],
|
||||||
num_workers: int = 0,
|
num_workers: Optional[int] = None,
|
||||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||||
experiment_name: str | None = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: str | None = None,
|
run_name: Optional[str] = None,
|
||||||
save_predictions: bool = True,
|
save_predictions: bool = True,
|
||||||
audio_config: AudioConfig | None = None,
|
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
||||||
evaluation_config: EvaluationConfig | None = None,
|
return evaluate(
|
||||||
outputs_config: OutputsConfig | None = None,
|
|
||||||
logger_config: LoggerConfig | None = None,
|
|
||||||
) -> tuple[dict[str, float], list[ClipDetections]]:
|
|
||||||
return run_evaluate(
|
|
||||||
self.model,
|
self.model,
|
||||||
test_annotations,
|
test_annotations,
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
audio_config=audio_config or self.audio_config,
|
config=self.config,
|
||||||
evaluation_config=evaluation_config or self.evaluation_config,
|
|
||||||
output_config=outputs_config or self.outputs_config,
|
|
||||||
logger_config=logger_config or self.logging_config.evaluation,
|
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
@ -221,8 +128,8 @@ class BatDetect2API:
|
|||||||
def evaluate_predictions(
|
def evaluate_predictions(
|
||||||
self,
|
self,
|
||||||
annotations: Sequence[data.ClipAnnotation],
|
annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[ClipDetections],
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
output_dir: data.PathLike | None = None,
|
output_dir: Optional[data.PathLike] = None,
|
||||||
):
|
):
|
||||||
clip_evals = self.evaluator.evaluate(
|
clip_evals = self.evaluator.evaluate(
|
||||||
annotations,
|
annotations,
|
||||||
@ -232,66 +139,30 @@ class BatDetect2API:
|
|||||||
metrics = self.evaluator.compute_metrics(clip_evals)
|
metrics = self.evaluator.compute_metrics(clip_evals)
|
||||||
|
|
||||||
if output_dir is not None:
|
if output_dir is not None:
|
||||||
save_evaluation_results(
|
output_dir = Path(output_dir)
|
||||||
metrics=metrics,
|
|
||||||
plots=self.evaluator.generate_plots(clip_evals),
|
if not output_dir.is_dir():
|
||||||
output_dir=output_dir,
|
output_dir.mkdir(parents=True)
|
||||||
)
|
|
||||||
|
metrics_path = output_dir / "metrics.json"
|
||||||
|
metrics_path.write_text(json.dumps(metrics))
|
||||||
|
|
||||||
|
for figure_name, fig in self.evaluator.generate_plots(clip_evals):
|
||||||
|
fig_path = output_dir / figure_name
|
||||||
|
|
||||||
|
if not fig_path.parent.is_dir():
|
||||||
|
fig_path.parent.mkdir(parents=True)
|
||||||
|
|
||||||
|
fig.savefig(fig_path)
|
||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
def load_audio(self, path: data.PathLike) -> np.ndarray:
|
def load_audio(self, path: data.PathLike) -> np.ndarray:
|
||||||
return self.audio_loader.load_file(path)
|
return self.audio_loader.load_file(path)
|
||||||
|
|
||||||
def load_recording(self, recording: data.Recording) -> np.ndarray:
|
|
||||||
return self.audio_loader.load_recording(recording)
|
|
||||||
|
|
||||||
def load_clip(self, clip: data.Clip) -> np.ndarray:
|
def load_clip(self, clip: data.Clip) -> np.ndarray:
|
||||||
return self.audio_loader.load_clip(clip)
|
return self.audio_loader.load_clip(clip)
|
||||||
|
|
||||||
def get_top_class_name(self, detection: Detection) -> str:
|
|
||||||
"""Get highest-confidence class name for one detection."""
|
|
||||||
|
|
||||||
top_index = int(np.argmax(detection.class_scores))
|
|
||||||
return self.targets.class_names[top_index]
|
|
||||||
|
|
||||||
def get_class_scores(
|
|
||||||
self,
|
|
||||||
detection: Detection,
|
|
||||||
*,
|
|
||||||
include_top_class: bool = True,
|
|
||||||
sort_descending: bool = True,
|
|
||||||
) -> list[tuple[str, float]]:
|
|
||||||
"""Get class score list as ``(class_name, score)`` pairs."""
|
|
||||||
|
|
||||||
scores = [
|
|
||||||
(class_name, float(score))
|
|
||||||
for class_name, score in zip(
|
|
||||||
self.targets.class_names,
|
|
||||||
detection.class_scores,
|
|
||||||
strict=True,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
if sort_descending:
|
|
||||||
scores.sort(key=lambda item: item[1], reverse=True)
|
|
||||||
|
|
||||||
if include_top_class:
|
|
||||||
return scores
|
|
||||||
|
|
||||||
top_class_name = self.get_top_class_name(detection)
|
|
||||||
return [
|
|
||||||
(class_name, score)
|
|
||||||
for class_name, score in scores
|
|
||||||
if class_name != top_class_name
|
|
||||||
]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_detection_features(detection: Detection) -> np.ndarray:
|
|
||||||
"""Get extracted feature vector for one detection."""
|
|
||||||
|
|
||||||
return detection.features
|
|
||||||
|
|
||||||
def generate_spectrogram(
|
def generate_spectrogram(
|
||||||
self,
|
self,
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
@ -299,41 +170,24 @@ class BatDetect2API:
|
|||||||
tensor = torch.tensor(audio).unsqueeze(0)
|
tensor = torch.tensor(audio).unsqueeze(0)
|
||||||
return self.preprocessor(tensor)
|
return self.preprocessor(tensor)
|
||||||
|
|
||||||
def process_file(
|
def process_file(self, audio_file: str) -> BatDetect2Prediction:
|
||||||
self,
|
|
||||||
audio_file: data.PathLike,
|
|
||||||
batch_size: int | None = None,
|
|
||||||
) -> ClipDetections:
|
|
||||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||||
|
wav = self.audio_loader.load_recording(recording)
|
||||||
predictions = self.process_files(
|
detections = self.process_audio(wav)
|
||||||
[audio_file],
|
return BatDetect2Prediction(
|
||||||
batch_size=(
|
|
||||||
batch_size
|
|
||||||
if batch_size is not None
|
|
||||||
else self.inference_config.loader.batch_size
|
|
||||||
),
|
|
||||||
)
|
|
||||||
detections = [
|
|
||||||
detection
|
|
||||||
for prediction in predictions
|
|
||||||
for detection in prediction.detections
|
|
||||||
]
|
|
||||||
|
|
||||||
return ClipDetections(
|
|
||||||
clip=data.Clip(
|
clip=data.Clip(
|
||||||
uuid=recording.uuid,
|
uuid=recording.uuid,
|
||||||
recording=recording,
|
recording=recording,
|
||||||
start_time=0,
|
start_time=0,
|
||||||
end_time=recording.duration,
|
end_time=recording.duration,
|
||||||
),
|
),
|
||||||
detections=detections,
|
predictions=detections,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_audio(
|
def process_audio(
|
||||||
self,
|
self,
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
) -> list[Detection]:
|
) -> List[RawPrediction]:
|
||||||
spec = self.generate_spectrogram(audio)
|
spec = self.generate_spectrogram(audio)
|
||||||
return self.process_spectrogram(spec)
|
return self.process_spectrogram(spec)
|
||||||
|
|
||||||
@ -341,7 +195,7 @@ class BatDetect2API:
|
|||||||
self,
|
self,
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
start_time: float = 0,
|
start_time: float = 0,
|
||||||
) -> list[Detection]:
|
) -> List[RawPrediction]:
|
||||||
if spec.ndim == 4 and spec.shape[0] > 1:
|
if spec.ndim == 4 and spec.shape[0] > 1:
|
||||||
raise ValueError("Batched spectrograms not supported.")
|
raise ValueError("Batched spectrograms not supported.")
|
||||||
|
|
||||||
@ -350,74 +204,59 @@ class BatDetect2API:
|
|||||||
|
|
||||||
outputs = self.model.detector(spec)
|
outputs = self.model.detector(spec)
|
||||||
|
|
||||||
detections = self.postprocessor(
|
detections = self.model.postprocessor(
|
||||||
outputs,
|
outputs,
|
||||||
|
start_times=[start_time],
|
||||||
)[0]
|
)[0]
|
||||||
return self.output_transform.to_detections(
|
|
||||||
detections=detections,
|
return to_raw_predictions(detections.numpy(), targets=self.targets)
|
||||||
start_time=start_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_directory(
|
def process_directory(
|
||||||
self,
|
self,
|
||||||
audio_dir: data.PathLike,
|
audio_dir: data.PathLike,
|
||||||
) -> list[ClipDetections]:
|
) -> List[BatDetect2Prediction]:
|
||||||
files = list(get_audio_files(audio_dir))
|
files = list(get_audio_files(audio_dir))
|
||||||
return self.process_files(files)
|
return self.process_files(files)
|
||||||
|
|
||||||
def process_files(
|
def process_files(
|
||||||
self,
|
self,
|
||||||
audio_files: Sequence[data.PathLike],
|
audio_files: Sequence[data.PathLike],
|
||||||
batch_size: int | None = None,
|
num_workers: Optional[int] = None,
|
||||||
num_workers: int = 0,
|
) -> List[BatDetect2Prediction]:
|
||||||
audio_config: AudioConfig | None = None,
|
|
||||||
inference_config: InferenceConfig | None = None,
|
|
||||||
output_config: OutputsConfig | None = None,
|
|
||||||
) -> list[ClipDetections]:
|
|
||||||
return process_file_list(
|
return process_file_list(
|
||||||
self.model,
|
self.model,
|
||||||
audio_files,
|
audio_files,
|
||||||
|
config=self.config,
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
output_transform=self.output_transform,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
audio_config=audio_config or self.audio_config,
|
|
||||||
inference_config=inference_config or self.inference_config,
|
|
||||||
output_config=output_config or self.outputs_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_clips(
|
def process_clips(
|
||||||
self,
|
self,
|
||||||
clips: Sequence[data.Clip],
|
clips: Sequence[data.Clip],
|
||||||
batch_size: int | None = None,
|
batch_size: Optional[int] = None,
|
||||||
num_workers: int = 0,
|
num_workers: Optional[int] = None,
|
||||||
audio_config: AudioConfig | None = None,
|
) -> List[BatDetect2Prediction]:
|
||||||
inference_config: InferenceConfig | None = None,
|
|
||||||
output_config: OutputsConfig | None = None,
|
|
||||||
) -> list[ClipDetections]:
|
|
||||||
return run_batch_inference(
|
return run_batch_inference(
|
||||||
self.model,
|
self.model,
|
||||||
clips,
|
clips,
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
output_transform=self.output_transform,
|
config=self.config,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
audio_config=audio_config or self.audio_config,
|
|
||||||
inference_config=inference_config or self.inference_config,
|
|
||||||
output_config=output_config or self.outputs_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_predictions(
|
def save_predictions(
|
||||||
self,
|
self,
|
||||||
predictions: Sequence[ClipDetections],
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
format: str | None = None,
|
format: Optional[str] = None,
|
||||||
config: OutputFormatConfig | None = None,
|
config: Optional[OutputFormatConfig] = None,
|
||||||
):
|
):
|
||||||
formatter = self.formatter
|
formatter = self.formatter
|
||||||
|
|
||||||
@ -435,78 +274,50 @@ class BatDetect2API:
|
|||||||
def load_predictions(
|
def load_predictions(
|
||||||
self,
|
self,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
format: str | None = None,
|
) -> List[BatDetect2Prediction]:
|
||||||
config: OutputFormatConfig | None = None,
|
return self.formatter.load(path)
|
||||||
) -> list[object]:
|
|
||||||
formatter = self.formatter
|
|
||||||
|
|
||||||
if format is not None or config is not None:
|
|
||||||
format = format or config.name # type: ignore
|
|
||||||
formatter = get_output_formatter(
|
|
||||||
name=format,
|
|
||||||
targets=self.targets,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
return formatter.load(path)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
config: BatDetect2Config,
|
config: BatDetect2Config,
|
||||||
) -> "BatDetect2API":
|
):
|
||||||
targets = build_targets(config=config.model.targets)
|
targets = build_targets(config=config.targets)
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=config.audio)
|
audio_loader = build_audio_loader(config=config.audio)
|
||||||
|
|
||||||
preprocessor = build_preprocessor(
|
preprocessor = build_preprocessor(
|
||||||
input_samplerate=audio_loader.samplerate,
|
input_samplerate=audio_loader.samplerate,
|
||||||
config=config.model.preprocess,
|
config=config.preprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
postprocessor = build_postprocessor(
|
postprocessor = build_postprocessor(
|
||||||
preprocessor,
|
preprocessor,
|
||||||
config=config.model.postprocess,
|
config=config.postprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
formatter = build_output_formatter(
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||||
targets,
|
|
||||||
config=config.outputs.format,
|
|
||||||
)
|
|
||||||
output_transform = build_output_transform(
|
|
||||||
config=config.outputs.transform,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
evaluator = build_evaluator(
|
# NOTE: Better to have a separate instance of
|
||||||
config=config.evaluation,
|
# preprocessor and postprocessor as these may be moved
|
||||||
targets=targets,
|
# to another device.
|
||||||
transform=output_transform,
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE: Build separate instances of preprocessor and postprocessor
|
|
||||||
# to avoid device mismatch errors
|
|
||||||
model = build_model(
|
model = build_model(
|
||||||
config=config.model,
|
config=config.model,
|
||||||
targets=build_targets(config=config.model.targets),
|
targets=targets,
|
||||||
preprocessor=build_preprocessor(
|
preprocessor=build_preprocessor(
|
||||||
input_samplerate=audio_loader.samplerate,
|
input_samplerate=audio_loader.samplerate,
|
||||||
config=config.model.preprocess,
|
config=config.preprocess,
|
||||||
),
|
),
|
||||||
postprocessor=build_postprocessor(
|
postprocessor=build_postprocessor(
|
||||||
preprocessor,
|
preprocessor,
|
||||||
config=config.model.postprocess,
|
config=config.postprocess,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
formatter = build_output_formatter(targets, config=config.output)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
model_config=config.model,
|
config=config,
|
||||||
audio_config=config.audio,
|
|
||||||
train_config=config.train,
|
|
||||||
evaluation_config=config.evaluation,
|
|
||||||
inference_config=config.inference,
|
|
||||||
outputs_config=config.outputs,
|
|
||||||
logging_config=config.logging,
|
|
||||||
targets=targets,
|
targets=targets,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
@ -514,83 +325,40 @@ class BatDetect2API:
|
|||||||
evaluator=evaluator,
|
evaluator=evaluator,
|
||||||
model=model,
|
model=model,
|
||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
output_transform=output_transform,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_checkpoint(
|
def from_checkpoint(
|
||||||
cls,
|
cls,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
targets_config: TargetConfig | None = None,
|
config: Optional[BatDetect2Config] = None,
|
||||||
audio_config: AudioConfig | None = None,
|
|
||||||
train_config: TrainingConfig | None = None,
|
|
||||||
evaluation_config: EvaluationConfig | None = None,
|
|
||||||
inference_config: InferenceConfig | None = None,
|
|
||||||
outputs_config: OutputsConfig | None = None,
|
|
||||||
logging_config: AppLoggingConfig | None = None,
|
|
||||||
) -> "BatDetect2API":
|
|
||||||
model, model_config = load_model_from_checkpoint(path)
|
|
||||||
|
|
||||||
audio_config = audio_config or AudioConfig(
|
|
||||||
samplerate=model_config.samplerate,
|
|
||||||
)
|
|
||||||
train_config = train_config or TrainingConfig()
|
|
||||||
evaluation_config = evaluation_config or EvaluationConfig()
|
|
||||||
inference_config = inference_config or InferenceConfig()
|
|
||||||
outputs_config = outputs_config or OutputsConfig()
|
|
||||||
logging_config = logging_config or AppLoggingConfig()
|
|
||||||
|
|
||||||
if (
|
|
||||||
targets_config is not None
|
|
||||||
and targets_config != model_config.targets
|
|
||||||
):
|
):
|
||||||
targets = build_targets(config=targets_config)
|
model, stored_config = load_model_from_checkpoint(path)
|
||||||
model = build_model_with_new_targets(
|
|
||||||
model=model,
|
config = (
|
||||||
targets=targets,
|
merge_configs(stored_config, config) if config else stored_config
|
||||||
)
|
|
||||||
model_config = model_config.model_copy(
|
|
||||||
update={"targets": targets_config}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
targets = build_targets(config=model_config.targets)
|
targets = build_targets(config=config.targets)
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=audio_config)
|
audio_loader = build_audio_loader(config=config.audio)
|
||||||
|
|
||||||
preprocessor = build_preprocessor(
|
preprocessor = build_preprocessor(
|
||||||
input_samplerate=audio_loader.samplerate,
|
input_samplerate=audio_loader.samplerate,
|
||||||
config=model_config.preprocess,
|
config=config.preprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
postprocessor = build_postprocessor(
|
postprocessor = build_postprocessor(
|
||||||
preprocessor,
|
preprocessor,
|
||||||
config=model_config.postprocess,
|
config=config.postprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
formatter = build_output_formatter(
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||||
targets,
|
|
||||||
config=outputs_config.format,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_transform = build_output_transform(
|
formatter = build_output_formatter(targets, config=config.output)
|
||||||
config=outputs_config.transform,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
evaluator = build_evaluator(
|
|
||||||
config=evaluation_config,
|
|
||||||
targets=targets,
|
|
||||||
transform=output_transform,
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
model_config=model_config,
|
config=config,
|
||||||
audio_config=audio_config,
|
|
||||||
train_config=train_config,
|
|
||||||
evaluation_config=evaluation_config,
|
|
||||||
inference_config=inference_config,
|
|
||||||
outputs_config=outputs_config,
|
|
||||||
logging_config=logging_config,
|
|
||||||
targets=targets,
|
targets=targets,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
@ -598,27 +366,4 @@ class BatDetect2API:
|
|||||||
evaluator=evaluator,
|
evaluator=evaluator,
|
||||||
model=model,
|
model=model,
|
||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
output_transform=output_transform,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _set_trainable_parameters(
|
|
||||||
self,
|
|
||||||
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
|
|
||||||
) -> None:
|
|
||||||
detector = cast(Detector, self.model.detector)
|
|
||||||
|
|
||||||
for parameter in detector.parameters():
|
|
||||||
parameter.requires_grad = False
|
|
||||||
|
|
||||||
if trainable == "all":
|
|
||||||
for parameter in detector.parameters():
|
|
||||||
parameter.requires_grad = True
|
|
||||||
return
|
|
||||||
|
|
||||||
if trainable in {"heads", "classifier_head"}:
|
|
||||||
for parameter in detector.classifier_head.parameters():
|
|
||||||
parameter.requires_grad = True
|
|
||||||
|
|
||||||
if trainable in {"heads", "bbox_head"}:
|
|
||||||
for parameter in detector.bbox_head.parameters():
|
|
||||||
parameter.requires_grad = True
|
|
||||||
|
|||||||
@ -5,11 +5,8 @@ from batdetect2.audio.loader import (
|
|||||||
SoundEventAudioLoader,
|
SoundEventAudioLoader,
|
||||||
build_audio_loader,
|
build_audio_loader,
|
||||||
)
|
)
|
||||||
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AudioLoader",
|
|
||||||
"ClipperProtocol",
|
|
||||||
"TARGET_SAMPLERATE_HZ",
|
"TARGET_SAMPLERATE_HZ",
|
||||||
"AudioConfig",
|
"AudioConfig",
|
||||||
"SoundEventAudioLoader",
|
"SoundEventAudioLoader",
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Annotated, List, Literal
|
from typing import Annotated, List, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -6,13 +6,8 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds, intervals_overlap
|
from soundevent.geometry import compute_bounds, intervals_overlap
|
||||||
|
|
||||||
from batdetect2.audio.types import ClipperProtocol
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.core import (
|
from batdetect2.typing import ClipperProtocol
|
||||||
BaseConfig,
|
|
||||||
ImportConfig,
|
|
||||||
Registry,
|
|
||||||
add_import_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
||||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||||
@ -21,24 +16,12 @@ DEFAULT_MAX_EMPTY_CLIP = 0.1
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"build_clipper",
|
"build_clipper",
|
||||||
"ClipConfig",
|
"ClipConfig",
|
||||||
"ClipperImportConfig",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")
|
clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(clipper_registry)
|
|
||||||
class ClipperImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a clipper.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class RandomClipConfig(BaseConfig):
|
class RandomClipConfig(BaseConfig):
|
||||||
name: Literal["random_subclip"] = "random_subclip"
|
name: Literal["random_subclip"] = "random_subclip"
|
||||||
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
||||||
@ -262,12 +245,16 @@ class FixedDurationClip:
|
|||||||
|
|
||||||
|
|
||||||
ClipConfig = Annotated[
|
ClipConfig = Annotated[
|
||||||
RandomClipConfig | PaddedClipConfig | FixedDurationClipConfig,
|
Union[
|
||||||
|
RandomClipConfig,
|
||||||
|
PaddedClipConfig,
|
||||||
|
FixedDurationClipConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_clipper(config: ClipConfig | None = None) -> ClipperProtocol:
|
def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
|
||||||
config = config or RandomClipConfig()
|
config = config or RandomClipConfig()
|
||||||
|
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import DTypeLike
|
from numpy.typing import DTypeLike
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -5,8 +7,8 @@ from scipy.signal import resample, resample_poly
|
|||||||
from soundevent import audio, data
|
from soundevent import audio, data
|
||||||
from soundfile import LibsndfileError
|
from soundfile import LibsndfileError
|
||||||
|
|
||||||
from batdetect2.audio.types import AudioLoader
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
|
from batdetect2.typing import AudioLoader
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SoundEventAudioLoader",
|
"SoundEventAudioLoader",
|
||||||
@ -26,17 +28,15 @@ class ResampleConfig(BaseConfig):
|
|||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
enabled : bool, default=True
|
samplerate : int, default=256000
|
||||||
Whether to resample the audio to the target sample rate. If
|
The target sample rate in Hz to resample the audio to. Must be > 0.
|
||||||
``False``, the audio is returned at its original sample rate.
|
|
||||||
method : str, default="poly"
|
method : str, default="poly"
|
||||||
The resampling algorithm to use. Options:
|
The resampling algorithm to use. Options:
|
||||||
|
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
|
||||||
- ``"poly"``: Polyphase resampling via
|
Generally fast.
|
||||||
``scipy.signal.resample_poly``. Generally fast and accurate.
|
- "fourier": Resampling via Fourier method using
|
||||||
- ``"fourier"``: FFT-based resampling via
|
`scipy.signal.resample`. May handle non-integer
|
||||||
``scipy.signal.resample``. May be preferred for non-integer
|
resampling factors differently.
|
||||||
resampling ratios.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
@ -50,7 +50,7 @@ class AudioConfig(BaseConfig):
|
|||||||
resample: ResampleConfig = Field(default_factory=ResampleConfig)
|
resample: ResampleConfig = Field(default_factory=ResampleConfig)
|
||||||
|
|
||||||
|
|
||||||
def build_audio_loader(config: AudioConfig | None = None) -> AudioLoader:
|
def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader:
|
||||||
"""Factory function to create an AudioLoader based on configuration."""
|
"""Factory function to create an AudioLoader based on configuration."""
|
||||||
config = config or AudioConfig()
|
config = config or AudioConfig()
|
||||||
return SoundEventAudioLoader(
|
return SoundEventAudioLoader(
|
||||||
@ -65,7 +65,7 @@ class SoundEventAudioLoader(AudioLoader):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
config: ResampleConfig | None = None,
|
config: Optional[ResampleConfig] = None,
|
||||||
):
|
):
|
||||||
self.samplerate = samplerate
|
self.samplerate = samplerate
|
||||||
self.config = config or ResampleConfig()
|
self.config = config or ResampleConfig()
|
||||||
@ -73,7 +73,7 @@ class SoundEventAudioLoader(AudioLoader):
|
|||||||
def load_file(
|
def load_file(
|
||||||
self,
|
self,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess audio directly from a file path."""
|
"""Load and preprocess audio directly from a file path."""
|
||||||
return load_file_audio(
|
return load_file_audio(
|
||||||
@ -86,7 +86,7 @@ class SoundEventAudioLoader(AudioLoader):
|
|||||||
def load_recording(
|
def load_recording(
|
||||||
self,
|
self,
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess the entire audio for a Recording object."""
|
"""Load and preprocess the entire audio for a Recording object."""
|
||||||
return load_recording_audio(
|
return load_recording_audio(
|
||||||
@ -99,7 +99,7 @@ class SoundEventAudioLoader(AudioLoader):
|
|||||||
def load_clip(
|
def load_clip(
|
||||||
self,
|
self,
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess the audio segment defined by a Clip object."""
|
"""Load and preprocess the audio segment defined by a Clip object."""
|
||||||
return load_clip_audio(
|
return load_clip_audio(
|
||||||
@ -112,10 +112,10 @@ class SoundEventAudioLoader(AudioLoader):
|
|||||||
|
|
||||||
def load_file_audio(
|
def load_file_audio(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
samplerate: int | None = None,
|
samplerate: Optional[int] = None,
|
||||||
config: ResampleConfig | None = None,
|
config: Optional[ResampleConfig] = None,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess audio from a file path using specified config."""
|
"""Load and preprocess audio from a file path using specified config."""
|
||||||
try:
|
try:
|
||||||
@ -136,10 +136,10 @@ def load_file_audio(
|
|||||||
|
|
||||||
def load_recording_audio(
|
def load_recording_audio(
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
samplerate: int | None = None,
|
samplerate: Optional[int] = None,
|
||||||
config: ResampleConfig | None = None,
|
config: Optional[ResampleConfig] = None,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess the entire audio content of a recording using config."""
|
"""Load and preprocess the entire audio content of a recording using config."""
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
@ -158,10 +158,10 @@ def load_recording_audio(
|
|||||||
|
|
||||||
def load_clip_audio(
|
def load_clip_audio(
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
samplerate: int | None = None,
|
samplerate: Optional[int] = None,
|
||||||
config: ResampleConfig | None = None,
|
config: Optional[ResampleConfig] = None,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Load and preprocess a specific audio clip segment based on config."""
|
"""Load and preprocess a specific audio clip segment based on config."""
|
||||||
try:
|
try:
|
||||||
@ -194,31 +194,7 @@ def resample_audio(
|
|||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
method: str = "poly",
|
method: str = "poly",
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Resample an audio waveform to a target sample rate.
|
"""Resample an audio waveform DataArray to a target sample rate."""
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
wav : np.ndarray
|
|
||||||
Input waveform array. The last axis is assumed to be time.
|
|
||||||
sr : int
|
|
||||||
Original sample rate of ``wav`` in Hz.
|
|
||||||
samplerate : int, default=256000
|
|
||||||
Target sample rate in Hz.
|
|
||||||
method : str, default="poly"
|
|
||||||
Resampling algorithm: ``"poly"`` (polyphase) or
|
|
||||||
``"fourier"`` (FFT-based).
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
Resampled waveform. If ``sr == samplerate`` the input array is
|
|
||||||
returned unchanged.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
NotImplementedError
|
|
||||||
If ``method`` is not ``"poly"`` or ``"fourier"``.
|
|
||||||
"""
|
|
||||||
if sr == samplerate:
|
if sr == samplerate:
|
||||||
return wav
|
return wav
|
||||||
|
|
||||||
@ -288,7 +264,7 @@ def resample_audio_fourier(
|
|||||||
sr_new: int,
|
sr_new: int,
|
||||||
axis: int = -1,
|
axis: int = -1,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""Resample a numpy array using ``scipy.signal.resample``.
|
"""Resample a numpy array using `scipy.signal.resample`.
|
||||||
|
|
||||||
This method uses FFTs to resample the signal.
|
This method uses FFTs to resample the signal.
|
||||||
|
|
||||||
@ -296,20 +272,23 @@ def resample_audio_fourier(
|
|||||||
----------
|
----------
|
||||||
array : np.ndarray
|
array : np.ndarray
|
||||||
The input array to resample.
|
The input array to resample.
|
||||||
sr_orig : int
|
num : int
|
||||||
The original sample rate in Hz.
|
The desired number of samples in the output array along `axis`.
|
||||||
sr_new : int
|
|
||||||
The target sample rate in Hz.
|
|
||||||
axis : int, default=-1
|
axis : int, default=-1
|
||||||
The axis of ``array`` along which to resample.
|
The axis of `array` along which to resample.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
np.ndarray
|
np.ndarray
|
||||||
The array resampled to the target sample rate.
|
The array resampled to have `num` samples along `axis`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `num` is negative.
|
||||||
"""
|
"""
|
||||||
ratio = sr_new / sr_orig
|
ratio = sr_new / sr_orig
|
||||||
return resample(
|
return resample( # type: ignore
|
||||||
array,
|
array,
|
||||||
int(array.shape[axis] * ratio),
|
int(array.shape[axis] * ratio),
|
||||||
axis=axis,
|
axis=axis,
|
||||||
|
|||||||
@ -1,40 +0,0 @@
|
|||||||
from typing import Protocol
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AudioLoader",
|
|
||||||
"ClipperProtocol",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class AudioLoader(Protocol):
|
|
||||||
samplerate: int
|
|
||||||
|
|
||||||
def load_file(
|
|
||||||
self,
|
|
||||||
path: data.PathLike,
|
|
||||||
audio_dir: data.PathLike | None = None,
|
|
||||||
) -> np.ndarray: ...
|
|
||||||
|
|
||||||
def load_recording(
|
|
||||||
self,
|
|
||||||
recording: data.Recording,
|
|
||||||
audio_dir: data.PathLike | None = None,
|
|
||||||
) -> np.ndarray: ...
|
|
||||||
|
|
||||||
def load_clip(
|
|
||||||
self,
|
|
||||||
clip: data.Clip,
|
|
||||||
audio_dir: data.PathLike | None = None,
|
|
||||||
) -> np.ndarray: ...
|
|
||||||
|
|
||||||
|
|
||||||
class ClipperProtocol(Protocol):
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
) -> data.ClipAnnotation: ...
|
|
||||||
|
|
||||||
def get_subclip(self, clip: data.Clip) -> data.Clip: ...
|
|
||||||
@ -2,7 +2,6 @@ from batdetect2.cli.base import cli
|
|||||||
from batdetect2.cli.compat import detect
|
from batdetect2.cli.compat import detect
|
||||||
from batdetect2.cli.data import data
|
from batdetect2.cli.data import data
|
||||||
from batdetect2.cli.evaluate import evaluate_command
|
from batdetect2.cli.evaluate import evaluate_command
|
||||||
from batdetect2.cli.inference import predict
|
|
||||||
from batdetect2.cli.train import train_command
|
from batdetect2.cli.train import train_command
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -11,7 +10,6 @@ __all__ = [
|
|||||||
"data",
|
"data",
|
||||||
"train_command",
|
"train_command",
|
||||||
"evaluate_command",
|
"evaluate_command",
|
||||||
"predict",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
@ -33,9 +34,9 @@ def data(): ...
|
|||||||
)
|
)
|
||||||
def summary(
|
def summary(
|
||||||
dataset_config: Path,
|
dataset_config: Path,
|
||||||
field: str | None = None,
|
field: Optional[str] = None,
|
||||||
targets_path: Path | None = None,
|
targets_path: Optional[Path] = None,
|
||||||
base_dir: Path | None = None,
|
base_dir: Optional[Path] = None,
|
||||||
):
|
):
|
||||||
from batdetect2.data import compute_class_summary, load_dataset_from_config
|
from batdetect2.data import compute_class_summary, load_dataset_from_config
|
||||||
from batdetect2.targets import load_targets
|
from batdetect2.targets import load_targets
|
||||||
@ -82,9 +83,9 @@ def summary(
|
|||||||
)
|
)
|
||||||
def convert(
|
def convert(
|
||||||
dataset_config: Path,
|
dataset_config: Path,
|
||||||
field: str | None = None,
|
field: Optional[str] = None,
|
||||||
output: Path = Path("annotations.json"),
|
output: Path = Path("annotations.json"),
|
||||||
base_dir: Path | None = None,
|
base_dir: Optional[Path] = None,
|
||||||
):
|
):
|
||||||
"""Convert a dataset config file to soundevent format."""
|
"""Convert a dataset config file to soundevent format."""
|
||||||
from soundevent import data, io
|
from soundevent import data, io
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -12,14 +13,9 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
|||||||
|
|
||||||
|
|
||||||
@cli.command(name="evaluate")
|
@cli.command(name="evaluate")
|
||||||
@click.argument("model_path", type=click.Path(exists=True))
|
@click.argument("model-path", type=click.Path(exists=True))
|
||||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||||
@click.option("--targets", "targets_config", type=click.Path(exists=True))
|
@click.option("--config", "config_path", type=click.Path())
|
||||||
@click.option("--audio-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--evaluation-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--inference-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--outputs-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--logging-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
|
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
|
||||||
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
|
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
|
||||||
@click.option("--experiment-name", type=str)
|
@click.option("--experiment-name", type=str)
|
||||||
@ -29,25 +25,15 @@ def evaluate_command(
|
|||||||
model_path: Path,
|
model_path: Path,
|
||||||
test_dataset: Path,
|
test_dataset: Path,
|
||||||
base_dir: Path,
|
base_dir: Path,
|
||||||
targets_config: Path | None,
|
config_path: Optional[Path],
|
||||||
audio_config: Path | None,
|
|
||||||
evaluation_config: Path | None,
|
|
||||||
inference_config: Path | None,
|
|
||||||
outputs_config: Path | None,
|
|
||||||
logging_config: Path | None,
|
|
||||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||||
num_workers: int = 0,
|
num_workers: Optional[int] = None,
|
||||||
experiment_name: str | None = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: str | None = None,
|
run_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.config import load_full_config
|
||||||
from batdetect2.data import load_dataset_from_config
|
from batdetect2.data import load_dataset_from_config
|
||||||
from batdetect2.evaluate import EvaluationConfig
|
|
||||||
from batdetect2.inference import InferenceConfig
|
|
||||||
from batdetect2.logging import AppLoggingConfig
|
|
||||||
from batdetect2.outputs import OutputsConfig
|
|
||||||
from batdetect2.targets import TargetConfig
|
|
||||||
|
|
||||||
logger.info("Initiating evaluation process...")
|
logger.info("Initiating evaluation process...")
|
||||||
|
|
||||||
@ -61,44 +47,11 @@ def evaluate_command(
|
|||||||
num_annotations=len(test_annotations),
|
num_annotations=len(test_annotations),
|
||||||
)
|
)
|
||||||
|
|
||||||
target_conf = (
|
config = None
|
||||||
TargetConfig.load(targets_config)
|
if config_path is not None:
|
||||||
if targets_config is not None
|
config = load_full_config(config_path)
|
||||||
else None
|
|
||||||
)
|
|
||||||
audio_conf = (
|
|
||||||
AudioConfig.load(audio_config) if audio_config is not None else None
|
|
||||||
)
|
|
||||||
eval_conf = (
|
|
||||||
EvaluationConfig.load(evaluation_config)
|
|
||||||
if evaluation_config is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
inference_conf = (
|
|
||||||
InferenceConfig.load(inference_config)
|
|
||||||
if inference_config is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
outputs_conf = (
|
|
||||||
OutputsConfig.load(outputs_config)
|
|
||||||
if outputs_config is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
logging_conf = (
|
|
||||||
AppLoggingConfig.load(logging_config)
|
|
||||||
if logging_config is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(model_path, config=config)
|
||||||
model_path,
|
|
||||||
targets_config=target_conf,
|
|
||||||
audio_config=audio_conf,
|
|
||||||
evaluation_config=eval_conf,
|
|
||||||
inference_config=inference_conf,
|
|
||||||
outputs_config=outputs_conf,
|
|
||||||
logging_config=logging_conf,
|
|
||||||
)
|
|
||||||
|
|
||||||
api.evaluate(
|
api.evaluate(
|
||||||
test_annotations,
|
test_annotations,
|
||||||
|
|||||||
@ -1,231 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import click
|
|
||||||
from loguru import logger
|
|
||||||
from soundevent import io
|
|
||||||
from soundevent.audio.files import get_audio_files
|
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
|
||||||
|
|
||||||
__all__ = ["predict"]
|
|
||||||
|
|
||||||
|
|
||||||
@cli.group(name="predict")
|
|
||||||
def predict() -> None:
|
|
||||||
"""Run prediction with BatDetect2 API v2."""
|
|
||||||
|
|
||||||
|
|
||||||
def _build_api(
|
|
||||||
model_path: Path,
|
|
||||||
audio_config: Path | None,
|
|
||||||
inference_config: Path | None,
|
|
||||||
outputs_config: Path | None,
|
|
||||||
logging_config: Path | None,
|
|
||||||
):
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
|
||||||
from batdetect2.audio import AudioConfig
|
|
||||||
from batdetect2.inference import InferenceConfig
|
|
||||||
from batdetect2.logging import AppLoggingConfig
|
|
||||||
from batdetect2.outputs import OutputsConfig
|
|
||||||
|
|
||||||
audio_conf = (
|
|
||||||
AudioConfig.load(audio_config) if audio_config is not None else None
|
|
||||||
)
|
|
||||||
inference_conf = (
|
|
||||||
InferenceConfig.load(inference_config)
|
|
||||||
if inference_config is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
outputs_conf = (
|
|
||||||
OutputsConfig.load(outputs_config)
|
|
||||||
if outputs_config is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
logging_conf = (
|
|
||||||
AppLoggingConfig.load(logging_config)
|
|
||||||
if logging_config is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
api = BatDetect2API.from_checkpoint(
|
|
||||||
model_path,
|
|
||||||
audio_config=audio_conf,
|
|
||||||
inference_config=inference_conf,
|
|
||||||
outputs_config=outputs_conf,
|
|
||||||
logging_config=logging_conf,
|
|
||||||
)
|
|
||||||
return api, audio_conf, inference_conf, outputs_conf
|
|
||||||
|
|
||||||
|
|
||||||
def _run_prediction(
|
|
||||||
model_path: Path,
|
|
||||||
audio_files: list[Path],
|
|
||||||
output_path: Path,
|
|
||||||
audio_config: Path | None,
|
|
||||||
inference_config: Path | None,
|
|
||||||
outputs_config: Path | None,
|
|
||||||
logging_config: Path | None,
|
|
||||||
batch_size: int | None,
|
|
||||||
num_workers: int,
|
|
||||||
format_name: str | None,
|
|
||||||
) -> None:
|
|
||||||
logger.info("Initiating prediction process...")
|
|
||||||
|
|
||||||
api, audio_conf, inference_conf, outputs_conf = _build_api(
|
|
||||||
model_path,
|
|
||||||
audio_config,
|
|
||||||
inference_config,
|
|
||||||
outputs_config,
|
|
||||||
logging_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Found {num_files} audio files", num_files=len(audio_files))
|
|
||||||
|
|
||||||
predictions = api.process_files(
|
|
||||||
audio_files,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=num_workers,
|
|
||||||
audio_config=audio_conf,
|
|
||||||
inference_config=inference_conf,
|
|
||||||
output_config=outputs_conf,
|
|
||||||
)
|
|
||||||
|
|
||||||
common_path = audio_files[0].parent if audio_files else None
|
|
||||||
api.save_predictions(
|
|
||||||
predictions,
|
|
||||||
path=output_path,
|
|
||||||
audio_dir=common_path,
|
|
||||||
format=format_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Prediction complete. Results saved to {path}", path=output_path
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@predict.command(name="directory")
|
|
||||||
@click.argument("model_path", type=click.Path(exists=True))
|
|
||||||
@click.argument("audio_dir", type=click.Path(exists=True))
|
|
||||||
@click.argument("output_path", type=click.Path())
|
|
||||||
@click.option("--audio-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--inference-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--outputs-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--logging-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--batch-size", type=int)
|
|
||||||
@click.option("--workers", "num_workers", type=int, default=0)
|
|
||||||
@click.option("--format", "format_name", type=str)
|
|
||||||
def predict_directory_command(
|
|
||||||
model_path: Path,
|
|
||||||
audio_dir: Path,
|
|
||||||
output_path: Path,
|
|
||||||
audio_config: Path | None,
|
|
||||||
inference_config: Path | None,
|
|
||||||
outputs_config: Path | None,
|
|
||||||
logging_config: Path | None,
|
|
||||||
batch_size: int | None,
|
|
||||||
num_workers: int,
|
|
||||||
format_name: str | None,
|
|
||||||
) -> None:
|
|
||||||
audio_files = list(get_audio_files(audio_dir))
|
|
||||||
_run_prediction(
|
|
||||||
model_path=model_path,
|
|
||||||
audio_files=audio_files,
|
|
||||||
output_path=output_path,
|
|
||||||
audio_config=audio_config,
|
|
||||||
inference_config=inference_config,
|
|
||||||
outputs_config=outputs_config,
|
|
||||||
logging_config=logging_config,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=num_workers,
|
|
||||||
format_name=format_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@predict.command(name="file_list")
|
|
||||||
@click.argument("model_path", type=click.Path(exists=True))
|
|
||||||
@click.argument("file_list", type=click.Path(exists=True))
|
|
||||||
@click.argument("output_path", type=click.Path())
|
|
||||||
@click.option("--audio-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--inference-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--outputs-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--logging-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--batch-size", type=int)
|
|
||||||
@click.option("--workers", "num_workers", type=int, default=0)
|
|
||||||
@click.option("--format", "format_name", type=str)
|
|
||||||
def predict_file_list_command(
|
|
||||||
model_path: Path,
|
|
||||||
file_list: Path,
|
|
||||||
output_path: Path,
|
|
||||||
audio_config: Path | None,
|
|
||||||
inference_config: Path | None,
|
|
||||||
outputs_config: Path | None,
|
|
||||||
logging_config: Path | None,
|
|
||||||
batch_size: int | None,
|
|
||||||
num_workers: int,
|
|
||||||
format_name: str | None,
|
|
||||||
) -> None:
|
|
||||||
file_list = Path(file_list)
|
|
||||||
audio_files = [
|
|
||||||
Path(line.strip())
|
|
||||||
for line in file_list.read_text().splitlines()
|
|
||||||
if line.strip()
|
|
||||||
]
|
|
||||||
|
|
||||||
_run_prediction(
|
|
||||||
model_path=model_path,
|
|
||||||
audio_files=audio_files,
|
|
||||||
output_path=output_path,
|
|
||||||
audio_config=audio_config,
|
|
||||||
inference_config=inference_config,
|
|
||||||
outputs_config=outputs_config,
|
|
||||||
logging_config=logging_config,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=num_workers,
|
|
||||||
format_name=format_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@predict.command(name="dataset")
|
|
||||||
@click.argument("model_path", type=click.Path(exists=True))
|
|
||||||
@click.argument("dataset_path", type=click.Path(exists=True))
|
|
||||||
@click.argument("output_path", type=click.Path())
|
|
||||||
@click.option("--audio-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--inference-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--outputs-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--logging-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--batch-size", type=int)
|
|
||||||
@click.option("--workers", "num_workers", type=int, default=0)
|
|
||||||
@click.option("--format", "format_name", type=str)
|
|
||||||
def predict_dataset_command(
|
|
||||||
model_path: Path,
|
|
||||||
dataset_path: Path,
|
|
||||||
output_path: Path,
|
|
||||||
audio_config: Path | None,
|
|
||||||
inference_config: Path | None,
|
|
||||||
outputs_config: Path | None,
|
|
||||||
logging_config: Path | None,
|
|
||||||
batch_size: int | None,
|
|
||||||
num_workers: int,
|
|
||||||
format_name: str | None,
|
|
||||||
) -> None:
|
|
||||||
dataset_path = Path(dataset_path)
|
|
||||||
dataset = io.load(dataset_path, type="annotation_set")
|
|
||||||
audio_files = sorted(
|
|
||||||
{
|
|
||||||
Path(clip_annotation.clip.recording.path)
|
|
||||||
for clip_annotation in dataset.clip_annotations
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
_run_prediction(
|
|
||||||
model_path=model_path,
|
|
||||||
audio_files=audio_files,
|
|
||||||
output_path=output_path,
|
|
||||||
audio_config=audio_config,
|
|
||||||
inference_config=inference_config,
|
|
||||||
outputs_config=outputs_config,
|
|
||||||
logging_config=logging_config,
|
|
||||||
batch_size=batch_size,
|
|
||||||
num_workers=num_workers,
|
|
||||||
format_name=format_name,
|
|
||||||
)
|
|
||||||
@ -1,4 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -13,15 +14,10 @@ __all__ = ["train_command"]
|
|||||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||||
@click.option("--model", "model_path", type=click.Path(exists=True))
|
@click.option("--model", "model_path", type=click.Path(exists=True))
|
||||||
@click.option("--targets", "targets_config", type=click.Path(exists=True))
|
@click.option("--targets", "targets_config", type=click.Path(exists=True))
|
||||||
@click.option("--model-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--training-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--audio-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--evaluation-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--inference-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--outputs-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--logging-config", type=click.Path(exists=True))
|
|
||||||
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
||||||
@click.option("--log-dir", type=click.Path(exists=True))
|
@click.option("--log-dir", type=click.Path(exists=True))
|
||||||
|
@click.option("--config", type=click.Path(exists=True))
|
||||||
|
@click.option("--config-field", type=str)
|
||||||
@click.option("--train-workers", type=int)
|
@click.option("--train-workers", type=int)
|
||||||
@click.option("--val-workers", type=int)
|
@click.option("--val-workers", type=int)
|
||||||
@click.option("--num-epochs", type=int)
|
@click.option("--num-epochs", type=int)
|
||||||
@ -30,82 +26,42 @@ __all__ = ["train_command"]
|
|||||||
@click.option("--seed", type=int)
|
@click.option("--seed", type=int)
|
||||||
def train_command(
|
def train_command(
|
||||||
train_dataset: Path,
|
train_dataset: Path,
|
||||||
val_dataset: Path | None = None,
|
val_dataset: Optional[Path] = None,
|
||||||
model_path: Path | None = None,
|
model_path: Optional[Path] = None,
|
||||||
ckpt_dir: Path | None = None,
|
ckpt_dir: Optional[Path] = None,
|
||||||
log_dir: Path | None = None,
|
log_dir: Optional[Path] = None,
|
||||||
targets_config: Path | None = None,
|
config: Optional[Path] = None,
|
||||||
model_config: Path | None = None,
|
targets_config: Optional[Path] = None,
|
||||||
training_config: Path | None = None,
|
config_field: Optional[str] = None,
|
||||||
audio_config: Path | None = None,
|
seed: Optional[int] = None,
|
||||||
evaluation_config: Path | None = None,
|
num_epochs: Optional[int] = None,
|
||||||
inference_config: Path | None = None,
|
|
||||||
outputs_config: Path | None = None,
|
|
||||||
logging_config: Path | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
num_epochs: int | None = None,
|
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
val_workers: int = 0,
|
val_workers: int = 0,
|
||||||
experiment_name: str | None = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: str | None = None,
|
run_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.config import (
|
||||||
from batdetect2.config import BatDetect2Config
|
BatDetect2Config,
|
||||||
|
load_full_config,
|
||||||
|
)
|
||||||
from batdetect2.data import load_dataset_from_config
|
from batdetect2.data import load_dataset_from_config
|
||||||
from batdetect2.evaluate import EvaluationConfig
|
from batdetect2.targets import load_target_config
|
||||||
from batdetect2.inference import InferenceConfig
|
|
||||||
from batdetect2.logging import AppLoggingConfig
|
|
||||||
from batdetect2.models import ModelConfig
|
|
||||||
from batdetect2.outputs import OutputsConfig
|
|
||||||
from batdetect2.targets import TargetConfig
|
|
||||||
from batdetect2.train import TrainingConfig
|
|
||||||
|
|
||||||
logger.info("Initiating training process...")
|
logger.info("Initiating training process...")
|
||||||
|
|
||||||
logger.info("Loading configuration...")
|
logger.info("Loading configuration...")
|
||||||
target_conf = (
|
conf = (
|
||||||
TargetConfig.load(targets_config)
|
load_full_config(config, field=config_field)
|
||||||
if targets_config is not None
|
if config is not None
|
||||||
else None
|
else BatDetect2Config()
|
||||||
)
|
|
||||||
model_conf = (
|
|
||||||
ModelConfig.load(model_config) if model_config is not None else None
|
|
||||||
)
|
|
||||||
train_conf = (
|
|
||||||
TrainingConfig.load(training_config)
|
|
||||||
if training_config is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
audio_conf = (
|
|
||||||
AudioConfig.load(audio_config) if audio_config is not None else None
|
|
||||||
)
|
|
||||||
eval_conf = (
|
|
||||||
EvaluationConfig.load(evaluation_config)
|
|
||||||
if evaluation_config is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
inference_conf = (
|
|
||||||
InferenceConfig.load(inference_config)
|
|
||||||
if inference_config is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
outputs_conf = (
|
|
||||||
OutputsConfig.load(outputs_config)
|
|
||||||
if outputs_config is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
logging_conf = (
|
|
||||||
AppLoggingConfig.load(logging_config)
|
|
||||||
if logging_config is not None
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if target_conf is not None:
|
if targets_config is not None:
|
||||||
logger.info("Loaded targets configuration.")
|
logger.info("Loading targets configuration...")
|
||||||
|
conf = conf.model_copy(
|
||||||
if model_conf is not None and target_conf is not None:
|
update=dict(targets=load_target_config(targets_config))
|
||||||
model_conf = model_conf.model_copy(update={"targets": target_conf})
|
)
|
||||||
|
|
||||||
logger.info("Loading training dataset...")
|
logger.info("Loading training dataset...")
|
||||||
train_annotations = load_dataset_from_config(train_dataset)
|
train_annotations = load_dataset_from_config(train_dataset)
|
||||||
@ -126,43 +82,12 @@ def train_command(
|
|||||||
|
|
||||||
logger.info("Configuration and data loaded. Starting training...")
|
logger.info("Configuration and data loaded. Starting training...")
|
||||||
|
|
||||||
if model_path is not None and model_conf is not None:
|
|
||||||
raise click.UsageError(
|
|
||||||
"--model-config cannot be used with --model. "
|
|
||||||
"Checkpoint model configuration is loaded from the checkpoint."
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
conf = BatDetect2Config()
|
|
||||||
if model_conf is not None:
|
|
||||||
conf.model = model_conf
|
|
||||||
elif target_conf is not None:
|
|
||||||
conf.model = conf.model.model_copy(update={"targets": target_conf})
|
|
||||||
|
|
||||||
if train_conf is not None:
|
|
||||||
conf.train = train_conf
|
|
||||||
if audio_conf is not None:
|
|
||||||
conf.audio = audio_conf
|
|
||||||
if eval_conf is not None:
|
|
||||||
conf.evaluation = eval_conf
|
|
||||||
if inference_conf is not None:
|
|
||||||
conf.inference = inference_conf
|
|
||||||
if outputs_conf is not None:
|
|
||||||
conf.outputs = outputs_conf
|
|
||||||
if logging_conf is not None:
|
|
||||||
conf.logging = logging_conf
|
|
||||||
|
|
||||||
api = BatDetect2API.from_config(conf)
|
api = BatDetect2API.from_config(conf)
|
||||||
else:
|
else:
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(
|
||||||
model_path,
|
model_path,
|
||||||
targets_config=target_conf,
|
config=conf if config is not None else None,
|
||||||
train_config=train_conf,
|
|
||||||
audio_config=audio_conf,
|
|
||||||
evaluation_config=eval_conf,
|
|
||||||
inference_config=inference_conf,
|
|
||||||
outputs_config=outputs_conf,
|
|
||||||
logging_config=logging_conf,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return api.train(
|
return api.train(
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -17,7 +17,7 @@ from batdetect2.types import (
|
|||||||
FileAnnotation,
|
FileAnnotation,
|
||||||
)
|
)
|
||||||
|
|
||||||
PathLike = Path | str | os.PathLike
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"convert_to_annotation_group",
|
"convert_to_annotation_group",
|
||||||
@ -33,7 +33,7 @@ UNKNOWN_CLASS = "__UNKNOWN__"
|
|||||||
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
|
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
|
||||||
|
|
||||||
|
|
||||||
EventFn = Callable[[data.SoundEventAnnotation], str | None]
|
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||||
|
|
||||||
ClassFn = Callable[[data.Recording], int]
|
ClassFn = Callable[[data.Recording], int]
|
||||||
|
|
||||||
@ -103,17 +103,17 @@ def convert_to_annotation_group(
|
|||||||
y_inds.append(0)
|
y_inds.append(0)
|
||||||
|
|
||||||
annotations.append(
|
annotations.append(
|
||||||
Annotation(
|
{
|
||||||
start_time=start_time,
|
"start_time": start_time,
|
||||||
end_time=end_time,
|
"end_time": end_time,
|
||||||
low_freq=low_freq,
|
"low_freq": low_freq,
|
||||||
high_freq=high_freq,
|
"high_freq": high_freq,
|
||||||
class_prob=1.0,
|
"class_prob": 1.0,
|
||||||
det_prob=1.0,
|
"det_prob": 1.0,
|
||||||
individual="0",
|
"individual": "0",
|
||||||
event=event,
|
"event": event,
|
||||||
class_id=class_id,
|
"class_id": class_id, # type: ignore
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@ -221,7 +221,7 @@ def annotation_to_sound_event_prediction(
|
|||||||
|
|
||||||
def file_annotation_to_clip(
|
def file_annotation_to_clip(
|
||||||
file_annotation: FileAnnotation,
|
file_annotation: FileAnnotation,
|
||||||
audio_dir: PathLike | None = None,
|
audio_dir: Optional[PathLike] = None,
|
||||||
label_key: str = "class",
|
label_key: str = "class",
|
||||||
) -> data.Clip:
|
) -> data.Clip:
|
||||||
"""Convert file annotation to recording."""
|
"""Convert file annotation to recording."""
|
||||||
|
|||||||
@ -1,20 +1,28 @@
|
|||||||
from typing import Literal
|
from typing import Literal, Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.audio import AudioConfig
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.data.predictions import OutputFormatConfig
|
||||||
|
from batdetect2.data.predictions.raw import RawOutputConfig
|
||||||
from batdetect2.evaluate.config import (
|
from batdetect2.evaluate.config import (
|
||||||
EvaluationConfig,
|
EvaluationConfig,
|
||||||
get_default_eval_config,
|
get_default_eval_config,
|
||||||
)
|
)
|
||||||
from batdetect2.inference.config import InferenceConfig
|
from batdetect2.inference.config import InferenceConfig
|
||||||
from batdetect2.logging import AppLoggingConfig
|
from batdetect2.models.config import BackboneConfig
|
||||||
from batdetect2.models import ModelConfig
|
from batdetect2.postprocess.config import PostprocessConfig
|
||||||
from batdetect2.outputs import OutputsConfig
|
from batdetect2.preprocess.config import PreprocessingConfig
|
||||||
|
from batdetect2.targets.config import TargetConfig
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import TrainingConfig
|
||||||
|
|
||||||
__all__ = ["BatDetect2Config"]
|
__all__ = [
|
||||||
|
"BatDetect2Config",
|
||||||
|
"load_full_config",
|
||||||
|
"validate_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2Config(BaseConfig):
|
class BatDetect2Config(BaseConfig):
|
||||||
@ -24,8 +32,26 @@ class BatDetect2Config(BaseConfig):
|
|||||||
evaluation: EvaluationConfig = Field(
|
evaluation: EvaluationConfig = Field(
|
||||||
default_factory=get_default_eval_config
|
default_factory=get_default_eval_config
|
||||||
)
|
)
|
||||||
model: ModelConfig = Field(default_factory=ModelConfig)
|
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||||
|
preprocess: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||||
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
|
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
|
||||||
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
|
|
||||||
|
|
||||||
|
def validate_config(config: Optional[dict]) -> BatDetect2Config:
|
||||||
|
if config is None:
|
||||||
|
return BatDetect2Config()
|
||||||
|
|
||||||
|
return BatDetect2Config.model_validate(config)
|
||||||
|
|
||||||
|
|
||||||
|
def load_full_config(
|
||||||
|
path: PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> BatDetect2Config:
|
||||||
|
return load_config(path, schema=BatDetect2Config, field=field)
|
||||||
|
|||||||
@ -1,14 +1,8 @@
|
|||||||
from batdetect2.core.configs import BaseConfig, load_config, merge_configs
|
from batdetect2.core.configs import BaseConfig, load_config, merge_configs
|
||||||
from batdetect2.core.registries import (
|
from batdetect2.core.registries import Registry
|
||||||
ImportConfig,
|
|
||||||
Registry,
|
|
||||||
add_import_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"add_import_config",
|
|
||||||
"BaseConfig",
|
"BaseConfig",
|
||||||
"ImportConfig",
|
|
||||||
"load_config",
|
"load_config",
|
||||||
"Registry",
|
"Registry",
|
||||||
"merge_configs",
|
"merge_configs",
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
@ -84,8 +86,8 @@ def adjust_width(
|
|||||||
|
|
||||||
def slice_tensor(
|
def slice_tensor(
|
||||||
tensor: torch.Tensor,
|
tensor: torch.Tensor,
|
||||||
start: int | None = None,
|
start: Optional[int] = None,
|
||||||
end: int | None = None,
|
end: Optional[int] = None,
|
||||||
dim: int = -1,
|
dim: int = -1,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
slices = [slice(None)] * tensor.ndim
|
slices = [slice(None)] * tensor.ndim
|
||||||
|
|||||||
@ -8,11 +8,11 @@ configuration data from files, with optional support for accessing nested
|
|||||||
configuration sections.
|
configuration sections.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Literal, Type, TypeVar, overload
|
from typing import Any, Optional, Type, TypeVar
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from deepmerge.merger import Merger
|
from deepmerge.merger import Merger
|
||||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
from pydantic import BaseModel, ConfigDict
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -21,8 +21,6 @@ __all__ = [
|
|||||||
"merge_configs",
|
"merge_configs",
|
||||||
]
|
]
|
||||||
|
|
||||||
C = TypeVar("C", bound="BaseConfig")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseConfig(BaseModel):
|
class BaseConfig(BaseModel):
|
||||||
"""Base class for all configuration models in BatDetect2.
|
"""Base class for all configuration models in BatDetect2.
|
||||||
@ -64,30 +62,8 @@ class BaseConfig(BaseModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_yaml(cls, yaml_str: str):
|
|
||||||
return cls.model_validate(yaml.safe_load(yaml_str))
|
|
||||||
|
|
||||||
@classmethod
|
T = TypeVar("T", bound=BaseModel)
|
||||||
def load(
|
|
||||||
cls: Type[C],
|
|
||||||
path: PathLike,
|
|
||||||
field: str | None = None,
|
|
||||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
|
||||||
strict: bool | None = None,
|
|
||||||
) -> C:
|
|
||||||
return load_config(
|
|
||||||
path,
|
|
||||||
schema=cls,
|
|
||||||
field=field,
|
|
||||||
extra=extra,
|
|
||||||
strict=strict,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
T_Model = TypeVar("T_Model", bound=BaseModel)
|
|
||||||
Schema = Type[T_Model] | TypeAdapter[T]
|
|
||||||
|
|
||||||
|
|
||||||
def get_object_field(obj: dict, current_key: str) -> Any:
|
def get_object_field(obj: dict, current_key: str) -> Any:
|
||||||
@ -149,69 +125,35 @@ def get_object_field(obj: dict, current_key: str) -> Any:
|
|||||||
return get_object_field(subobj, rest)
|
return get_object_field(subobj, rest)
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def load_config(
|
def load_config(
|
||||||
path: PathLike,
|
path: PathLike,
|
||||||
schema: Type[T_Model],
|
schema: Type[T],
|
||||||
field: str | None = None,
|
field: Optional[str] = None,
|
||||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
) -> T:
|
||||||
strict: bool | None = None,
|
|
||||||
) -> T_Model: ...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def load_config(
|
|
||||||
path: PathLike,
|
|
||||||
schema: TypeAdapter[T],
|
|
||||||
field: str | None = None,
|
|
||||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
|
||||||
strict: bool | None = None,
|
|
||||||
) -> T: ...
|
|
||||||
|
|
||||||
|
|
||||||
def load_config(
|
|
||||||
path: PathLike,
|
|
||||||
schema: Type[T_Model] | TypeAdapter[T],
|
|
||||||
field: str | None = None,
|
|
||||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
|
||||||
strict: bool | None = None,
|
|
||||||
) -> T_Model | T:
|
|
||||||
"""Load and validate configuration data from a file against a schema.
|
"""Load and validate configuration data from a file against a schema.
|
||||||
|
|
||||||
Reads a YAML file, optionally extracts a specific section using dot
|
Reads a YAML file, optionally extracts a specific section using dot
|
||||||
notation, and then validates the resulting data against the provided
|
notation, and then validates the resulting data against the provided
|
||||||
Pydantic schema.
|
Pydantic `schema`.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
path : PathLike
|
path : PathLike
|
||||||
The path to the configuration file (typically `.yaml`).
|
The path to the configuration file (typically `.yaml`).
|
||||||
schema : Type[T_Model] | TypeAdapter[T]
|
schema : Type[T]
|
||||||
Either a Pydantic `BaseModel` subclass or a `TypeAdapter` instance
|
The Pydantic `BaseModel` subclass that defines the expected structure
|
||||||
that defines the expected structure and types for the configuration
|
and types for the configuration data.
|
||||||
data.
|
|
||||||
field : str, optional
|
field : str, optional
|
||||||
A dot-separated string indicating a nested section within the YAML
|
A dot-separated string indicating a nested section within the YAML
|
||||||
file to extract before validation. If None (default), the entire
|
file to extract before validation. If None (default), the entire
|
||||||
file content is validated against the schema.
|
file content is validated against the schema.
|
||||||
Example: `"training.optimizer"` would extract the `optimizer` section
|
Example: `"training.optimizer"` would extract the `optimizer` section
|
||||||
within the `training` section.
|
within the `training` section.
|
||||||
extra : Literal["ignore", "allow", "forbid"], optional
|
|
||||||
How to handle extra keys in the configuration data. If None (default),
|
|
||||||
the default behaviour of the schema is used. If "ignore", extra keys
|
|
||||||
are ignored. If "allow", extra keys are allowed and will be accessible
|
|
||||||
as attributes on the resulting model instance. If "forbid", extra
|
|
||||||
keys are forbidden and an exception is raised. See pydantic
|
|
||||||
documentation for more details.
|
|
||||||
strict : bool, optional
|
|
||||||
Whether to enforce types strictly. If None (default), the default
|
|
||||||
behaviour of the schema is used. See pydantic documentation for more
|
|
||||||
details.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
T_Model | T
|
T
|
||||||
An instance of the schema type, populated and validated with
|
An instance of the provided `schema`, populated and validated with
|
||||||
data from the configuration file.
|
data from the configuration file.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
@ -237,10 +179,7 @@ def load_config(
|
|||||||
if field:
|
if field:
|
||||||
config = get_object_field(config, field)
|
config = get_object_field(config, field)
|
||||||
|
|
||||||
if isinstance(schema, TypeAdapter):
|
return schema.model_validate(config or {})
|
||||||
return schema.validate_python(config or {}, extra=extra, strict=strict)
|
|
||||||
|
|
||||||
return schema.model_validate(config or {}, extra=extra, strict=strict)
|
|
||||||
|
|
||||||
|
|
||||||
default_merger = Merger(
|
default_merger = Merger(
|
||||||
@ -250,7 +189,7 @@ default_merger = Merger(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def merge_configs(config1: T_Model, config2: T_Model) -> T_Model:
|
def merge_configs(config1: T, config2: T) -> T:
|
||||||
"""Merge two configuration objects."""
|
"""Merge two configuration objects."""
|
||||||
model = type(config1)
|
model = type(config1)
|
||||||
dict1 = config1.model_dump()
|
dict1 = config1.model_dump()
|
||||||
|
|||||||
@ -1,28 +1,23 @@
|
|||||||
from typing import (
|
import sys
|
||||||
Any,
|
from typing import Callable, Dict, Generic, Tuple, Type, TypeVar
|
||||||
Callable,
|
|
||||||
Concatenate,
|
|
||||||
Generic,
|
|
||||||
ParamSpec,
|
|
||||||
Sequence,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
|
|
||||||
from hydra.utils import instantiate
|
from pydantic import BaseModel
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
from typing import Concatenate, ParamSpec
|
||||||
|
else:
|
||||||
|
from typing_extensions import Concatenate, ParamSpec
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"add_import_config",
|
|
||||||
"ImportConfig",
|
|
||||||
"Registry",
|
"Registry",
|
||||||
"SimpleRegistry",
|
"SimpleRegistry",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
|
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
|
||||||
T_Type = TypeVar("T_Type", covariant=True)
|
T_Type = TypeVar("T_Type", covariant=True)
|
||||||
P_Type = ParamSpec("P_Type")
|
P_Type = ParamSpec("P_Type")
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
@ -48,13 +43,12 @@ class SimpleRegistry(Generic[T]):
|
|||||||
class Registry(Generic[T_Type, P_Type]):
|
class Registry(Generic[T_Type, P_Type]):
|
||||||
"""A generic class to create and manage a registry of items."""
|
"""A generic class to create and manage a registry of items."""
|
||||||
|
|
||||||
def __init__(self, name: str, discriminator: str = "name"):
|
def __init__(self, name: str):
|
||||||
self._name = name
|
self._name = name
|
||||||
self._registry: dict[
|
self._registry: Dict[
|
||||||
str, Callable[Concatenate[..., P_Type], T_Type]
|
str, Callable[Concatenate[..., P_Type], T_Type]
|
||||||
] = {}
|
] = {}
|
||||||
self._discriminator = discriminator
|
self._config_types: Dict[str, Type[BaseModel]] = {}
|
||||||
self._config_types: dict[str, Type[BaseModel]] = {}
|
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self,
|
self,
|
||||||
@ -62,20 +56,15 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
):
|
):
|
||||||
fields = config_cls.model_fields
|
fields = config_cls.model_fields
|
||||||
|
|
||||||
if self._discriminator not in fields:
|
if "name" not in fields:
|
||||||
raise ValueError(
|
raise ValueError("Configuration object must have a 'name' field.")
|
||||||
"Configuration object must have "
|
|
||||||
f"a '{self._discriminator}' field."
|
|
||||||
)
|
|
||||||
|
|
||||||
name = fields[self._discriminator].default
|
name = fields["name"].default
|
||||||
|
|
||||||
self._config_types[name] = config_cls
|
self._config_types[name] = config_cls
|
||||||
|
|
||||||
if not isinstance(name, str):
|
if not isinstance(name, str):
|
||||||
raise ValueError(
|
raise ValueError("'name' field must be a string literal.")
|
||||||
f"'{self._discriminator}' field must be a string literal."
|
|
||||||
)
|
|
||||||
|
|
||||||
def decorator(
|
def decorator(
|
||||||
func: Callable[Concatenate[T_Config, P_Type], T_Type],
|
func: Callable[Concatenate[T_Config, P_Type], T_Type],
|
||||||
@ -85,7 +74,7 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def get_config_types(self) -> tuple[Type[BaseModel], ...]:
|
def get_config_types(self) -> Tuple[Type[BaseModel], ...]:
|
||||||
return tuple(self._config_types.values())
|
return tuple(self._config_types.values())
|
||||||
|
|
||||||
def get_config_type(self, name: str) -> Type[BaseModel]:
|
def get_config_type(self, name: str) -> Type[BaseModel]:
|
||||||
@ -105,12 +94,10 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
) -> T_Type:
|
) -> T_Type:
|
||||||
"""Builds a logic instance from a config object."""
|
"""Builds a logic instance from a config object."""
|
||||||
|
|
||||||
name = getattr(config, self._discriminator) # noqa: B009
|
name = getattr(config, "name") # noqa: B009
|
||||||
|
|
||||||
if name is None:
|
if name is None:
|
||||||
raise ValueError(
|
raise ValueError("Config does not have a name field")
|
||||||
f"Config does not have a '{self._discriminator}' field"
|
|
||||||
)
|
|
||||||
|
|
||||||
if name not in self._registry:
|
if name not in self._registry:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@ -118,92 +105,3 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return self._registry[name](config, *args, **kwargs)
|
return self._registry[name](config, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ImportConfig(BaseModel):
|
|
||||||
"""Base config for dynamic instantiation via Hydra.
|
|
||||||
|
|
||||||
Subclass this to create a registry-specific import escape hatch.
|
|
||||||
The subclass must add a discriminator field whose name matches the
|
|
||||||
registry's own discriminator key, with its value fixed to
|
|
||||||
``Literal["import"]``.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
target : str
|
|
||||||
Fully-qualified dotted path to the callable to instantiate,
|
|
||||||
e.g. ``"mypackage.module.MyClass"``.
|
|
||||||
arguments : dict[str, Any]
|
|
||||||
Base keyword arguments forwarded to the callable. When the
|
|
||||||
same key also appears in ``kwargs`` passed to ``build()``,
|
|
||||||
the ``kwargs`` value takes priority.
|
|
||||||
"""
|
|
||||||
|
|
||||||
target: str
|
|
||||||
arguments: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
T_Import = TypeVar("T_Import", bound=ImportConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def add_import_config(
|
|
||||||
registry: Registry[T_Type, P_Type],
|
|
||||||
arg_names: Sequence[str] | None = None,
|
|
||||||
) -> Callable[[Type[T_Import]], Type[T_Import]]:
|
|
||||||
"""Decorator that registers an ImportConfig subclass as an escape hatch.
|
|
||||||
|
|
||||||
Wraps the decorated class in a builder that calls
|
|
||||||
``hydra.utils.instantiate`` using ``config.target`` and
|
|
||||||
``config.arguments``. The builder is registered on *registry*
|
|
||||||
under the discriminator value ``"import"``.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
registry : Registry
|
|
||||||
The registry instance on which the config should be registered.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Callable[[type[ImportConfig]], type[ImportConfig]]
|
|
||||||
A class decorator that registers the class and returns it
|
|
||||||
unchanged.
|
|
||||||
|
|
||||||
Examples
|
|
||||||
--------
|
|
||||||
Define a per-registry import escape hatch::
|
|
||||||
|
|
||||||
@add_import_config(my_registry)
|
|
||||||
class MyRegistryImportConfig(ImportConfig):
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(config_cls: Type[T_Import]) -> Type[T_Import]:
|
|
||||||
def builder(
|
|
||||||
config: T_Import,
|
|
||||||
*args: P_Type.args,
|
|
||||||
**kwargs: P_Type.kwargs,
|
|
||||||
) -> T_Type:
|
|
||||||
_arg_names = arg_names or []
|
|
||||||
|
|
||||||
if len(args) != len(_arg_names):
|
|
||||||
raise ValueError(
|
|
||||||
"Positional arguments are not supported "
|
|
||||||
"for import escape hatch unless you specify "
|
|
||||||
"the argument names. Use `arg_names` to specify "
|
|
||||||
"the names of the positional arguments."
|
|
||||||
)
|
|
||||||
|
|
||||||
args_dict = {_arg_names[i]: args[i] for i in range(len(args))}
|
|
||||||
|
|
||||||
hydra_cfg = {
|
|
||||||
"_target_": config.target,
|
|
||||||
**config.arguments,
|
|
||||||
**args_dict,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
return instantiate(hydra_cfg)
|
|
||||||
|
|
||||||
registry.register(config_cls)(builder)
|
|
||||||
return config_cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|||||||
@ -7,12 +7,20 @@ from batdetect2.data.annotations import (
|
|||||||
load_annotated_dataset,
|
load_annotated_dataset,
|
||||||
)
|
)
|
||||||
from batdetect2.data.datasets import (
|
from batdetect2.data.datasets import (
|
||||||
Dataset,
|
|
||||||
DatasetConfig,
|
DatasetConfig,
|
||||||
load_dataset,
|
load_dataset,
|
||||||
load_dataset_config,
|
load_dataset_config,
|
||||||
load_dataset_from_config,
|
load_dataset_from_config,
|
||||||
)
|
)
|
||||||
|
from batdetect2.data.predictions import (
|
||||||
|
BatDetect2OutputConfig,
|
||||||
|
OutputFormatConfig,
|
||||||
|
RawOutputConfig,
|
||||||
|
SoundEventOutputConfig,
|
||||||
|
build_output_formatter,
|
||||||
|
get_output_formatter,
|
||||||
|
load_predictions,
|
||||||
|
)
|
||||||
from batdetect2.data.summary import (
|
from batdetect2.data.summary import (
|
||||||
compute_class_summary,
|
compute_class_summary,
|
||||||
extract_recordings_df,
|
extract_recordings_df,
|
||||||
@ -20,7 +28,6 @@ from batdetect2.data.summary import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Dataset",
|
|
||||||
"AOEFAnnotations",
|
"AOEFAnnotations",
|
||||||
"AnnotatedDataset",
|
"AnnotatedDataset",
|
||||||
"AnnotationFormats",
|
"AnnotationFormats",
|
||||||
@ -29,7 +36,6 @@ __all__ = [
|
|||||||
"BatDetect2OutputConfig",
|
"BatDetect2OutputConfig",
|
||||||
"DatasetConfig",
|
"DatasetConfig",
|
||||||
"OutputFormatConfig",
|
"OutputFormatConfig",
|
||||||
"ParquetOutputConfig",
|
|
||||||
"RawOutputConfig",
|
"RawOutputConfig",
|
||||||
"SoundEventOutputConfig",
|
"SoundEventOutputConfig",
|
||||||
"build_output_formatter",
|
"build_output_formatter",
|
||||||
|
|||||||
@ -13,18 +13,22 @@ format-specific loading function to retrieve the annotations as a standard
|
|||||||
`soundevent.data.AnnotationSet`.
|
`soundevent.data.AnnotationSet`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated
|
from typing import Annotated, Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.data.annotations.aoef import AOEFAnnotations
|
from batdetect2.data.annotations.aoef import (
|
||||||
|
AOEFAnnotations,
|
||||||
|
load_aoef_annotated_dataset,
|
||||||
|
)
|
||||||
from batdetect2.data.annotations.batdetect2 import (
|
from batdetect2.data.annotations.batdetect2 import (
|
||||||
AnnotationFilter,
|
AnnotationFilter,
|
||||||
BatDetect2FilesAnnotations,
|
BatDetect2FilesAnnotations,
|
||||||
BatDetect2MergedAnnotations,
|
BatDetect2MergedAnnotations,
|
||||||
|
load_batdetect2_files_annotated_dataset,
|
||||||
|
load_batdetect2_merged_annotated_dataset,
|
||||||
)
|
)
|
||||||
from batdetect2.data.annotations.registry import annotation_format_registry
|
|
||||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -39,7 +43,11 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
AnnotationFormats = Annotated[
|
AnnotationFormats = Annotated[
|
||||||
BatDetect2MergedAnnotations | BatDetect2FilesAnnotations | AOEFAnnotations,
|
Union[
|
||||||
|
BatDetect2MergedAnnotations,
|
||||||
|
BatDetect2FilesAnnotations,
|
||||||
|
AOEFAnnotations,
|
||||||
|
],
|
||||||
Field(discriminator="format"),
|
Field(discriminator="format"),
|
||||||
]
|
]
|
||||||
"""Type Alias representing all supported data source configurations.
|
"""Type Alias representing all supported data source configurations.
|
||||||
@ -55,24 +63,24 @@ source configuration represents.
|
|||||||
|
|
||||||
def load_annotated_dataset(
|
def load_annotated_dataset(
|
||||||
dataset: AnnotatedDataset,
|
dataset: AnnotatedDataset,
|
||||||
base_dir: data.PathLike | None = None,
|
base_dir: Optional[data.PathLike] = None,
|
||||||
) -> data.AnnotationSet:
|
) -> data.AnnotationSet:
|
||||||
"""Load annotations for a single data source based on its configuration.
|
"""Load annotations for a single data source based on its configuration.
|
||||||
|
|
||||||
This function acts as a dispatcher. It inspects the format of the input
|
This function acts as a dispatcher. It inspects the type of the input
|
||||||
`dataset` object and delegates to the appropriate format-specific loader
|
`source_config` object (which corresponds to a specific annotation format)
|
||||||
registered in the `annotation_format_registry` (e.g.,
|
and calls the appropriate loading function (e.g.,
|
||||||
`AOEFLoader` for `AOEFAnnotations`).
|
`load_aoef_annotated_dataset` for `AOEFAnnotations`).
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
dataset : AnnotatedDataset
|
source_config : AnnotationFormats
|
||||||
The configuration object for the data source, specifying its format
|
The configuration object for the data source, specifying its format
|
||||||
and necessary details (like paths). Must be an instance of one of the
|
and necessary details (like paths). Must be an instance of one of the
|
||||||
types included in the `AnnotationFormats` union.
|
types included in the `AnnotationFormats` union.
|
||||||
base_dir : Path, optional
|
base_dir : Path, optional
|
||||||
An optional base directory path. If provided, relative paths within
|
An optional base directory path. If provided, relative paths within
|
||||||
the `dataset` will be resolved relative to this directory by
|
the `source_config` might be resolved relative to this directory by
|
||||||
the underlying loading functions. Defaults to None.
|
the underlying loading functions. Defaults to None.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -84,8 +92,23 @@ def load_annotated_dataset(
|
|||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
NotImplementedError
|
NotImplementedError
|
||||||
If the `format` field of `dataset` does not match any registered
|
If the type of the `source_config` object does not match any of the
|
||||||
annotation format loader.
|
known format-specific loading functions implemented in the dispatch
|
||||||
|
logic.
|
||||||
"""
|
"""
|
||||||
loader = annotation_format_registry.build(dataset)
|
|
||||||
return loader.load(base_dir=base_dir)
|
if isinstance(dataset, AOEFAnnotations):
|
||||||
|
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)
|
||||||
|
|
||||||
|
if isinstance(dataset, BatDetect2MergedAnnotations):
|
||||||
|
return load_batdetect2_merged_annotated_dataset(
|
||||||
|
dataset, base_dir=base_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(dataset, BatDetect2FilesAnnotations):
|
||||||
|
return load_batdetect2_files_annotated_dataset(
|
||||||
|
dataset,
|
||||||
|
base_dir=base_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise NotImplementedError(f"Unknown annotation format: {dataset.name}")
|
||||||
|
|||||||
@ -12,22 +12,17 @@ that meet specific status criteria (e.g., completed, verified, without issues).
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal, Optional
|
||||||
from uuid import uuid5
|
from uuid import uuid5
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data, io
|
from soundevent import data, io
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data.annotations.registry import annotation_format_registry
|
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||||
from batdetect2.data.annotations.types import (
|
|
||||||
AnnotatedDataset,
|
|
||||||
AnnotationLoader,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AOEFAnnotations",
|
"AOEFAnnotations",
|
||||||
"AOEFLoader",
|
|
||||||
"load_aoef_annotated_dataset",
|
"load_aoef_annotated_dataset",
|
||||||
"AnnotationTaskFilter",
|
"AnnotationTaskFilter",
|
||||||
]
|
]
|
||||||
@ -82,30 +77,14 @@ class AOEFAnnotations(AnnotatedDataset):
|
|||||||
|
|
||||||
annotations_path: Path
|
annotations_path: Path
|
||||||
|
|
||||||
filter: AnnotationTaskFilter | None = Field(
|
filter: Optional[AnnotationTaskFilter] = Field(
|
||||||
default_factory=AnnotationTaskFilter
|
default_factory=AnnotationTaskFilter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AOEFLoader(AnnotationLoader):
|
|
||||||
def __init__(self, config: AOEFAnnotations):
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def load(
|
|
||||||
self,
|
|
||||||
base_dir: data.PathLike | None = None,
|
|
||||||
) -> data.AnnotationSet:
|
|
||||||
return load_aoef_annotated_dataset(self.config, base_dir=base_dir)
|
|
||||||
|
|
||||||
@annotation_format_registry.register(AOEFAnnotations)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: AOEFAnnotations):
|
|
||||||
return AOEFLoader(config)
|
|
||||||
|
|
||||||
|
|
||||||
def load_aoef_annotated_dataset(
|
def load_aoef_annotated_dataset(
|
||||||
dataset: AOEFAnnotations,
|
dataset: AOEFAnnotations,
|
||||||
base_dir: data.PathLike | None = None,
|
base_dir: Optional[data.PathLike] = None,
|
||||||
) -> data.AnnotationSet:
|
) -> data.AnnotationSet:
|
||||||
"""Load annotations from an AnnotationSet or AnnotationProject file.
|
"""Load annotations from an AnnotationSet or AnnotationProject file.
|
||||||
|
|
||||||
|
|||||||
@ -27,7 +27,7 @@ aggregated into a `soundevent.data.AnnotationSet`.
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field, ValidationError
|
from pydantic import Field, ValidationError
|
||||||
@ -41,13 +41,9 @@ from batdetect2.data.annotations.legacy import (
|
|||||||
list_file_annotations,
|
list_file_annotations,
|
||||||
load_file_annotation,
|
load_file_annotation,
|
||||||
)
|
)
|
||||||
from batdetect2.data.annotations.registry import annotation_format_registry
|
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||||
from batdetect2.data.annotations.types import (
|
|
||||||
AnnotatedDataset,
|
|
||||||
AnnotationLoader,
|
|
||||||
)
|
|
||||||
|
|
||||||
PathLike = Path | str | os.PathLike
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -106,7 +102,7 @@ class BatDetect2FilesAnnotations(AnnotatedDataset):
|
|||||||
format: Literal["batdetect2"] = "batdetect2"
|
format: Literal["batdetect2"] = "batdetect2"
|
||||||
annotations_dir: Path
|
annotations_dir: Path
|
||||||
|
|
||||||
filter: AnnotationFilter | None = Field(
|
filter: Optional[AnnotationFilter] = Field(
|
||||||
default_factory=AnnotationFilter,
|
default_factory=AnnotationFilter,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -137,14 +133,14 @@ class BatDetect2MergedAnnotations(AnnotatedDataset):
|
|||||||
format: Literal["batdetect2_file"] = "batdetect2_file"
|
format: Literal["batdetect2_file"] = "batdetect2_file"
|
||||||
annotations_path: Path
|
annotations_path: Path
|
||||||
|
|
||||||
filter: AnnotationFilter | None = Field(
|
filter: Optional[AnnotationFilter] = Field(
|
||||||
default_factory=AnnotationFilter,
|
default_factory=AnnotationFilter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_batdetect2_files_annotated_dataset(
|
def load_batdetect2_files_annotated_dataset(
|
||||||
dataset: BatDetect2FilesAnnotations,
|
dataset: BatDetect2FilesAnnotations,
|
||||||
base_dir: PathLike | None = None,
|
base_dir: Optional[PathLike] = None,
|
||||||
) -> data.AnnotationSet:
|
) -> data.AnnotationSet:
|
||||||
"""Load and convert 'batdetect2_file' annotations into an AnnotationSet.
|
"""Load and convert 'batdetect2_file' annotations into an AnnotationSet.
|
||||||
|
|
||||||
@ -248,7 +244,7 @@ def load_batdetect2_files_annotated_dataset(
|
|||||||
|
|
||||||
def load_batdetect2_merged_annotated_dataset(
|
def load_batdetect2_merged_annotated_dataset(
|
||||||
dataset: BatDetect2MergedAnnotations,
|
dataset: BatDetect2MergedAnnotations,
|
||||||
base_dir: PathLike | None = None,
|
base_dir: Optional[PathLike] = None,
|
||||||
) -> data.AnnotationSet:
|
) -> data.AnnotationSet:
|
||||||
"""Load and convert 'batdetect2_merged' annotations into an AnnotationSet.
|
"""Load and convert 'batdetect2_merged' annotations into an AnnotationSet.
|
||||||
|
|
||||||
@ -306,7 +302,7 @@ def load_batdetect2_merged_annotated_dataset(
|
|||||||
try:
|
try:
|
||||||
ann = FileAnnotation.model_validate(ann)
|
ann = FileAnnotation.model_validate(ann)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
logger.warning("Invalid annotation file: {err}", err=err)
|
logger.warning(f"Invalid annotation file: {err}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -314,23 +310,17 @@ def load_batdetect2_merged_annotated_dataset(
|
|||||||
and dataset.filter.only_annotated
|
and dataset.filter.only_annotated
|
||||||
and not ann.annotated
|
and not ann.annotated
|
||||||
):
|
):
|
||||||
logger.debug(
|
logger.debug(f"Skipping incomplete annotation {ann.id}")
|
||||||
"Skipping incomplete annotation {ann_id}",
|
|
||||||
ann_id=ann.id,
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
|
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
|
||||||
logger.debug(
|
logger.debug(f"Skipping annotation with issues {ann.id}")
|
||||||
"Skipping annotation with issues {ann_id}",
|
|
||||||
ann_id=ann.id,
|
|
||||||
)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
||||||
except FileNotFoundError as err:
|
except FileNotFoundError as err:
|
||||||
logger.warning("Error loading annotations: {err}", err=err)
|
logger.warning(f"Error loading annotations: {err}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
||||||
@ -340,41 +330,3 @@ def load_batdetect2_merged_annotated_dataset(
|
|||||||
description=dataset.description,
|
description=dataset.description,
|
||||||
clip_annotations=annotations,
|
clip_annotations=annotations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2MergedLoader(AnnotationLoader):
|
|
||||||
def __init__(self, config: BatDetect2MergedAnnotations):
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def load(
|
|
||||||
self,
|
|
||||||
base_dir: PathLike | None = None,
|
|
||||||
) -> data.AnnotationSet:
|
|
||||||
return load_batdetect2_merged_annotated_dataset(
|
|
||||||
self.config,
|
|
||||||
base_dir=base_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
@annotation_format_registry.register(BatDetect2MergedAnnotations)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: BatDetect2MergedAnnotations):
|
|
||||||
return BatDetect2MergedLoader(config)
|
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2FilesLoader(AnnotationLoader):
|
|
||||||
def __init__(self, config: BatDetect2FilesAnnotations):
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def load(
|
|
||||||
self,
|
|
||||||
base_dir: PathLike | None = None,
|
|
||||||
) -> data.AnnotationSet:
|
|
||||||
return load_batdetect2_files_annotated_dataset(
|
|
||||||
self.config,
|
|
||||||
base_dir=base_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
@annotation_format_registry.register(BatDetect2FilesAnnotations)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: BatDetect2FilesAnnotations):
|
|
||||||
return BatDetect2FilesLoader(config)
|
|
||||||
|
|||||||
@ -3,12 +3,12 @@
|
|||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
PathLike = Path | str | os.PathLike
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
|
|
||||||
@ -27,7 +27,7 @@ SOUND_EVENT_ANNOTATION_NAMESPACE = uuid.uuid5(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
EventFn = Callable[[data.SoundEventAnnotation], str | None]
|
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||||
|
|
||||||
ClassFn = Callable[[data.Recording], int]
|
ClassFn = Callable[[data.Recording], int]
|
||||||
|
|
||||||
@ -130,7 +130,7 @@ def get_sound_event_tags(
|
|||||||
|
|
||||||
def file_annotation_to_clip(
|
def file_annotation_to_clip(
|
||||||
file_annotation: FileAnnotation,
|
file_annotation: FileAnnotation,
|
||||||
audio_dir: PathLike | None = None,
|
audio_dir: Optional[PathLike] = None,
|
||||||
label_key: str = "class",
|
label_key: str = "class",
|
||||||
) -> data.Clip:
|
) -> data.Clip:
|
||||||
"""Convert file annotation to recording."""
|
"""Convert file annotation to recording."""
|
||||||
|
|||||||
@ -1,35 +0,0 @@
|
|||||||
from typing import Literal
|
|
||||||
|
|
||||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
|
||||||
from batdetect2.data.annotations.types import AnnotationLoader
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AnnotationFormatImportConfig",
|
|
||||||
"annotation_format_registry",
|
|
||||||
]
|
|
||||||
|
|
||||||
annotation_format_registry: Registry[AnnotationLoader, []] = Registry(
|
|
||||||
"annotation_format",
|
|
||||||
discriminator="format",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(annotation_format_registry)
|
|
||||||
class AnnotationFormatImportConfig(ImportConfig):
|
|
||||||
"""Import escape hatch for the annotation format registry.
|
|
||||||
|
|
||||||
Use this config to dynamically instantiate any callable as an
|
|
||||||
annotation loader without registering it in
|
|
||||||
``annotation_format_registry`` ahead of time.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
format : Literal["import"]
|
|
||||||
Discriminator value; must always be ``"import"``.
|
|
||||||
target : str
|
|
||||||
Fully-qualified dotted path to the callable to instantiate.
|
|
||||||
arguments : dict[str, Any]
|
|
||||||
Keyword arguments forwarded to the callable.
|
|
||||||
"""
|
|
||||||
|
|
||||||
format: Literal["import"] = "import"
|
|
||||||
@ -1,13 +1,9 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Protocol
|
|
||||||
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnnotatedDataset",
|
"AnnotatedDataset",
|
||||||
"AnnotationLoader",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -38,10 +34,3 @@ class AnnotatedDataset(BaseConfig):
|
|||||||
name: str
|
name: str
|
||||||
audio_dir: Path
|
audio_dir: Path
|
||||||
description: str = ""
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
class AnnotationLoader(Protocol):
|
|
||||||
def load(
|
|
||||||
self,
|
|
||||||
base_dir: data.PathLike | None = None,
|
|
||||||
) -> data.AnnotationSet: ...
|
|
||||||
|
|||||||
@ -1,33 +1,18 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Annotated, List, Literal, Sequence
|
from typing import Annotated, List, Literal, Sequence, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.core.registries import (
|
from batdetect2.core.registries import Registry
|
||||||
ImportConfig,
|
|
||||||
Registry,
|
|
||||||
add_import_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||||
|
|
||||||
conditions: Registry[SoundEventCondition, []] = Registry("condition")
|
conditions: Registry[SoundEventCondition, []] = Registry("condition")
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(conditions)
|
|
||||||
class SoundEventConditionImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a sound event condition.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class HasTagConfig(BaseConfig):
|
class HasTagConfig(BaseConfig):
|
||||||
name: Literal["has_tag"] = "has_tag"
|
name: Literal["has_tag"] = "has_tag"
|
||||||
tag: data.Tag
|
tag: data.Tag
|
||||||
@ -279,14 +264,16 @@ class Not:
|
|||||||
|
|
||||||
|
|
||||||
SoundEventConditionConfig = Annotated[
|
SoundEventConditionConfig = Annotated[
|
||||||
HasTagConfig
|
Union[
|
||||||
| HasAllTagsConfig
|
HasTagConfig,
|
||||||
| HasAnyTagConfig
|
HasAllTagsConfig,
|
||||||
| DurationConfig
|
HasAnyTagConfig,
|
||||||
| FrequencyConfig
|
DurationConfig,
|
||||||
| AllOfConfig
|
FrequencyConfig,
|
||||||
| AnyOfConfig
|
AllOfConfig,
|
||||||
| NotConfig,
|
AnyOfConfig,
|
||||||
|
NotConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,7 @@ The core components are:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Sequence
|
from typing import List, Optional, Sequence
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -69,7 +69,7 @@ class DatasetConfig(BaseConfig):
|
|||||||
description: str
|
description: str
|
||||||
sources: List[AnnotationFormats]
|
sources: List[AnnotationFormats]
|
||||||
|
|
||||||
sound_event_filter: SoundEventConditionConfig | None = None
|
sound_event_filter: Optional[SoundEventConditionConfig] = None
|
||||||
sound_event_transforms: List[SoundEventTransformConfig] = Field(
|
sound_event_transforms: List[SoundEventTransformConfig] = Field(
|
||||||
default_factory=list
|
default_factory=list
|
||||||
)
|
)
|
||||||
@ -77,7 +77,7 @@ class DatasetConfig(BaseConfig):
|
|||||||
|
|
||||||
def load_dataset(
|
def load_dataset(
|
||||||
config: DatasetConfig,
|
config: DatasetConfig,
|
||||||
base_dir: data.PathLike | None = None,
|
base_dir: Optional[data.PathLike] = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
||||||
clip_annotations = []
|
clip_annotations = []
|
||||||
@ -161,14 +161,14 @@ def insert_source_tag(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_config(path: data.PathLike, field: str | None = None):
|
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
||||||
return load_config(path=path, schema=DatasetConfig, field=field)
|
return load_config(path=path, schema=DatasetConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_from_config(
|
def load_dataset_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: str | None = None,
|
field: Optional[str] = None,
|
||||||
base_dir: data.PathLike | None = None,
|
base_dir: Optional[data.PathLike] = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""Load dataset annotation metadata from a configuration file.
|
"""Load dataset annotation metadata from a configuration file.
|
||||||
|
|
||||||
@ -215,9 +215,9 @@ def load_dataset_from_config(
|
|||||||
def save_dataset(
|
def save_dataset(
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
name: str | None = None,
|
name: Optional[str] = None,
|
||||||
description: str | None = None,
|
description: Optional[str] = None,
|
||||||
audio_dir: Path | None = None,
|
audio_dir: Optional[Path] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save a loaded dataset (list of ClipAnnotations) to a file.
|
"""Save a loaded dataset (list of ClipAnnotations) to a file.
|
||||||
|
|
||||||
|
|||||||
@ -1,15 +1,16 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
def iterate_over_sound_events(
|
def iterate_over_sound_events(
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
) -> Generator[tuple[str | None, data.SoundEventAnnotation], None, None]:
|
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
|
||||||
"""Iterate over sound events in a dataset.
|
"""Iterate over sound events in a dataset.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -23,7 +24,7 @@ def iterate_over_sound_events(
|
|||||||
|
|
||||||
Yields
|
Yields
|
||||||
------
|
------
|
||||||
tuple[Optional[str], data.SoundEventAnnotation]
|
Tuple[Optional[str], data.SoundEventAnnotation]
|
||||||
A tuple containing:
|
A tuple containing:
|
||||||
- The encoded class name (str) for the sound event, or None if it
|
- The encoded class name (str) for the sound event, or None if it
|
||||||
cannot be encoded to a specific class.
|
cannot be encoded to a specific class.
|
||||||
|
|||||||
@ -1,42 +1,39 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated, Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.outputs.formats.base import (
|
from batdetect2.data.predictions.base import (
|
||||||
OutputFormatterProtocol,
|
OutputFormatterProtocol,
|
||||||
output_formatters,
|
prediction_formatters,
|
||||||
)
|
)
|
||||||
from batdetect2.outputs.formats.batdetect2 import BatDetect2OutputConfig
|
from batdetect2.data.predictions.batdetect2 import BatDetect2OutputConfig
|
||||||
from batdetect2.outputs.formats.parquet import ParquetOutputConfig
|
from batdetect2.data.predictions.raw import RawOutputConfig
|
||||||
from batdetect2.outputs.formats.raw import RawOutputConfig
|
from batdetect2.data.predictions.soundevent import SoundEventOutputConfig
|
||||||
from batdetect2.outputs.formats.soundevent import SoundEventOutputConfig
|
from batdetect2.typing import TargetProtocol
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BatDetect2OutputConfig",
|
|
||||||
"OutputFormatConfig",
|
|
||||||
"ParquetOutputConfig",
|
|
||||||
"RawOutputConfig",
|
|
||||||
"SoundEventOutputConfig",
|
|
||||||
"build_output_formatter",
|
"build_output_formatter",
|
||||||
"get_output_formatter",
|
"get_output_formatter",
|
||||||
"load_predictions",
|
"BatDetect2OutputConfig",
|
||||||
|
"RawOutputConfig",
|
||||||
|
"SoundEventOutputConfig",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
OutputFormatConfig = Annotated[
|
OutputFormatConfig = Annotated[
|
||||||
BatDetect2OutputConfig
|
Union[
|
||||||
| ParquetOutputConfig
|
BatDetect2OutputConfig,
|
||||||
| SoundEventOutputConfig
|
SoundEventOutputConfig,
|
||||||
| RawOutputConfig,
|
RawOutputConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_output_formatter(
|
def build_output_formatter(
|
||||||
targets: TargetProtocol | None = None,
|
targets: Optional[TargetProtocol] = None,
|
||||||
config: OutputFormatConfig | None = None,
|
config: Optional[OutputFormatConfig] = None,
|
||||||
) -> OutputFormatterProtocol:
|
) -> OutputFormatterProtocol:
|
||||||
"""Construct the final output formatter."""
|
"""Construct the final output formatter."""
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
@ -44,13 +41,13 @@ def build_output_formatter(
|
|||||||
config = config or RawOutputConfig()
|
config = config or RawOutputConfig()
|
||||||
|
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
return output_formatters.build(config, targets)
|
return prediction_formatters.build(config, targets)
|
||||||
|
|
||||||
|
|
||||||
def get_output_formatter(
|
def get_output_formatter(
|
||||||
name: str | None = None,
|
name: Optional[str] = None,
|
||||||
targets: TargetProtocol | None = None,
|
targets: Optional[TargetProtocol] = None,
|
||||||
config: OutputFormatConfig | None = None,
|
config: Optional[OutputFormatConfig] = None,
|
||||||
) -> OutputFormatterProtocol:
|
) -> OutputFormatterProtocol:
|
||||||
"""Get the output formatter by name."""
|
"""Get the output formatter by name."""
|
||||||
|
|
||||||
@ -58,7 +55,7 @@ def get_output_formatter(
|
|||||||
if name is None:
|
if name is None:
|
||||||
raise ValueError("Either config or name must be provided.")
|
raise ValueError("Either config or name must be provided.")
|
||||||
|
|
||||||
config_class = output_formatters.get_config_type(name)
|
config_class = prediction_formatters.get_config_type(name)
|
||||||
config = config_class() # type: ignore
|
config = config_class() # type: ignore
|
||||||
|
|
||||||
if config.name != name: # type: ignore
|
if config.name != name: # type: ignore
|
||||||
@ -71,9 +68,9 @@ def get_output_formatter(
|
|||||||
|
|
||||||
def load_predictions(
|
def load_predictions(
|
||||||
path: PathLike,
|
path: PathLike,
|
||||||
format: str | None = "raw",
|
format: Optional[str] = "raw",
|
||||||
config: OutputFormatConfig | None = None,
|
config: Optional[OutputFormatConfig] = None,
|
||||||
targets: TargetProtocol | None = None,
|
targets: Optional[TargetProtocol] = None,
|
||||||
):
|
):
|
||||||
"""Load predictions from a file."""
|
"""Load predictions from a file."""
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
29
src/batdetect2/data/predictions/base.py
Normal file
29
src/batdetect2/data/predictions/base.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
from batdetect2.core import Registry
|
||||||
|
from batdetect2.typing import (
|
||||||
|
OutputFormatterProtocol,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
|
||||||
|
path = Path(path)
|
||||||
|
audio_dir = Path(audio_dir)
|
||||||
|
|
||||||
|
if path.is_absolute():
|
||||||
|
if not path.is_relative_to(audio_dir):
|
||||||
|
raise ValueError(
|
||||||
|
f"Audio file {path} is not in audio_dir {audio_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return path.relative_to(audio_dir)
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
|
||||||
|
Registry(name="output_formatter")
|
||||||
|
)
|
||||||
@ -1,20 +1,23 @@
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Literal, Sequence, TypedDict
|
from typing import List, Literal, Optional, Sequence, TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.outputs.formats.base import (
|
from batdetect2.data.predictions.base import (
|
||||||
make_path_relative,
|
make_path_relative,
|
||||||
output_formatters,
|
prediction_formatters,
|
||||||
)
|
)
|
||||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
|
||||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
|
||||||
from batdetect2.targets import terms
|
from batdetect2.targets import terms
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
OutputFormatterProtocol,
|
||||||
|
RawPrediction,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import NotRequired # type: ignore
|
from typing import NotRequired # type: ignore
|
||||||
@ -25,32 +28,76 @@ DictWithClass = TypedDict("DictWithClass", {"class": str})
|
|||||||
|
|
||||||
|
|
||||||
class Annotation(DictWithClass):
|
class Annotation(DictWithClass):
|
||||||
|
"""Format of annotations.
|
||||||
|
|
||||||
|
This is the format of a single annotation as expected by the
|
||||||
|
annotation tool.
|
||||||
|
"""
|
||||||
|
|
||||||
start_time: float
|
start_time: float
|
||||||
|
"""Start time in seconds."""
|
||||||
|
|
||||||
end_time: float
|
end_time: float
|
||||||
|
"""End time in seconds."""
|
||||||
|
|
||||||
low_freq: float
|
low_freq: float
|
||||||
|
"""Low frequency in Hz."""
|
||||||
|
|
||||||
high_freq: float
|
high_freq: float
|
||||||
|
"""High frequency in Hz."""
|
||||||
|
|
||||||
class_prob: float
|
class_prob: float
|
||||||
|
"""Probability of class assignment."""
|
||||||
|
|
||||||
det_prob: float
|
det_prob: float
|
||||||
|
"""Probability of detection."""
|
||||||
|
|
||||||
individual: str
|
individual: str
|
||||||
|
"""Individual ID."""
|
||||||
|
|
||||||
event: str
|
event: str
|
||||||
|
"""Type of detected event."""
|
||||||
|
|
||||||
|
|
||||||
class FileAnnotation(TypedDict):
|
class FileAnnotation(TypedDict):
|
||||||
|
"""Format of results.
|
||||||
|
|
||||||
|
This is the format of the results expected by the annotation tool.
|
||||||
|
"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
|
"""File ID."""
|
||||||
|
|
||||||
annotated: bool
|
annotated: bool
|
||||||
|
"""Whether file has been annotated."""
|
||||||
|
|
||||||
duration: float
|
duration: float
|
||||||
|
"""Duration of audio file."""
|
||||||
|
|
||||||
issues: bool
|
issues: bool
|
||||||
|
"""Whether file has issues."""
|
||||||
|
|
||||||
time_exp: float
|
time_exp: float
|
||||||
|
"""Time expansion factor."""
|
||||||
|
|
||||||
class_name: str
|
class_name: str
|
||||||
|
"""Class predicted at file level."""
|
||||||
|
|
||||||
notes: str
|
notes: str
|
||||||
|
"""Notes of file."""
|
||||||
|
|
||||||
annotation: List[Annotation]
|
annotation: List[Annotation]
|
||||||
file_path: NotRequired[str] # ty: ignore[invalid-type-form]
|
"""List of annotations."""
|
||||||
|
|
||||||
|
file_path: NotRequired[str]
|
||||||
|
"""Path to file."""
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2OutputConfig(BaseConfig):
|
class BatDetect2OutputConfig(BaseConfig):
|
||||||
name: Literal["batdetect2"] = "batdetect2"
|
name: Literal["batdetect2"] = "batdetect2"
|
||||||
|
|
||||||
event_name: str = "Echolocation"
|
event_name: str = "Echolocation"
|
||||||
|
|
||||||
annotation_note: str = "Automatically generated."
|
annotation_note: str = "Automatically generated."
|
||||||
|
|
||||||
|
|
||||||
@ -66,7 +113,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
self.annotation_note = annotation_note
|
self.annotation_note = annotation_note
|
||||||
|
|
||||||
def format(
|
def format(
|
||||||
self, predictions: Sequence[ClipDetections]
|
self, predictions: Sequence[BatDetect2Prediction]
|
||||||
) -> List[FileAnnotation]:
|
) -> List[FileAnnotation]:
|
||||||
return [
|
return [
|
||||||
self.format_prediction(prediction) for prediction in predictions
|
self.format_prediction(prediction) for prediction in predictions
|
||||||
@ -76,7 +123,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
self,
|
self,
|
||||||
predictions: Sequence[FileAnnotation],
|
predictions: Sequence[FileAnnotation],
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
|
||||||
@ -109,18 +156,22 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def get_recording_class(self, annotations: List[Annotation]) -> str:
|
def get_recording_class(self, annotations: List[Annotation]) -> str:
|
||||||
|
"""Get class of recording from annotations."""
|
||||||
|
|
||||||
if not annotations:
|
if not annotations:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
highest_scoring = max(annotations, key=lambda x: x["class_prob"])
|
highest_scoring = max(annotations, key=lambda x: x["class_prob"])
|
||||||
return highest_scoring["class"]
|
return highest_scoring["class"]
|
||||||
|
|
||||||
def format_prediction(self, prediction: ClipDetections) -> FileAnnotation:
|
def format_prediction(
|
||||||
|
self, prediction: BatDetect2Prediction
|
||||||
|
) -> FileAnnotation:
|
||||||
recording = prediction.clip.recording
|
recording = prediction.clip.recording
|
||||||
|
|
||||||
annotations = [
|
annotations = [
|
||||||
self.format_sound_event_prediction(pred)
|
self.format_sound_event_prediction(pred)
|
||||||
for pred in prediction.detections
|
for pred in prediction.predictions
|
||||||
]
|
]
|
||||||
|
|
||||||
return FileAnnotation(
|
return FileAnnotation(
|
||||||
@ -145,7 +196,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
def format_sound_event_prediction(
|
def format_sound_event_prediction(
|
||||||
self, prediction: Detection
|
self, prediction: RawPrediction
|
||||||
) -> Annotation:
|
) -> Annotation:
|
||||||
start_time, low_freq, end_time, high_freq = compute_bounds(
|
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||||
prediction.geometry
|
prediction.geometry
|
||||||
@ -166,7 +217,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
**{"class": top_class},
|
**{"class": top_class},
|
||||||
)
|
)
|
||||||
|
|
||||||
@output_formatters.register(BatDetect2OutputConfig)
|
@prediction_formatters.register(BatDetect2OutputConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(config: BatDetect2OutputConfig, targets: TargetProtocol):
|
def from_config(config: BatDetect2OutputConfig, targets: TargetProtocol):
|
||||||
return BatDetect2Formatter(
|
return BatDetect2Formatter(
|
||||||
@ -1,6 +1,6 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Literal, Sequence
|
from typing import List, Literal, Optional, Sequence
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -10,13 +10,16 @@ from soundevent import data
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.outputs.formats.base import (
|
from batdetect2.data.predictions.base import (
|
||||||
make_path_relative,
|
make_path_relative,
|
||||||
output_formatters,
|
prediction_formatters,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
OutputFormatterProtocol,
|
||||||
|
RawPrediction,
|
||||||
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
|
||||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
|
|
||||||
|
|
||||||
class RawOutputConfig(BaseConfig):
|
class RawOutputConfig(BaseConfig):
|
||||||
@ -27,7 +30,7 @@ class RawOutputConfig(BaseConfig):
|
|||||||
include_geometry: bool = True
|
include_geometry: bool = True
|
||||||
|
|
||||||
|
|
||||||
class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
@ -44,15 +47,15 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
|
|
||||||
def format(
|
def format(
|
||||||
self,
|
self,
|
||||||
predictions: Sequence[ClipDetections],
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
) -> List[ClipDetections]:
|
) -> List[BatDetect2Prediction]:
|
||||||
return list(predictions)
|
return list(predictions)
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
predictions: Sequence[ClipDetections],
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
|
||||||
@ -65,10 +68,10 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
dataset = self.pred_to_xr(prediction, audio_dir)
|
dataset = self.pred_to_xr(prediction, audio_dir)
|
||||||
dataset.to_netcdf(path / f"{clip.uuid}.nc")
|
dataset.to_netcdf(path / f"{clip.uuid}.nc")
|
||||||
|
|
||||||
def load(self, path: data.PathLike) -> List[ClipDetections]:
|
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
files = list(path.glob("*.nc"))
|
files = list(path.glob("*.nc"))
|
||||||
predictions: List[ClipDetections] = []
|
predictions: List[BatDetect2Prediction] = []
|
||||||
|
|
||||||
for filepath in files:
|
for filepath in files:
|
||||||
logger.debug(f"Loading clip predictions {filepath}")
|
logger.debug(f"Loading clip predictions {filepath}")
|
||||||
@ -80,8 +83,8 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
|
|
||||||
def pred_to_xr(
|
def pred_to_xr(
|
||||||
self,
|
self,
|
||||||
prediction: ClipDetections,
|
prediction: BatDetect2Prediction,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
clip = prediction.clip
|
clip = prediction.clip
|
||||||
recording = clip.recording
|
recording = clip.recording
|
||||||
@ -92,56 +95,56 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
update=dict(path=make_path_relative(recording.path, audio_dir))
|
update=dict(path=make_path_relative(recording.path, audio_dir))
|
||||||
)
|
)
|
||||||
|
|
||||||
values = defaultdict(list)
|
data = defaultdict(list)
|
||||||
|
|
||||||
for pred in prediction.detections:
|
for pred in prediction.predictions:
|
||||||
detection_id = str(uuid4())
|
detection_id = str(uuid4())
|
||||||
|
|
||||||
values["detection_id"].append(detection_id)
|
data["detection_id"].append(detection_id)
|
||||||
values["detection_score"].append(pred.detection_score)
|
data["detection_score"].append(pred.detection_score)
|
||||||
|
|
||||||
start_time, low_freq, end_time, high_freq = compute_bounds(
|
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||||
pred.geometry
|
pred.geometry
|
||||||
)
|
)
|
||||||
|
|
||||||
values["start_time"].append(start_time)
|
data["start_time"].append(start_time)
|
||||||
values["end_time"].append(end_time)
|
data["end_time"].append(end_time)
|
||||||
values["low_freq"].append(low_freq)
|
data["low_freq"].append(low_freq)
|
||||||
values["high_freq"].append(high_freq)
|
data["high_freq"].append(high_freq)
|
||||||
|
|
||||||
values["geometry"].append(pred.geometry.model_dump_json())
|
data["geometry"].append(pred.geometry.model_dump_json())
|
||||||
|
|
||||||
top_class_index = int(np.argmax(pred.class_scores))
|
top_class_index = int(np.argmax(pred.class_scores))
|
||||||
top_class_score = float(pred.class_scores[top_class_index])
|
top_class_score = float(pred.class_scores[top_class_index])
|
||||||
top_class = self.targets.class_names[top_class_index]
|
top_class = self.targets.class_names[top_class_index]
|
||||||
|
|
||||||
values["top_class"].append(top_class)
|
data["top_class"].append(top_class)
|
||||||
values["top_class_score"].append(top_class_score)
|
data["top_class_score"].append(top_class_score)
|
||||||
|
|
||||||
values["class_scores"].append(pred.class_scores)
|
data["class_scores"].append(pred.class_scores)
|
||||||
values["features"].append(pred.features)
|
data["features"].append(pred.features)
|
||||||
|
|
||||||
num_features = len(pred.features)
|
num_features = len(pred.features)
|
||||||
|
|
||||||
data_vars = {
|
data_vars = {
|
||||||
"score": (["detection"], values["detection_score"]),
|
"score": (["detection"], data["detection_score"]),
|
||||||
"start_time": (["detection"], values["start_time"]),
|
"start_time": (["detection"], data["start_time"]),
|
||||||
"end_time": (["detection"], values["end_time"]),
|
"end_time": (["detection"], data["end_time"]),
|
||||||
"low_freq": (["detection"], values["low_freq"]),
|
"low_freq": (["detection"], data["low_freq"]),
|
||||||
"high_freq": (["detection"], values["high_freq"]),
|
"high_freq": (["detection"], data["high_freq"]),
|
||||||
"top_class": (["detection"], values["top_class"]),
|
"top_class": (["detection"], data["top_class"]),
|
||||||
"top_class_score": (["detection"], values["top_class_score"]),
|
"top_class_score": (["detection"], data["top_class_score"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
coords = {
|
coords = {
|
||||||
"detection": ("detection", values["detection_id"]),
|
"detection": ("detection", data["detection_id"]),
|
||||||
"clip_start": clip.start_time,
|
"clip_start": clip.start_time,
|
||||||
"clip_end": clip.end_time,
|
"clip_end": clip.end_time,
|
||||||
"clip_id": str(clip.uuid),
|
"clip_id": str(clip.uuid),
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.include_class_scores:
|
if self.include_class_scores:
|
||||||
class_scores = np.stack(values["class_scores"], axis=0)
|
class_scores = np.stack(data["class_scores"], axis=0)
|
||||||
data_vars["class_scores"] = (
|
data_vars["class_scores"] = (
|
||||||
["detection", "classes"],
|
["detection", "classes"],
|
||||||
class_scores,
|
class_scores,
|
||||||
@ -149,12 +152,12 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
coords["classes"] = ("classes", self.targets.class_names)
|
coords["classes"] = ("classes", self.targets.class_names)
|
||||||
|
|
||||||
if self.include_features:
|
if self.include_features:
|
||||||
features = np.stack(values["features"], axis=0)
|
features = np.stack(data["features"], axis=0)
|
||||||
data_vars["features"] = (["detection", "feature"], features)
|
data_vars["features"] = (["detection", "feature"], features)
|
||||||
coords["feature"] = ("feature", np.arange(num_features))
|
coords["feature"] = ("feature", np.arange(num_features))
|
||||||
|
|
||||||
if self.include_geometry:
|
if self.include_geometry:
|
||||||
data_vars["geometry"] = (["detection"], values["geometry"])
|
data_vars["geometry"] = (["detection"], data["geometry"])
|
||||||
|
|
||||||
return xr.Dataset(
|
return xr.Dataset(
|
||||||
data_vars=data_vars,
|
data_vars=data_vars,
|
||||||
@ -164,8 +167,9 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def pred_from_xr(self, dataset: xr.Dataset) -> ClipDetections:
|
def pred_from_xr(self, dataset: xr.Dataset) -> BatDetect2Prediction:
|
||||||
clip_data = dataset
|
clip_data = dataset
|
||||||
|
clip_id = clip_data.clip_id.item()
|
||||||
|
|
||||||
recording = data.Recording.model_validate_json(
|
recording = data.Recording.model_validate_json(
|
||||||
clip_data.attrs["recording"]
|
clip_data.attrs["recording"]
|
||||||
@ -215,7 +219,7 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
features = np.zeros(0)
|
features = np.zeros(0)
|
||||||
|
|
||||||
sound_events.append(
|
sound_events.append(
|
||||||
Detection(
|
RawPrediction(
|
||||||
geometry=geometry,
|
geometry=geometry,
|
||||||
detection_score=score,
|
detection_score=score,
|
||||||
class_scores=class_scores,
|
class_scores=class_scores,
|
||||||
@ -223,12 +227,12 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return ClipDetections(
|
return BatDetect2Prediction(
|
||||||
clip=clip,
|
clip=clip,
|
||||||
detections=sound_events,
|
predictions=sound_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
@output_formatters.register(RawOutputConfig)
|
@prediction_formatters.register(RawOutputConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(config: RawOutputConfig, targets: TargetProtocol):
|
def from_config(config: RawOutputConfig, targets: TargetProtocol):
|
||||||
return RawFormatter(
|
return RawFormatter(
|
||||||
@ -1,30 +1,33 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Literal, Sequence
|
from typing import List, Literal, Optional, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data, io
|
from soundevent import data, io
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.outputs.formats.base import (
|
from batdetect2.data.predictions.base import (
|
||||||
output_formatters,
|
prediction_formatters,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
OutputFormatterProtocol,
|
||||||
|
RawPrediction,
|
||||||
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
|
||||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
|
|
||||||
|
|
||||||
class SoundEventOutputConfig(BaseConfig):
|
class SoundEventOutputConfig(BaseConfig):
|
||||||
name: Literal["soundevent"] = "soundevent"
|
name: Literal["soundevent"] = "soundevent"
|
||||||
top_k: int | None = 1
|
top_k: Optional[int] = 1
|
||||||
min_score: float | None = None
|
min_score: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
top_k: int | None = 1,
|
top_k: Optional[int] = 1,
|
||||||
min_score: float | None = 0,
|
min_score: Optional[float] = 0,
|
||||||
):
|
):
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
@ -32,7 +35,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
|||||||
|
|
||||||
def format(
|
def format(
|
||||||
self,
|
self,
|
||||||
predictions: Sequence[ClipDetections],
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
) -> List[data.ClipPrediction]:
|
) -> List[data.ClipPrediction]:
|
||||||
return [
|
return [
|
||||||
self.format_prediction(prediction) for prediction in predictions
|
self.format_prediction(prediction) for prediction in predictions
|
||||||
@ -42,7 +45,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
|||||||
self,
|
self,
|
||||||
predictions: Sequence[data.ClipPrediction],
|
predictions: Sequence[data.ClipPrediction],
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
run = data.PredictionSet(clip_predictions=list(predictions))
|
run = data.PredictionSet(clip_predictions=list(predictions))
|
||||||
|
|
||||||
@ -60,20 +63,20 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
|||||||
|
|
||||||
def format_prediction(
|
def format_prediction(
|
||||||
self,
|
self,
|
||||||
prediction: ClipDetections,
|
prediction: BatDetect2Prediction,
|
||||||
) -> data.ClipPrediction:
|
) -> data.ClipPrediction:
|
||||||
recording = prediction.clip.recording
|
recording = prediction.clip.recording
|
||||||
return data.ClipPrediction(
|
return data.ClipPrediction(
|
||||||
clip=prediction.clip,
|
clip=prediction.clip,
|
||||||
sound_events=[
|
sound_events=[
|
||||||
self.format_sound_event_prediction(pred, recording)
|
self.format_sound_event_prediction(pred, recording)
|
||||||
for pred in prediction.detections
|
for pred in prediction.predictions
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def format_sound_event_prediction(
|
def format_sound_event_prediction(
|
||||||
self,
|
self,
|
||||||
prediction: Detection,
|
prediction: RawPrediction,
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
) -> data.SoundEventPrediction:
|
) -> data.SoundEventPrediction:
|
||||||
return data.SoundEventPrediction(
|
return data.SoundEventPrediction(
|
||||||
@ -86,7 +89,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def get_sound_event_tags(
|
def get_sound_event_tags(
|
||||||
self, prediction: Detection
|
self, prediction: RawPrediction
|
||||||
) -> List[data.PredictedTag]:
|
) -> List[data.PredictedTag]:
|
||||||
sorted_indices = np.argsort(prediction.class_scores)[::-1]
|
sorted_indices = np.argsort(prediction.class_scores)[::-1]
|
||||||
|
|
||||||
@ -118,7 +121,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
|||||||
|
|
||||||
return tags
|
return tags
|
||||||
|
|
||||||
@output_formatters.register(SoundEventOutputConfig)
|
@prediction_formatters.register(SoundEventOutputConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(config: SoundEventOutputConfig, targets: TargetProtocol):
|
def from_config(config: SoundEventOutputConfig, targets: TargetProtocol):
|
||||||
return SoundEventOutputFormatter(
|
return SoundEventOutputFormatter(
|
||||||
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
@ -5,15 +7,15 @@ from batdetect2.data.summary import (
|
|||||||
extract_recordings_df,
|
extract_recordings_df,
|
||||||
extract_sound_events_df,
|
extract_sound_events_df,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
def split_dataset_by_recordings(
|
def split_dataset_by_recordings(
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
train_size: float = 0.75,
|
train_size: float = 0.75,
|
||||||
random_state: int | None = None,
|
random_state: Optional[int] = None,
|
||||||
) -> tuple[Dataset, Dataset]:
|
) -> Tuple[Dataset, Dataset]:
|
||||||
recordings = extract_recordings_df(dataset)
|
recordings = extract_recordings_df(dataset)
|
||||||
|
|
||||||
sound_events = extract_sound_events_df(
|
sound_events = extract_sound_events_df(
|
||||||
@ -24,15 +26,13 @@ def split_dataset_by_recordings(
|
|||||||
)
|
)
|
||||||
|
|
||||||
majority_class = (
|
majority_class = (
|
||||||
sound_events.groupby("recording_id") # type: ignore
|
sound_events.groupby("recording_id")
|
||||||
.apply(
|
.apply(
|
||||||
lambda group: (
|
lambda group: group["class_name"] # type: ignore
|
||||||
group["class_name"]
|
|
||||||
.value_counts()
|
.value_counts()
|
||||||
.sort_values(ascending=False)
|
.sort_values(ascending=False)
|
||||||
.index[0]
|
.index[0],
|
||||||
),
|
include_groups=False, # type: ignore
|
||||||
include_groups=False,
|
|
||||||
)
|
)
|
||||||
.rename("class_name")
|
.rename("class_name")
|
||||||
.to_frame()
|
.to_frame()
|
||||||
@ -46,8 +46,8 @@ def split_dataset_by_recordings(
|
|||||||
random_state=random_state,
|
random_state=random_state,
|
||||||
)
|
)
|
||||||
|
|
||||||
train_ids_set = set(train.values)
|
train_ids_set = set(train.values) # type: ignore
|
||||||
test_ids_set = set(test.values)
|
test_ids_set = set(test.values) # type: ignore
|
||||||
|
|
||||||
extra = set(recordings["recording_id"]) - train_ids_set - test_ids_set
|
extra = set(recordings["recording_id"]) - train_ids_set - test_ids_set
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import pandas as pd
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"extract_recordings_df",
|
"extract_recordings_df",
|
||||||
@ -175,14 +175,14 @@ def compute_class_summary(
|
|||||||
.rename("num recordings")
|
.rename("num recordings")
|
||||||
)
|
)
|
||||||
durations = (
|
durations = (
|
||||||
sound_events.groupby("class_name") # ty: ignore[no-matching-overload]
|
sound_events.groupby("class_name")
|
||||||
.apply(
|
.apply(
|
||||||
lambda group: recordings[
|
lambda group: recordings[
|
||||||
recordings["clip_annotation_id"].isin(
|
recordings["clip_annotation_id"].isin(
|
||||||
group["clip_annotation_id"]
|
group["clip_annotation_id"] # type: ignore
|
||||||
)
|
)
|
||||||
]["duration"].sum(),
|
]["duration"].sum(),
|
||||||
include_groups=False,
|
include_groups=False, # type: ignore
|
||||||
)
|
)
|
||||||
.sort_values(ascending=False)
|
.sort_values(ascending=False)
|
||||||
.rename("duration")
|
.rename("duration")
|
||||||
|
|||||||
@ -1,15 +1,11 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Annotated, Dict, List, Literal
|
from typing import Annotated, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.core.registries import (
|
from batdetect2.core.registries import Registry
|
||||||
ImportConfig,
|
|
||||||
Registry,
|
|
||||||
add_import_config,
|
|
||||||
)
|
|
||||||
from batdetect2.data.conditions import (
|
from batdetect2.data.conditions import (
|
||||||
SoundEventCondition,
|
SoundEventCondition,
|
||||||
SoundEventConditionConfig,
|
SoundEventConditionConfig,
|
||||||
@ -24,17 +20,6 @@ SoundEventTransform = Callable[
|
|||||||
transforms: Registry[SoundEventTransform, []] = Registry("transform")
|
transforms: Registry[SoundEventTransform, []] = Registry("transform")
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(transforms)
|
|
||||||
class SoundEventTransformImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a sound event transform.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class SetFrequencyBoundConfig(BaseConfig):
|
class SetFrequencyBoundConfig(BaseConfig):
|
||||||
name: Literal["set_frequency"] = "set_frequency"
|
name: Literal["set_frequency"] = "set_frequency"
|
||||||
boundary: Literal["low", "high"] = "low"
|
boundary: Literal["low", "high"] = "low"
|
||||||
@ -157,7 +142,7 @@ class MapTagValueConfig(BaseConfig):
|
|||||||
name: Literal["map_tag_value"] = "map_tag_value"
|
name: Literal["map_tag_value"] = "map_tag_value"
|
||||||
tag_key: str
|
tag_key: str
|
||||||
value_mapping: Dict[str, str]
|
value_mapping: Dict[str, str]
|
||||||
target_key: str | None = None
|
target_key: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class MapTagValue:
|
class MapTagValue:
|
||||||
@ -165,7 +150,7 @@ class MapTagValue:
|
|||||||
self,
|
self,
|
||||||
tag_key: str,
|
tag_key: str,
|
||||||
value_mapping: Dict[str, str],
|
value_mapping: Dict[str, str],
|
||||||
target_key: str | None = None,
|
target_key: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.tag_key = tag_key
|
self.tag_key = tag_key
|
||||||
self.value_mapping = value_mapping
|
self.value_mapping = value_mapping
|
||||||
@ -191,7 +176,12 @@ class MapTagValue:
|
|||||||
if self.target_key is None:
|
if self.target_key is None:
|
||||||
tags.append(tag.model_copy(update=dict(value=value)))
|
tags.append(tag.model_copy(update=dict(value=value)))
|
||||||
else:
|
else:
|
||||||
tags.append(data.Tag(key=self.target_key, value=value))
|
tags.append(
|
||||||
|
data.Tag(
|
||||||
|
key=self.target_key, # type: ignore
|
||||||
|
value=value,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||||
|
|
||||||
@ -231,11 +221,13 @@ class ApplyAll:
|
|||||||
|
|
||||||
|
|
||||||
SoundEventTransformConfig = Annotated[
|
SoundEventTransformConfig = Annotated[
|
||||||
SetFrequencyBoundConfig
|
Union[
|
||||||
| ReplaceTagConfig
|
SetFrequencyBoundConfig,
|
||||||
| MapTagValueConfig
|
ReplaceTagConfig,
|
||||||
| ApplyIfConfig
|
MapTagValueConfig,
|
||||||
| ApplyAllConfig,
|
ApplyIfConfig,
|
||||||
|
ApplyAllConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""Functions to compute features from predictions."""
|
"""Functions to compute features from predictions."""
|
||||||
|
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ def compute_bandwidth(
|
|||||||
|
|
||||||
def compute_max_power_bb(
|
def compute_max_power_bb(
|
||||||
prediction: types.Prediction,
|
prediction: types.Prediction,
|
||||||
spec: np.ndarray | None = None,
|
spec: Optional[np.ndarray] = None,
|
||||||
min_freq: int = MIN_FREQ_HZ,
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
max_freq: int = MAX_FREQ_HZ,
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
**_,
|
**_,
|
||||||
@ -131,7 +131,7 @@ def compute_max_power_bb(
|
|||||||
|
|
||||||
def compute_max_power(
|
def compute_max_power(
|
||||||
prediction: types.Prediction,
|
prediction: types.Prediction,
|
||||||
spec: np.ndarray | None = None,
|
spec: Optional[np.ndarray] = None,
|
||||||
min_freq: int = MIN_FREQ_HZ,
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
max_freq: int = MAX_FREQ_HZ,
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
**_,
|
**_,
|
||||||
@ -157,7 +157,7 @@ def compute_max_power(
|
|||||||
|
|
||||||
def compute_max_power_first(
|
def compute_max_power_first(
|
||||||
prediction: types.Prediction,
|
prediction: types.Prediction,
|
||||||
spec: np.ndarray | None = None,
|
spec: Optional[np.ndarray] = None,
|
||||||
min_freq: int = MIN_FREQ_HZ,
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
max_freq: int = MAX_FREQ_HZ,
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
**_,
|
**_,
|
||||||
@ -184,7 +184,7 @@ def compute_max_power_first(
|
|||||||
|
|
||||||
def compute_max_power_second(
|
def compute_max_power_second(
|
||||||
prediction: types.Prediction,
|
prediction: types.Prediction,
|
||||||
spec: np.ndarray | None = None,
|
spec: Optional[np.ndarray] = None,
|
||||||
min_freq: int = MIN_FREQ_HZ,
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
max_freq: int = MAX_FREQ_HZ,
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
**_,
|
**_,
|
||||||
@ -211,7 +211,7 @@ def compute_max_power_second(
|
|||||||
|
|
||||||
def compute_call_interval(
|
def compute_call_interval(
|
||||||
prediction: types.Prediction,
|
prediction: types.Prediction,
|
||||||
previous: types.Prediction | None = None,
|
previous: Optional[types.Prediction] = None,
|
||||||
**_,
|
**_,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""Compute time between this call and the previous call in seconds."""
|
"""Compute time between this call and the previous call in seconds."""
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, computed_field
|
from pydantic import BaseModel, Field, computed_field
|
||||||
|
|
||||||
@ -198,8 +198,8 @@ class TrainingParameters(BaseModel):
|
|||||||
def get_params(
|
def get_params(
|
||||||
make_dirs: bool = False,
|
make_dirs: bool = False,
|
||||||
exps_dir: str = "../../experiments/",
|
exps_dir: str = "../../experiments/",
|
||||||
model_name: str | None = None,
|
model_name: Optional[str] = None,
|
||||||
experiment: Path | str | None = None,
|
experiment: Union[Path, str, None] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> TrainingParameters:
|
) -> TrainingParameters:
|
||||||
experiments_dir = Path(exps_dir)
|
experiments_dir = Path(exps_dir)
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
"""Post-processing of the output of the model."""
|
"""Post-processing of the output of the model."""
|
||||||
|
|
||||||
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -43,7 +45,7 @@ def run_nms(
|
|||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
params: NonMaximumSuppressionConfig,
|
params: NonMaximumSuppressionConfig,
|
||||||
sampling_rate: np.ndarray,
|
sampling_rate: np.ndarray,
|
||||||
) -> tuple[list[PredictionResults], list[np.ndarray]]:
|
) -> Tuple[List[PredictionResults], List[np.ndarray]]:
|
||||||
"""Run non-maximum suppression on the output of the model.
|
"""Run non-maximum suppression on the output of the model.
|
||||||
|
|
||||||
Model outputs processed are expected to have a batch dimension.
|
Model outputs processed are expected to have a batch dimension.
|
||||||
@ -71,8 +73,8 @@ def run_nms(
|
|||||||
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
|
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
|
||||||
|
|
||||||
# loop over batch to save outputs
|
# loop over batch to save outputs
|
||||||
preds: list[PredictionResults] = []
|
preds: List[PredictionResults] = []
|
||||||
feats: list[np.ndarray] = []
|
feats: List[np.ndarray] = []
|
||||||
for num_detection in range(pred_det_nms.shape[0]):
|
for num_detection in range(pred_det_nms.shape[0]):
|
||||||
# get valid indices
|
# get valid indices
|
||||||
inds_ord = torch.argsort(x_pos[num_detection, :])
|
inds_ord = torch.argsort(x_pos[num_detection, :])
|
||||||
@ -149,7 +151,7 @@ def run_nms(
|
|||||||
|
|
||||||
def non_max_suppression(
|
def non_max_suppression(
|
||||||
heat: torch.Tensor,
|
heat: torch.Tensor,
|
||||||
kernel_size: int | tuple[int, int],
|
kernel_size: Union[int, Tuple[int, int]],
|
||||||
):
|
):
|
||||||
# kernel can be an int or list/tuple
|
# kernel can be an int or list/tuple
|
||||||
if isinstance(kernel_size, int):
|
if isinstance(kernel_size, int):
|
||||||
|
|||||||
@ -1,32 +1,15 @@
|
|||||||
from batdetect2.evaluate.config import EvaluationConfig
|
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||||
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate
|
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate
|
||||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||||
from batdetect2.evaluate.results import save_evaluation_results
|
|
||||||
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
||||||
from batdetect2.evaluate.types import (
|
|
||||||
AffinityFunction,
|
|
||||||
ClipMatches,
|
|
||||||
EvaluationTaskProtocol,
|
|
||||||
EvaluatorProtocol,
|
|
||||||
MetricsProtocol,
|
|
||||||
PlotterProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AffinityFunction",
|
|
||||||
"ClipMatches",
|
|
||||||
"DEFAULT_EVAL_DIR",
|
|
||||||
"EvaluationConfig",
|
"EvaluationConfig",
|
||||||
"EvaluationTaskProtocol",
|
|
||||||
"Evaluator",
|
"Evaluator",
|
||||||
"EvaluatorProtocol",
|
|
||||||
"MatchEvaluation",
|
|
||||||
"MatcherProtocol",
|
|
||||||
"MetricsProtocol",
|
|
||||||
"PlotterProtocol",
|
|
||||||
"TaskConfig",
|
"TaskConfig",
|
||||||
"build_evaluator",
|
"build_evaluator",
|
||||||
"build_task",
|
"build_task",
|
||||||
"run_evaluate",
|
"evaluate",
|
||||||
"save_evaluation_results",
|
"load_evaluation_config",
|
||||||
|
"DEFAULT_EVAL_DIR",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,116 +1,76 @@
|
|||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import (
|
from soundevent.evaluation import compute_affinity
|
||||||
buffer_geometry,
|
from soundevent.geometry import compute_interval_overlap
|
||||||
compute_bbox_iou,
|
|
||||||
compute_geometric_iou,
|
|
||||||
compute_temporal_closeness,
|
|
||||||
compute_temporal_iou,
|
|
||||||
)
|
|
||||||
|
|
||||||
from batdetect2.core import (
|
from batdetect2.core.configs import BaseConfig
|
||||||
BaseConfig,
|
from batdetect2.core.registries import Registry
|
||||||
ImportConfig,
|
from batdetect2.typing.evaluate import AffinityFunction
|
||||||
Registry,
|
|
||||||
add_import_config,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.types import AffinityFunction
|
|
||||||
from batdetect2.postprocess.types import Detection
|
|
||||||
|
|
||||||
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
||||||
"affinity_function"
|
"matching_strategy"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(affinity_functions)
|
|
||||||
class AffinityFunctionImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as an affinity function.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class TimeAffinityConfig(BaseConfig):
|
class TimeAffinityConfig(BaseConfig):
|
||||||
name: Literal["time_affinity"] = "time_affinity"
|
name: Literal["time_affinity"] = "time_affinity"
|
||||||
position: Literal["start", "end", "center"] | float = "start"
|
time_buffer: float = 0.01
|
||||||
max_distance: float = 0.01
|
|
||||||
|
|
||||||
|
|
||||||
class TimeAffinity(AffinityFunction):
|
class TimeAffinity(AffinityFunction):
|
||||||
def __init__(
|
def __init__(self, time_buffer: float):
|
||||||
self,
|
self.time_buffer = time_buffer
|
||||||
max_distance: float = 0.01,
|
|
||||||
position: Literal["start", "end", "center"] | float = "start",
|
|
||||||
):
|
|
||||||
if position == "start":
|
|
||||||
position = 0
|
|
||||||
elif position == "end":
|
|
||||||
position = 1
|
|
||||||
elif position == "center":
|
|
||||||
position = 0.5
|
|
||||||
|
|
||||||
self.position = position
|
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||||
self.max_distance = max_distance
|
return compute_timestamp_affinity(
|
||||||
|
geometry1, geometry2, time_buffer=self.time_buffer
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
detection: Detection,
|
|
||||||
ground_truth: data.SoundEventAnnotation,
|
|
||||||
) -> float:
|
|
||||||
target_geometry = ground_truth.sound_event.geometry
|
|
||||||
source_geometry = detection.geometry
|
|
||||||
return compute_temporal_closeness(
|
|
||||||
target_geometry,
|
|
||||||
source_geometry,
|
|
||||||
ratio=self.position,
|
|
||||||
max_distance=self.max_distance,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@affinity_functions.register(TimeAffinityConfig)
|
@affinity_functions.register(TimeAffinityConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(config: TimeAffinityConfig):
|
def from_config(config: TimeAffinityConfig):
|
||||||
return TimeAffinity(
|
return TimeAffinity(time_buffer=config.time_buffer)
|
||||||
max_distance=config.max_distance,
|
|
||||||
position=config.position,
|
|
||||||
)
|
def compute_timestamp_affinity(
|
||||||
|
geometry1: data.Geometry,
|
||||||
|
geometry2: data.Geometry,
|
||||||
|
time_buffer: float = 0.01,
|
||||||
|
) -> float:
|
||||||
|
assert isinstance(geometry1, data.TimeStamp)
|
||||||
|
assert isinstance(geometry2, data.TimeStamp)
|
||||||
|
|
||||||
|
start_time1 = geometry1.coordinates
|
||||||
|
start_time2 = geometry2.coordinates
|
||||||
|
|
||||||
|
a = min(start_time1, start_time2)
|
||||||
|
b = max(start_time1, start_time2)
|
||||||
|
|
||||||
|
if b - a >= 2 * time_buffer:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
intersection = a - b + 2 * time_buffer
|
||||||
|
union = b - a + 2 * time_buffer
|
||||||
|
return intersection / union
|
||||||
|
|
||||||
|
|
||||||
class IntervalIOUConfig(BaseConfig):
|
class IntervalIOUConfig(BaseConfig):
|
||||||
name: Literal["interval_iou"] = "interval_iou"
|
name: Literal["interval_iou"] = "interval_iou"
|
||||||
time_buffer: float = 0.0
|
time_buffer: float = 0.01
|
||||||
|
|
||||||
|
|
||||||
class IntervalIOU(AffinityFunction):
|
class IntervalIOU(AffinityFunction):
|
||||||
def __init__(self, time_buffer: float):
|
def __init__(self, time_buffer: float):
|
||||||
if time_buffer < 0:
|
|
||||||
raise ValueError("time_buffer must be non-negative")
|
|
||||||
|
|
||||||
self.time_buffer = time_buffer
|
self.time_buffer = time_buffer
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||||
self,
|
return compute_interval_iou(
|
||||||
detection: Detection,
|
geometry1,
|
||||||
ground_truth: data.SoundEventAnnotation,
|
geometry2,
|
||||||
) -> float:
|
time_buffer=self.time_buffer,
|
||||||
target_geometry = ground_truth.sound_event.geometry
|
|
||||||
source_geometry = detection.geometry
|
|
||||||
|
|
||||||
if self.time_buffer > 0:
|
|
||||||
target_geometry = buffer_geometry(
|
|
||||||
target_geometry,
|
|
||||||
time=self.time_buffer,
|
|
||||||
)
|
)
|
||||||
source_geometry = buffer_geometry(
|
|
||||||
source_geometry,
|
|
||||||
time=self.time_buffer,
|
|
||||||
)
|
|
||||||
|
|
||||||
return compute_temporal_iou(target_geometry, source_geometry)
|
|
||||||
|
|
||||||
@affinity_functions.register(IntervalIOUConfig)
|
@affinity_functions.register(IntervalIOUConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -118,44 +78,64 @@ class IntervalIOU(AffinityFunction):
|
|||||||
return IntervalIOU(time_buffer=config.time_buffer)
|
return IntervalIOU(time_buffer=config.time_buffer)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_interval_iou(
|
||||||
|
geometry1: data.Geometry,
|
||||||
|
geometry2: data.Geometry,
|
||||||
|
time_buffer: float = 0.01,
|
||||||
|
) -> float:
|
||||||
|
assert isinstance(geometry1, data.TimeInterval)
|
||||||
|
assert isinstance(geometry2, data.TimeInterval)
|
||||||
|
|
||||||
|
start_time1, end_time1 = geometry1.coordinates
|
||||||
|
start_time2, end_time2 = geometry1.coordinates
|
||||||
|
|
||||||
|
start_time1 -= time_buffer
|
||||||
|
start_time2 -= time_buffer
|
||||||
|
end_time1 += time_buffer
|
||||||
|
end_time2 += time_buffer
|
||||||
|
|
||||||
|
intersection = compute_interval_overlap(
|
||||||
|
(start_time1, end_time1),
|
||||||
|
(start_time2, end_time2),
|
||||||
|
)
|
||||||
|
|
||||||
|
union = (
|
||||||
|
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
|
||||||
|
)
|
||||||
|
|
||||||
|
if union == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return intersection / union
|
||||||
|
|
||||||
|
|
||||||
class BBoxIOUConfig(BaseConfig):
|
class BBoxIOUConfig(BaseConfig):
|
||||||
name: Literal["bbox_iou"] = "bbox_iou"
|
name: Literal["bbox_iou"] = "bbox_iou"
|
||||||
time_buffer: float = 0.0
|
time_buffer: float = 0.01
|
||||||
freq_buffer: float = 0.0
|
freq_buffer: float = 1000
|
||||||
|
|
||||||
|
|
||||||
class BBoxIOU(AffinityFunction):
|
class BBoxIOU(AffinityFunction):
|
||||||
def __init__(self, time_buffer: float, freq_buffer: float):
|
def __init__(self, time_buffer: float, freq_buffer: float):
|
||||||
if time_buffer < 0:
|
|
||||||
raise ValueError("time_buffer must be non-negative")
|
|
||||||
|
|
||||||
if freq_buffer < 0:
|
|
||||||
raise ValueError("freq_buffer must be non-negative")
|
|
||||||
|
|
||||||
self.time_buffer = time_buffer
|
self.time_buffer = time_buffer
|
||||||
self.freq_buffer = freq_buffer
|
self.freq_buffer = freq_buffer
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||||
self,
|
if not isinstance(geometry1, data.BoundingBox):
|
||||||
detection: Detection,
|
raise TypeError(
|
||||||
ground_truth: data.SoundEventAnnotation,
|
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}"
|
||||||
):
|
|
||||||
target_geometry = ground_truth.sound_event.geometry
|
|
||||||
source_geometry = detection.geometry
|
|
||||||
|
|
||||||
if self.time_buffer > 0 or self.freq_buffer > 0:
|
|
||||||
target_geometry = buffer_geometry(
|
|
||||||
target_geometry,
|
|
||||||
time=self.time_buffer,
|
|
||||||
freq=self.freq_buffer,
|
|
||||||
)
|
|
||||||
source_geometry = buffer_geometry(
|
|
||||||
source_geometry,
|
|
||||||
time=self.time_buffer,
|
|
||||||
freq=self.freq_buffer,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return compute_bbox_iou(target_geometry, source_geometry)
|
if not isinstance(geometry2, data.BoundingBox):
|
||||||
|
raise TypeError(
|
||||||
|
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
|
||||||
|
)
|
||||||
|
return bbox_iou(
|
||||||
|
geometry1,
|
||||||
|
geometry2,
|
||||||
|
time_buffer=self.time_buffer,
|
||||||
|
freq_buffer=self.freq_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
@affinity_functions.register(BBoxIOUConfig)
|
@affinity_functions.register(BBoxIOUConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -166,44 +146,65 @@ class BBoxIOU(AffinityFunction):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def bbox_iou(
|
||||||
|
geometry1: data.BoundingBox,
|
||||||
|
geometry2: data.BoundingBox,
|
||||||
|
time_buffer: float = 0.01,
|
||||||
|
freq_buffer: float = 1000,
|
||||||
|
) -> float:
|
||||||
|
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
|
||||||
|
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
|
||||||
|
|
||||||
|
start_time1 -= time_buffer
|
||||||
|
start_time2 -= time_buffer
|
||||||
|
end_time1 += time_buffer
|
||||||
|
end_time2 += time_buffer
|
||||||
|
|
||||||
|
low_freq1 -= freq_buffer
|
||||||
|
low_freq2 -= freq_buffer
|
||||||
|
high_freq1 += freq_buffer
|
||||||
|
high_freq2 += freq_buffer
|
||||||
|
|
||||||
|
time_intersection = compute_interval_overlap(
|
||||||
|
(start_time1, end_time1),
|
||||||
|
(start_time2, end_time2),
|
||||||
|
)
|
||||||
|
|
||||||
|
freq_intersection = max(
|
||||||
|
0,
|
||||||
|
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
|
||||||
|
)
|
||||||
|
|
||||||
|
intersection = time_intersection * freq_intersection
|
||||||
|
|
||||||
|
if intersection == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
union = (
|
||||||
|
(end_time1 - start_time1) * (high_freq1 - low_freq1)
|
||||||
|
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
|
||||||
|
- intersection
|
||||||
|
)
|
||||||
|
|
||||||
|
return intersection / union
|
||||||
|
|
||||||
|
|
||||||
class GeometricIOUConfig(BaseConfig):
|
class GeometricIOUConfig(BaseConfig):
|
||||||
name: Literal["geometric_iou"] = "geometric_iou"
|
name: Literal["geometric_iou"] = "geometric_iou"
|
||||||
time_buffer: float = 0.0
|
time_buffer: float = 0.01
|
||||||
freq_buffer: float = 0.0
|
freq_buffer: float = 1000
|
||||||
|
|
||||||
|
|
||||||
class GeometricIOU(AffinityFunction):
|
class GeometricIOU(AffinityFunction):
|
||||||
def __init__(self, time_buffer: float = 0, freq_buffer: float = 0):
|
def __init__(self, time_buffer: float):
|
||||||
if time_buffer < 0:
|
|
||||||
raise ValueError("time_buffer must be non-negative")
|
|
||||||
|
|
||||||
if freq_buffer < 0:
|
|
||||||
raise ValueError("freq_buffer must be non-negative")
|
|
||||||
|
|
||||||
self.time_buffer = time_buffer
|
self.time_buffer = time_buffer
|
||||||
self.freq_buffer = freq_buffer
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||||
self,
|
return compute_affinity(
|
||||||
detection: Detection,
|
geometry1,
|
||||||
ground_truth: data.SoundEventAnnotation,
|
geometry2,
|
||||||
):
|
time_buffer=self.time_buffer,
|
||||||
target_geometry = ground_truth.sound_event.geometry
|
|
||||||
source_geometry = detection.geometry
|
|
||||||
|
|
||||||
if self.time_buffer > 0 or self.freq_buffer > 0:
|
|
||||||
target_geometry = buffer_geometry(
|
|
||||||
target_geometry,
|
|
||||||
time=self.time_buffer,
|
|
||||||
freq=self.freq_buffer,
|
|
||||||
)
|
)
|
||||||
source_geometry = buffer_geometry(
|
|
||||||
source_geometry,
|
|
||||||
time=self.time_buffer,
|
|
||||||
freq=self.freq_buffer,
|
|
||||||
)
|
|
||||||
|
|
||||||
return compute_geometric_iou(target_geometry, source_geometry)
|
|
||||||
|
|
||||||
@affinity_functions.register(GeometricIOUConfig)
|
@affinity_functions.register(GeometricIOUConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -212,16 +213,18 @@ class GeometricIOU(AffinityFunction):
|
|||||||
|
|
||||||
|
|
||||||
AffinityConfig = Annotated[
|
AffinityConfig = Annotated[
|
||||||
TimeAffinityConfig
|
Union[
|
||||||
| IntervalIOUConfig
|
TimeAffinityConfig,
|
||||||
| BBoxIOUConfig
|
IntervalIOUConfig,
|
||||||
| GeometricIOUConfig,
|
BBoxIOUConfig,
|
||||||
|
GeometricIOUConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_affinity_function(
|
def build_affinity_function(
|
||||||
config: AffinityConfig | None = None,
|
config: Optional[AffinityConfig] = None,
|
||||||
) -> AffinityFunction:
|
) -> AffinityFunction:
|
||||||
config = config or GeometricIOUConfig()
|
config = config or GeometricIOUConfig()
|
||||||
return affinity_functions.build(config)
|
return affinity_functions.build(config)
|
||||||
|
|||||||
@ -1,14 +1,19 @@
|
|||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate.tasks import TaskConfig
|
from batdetect2.evaluate.tasks import (
|
||||||
|
TaskConfig,
|
||||||
|
)
|
||||||
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||||
|
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EvaluationConfig",
|
"EvaluationConfig",
|
||||||
|
"load_evaluation_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -19,6 +24,7 @@ class EvaluationConfig(BaseConfig):
|
|||||||
ClassificationTaskConfig(),
|
ClassificationTaskConfig(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||||
|
|
||||||
|
|
||||||
def get_default_eval_config() -> EvaluationConfig:
|
def get_default_eval_config() -> EvaluationConfig:
|
||||||
@ -41,3 +47,10 @@ def get_default_eval_config() -> EvaluationConfig:
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_evaluation_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> EvaluationConfig:
|
||||||
|
return load_config(path, schema=EvaluationConfig, field=field)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, NamedTuple, Sequence
|
from typing import List, NamedTuple, Optional, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -8,11 +8,14 @@ from torch.utils.data import DataLoader, Dataset
|
|||||||
|
|
||||||
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
||||||
from batdetect2.audio.clips import PaddedClipConfig
|
from batdetect2.audio.clips import PaddedClipConfig
|
||||||
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.core.arrays import adjust_width
|
from batdetect2.core.arrays import adjust_width
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.typing import (
|
||||||
|
AudioLoader,
|
||||||
|
ClipperProtocol,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TestDataset",
|
"TestDataset",
|
||||||
@ -36,8 +39,8 @@ class TestDataset(Dataset[TestExample]):
|
|||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: AudioLoader,
|
audio_loader: AudioLoader,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
clipper: ClipperProtocol | None = None,
|
clipper: Optional[ClipperProtocol] = None,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
):
|
):
|
||||||
self.clip_annotations = list(clip_annotations)
|
self.clip_annotations = list(clip_annotations)
|
||||||
self.clipper = clipper
|
self.clipper = clipper
|
||||||
@ -48,8 +51,8 @@ class TestDataset(Dataset[TestExample]):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.clip_annotations)
|
return len(self.clip_annotations)
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> TestExample:
|
def __getitem__(self, idx: int) -> TestExample:
|
||||||
clip_annotation = self.clip_annotations[index]
|
clip_annotation = self.clip_annotations[idx]
|
||||||
|
|
||||||
if self.clipper is not None:
|
if self.clipper is not None:
|
||||||
clip_annotation = self.clipper(clip_annotation)
|
clip_annotation = self.clipper(clip_annotation)
|
||||||
@ -60,13 +63,14 @@ class TestDataset(Dataset[TestExample]):
|
|||||||
spectrogram = self.preprocessor(wav_tensor)
|
spectrogram = self.preprocessor(wav_tensor)
|
||||||
return TestExample(
|
return TestExample(
|
||||||
spec=spectrogram,
|
spec=spectrogram,
|
||||||
idx=torch.tensor(index),
|
idx=torch.tensor(idx),
|
||||||
start_time=torch.tensor(clip.start_time),
|
start_time=torch.tensor(clip.start_time),
|
||||||
end_time=torch.tensor(clip.end_time),
|
end_time=torch.tensor(clip.end_time),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestLoaderConfig(BaseConfig):
|
class TestLoaderConfig(BaseConfig):
|
||||||
|
num_workers: int = 0
|
||||||
clipping_strategy: ClipConfig = Field(
|
clipping_strategy: ClipConfig = Field(
|
||||||
default_factory=lambda: PaddedClipConfig()
|
default_factory=lambda: PaddedClipConfig()
|
||||||
)
|
)
|
||||||
@ -74,10 +78,10 @@ class TestLoaderConfig(BaseConfig):
|
|||||||
|
|
||||||
def build_test_loader(
|
def build_test_loader(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
config: TestLoaderConfig | None = None,
|
config: Optional[TestLoaderConfig] = None,
|
||||||
num_workers: int = 0,
|
num_workers: Optional[int] = None,
|
||||||
) -> DataLoader[TestExample]:
|
) -> DataLoader[TestExample]:
|
||||||
logger.info("Building test data loader...")
|
logger.info("Building test data loader...")
|
||||||
config = config or TestLoaderConfig()
|
config = config or TestLoaderConfig()
|
||||||
@ -93,6 +97,7 @@ def build_test_loader(
|
|||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
num_workers = num_workers or config.num_workers
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
@ -104,9 +109,9 @@ def build_test_loader(
|
|||||||
|
|
||||||
def build_test_dataset(
|
def build_test_dataset(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
config: TestLoaderConfig | None = None,
|
config: Optional[TestLoaderConfig] = None,
|
||||||
) -> TestDataset:
|
) -> TestDataset:
|
||||||
logger.info("Building training dataset...")
|
logger.info("Building training dataset...")
|
||||||
config = config or TestLoaderConfig()
|
config = config or TestLoaderConfig()
|
||||||
|
|||||||
@ -1,51 +1,56 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Sequence
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from lightning import Trainer
|
from lightning import Trainer
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.audio.types import AudioLoader
|
|
||||||
from batdetect2.evaluate import EvaluationConfig
|
|
||||||
from batdetect2.evaluate.dataset import build_test_loader
|
from batdetect2.evaluate.dataset import build_test_loader
|
||||||
from batdetect2.evaluate.evaluator import build_evaluator
|
from batdetect2.evaluate.evaluator import build_evaluator
|
||||||
from batdetect2.evaluate.lightning import EvaluationModule
|
from batdetect2.evaluate.lightning import EvaluationModule
|
||||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger
|
from batdetect2.logging import build_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.outputs import OutputsConfig, build_output_transform
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.typing.postprocess import RawPrediction
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
if TYPE_CHECKING:
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
|
from batdetect2.typing import (
|
||||||
|
AudioLoader,
|
||||||
|
OutputFormatterProtocol,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||||
|
|
||||||
|
|
||||||
def run_evaluate(
|
def evaluate(
|
||||||
model: Model,
|
model: Model,
|
||||||
test_annotations: Sequence[data.ClipAnnotation],
|
test_annotations: Sequence[data.ClipAnnotation],
|
||||||
targets: TargetProtocol | None = None,
|
targets: Optional["TargetProtocol"] = None,
|
||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
audio_config: AudioConfig | None = None,
|
config: Optional["BatDetect2Config"] = None,
|
||||||
evaluation_config: EvaluationConfig | None = None,
|
formatter: Optional["OutputFormatterProtocol"] = None,
|
||||||
output_config: OutputsConfig | None = None,
|
num_workers: Optional[int] = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
|
||||||
formatter: OutputFormatterProtocol | None = None,
|
|
||||||
num_workers: int = 0,
|
|
||||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||||
experiment_name: str | None = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: str | None = None,
|
run_name: Optional[str] = None,
|
||||||
) -> tuple[dict[str, float], list[ClipDetections]]:
|
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
|
|
||||||
audio_config = audio_config or AudioConfig()
|
config = config or BatDetect2Config()
|
||||||
evaluation_config = evaluation_config or EvaluationConfig()
|
|
||||||
output_config = output_config or OutputsConfig()
|
|
||||||
|
|
||||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
||||||
|
|
||||||
preprocessor = preprocessor or model.preprocessor
|
preprocessor = preprocessor or build_preprocessor(
|
||||||
targets = targets or model.targets
|
config=config.preprocess,
|
||||||
|
input_samplerate=audio_loader.samplerate,
|
||||||
|
)
|
||||||
|
|
||||||
|
targets = targets or build_targets(config=config.targets)
|
||||||
|
|
||||||
loader = build_test_loader(
|
loader = build_test_loader(
|
||||||
test_annotations,
|
test_annotations,
|
||||||
@ -54,26 +59,15 @@ def run_evaluate(
|
|||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_transform = build_output_transform(
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||||
config=output_config.transform,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
evaluator = build_evaluator(
|
|
||||||
config=evaluation_config,
|
|
||||||
targets=targets,
|
|
||||||
transform=output_transform,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = build_logger(
|
logger = build_logger(
|
||||||
logger_config or CSVLoggerConfig(),
|
config.evaluation.logger,
|
||||||
log_dir=Path(output_dir),
|
log_dir=Path(output_dir),
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
)
|
)
|
||||||
module = EvaluationModule(
|
module = EvaluationModule(model, evaluator)
|
||||||
model,
|
|
||||||
evaluator,
|
|
||||||
)
|
|
||||||
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
||||||
metrics = trainer.test(module, loader)
|
metrics = trainer.test(module, loader)
|
||||||
|
|
||||||
|
|||||||
@ -1,15 +1,13 @@
|
|||||||
from typing import Any, Dict, Iterable, List, Sequence, Tuple
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
from batdetect2.evaluate.tasks import build_task
|
from batdetect2.evaluate.tasks import build_task
|
||||||
from batdetect2.evaluate.types import EvaluationTaskProtocol, EvaluatorProtocol
|
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
|
||||||
from batdetect2.postprocess.types import ClipDetections, ClipDetectionsTensor
|
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
|
||||||
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Evaluator",
|
"Evaluator",
|
||||||
@ -21,27 +19,15 @@ class Evaluator:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
transform: OutputTransformProtocol,
|
tasks: Sequence[EvaluatorProtocol],
|
||||||
tasks: Sequence[EvaluationTaskProtocol],
|
|
||||||
):
|
):
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.transform = transform
|
|
||||||
self.tasks = tasks
|
self.tasks = tasks
|
||||||
|
|
||||||
def to_clip_detections_batch(
|
|
||||||
self,
|
|
||||||
clip_detections: Sequence[ClipDetectionsTensor],
|
|
||||||
clips: Sequence[data.Clip],
|
|
||||||
) -> list[ClipDetections]:
|
|
||||||
return [
|
|
||||||
self.transform.to_clip_detections(detections=dets, clip=clip)
|
|
||||||
for dets, clip in zip(clip_detections, clips, strict=False)
|
|
||||||
]
|
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[ClipDetections],
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
return [
|
return [
|
||||||
task.evaluate(clip_annotations, predictions) for task in self.tasks
|
task.evaluate(clip_annotations, predictions) for task in self.tasks
|
||||||
@ -50,7 +36,7 @@ class Evaluator:
|
|||||||
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
|
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
for task, outputs in zip(self.tasks, eval_outputs, strict=False):
|
for task, outputs in zip(self.tasks, eval_outputs):
|
||||||
results.update(task.compute_metrics(outputs))
|
results.update(task.compute_metrics(outputs))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@ -59,15 +45,14 @@ class Evaluator:
|
|||||||
self,
|
self,
|
||||||
eval_outputs: List[Any],
|
eval_outputs: List[Any],
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
for task, outputs in zip(self.tasks, eval_outputs, strict=False):
|
for task, outputs in zip(self.tasks, eval_outputs):
|
||||||
for name, fig in task.generate_plots(outputs):
|
for name, fig in task.generate_plots(outputs):
|
||||||
yield name, fig
|
yield name, fig
|
||||||
|
|
||||||
|
|
||||||
def build_evaluator(
|
def build_evaluator(
|
||||||
config: EvaluationConfig | dict | None = None,
|
config: Optional[Union[EvaluationConfig, dict]] = None,
|
||||||
targets: TargetProtocol | None = None,
|
targets: Optional[TargetProtocol] = None,
|
||||||
transform: OutputTransformProtocol | None = None,
|
|
||||||
) -> EvaluatorProtocol:
|
) -> EvaluatorProtocol:
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
|
|
||||||
@ -77,10 +62,7 @@ def build_evaluator(
|
|||||||
if not isinstance(config, EvaluationConfig):
|
if not isinstance(config, EvaluationConfig):
|
||||||
config = EvaluationConfig.model_validate(config)
|
config = EvaluationConfig.model_validate(config)
|
||||||
|
|
||||||
transform = transform or build_output_transform(targets=targets)
|
|
||||||
|
|
||||||
return Evaluator(
|
return Evaluator(
|
||||||
targets=targets,
|
targets=targets,
|
||||||
transform=transform,
|
|
||||||
tasks=[build_task(task, targets=targets) for task in config.tasks],
|
tasks=[build_task(task, targets=targets) for task in config.tasks],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -357,7 +357,7 @@ def train_rf_model(x_train, y_train, num_classes, seed=2001):
|
|||||||
clf = RandomForestClassifier(random_state=seed, n_jobs=-1)
|
clf = RandomForestClassifier(random_state=seed, n_jobs=-1)
|
||||||
clf.fit(x_train, y_train)
|
clf.fit(x_train, y_train)
|
||||||
y_pred = clf.predict(x_train)
|
y_pred = clf.predict(x_train)
|
||||||
(y_pred == y_train).mean()
|
tr_acc = (y_pred == y_train).mean()
|
||||||
# print('Train acc', round(tr_acc*100, 2))
|
# print('Train acc', round(tr_acc*100, 2))
|
||||||
return clf, un_train_class
|
return clf, un_train_class
|
||||||
|
|
||||||
@ -450,7 +450,7 @@ def add_root_path_back(data_sets, ann_path, wav_path):
|
|||||||
|
|
||||||
|
|
||||||
def check_classes_in_train(gt_list, class_names):
|
def check_classes_in_train(gt_list, class_names):
|
||||||
np.sum([gg["start_times"].shape[0] for gg in gt_list])
|
num_gt_total = np.sum([gg["start_times"].shape[0] for gg in gt_list])
|
||||||
num_with_no_class = 0
|
num_with_no_class = 0
|
||||||
for gt in gt_list:
|
for gt in gt_list:
|
||||||
for cc in gt["class_names"]:
|
for cc in gt["class_names"]:
|
||||||
@ -569,7 +569,7 @@ if __name__ == "__main__":
|
|||||||
num_with_no_class = check_classes_in_train(gt_test, class_names)
|
num_with_no_class = check_classes_in_train(gt_test, class_names)
|
||||||
if total_num_calls == num_with_no_class:
|
if total_num_calls == num_with_no_class:
|
||||||
print("Classes from the test set are not in the train set.")
|
print("Classes from the test set are not in the train set.")
|
||||||
raise AssertionError()
|
assert False
|
||||||
|
|
||||||
# only need the train data if evaluating Sonobat or Tadarida
|
# only need the train data if evaluating Sonobat or Tadarida
|
||||||
if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "":
|
if args["sb_ip_dir"] != "" or args["td_ip_dir"] != "":
|
||||||
@ -743,7 +743,7 @@ if __name__ == "__main__":
|
|||||||
# check if the class names are the same
|
# check if the class names are the same
|
||||||
if params_bd["class_names"] != class_names:
|
if params_bd["class_names"] != class_names:
|
||||||
print("Warning: Class names are not the same as the trained model")
|
print("Warning: Class names are not the same as the trained model")
|
||||||
raise AssertionError()
|
assert False
|
||||||
|
|
||||||
run_config = {
|
run_config = {
|
||||||
**bd_args,
|
**bd_args,
|
||||||
@ -753,7 +753,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
preds_bd = []
|
preds_bd = []
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
for gg in gt_test:
|
for ii, gg in enumerate(gt_test):
|
||||||
pred = du.process_file(
|
pred = du.process_file(
|
||||||
gg["file_path"],
|
gg["file_path"],
|
||||||
model,
|
model,
|
||||||
|
|||||||
@ -5,10 +5,11 @@ from soundevent import data
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
|
||||||
from batdetect2.logging import get_image_logger
|
from batdetect2.logging import get_image_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
|
from batdetect2.typing import EvaluatorProtocol
|
||||||
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
|
|
||||||
class EvaluationModule(LightningModule):
|
class EvaluationModule(LightningModule):
|
||||||
@ -23,7 +24,7 @@ class EvaluationModule(LightningModule):
|
|||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
|
|
||||||
self.clip_annotations: List[data.ClipAnnotation] = []
|
self.clip_annotations: List[data.ClipAnnotation] = []
|
||||||
self.predictions: List[ClipDetections] = []
|
self.predictions: List[BatDetect2Prediction] = []
|
||||||
|
|
||||||
def test_step(self, batch: TestExample, batch_idx: int):
|
def test_step(self, batch: TestExample, batch_idx: int):
|
||||||
dataset = self.get_dataset()
|
dataset = self.get_dataset()
|
||||||
@ -33,11 +34,22 @@ class EvaluationModule(LightningModule):
|
|||||||
]
|
]
|
||||||
|
|
||||||
outputs = self.model.detector(batch.spec)
|
outputs = self.model.detector(batch.spec)
|
||||||
clip_detections = self.model.postprocessor(outputs)
|
clip_detections = self.model.postprocessor(
|
||||||
predictions = self.evaluator.to_clip_detections_batch(
|
outputs,
|
||||||
clip_detections,
|
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||||
[clip_annotation.clip for clip_annotation in clip_annotations],
|
|
||||||
)
|
)
|
||||||
|
predictions = [
|
||||||
|
BatDetect2Prediction(
|
||||||
|
clip=clip_annotation.clip,
|
||||||
|
predictions=to_raw_predictions(
|
||||||
|
clip_dets.numpy(),
|
||||||
|
targets=self.evaluator.targets,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for clip_annotation, clip_dets in zip(
|
||||||
|
clip_annotations, clip_detections
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
self.clip_annotations.extend(clip_annotations)
|
self.clip_annotations.extend(clip_annotations)
|
||||||
self.predictions.extend(predictions)
|
self.predictions.extend(predictions)
|
||||||
|
|||||||
617
src/batdetect2/evaluate/match.py
Normal file
617
src/batdetect2/evaluate/match.py
Normal file
@ -0,0 +1,617 @@
|
|||||||
|
from collections.abc import Callable, Iterable, Mapping
|
||||||
|
from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import Field
|
||||||
|
from scipy.optimize import linear_sum_assignment
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.evaluation import compute_affinity
|
||||||
|
from soundevent.geometry import buffer_geometry, compute_bounds, scale_geometry
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig, Registry
|
||||||
|
from batdetect2.evaluate.affinity import (
|
||||||
|
AffinityConfig,
|
||||||
|
BBoxIOUConfig,
|
||||||
|
GeometricIOUConfig,
|
||||||
|
build_affinity_function,
|
||||||
|
)
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
|
from batdetect2.typing import (
|
||||||
|
AffinityFunction,
|
||||||
|
MatcherProtocol,
|
||||||
|
MatchEvaluation,
|
||||||
|
RawPrediction,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
from batdetect2.typing.evaluate import ClipMatches
|
||||||
|
|
||||||
|
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
||||||
|
"""The geometry representation to use for matching."""
|
||||||
|
|
||||||
|
matching_strategies = Registry("matching_strategy")
|
||||||
|
|
||||||
|
|
||||||
|
def match(
|
||||||
|
sound_event_annotations: Sequence[data.SoundEventAnnotation],
|
||||||
|
raw_predictions: Sequence[RawPrediction],
|
||||||
|
clip: data.Clip,
|
||||||
|
scores: Optional[Sequence[float]] = None,
|
||||||
|
targets: Optional[TargetProtocol] = None,
|
||||||
|
matcher: Optional[MatcherProtocol] = None,
|
||||||
|
) -> ClipMatches:
|
||||||
|
if matcher is None:
|
||||||
|
matcher = build_matcher()
|
||||||
|
|
||||||
|
if targets is None:
|
||||||
|
targets = build_targets()
|
||||||
|
|
||||||
|
target_geometries: List[data.Geometry] = [ # type: ignore
|
||||||
|
sound_event_annotation.sound_event.geometry
|
||||||
|
for sound_event_annotation in sound_event_annotations
|
||||||
|
]
|
||||||
|
|
||||||
|
predicted_geometries = [
|
||||||
|
raw_prediction.geometry for raw_prediction in raw_predictions
|
||||||
|
]
|
||||||
|
|
||||||
|
if scores is None:
|
||||||
|
scores = [
|
||||||
|
raw_prediction.detection_score
|
||||||
|
for raw_prediction in raw_predictions
|
||||||
|
]
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
for source_idx, target_idx, affinity in matcher(
|
||||||
|
ground_truth=target_geometries,
|
||||||
|
predictions=predicted_geometries,
|
||||||
|
scores=scores,
|
||||||
|
):
|
||||||
|
target = (
|
||||||
|
sound_event_annotations[target_idx]
|
||||||
|
if target_idx is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
prediction = (
|
||||||
|
raw_predictions[source_idx] if source_idx is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
gt_det = target_idx is not None
|
||||||
|
gt_class = targets.encode_class(target) if target is not None else None
|
||||||
|
gt_geometry = (
|
||||||
|
target_geometries[target_idx] if target_idx is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_score = float(prediction.detection_score) if prediction else 0
|
||||||
|
pred_geometry = (
|
||||||
|
predicted_geometries[source_idx]
|
||||||
|
if source_idx is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
class_scores = (
|
||||||
|
{
|
||||||
|
class_name: score
|
||||||
|
for class_name, score in zip(
|
||||||
|
targets.class_names,
|
||||||
|
prediction.class_scores,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if prediction is not None
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
|
||||||
|
matches.append(
|
||||||
|
MatchEvaluation(
|
||||||
|
clip=clip,
|
||||||
|
sound_event_annotation=target,
|
||||||
|
gt_det=gt_det,
|
||||||
|
gt_class=gt_class,
|
||||||
|
gt_geometry=gt_geometry,
|
||||||
|
pred_score=pred_score,
|
||||||
|
pred_class_scores=class_scores,
|
||||||
|
pred_geometry=pred_geometry,
|
||||||
|
affinity=affinity,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ClipMatches(clip=clip, matches=matches)
|
||||||
|
|
||||||
|
|
||||||
|
class StartTimeMatchConfig(BaseConfig):
|
||||||
|
name: Literal["start_time_match"] = "start_time_match"
|
||||||
|
distance_threshold: float = 0.01
|
||||||
|
|
||||||
|
|
||||||
|
class StartTimeMatcher(MatcherProtocol):
|
||||||
|
def __init__(self, distance_threshold: float):
|
||||||
|
self.distance_threshold = distance_threshold
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
ground_truth: Sequence[data.Geometry],
|
||||||
|
predictions: Sequence[data.Geometry],
|
||||||
|
scores: Sequence[float],
|
||||||
|
):
|
||||||
|
return match_start_times(
|
||||||
|
ground_truth,
|
||||||
|
predictions,
|
||||||
|
scores,
|
||||||
|
distance_threshold=self.distance_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
@matching_strategies.register(StartTimeMatchConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: StartTimeMatchConfig):
|
||||||
|
return StartTimeMatcher(distance_threshold=config.distance_threshold)
|
||||||
|
|
||||||
|
|
||||||
|
def match_start_times(
|
||||||
|
ground_truth: Sequence[data.Geometry],
|
||||||
|
predictions: Sequence[data.Geometry],
|
||||||
|
scores: Sequence[float],
|
||||||
|
distance_threshold: float = 0.01,
|
||||||
|
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||||
|
if not ground_truth:
|
||||||
|
for index in range(len(predictions)):
|
||||||
|
yield index, None, 0
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
if not predictions:
|
||||||
|
for index in range(len(ground_truth)):
|
||||||
|
yield None, index, 0
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
gt_times = np.array([compute_bounds(geom)[0] for geom in ground_truth])
|
||||||
|
pred_times = np.array([compute_bounds(geom)[0] for geom in predictions])
|
||||||
|
|
||||||
|
scores = np.array(scores)
|
||||||
|
sort_args = np.argsort(scores)[::-1]
|
||||||
|
|
||||||
|
distances = np.abs(gt_times[None, :] - pred_times[:, None])
|
||||||
|
closests = np.argmin(distances, axis=-1)
|
||||||
|
|
||||||
|
unmatched_gt = set(range(len(gt_times)))
|
||||||
|
|
||||||
|
for pred_index in sort_args:
|
||||||
|
# Get the closest ground truth
|
||||||
|
gt_closest_index = closests[pred_index]
|
||||||
|
|
||||||
|
if gt_closest_index not in unmatched_gt:
|
||||||
|
# Does not match if closest has been assigned
|
||||||
|
yield pred_index, None, 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get the actual distance
|
||||||
|
distance = distances[pred_index, gt_closest_index]
|
||||||
|
|
||||||
|
if distance > distance_threshold:
|
||||||
|
# Does not match if too far from closest
|
||||||
|
yield pred_index, None, 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Return affinity value: linear interpolation between 0 to 1, where a
|
||||||
|
# distance at the threshold maps to 0 affinity and a zero distance maps
|
||||||
|
# to 1.
|
||||||
|
affinity = np.interp(
|
||||||
|
distance,
|
||||||
|
[0, distance_threshold],
|
||||||
|
[1, 0],
|
||||||
|
left=1,
|
||||||
|
right=0,
|
||||||
|
)
|
||||||
|
unmatched_gt.remove(gt_closest_index)
|
||||||
|
yield pred_index, gt_closest_index, affinity
|
||||||
|
|
||||||
|
for missing_index in unmatched_gt:
|
||||||
|
yield None, missing_index, 0
|
||||||
|
|
||||||
|
|
||||||
|
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
|
||||||
|
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||||
|
return data.BoundingBox(
|
||||||
|
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_interval(geometry: data.Geometry) -> data.TimeInterval:
|
||||||
|
start_time, _, end_time, _ = compute_bounds(geometry)
|
||||||
|
return data.TimeInterval(coordinates=[start_time, end_time])
|
||||||
|
|
||||||
|
|
||||||
|
def _to_timestamp(geometry: data.Geometry) -> data.TimeStamp:
|
||||||
|
start_time = compute_bounds(geometry)[0]
|
||||||
|
return data.TimeStamp(coordinates=start_time)
|
||||||
|
|
||||||
|
|
||||||
|
_geometry_cast_functions: Mapping[
|
||||||
|
MatchingGeometry, Callable[[data.Geometry], data.Geometry]
|
||||||
|
] = {
|
||||||
|
"bbox": _to_bbox,
|
||||||
|
"interval": _to_interval,
|
||||||
|
"timestamp": _to_timestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GreedyMatchConfig(BaseConfig):
|
||||||
|
name: Literal["greedy_match"] = "greedy_match"
|
||||||
|
geometry: MatchingGeometry = "timestamp"
|
||||||
|
affinity_threshold: float = 0.5
|
||||||
|
affinity_function: AffinityConfig = Field(
|
||||||
|
default_factory=GeometricIOUConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GreedyMatcher(MatcherProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
geometry: MatchingGeometry,
|
||||||
|
affinity_threshold: float,
|
||||||
|
affinity_function: AffinityFunction,
|
||||||
|
):
|
||||||
|
self.geometry = geometry
|
||||||
|
self.affinity_function = affinity_function
|
||||||
|
self.affinity_threshold = affinity_threshold
|
||||||
|
self.cast_geometry = _geometry_cast_functions[self.geometry]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
ground_truth: Sequence[data.Geometry],
|
||||||
|
predictions: Sequence[data.Geometry],
|
||||||
|
scores: Sequence[float],
|
||||||
|
):
|
||||||
|
return greedy_match(
|
||||||
|
ground_truth=[self.cast_geometry(geom) for geom in ground_truth],
|
||||||
|
predictions=[self.cast_geometry(geom) for geom in predictions],
|
||||||
|
scores=scores,
|
||||||
|
affinity_function=self.affinity_function,
|
||||||
|
affinity_threshold=self.affinity_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
@matching_strategies.register(GreedyMatchConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: GreedyMatchConfig):
|
||||||
|
affinity_function = build_affinity_function(config.affinity_function)
|
||||||
|
return GreedyMatcher(
|
||||||
|
geometry=config.geometry,
|
||||||
|
affinity_threshold=config.affinity_threshold,
|
||||||
|
affinity_function=affinity_function,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_match(
|
||||||
|
ground_truth: Sequence[data.Geometry],
|
||||||
|
predictions: Sequence[data.Geometry],
|
||||||
|
scores: Sequence[float],
|
||||||
|
affinity_threshold: float = 0.5,
|
||||||
|
affinity_function: AffinityFunction = compute_affinity,
|
||||||
|
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||||
|
"""Performs a greedy, one-to-one matching of source to target geometries.
|
||||||
|
|
||||||
|
Iterates through source geometries, prioritizing by score if provided. Each
|
||||||
|
source is matched to the best available target, provided the affinity
|
||||||
|
exceeds the threshold and the target has not already been assigned.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
source
|
||||||
|
A list of source geometries (e.g., predictions).
|
||||||
|
target
|
||||||
|
A list of target geometries (e.g., ground truths).
|
||||||
|
scores
|
||||||
|
Confidence scores for each source geometry for prioritization.
|
||||||
|
affinity_threshold
|
||||||
|
The minimum affinity score required for a valid match.
|
||||||
|
|
||||||
|
Yields
|
||||||
|
------
|
||||||
|
Tuple[Optional[int], Optional[int], float]
|
||||||
|
A 3-element tuple describing a match or a miss. There are three
|
||||||
|
possible formats:
|
||||||
|
- Successful Match: `(source_idx, target_idx, affinity)`
|
||||||
|
- Unmatched Source (False Positive): `(source_idx, None, 0)`
|
||||||
|
- Unmatched Target (False Negative): `(None, target_idx, 0)`
|
||||||
|
"""
|
||||||
|
unassigned_gt = set(range(len(ground_truth)))
|
||||||
|
|
||||||
|
if not predictions:
|
||||||
|
for gt_idx in range(len(ground_truth)):
|
||||||
|
yield None, gt_idx, 0
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
if not ground_truth:
|
||||||
|
for pred_idx in range(len(predictions)):
|
||||||
|
yield pred_idx, None, 0
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
indices = np.argsort(scores)[::-1]
|
||||||
|
|
||||||
|
for pred_idx in indices:
|
||||||
|
source_geometry = predictions[pred_idx]
|
||||||
|
|
||||||
|
affinities = np.array(
|
||||||
|
[
|
||||||
|
affinity_function(source_geometry, target_geometry)
|
||||||
|
for target_geometry in ground_truth
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
closest_target = int(np.argmax(affinities))
|
||||||
|
affinity = affinities[closest_target]
|
||||||
|
|
||||||
|
if affinities[closest_target] <= affinity_threshold:
|
||||||
|
yield pred_idx, None, 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
if closest_target not in unassigned_gt:
|
||||||
|
yield pred_idx, None, 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
unassigned_gt.remove(closest_target)
|
||||||
|
yield pred_idx, closest_target, affinity
|
||||||
|
|
||||||
|
for gt_idx in unassigned_gt:
|
||||||
|
yield None, gt_idx, 0
|
||||||
|
|
||||||
|
|
||||||
|
class GreedyAffinityMatchConfig(BaseConfig):
|
||||||
|
name: Literal["greedy_affinity_match"] = "greedy_affinity_match"
|
||||||
|
affinity_function: AffinityConfig = Field(default_factory=BBoxIOUConfig)
|
||||||
|
affinity_threshold: float = 0.5
|
||||||
|
time_buffer: float = 0
|
||||||
|
frequency_buffer: float = 0
|
||||||
|
time_scale: float = 1.0
|
||||||
|
frequency_scale: float = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class GreedyAffinityMatcher(MatcherProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
affinity_threshold: float,
|
||||||
|
affinity_function: AffinityFunction,
|
||||||
|
time_buffer: float = 0,
|
||||||
|
frequency_buffer: float = 0,
|
||||||
|
time_scale: float = 1.0,
|
||||||
|
frequency_scale: float = 1.0,
|
||||||
|
):
|
||||||
|
self.affinity_threshold = affinity_threshold
|
||||||
|
self.affinity_function = affinity_function
|
||||||
|
self.time_buffer = time_buffer
|
||||||
|
self.frequency_buffer = frequency_buffer
|
||||||
|
self.time_scale = time_scale
|
||||||
|
self.frequency_scale = frequency_scale
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
ground_truth: Sequence[data.Geometry],
|
||||||
|
predictions: Sequence[data.Geometry],
|
||||||
|
scores: Sequence[float],
|
||||||
|
):
|
||||||
|
if self.time_buffer != 0 or self.frequency_buffer != 0:
|
||||||
|
ground_truth = [
|
||||||
|
buffer_geometry(
|
||||||
|
geometry,
|
||||||
|
time_buffer=self.time_buffer,
|
||||||
|
freq_buffer=self.frequency_buffer,
|
||||||
|
)
|
||||||
|
for geometry in ground_truth
|
||||||
|
]
|
||||||
|
|
||||||
|
predictions = [
|
||||||
|
buffer_geometry(
|
||||||
|
geometry,
|
||||||
|
time_buffer=self.time_buffer,
|
||||||
|
freq_buffer=self.frequency_buffer,
|
||||||
|
)
|
||||||
|
for geometry in predictions
|
||||||
|
]
|
||||||
|
|
||||||
|
affinity_matrix = compute_affinity_matrix(
|
||||||
|
ground_truth,
|
||||||
|
predictions,
|
||||||
|
self.affinity_function,
|
||||||
|
time_scale=self.time_scale,
|
||||||
|
frequency_scale=self.frequency_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
return select_greedy_matches(
|
||||||
|
affinity_matrix,
|
||||||
|
affinity_threshold=self.affinity_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
@matching_strategies.register(GreedyAffinityMatchConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: GreedyAffinityMatchConfig):
|
||||||
|
affinity_function = build_affinity_function(config.affinity_function)
|
||||||
|
return GreedyAffinityMatcher(
|
||||||
|
affinity_threshold=config.affinity_threshold,
|
||||||
|
affinity_function=affinity_function,
|
||||||
|
time_scale=config.time_scale,
|
||||||
|
frequency_scale=config.frequency_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OptimalMatchConfig(BaseConfig):
|
||||||
|
name: Literal["optimal_affinity_match"] = "optimal_affinity_match"
|
||||||
|
affinity_function: AffinityConfig = Field(default_factory=BBoxIOUConfig)
|
||||||
|
affinity_threshold: float = 0.5
|
||||||
|
time_buffer: float = 0
|
||||||
|
frequency_buffer: float = 0
|
||||||
|
time_scale: float = 1.0
|
||||||
|
frequency_scale: float = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class OptimalMatcher(MatcherProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
affinity_threshold: float,
|
||||||
|
affinity_function: AffinityFunction,
|
||||||
|
time_buffer: float = 0,
|
||||||
|
frequency_buffer: float = 0,
|
||||||
|
time_scale: float = 1.0,
|
||||||
|
frequency_scale: float = 1.0,
|
||||||
|
):
|
||||||
|
self.affinity_threshold = affinity_threshold
|
||||||
|
self.affinity_function = affinity_function
|
||||||
|
self.time_buffer = time_buffer
|
||||||
|
self.frequency_buffer = frequency_buffer
|
||||||
|
self.time_scale = time_scale
|
||||||
|
self.frequency_scale = frequency_scale
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
ground_truth: Sequence[data.Geometry],
|
||||||
|
predictions: Sequence[data.Geometry],
|
||||||
|
scores: Sequence[float],
|
||||||
|
):
|
||||||
|
if self.time_buffer != 0 or self.frequency_buffer != 0:
|
||||||
|
ground_truth = [
|
||||||
|
buffer_geometry(
|
||||||
|
geometry,
|
||||||
|
time_buffer=self.time_buffer,
|
||||||
|
freq_buffer=self.frequency_buffer,
|
||||||
|
)
|
||||||
|
for geometry in ground_truth
|
||||||
|
]
|
||||||
|
|
||||||
|
predictions = [
|
||||||
|
buffer_geometry(
|
||||||
|
geometry,
|
||||||
|
time_buffer=self.time_buffer,
|
||||||
|
freq_buffer=self.frequency_buffer,
|
||||||
|
)
|
||||||
|
for geometry in predictions
|
||||||
|
]
|
||||||
|
|
||||||
|
affinity_matrix = compute_affinity_matrix(
|
||||||
|
ground_truth,
|
||||||
|
predictions,
|
||||||
|
self.affinity_function,
|
||||||
|
time_scale=self.time_scale,
|
||||||
|
frequency_scale=self.frequency_scale,
|
||||||
|
)
|
||||||
|
return select_optimal_matches(
|
||||||
|
affinity_matrix,
|
||||||
|
affinity_threshold=self.affinity_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
@matching_strategies.register(OptimalMatchConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: OptimalMatchConfig):
|
||||||
|
affinity_function = build_affinity_function(config.affinity_function)
|
||||||
|
return OptimalMatcher(
|
||||||
|
affinity_threshold=config.affinity_threshold,
|
||||||
|
affinity_function=affinity_function,
|
||||||
|
time_buffer=config.time_buffer,
|
||||||
|
frequency_buffer=config.frequency_buffer,
|
||||||
|
time_scale=config.time_scale,
|
||||||
|
frequency_scale=config.frequency_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
MatchConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
GreedyMatchConfig,
|
||||||
|
StartTimeMatchConfig,
|
||||||
|
OptimalMatchConfig,
|
||||||
|
GreedyAffinityMatchConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def compute_affinity_matrix(
|
||||||
|
ground_truth: Sequence[data.Geometry],
|
||||||
|
predictions: Sequence[data.Geometry],
|
||||||
|
affinity_function: AffinityFunction,
|
||||||
|
time_scale: float = 1,
|
||||||
|
frequency_scale: float = 1,
|
||||||
|
) -> np.ndarray:
|
||||||
|
# Scale geometries if necessary
|
||||||
|
if time_scale != 1 or frequency_scale != 1:
|
||||||
|
ground_truth = [
|
||||||
|
scale_geometry(geometry, time_scale, frequency_scale)
|
||||||
|
for geometry in ground_truth
|
||||||
|
]
|
||||||
|
|
||||||
|
predictions = [
|
||||||
|
scale_geometry(geometry, time_scale, frequency_scale)
|
||||||
|
for geometry in predictions
|
||||||
|
]
|
||||||
|
|
||||||
|
affinity_matrix = np.zeros((len(ground_truth), len(predictions)))
|
||||||
|
for gt_idx, gt_geometry in enumerate(ground_truth):
|
||||||
|
for pred_idx, pred_geometry in enumerate(predictions):
|
||||||
|
affinity = affinity_function(
|
||||||
|
gt_geometry,
|
||||||
|
pred_geometry,
|
||||||
|
)
|
||||||
|
affinity_matrix[gt_idx, pred_idx] = affinity
|
||||||
|
|
||||||
|
return affinity_matrix
|
||||||
|
|
||||||
|
|
||||||
|
def select_optimal_matches(
|
||||||
|
affinity_matrix: np.ndarray,
|
||||||
|
affinity_threshold: float = 0.5,
|
||||||
|
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||||
|
num_gt, num_pred = affinity_matrix.shape
|
||||||
|
gts = set(range(num_gt))
|
||||||
|
preds = set(range(num_pred))
|
||||||
|
|
||||||
|
assiged_rows, assigned_columns = linear_sum_assignment(
|
||||||
|
affinity_matrix,
|
||||||
|
maximize=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for gt_idx, pred_idx in zip(assiged_rows, assigned_columns):
|
||||||
|
affinity = float(affinity_matrix[gt_idx, pred_idx])
|
||||||
|
|
||||||
|
if affinity <= affinity_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield gt_idx, pred_idx, affinity
|
||||||
|
gts.remove(gt_idx)
|
||||||
|
preds.remove(pred_idx)
|
||||||
|
|
||||||
|
for gt_idx in gts:
|
||||||
|
yield gt_idx, None, 0
|
||||||
|
|
||||||
|
for pred_idx in preds:
|
||||||
|
yield None, pred_idx, 0
|
||||||
|
|
||||||
|
|
||||||
|
def select_greedy_matches(
|
||||||
|
affinity_matrix: np.ndarray,
|
||||||
|
affinity_threshold: float = 0.5,
|
||||||
|
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||||
|
num_gt, num_pred = affinity_matrix.shape
|
||||||
|
unmatched_pred = set(range(num_pred))
|
||||||
|
|
||||||
|
for gt_idx in range(num_gt):
|
||||||
|
row = affinity_matrix[gt_idx]
|
||||||
|
|
||||||
|
top_pred = int(np.argmax(row))
|
||||||
|
top_affinity = float(row[top_pred])
|
||||||
|
|
||||||
|
if (
|
||||||
|
top_affinity <= affinity_threshold
|
||||||
|
or top_pred not in unmatched_pred
|
||||||
|
):
|
||||||
|
yield None, gt_idx, 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
unmatched_pred.remove(top_pred)
|
||||||
|
yield top_pred, gt_idx, top_affinity
|
||||||
|
|
||||||
|
for pred_idx in unmatched_pred:
|
||||||
|
yield pred_idx, None, 0
|
||||||
|
|
||||||
|
|
||||||
|
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
|
||||||
|
config = config or StartTimeMatchConfig()
|
||||||
|
return matching_strategies.build(config)
|
||||||
@ -7,8 +7,10 @@ from typing import (
|
|||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
Mapping,
|
Mapping,
|
||||||
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -16,23 +18,16 @@ from pydantic import Field
|
|||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core import (
|
from batdetect2.core import BaseConfig, Registry
|
||||||
BaseConfig,
|
|
||||||
ImportConfig,
|
|
||||||
Registry,
|
|
||||||
add_import_config,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.metrics.common import (
|
from batdetect2.evaluate.metrics.common import (
|
||||||
average_precision,
|
average_precision,
|
||||||
compute_precision_recall,
|
compute_precision_recall,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.types import Detection
|
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClassificationMetric",
|
"ClassificationMetric",
|
||||||
"ClassificationMetricConfig",
|
"ClassificationMetricConfig",
|
||||||
"ClassificationMetricImportConfig",
|
|
||||||
"build_classification_metric",
|
"build_classification_metric",
|
||||||
"compute_precision_recall_curves",
|
"compute_precision_recall_curves",
|
||||||
]
|
]
|
||||||
@ -41,13 +36,13 @@ __all__ = [
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MatchEval:
|
class MatchEval:
|
||||||
clip: data.Clip
|
clip: data.Clip
|
||||||
gt: data.SoundEventAnnotation | None
|
gt: Optional[data.SoundEventAnnotation]
|
||||||
pred: Detection | None
|
pred: Optional[RawPrediction]
|
||||||
|
|
||||||
is_prediction: bool
|
is_prediction: bool
|
||||||
is_ground_truth: bool
|
is_ground_truth: bool
|
||||||
is_generic: bool
|
is_generic: bool
|
||||||
true_class: str | None
|
true_class: Optional[str]
|
||||||
score: float
|
score: float
|
||||||
|
|
||||||
|
|
||||||
@ -65,28 +60,17 @@ classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(classification_metrics)
|
|
||||||
class ClassificationMetricImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a classification metric.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class BaseClassificationConfig(BaseConfig):
|
class BaseClassificationConfig(BaseConfig):
|
||||||
include: List[str] | None = None
|
include: Optional[List[str]] = None
|
||||||
exclude: List[str] | None = None
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class BaseClassificationMetric:
|
class BaseClassificationMetric:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
include: List[str] | None = None,
|
include: Optional[List[str]] = None,
|
||||||
exclude: List[str] | None = None,
|
exclude: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.include = include
|
self.include = include
|
||||||
@ -116,8 +100,8 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
|
|||||||
ignore_non_predictions: bool = True,
|
ignore_non_predictions: bool = True,
|
||||||
ignore_generic: bool = True,
|
ignore_generic: bool = True,
|
||||||
label: str = "average_precision",
|
label: str = "average_precision",
|
||||||
include: List[str] | None = None,
|
include: Optional[List[str]] = None,
|
||||||
exclude: List[str] | None = None,
|
exclude: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
super().__init__(include=include, exclude=exclude, targets=targets)
|
super().__init__(include=include, exclude=exclude, targets=targets)
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
@ -185,8 +169,8 @@ class ClassificationROCAUC(BaseClassificationMetric):
|
|||||||
ignore_non_predictions: bool = True,
|
ignore_non_predictions: bool = True,
|
||||||
ignore_generic: bool = True,
|
ignore_generic: bool = True,
|
||||||
label: str = "roc_auc",
|
label: str = "roc_auc",
|
||||||
include: List[str] | None = None,
|
include: Optional[List[str]] = None,
|
||||||
exclude: List[str] | None = None,
|
exclude: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
@ -241,7 +225,10 @@ class ClassificationROCAUC(BaseClassificationMetric):
|
|||||||
|
|
||||||
|
|
||||||
ClassificationMetricConfig = Annotated[
|
ClassificationMetricConfig = Annotated[
|
||||||
ClassificationAveragePrecisionConfig | ClassificationROCAUCConfig,
|
Union[
|
||||||
|
ClassificationAveragePrecisionConfig,
|
||||||
|
ClassificationROCAUCConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,17 +1,13 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Annotated, Callable, Dict, Literal, Sequence, Set
|
from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.core.registries import (
|
from batdetect2.core.registries import Registry
|
||||||
ImportConfig,
|
|
||||||
Registry,
|
|
||||||
add_import_config,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
|
|
||||||
|
|
||||||
@ -28,17 +24,6 @@ clip_classification_metrics: Registry[ClipClassificationMetric, []] = Registry(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(clip_classification_metrics)
|
|
||||||
class ClipClassificationMetricImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a clip classification metric.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationAveragePrecisionConfig(BaseConfig):
|
class ClipClassificationAveragePrecisionConfig(BaseConfig):
|
||||||
name: Literal["average_precision"] = "average_precision"
|
name: Literal["average_precision"] = "average_precision"
|
||||||
label: str = "average_precision"
|
label: str = "average_precision"
|
||||||
@ -138,7 +123,10 @@ class ClipClassificationROCAUC:
|
|||||||
|
|
||||||
|
|
||||||
ClipClassificationMetricConfig = Annotated[
|
ClipClassificationMetricConfig = Annotated[
|
||||||
ClipClassificationAveragePrecisionConfig | ClipClassificationROCAUCConfig,
|
Union[
|
||||||
|
ClipClassificationAveragePrecisionConfig,
|
||||||
|
ClipClassificationROCAUCConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,16 +1,12 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Annotated, Callable, Dict, Literal, Sequence
|
from typing import Annotated, Callable, Dict, Literal, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.core.registries import (
|
from batdetect2.core.registries import Registry
|
||||||
ImportConfig,
|
|
||||||
Registry,
|
|
||||||
add_import_config,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
|
|
||||||
|
|
||||||
@ -27,17 +23,6 @@ clip_detection_metrics: Registry[ClipDetectionMetric, []] = Registry(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(clip_detection_metrics)
|
|
||||||
class ClipDetectionMetricImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a clip detection metric.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionAveragePrecisionConfig(BaseConfig):
|
class ClipDetectionAveragePrecisionConfig(BaseConfig):
|
||||||
name: Literal["average_precision"] = "average_precision"
|
name: Literal["average_precision"] = "average_precision"
|
||||||
label: str = "average_precision"
|
label: str = "average_precision"
|
||||||
@ -174,10 +159,12 @@ class ClipDetectionPrecision:
|
|||||||
|
|
||||||
|
|
||||||
ClipDetectionMetricConfig = Annotated[
|
ClipDetectionMetricConfig = Annotated[
|
||||||
ClipDetectionAveragePrecisionConfig
|
Union[
|
||||||
| ClipDetectionROCAUCConfig
|
ClipDetectionAveragePrecisionConfig,
|
||||||
| ClipDetectionRecallConfig
|
ClipDetectionROCAUCConfig,
|
||||||
| ClipDetectionPrecisionConfig,
|
ClipDetectionRecallConfig,
|
||||||
|
ClipDetectionPrecisionConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -11,7 +11,7 @@ __all__ = [
|
|||||||
def compute_precision_recall(
|
def compute_precision_recall(
|
||||||
y_true,
|
y_true,
|
||||||
y_score,
|
y_score,
|
||||||
num_positives: int | None = None,
|
num_positives: Optional[int] = None,
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
y_true = np.array(y_true)
|
y_true = np.array(y_true)
|
||||||
y_score = np.array(y_score)
|
y_score = np.array(y_score)
|
||||||
@ -41,7 +41,7 @@ def compute_precision_recall(
|
|||||||
def average_precision(
|
def average_precision(
|
||||||
y_true,
|
y_true,
|
||||||
y_score,
|
y_score,
|
||||||
num_positives: int | None = None,
|
num_positives: Optional[int] = None,
|
||||||
) -> float:
|
) -> float:
|
||||||
if num_positives == 0:
|
if num_positives == 0:
|
||||||
return np.nan
|
return np.nan
|
||||||
|
|||||||
@ -5,7 +5,9 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -13,27 +15,21 @@ from pydantic import Field
|
|||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core import (
|
from batdetect2.core import BaseConfig, Registry
|
||||||
BaseConfig,
|
|
||||||
ImportConfig,
|
|
||||||
Registry,
|
|
||||||
add_import_config,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
from batdetect2.postprocess.types import Detection
|
from batdetect2.typing import RawPrediction
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DetectionMetricConfig",
|
"DetectionMetricConfig",
|
||||||
"DetectionMetric",
|
"DetectionMetric",
|
||||||
"DetectionMetricImportConfig",
|
|
||||||
"build_detection_metric",
|
"build_detection_metric",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MatchEval:
|
class MatchEval:
|
||||||
gt: data.SoundEventAnnotation | None
|
gt: Optional[data.SoundEventAnnotation]
|
||||||
pred: Detection | None
|
pred: Optional[RawPrediction]
|
||||||
|
|
||||||
is_prediction: bool
|
is_prediction: bool
|
||||||
is_ground_truth: bool
|
is_ground_truth: bool
|
||||||
@ -52,17 +48,6 @@ DetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
|||||||
detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric")
|
detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric")
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(detection_metrics)
|
|
||||||
class DetectionMetricImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a detection metric.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionAveragePrecisionConfig(BaseConfig):
|
class DetectionAveragePrecisionConfig(BaseConfig):
|
||||||
name: Literal["average_precision"] = "average_precision"
|
name: Literal["average_precision"] = "average_precision"
|
||||||
label: str = "average_precision"
|
label: str = "average_precision"
|
||||||
@ -94,7 +79,7 @@ class DetectionAveragePrecision:
|
|||||||
y_score.append(m.score)
|
y_score.append(m.score)
|
||||||
|
|
||||||
ap = average_precision(y_true, y_score, num_positives=num_positives)
|
ap = average_precision(y_true, y_score, num_positives=num_positives)
|
||||||
return {self.label: float(ap)}
|
return {self.label: ap}
|
||||||
|
|
||||||
@detection_metrics.register(DetectionAveragePrecisionConfig)
|
@detection_metrics.register(DetectionAveragePrecisionConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -227,10 +212,12 @@ class DetectionPrecision:
|
|||||||
|
|
||||||
|
|
||||||
DetectionMetricConfig = Annotated[
|
DetectionMetricConfig = Annotated[
|
||||||
DetectionAveragePrecisionConfig
|
Union[
|
||||||
| DetectionROCAUCConfig
|
DetectionAveragePrecisionConfig,
|
||||||
| DetectionRecallConfig
|
DetectionROCAUCConfig,
|
||||||
| DetectionPrecisionConfig,
|
DetectionRecallConfig,
|
||||||
|
DetectionPrecisionConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
@ -5,7 +6,9 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -13,20 +16,14 @@ from pydantic import Field
|
|||||||
from sklearn import metrics, preprocessing
|
from sklearn import metrics, preprocessing
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core import (
|
from batdetect2.core import BaseConfig, Registry
|
||||||
BaseConfig,
|
|
||||||
ImportConfig,
|
|
||||||
Registry,
|
|
||||||
add_import_config,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
from batdetect2.postprocess.types import Detection
|
from batdetect2.typing import RawPrediction
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TopClassMetricConfig",
|
"TopClassMetricConfig",
|
||||||
"TopClassMetric",
|
"TopClassMetric",
|
||||||
"TopClassMetricImportConfig",
|
|
||||||
"build_top_class_metric",
|
"build_top_class_metric",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -34,14 +31,14 @@ __all__ = [
|
|||||||
@dataclass
|
@dataclass
|
||||||
class MatchEval:
|
class MatchEval:
|
||||||
clip: data.Clip
|
clip: data.Clip
|
||||||
gt: data.SoundEventAnnotation | None
|
gt: Optional[data.SoundEventAnnotation]
|
||||||
pred: Detection | None
|
pred: Optional[RawPrediction]
|
||||||
|
|
||||||
is_ground_truth: bool
|
is_ground_truth: bool
|
||||||
is_generic: bool
|
is_generic: bool
|
||||||
is_prediction: bool
|
is_prediction: bool
|
||||||
pred_class: str | None
|
pred_class: Optional[str]
|
||||||
true_class: str | None
|
true_class: Optional[str]
|
||||||
score: float
|
score: float
|
||||||
|
|
||||||
|
|
||||||
@ -57,17 +54,6 @@ TopClassMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
|||||||
top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric")
|
top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric")
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(top_class_metrics)
|
|
||||||
class TopClassMetricImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a top-class metric.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassAveragePrecisionConfig(BaseConfig):
|
class TopClassAveragePrecisionConfig(BaseConfig):
|
||||||
name: Literal["average_precision"] = "average_precision"
|
name: Literal["average_precision"] = "average_precision"
|
||||||
label: str = "average_precision"
|
label: str = "average_precision"
|
||||||
@ -315,11 +301,13 @@ class BalancedAccuracy:
|
|||||||
|
|
||||||
|
|
||||||
TopClassMetricConfig = Annotated[
|
TopClassMetricConfig = Annotated[
|
||||||
TopClassAveragePrecisionConfig
|
Union[
|
||||||
| TopClassROCAUCConfig
|
TopClassAveragePrecisionConfig,
|
||||||
| TopClassRecallConfig
|
TopClassROCAUCConfig,
|
||||||
| TopClassPrecisionConfig
|
TopClassRecallConfig,
|
||||||
| BalancedAccuracyConfig,
|
TopClassPrecisionConfig,
|
||||||
|
BalancedAccuracyConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,14 +1,16 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
class BasePlotConfig(BaseConfig):
|
class BasePlotConfig(BaseConfig):
|
||||||
label: str = "plot"
|
label: str = "plot"
|
||||||
theme: str = "default"
|
theme: str = "default"
|
||||||
title: str | None = None
|
title: Optional[str] = None
|
||||||
figsize: tuple[int, int] = (10, 10)
|
figsize: tuple[int, int] = (10, 10)
|
||||||
dpi: int = 100
|
dpi: int = 100
|
||||||
|
|
||||||
@ -19,7 +21,7 @@ class BasePlot:
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
label: str = "plot",
|
label: str = "plot",
|
||||||
figsize: tuple[int, int] = (10, 10),
|
figsize: tuple[int, int] = (10, 10),
|
||||||
title: str | None = None,
|
title: Optional[str] = None,
|
||||||
dpi: int = 100,
|
dpi: int = 100,
|
||||||
theme: str = "default",
|
theme: str = "default",
|
||||||
):
|
):
|
||||||
|
|||||||
@ -3,8 +3,10 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Iterable,
|
Iterable,
|
||||||
Literal,
|
Literal,
|
||||||
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -12,7 +14,7 @@ from matplotlib.figure import Figure
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
|
|
||||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
from batdetect2.core import Registry
|
||||||
from batdetect2.evaluate.metrics.classification import (
|
from batdetect2.evaluate.metrics.classification import (
|
||||||
ClipEval,
|
ClipEval,
|
||||||
_extract_per_class_metric_data,
|
_extract_per_class_metric_data,
|
||||||
@ -29,7 +31,7 @@ from batdetect2.plotting.metrics import (
|
|||||||
plot_threshold_recall_curve,
|
plot_threshold_recall_curve,
|
||||||
plot_threshold_recall_curves,
|
plot_threshold_recall_curves,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing import TargetProtocol
|
||||||
|
|
||||||
ClassificationPlotter = Callable[
|
ClassificationPlotter = Callable[
|
||||||
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
||||||
@ -40,21 +42,10 @@ classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(classification_plots)
|
|
||||||
class ClassificationPlotImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a classification plot.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class PRCurveConfig(BasePlotConfig):
|
class PRCurveConfig(BasePlotConfig):
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
label: str = "pr_curve"
|
label: str = "pr_curve"
|
||||||
title: str | None = "Classification Precision-Recall Curve"
|
title: Optional[str] = "Classification Precision-Recall Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
separate_figures: bool = False
|
separate_figures: bool = False
|
||||||
@ -97,7 +88,9 @@ class PRCurve(BasePlot):
|
|||||||
|
|
||||||
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
|
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||||
ax.set_title(class_name)
|
ax.set_title(class_name)
|
||||||
|
|
||||||
yield f"{self.label}/{class_name}", fig
|
yield f"{self.label}/{class_name}", fig
|
||||||
|
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
@classification_plots.register(PRCurveConfig)
|
@classification_plots.register(PRCurveConfig)
|
||||||
@ -115,7 +108,7 @@ class PRCurve(BasePlot):
|
|||||||
class ThresholdPrecisionCurveConfig(BasePlotConfig):
|
class ThresholdPrecisionCurveConfig(BasePlotConfig):
|
||||||
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
|
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
|
||||||
label: str = "threshold_precision_curve"
|
label: str = "threshold_precision_curve"
|
||||||
title: str | None = "Classification Threshold-Precision Curve"
|
title: Optional[str] = "Classification Threshold-Precision Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
separate_figures: bool = False
|
separate_figures: bool = False
|
||||||
@ -188,7 +181,7 @@ class ThresholdPrecisionCurve(BasePlot):
|
|||||||
class ThresholdRecallCurveConfig(BasePlotConfig):
|
class ThresholdRecallCurveConfig(BasePlotConfig):
|
||||||
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
|
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
|
||||||
label: str = "threshold_recall_curve"
|
label: str = "threshold_recall_curve"
|
||||||
title: str | None = "Classification Threshold-Recall Curve"
|
title: Optional[str] = "Classification Threshold-Recall Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
separate_figures: bool = False
|
separate_figures: bool = False
|
||||||
@ -261,7 +254,7 @@ class ThresholdRecallCurve(BasePlot):
|
|||||||
class ROCCurveConfig(BasePlotConfig):
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
label: str = "roc_curve"
|
label: str = "roc_curve"
|
||||||
title: str | None = "Classification ROC Curve"
|
title: Optional[str] = "Classification ROC Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
separate_figures: bool = False
|
separate_figures: bool = False
|
||||||
@ -333,10 +326,12 @@ class ROCCurve(BasePlot):
|
|||||||
|
|
||||||
|
|
||||||
ClassificationPlotConfig = Annotated[
|
ClassificationPlotConfig = Annotated[
|
||||||
PRCurveConfig
|
Union[
|
||||||
| ROCCurveConfig
|
PRCurveConfig,
|
||||||
| ThresholdPrecisionCurveConfig
|
ROCCurveConfig,
|
||||||
| ThresholdRecallCurveConfig,
|
ThresholdPrecisionCurveConfig,
|
||||||
|
ThresholdRecallCurveConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -3,8 +3,10 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Iterable,
|
Iterable,
|
||||||
Literal,
|
Literal,
|
||||||
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -12,7 +14,7 @@ from matplotlib.figure import Figure
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
|
|
||||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
from batdetect2.core import Registry
|
||||||
from batdetect2.evaluate.metrics.clip_classification import ClipEval
|
from batdetect2.evaluate.metrics.clip_classification import ClipEval
|
||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
@ -22,11 +24,10 @@ from batdetect2.plotting.metrics import (
|
|||||||
plot_roc_curve,
|
plot_roc_curve,
|
||||||
plot_roc_curves,
|
plot_roc_curves,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClipClassificationPlotConfig",
|
"ClipClassificationPlotConfig",
|
||||||
"ClipClassificationPlotImportConfig",
|
|
||||||
"ClipClassificationPlotter",
|
"ClipClassificationPlotter",
|
||||||
"build_clip_classification_plotter",
|
"build_clip_classification_plotter",
|
||||||
]
|
]
|
||||||
@ -40,21 +41,10 @@ clip_classification_plots: Registry[
|
|||||||
] = Registry("clip_classification_plot")
|
] = Registry("clip_classification_plot")
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(clip_classification_plots)
|
|
||||||
class ClipClassificationPlotImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a clip classification plot.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class PRCurveConfig(BasePlotConfig):
|
class PRCurveConfig(BasePlotConfig):
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
label: str = "pr_curve"
|
label: str = "pr_curve"
|
||||||
title: str | None = "Clip Classification Precision-Recall Curve"
|
title: Optional[str] = "Clip Classification Precision-Recall Curve"
|
||||||
separate_figures: bool = False
|
separate_figures: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ -121,7 +111,7 @@ class PRCurve(BasePlot):
|
|||||||
class ROCCurveConfig(BasePlotConfig):
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
label: str = "roc_curve"
|
label: str = "roc_curve"
|
||||||
title: str | None = "Clip Classification ROC Curve"
|
title: Optional[str] = "Clip Classification ROC Curve"
|
||||||
separate_figures: bool = False
|
separate_figures: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ -184,7 +174,10 @@ class ROCCurve(BasePlot):
|
|||||||
|
|
||||||
|
|
||||||
ClipClassificationPlotConfig = Annotated[
|
ClipClassificationPlotConfig = Annotated[
|
||||||
PRCurveConfig | ROCCurveConfig,
|
Union[
|
||||||
|
PRCurveConfig,
|
||||||
|
ROCCurveConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -3,8 +3,10 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Iterable,
|
Iterable,
|
||||||
Literal,
|
Literal,
|
||||||
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -13,16 +15,15 @@ from matplotlib.figure import Figure
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
|
|
||||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
from batdetect2.core import Registry
|
||||||
from batdetect2.evaluate.metrics.clip_detection import ClipEval
|
from batdetect2.evaluate.metrics.clip_detection import ClipEval
|
||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClipDetectionPlotConfig",
|
"ClipDetectionPlotConfig",
|
||||||
"ClipDetectionPlotImportConfig",
|
|
||||||
"ClipDetectionPlotter",
|
"ClipDetectionPlotter",
|
||||||
"build_clip_detection_plotter",
|
"build_clip_detection_plotter",
|
||||||
]
|
]
|
||||||
@ -37,21 +38,10 @@ clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(clip_detection_plots)
|
|
||||||
class ClipDetectionPlotImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a clip detection plot.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class PRCurveConfig(BasePlotConfig):
|
class PRCurveConfig(BasePlotConfig):
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
label: str = "pr_curve"
|
label: str = "pr_curve"
|
||||||
title: str | None = "Clip Detection Precision-Recall Curve"
|
title: Optional[str] = "Clip Detection Precision-Recall Curve"
|
||||||
|
|
||||||
|
|
||||||
class PRCurve(BasePlot):
|
class PRCurve(BasePlot):
|
||||||
@ -84,7 +74,7 @@ class PRCurve(BasePlot):
|
|||||||
class ROCCurveConfig(BasePlotConfig):
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
label: str = "roc_curve"
|
label: str = "roc_curve"
|
||||||
title: str | None = "Clip Detection ROC Curve"
|
title: Optional[str] = "Clip Detection ROC Curve"
|
||||||
|
|
||||||
|
|
||||||
class ROCCurve(BasePlot):
|
class ROCCurve(BasePlot):
|
||||||
@ -117,7 +107,7 @@ class ROCCurve(BasePlot):
|
|||||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||||
name: Literal["score_distribution"] = "score_distribution"
|
name: Literal["score_distribution"] = "score_distribution"
|
||||||
label: str = "score_distribution"
|
label: str = "score_distribution"
|
||||||
title: str | None = "Clip Detection Score Distribution"
|
title: Optional[str] = "Clip Detection Score Distribution"
|
||||||
|
|
||||||
|
|
||||||
class ScoreDistributionPlot(BasePlot):
|
class ScoreDistributionPlot(BasePlot):
|
||||||
@ -157,7 +147,11 @@ class ScoreDistributionPlot(BasePlot):
|
|||||||
|
|
||||||
|
|
||||||
ClipDetectionPlotConfig = Annotated[
|
ClipDetectionPlotConfig = Annotated[
|
||||||
PRCurveConfig | ROCCurveConfig | ScoreDistributionPlotConfig,
|
Union[
|
||||||
|
PRCurveConfig,
|
||||||
|
ROCCurveConfig,
|
||||||
|
ScoreDistributionPlotConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -4,8 +4,10 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
Iterable,
|
Iterable,
|
||||||
Literal,
|
Literal,
|
||||||
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -16,16 +18,14 @@ from pydantic import Field
|
|||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
from batdetect2.audio.types import AudioLoader
|
from batdetect2.core import Registry
|
||||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
|
||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
from batdetect2.evaluate.metrics.detection import ClipEval
|
from batdetect2.evaluate.metrics.detection import ClipEval
|
||||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
from batdetect2.plotting.detections import plot_clip_detections
|
from batdetect2.plotting.detections import plot_clip_detections
|
||||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
|
|
||||||
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
||||||
|
|
||||||
@ -34,21 +34,10 @@ detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(detection_plots)
|
|
||||||
class DetectionPlotImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a detection plot.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class PRCurveConfig(BasePlotConfig):
|
class PRCurveConfig(BasePlotConfig):
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
label: str = "pr_curve"
|
label: str = "pr_curve"
|
||||||
title: str | None = "Detection Precision-Recall Curve"
|
title: Optional[str] = "Detection Precision-Recall Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
|
||||||
@ -111,7 +100,7 @@ class PRCurve(BasePlot):
|
|||||||
class ROCCurveConfig(BasePlotConfig):
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
label: str = "roc_curve"
|
label: str = "roc_curve"
|
||||||
title: str | None = "Detection ROC Curve"
|
title: Optional[str] = "Detection ROC Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
|
||||||
@ -170,7 +159,7 @@ class ROCCurve(BasePlot):
|
|||||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||||
name: Literal["score_distribution"] = "score_distribution"
|
name: Literal["score_distribution"] = "score_distribution"
|
||||||
label: str = "score_distribution"
|
label: str = "score_distribution"
|
||||||
title: str | None = "Detection Score Distribution"
|
title: Optional[str] = "Detection Score Distribution"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
|
||||||
@ -237,7 +226,7 @@ class ScoreDistributionPlot(BasePlot):
|
|||||||
class ExampleDetectionPlotConfig(BasePlotConfig):
|
class ExampleDetectionPlotConfig(BasePlotConfig):
|
||||||
name: Literal["example_detection"] = "example_detection"
|
name: Literal["example_detection"] = "example_detection"
|
||||||
label: str = "example_detection"
|
label: str = "example_detection"
|
||||||
title: str | None = "Example Detection"
|
title: Optional[str] = "Example Detection"
|
||||||
figsize: tuple[int, int] = (10, 4)
|
figsize: tuple[int, int] = (10, 4)
|
||||||
num_examples: int = 5
|
num_examples: int = 5
|
||||||
threshold: float = 0.2
|
threshold: float = 0.2
|
||||||
@ -303,10 +292,12 @@ class ExampleDetectionPlot(BasePlot):
|
|||||||
|
|
||||||
|
|
||||||
DetectionPlotConfig = Annotated[
|
DetectionPlotConfig = Annotated[
|
||||||
PRCurveConfig
|
Union[
|
||||||
| ROCCurveConfig
|
PRCurveConfig,
|
||||||
| ScoreDistributionPlotConfig
|
ROCCurveConfig,
|
||||||
| ExampleDetectionPlotConfig,
|
ScoreDistributionPlotConfig,
|
||||||
|
ExampleDetectionPlotConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -4,9 +4,14 @@ from dataclasses import dataclass, field
|
|||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
Callable,
|
Callable,
|
||||||
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -16,8 +21,7 @@ from pydantic import Field
|
|||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
from batdetect2.audio.types import AudioLoader
|
from batdetect2.core import Registry
|
||||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
|
||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
from batdetect2.evaluate.metrics.top_class import (
|
from batdetect2.evaluate.metrics.top_class import (
|
||||||
ClipEval,
|
ClipEval,
|
||||||
@ -28,31 +32,19 @@ from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
|||||||
from batdetect2.plotting.gallery import plot_match_gallery
|
from batdetect2.plotting.gallery import plot_match_gallery
|
||||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
|
|
||||||
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[tuple[str, Figure]]]
|
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
||||||
|
|
||||||
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
||||||
name="top_class_plot"
|
name="top_class_plot"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(top_class_plots)
|
|
||||||
class TopClassPlotImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a top-class plot.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class PRCurveConfig(BasePlotConfig):
|
class PRCurveConfig(BasePlotConfig):
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
label: str = "pr_curve"
|
label: str = "pr_curve"
|
||||||
title: str | None = "Top Class Precision-Recall Curve"
|
title: Optional[str] = "Top Class Precision-Recall Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
|
||||||
@ -72,7 +64,7 @@ class PRCurve(BasePlot):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Iterable[tuple[str, Figure]]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true = []
|
y_true = []
|
||||||
y_score = []
|
y_score = []
|
||||||
num_positives = 0
|
num_positives = 0
|
||||||
@ -119,7 +111,7 @@ class PRCurve(BasePlot):
|
|||||||
class ROCCurveConfig(BasePlotConfig):
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
label: str = "roc_curve"
|
label: str = "roc_curve"
|
||||||
title: str | None = "Top Class ROC Curve"
|
title: Optional[str] = "Top Class ROC Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
|
||||||
@ -139,7 +131,7 @@ class ROCCurve(BasePlot):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Iterable[tuple[str, Figure]]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true = []
|
y_true = []
|
||||||
y_score = []
|
y_score = []
|
||||||
|
|
||||||
@ -181,7 +173,7 @@ class ROCCurve(BasePlot):
|
|||||||
|
|
||||||
class ConfusionMatrixConfig(BasePlotConfig):
|
class ConfusionMatrixConfig(BasePlotConfig):
|
||||||
name: Literal["confusion_matrix"] = "confusion_matrix"
|
name: Literal["confusion_matrix"] = "confusion_matrix"
|
||||||
title: str | None = "Top Class Confusion Matrix"
|
title: Optional[str] = "Top Class Confusion Matrix"
|
||||||
figsize: tuple[int, int] = (10, 10)
|
figsize: tuple[int, int] = (10, 10)
|
||||||
label: str = "confusion_matrix"
|
label: str = "confusion_matrix"
|
||||||
exclude_generic: bool = True
|
exclude_generic: bool = True
|
||||||
@ -222,7 +214,7 @@ class ConfusionMatrix(BasePlot):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Iterable[tuple[str, Figure]]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
cm, labels = compute_confusion_matrix(
|
cm, labels = compute_confusion_matrix(
|
||||||
clip_evaluations,
|
clip_evaluations,
|
||||||
self.targets,
|
self.targets,
|
||||||
@ -265,7 +257,7 @@ class ConfusionMatrix(BasePlot):
|
|||||||
class ExampleClassificationPlotConfig(BasePlotConfig):
|
class ExampleClassificationPlotConfig(BasePlotConfig):
|
||||||
name: Literal["example_classification"] = "example_classification"
|
name: Literal["example_classification"] = "example_classification"
|
||||||
label: str = "example_classification"
|
label: str = "example_classification"
|
||||||
title: str | None = "Example Classification"
|
title: Optional[str] = "Example Classification"
|
||||||
num_examples: int = 4
|
num_examples: int = 4
|
||||||
threshold: float = 0.2
|
threshold: float = 0.2
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
@ -294,26 +286,26 @@ class ExampleClassificationPlot(BasePlot):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Iterable[tuple[str, Figure]]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
grouped = group_matches(clip_evaluations, threshold=self.threshold)
|
grouped = group_matches(clip_evaluations, threshold=self.threshold)
|
||||||
|
|
||||||
for class_name, matches in grouped.items():
|
for class_name, matches in grouped.items():
|
||||||
true_positives: list[MatchEval] = get_binned_sample(
|
true_positives: List[MatchEval] = get_binned_sample(
|
||||||
matches.true_positives,
|
matches.true_positives,
|
||||||
n_examples=self.num_examples,
|
n_examples=self.num_examples,
|
||||||
)
|
)
|
||||||
|
|
||||||
false_positives: list[MatchEval] = get_binned_sample(
|
false_positives: List[MatchEval] = get_binned_sample(
|
||||||
matches.false_positives,
|
matches.false_positives,
|
||||||
n_examples=self.num_examples,
|
n_examples=self.num_examples,
|
||||||
)
|
)
|
||||||
|
|
||||||
false_negatives: list[MatchEval] = random.sample(
|
false_negatives: List[MatchEval] = random.sample(
|
||||||
matches.false_negatives,
|
matches.false_negatives,
|
||||||
k=min(self.num_examples, len(matches.false_negatives)),
|
k=min(self.num_examples, len(matches.false_negatives)),
|
||||||
)
|
)
|
||||||
|
|
||||||
cross_triggers: list[MatchEval] = get_binned_sample(
|
cross_triggers: List[MatchEval] = get_binned_sample(
|
||||||
matches.cross_triggers, n_examples=self.num_examples
|
matches.cross_triggers, n_examples=self.num_examples
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -356,10 +348,12 @@ class ExampleClassificationPlot(BasePlot):
|
|||||||
|
|
||||||
|
|
||||||
TopClassPlotConfig = Annotated[
|
TopClassPlotConfig = Annotated[
|
||||||
PRCurveConfig
|
Union[
|
||||||
| ROCCurveConfig
|
PRCurveConfig,
|
||||||
| ConfusionMatrixConfig
|
ROCCurveConfig,
|
||||||
| ExampleClassificationPlotConfig,
|
ConfusionMatrixConfig,
|
||||||
|
ExampleClassificationPlotConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -373,16 +367,16 @@ def build_top_class_plotter(
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClassMatches:
|
class ClassMatches:
|
||||||
false_positives: list[MatchEval] = field(default_factory=list)
|
false_positives: List[MatchEval] = field(default_factory=list)
|
||||||
false_negatives: list[MatchEval] = field(default_factory=list)
|
false_negatives: List[MatchEval] = field(default_factory=list)
|
||||||
true_positives: list[MatchEval] = field(default_factory=list)
|
true_positives: List[MatchEval] = field(default_factory=list)
|
||||||
cross_triggers: list[MatchEval] = field(default_factory=list)
|
cross_triggers: List[MatchEval] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
def group_matches(
|
def group_matches(
|
||||||
clip_evals: Sequence[ClipEval],
|
clip_evals: Sequence[ClipEval],
|
||||||
threshold: float = 0.2,
|
threshold: float = 0.2,
|
||||||
) -> dict[str, ClassMatches]:
|
) -> Dict[str, ClassMatches]:
|
||||||
class_examples = defaultdict(ClassMatches)
|
class_examples = defaultdict(ClassMatches)
|
||||||
|
|
||||||
for clip_eval in clip_evals:
|
for clip_eval in clip_evals:
|
||||||
@ -411,13 +405,12 @@ def group_matches(
|
|||||||
return class_examples
|
return class_examples
|
||||||
|
|
||||||
|
|
||||||
def get_binned_sample(matches: list[MatchEval], n_examples: int = 5):
|
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
|
||||||
if len(matches) < n_examples:
|
if len(matches) < n_examples:
|
||||||
return matches
|
return matches
|
||||||
|
|
||||||
indices, pred_scores = zip(
|
indices, pred_scores = zip(
|
||||||
*[(index, match.score) for index, match in enumerate(matches)],
|
*[(index, match.score) for index, match in enumerate(matches)]
|
||||||
strict=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
||||||
|
|||||||
@ -1,27 +0,0 @@
|
|||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Iterable
|
|
||||||
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
__all__ = ["save_evaluation_results"]
|
|
||||||
|
|
||||||
|
|
||||||
def save_evaluation_results(
|
|
||||||
metrics: dict[str, float],
|
|
||||||
plots: Iterable[tuple[str, Figure]],
|
|
||||||
output_dir: data.PathLike,
|
|
||||||
) -> None:
|
|
||||||
"""Save evaluation metrics and plots to disk."""
|
|
||||||
|
|
||||||
output_path = Path(output_dir)
|
|
||||||
output_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
metrics_path = output_path / "metrics.json"
|
|
||||||
metrics_path.write_text(json.dumps(metrics))
|
|
||||||
|
|
||||||
for figure_name, figure in plots:
|
|
||||||
figure_path = output_path / figure_name
|
|
||||||
figure_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
figure.savefig(figure_path)
|
|
||||||
106
src/batdetect2/evaluate/tables.py
Normal file
106
src/batdetect2/evaluate/tables.py
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
from typing import Annotated, Callable, Literal, Sequence, Union
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig, Registry
|
||||||
|
from batdetect2.typing import ClipMatches
|
||||||
|
|
||||||
|
EvaluationTableGenerator = Callable[[Sequence[ClipMatches]], pd.DataFrame]
|
||||||
|
|
||||||
|
|
||||||
|
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
|
||||||
|
"evaluation_table"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FullEvaluationTableConfig(BaseConfig):
|
||||||
|
name: Literal["full_evaluation"] = "full_evaluation"
|
||||||
|
|
||||||
|
|
||||||
|
class FullEvaluationTable:
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipMatches]
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
return extract_matches_dataframe(clip_evaluations)
|
||||||
|
|
||||||
|
@tables_registry.register(FullEvaluationTableConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: FullEvaluationTableConfig):
|
||||||
|
return FullEvaluationTable()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_matches_dataframe(
|
||||||
|
clip_evaluations: Sequence[ClipMatches],
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
data = []
|
||||||
|
|
||||||
|
for clip_evaluation in clip_evaluations:
|
||||||
|
for match in clip_evaluation.matches:
|
||||||
|
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
||||||
|
pred_start_time = pred_low_freq = pred_end_time = (
|
||||||
|
pred_high_freq
|
||||||
|
) = None
|
||||||
|
|
||||||
|
sound_event_annotation = match.sound_event_annotation
|
||||||
|
|
||||||
|
if sound_event_annotation is not None:
|
||||||
|
geometry = sound_event_annotation.sound_event.geometry
|
||||||
|
assert geometry is not None
|
||||||
|
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
|
||||||
|
compute_bounds(geometry)
|
||||||
|
)
|
||||||
|
|
||||||
|
if match.pred_geometry is not None:
|
||||||
|
(
|
||||||
|
pred_start_time,
|
||||||
|
pred_low_freq,
|
||||||
|
pred_end_time,
|
||||||
|
pred_high_freq,
|
||||||
|
) = compute_bounds(match.pred_geometry)
|
||||||
|
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
("recording", "uuid"): match.clip.recording.uuid,
|
||||||
|
("clip", "uuid"): match.clip.uuid,
|
||||||
|
("clip", "start_time"): match.clip.start_time,
|
||||||
|
("clip", "end_time"): match.clip.end_time,
|
||||||
|
("gt", "uuid"): match.sound_event_annotation.uuid
|
||||||
|
if match.sound_event_annotation is not None
|
||||||
|
else None,
|
||||||
|
("gt", "class"): match.gt_class,
|
||||||
|
("gt", "det"): match.gt_det,
|
||||||
|
("gt", "start_time"): gt_start_time,
|
||||||
|
("gt", "end_time"): gt_end_time,
|
||||||
|
("gt", "low_freq"): gt_low_freq,
|
||||||
|
("gt", "high_freq"): gt_high_freq,
|
||||||
|
("pred", "score"): match.pred_score,
|
||||||
|
("pred", "class"): match.top_class,
|
||||||
|
("pred", "class_score"): match.top_class_score,
|
||||||
|
("pred", "start_time"): pred_start_time,
|
||||||
|
("pred", "end_time"): pred_end_time,
|
||||||
|
("pred", "low_freq"): pred_low_freq,
|
||||||
|
("pred", "high_freq"): pred_high_freq,
|
||||||
|
("match", "affinity"): match.affinity,
|
||||||
|
**{
|
||||||
|
("pred_class_score", key): value
|
||||||
|
for key, value in match.pred_class_scores.items()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
EvaluationTableConfig = Annotated[
|
||||||
|
Union[FullEvaluationTableConfig,], Field(discriminator="name")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_table_generator(
|
||||||
|
config: EvaluationTableConfig,
|
||||||
|
) -> EvaluationTableGenerator:
|
||||||
|
return tables_registry.build(config)
|
||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Annotated, Sequence
|
from typing import Annotated, Optional, Sequence, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -11,10 +11,12 @@ from batdetect2.evaluate.tasks.clip_classification import (
|
|||||||
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
|
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
|
||||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||||
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
||||||
from batdetect2.evaluate.types import EvaluationTaskProtocol
|
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
EvaluatorProtocol,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TaskConfig",
|
"TaskConfig",
|
||||||
@ -24,29 +26,31 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
TaskConfig = Annotated[
|
TaskConfig = Annotated[
|
||||||
ClassificationTaskConfig
|
Union[
|
||||||
| DetectionTaskConfig
|
ClassificationTaskConfig,
|
||||||
| ClipDetectionTaskConfig
|
DetectionTaskConfig,
|
||||||
| ClipClassificationTaskConfig
|
ClipDetectionTaskConfig,
|
||||||
| TopClassDetectionTaskConfig,
|
ClipClassificationTaskConfig,
|
||||||
|
TopClassDetectionTaskConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_task(
|
def build_task(
|
||||||
config: TaskConfig,
|
config: TaskConfig,
|
||||||
targets: TargetProtocol | None = None,
|
targets: Optional[TargetProtocol] = None,
|
||||||
) -> EvaluationTaskProtocol:
|
) -> EvaluatorProtocol:
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
return tasks_registry.build(config, targets)
|
return tasks_registry.build(config, targets)
|
||||||
|
|
||||||
|
|
||||||
def evaluate_task(
|
def evaluate_task(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[ClipDetections],
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
task: str | None = None,
|
task: Optional["str"] = None,
|
||||||
targets: TargetProtocol | None = None,
|
targets: Optional[TargetProtocol] = None,
|
||||||
config: TaskConfig | dict | None = None,
|
config: Optional[Union[TaskConfig, dict]] = None,
|
||||||
):
|
):
|
||||||
if isinstance(config, BaseTaskConfig):
|
if isinstance(config, BaseTaskConfig):
|
||||||
task_obj = build_task(config, targets)
|
task_obj = build_task(config, targets)
|
||||||
|
|||||||
@ -4,93 +4,78 @@ from typing import (
|
|||||||
Generic,
|
Generic,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Literal,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.core import (
|
from batdetect2.core import BaseConfig
|
||||||
BaseConfig,
|
from batdetect2.core.registries import Registry
|
||||||
ImportConfig,
|
from batdetect2.evaluate.match import (
|
||||||
Registry,
|
MatchConfig,
|
||||||
add_import_config,
|
StartTimeMatchConfig,
|
||||||
|
build_matcher,
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.affinity import (
|
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
|
||||||
AffinityConfig,
|
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction
|
||||||
TimeAffinityConfig,
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
build_affinity_function,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.types import (
|
|
||||||
AffinityFunction,
|
|
||||||
EvaluationTaskProtocol,
|
|
||||||
)
|
|
||||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseTaskConfig",
|
"BaseTaskConfig",
|
||||||
"BaseTask",
|
"BaseTask",
|
||||||
"TaskImportConfig",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
tasks_registry: Registry[EvaluationTaskProtocol, [TargetProtocol]] = Registry(
|
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
|
||||||
"tasks"
|
"tasks"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(tasks_registry)
|
|
||||||
class TaskImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as an evaluation task.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
T_Output = TypeVar("T_Output")
|
T_Output = TypeVar("T_Output")
|
||||||
|
|
||||||
|
|
||||||
class BaseTaskConfig(BaseConfig):
|
class BaseTaskConfig(BaseConfig):
|
||||||
prefix: str
|
prefix: str
|
||||||
|
|
||||||
ignore_start_end: float = 0.01
|
ignore_start_end: float = 0.01
|
||||||
|
matching_strategy: MatchConfig = Field(
|
||||||
|
default_factory=StartTimeMatchConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseTask(EvaluationTaskProtocol, Generic[T_Output]):
|
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||||
targets: TargetProtocol
|
targets: TargetProtocol
|
||||||
|
|
||||||
|
matcher: MatcherProtocol
|
||||||
|
|
||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
||||||
|
|
||||||
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||||
|
|
||||||
prefix: str
|
|
||||||
|
|
||||||
ignore_start_end: float
|
ignore_start_end: float
|
||||||
|
|
||||||
|
prefix: str
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
matcher: MatcherProtocol,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
plots: List[
|
|
||||||
Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]
|
|
||||||
]
|
|
||||||
| None = None,
|
|
||||||
ignore_start_end: float = 0.01,
|
ignore_start_end: float = 0.01,
|
||||||
|
plots: Optional[
|
||||||
|
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||||
|
] = None,
|
||||||
):
|
):
|
||||||
self.prefix = prefix
|
self.matcher = matcher
|
||||||
self.targets = targets
|
|
||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
self.plots = plots or []
|
self.plots = plots or []
|
||||||
|
self.targets = targets
|
||||||
|
self.prefix = prefix
|
||||||
self.ignore_start_end = ignore_start_end
|
self.ignore_start_end = ignore_start_end
|
||||||
|
|
||||||
def compute_metrics(
|
def compute_metrics(
|
||||||
@ -108,30 +93,24 @@ class BaseTask(EvaluationTaskProtocol, Generic[T_Output]):
|
|||||||
self, eval_outputs: List[T_Output]
|
self, eval_outputs: List[T_Output]
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
for plot in self.plots:
|
for plot in self.plots:
|
||||||
try:
|
|
||||||
for name, fig in plot(eval_outputs):
|
for name, fig in plot(eval_outputs):
|
||||||
yield f"{self.prefix}/{name}", fig
|
yield f"{self.prefix}/{name}", fig
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error plotting {self.prefix}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[ClipDetections],
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
) -> List[T_Output]:
|
) -> List[T_Output]:
|
||||||
return [
|
return [
|
||||||
self.evaluate_clip(clip_annotation, preds)
|
self.evaluate_clip(clip_annotation, preds)
|
||||||
for clip_annotation, preds in zip(
|
for clip_annotation, preds in zip(clip_annotations, predictions)
|
||||||
clip_annotations, predictions, strict=False
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
prediction: ClipDetections,
|
prediction: BatDetect2Prediction,
|
||||||
) -> T_Output: ... # ty: ignore[empty-body]
|
) -> T_Output: ...
|
||||||
|
|
||||||
def include_sound_event_annotation(
|
def include_sound_event_annotation(
|
||||||
self,
|
self,
|
||||||
@ -142,6 +121,9 @@ class BaseTask(EvaluationTaskProtocol, Generic[T_Output]):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
geometry = sound_event_annotation.sound_event.geometry
|
geometry = sound_event_annotation.sound_event.geometry
|
||||||
|
if geometry is None:
|
||||||
|
return False
|
||||||
|
|
||||||
return is_in_bounds(
|
return is_in_bounds(
|
||||||
geometry,
|
geometry,
|
||||||
clip,
|
clip,
|
||||||
@ -150,7 +132,7 @@ class BaseTask(EvaluationTaskProtocol, Generic[T_Output]):
|
|||||||
|
|
||||||
def include_prediction(
|
def include_prediction(
|
||||||
self,
|
self,
|
||||||
prediction: Detection,
|
prediction: RawPrediction,
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return is_in_bounds(
|
return is_in_bounds(
|
||||||
@ -159,56 +141,25 @@ class BaseTask(EvaluationTaskProtocol, Generic[T_Output]):
|
|||||||
self.ignore_start_end,
|
self.ignore_start_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseSEDTaskConfig(BaseTaskConfig):
|
|
||||||
affinity: AffinityConfig = Field(default_factory=TimeAffinityConfig)
|
|
||||||
affinity_threshold: float = 0
|
|
||||||
strict_match: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSEDTask(BaseTask[T_Output]):
|
|
||||||
affinity: AffinityFunction
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prefix: str,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
|
||||||
affinity: AffinityFunction,
|
|
||||||
plots: List[
|
|
||||||
Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]
|
|
||||||
]
|
|
||||||
| None = None,
|
|
||||||
affinity_threshold: float = 0,
|
|
||||||
ignore_start_end: float = 0.01,
|
|
||||||
strict_match: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__(
|
|
||||||
prefix=prefix,
|
|
||||||
metrics=metrics,
|
|
||||||
plots=plots,
|
|
||||||
targets=targets,
|
|
||||||
ignore_start_end=ignore_start_end,
|
|
||||||
)
|
|
||||||
self.affinity = affinity
|
|
||||||
self.affinity_threshold = affinity_threshold
|
|
||||||
self.strict_match = strict_match
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def build(
|
def build(
|
||||||
cls,
|
cls,
|
||||||
config: BaseSEDTaskConfig,
|
config: BaseTaskConfig,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||||
|
plots: Optional[
|
||||||
|
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||||
|
] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
affinity = build_affinity_function(config.affinity)
|
matcher = build_matcher(config.matching_strategy)
|
||||||
return cls(
|
return cls(
|
||||||
affinity=affinity,
|
matcher=matcher,
|
||||||
affinity_threshold=config.affinity_threshold,
|
targets=targets,
|
||||||
|
metrics=metrics,
|
||||||
|
plots=plots,
|
||||||
prefix=config.prefix,
|
prefix=config.prefix,
|
||||||
ignore_start_end=config.ignore_start_end,
|
ignore_start_end=config.ignore_start_end,
|
||||||
strict_match=config.strict_match,
|
|
||||||
targets=targets,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
from functools import partial
|
from typing import (
|
||||||
from typing import Literal
|
List,
|
||||||
|
Literal,
|
||||||
|
)
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import match_detections_and_gts
|
|
||||||
|
|
||||||
from batdetect2.evaluate.metrics.classification import (
|
from batdetect2.evaluate.metrics.classification import (
|
||||||
ClassificationAveragePrecisionConfig,
|
ClassificationAveragePrecisionConfig,
|
||||||
@ -17,25 +18,24 @@ from batdetect2.evaluate.plots.classification import (
|
|||||||
build_classification_plotter,
|
build_classification_plotter,
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.tasks.base import (
|
from batdetect2.evaluate.tasks.base import (
|
||||||
BaseSEDTask,
|
BaseTask,
|
||||||
BaseSEDTaskConfig,
|
BaseTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTaskConfig(BaseSEDTaskConfig):
|
class ClassificationTaskConfig(BaseTaskConfig):
|
||||||
name: Literal["sound_event_classification"] = "sound_event_classification"
|
name: Literal["sound_event_classification"] = "sound_event_classification"
|
||||||
prefix: str = "classification"
|
prefix: str = "classification"
|
||||||
metrics: list[ClassificationMetricConfig] = Field(
|
metrics: List[ClassificationMetricConfig] = Field(
|
||||||
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
||||||
)
|
)
|
||||||
plots: list[ClassificationPlotConfig] = Field(default_factory=list)
|
plots: List[ClassificationPlotConfig] = Field(default_factory=list)
|
||||||
include_generics: bool = True
|
include_generics: bool = True
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTask(BaseSEDTask[ClipEval]):
|
class ClassificationTask(BaseTask[ClipEval]):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
@ -48,13 +48,13 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
|
|||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
prediction: ClipDetections,
|
prediction: BatDetect2Prediction,
|
||||||
) -> ClipEval:
|
) -> ClipEval:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
preds = [
|
preds = [
|
||||||
pred
|
pred
|
||||||
for pred in prediction.detections
|
for pred in prediction.predictions
|
||||||
if self.include_prediction(pred, clip)
|
if self.include_prediction(pred, clip)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -73,40 +73,40 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
|
|||||||
gts = [
|
gts = [
|
||||||
sound_event
|
sound_event
|
||||||
for sound_event in all_gts
|
for sound_event in all_gts
|
||||||
if is_target_class(
|
if self.is_class(sound_event, class_name)
|
||||||
sound_event,
|
|
||||||
class_name,
|
|
||||||
self.targets,
|
|
||||||
include_generics=self.include_generics,
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
scores = [float(pred.class_scores[class_idx]) for pred in preds]
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
|
|
||||||
for match in match_detections_and_gts(
|
for pred_idx, gt_idx, _ in self.matcher(
|
||||||
detections=preds,
|
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||||
ground_truths=gts,
|
predictions=[pred.geometry for pred in preds],
|
||||||
affinity=self.affinity,
|
scores=scores,
|
||||||
score=partial(get_class_score, class_idx=class_idx),
|
|
||||||
strict_match=self.strict_match,
|
|
||||||
affinity_threshold=self.affinity_threshold,
|
|
||||||
):
|
):
|
||||||
|
gt = gts[gt_idx] if gt_idx is not None else None
|
||||||
|
pred = preds[pred_idx] if pred_idx is not None else None
|
||||||
|
|
||||||
true_class = (
|
true_class = (
|
||||||
self.targets.encode_class(match.annotation)
|
self.targets.encode_class(gt) if gt is not None else None
|
||||||
if match.annotation is not None
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
score = (
|
||||||
|
float(pred.class_scores[class_idx])
|
||||||
|
if pred is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
matches.append(
|
matches.append(
|
||||||
MatchEval(
|
MatchEval(
|
||||||
clip=clip,
|
clip=clip,
|
||||||
gt=match.annotation,
|
gt=gt,
|
||||||
pred=match.prediction,
|
pred=pred,
|
||||||
is_prediction=match.prediction is not None,
|
is_prediction=pred is not None,
|
||||||
is_ground_truth=match.annotation is not None,
|
is_ground_truth=gt is not None,
|
||||||
is_generic=match.annotation is not None
|
is_generic=gt is not None and true_class is None,
|
||||||
and true_class is None,
|
|
||||||
true_class=true_class,
|
true_class=true_class,
|
||||||
score=match.prediction_score,
|
score=score,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -114,6 +114,20 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
|
|||||||
|
|
||||||
return ClipEval(clip=clip, matches=per_class_matches)
|
return ClipEval(clip=clip, matches=per_class_matches)
|
||||||
|
|
||||||
|
def is_class(
|
||||||
|
self,
|
||||||
|
sound_event: data.SoundEventAnnotation,
|
||||||
|
class_name: str,
|
||||||
|
) -> bool:
|
||||||
|
sound_event_class = self.targets.encode_class(sound_event)
|
||||||
|
|
||||||
|
if sound_event_class is None and self.include_generics:
|
||||||
|
# Sound events that are generic could be of the given
|
||||||
|
# class
|
||||||
|
return True
|
||||||
|
|
||||||
|
return sound_event_class == class_name
|
||||||
|
|
||||||
@tasks_registry.register(ClassificationTaskConfig)
|
@tasks_registry.register(ClassificationTaskConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
@ -133,25 +147,4 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
|
|||||||
plots=plots,
|
plots=plots,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
include_generics=config.include_generics,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_class_score(pred: Detection, class_idx: int) -> float:
|
|
||||||
return pred.class_scores[class_idx]
|
|
||||||
|
|
||||||
|
|
||||||
def is_target_class(
|
|
||||||
sound_event: data.SoundEventAnnotation,
|
|
||||||
class_name: str,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
include_generics: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
sound_event_class = targets.encode_class(sound_event)
|
|
||||||
|
|
||||||
if sound_event_class is None and include_generics:
|
|
||||||
# Sound events that are generic could be of the given
|
|
||||||
# class
|
|
||||||
return True
|
|
||||||
|
|
||||||
return sound_event_class == class_name
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Literal
|
from typing import List, Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -19,26 +19,26 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.typing import TargetProtocol
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||||
name: Literal["clip_classification"] = "clip_classification"
|
name: Literal["clip_classification"] = "clip_classification"
|
||||||
prefix: str = "clip_classification"
|
prefix: str = "clip_classification"
|
||||||
metrics: list[ClipClassificationMetricConfig] = Field(
|
metrics: List[ClipClassificationMetricConfig] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
ClipClassificationAveragePrecisionConfig(),
|
ClipClassificationAveragePrecisionConfig(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
plots: list[ClipClassificationPlotConfig] = Field(default_factory=list)
|
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationTask(BaseTask[ClipEval]):
|
class ClipClassificationTask(BaseTask[ClipEval]):
|
||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
prediction: ClipDetections,
|
prediction: BatDetect2Prediction,
|
||||||
) -> ClipEval:
|
) -> ClipEval:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
@ -55,7 +55,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
|
|||||||
gt_classes.add(class_name)
|
gt_classes.add(class_name)
|
||||||
|
|
||||||
pred_scores = defaultdict(float)
|
pred_scores = defaultdict(float)
|
||||||
for pred in prediction.detections:
|
for pred in prediction.predictions:
|
||||||
if not self.include_prediction(pred, clip):
|
if not self.include_prediction(pred, clip):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -78,8 +78,8 @@ class ClipClassificationTask(BaseTask[ClipEval]):
|
|||||||
build_clip_classification_plotter(plot, targets)
|
build_clip_classification_plotter(plot, targets)
|
||||||
for plot in config.plots
|
for plot in config.plots
|
||||||
]
|
]
|
||||||
return ClipClassificationTask(
|
return ClipClassificationTask.build(
|
||||||
prefix=config.prefix,
|
config=config,
|
||||||
plots=plots,
|
plots=plots,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Literal
|
from typing import List, Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -18,26 +18,26 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.typing import TargetProtocol
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||||
name: Literal["clip_detection"] = "clip_detection"
|
name: Literal["clip_detection"] = "clip_detection"
|
||||||
prefix: str = "clip_detection"
|
prefix: str = "clip_detection"
|
||||||
metrics: list[ClipDetectionMetricConfig] = Field(
|
metrics: List[ClipDetectionMetricConfig] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
ClipDetectionAveragePrecisionConfig(),
|
ClipDetectionAveragePrecisionConfig(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
plots: list[ClipDetectionPlotConfig] = Field(default_factory=list)
|
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionTask(BaseTask[ClipEval]):
|
class ClipDetectionTask(BaseTask[ClipEval]):
|
||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
prediction: ClipDetections,
|
prediction: BatDetect2Prediction,
|
||||||
) -> ClipEval:
|
) -> ClipEval:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
@ -47,7 +47,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
pred_score = 0
|
pred_score = 0
|
||||||
for pred in prediction.detections:
|
for pred in prediction.predictions:
|
||||||
if not self.include_prediction(pred, clip):
|
if not self.include_prediction(pred, clip):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -69,8 +69,8 @@ class ClipDetectionTask(BaseTask[ClipEval]):
|
|||||||
build_clip_detection_plotter(plot, targets)
|
build_clip_detection_plotter(plot, targets)
|
||||||
for plot in config.plots
|
for plot in config.plots
|
||||||
]
|
]
|
||||||
return ClipDetectionTask(
|
return ClipDetectionTask.build(
|
||||||
prefix=config.prefix,
|
config=config,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
plots=plots,
|
plots=plots,
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
from typing import Literal
|
from typing import List, Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import match_detections_and_gts
|
|
||||||
|
|
||||||
from batdetect2.evaluate.metrics.detection import (
|
from batdetect2.evaluate.metrics.detection import (
|
||||||
ClipEval,
|
ClipEval,
|
||||||
@ -16,28 +15,28 @@ from batdetect2.evaluate.plots.detection import (
|
|||||||
build_detection_plotter,
|
build_detection_plotter,
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.tasks.base import (
|
from batdetect2.evaluate.tasks.base import (
|
||||||
BaseSEDTask,
|
BaseTask,
|
||||||
BaseSEDTaskConfig,
|
BaseTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.typing import TargetProtocol
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
|
|
||||||
class DetectionTaskConfig(BaseSEDTaskConfig):
|
class DetectionTaskConfig(BaseTaskConfig):
|
||||||
name: Literal["sound_event_detection"] = "sound_event_detection"
|
name: Literal["sound_event_detection"] = "sound_event_detection"
|
||||||
prefix: str = "detection"
|
prefix: str = "detection"
|
||||||
metrics: list[DetectionMetricConfig] = Field(
|
metrics: List[DetectionMetricConfig] = Field(
|
||||||
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
||||||
)
|
)
|
||||||
plots: list[DetectionPlotConfig] = Field(default_factory=list)
|
plots: List[DetectionPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class DetectionTask(BaseSEDTask[ClipEval]):
|
class DetectionTask(BaseTask[ClipEval]):
|
||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
prediction: ClipDetections,
|
prediction: BatDetect2Prediction,
|
||||||
) -> ClipEval:
|
) -> ClipEval:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
@ -48,26 +47,27 @@ class DetectionTask(BaseSEDTask[ClipEval]):
|
|||||||
]
|
]
|
||||||
preds = [
|
preds = [
|
||||||
pred
|
pred
|
||||||
for pred in prediction.detections
|
for pred in prediction.predictions
|
||||||
if self.include_prediction(pred, clip)
|
if self.include_prediction(pred, clip)
|
||||||
]
|
]
|
||||||
|
scores = [pred.detection_score for pred in preds]
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
for match in match_detections_and_gts(
|
for pred_idx, gt_idx, _ in self.matcher(
|
||||||
detections=preds,
|
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||||
ground_truths=gts,
|
predictions=[pred.geometry for pred in preds],
|
||||||
affinity=self.affinity,
|
scores=scores,
|
||||||
score=lambda pred: pred.detection_score,
|
|
||||||
strict_match=self.strict_match,
|
|
||||||
affinity_threshold=self.affinity_threshold,
|
|
||||||
):
|
):
|
||||||
|
gt = gts[gt_idx] if gt_idx is not None else None
|
||||||
|
pred = preds[pred_idx] if pred_idx is not None else None
|
||||||
|
|
||||||
matches.append(
|
matches.append(
|
||||||
MatchEval(
|
MatchEval(
|
||||||
gt=match.annotation,
|
gt=gt,
|
||||||
pred=match.prediction,
|
pred=pred,
|
||||||
is_prediction=match.prediction is not None,
|
is_prediction=pred is not None,
|
||||||
is_ground_truth=match.annotation is not None,
|
is_ground_truth=gt is not None,
|
||||||
score=match.prediction_score,
|
score=pred.detection_score if pred is not None else 0,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
from typing import Literal
|
from typing import List, Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import match_detections_and_gts
|
|
||||||
|
|
||||||
from batdetect2.evaluate.metrics.top_class import (
|
from batdetect2.evaluate.metrics.top_class import (
|
||||||
ClipEval,
|
ClipEval,
|
||||||
@ -16,28 +15,28 @@ from batdetect2.evaluate.plots.top_class import (
|
|||||||
build_top_class_plotter,
|
build_top_class_plotter,
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.tasks.base import (
|
from batdetect2.evaluate.tasks.base import (
|
||||||
BaseSEDTask,
|
BaseTask,
|
||||||
BaseSEDTaskConfig,
|
BaseTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.typing import TargetProtocol
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
|
|
||||||
class TopClassDetectionTaskConfig(BaseSEDTaskConfig):
|
class TopClassDetectionTaskConfig(BaseTaskConfig):
|
||||||
name: Literal["top_class_detection"] = "top_class_detection"
|
name: Literal["top_class_detection"] = "top_class_detection"
|
||||||
prefix: str = "top_class"
|
prefix: str = "top_class"
|
||||||
metrics: list[TopClassMetricConfig] = Field(
|
metrics: List[TopClassMetricConfig] = Field(
|
||||||
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
||||||
)
|
)
|
||||||
plots: list[TopClassPlotConfig] = Field(default_factory=list)
|
plots: List[TopClassPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class TopClassDetectionTask(BaseSEDTask[ClipEval]):
|
class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
prediction: ClipDetections,
|
prediction: BatDetect2Prediction,
|
||||||
) -> ClipEval:
|
) -> ClipEval:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
@ -48,21 +47,21 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
|
|||||||
]
|
]
|
||||||
preds = [
|
preds = [
|
||||||
pred
|
pred
|
||||||
for pred in prediction.detections
|
for pred in prediction.predictions
|
||||||
if self.include_prediction(pred, clip)
|
if self.include_prediction(pred, clip)
|
||||||
]
|
]
|
||||||
|
# Take the highest score for each prediction
|
||||||
|
scores = [pred.class_scores.max() for pred in preds]
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
for match in match_detections_and_gts(
|
for pred_idx, gt_idx, _ in self.matcher(
|
||||||
ground_truths=gts,
|
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||||
detections=preds,
|
predictions=[pred.geometry for pred in preds],
|
||||||
affinity=self.affinity,
|
scores=scores,
|
||||||
score=lambda pred: pred.class_scores.max(),
|
|
||||||
strict_match=self.strict_match,
|
|
||||||
affinity_threshold=self.affinity_threshold,
|
|
||||||
):
|
):
|
||||||
gt = match.annotation
|
gt = gts[gt_idx] if gt_idx is not None else None
|
||||||
pred = match.prediction
|
pred = preds[pred_idx] if pred_idx is not None else None
|
||||||
|
|
||||||
true_class = (
|
true_class = (
|
||||||
self.targets.encode_class(gt) if gt is not None else None
|
self.targets.encode_class(gt) if gt is not None else None
|
||||||
)
|
)
|
||||||
@ -70,6 +69,11 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
|
|||||||
class_idx = (
|
class_idx = (
|
||||||
pred.class_scores.argmax() if pred is not None else None
|
pred.class_scores.argmax() if pred is not None else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
score = (
|
||||||
|
float(pred.class_scores[class_idx]) if pred is not None else 0
|
||||||
|
)
|
||||||
|
|
||||||
pred_class = (
|
pred_class = (
|
||||||
self.targets.class_names[class_idx]
|
self.targets.class_names[class_idx]
|
||||||
if class_idx is not None
|
if class_idx is not None
|
||||||
@ -86,7 +90,7 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
|
|||||||
true_class=true_class,
|
true_class=true_class,
|
||||||
is_generic=gt is not None and true_class is None,
|
is_generic=gt is not None and true_class is None,
|
||||||
pred_class=pred_class,
|
pred_class=pred_class,
|
||||||
score=match.prediction_score,
|
score=score,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,147 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import Generic, Iterable, Protocol, Sequence, TypeVar
|
|
||||||
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.outputs.types import OutputTransformProtocol
|
|
||||||
from batdetect2.postprocess.types import (
|
|
||||||
ClipDetections,
|
|
||||||
ClipDetectionsTensor,
|
|
||||||
Detection,
|
|
||||||
)
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AffinityFunction",
|
|
||||||
"ClipMatches",
|
|
||||||
"EvaluationTaskProtocol",
|
|
||||||
"EvaluatorProtocol",
|
|
||||||
"MatchEvaluation",
|
|
||||||
"MatcherProtocol",
|
|
||||||
"MetricsProtocol",
|
|
||||||
"PlotterProtocol",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MatchEvaluation:
|
|
||||||
clip: data.Clip
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation | None
|
|
||||||
gt_det: bool
|
|
||||||
gt_class: str | None
|
|
||||||
gt_geometry: data.Geometry | None
|
|
||||||
pred_score: float
|
|
||||||
pred_class_scores: dict[str, float]
|
|
||||||
pred_geometry: data.Geometry | None
|
|
||||||
affinity: float
|
|
||||||
|
|
||||||
@property
|
|
||||||
def top_class(self) -> str | None:
|
|
||||||
if not self.pred_class_scores:
|
|
||||||
return None
|
|
||||||
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_prediction(self) -> bool:
|
|
||||||
return self.pred_geometry is not None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_generic(self) -> bool:
|
|
||||||
return self.gt_det and self.gt_class is None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def top_class_score(self) -> float:
|
|
||||||
pred_class = self.top_class
|
|
||||||
if pred_class is None:
|
|
||||||
return 0
|
|
||||||
return self.pred_class_scores[pred_class]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClipMatches:
|
|
||||||
clip: data.Clip
|
|
||||||
matches: list[MatchEvaluation]
|
|
||||||
|
|
||||||
|
|
||||||
class MatcherProtocol(Protocol):
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
ground_truth: Sequence[data.Geometry],
|
|
||||||
predictions: Sequence[data.Geometry],
|
|
||||||
scores: Sequence[float],
|
|
||||||
) -> Iterable[tuple[int | None, int | None, float]]: ...
|
|
||||||
|
|
||||||
|
|
||||||
class AffinityFunction(Protocol):
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
detection: Detection,
|
|
||||||
ground_truth: data.SoundEventAnnotation,
|
|
||||||
) -> float: ...
|
|
||||||
|
|
||||||
|
|
||||||
class MetricsProtocol(Protocol):
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
predictions: Sequence[Sequence[Detection]],
|
|
||||||
) -> dict[str, float]: ...
|
|
||||||
|
|
||||||
|
|
||||||
class PlotterProtocol(Protocol):
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
predictions: Sequence[Sequence[Detection]],
|
|
||||||
) -> Iterable[tuple[str, Figure]]: ...
|
|
||||||
|
|
||||||
|
|
||||||
EvaluationOutput = TypeVar("EvaluationOutput")
|
|
||||||
|
|
||||||
|
|
||||||
class EvaluationTaskProtocol(Protocol, Generic[EvaluationOutput]):
|
|
||||||
targets: TargetProtocol
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
self,
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
predictions: Sequence[ClipDetections],
|
|
||||||
) -> EvaluationOutput: ...
|
|
||||||
|
|
||||||
def compute_metrics(
|
|
||||||
self,
|
|
||||||
eval_outputs: EvaluationOutput,
|
|
||||||
) -> dict[str, float]: ...
|
|
||||||
|
|
||||||
def generate_plots(
|
|
||||||
self,
|
|
||||||
eval_outputs: EvaluationOutput,
|
|
||||||
) -> Iterable[tuple[str, Figure]]: ...
|
|
||||||
|
|
||||||
|
|
||||||
class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
|
||||||
targets: TargetProtocol
|
|
||||||
transform: OutputTransformProtocol
|
|
||||||
|
|
||||||
def to_clip_detections_batch(
|
|
||||||
self,
|
|
||||||
clip_detections: Sequence[ClipDetectionsTensor],
|
|
||||||
clips: Sequence[data.Clip],
|
|
||||||
) -> list[ClipDetections]: ...
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
self,
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
predictions: Sequence[ClipDetections],
|
|
||||||
) -> EvaluationOutput: ...
|
|
||||||
|
|
||||||
def compute_metrics(
|
|
||||||
self,
|
|
||||||
eval_outputs: EvaluationOutput,
|
|
||||||
) -> dict[str, float]: ...
|
|
||||||
|
|
||||||
def generate_plots(
|
|
||||||
self,
|
|
||||||
eval_outputs: EvaluationOutput,
|
|
||||||
) -> Iterable[tuple[str, Figure]]: ...
|
|
||||||
@ -1,7 +1,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
@ -88,8 +88,7 @@ def select_device(warn=True) -> str:
|
|||||||
if warn:
|
if warn:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"No GPU available, using the CPU instead. Please consider using a GPU "
|
"No GPU available, using the CPU instead. Please consider using a GPU "
|
||||||
"to speed up training.",
|
"to speed up training."
|
||||||
stacklevel=2,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return "cpu"
|
return "cpu"
|
||||||
@ -99,8 +98,8 @@ def load_annotations(
|
|||||||
dataset_name: str,
|
dataset_name: str,
|
||||||
ann_path: str,
|
ann_path: str,
|
||||||
audio_path: str,
|
audio_path: str,
|
||||||
classes_to_ignore: List[str] | None = None,
|
classes_to_ignore: Optional[List[str]] = None,
|
||||||
events_of_interest: List[str] | None = None,
|
events_of_interest: Optional[List[str]] = None,
|
||||||
) -> List[types.FileAnnotation]:
|
) -> List[types.FileAnnotation]:
|
||||||
train_sets: List[types.DatasetDict] = []
|
train_sets: List[types.DatasetDict] = []
|
||||||
train_sets.append(
|
train_sets.append(
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.model_selection import StratifiedGroupKFold
|
from sklearn.model_selection import StratifiedGroupKFold
|
||||||
@ -11,8 +12,8 @@ from batdetect2 import types
|
|||||||
|
|
||||||
|
|
||||||
def print_dataset_stats(
|
def print_dataset_stats(
|
||||||
data: list[types.FileAnnotation],
|
data: List[types.FileAnnotation],
|
||||||
classes_to_ignore: list[str] | None = None,
|
classes_to_ignore: Optional[List[str]] = None,
|
||||||
) -> Counter[str]:
|
) -> Counter[str]:
|
||||||
print("Num files:", len(data))
|
print("Num files:", len(data))
|
||||||
counts, _ = tu.get_class_names(data, classes_to_ignore)
|
counts, _ = tu.get_class_names(data, classes_to_ignore)
|
||||||
@ -21,7 +22,7 @@ def print_dataset_stats(
|
|||||||
return counts
|
return counts
|
||||||
|
|
||||||
|
|
||||||
def load_file_names(file_name: str) -> list[str]:
|
def load_file_names(file_name: str) -> List[str]:
|
||||||
if not os.path.isfile(file_name):
|
if not os.path.isfile(file_name):
|
||||||
raise FileNotFoundError(f"Input file not found - {file_name}")
|
raise FileNotFoundError(f"Input file not found - {file_name}")
|
||||||
|
|
||||||
@ -99,12 +100,12 @@ def parse_args():
|
|||||||
|
|
||||||
|
|
||||||
def split_data(
|
def split_data(
|
||||||
data: list[types.FileAnnotation],
|
data: List[types.FileAnnotation],
|
||||||
train_file: str,
|
train_file: str,
|
||||||
test_file: str,
|
test_file: str,
|
||||||
n_splits: int = 5,
|
n_splits: int = 5,
|
||||||
random_state: int = 0,
|
random_state: int = 0,
|
||||||
) -> tuple[list[types.FileAnnotation], list[types.FileAnnotation]]:
|
) -> Tuple[List[types.FileAnnotation], List[types.FileAnnotation]]:
|
||||||
if train_file != "" and test_file != "":
|
if train_file != "" and test_file != "":
|
||||||
# user has specifed the train / test split
|
# user has specifed the train / test split
|
||||||
mapping = {
|
mapping = {
|
||||||
@ -161,7 +162,7 @@ def main():
|
|||||||
# change the names of the classes
|
# change the names of the classes
|
||||||
ip_names = args.input_class_names.split(";")
|
ip_names = args.input_class_names.split(";")
|
||||||
op_names = args.output_class_names.split(";")
|
op_names = args.output_class_names.split(";")
|
||||||
name_dict = dict(zip(ip_names, op_names, strict=False))
|
name_dict = dict(zip(ip_names, op_names))
|
||||||
|
|
||||||
# load annotations
|
# load annotations
|
||||||
data_all = tu.load_set_of_anns(
|
data_all = tu.load_set_of_anns(
|
||||||
|
|||||||
@ -1,68 +1,58 @@
|
|||||||
from typing import Sequence
|
from typing import TYPE_CHECKING, List, Optional, Sequence
|
||||||
|
|
||||||
from lightning import Trainer
|
from lightning import Trainer
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig
|
|
||||||
from batdetect2.audio.loader import build_audio_loader
|
from batdetect2.audio.loader import build_audio_loader
|
||||||
from batdetect2.audio.types import AudioLoader
|
|
||||||
from batdetect2.inference.clips import get_clips_from_files
|
from batdetect2.inference.clips import get_clips_from_files
|
||||||
from batdetect2.inference.config import InferenceConfig
|
|
||||||
from batdetect2.inference.dataset import build_inference_loader
|
from batdetect2.inference.dataset import build_inference_loader
|
||||||
from batdetect2.inference.lightning import InferenceModule
|
from batdetect2.inference.lightning import InferenceModule
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.outputs import (
|
from batdetect2.preprocess.preprocessor import build_preprocessor
|
||||||
OutputsConfig,
|
from batdetect2.targets.targets import build_targets
|
||||||
OutputTransformProtocol,
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
build_output_transform,
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
|
from batdetect2.typing import (
|
||||||
|
AudioLoader,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
|
|
||||||
|
|
||||||
def run_batch_inference(
|
def run_batch_inference(
|
||||||
model: Model,
|
model,
|
||||||
clips: Sequence[data.Clip],
|
clips: Sequence[data.Clip],
|
||||||
targets: TargetProtocol | None = None,
|
targets: Optional["TargetProtocol"] = None,
|
||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
audio_config: AudioConfig | None = None,
|
config: Optional["BatDetect2Config"] = None,
|
||||||
output_transform: OutputTransformProtocol | None = None,
|
num_workers: Optional[int] = None,
|
||||||
output_config: OutputsConfig | None = None,
|
batch_size: Optional[int] = None,
|
||||||
inference_config: InferenceConfig | None = None,
|
) -> List[BatDetect2Prediction]:
|
||||||
num_workers: int = 1,
|
from batdetect2.config import BatDetect2Config
|
||||||
batch_size: int | None = None,
|
|
||||||
) -> list[ClipDetections]:
|
config = config or BatDetect2Config()
|
||||||
audio_config = audio_config or AudioConfig(
|
|
||||||
samplerate=model.preprocessor.input_samplerate,
|
audio_loader = audio_loader or build_audio_loader()
|
||||||
|
|
||||||
|
preprocessor = preprocessor or build_preprocessor(
|
||||||
|
input_samplerate=audio_loader.samplerate,
|
||||||
)
|
)
|
||||||
output_config = output_config or OutputsConfig()
|
|
||||||
inference_config = inference_config or InferenceConfig()
|
|
||||||
|
|
||||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
targets = targets or build_targets()
|
||||||
|
|
||||||
preprocessor = preprocessor or model.preprocessor
|
|
||||||
targets = targets or model.targets
|
|
||||||
|
|
||||||
output_transform = output_transform or build_output_transform(
|
|
||||||
config=output_config.transform,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
loader = build_inference_loader(
|
loader = build_inference_loader(
|
||||||
clips,
|
clips,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=inference_config.loader,
|
config=config.inference.loader,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
module = InferenceModule(
|
module = InferenceModule(model)
|
||||||
model,
|
|
||||||
output_transform=output_transform,
|
|
||||||
)
|
|
||||||
trainer = Trainer(enable_checkpointing=False, logger=False)
|
trainer = Trainer(enable_checkpointing=False, logger=False)
|
||||||
outputs = trainer.predict(module, loader)
|
outputs = trainer.predict(module, loader)
|
||||||
return [
|
return [
|
||||||
@ -75,18 +65,13 @@ def run_batch_inference(
|
|||||||
def process_file_list(
|
def process_file_list(
|
||||||
model: Model,
|
model: Model,
|
||||||
paths: Sequence[data.PathLike],
|
paths: Sequence[data.PathLike],
|
||||||
targets: TargetProtocol | None = None,
|
config: "BatDetect2Config",
|
||||||
audio_loader: AudioLoader | None = None,
|
targets: Optional["TargetProtocol"] = None,
|
||||||
audio_config: AudioConfig | None = None,
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
inference_config: InferenceConfig | None = None,
|
num_workers: Optional[int] = None,
|
||||||
output_config: OutputsConfig | None = None,
|
) -> List[BatDetect2Prediction]:
|
||||||
output_transform: OutputTransformProtocol | None = None,
|
clip_config = config.inference.clipping
|
||||||
batch_size: int | None = None,
|
|
||||||
num_workers: int = 0,
|
|
||||||
) -> list[ClipDetections]:
|
|
||||||
inference_config = inference_config or InferenceConfig()
|
|
||||||
clip_config = inference_config.clipping
|
|
||||||
clips = get_clips_from_files(
|
clips = get_clips_from_files(
|
||||||
paths,
|
paths,
|
||||||
duration=clip_config.duration,
|
duration=clip_config.duration,
|
||||||
@ -100,10 +85,6 @@ def process_file_list(
|
|||||||
targets=targets,
|
targets=targets,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
batch_size=batch_size,
|
config=config,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
output_config=output_config,
|
|
||||||
audio_config=audio_config,
|
|
||||||
output_transform=output_transform,
|
|
||||||
inference_config=inference_config,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -38,10 +38,10 @@ def get_recording_clips(
|
|||||||
discard_empty: bool = True,
|
discard_empty: bool = True,
|
||||||
) -> Sequence[data.Clip]:
|
) -> Sequence[data.Clip]:
|
||||||
start_time = 0
|
start_time = 0
|
||||||
recording_duration = recording.duration
|
duration = recording.duration
|
||||||
hop = duration * (1 - overlap)
|
hop = duration * (1 - overlap)
|
||||||
|
|
||||||
num_clips = int(np.ceil(recording_duration / hop))
|
num_clips = int(np.ceil(duration / hop))
|
||||||
|
|
||||||
if num_clips == 0:
|
if num_clips == 0:
|
||||||
# This should only happen if the clip's duration is zero,
|
# This should only happen if the clip's duration is zero,
|
||||||
@ -53,8 +53,8 @@ def get_recording_clips(
|
|||||||
start = start_time + i * hop
|
start = start_time + i * hop
|
||||||
end = start + duration
|
end = start + duration
|
||||||
|
|
||||||
if end > recording_duration:
|
if end > duration:
|
||||||
empty_duration = end - recording_duration
|
empty_duration = end - duration
|
||||||
|
|
||||||
if empty_duration > max_empty and discard_empty:
|
if empty_duration > max_empty and discard_empty:
|
||||||
# Discard clips that contain too much empty space
|
# Discard clips that contain too much empty space
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import NamedTuple, Sequence
|
from typing import List, NamedTuple, Optional, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -6,11 +6,10 @@ from soundevent import data
|
|||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.audio.types import AudioLoader
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.core.arrays import adjust_width
|
from batdetect2.core.arrays import adjust_width
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"InferenceDataset",
|
"InferenceDataset",
|
||||||
@ -30,14 +29,14 @@ class DatasetItem(NamedTuple):
|
|||||||
|
|
||||||
|
|
||||||
class InferenceDataset(Dataset[DatasetItem]):
|
class InferenceDataset(Dataset[DatasetItem]):
|
||||||
clips: list[data.Clip]
|
clips: List[data.Clip]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
clips: Sequence[data.Clip],
|
clips: Sequence[data.Clip],
|
||||||
audio_loader: AudioLoader,
|
audio_loader: AudioLoader,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
audio_dir: data.PathLike | None = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
):
|
):
|
||||||
self.clips = list(clips)
|
self.clips = list(clips)
|
||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
@ -47,30 +46,31 @@ class InferenceDataset(Dataset[DatasetItem]):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.clips)
|
return len(self.clips)
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> DatasetItem:
|
def __getitem__(self, idx: int) -> DatasetItem:
|
||||||
clip = self.clips[index]
|
clip = self.clips[idx]
|
||||||
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||||
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
||||||
spectrogram = self.preprocessor(wav_tensor)
|
spectrogram = self.preprocessor(wav_tensor)
|
||||||
return DatasetItem(
|
return DatasetItem(
|
||||||
spec=spectrogram,
|
spec=spectrogram,
|
||||||
idx=torch.tensor(index),
|
idx=torch.tensor(idx),
|
||||||
start_time=torch.tensor(clip.start_time),
|
start_time=torch.tensor(clip.start_time),
|
||||||
end_time=torch.tensor(clip.end_time),
|
end_time=torch.tensor(clip.end_time),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InferenceLoaderConfig(BaseConfig):
|
class InferenceLoaderConfig(BaseConfig):
|
||||||
|
num_workers: int = 0
|
||||||
batch_size: int = 8
|
batch_size: int = 8
|
||||||
|
|
||||||
|
|
||||||
def build_inference_loader(
|
def build_inference_loader(
|
||||||
clips: Sequence[data.Clip],
|
clips: Sequence[data.Clip],
|
||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
config: InferenceLoaderConfig | None = None,
|
config: Optional[InferenceLoaderConfig] = None,
|
||||||
num_workers: int = 0,
|
num_workers: Optional[int] = None,
|
||||||
batch_size: int | None = None,
|
batch_size: Optional[int] = None,
|
||||||
) -> DataLoader[DatasetItem]:
|
) -> DataLoader[DatasetItem]:
|
||||||
logger.info("Building inference data loader...")
|
logger.info("Building inference data loader...")
|
||||||
config = config or InferenceLoaderConfig()
|
config = config or InferenceLoaderConfig()
|
||||||
@ -83,19 +83,20 @@ def build_inference_loader(
|
|||||||
|
|
||||||
batch_size = batch_size or config.batch_size
|
batch_size = batch_size or config.batch_size
|
||||||
|
|
||||||
|
num_workers = num_workers or config.num_workers
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
inference_dataset,
|
inference_dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=num_workers,
|
num_workers=config.num_workers,
|
||||||
collate_fn=_collate_fn,
|
collate_fn=_collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_inference_dataset(
|
def build_inference_dataset(
|
||||||
clips: Sequence[data.Clip],
|
clips: Sequence[data.Clip],
|
||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
) -> InferenceDataset:
|
) -> InferenceDataset:
|
||||||
if audio_loader is None:
|
if audio_loader is None:
|
||||||
audio_loader = build_audio_loader()
|
audio_loader = build_audio_loader()
|
||||||
@ -110,7 +111,7 @@ def build_inference_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _collate_fn(batch: list[DatasetItem]) -> DatasetItem:
|
def _collate_fn(batch: List[DatasetItem]) -> DatasetItem:
|
||||||
max_width = max(item.spec.shape[-1] for item in batch)
|
max_width = max(item.spec.shape[-1] for item in batch)
|
||||||
return DatasetItem(
|
return DatasetItem(
|
||||||
spec=torch.stack(
|
spec=torch.stack(
|
||||||
|
|||||||
@ -5,44 +5,45 @@ from torch.utils.data import DataLoader
|
|||||||
|
|
||||||
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
|
|
||||||
class InferenceModule(LightningModule):
|
class InferenceModule(LightningModule):
|
||||||
def __init__(
|
def __init__(self, model: Model):
|
||||||
self,
|
|
||||||
model: Model,
|
|
||||||
output_transform: OutputTransformProtocol | None = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.output_transform = output_transform or build_output_transform(
|
|
||||||
targets=model.targets
|
|
||||||
)
|
|
||||||
|
|
||||||
def predict_step(
|
def predict_step(
|
||||||
self,
|
self,
|
||||||
batch: DatasetItem,
|
batch: DatasetItem,
|
||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
dataloader_idx: int = 0,
|
dataloader_idx: int = 0,
|
||||||
) -> Sequence[ClipDetections]:
|
) -> Sequence[BatDetect2Prediction]:
|
||||||
dataset = self.get_dataset()
|
dataset = self.get_dataset()
|
||||||
|
|
||||||
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
|
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
|
||||||
|
|
||||||
outputs = self.model.detector(batch.spec)
|
outputs = self.model.detector(batch.spec)
|
||||||
|
|
||||||
clip_detections = self.model.postprocessor(outputs)
|
clip_detections = self.model.postprocessor(
|
||||||
|
outputs,
|
||||||
return [
|
start_times=[clip.start_time for clip in clips],
|
||||||
self.output_transform.to_clip_detections(
|
|
||||||
detections=clip_dets,
|
|
||||||
clip=clip,
|
|
||||||
)
|
)
|
||||||
for clip, clip_dets in zip(clips, clip_detections, strict=True)
|
|
||||||
|
predictions = [
|
||||||
|
BatDetect2Prediction(
|
||||||
|
clip=clip,
|
||||||
|
predictions=to_raw_predictions(
|
||||||
|
clip_dets.numpy(),
|
||||||
|
targets=self.model.targets,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for clip, clip_dets in zip(clips, clip_detections)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
def get_dataset(self) -> InferenceDataset:
|
def get_dataset(self) -> InferenceDataset:
|
||||||
dataloaders = self.trainer.predict_dataloaders
|
dataloaders = self.trainer.predict_dataloaders
|
||||||
assert isinstance(dataloaders, DataLoader)
|
assert isinstance(dataloaders, DataLoader)
|
||||||
|
|||||||
@ -9,8 +9,10 @@ from typing import (
|
|||||||
Dict,
|
Dict,
|
||||||
Generic,
|
Generic,
|
||||||
Literal,
|
Literal,
|
||||||
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -30,21 +32,6 @@ from batdetect2.core.configs import BaseConfig
|
|||||||
|
|
||||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AppLoggingConfig",
|
|
||||||
"BaseLoggerConfig",
|
|
||||||
"CSVLoggerConfig",
|
|
||||||
"DEFAULT_LOGS_DIR",
|
|
||||||
"DVCLiveConfig",
|
|
||||||
"LoggerConfig",
|
|
||||||
"MLFlowLoggerConfig",
|
|
||||||
"TensorBoardLoggerConfig",
|
|
||||||
"build_logger",
|
|
||||||
"enable_logging",
|
|
||||||
"get_image_logger",
|
|
||||||
"get_table_logger",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def enable_logging(level: int):
|
def enable_logging(level: int):
|
||||||
logger.remove()
|
logger.remove()
|
||||||
@ -62,14 +49,14 @@ def enable_logging(level: int):
|
|||||||
|
|
||||||
class BaseLoggerConfig(BaseConfig):
|
class BaseLoggerConfig(BaseConfig):
|
||||||
log_dir: Path = DEFAULT_LOGS_DIR
|
log_dir: Path = DEFAULT_LOGS_DIR
|
||||||
experiment_name: str | None = None
|
experiment_name: Optional[str] = None
|
||||||
run_name: str | None = None
|
run_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class DVCLiveConfig(BaseLoggerConfig):
|
class DVCLiveConfig(BaseLoggerConfig):
|
||||||
name: Literal["dvclive"] = "dvclive"
|
name: Literal["dvclive"] = "dvclive"
|
||||||
prefix: str = ""
|
prefix: str = ""
|
||||||
log_model: bool | Literal["all"] = False
|
log_model: Union[bool, Literal["all"]] = False
|
||||||
monitor_system: bool = False
|
monitor_system: bool = False
|
||||||
|
|
||||||
|
|
||||||
@ -85,26 +72,22 @@ class TensorBoardLoggerConfig(BaseLoggerConfig):
|
|||||||
|
|
||||||
class MLFlowLoggerConfig(BaseLoggerConfig):
|
class MLFlowLoggerConfig(BaseLoggerConfig):
|
||||||
name: Literal["mlflow"] = "mlflow"
|
name: Literal["mlflow"] = "mlflow"
|
||||||
tracking_uri: str | None = "http://localhost:5000"
|
tracking_uri: Optional[str] = "http://localhost:5000"
|
||||||
tags: dict[str, Any] | None = None
|
tags: Optional[dict[str, Any]] = None
|
||||||
log_model: bool = False
|
log_model: bool = False
|
||||||
|
|
||||||
|
|
||||||
LoggerConfig = Annotated[
|
LoggerConfig = Annotated[
|
||||||
DVCLiveConfig
|
Union[
|
||||||
| CSVLoggerConfig
|
DVCLiveConfig,
|
||||||
| TensorBoardLoggerConfig
|
CSVLoggerConfig,
|
||||||
| MLFlowLoggerConfig,
|
TensorBoardLoggerConfig,
|
||||||
|
MLFlowLoggerConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class AppLoggingConfig(BaseConfig):
|
|
||||||
train: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
|
||||||
evaluation: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
|
||||||
inference: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=LoggerConfig, contravariant=True)
|
T = TypeVar("T", bound=LoggerConfig, contravariant=True)
|
||||||
|
|
||||||
|
|
||||||
@ -112,20 +95,20 @@ class LoggerBuilder(Protocol, Generic[T]):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
config: T,
|
config: T,
|
||||||
log_dir: Path | None = None,
|
log_dir: Optional[Path] = None,
|
||||||
experiment_name: str | None = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: str | None = None,
|
run_name: Optional[str] = None,
|
||||||
) -> Logger: ...
|
) -> Logger: ...
|
||||||
|
|
||||||
|
|
||||||
def create_dvclive_logger(
|
def create_dvclive_logger(
|
||||||
config: DVCLiveConfig,
|
config: DVCLiveConfig,
|
||||||
log_dir: Path | None = None,
|
log_dir: Optional[Path] = None,
|
||||||
experiment_name: str | None = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: str | None = None,
|
run_name: Optional[str] = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
try:
|
try:
|
||||||
from dvclive.lightning import DVCLiveLogger
|
from dvclive.lightning import DVCLiveLogger # type: ignore
|
||||||
except ImportError as error:
|
except ImportError as error:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"DVCLive is not installed and cannot be used for logging"
|
"DVCLive is not installed and cannot be used for logging"
|
||||||
@ -147,9 +130,9 @@ def create_dvclive_logger(
|
|||||||
|
|
||||||
def create_csv_logger(
|
def create_csv_logger(
|
||||||
config: CSVLoggerConfig,
|
config: CSVLoggerConfig,
|
||||||
log_dir: Path | None = None,
|
log_dir: Optional[Path] = None,
|
||||||
experiment_name: str | None = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: str | None = None,
|
run_name: Optional[str] = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
from lightning.pytorch.loggers import CSVLogger
|
from lightning.pytorch.loggers import CSVLogger
|
||||||
|
|
||||||
@ -176,9 +159,9 @@ def create_csv_logger(
|
|||||||
|
|
||||||
def create_tensorboard_logger(
|
def create_tensorboard_logger(
|
||||||
config: TensorBoardLoggerConfig,
|
config: TensorBoardLoggerConfig,
|
||||||
log_dir: Path | None = None,
|
log_dir: Optional[Path] = None,
|
||||||
experiment_name: str | None = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: str | None = None,
|
run_name: Optional[str] = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
|
||||||
@ -208,9 +191,9 @@ def create_tensorboard_logger(
|
|||||||
|
|
||||||
def create_mlflow_logger(
|
def create_mlflow_logger(
|
||||||
config: MLFlowLoggerConfig,
|
config: MLFlowLoggerConfig,
|
||||||
log_dir: data.PathLike | None = None,
|
log_dir: Optional[data.PathLike] = None,
|
||||||
experiment_name: str | None = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: str | None = None,
|
run_name: Optional[str] = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
try:
|
try:
|
||||||
from lightning.pytorch.loggers import MLFlowLogger
|
from lightning.pytorch.loggers import MLFlowLogger
|
||||||
@ -249,9 +232,9 @@ LOGGER_FACTORY: Dict[str, LoggerBuilder] = {
|
|||||||
|
|
||||||
def build_logger(
|
def build_logger(
|
||||||
config: LoggerConfig,
|
config: LoggerConfig,
|
||||||
log_dir: Path | None = None,
|
log_dir: Optional[Path] = None,
|
||||||
experiment_name: str | None = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: str | None = None,
|
run_name: Optional[str] = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building logger with config: \n{}",
|
"Building logger with config: \n{}",
|
||||||
@ -274,7 +257,7 @@ def build_logger(
|
|||||||
PlotLogger = Callable[[str, Figure, int], None]
|
PlotLogger = Callable[[str, Figure, int], None]
|
||||||
|
|
||||||
|
|
||||||
def get_image_logger(logger: Logger) -> PlotLogger | None:
|
def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
|
||||||
if isinstance(logger, TensorBoardLogger):
|
if isinstance(logger, TensorBoardLogger):
|
||||||
return logger.experiment.add_figure
|
return logger.experiment.add_figure
|
||||||
|
|
||||||
@ -299,7 +282,7 @@ def get_image_logger(logger: Logger) -> PlotLogger | None:
|
|||||||
TableLogger = Callable[[str, pd.DataFrame, int], None]
|
TableLogger = Callable[[str, pd.DataFrame, int], None]
|
||||||
|
|
||||||
|
|
||||||
def get_table_logger(logger: Logger) -> TableLogger | None:
|
def get_table_logger(logger: Logger) -> Optional[TableLogger]:
|
||||||
if isinstance(logger, TensorBoardLogger):
|
if isinstance(logger, TensorBoardLogger):
|
||||||
return partial(save_table, dir=Path(logger.log_dir))
|
return partial(save_table, dir=Path(logger.log_dir))
|
||||||
|
|
||||||
|
|||||||
@ -1,46 +1,36 @@
|
|||||||
"""Neural network model definitions and builders for BatDetect2.
|
"""Defines and builds the neural network models used in BatDetect2.
|
||||||
|
|
||||||
This package contains the PyTorch implementations of the deep neural network
|
This package (`batdetect2.models`) contains the PyTorch implementations of the
|
||||||
architectures used to detect and classify bat echolocation calls in
|
deep neural network architectures used for detecting and classifying bat calls
|
||||||
spectrograms. Components are designed to be combined through configuration
|
from spectrograms. It provides modular components and configuration-driven
|
||||||
objects, making it easy to experiment with different architectures.
|
assembly, allowing for experimentation and use of different architectural
|
||||||
|
variants.
|
||||||
|
|
||||||
Key submodules
|
Key Submodules:
|
||||||
--------------
|
- `.types`: Defines core data structures (`ModelOutput`) and abstract base
|
||||||
- ``blocks``: Reusable convolutional building blocks (downsampling,
|
classes (`BackboneModel`, `DetectionModel`) establishing interfaces.
|
||||||
upsampling, attention, coord-conv variants).
|
- `.blocks`: Provides reusable neural network building blocks.
|
||||||
- ``encoder``: The downsampling path; reduces spatial resolution whilst
|
- `.encoder`: Defines and builds the downsampling path (encoder) of the network.
|
||||||
extracting increasingly abstract features.
|
- `.bottleneck`: Defines and builds the central bottleneck component.
|
||||||
- ``bottleneck``: The central component connecting encoder to decoder;
|
- `.decoder`: Defines and builds the upsampling path (decoder) of the network.
|
||||||
optionally applies self-attention along the time axis.
|
- `.backbone`: Assembles the encoder, bottleneck, and decoder into a complete
|
||||||
- ``decoder``: The upsampling path; reconstructs high-resolution feature
|
feature extraction backbone (e.g., a U-Net like structure).
|
||||||
maps using bottleneck output and skip connections from the encoder.
|
- `.heads`: Defines simple prediction heads (detection, classification, size)
|
||||||
- ``backbones``: Assembles encoder, bottleneck, and decoder into a complete
|
that attach to the backbone features.
|
||||||
U-Net-style feature extraction backbone.
|
- `.detectors`: Assembles the backbone and prediction heads into the final,
|
||||||
- ``heads``: Lightweight 1×1 convolutional heads that produce detection,
|
end-to-end `Detector` model.
|
||||||
classification, and bounding-box size predictions from backbone features.
|
|
||||||
- ``detectors``: Combines a backbone with prediction heads into the final
|
|
||||||
end-to-end ``Detector`` model.
|
|
||||||
|
|
||||||
The primary entry point for building a full, ready-to-use BatDetect2 model
|
This module re-exports the most important classes, configurations, and builder
|
||||||
is the ``build_model`` factory function exported from this module.
|
functions from these submodules for convenient access. The primary entry point
|
||||||
|
for creating a standard BatDetect2 model instance is the `build_model` function
|
||||||
|
provided here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Literal
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
|
||||||
from soundevent.data import PathLike
|
|
||||||
|
|
||||||
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
from batdetect2.models.backbones import Backbone, build_backbone
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
from batdetect2.models.backbones import (
|
|
||||||
BackboneConfig,
|
|
||||||
UNetBackbone,
|
|
||||||
UNetBackboneConfig,
|
|
||||||
build_backbone,
|
|
||||||
load_backbone_config,
|
|
||||||
)
|
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvDownConfig,
|
FreqCoordConvDownConfig,
|
||||||
@ -53,6 +43,10 @@ from batdetect2.models.bottleneck import (
|
|||||||
BottleneckConfig,
|
BottleneckConfig,
|
||||||
build_bottleneck,
|
build_bottleneck,
|
||||||
)
|
)
|
||||||
|
from batdetect2.models.config import (
|
||||||
|
BackboneConfig,
|
||||||
|
load_backbone_config,
|
||||||
|
)
|
||||||
from batdetect2.models.decoder import (
|
from batdetect2.models.decoder import (
|
||||||
DEFAULT_DECODER_CONFIG,
|
DEFAULT_DECODER_CONFIG,
|
||||||
DecoderConfig,
|
DecoderConfig,
|
||||||
@ -65,20 +59,17 @@ from batdetect2.models.encoder import (
|
|||||||
build_encoder,
|
build_encoder,
|
||||||
)
|
)
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||||
from batdetect2.models.types import DetectionModel
|
from batdetect2.typing import (
|
||||||
from batdetect2.postprocess.config import PostprocessConfig
|
|
||||||
from batdetect2.postprocess.types import (
|
|
||||||
ClipDetectionsTensor,
|
ClipDetectionsTensor,
|
||||||
|
DetectionModel,
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess.config import PreprocessingConfig
|
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
|
||||||
from batdetect2.targets.config import TargetConfig
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BBoxHead",
|
"BBoxHead",
|
||||||
"UNetBackbone",
|
"Backbone",
|
||||||
"BackboneConfig",
|
"BackboneConfig",
|
||||||
"Bottleneck",
|
"Bottleneck",
|
||||||
"BottleneckConfig",
|
"BottleneckConfig",
|
||||||
@ -101,93 +92,11 @@ __all__ = [
|
|||||||
"build_detector",
|
"build_detector",
|
||||||
"load_backbone_config",
|
"load_backbone_config",
|
||||||
"Model",
|
"Model",
|
||||||
"ModelConfig",
|
|
||||||
"build_model",
|
"build_model",
|
||||||
"build_model_with_new_targets",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseConfig):
|
|
||||||
"""Complete configuration describing a BatDetect2 model.
|
|
||||||
|
|
||||||
Bundles every parameter that defines a model's behaviour: the input
|
|
||||||
sample rate, backbone architecture, preprocessing pipeline,
|
|
||||||
postprocessing pipeline, and detection targets.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
samplerate : int
|
|
||||||
Expected input audio sample rate in Hz. Audio must be resampled
|
|
||||||
to this rate before being passed to the model. Defaults to
|
|
||||||
``TARGET_SAMPLERATE_HZ`` (256 000 Hz).
|
|
||||||
architecture : BackboneConfig
|
|
||||||
Configuration for the encoder-decoder backbone network. Defaults
|
|
||||||
to ``UNetBackboneConfig()``.
|
|
||||||
preprocess : PreprocessingConfig
|
|
||||||
Parameters for the audio-to-spectrogram preprocessing pipeline
|
|
||||||
(STFT, frequency crop, transforms, resize). Defaults to
|
|
||||||
``PreprocessingConfig()``.
|
|
||||||
postprocess : PostprocessConfig
|
|
||||||
Parameters for converting raw model outputs into detections (NMS
|
|
||||||
kernel, thresholds, top-k limit). Defaults to
|
|
||||||
``PostprocessConfig()``.
|
|
||||||
targets : TargetConfig
|
|
||||||
Detection and classification target definitions (class list,
|
|
||||||
detection target, bounding-box mapper). Defaults to
|
|
||||||
``TargetConfig()``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
|
||||||
architecture: BackboneConfig = Field(default_factory=UNetBackboneConfig)
|
|
||||||
preprocess: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(
|
|
||||||
cls,
|
|
||||||
path: PathLike,
|
|
||||||
field: str | None = None,
|
|
||||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
|
||||||
strict: bool | None = None,
|
|
||||||
targets: TargetConfig | None = None,
|
|
||||||
) -> "ModelConfig":
|
|
||||||
config = super().load(path, field, extra, strict)
|
|
||||||
|
|
||||||
if targets is None:
|
|
||||||
return config
|
|
||||||
|
|
||||||
return config.model_copy(update={"targets": targets})
|
|
||||||
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing.
|
|
||||||
|
|
||||||
Combines a preprocessor, a detection model, and a postprocessor into a
|
|
||||||
single PyTorch module. Calling ``forward`` on a raw waveform tensor
|
|
||||||
returns a list of detection tensors ready for downstream use.
|
|
||||||
|
|
||||||
This class is the top-level object produced by ``build_model``. Most
|
|
||||||
users will not need to construct it directly.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
detector : DetectionModel
|
|
||||||
The neural network that processes spectrograms and produces raw
|
|
||||||
detection, classification, and bounding-box outputs.
|
|
||||||
preprocessor : PreprocessorProtocol
|
|
||||||
Converts a raw waveform tensor into a spectrogram tensor accepted by
|
|
||||||
``detector``.
|
|
||||||
postprocessor : PostprocessorProtocol
|
|
||||||
Converts the raw ``ModelOutput`` from ``detector`` into a list of
|
|
||||||
per-clip detection tensors.
|
|
||||||
targets : TargetProtocol
|
|
||||||
Describes the set of target classes; used when building heads and
|
|
||||||
during training target construction.
|
|
||||||
"""
|
|
||||||
|
|
||||||
detector: DetectionModel
|
detector: DetectionModel
|
||||||
preprocessor: PreprocessorProtocol
|
preprocessor: PreprocessorProtocol
|
||||||
postprocessor: PostprocessorProtocol
|
postprocessor: PostprocessorProtocol
|
||||||
@ -206,87 +115,31 @@ class Model(torch.nn.Module):
|
|||||||
self.postprocessor = postprocessor
|
self.postprocessor = postprocessor
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
|
||||||
"""Run the full detection pipeline on a waveform tensor.
|
|
||||||
|
|
||||||
Converts the waveform to a spectrogram, passes it through the
|
|
||||||
detector, and postprocesses the raw outputs into detection tensors.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
wav : torch.Tensor
|
|
||||||
Raw audio waveform tensor. The exact expected shape depends on
|
|
||||||
the preprocessor, but is typically ``(batch, samples)`` or
|
|
||||||
``(batch, channels, samples)``.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
list[ClipDetectionsTensor]
|
|
||||||
One detection tensor per clip in the batch. Each tensor encodes
|
|
||||||
the detected events (locations, class scores, sizes) for that
|
|
||||||
clip.
|
|
||||||
"""
|
|
||||||
spec = self.preprocessor(wav)
|
spec = self.preprocessor(wav)
|
||||||
outputs = self.detector(spec)
|
outputs = self.detector(spec)
|
||||||
return self.postprocessor(outputs)
|
return self.postprocessor(outputs)
|
||||||
|
|
||||||
|
|
||||||
def build_model(
|
def build_model(
|
||||||
config: ModelConfig | None = None,
|
config: Optional[BackboneConfig] = None,
|
||||||
targets: TargetProtocol | None = None,
|
targets: Optional[TargetProtocol] = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
postprocessor: PostprocessorProtocol | None = None,
|
postprocessor: Optional[PostprocessorProtocol] = None,
|
||||||
) -> Model:
|
):
|
||||||
"""Build a complete, ready-to-use BatDetect2 model.
|
|
||||||
|
|
||||||
Assembles a ``Model`` instance from a ``ModelConfig`` and optional
|
|
||||||
component overrides. Any component argument left as ``None`` is built
|
|
||||||
from the configuration. Passing a pre-built component overrides the
|
|
||||||
corresponding config fields for that component only.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config : ModelConfig, optional
|
|
||||||
Full model configuration (samplerate, architecture, preprocessing,
|
|
||||||
postprocessing, targets). Defaults to ``ModelConfig()`` if not
|
|
||||||
provided.
|
|
||||||
targets : TargetProtocol, optional
|
|
||||||
Pre-built targets object. If given, overrides
|
|
||||||
``config.targets``.
|
|
||||||
preprocessor : PreprocessorProtocol, optional
|
|
||||||
Pre-built preprocessor. If given, overrides
|
|
||||||
``config.preprocess`` and ``config.samplerate`` for the
|
|
||||||
preprocessing step.
|
|
||||||
postprocessor : PostprocessorProtocol, optional
|
|
||||||
Pre-built postprocessor. If given, overrides
|
|
||||||
``config.postprocess``. When omitted and a custom
|
|
||||||
``preprocessor`` is supplied, the default postprocessor is built
|
|
||||||
using that preprocessor so that frequency and time scaling remain
|
|
||||||
consistent.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Model
|
|
||||||
A fully assembled ``Model`` instance ready for inference or
|
|
||||||
training.
|
|
||||||
"""
|
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
|
|
||||||
config = config or ModelConfig()
|
config = config or BackboneConfig()
|
||||||
targets = targets or build_targets(config=config.targets)
|
targets = targets or build_targets()
|
||||||
preprocessor = preprocessor or build_preprocessor(
|
preprocessor = preprocessor or build_preprocessor()
|
||||||
config=config.preprocess,
|
|
||||||
input_samplerate=config.samplerate,
|
|
||||||
)
|
|
||||||
postprocessor = postprocessor or build_postprocessor(
|
postprocessor = postprocessor or build_postprocessor(
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config.postprocess,
|
|
||||||
)
|
)
|
||||||
detector = build_detector(
|
detector = build_detector(
|
||||||
num_classes=len(targets.class_names),
|
num_classes=len(targets.class_names),
|
||||||
config=config.architecture,
|
config=config,
|
||||||
)
|
)
|
||||||
return Model(
|
return Model(
|
||||||
detector=detector,
|
detector=detector,
|
||||||
@ -294,21 +147,3 @@ def build_model(
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_model_with_new_targets(
|
|
||||||
model: Model,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
) -> Model:
|
|
||||||
"""Build a new model with a different target set."""
|
|
||||||
detector = build_detector(
|
|
||||||
num_classes=len(targets.class_names),
|
|
||||||
backbone=model.detector.backbone,
|
|
||||||
)
|
|
||||||
|
|
||||||
return Model(
|
|
||||||
detector=detector,
|
|
||||||
postprocessor=model.postprocessor,
|
|
||||||
preprocessor=model.preprocessor,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,176 +1,99 @@
|
|||||||
"""Assembles a complete encoder-decoder backbone network.
|
"""Assembles a complete Encoder-Decoder Backbone network.
|
||||||
|
|
||||||
This module defines ``UNetBackboneConfig`` and the ``UNetBackbone``
|
This module defines the configuration (`BackboneConfig`) and implementation
|
||||||
``nn.Module``, together with the ``build_backbone`` and
|
(`Backbone`) for a standard encoder-decoder style neural network backbone.
|
||||||
``load_backbone_config`` helpers.
|
|
||||||
|
|
||||||
A backbone combines three components built from the sibling modules:
|
It orchestrates the connection between three main components, built using their
|
||||||
|
respective configurations and factory functions from sibling modules:
|
||||||
|
1. Encoder (`batdetect2.models.encoder`): Downsampling path, extracts features
|
||||||
|
at multiple resolutions and provides skip connections.
|
||||||
|
2. Bottleneck (`batdetect2.models.bottleneck`): Processes features at the
|
||||||
|
lowest resolution, optionally applying self-attention.
|
||||||
|
3. Decoder (`batdetect2.models.decoder`): Upsampling path, reconstructs high-
|
||||||
|
resolution features using bottleneck features and skip connections.
|
||||||
|
|
||||||
1. **Encoder** (``batdetect2.models.encoder``) – reduces spatial resolution
|
The resulting `Backbone` module takes a spectrogram as input and outputs a
|
||||||
while extracting hierarchical features and storing skip-connection tensors.
|
final feature map, typically used by subsequent prediction heads. It includes
|
||||||
2. **Bottleneck** (``batdetect2.models.bottleneck``) – processes the
|
automatic padding to handle input sizes not perfectly divisible by the
|
||||||
lowest-resolution features, optionally applying self-attention.
|
network's total downsampling factor.
|
||||||
3. **Decoder** (``batdetect2.models.decoder``) – restores spatial resolution
|
|
||||||
using bottleneck features and skip connections from the encoder.
|
|
||||||
|
|
||||||
The resulting ``UNetBackbone`` takes a spectrogram tensor as input and returns
|
|
||||||
a high-resolution feature map consumed by the prediction heads in
|
|
||||||
``batdetect2.models.detectors``.
|
|
||||||
|
|
||||||
Input padding is handled automatically: the backbone pads the input to be
|
|
||||||
divisible by the total downsampling factor and strips the padding from the
|
|
||||||
output so that the output spatial dimensions always match the input spatial
|
|
||||||
dimensions.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, Literal
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pydantic import Field, TypeAdapter
|
from torch import nn
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.models.bottleneck import build_bottleneck
|
||||||
from batdetect2.core.registries import (
|
from batdetect2.models.config import BackboneConfig
|
||||||
ImportConfig,
|
from batdetect2.models.decoder import Decoder, build_decoder
|
||||||
Registry,
|
from batdetect2.models.encoder import Encoder, build_encoder
|
||||||
add_import_config,
|
from batdetect2.typing.models import BackboneModel
|
||||||
)
|
|
||||||
from batdetect2.models.bottleneck import (
|
|
||||||
DEFAULT_BOTTLENECK_CONFIG,
|
|
||||||
BottleneckConfig,
|
|
||||||
build_bottleneck,
|
|
||||||
)
|
|
||||||
from batdetect2.models.decoder import (
|
|
||||||
DEFAULT_DECODER_CONFIG,
|
|
||||||
DecoderConfig,
|
|
||||||
build_decoder,
|
|
||||||
)
|
|
||||||
from batdetect2.models.encoder import (
|
|
||||||
DEFAULT_ENCODER_CONFIG,
|
|
||||||
EncoderConfig,
|
|
||||||
build_encoder,
|
|
||||||
)
|
|
||||||
from batdetect2.models.types import (
|
|
||||||
BackboneModel,
|
|
||||||
BottleneckProtocol,
|
|
||||||
DecoderProtocol,
|
|
||||||
EncoderProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BackboneImportConfig",
|
"Backbone",
|
||||||
"UNetBackbone",
|
|
||||||
"BackboneConfig",
|
|
||||||
"load_backbone_config",
|
|
||||||
"build_backbone",
|
"build_backbone",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class UNetBackboneConfig(BaseConfig):
|
class Backbone(BackboneModel):
|
||||||
"""Configuration for a U-Net-style encoder-decoder backbone.
|
"""Encoder-Decoder Backbone Network Implementation.
|
||||||
|
|
||||||
All fields have sensible defaults that reproduce the standard BatDetect2
|
Combines an Encoder, Bottleneck, and Decoder module sequentially, using
|
||||||
architecture, so you can start with ``UNetBackboneConfig()`` and override
|
skip connections between the Encoder and Decoder. Implements the standard
|
||||||
only the fields you want to change.
|
U-Net style forward pass. Includes automatic input padding to handle
|
||||||
|
various input sizes and a final convolutional block to adjust the output
|
||||||
|
channels.
|
||||||
|
|
||||||
Attributes
|
This class inherits from `BackboneModel` and implements its `forward`
|
||||||
----------
|
method. Instances are typically created using the `build_backbone` factory
|
||||||
name : str
|
function.
|
||||||
Discriminator field used by the backbone registry; always
|
|
||||||
``"UNetBackbone"``.
|
|
||||||
input_height : int
|
|
||||||
Number of frequency bins in the input spectrogram. Defaults to
|
|
||||||
``128``.
|
|
||||||
in_channels : int
|
|
||||||
Number of channels in the input spectrogram (e.g. ``1`` for a
|
|
||||||
standard mel-spectrogram). Defaults to ``1``.
|
|
||||||
encoder : EncoderConfig
|
|
||||||
Configuration for the downsampling path. Defaults to
|
|
||||||
``DEFAULT_ENCODER_CONFIG``.
|
|
||||||
bottleneck : BottleneckConfig
|
|
||||||
Configuration for the bottleneck. Defaults to
|
|
||||||
``DEFAULT_BOTTLENECK_CONFIG``.
|
|
||||||
decoder : DecoderConfig
|
|
||||||
Configuration for the upsampling path. Defaults to
|
|
||||||
``DEFAULT_DECODER_CONFIG``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["UNetBackbone"] = "UNetBackbone"
|
|
||||||
input_height: int = 128
|
|
||||||
in_channels: int = 1
|
|
||||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
|
||||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
|
||||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
|
||||||
|
|
||||||
|
|
||||||
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
|
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(backbone_registry)
|
|
||||||
class BackboneImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a backbone model.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class UNetBackbone(BackboneModel):
|
|
||||||
"""U-Net-style encoder-decoder backbone network.
|
|
||||||
|
|
||||||
Combines an encoder, a bottleneck, and a decoder into a single module
|
|
||||||
that produces a high-resolution feature map from an input spectrogram.
|
|
||||||
Skip connections from each encoder stage are added element-wise to the
|
|
||||||
corresponding decoder stage input.
|
|
||||||
|
|
||||||
Input spectrograms of arbitrary width are handled automatically: the
|
|
||||||
backbone pads the input so that its dimensions are divisible by
|
|
||||||
``divide_factor`` and removes the padding from the output.
|
|
||||||
|
|
||||||
Instances are typically created via ``build_backbone``.
|
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
input_height : int
|
input_height : int
|
||||||
Expected height (frequency bins) of the input spectrogram.
|
Expected height of the input spectrogram.
|
||||||
out_channels : int
|
out_channels : int
|
||||||
Number of channels in the output feature map (taken from the
|
Number of channels in the final output feature map.
|
||||||
decoder's output channel count).
|
encoder : Encoder
|
||||||
encoder : EncoderProtocol
|
|
||||||
The instantiated encoder module.
|
The instantiated encoder module.
|
||||||
decoder : DecoderProtocol
|
decoder : Decoder
|
||||||
The instantiated decoder module.
|
The instantiated decoder module.
|
||||||
bottleneck : BottleneckProtocol
|
bottleneck : nn.Module
|
||||||
The instantiated bottleneck module.
|
The instantiated bottleneck module.
|
||||||
|
final_conv : ConvBlock
|
||||||
|
Final convolutional block applied after the decoder.
|
||||||
divide_factor : int
|
divide_factor : int
|
||||||
The total spatial downsampling factor applied by the encoder
|
The total downsampling factor (2^depth) applied by the encoder,
|
||||||
(``input_height // encoder.output_height``). The input width is
|
used for automatic input padding.
|
||||||
padded to be a multiple of this value before processing.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_height: int,
|
input_height: int,
|
||||||
encoder: EncoderProtocol,
|
encoder: Encoder,
|
||||||
decoder: DecoderProtocol,
|
decoder: Decoder,
|
||||||
bottleneck: BottleneckProtocol,
|
bottleneck: nn.Module,
|
||||||
):
|
):
|
||||||
"""Initialise the backbone network.
|
"""Initialize the Backbone network.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
input_height : int
|
input_height : int
|
||||||
Expected height (frequency bins) of the input spectrogram.
|
Expected height of the input spectrogram.
|
||||||
encoder : EncoderProtocol
|
out_channels : int
|
||||||
An initialised encoder module.
|
Desired number of output channels for the backbone's feature map.
|
||||||
decoder : DecoderProtocol
|
encoder : Encoder
|
||||||
An initialised decoder module. Its ``output_height`` must equal
|
An initialized Encoder module.
|
||||||
``input_height``; a ``ValueError`` is raised otherwise.
|
decoder : Decoder
|
||||||
bottleneck : BottleneckProtocol
|
An initialized Decoder module.
|
||||||
An initialised bottleneck module.
|
bottleneck : nn.Module
|
||||||
|
An initialized Bottleneck module.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If component output/input channels or heights are incompatible.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_height = input_height
|
self.input_height = input_height
|
||||||
@ -187,25 +110,22 @@ class UNetBackbone(BackboneModel):
|
|||||||
self.divide_factor = input_height // self.encoder.output_height
|
self.divide_factor = input_height // self.encoder.output_height
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
"""Produce a feature map from an input spectrogram.
|
"""Perform the forward pass through the encoder-decoder backbone.
|
||||||
|
|
||||||
Pads the input if necessary, runs it through the encoder, then
|
Applies padding, runs encoder, bottleneck, decoder (with skip
|
||||||
the bottleneck, then the decoder (incorporating encoder skip
|
connections), removes padding, and applies a final convolution.
|
||||||
connections), and finally removes any padding added earlier.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
spec : torch.Tensor
|
spec : torch.Tensor
|
||||||
Input spectrogram tensor, shape
|
Input spectrogram tensor, shape `(B, C_in, H_in, W_in)`. Must match
|
||||||
``(B, C_in, H_in, W_in)``. ``H_in`` must equal
|
`self.encoder.input_channels` and `self.input_height`.
|
||||||
``self.input_height``; ``W_in`` can be any positive integer.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
Feature map tensor, shape ``(B, C_out, H_in, W_in)``, where
|
Output feature map tensor, shape `(B, C_out, H_in, W_in)`, where
|
||||||
``C_out`` is ``self.out_channels``. The spatial dimensions
|
`C_out` is `self.out_channels`.
|
||||||
always match those of the input.
|
|
||||||
"""
|
"""
|
||||||
spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor)
|
spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor)
|
||||||
|
|
||||||
@ -223,9 +143,35 @@ class UNetBackbone(BackboneModel):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@backbone_registry.register(UNetBackboneConfig)
|
|
||||||
@staticmethod
|
def build_backbone(config: BackboneConfig) -> BackboneModel:
|
||||||
def from_config(config: UNetBackboneConfig) -> BackboneModel:
|
"""Factory function to build a Backbone from configuration.
|
||||||
|
|
||||||
|
Constructs the `Encoder`, `Bottleneck`, and `Decoder` components based on
|
||||||
|
the provided `BackboneConfig`, validates their compatibility, and assembles
|
||||||
|
them into a `Backbone` instance.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : BackboneConfig
|
||||||
|
The configuration object detailing the backbone architecture, including
|
||||||
|
input dimensions and configurations for encoder, bottleneck, and
|
||||||
|
decoder.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BackboneModel
|
||||||
|
An initialized `Backbone` module ready for use.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If sub-component configurations are incompatible
|
||||||
|
(e.g., channel mismatches, decoder output height doesn't match backbone
|
||||||
|
input height).
|
||||||
|
NotImplementedError
|
||||||
|
If an unknown block type is specified in sub-configs.
|
||||||
|
"""
|
||||||
encoder = build_encoder(
|
encoder = build_encoder(
|
||||||
in_channels=config.in_channels,
|
in_channels=config.in_channels,
|
||||||
input_height=config.input_height,
|
input_height=config.input_height,
|
||||||
@ -252,7 +198,7 @@ class UNetBackbone(BackboneModel):
|
|||||||
"configurations and input/bottleneck heights."
|
"configurations and input/bottleneck heights."
|
||||||
)
|
)
|
||||||
|
|
||||||
return UNetBackbone(
|
return Backbone(
|
||||||
input_height=config.input_height,
|
input_height=config.input_height,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
@ -260,60 +206,32 @@ class UNetBackbone(BackboneModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
BackboneConfig = Annotated[
|
|
||||||
UNetBackboneConfig | BackboneImportConfig,
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
|
||||||
"""Build a backbone network from configuration.
|
|
||||||
|
|
||||||
Looks up the backbone class corresponding to ``config.name`` in the
|
|
||||||
backbone registry and calls its ``from_config`` method. If no
|
|
||||||
configuration is provided, a default ``UNetBackbone`` is returned.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config : BackboneConfig, optional
|
|
||||||
A configuration object describing the desired backbone. Currently
|
|
||||||
``UNetBackboneConfig`` is the only supported type. Defaults to
|
|
||||||
``UNetBackboneConfig()`` if not provided.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
BackboneModel
|
|
||||||
An initialised backbone module.
|
|
||||||
"""
|
|
||||||
config = config or UNetBackboneConfig()
|
|
||||||
return backbone_registry.build(config)
|
|
||||||
|
|
||||||
|
|
||||||
def _pad_adjust(
|
def _pad_adjust(
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
factor: int = 32,
|
factor: int = 32,
|
||||||
) -> tuple[torch.Tensor, int, int]:
|
) -> Tuple[torch.Tensor, int, int]:
|
||||||
"""Pad a tensor's height and width to be divisible by ``factor``.
|
"""Pad tensor height and width to be divisible by a factor.
|
||||||
|
|
||||||
Adds zero-padding to the bottom and right edges of the tensor so that
|
Calculates the required padding for the last two dimensions (H, W) to make
|
||||||
both dimensions are exact multiples of ``factor``. If both dimensions
|
them divisible by `factor` and applies right/bottom padding using
|
||||||
are already divisible, the tensor is returned unchanged.
|
`torch.nn.functional.pad`.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
spec : torch.Tensor
|
spec : torch.Tensor
|
||||||
Input tensor, typically shape ``(B, C, H, W)``.
|
Input tensor, typically shape `(B, C, H, W)`.
|
||||||
factor : int, default=32
|
factor : int, default=32
|
||||||
The factor that both H and W should be divisible by after padding.
|
The factor to make height and width divisible by.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
tuple[torch.Tensor, int, int]
|
Tuple[torch.Tensor, int, int]
|
||||||
- Padded tensor.
|
A tuple containing:
|
||||||
- Number of rows added to the height (``h_pad``).
|
- The padded tensor.
|
||||||
- Number of columns added to the width (``w_pad``).
|
- The amount of padding added to height (`h_pad`).
|
||||||
|
- The amount of padding added to width (`w_pad`).
|
||||||
"""
|
"""
|
||||||
h, w = spec.shape[-2:]
|
h, w = spec.shape[2:]
|
||||||
h_pad = -h % factor
|
h_pad = -h % factor
|
||||||
w_pad = -w % factor
|
w_pad = -w % factor
|
||||||
|
|
||||||
@ -326,71 +244,28 @@ def _pad_adjust(
|
|||||||
def _restore_pad(
|
def _restore_pad(
|
||||||
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
|
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Remove padding previously added by ``_pad_adjust``.
|
"""Remove padding added by _pad_adjust.
|
||||||
|
|
||||||
Trims ``h_pad`` rows from the bottom and ``w_pad`` columns from the
|
Removes padding from the bottom and right edges of the tensor.
|
||||||
right of the tensor, restoring its original spatial dimensions.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
x : torch.Tensor
|
x : torch.Tensor
|
||||||
Padded tensor, typically shape ``(B, C, H_padded, W_padded)``.
|
Padded tensor, typically shape `(B, C, H_padded, W_padded)`.
|
||||||
h_pad : int, default=0
|
h_pad : int, default=0
|
||||||
Number of rows to remove from the bottom.
|
Amount of padding previously added to the height (bottom).
|
||||||
w_pad : int, default=0
|
w_pad : int, default=0
|
||||||
Number of columns to remove from the right.
|
Amount of padding previously added to the width (right).
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
Tensor with padding removed, shape
|
Tensor with padding removed, shape `(B, C, H_original, W_original)`.
|
||||||
``(B, C, H_padded - h_pad, W_padded - w_pad)``.
|
|
||||||
"""
|
"""
|
||||||
if h_pad > 0:
|
if h_pad > 0:
|
||||||
x = x[..., :-h_pad, :]
|
x = x[:, :, :-h_pad, :]
|
||||||
|
|
||||||
if w_pad > 0:
|
if w_pad > 0:
|
||||||
x = x[..., :-w_pad]
|
x = x[:, :, :, :-w_pad]
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def load_backbone_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: str | None = None,
|
|
||||||
) -> BackboneConfig:
|
|
||||||
"""Load a backbone configuration from a YAML or JSON file.
|
|
||||||
|
|
||||||
Reads the file at ``path``, optionally descends into a named sub-field,
|
|
||||||
and validates the result against the ``BackboneConfig`` discriminated
|
|
||||||
union.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : PathLike
|
|
||||||
Path to the configuration file. Both YAML and JSON formats are
|
|
||||||
supported.
|
|
||||||
field : str, optional
|
|
||||||
Dot-separated key path to the sub-field that contains the backbone
|
|
||||||
configuration (e.g. ``"model"``). If ``None``, the root of the
|
|
||||||
file is used.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
BackboneConfig
|
|
||||||
A validated backbone configuration object (currently always a
|
|
||||||
``UNetBackboneConfig`` instance).
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If ``path`` does not exist.
|
|
||||||
ValidationError
|
|
||||||
If the loaded data does not conform to a known ``BackboneConfig``
|
|
||||||
schema.
|
|
||||||
"""
|
|
||||||
return load_config(
|
|
||||||
path,
|
|
||||||
schema=TypeAdapter(BackboneConfig),
|
|
||||||
field=field,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,63 +1,42 @@
|
|||||||
"""Reusable convolutional building blocks for BatDetect2 models.
|
"""Commonly used neural network building blocks for BatDetect2 models.
|
||||||
|
|
||||||
This module provides a collection of ``torch.nn.Module`` subclasses that form
|
This module provides various reusable `torch.nn.Module` subclasses that form
|
||||||
the fundamental building blocks for the encoder-decoder backbone used in
|
the fundamental building blocks for constructing convolutional neural network
|
||||||
BatDetect2. All blocks follow a consistent interface: they store
|
architectures, particularly encoder-decoder backbones used in BatDetect2.
|
||||||
``in_channels`` and ``out_channels`` as attributes and implement a
|
|
||||||
``get_output_height`` method that reports how a given input height maps to an
|
|
||||||
output height (e.g., halved by downsampling blocks, doubled by upsampling
|
|
||||||
blocks).
|
|
||||||
|
|
||||||
Available block families
|
It includes standard components like basic convolutional blocks (`ConvBlock`),
|
||||||
------------------------
|
blocks incorporating downsampling (`StandardConvDownBlock`), and blocks with
|
||||||
Standard blocks
|
upsampling (`StandardConvUpBlock`).
|
||||||
``ConvBlock`` – convolution + batch normalisation + ReLU, no change in
|
|
||||||
spatial resolution.
|
|
||||||
|
|
||||||
Downsampling blocks
|
Additionally, it features specialized layers investigated in BatDetect2
|
||||||
``StandardConvDownBlock`` – convolution then 2×2 max-pooling, halves H
|
research:
|
||||||
and W.
|
|
||||||
``FreqCoordConvDownBlock`` – like ``StandardConvDownBlock`` but prepends
|
|
||||||
a normalised frequency-coordinate channel before the convolution
|
|
||||||
(CoordConv concept), helping filters learn frequency-position-dependent
|
|
||||||
patterns.
|
|
||||||
|
|
||||||
Upsampling blocks
|
- `SelfAttention`: Applies self-attention along the time dimension, enabling
|
||||||
``StandardConvUpBlock`` – bilinear interpolation then convolution,
|
the model to weigh information across the entire temporal context, often
|
||||||
doubles H and W.
|
used in the bottleneck of an encoder-decoder.
|
||||||
``FreqCoordConvUpBlock`` – like ``StandardConvUpBlock`` but prepends a
|
- `FreqCoordConvDownBlock` / `FreqCoordConvUpBlock`: Implement the "CoordConv"
|
||||||
frequency-coordinate channel after upsampling.
|
concept by concatenating normalized frequency coordinate information as an
|
||||||
|
extra channel to the input of convolutional layers. This explicitly provides
|
||||||
|
spatial frequency information to filters, potentially enabling them to learn
|
||||||
|
frequency-dependent patterns more effectively.
|
||||||
|
|
||||||
Bottleneck blocks
|
These blocks can be utilized directly in custom PyTorch model definitions or
|
||||||
``VerticalConv`` – 1-D convolution whose kernel spans the entire
|
assembled into larger architectures.
|
||||||
frequency axis, collapsing H to 1 whilst preserving W.
|
|
||||||
``SelfAttention`` – scaled dot-product self-attention along the time
|
|
||||||
axis; typically follows a ``VerticalConv``.
|
|
||||||
|
|
||||||
Group block
|
A unified factory function `build_layer_from_config` allows creating instances
|
||||||
``LayerGroup`` – chains several blocks sequentially into one unit,
|
of these blocks based on configuration objects.
|
||||||
useful when a single encoder or decoder "stage" requires more than one
|
|
||||||
operation.
|
|
||||||
|
|
||||||
Factory function
|
|
||||||
----------------
|
|
||||||
``build_layer`` creates any of the above blocks from the matching
|
|
||||||
configuration object (one of the ``*Config`` classes exported here), using
|
|
||||||
a discriminated-union ``name`` field to dispatch to the correct class.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, List, Literal, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BlockImportConfig",
|
|
||||||
"ConvBlock",
|
"ConvBlock",
|
||||||
"LayerGroupConfig",
|
"LayerGroupConfig",
|
||||||
"VerticalConv",
|
"VerticalConv",
|
||||||
@ -72,125 +51,63 @@ __all__ = [
|
|||||||
"FreqCoordConvUpConfig",
|
"FreqCoordConvUpConfig",
|
||||||
"StandardConvUpConfig",
|
"StandardConvUpConfig",
|
||||||
"LayerConfig",
|
"LayerConfig",
|
||||||
"build_layer",
|
"build_layer_from_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
|
||||||
"""Abstract base class for all BatDetect2 building blocks.
|
|
||||||
|
|
||||||
Subclasses must set ``in_channels`` and ``out_channels`` as integer
|
|
||||||
attributes so that factory functions can wire blocks together without
|
|
||||||
inspecting configuration objects at runtime. They may also override
|
|
||||||
``get_output_height`` when the block changes the height dimension (e.g.
|
|
||||||
downsampling or upsampling blocks).
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
in_channels : int
|
|
||||||
Number of channels expected in the input tensor.
|
|
||||||
out_channels : int
|
|
||||||
Number of channels produced in the output tensor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
in_channels: int
|
|
||||||
out_channels: int
|
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
|
||||||
"""Return the output height for a given input height.
|
|
||||||
|
|
||||||
The default implementation returns ``input_height`` unchanged,
|
|
||||||
which is correct for blocks that do not alter spatial resolution.
|
|
||||||
Override this in downsampling (returns ``input_height // 2``) or
|
|
||||||
upsampling (returns ``input_height * 2``) subclasses.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
input_height : int
|
|
||||||
Height (number of frequency bins) of the input feature map.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
int
|
|
||||||
Height of the output feature map.
|
|
||||||
"""
|
|
||||||
return input_height
|
|
||||||
|
|
||||||
|
|
||||||
block_registry: Registry[Block, [int, int]] = Registry("block")
|
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(block_registry)
|
|
||||||
class BlockImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a model block.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class SelfAttentionConfig(BaseConfig):
|
class SelfAttentionConfig(BaseConfig):
|
||||||
"""Configuration for a ``SelfAttention`` block.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
name : str
|
|
||||||
Discriminator field; always ``"SelfAttention"``.
|
|
||||||
attention_channels : int
|
|
||||||
Dimensionality of the query, key, and value projections.
|
|
||||||
temperature : float
|
|
||||||
Scaling factor applied to the weighted values before the final
|
|
||||||
linear projection. Defaults to ``1``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["SelfAttention"] = "SelfAttention"
|
name: Literal["SelfAttention"] = "SelfAttention"
|
||||||
attention_channels: int
|
attention_channels: int
|
||||||
temperature: float = 1
|
temperature: float = 1
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(Block):
|
class SelfAttention(nn.Module):
|
||||||
"""Self-attention block operating along the time axis.
|
"""Self-Attention mechanism operating along the time dimension.
|
||||||
|
|
||||||
Applies a scaled dot-product self-attention mechanism across the time
|
This module implements a scaled dot-product self-attention mechanism,
|
||||||
steps of an input feature map. Before attention is computed the height
|
specifically designed here to operate across the time steps of an input
|
||||||
dimension (frequency axis) is expected to have been reduced to 1, e.g.
|
feature map, typically after spatial dimensions (like frequency) have been
|
||||||
by a preceding ``VerticalConv`` layer.
|
condensed or squeezed.
|
||||||
|
|
||||||
For each time step the block computes query, key, and value projections
|
By calculating attention weights between all pairs of time steps, it allows
|
||||||
with learned linear weights, then calculates attention weights from the
|
the model to capture long-range temporal dependencies and focus on relevant
|
||||||
query–key dot products divided by ``temperature × attention_channels``.
|
parts of the sequence. It's often employed in the bottleneck or
|
||||||
The weighted sum of values is projected back to ``in_channels`` via a
|
intermediate layers of an encoder-decoder architecture to integrate global
|
||||||
final linear layer, and the height dimension is restored so that the
|
temporal context.
|
||||||
output shape matches the input shape.
|
|
||||||
|
The implementation uses linear projections to create query, key, and value
|
||||||
|
representations, computes scaled dot-product attention scores, applies
|
||||||
|
softmax, and produces an output by weighting the values according to the
|
||||||
|
attention scores, followed by a final linear projection. Positional encoding
|
||||||
|
is not explicitly included in this block.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of input channels (features per time step). The output will
|
Number of input channels (features per time step after spatial squeeze).
|
||||||
also have ``in_channels`` channels.
|
|
||||||
attention_channels : int
|
attention_channels : int
|
||||||
Dimensionality of the query, key, and value projections.
|
Number of channels for the query, key, and value projections. Also the
|
||||||
|
dimension of the output projection's input.
|
||||||
temperature : float, default=1.0
|
temperature : float, default=1.0
|
||||||
Divisor applied together with ``attention_channels`` when scaling
|
Scaling factor applied *before* the final projection layer. Can be used
|
||||||
the dot-product scores before softmax. Larger values produce softer
|
to adjust the sharpness or focus of the attention mechanism, although
|
||||||
(more uniform) attention distributions.
|
scaling within the softmax (dividing by sqrt(dim)) is more common for
|
||||||
|
standard transformers. Here it scales the weighted values.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
key_fun : nn.Linear
|
key_fun : nn.Linear
|
||||||
Linear projection for keys.
|
Linear layer for key projection.
|
||||||
value_fun : nn.Linear
|
value_fun : nn.Linear
|
||||||
Linear projection for values.
|
Linear layer for value projection.
|
||||||
query_fun : nn.Linear
|
query_fun : nn.Linear
|
||||||
Linear projection for queries.
|
Linear layer for query projection.
|
||||||
pro_fun : nn.Linear
|
pro_fun : nn.Linear
|
||||||
Final linear projection applied to the attended values.
|
Final linear projection layer applied after attention weighting.
|
||||||
temperature : float
|
temperature : float
|
||||||
Scaling divisor used when computing attention scores.
|
Scaling factor applied before final projection.
|
||||||
att_dim : int
|
att_dim : int
|
||||||
Dimensionality of the attention space (``attention_channels``).
|
Dimensionality of the attention space (`attention_channels`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -200,13 +117,10 @@ class SelfAttention(Block):
|
|||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = in_channels
|
|
||||||
|
|
||||||
# Note, does not encode position information (absolute or relative)
|
# Note, does not encode position information (absolute or relative)
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.att_dim = attention_channels
|
self.att_dim = attention_channels
|
||||||
self.output_channels = in_channels
|
|
||||||
|
|
||||||
self.key_fun = nn.Linear(in_channels, attention_channels)
|
self.key_fun = nn.Linear(in_channels, attention_channels)
|
||||||
self.value_fun = nn.Linear(in_channels, attention_channels)
|
self.value_fun = nn.Linear(in_channels, attention_channels)
|
||||||
@ -219,16 +133,20 @@ class SelfAttention(Block):
|
|||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
x : torch.Tensor
|
x : torch.Tensor
|
||||||
Input tensor with shape ``(B, C, 1, W)``. The height dimension
|
Input tensor, expected shape `(B, C, H, W)`, where H is typically
|
||||||
must be 1 (i.e. the frequency axis should already have been
|
squeezed (e.g., H=1 after a `VerticalConv` or pooling) before
|
||||||
collapsed by a preceding ``VerticalConv`` layer).
|
applying attention along the W (time) dimension.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
Output tensor with the same shape ``(B, C, 1, W)`` as the
|
Output tensor of the same shape as the input `(B, C, H, W)`, where
|
||||||
input, with each time step updated by attended context from all
|
attention has been applied across the W dimension.
|
||||||
other time steps.
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
RuntimeError
|
||||||
|
If input tensor dimensions are incompatible with operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
x = x.squeeze(2).permute(0, 2, 1)
|
x = x.squeeze(2).permute(0, 2, 1)
|
||||||
@ -257,22 +175,6 @@ class SelfAttention(Block):
|
|||||||
return op
|
return op
|
||||||
|
|
||||||
def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
|
def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Return the softmax attention weight matrix.
|
|
||||||
|
|
||||||
Useful for visualising which time steps attend to which others.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
x : torch.Tensor
|
|
||||||
Input tensor with shape ``(B, C, 1, W)``.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
Attention weight matrix with shape ``(B, W, W)``. Entry
|
|
||||||
``[b, i, j]`` is the attention weight that time step ``i``
|
|
||||||
assigns to time step ``j`` in batch item ``b``.
|
|
||||||
"""
|
|
||||||
x = x.squeeze(2).permute(0, 2, 1)
|
x = x.squeeze(2).permute(0, 2, 1)
|
||||||
|
|
||||||
key = torch.matmul(
|
key = torch.matmul(
|
||||||
@ -288,19 +190,6 @@ class SelfAttention(Block):
|
|||||||
att_weights = F.softmax(kk_qq, 1)
|
att_weights = F.softmax(kk_qq, 1)
|
||||||
return att_weights
|
return att_weights
|
||||||
|
|
||||||
@block_registry.register(SelfAttentionConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: SelfAttentionConfig,
|
|
||||||
input_channels: int,
|
|
||||||
input_height: int,
|
|
||||||
) -> "SelfAttention":
|
|
||||||
return SelfAttention(
|
|
||||||
in_channels=input_channels,
|
|
||||||
attention_channels=config.attention_channels,
|
|
||||||
temperature=config.temperature,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ConvConfig(BaseConfig):
|
class ConvConfig(BaseConfig):
|
||||||
"""Configuration for a basic ConvBlock."""
|
"""Configuration for a basic ConvBlock."""
|
||||||
@ -318,7 +207,7 @@ class ConvConfig(BaseConfig):
|
|||||||
"""Padding size."""
|
"""Padding size."""
|
||||||
|
|
||||||
|
|
||||||
class ConvBlock(Block):
|
class ConvBlock(nn.Module):
|
||||||
"""Basic Convolutional Block.
|
"""Basic Convolutional Block.
|
||||||
|
|
||||||
A standard building block consisting of a 2D convolution, followed by
|
A standard building block consisting of a 2D convolution, followed by
|
||||||
@ -346,8 +235,6 @@ class ConvBlock(Block):
|
|||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
@ -371,37 +258,8 @@ class ConvBlock(Block):
|
|||||||
"""
|
"""
|
||||||
return F.relu_(self.batch_norm(self.conv(x)))
|
return F.relu_(self.batch_norm(self.conv(x)))
|
||||||
|
|
||||||
@block_registry.register(ConvConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: ConvConfig,
|
|
||||||
input_channels: int,
|
|
||||||
input_height: int,
|
|
||||||
):
|
|
||||||
return ConvBlock(
|
|
||||||
in_channels=input_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
kernel_size=config.kernel_size,
|
|
||||||
pad_size=config.pad_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
class VerticalConv(nn.Module):
|
||||||
class VerticalConvConfig(BaseConfig):
|
|
||||||
"""Configuration for a ``VerticalConv`` block.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
name : str
|
|
||||||
Discriminator field; always ``"VerticalConv"``.
|
|
||||||
channels : int
|
|
||||||
Number of output channels produced by the vertical convolution.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["VerticalConv"] = "VerticalConv"
|
|
||||||
channels: int
|
|
||||||
|
|
||||||
|
|
||||||
class VerticalConv(Block):
|
|
||||||
"""Convolutional layer that aggregates features across the entire height.
|
"""Convolutional layer that aggregates features across the entire height.
|
||||||
|
|
||||||
Applies a 2D convolution using a kernel with shape `(input_height, 1)`.
|
Applies a 2D convolution using a kernel with shape `(input_height, 1)`.
|
||||||
@ -430,8 +288,6 @@ class VerticalConv(Block):
|
|||||||
input_height: int,
|
input_height: int,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
@ -456,19 +312,6 @@ class VerticalConv(Block):
|
|||||||
"""
|
"""
|
||||||
return F.relu_(self.bn(self.conv(x)))
|
return F.relu_(self.bn(self.conv(x)))
|
||||||
|
|
||||||
@block_registry.register(VerticalConvConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: VerticalConvConfig,
|
|
||||||
input_channels: int,
|
|
||||||
input_height: int,
|
|
||||||
):
|
|
||||||
return VerticalConv(
|
|
||||||
in_channels=input_channels,
|
|
||||||
out_channels=config.channels,
|
|
||||||
input_height=input_height,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FreqCoordConvDownConfig(BaseConfig):
|
class FreqCoordConvDownConfig(BaseConfig):
|
||||||
"""Configuration for a FreqCoordConvDownBlock."""
|
"""Configuration for a FreqCoordConvDownBlock."""
|
||||||
@ -486,7 +329,7 @@ class FreqCoordConvDownConfig(BaseConfig):
|
|||||||
"""Padding size."""
|
"""Padding size."""
|
||||||
|
|
||||||
|
|
||||||
class FreqCoordConvDownBlock(Block):
|
class FreqCoordConvDownBlock(nn.Module):
|
||||||
"""Downsampling Conv Block incorporating Frequency Coordinate features.
|
"""Downsampling Conv Block incorporating Frequency Coordinate features.
|
||||||
|
|
||||||
This block implements a downsampling step (Conv2d + MaxPool2d) commonly
|
This block implements a downsampling step (Conv2d + MaxPool2d) commonly
|
||||||
@ -525,8 +368,6 @@ class FreqCoordConvDownBlock(Block):
|
|||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
|
|
||||||
self.coords = nn.Parameter(
|
self.coords = nn.Parameter(
|
||||||
torch.linspace(-1, 1, input_height)[None, None, ..., None],
|
torch.linspace(-1, 1, input_height)[None, None, ..., None],
|
||||||
@ -561,24 +402,6 @@ class FreqCoordConvDownBlock(Block):
|
|||||||
x = F.relu(self.batch_norm(x), inplace=True)
|
x = F.relu(self.batch_norm(x), inplace=True)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
|
||||||
return input_height // 2
|
|
||||||
|
|
||||||
@block_registry.register(FreqCoordConvDownConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: FreqCoordConvDownConfig,
|
|
||||||
input_channels: int,
|
|
||||||
input_height: int,
|
|
||||||
):
|
|
||||||
return FreqCoordConvDownBlock(
|
|
||||||
in_channels=input_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
input_height=input_height,
|
|
||||||
kernel_size=config.kernel_size,
|
|
||||||
pad_size=config.pad_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StandardConvDownConfig(BaseConfig):
|
class StandardConvDownConfig(BaseConfig):
|
||||||
"""Configuration for a StandardConvDownBlock."""
|
"""Configuration for a StandardConvDownBlock."""
|
||||||
@ -596,7 +419,7 @@ class StandardConvDownConfig(BaseConfig):
|
|||||||
"""Padding size."""
|
"""Padding size."""
|
||||||
|
|
||||||
|
|
||||||
class StandardConvDownBlock(Block):
|
class StandardConvDownBlock(nn.Module):
|
||||||
"""Standard Downsampling Convolutional Block.
|
"""Standard Downsampling Convolutional Block.
|
||||||
|
|
||||||
A basic downsampling block consisting of a 2D convolution, followed by
|
A basic downsampling block consisting of a 2D convolution, followed by
|
||||||
@ -624,8 +447,6 @@ class StandardConvDownBlock(Block):
|
|||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
):
|
):
|
||||||
super(StandardConvDownBlock, self).__init__()
|
super(StandardConvDownBlock, self).__init__()
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
@ -651,23 +472,6 @@ class StandardConvDownBlock(Block):
|
|||||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||||
return F.relu(self.batch_norm(x), inplace=True)
|
return F.relu(self.batch_norm(x), inplace=True)
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
|
||||||
return input_height // 2
|
|
||||||
|
|
||||||
@block_registry.register(StandardConvDownConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: StandardConvDownConfig,
|
|
||||||
input_channels: int,
|
|
||||||
input_height: int,
|
|
||||||
):
|
|
||||||
return StandardConvDownBlock(
|
|
||||||
in_channels=input_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
kernel_size=config.kernel_size,
|
|
||||||
pad_size=config.pad_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FreqCoordConvUpConfig(BaseConfig):
|
class FreqCoordConvUpConfig(BaseConfig):
|
||||||
"""Configuration for a FreqCoordConvUpBlock."""
|
"""Configuration for a FreqCoordConvUpBlock."""
|
||||||
@ -684,14 +488,8 @@ class FreqCoordConvUpConfig(BaseConfig):
|
|||||||
pad_size: int = 1
|
pad_size: int = 1
|
||||||
"""Padding size."""
|
"""Padding size."""
|
||||||
|
|
||||||
up_mode: str = "bilinear"
|
|
||||||
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
|
|
||||||
|
|
||||||
up_scale: tuple[int, int] = (2, 2)
|
class FreqCoordConvUpBlock(nn.Module):
|
||||||
"""Scaling factor for height and width during upsampling."""
|
|
||||||
|
|
||||||
|
|
||||||
class FreqCoordConvUpBlock(Block):
|
|
||||||
"""Upsampling Conv Block incorporating Frequency Coordinate features.
|
"""Upsampling Conv Block incorporating Frequency Coordinate features.
|
||||||
|
|
||||||
This block implements an upsampling step followed by a convolution,
|
This block implements an upsampling step followed by a convolution,
|
||||||
@ -706,22 +504,22 @@ class FreqCoordConvUpBlock(Block):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
in_channels
|
in_channels : int
|
||||||
Number of channels in the input tensor (before upsampling).
|
Number of channels in the input tensor (before upsampling).
|
||||||
out_channels
|
out_channels : int
|
||||||
Number of output channels after the convolution.
|
Number of output channels after the convolution.
|
||||||
input_height
|
input_height : int
|
||||||
Height (H dimension, frequency bins) of the tensor *before* upsampling.
|
Height (H dimension, frequency bins) of the tensor *before* upsampling.
|
||||||
Used to calculate the height for coordinate feature generation after
|
Used to calculate the height for coordinate feature generation after
|
||||||
upsampling.
|
upsampling.
|
||||||
kernel_size
|
kernel_size : int, default=3
|
||||||
Size of the square convolutional kernel.
|
Size of the square convolutional kernel.
|
||||||
pad_size
|
pad_size : int, default=1
|
||||||
Padding added before convolution.
|
Padding added before convolution.
|
||||||
up_mode
|
up_mode : str, default="bilinear"
|
||||||
Interpolation mode for upsampling (e.g., "nearest", "bilinear",
|
Interpolation mode for upsampling (e.g., "nearest", "bilinear",
|
||||||
"bicubic").
|
"bicubic").
|
||||||
up_scale
|
up_scale : Tuple[int, int], default=(2, 2)
|
||||||
Scaling factor for height and width during upsampling
|
Scaling factor for height and width during upsampling
|
||||||
(typically (2, 2)).
|
(typically (2, 2)).
|
||||||
"""
|
"""
|
||||||
@ -734,11 +532,9 @@ class FreqCoordConvUpBlock(Block):
|
|||||||
kernel_size: int = 3,
|
kernel_size: int = 3,
|
||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
up_mode: str = "bilinear",
|
up_mode: str = "bilinear",
|
||||||
up_scale: tuple[int, int] = (2, 2),
|
up_scale: Tuple[int, int] = (2, 2),
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
|
|
||||||
self.up_scale = up_scale
|
self.up_scale = up_scale
|
||||||
self.up_mode = up_mode
|
self.up_mode = up_mode
|
||||||
@ -785,26 +581,6 @@ class FreqCoordConvUpBlock(Block):
|
|||||||
op = F.relu(self.batch_norm(op), inplace=True)
|
op = F.relu(self.batch_norm(op), inplace=True)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
|
||||||
return input_height * 2
|
|
||||||
|
|
||||||
@block_registry.register(FreqCoordConvUpConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: FreqCoordConvUpConfig,
|
|
||||||
input_channels: int,
|
|
||||||
input_height: int,
|
|
||||||
):
|
|
||||||
return FreqCoordConvUpBlock(
|
|
||||||
in_channels=input_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
input_height=input_height,
|
|
||||||
kernel_size=config.kernel_size,
|
|
||||||
pad_size=config.pad_size,
|
|
||||||
up_mode=config.up_mode,
|
|
||||||
up_scale=config.up_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StandardConvUpConfig(BaseConfig):
|
class StandardConvUpConfig(BaseConfig):
|
||||||
"""Configuration for a StandardConvUpBlock."""
|
"""Configuration for a StandardConvUpBlock."""
|
||||||
@ -821,14 +597,8 @@ class StandardConvUpConfig(BaseConfig):
|
|||||||
pad_size: int = 1
|
pad_size: int = 1
|
||||||
"""Padding size."""
|
"""Padding size."""
|
||||||
|
|
||||||
up_mode: str = "bilinear"
|
|
||||||
"""Interpolation mode for upsampling (e.g., "nearest", "bilinear")."""
|
|
||||||
|
|
||||||
up_scale: tuple[int, int] = (2, 2)
|
class StandardConvUpBlock(nn.Module):
|
||||||
"""Scaling factor for height and width during upsampling."""
|
|
||||||
|
|
||||||
|
|
||||||
class StandardConvUpBlock(Block):
|
|
||||||
"""Standard Upsampling Convolutional Block.
|
"""Standard Upsampling Convolutional Block.
|
||||||
|
|
||||||
A basic upsampling block used in CNN decoders. It first upsamples the input
|
A basic upsampling block used in CNN decoders. It first upsamples the input
|
||||||
@ -839,17 +609,17 @@ class StandardConvUpBlock(Block):
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
in_channels
|
in_channels : int
|
||||||
Number of channels in the input tensor (before upsampling).
|
Number of channels in the input tensor (before upsampling).
|
||||||
out_channels
|
out_channels : int
|
||||||
Number of output channels after the convolution.
|
Number of output channels after the convolution.
|
||||||
kernel_size
|
kernel_size : int, default=3
|
||||||
Size of the square convolutional kernel.
|
Size of the square convolutional kernel.
|
||||||
pad_size
|
pad_size : int, default=1
|
||||||
Padding added before convolution.
|
Padding added before convolution.
|
||||||
up_mode
|
up_mode : str, default="bilinear"
|
||||||
Interpolation mode for upsampling (e.g., "nearest", "bilinear").
|
Interpolation mode for upsampling (e.g., "nearest", "bilinear").
|
||||||
up_scale
|
up_scale : Tuple[int, int], default=(2, 2)
|
||||||
Scaling factor for height and width during upsampling.
|
Scaling factor for height and width during upsampling.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -860,11 +630,9 @@ class StandardConvUpBlock(Block):
|
|||||||
kernel_size: int = 3,
|
kernel_size: int = 3,
|
||||||
pad_size: int = 1,
|
pad_size: int = 1,
|
||||||
up_mode: str = "bilinear",
|
up_mode: str = "bilinear",
|
||||||
up_scale: tuple[int, int] = (2, 2),
|
up_scale: Tuple[int, int] = (2, 2),
|
||||||
):
|
):
|
||||||
super(StandardConvUpBlock, self).__init__()
|
super(StandardConvUpBlock, self).__init__()
|
||||||
self.in_channels = in_channels
|
|
||||||
self.out_channels = out_channels
|
|
||||||
self.up_scale = up_scale
|
self.up_scale = up_scale
|
||||||
self.up_mode = up_mode
|
self.up_mode = up_mode
|
||||||
self.conv = nn.Conv2d(
|
self.conv = nn.Conv2d(
|
||||||
@ -901,195 +669,155 @@ class StandardConvUpBlock(Block):
|
|||||||
op = F.relu(self.batch_norm(op), inplace=True)
|
op = F.relu(self.batch_norm(op), inplace=True)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
|
||||||
return input_height * 2
|
|
||||||
|
|
||||||
@block_registry.register(StandardConvUpConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: StandardConvUpConfig,
|
|
||||||
input_channels: int,
|
|
||||||
input_height: int,
|
|
||||||
):
|
|
||||||
return StandardConvUpBlock(
|
|
||||||
in_channels=input_channels,
|
|
||||||
out_channels=config.out_channels,
|
|
||||||
kernel_size=config.kernel_size,
|
|
||||||
pad_size=config.pad_size,
|
|
||||||
up_mode=config.up_mode,
|
|
||||||
up_scale=config.up_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LayerGroupConfig(BaseConfig):
|
|
||||||
"""Configuration for a ``LayerGroup`` — a sequential chain of blocks.
|
|
||||||
|
|
||||||
Use this when a single encoder or decoder stage needs more than one
|
|
||||||
block. The blocks are executed in the order they appear in ``layers``,
|
|
||||||
with channel counts and heights propagated automatically.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
name : str
|
|
||||||
Discriminator field; always ``"LayerGroup"``.
|
|
||||||
layers : List[LayerConfig]
|
|
||||||
Ordered list of block configurations to chain together.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["LayerGroup"] = "LayerGroup"
|
|
||||||
layers: list["LayerConfig"]
|
|
||||||
|
|
||||||
|
|
||||||
LayerConfig = Annotated[
|
LayerConfig = Annotated[
|
||||||
ConvConfig
|
Union[
|
||||||
| FreqCoordConvDownConfig
|
ConvConfig,
|
||||||
| StandardConvDownConfig
|
FreqCoordConvDownConfig,
|
||||||
| FreqCoordConvUpConfig
|
StandardConvDownConfig,
|
||||||
| StandardConvUpConfig
|
FreqCoordConvUpConfig,
|
||||||
| SelfAttentionConfig
|
StandardConvUpConfig,
|
||||||
| LayerGroupConfig,
|
SelfAttentionConfig,
|
||||||
|
"LayerGroupConfig",
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configuration models."""
|
"""Type alias for the discriminated union of block configuration models."""
|
||||||
|
|
||||||
|
|
||||||
class LayerGroup(nn.Module):
|
class LayerGroupConfig(BaseConfig):
|
||||||
"""Sequential chain of blocks that acts as a single composite block.
|
name: Literal["LayerGroup"] = "LayerGroup"
|
||||||
|
layers: List[LayerConfig]
|
||||||
Wraps multiple ``Block`` instances in an ``nn.Sequential`` container,
|
|
||||||
exposing the same ``in_channels``, ``out_channels``, and
|
|
||||||
``get_output_height`` interface as a regular ``Block`` so it can be
|
|
||||||
used transparently wherever a single block is expected.
|
|
||||||
|
|
||||||
Instances are typically constructed by ``build_layer`` when given a
|
|
||||||
``LayerGroupConfig``; you rarely need to create them directly.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
layers : list[Block]
|
|
||||||
Pre-built block instances to chain, in execution order.
|
|
||||||
input_height : int
|
|
||||||
Height of the tensor entering the first block.
|
|
||||||
input_channels : int
|
|
||||||
Number of channels in the tensor entering the first block.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
in_channels : int
|
|
||||||
Number of input channels (taken from the first block).
|
|
||||||
out_channels : int
|
|
||||||
Number of output channels (taken from the last block).
|
|
||||||
layers : nn.Sequential
|
|
||||||
The wrapped sequence of block modules.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
layers: list[Block],
|
|
||||||
input_height: int,
|
|
||||||
input_channels: int,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.in_channels = input_channels
|
|
||||||
self.out_channels = (
|
|
||||||
layers[-1].out_channels if layers else input_channels
|
|
||||||
)
|
|
||||||
self.layers = nn.Sequential(*layers)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Pass input through all blocks in sequence.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
x : torch.Tensor
|
|
||||||
Input feature map, shape ``(B, C_in, H, W)``.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
torch.Tensor
|
|
||||||
Output feature map after all blocks have been applied.
|
|
||||||
"""
|
|
||||||
return self.layers(x)
|
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int:
|
|
||||||
"""Compute the output height by propagating through all blocks.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
input_height : int
|
|
||||||
Height of the input feature map.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
int
|
|
||||||
Height after all blocks in the group have been applied.
|
|
||||||
"""
|
|
||||||
for block in self.layers:
|
|
||||||
input_height = block.get_output_height(input_height) # type: ignore
|
|
||||||
return input_height
|
|
||||||
|
|
||||||
@block_registry.register(LayerGroupConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: LayerGroupConfig,
|
|
||||||
input_channels: int,
|
|
||||||
input_height: int,
|
|
||||||
):
|
|
||||||
layers = []
|
|
||||||
|
|
||||||
for layer_config in config.layers:
|
|
||||||
layer = build_layer(
|
|
||||||
input_height=input_height,
|
|
||||||
in_channels=input_channels,
|
|
||||||
config=layer_config,
|
|
||||||
)
|
|
||||||
layers.append(layer)
|
|
||||||
input_height = layer.get_output_height(input_height)
|
|
||||||
input_channels = layer.out_channels
|
|
||||||
|
|
||||||
return LayerGroup(
|
|
||||||
layers=layers,
|
|
||||||
input_height=input_height,
|
|
||||||
input_channels=input_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_layer(
|
def build_layer_from_config(
|
||||||
input_height: int,
|
input_height: int,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
config: LayerConfig,
|
config: LayerConfig,
|
||||||
) -> Block:
|
) -> Tuple[nn.Module, int, int]:
|
||||||
"""Build a block from its configuration object.
|
"""Factory function to build a specific nn.Module block from its config.
|
||||||
|
|
||||||
Looks up the block class corresponding to ``config.name`` in the
|
Takes configuration object (one of the types included in the `LayerConfig`
|
||||||
internal block registry and instantiates it with the given input
|
union) and instantiates the corresponding nn.Module block with the correct
|
||||||
dimensions. This is the standard way to construct blocks when
|
parameters derived from the config and the current pipeline state
|
||||||
assembling an encoder or decoder from a configuration file.
|
(`input_height`, `in_channels`).
|
||||||
|
|
||||||
|
It uses the `name` field within the `config` object to determine
|
||||||
|
which block class to instantiate.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
input_height : int
|
input_height : int
|
||||||
Height (number of frequency bins) of the input tensor to this
|
Height (frequency bins) of the input tensor *to this layer*.
|
||||||
block. Required for blocks whose kernel size depends on the input
|
|
||||||
height (e.g. ``VerticalConv``) and for coordinate-aware blocks.
|
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of channels in the input tensor to this block.
|
Number of channels in the input tensor *to this layer*.
|
||||||
config : LayerConfig
|
config : LayerConfig
|
||||||
A configuration object for the desired block type. The ``name``
|
A Pydantic configuration object for the desired block (e.g., an
|
||||||
field selects the block class; remaining fields supply its
|
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
|
||||||
parameters.
|
by its `name` field.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Block
|
Tuple[nn.Module, int, int]
|
||||||
An initialised block module ready to be added to an
|
A tuple containing:
|
||||||
``nn.Sequential`` or ``nn.ModuleList``.
|
- The instantiated `nn.Module` block.
|
||||||
|
- The number of output channels produced by the block.
|
||||||
|
- The calculated height of the output produced by the block.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
KeyError
|
NotImplementedError
|
||||||
If ``config.name`` does not correspond to a registered block type.
|
If the `config.name` does not correspond to a known block type.
|
||||||
ValueError
|
ValueError
|
||||||
If the configuration parameters are invalid for the chosen block.
|
If parameters derived from the config are invalid for the block.
|
||||||
"""
|
"""
|
||||||
return block_registry.build(config, in_channels, input_height)
|
if config.name == "ConvBlock":
|
||||||
|
return (
|
||||||
|
ConvBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
pad_size=config.pad_size,
|
||||||
|
),
|
||||||
|
config.out_channels,
|
||||||
|
input_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.name == "FreqCoordConvDown":
|
||||||
|
return (
|
||||||
|
FreqCoordConvDownBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
input_height=input_height,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
pad_size=config.pad_size,
|
||||||
|
),
|
||||||
|
config.out_channels,
|
||||||
|
input_height // 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.name == "StandardConvDown":
|
||||||
|
return (
|
||||||
|
StandardConvDownBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
pad_size=config.pad_size,
|
||||||
|
),
|
||||||
|
config.out_channels,
|
||||||
|
input_height // 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.name == "FreqCoordConvUp":
|
||||||
|
return (
|
||||||
|
FreqCoordConvUpBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
input_height=input_height,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
pad_size=config.pad_size,
|
||||||
|
),
|
||||||
|
config.out_channels,
|
||||||
|
input_height * 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.name == "StandardConvUp":
|
||||||
|
return (
|
||||||
|
StandardConvUpBlock(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=config.out_channels,
|
||||||
|
kernel_size=config.kernel_size,
|
||||||
|
pad_size=config.pad_size,
|
||||||
|
),
|
||||||
|
config.out_channels,
|
||||||
|
input_height * 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.name == "SelfAttention":
|
||||||
|
return (
|
||||||
|
SelfAttention(
|
||||||
|
in_channels=in_channels,
|
||||||
|
attention_channels=config.attention_channels,
|
||||||
|
temperature=config.temperature,
|
||||||
|
),
|
||||||
|
config.attention_channels,
|
||||||
|
input_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.name == "LayerGroup":
|
||||||
|
current_channels = in_channels
|
||||||
|
current_height = input_height
|
||||||
|
|
||||||
|
blocks = []
|
||||||
|
|
||||||
|
for block_config in config.layers:
|
||||||
|
block, current_channels, current_height = build_layer_from_config(
|
||||||
|
input_height=current_height,
|
||||||
|
in_channels=current_channels,
|
||||||
|
config=block_config,
|
||||||
|
)
|
||||||
|
blocks.append(block)
|
||||||
|
|
||||||
|
return nn.Sequential(*blocks), current_channels, current_height
|
||||||
|
|
||||||
|
raise NotImplementedError(f"Unknown block type {config.name}")
|
||||||
|
|||||||
@ -1,24 +1,20 @@
|
|||||||
"""Bottleneck component for encoder-decoder network architectures.
|
"""Defines the Bottleneck component of an Encoder-Decoder architecture.
|
||||||
|
|
||||||
The bottleneck sits between the encoder (downsampling path) and the decoder
|
This module provides the configuration (`BottleneckConfig`) and
|
||||||
(upsampling path) and processes the lowest-resolution, highest-channel feature
|
`torch.nn.Module` implementations (`Bottleneck`, `BottleneckAttn`) for the
|
||||||
map produced by the encoder.
|
bottleneck layer(s) that typically connect the Encoder (downsampling path) and
|
||||||
|
Decoder (upsampling path) in networks like U-Nets.
|
||||||
|
|
||||||
This module provides:
|
The bottleneck processes the lowest-resolution, highest-dimensionality feature
|
||||||
|
map produced by the Encoder. This module offers a configurable option to include
|
||||||
|
a `SelfAttention` layer within the bottleneck, allowing the model to capture
|
||||||
|
global temporal context before features are passed to the Decoder.
|
||||||
|
|
||||||
- ``BottleneckConfig`` – configuration dataclass describing the number of
|
A factory function `build_bottleneck` constructs the appropriate bottleneck
|
||||||
internal channels and an optional sequence of additional layers (currently
|
module based on the provided configuration.
|
||||||
only ``SelfAttention`` is supported).
|
|
||||||
- ``Bottleneck`` – the ``torch.nn.Module`` implementation. It first applies a
|
|
||||||
``VerticalConv`` to collapse the frequency axis to a single bin, optionally
|
|
||||||
runs one or more additional layers (e.g. self-attention along the time axis),
|
|
||||||
then repeats the output along the height dimension to restore the original
|
|
||||||
frequency resolution before passing features to the decoder.
|
|
||||||
- ``build_bottleneck`` – factory function that constructs a ``Bottleneck``
|
|
||||||
instance from a ``BottleneckConfig`` and the encoder's output dimensions.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, List
|
from typing import Annotated, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -26,12 +22,10 @@ from torch import nn
|
|||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
Block,
|
|
||||||
SelfAttentionConfig,
|
SelfAttentionConfig,
|
||||||
VerticalConv,
|
VerticalConv,
|
||||||
build_layer,
|
build_layer_from_config,
|
||||||
)
|
)
|
||||||
from batdetect2.models.types import BottleneckProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BottleneckConfig",
|
"BottleneckConfig",
|
||||||
@ -40,52 +34,43 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Bottleneck(Block):
|
class Bottleneck(nn.Module):
|
||||||
"""Bottleneck module for encoder-decoder architectures.
|
"""Base Bottleneck module for Encoder-Decoder architectures.
|
||||||
|
|
||||||
Processes the lowest-resolution feature map that links the encoder and
|
This implementation represents the simplest bottleneck structure
|
||||||
decoder. The sequence of operations is:
|
considered, primarily consisting of a `VerticalConv` layer. This layer
|
||||||
|
collapses the frequency dimension (height) to 1, summarizing information
|
||||||
|
across frequencies at each time step. The output is then repeated along the
|
||||||
|
height dimension to match the original bottleneck input height before being
|
||||||
|
passed to the decoder.
|
||||||
|
|
||||||
1. ``VerticalConv`` – collapses the frequency axis (height) to a single
|
This base version does *not* include self-attention.
|
||||||
bin by applying a convolution whose kernel spans the full height.
|
|
||||||
2. Optional additional layers (e.g. ``SelfAttention``) – applied while
|
|
||||||
the feature map has height 1, so they operate purely along the time
|
|
||||||
axis.
|
|
||||||
3. Height restoration – the single-bin output is repeated along the
|
|
||||||
height axis to restore the original frequency resolution, producing
|
|
||||||
a tensor that the decoder can accept.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
input_height : int
|
input_height : int
|
||||||
Height (number of frequency bins) of the input tensor. Must be
|
Height (frequency bins) of the input tensor. Must be positive.
|
||||||
positive.
|
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of channels in the input tensor from the encoder. Must be
|
Number of channels in the input tensor from the encoder. Must be
|
||||||
positive.
|
positive.
|
||||||
out_channels : int
|
out_channels : int
|
||||||
Number of output channels after the bottleneck. Must be positive.
|
Number of output channels. Must be positive.
|
||||||
bottleneck_channels : int, optional
|
|
||||||
Number of internal channels used by the ``VerticalConv`` layer.
|
|
||||||
Defaults to ``out_channels`` if not provided.
|
|
||||||
layers : List[torch.nn.Module], optional
|
|
||||||
Additional modules (e.g. ``SelfAttention``) to apply after the
|
|
||||||
``VerticalConv`` and before height restoration.
|
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of input channels accepted by the bottleneck.
|
Number of input channels accepted by the bottleneck.
|
||||||
out_channels : int
|
|
||||||
Number of output channels produced by the bottleneck.
|
|
||||||
input_height : int
|
input_height : int
|
||||||
Expected height of the input tensor.
|
Expected height of the input tensor.
|
||||||
bottleneck_channels : int
|
channels : int
|
||||||
Number of channels used internally by the vertical convolution.
|
Number of output channels.
|
||||||
conv_vert : VerticalConv
|
conv_vert : VerticalConv
|
||||||
The vertical convolution layer.
|
The vertical convolution layer.
|
||||||
layers : nn.ModuleList
|
|
||||||
Additional layers applied after the vertical convolution.
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `input_height`, `in_channels`, or `out_channels` are not positive.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -93,31 +78,14 @@ class Bottleneck(Block):
|
|||||||
input_height: int,
|
input_height: int,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
bottleneck_channels: int | None = None,
|
bottleneck_channels: Optional[int] = None,
|
||||||
layers: List[torch.nn.Module] | None = None,
|
layers: Optional[List[torch.nn.Module]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialise the Bottleneck layer.
|
"""Initialize the base Bottleneck layer."""
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
input_height : int
|
|
||||||
Height (number of frequency bins) of the input tensor.
|
|
||||||
in_channels : int
|
|
||||||
Number of channels in the input tensor.
|
|
||||||
out_channels : int
|
|
||||||
Number of channels in the output tensor.
|
|
||||||
bottleneck_channels : int, optional
|
|
||||||
Number of internal channels for the ``VerticalConv``. Defaults
|
|
||||||
to ``out_channels``.
|
|
||||||
layers : List[torch.nn.Module], optional
|
|
||||||
Additional modules applied after the ``VerticalConv``, such as
|
|
||||||
a ``SelfAttention`` block.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.input_height = input_height
|
self.input_height = input_height
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
|
||||||
self.bottleneck_channels = (
|
self.bottleneck_channels = (
|
||||||
bottleneck_channels
|
bottleneck_channels
|
||||||
if bottleneck_channels is not None
|
if bottleneck_channels is not None
|
||||||
@ -132,24 +100,23 @@ class Bottleneck(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Process the encoder's bottleneck features.
|
"""Process input features through the bottleneck.
|
||||||
|
|
||||||
Applies vertical convolution, optional additional layers, then
|
Applies vertical convolution and repeats the output height.
|
||||||
restores the height dimension by repetition.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
x : torch.Tensor
|
x : torch.Tensor
|
||||||
Input tensor from the encoder, shape
|
Input tensor from the encoder bottleneck, shape
|
||||||
``(B, C_in, H_in, W)``. ``C_in`` must match
|
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
|
||||||
``self.in_channels`` and ``H_in`` must match
|
`H_in` must match `self.input_height`.
|
||||||
``self.input_height``.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
Output tensor with shape ``(B, C_out, H_in, W)``. The height
|
Output tensor, shape `(B, C_out, H_in, W)`. Note that the height
|
||||||
``H_in`` is restored by repeating the single-bin result.
|
dimension `H_in` is restored via repetition after the vertical
|
||||||
|
convolution.
|
||||||
"""
|
"""
|
||||||
x = self.conv_vert(x)
|
x = self.conv_vert(x)
|
||||||
|
|
||||||
@ -160,29 +127,37 @@ class Bottleneck(Block):
|
|||||||
|
|
||||||
|
|
||||||
BottleneckLayerConfig = Annotated[
|
BottleneckLayerConfig = Annotated[
|
||||||
SelfAttentionConfig,
|
Union[SelfAttentionConfig,],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in the Bottleneck."""
|
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||||
|
|
||||||
|
|
||||||
class BottleneckConfig(BaseConfig):
|
class BottleneckConfig(BaseConfig):
|
||||||
"""Configuration for the bottleneck component.
|
"""Configuration for the bottleneck layer(s).
|
||||||
|
|
||||||
|
Defines the number of channels within the bottleneck and whether to include
|
||||||
|
a self-attention mechanism.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
channels : int
|
channels : int
|
||||||
Number of output channels produced by the bottleneck. This value
|
The number of output channels produced by the main convolutional layer
|
||||||
is also used as the dimensionality of any optional layers (e.g.
|
within the bottleneck. This often matches the number of channels coming
|
||||||
self-attention). Must be positive.
|
from the last encoder stage, but can be different. Must be positive.
|
||||||
layers : List[BottleneckLayerConfig]
|
This also defines the channel dimensions used within the optional
|
||||||
Ordered list of additional block configurations to apply after the
|
`SelfAttention` layer.
|
||||||
initial ``VerticalConv``. Currently only ``SelfAttentionConfig`` is
|
self_attention : bool
|
||||||
supported. Defaults to an empty list (no extra layers).
|
If True, includes a `SelfAttention` layer operating on the time
|
||||||
|
dimension after an initial `VerticalConv` layer within the bottleneck.
|
||||||
|
If False, only the initial `VerticalConv` (and height repetition) is
|
||||||
|
performed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
channels: int
|
channels: int
|
||||||
layers: List[BottleneckLayerConfig] = Field(default_factory=list)
|
layers: List[BottleneckLayerConfig] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
||||||
@ -196,39 +171,32 @@ DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
|||||||
def build_bottleneck(
|
def build_bottleneck(
|
||||||
input_height: int,
|
input_height: int,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
config: BottleneckConfig | None = None,
|
config: Optional[BottleneckConfig] = None,
|
||||||
) -> BottleneckProtocol:
|
) -> nn.Module:
|
||||||
"""Build a ``Bottleneck`` module from configuration.
|
"""Factory function to build the Bottleneck module from configuration.
|
||||||
|
|
||||||
Constructs a ``Bottleneck`` instance whose internal channel count and
|
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
|
||||||
optional extra layers (e.g. self-attention) are controlled by
|
on the `config.self_attention` flag.
|
||||||
``config``. If no configuration is provided, the default
|
|
||||||
``DEFAULT_BOTTLENECK_CONFIG`` is used, which includes a
|
|
||||||
``SelfAttention`` layer.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
input_height : int
|
input_height : int
|
||||||
Height (number of frequency bins) of the input tensor from the
|
Height (frequency bins) of the input tensor. Must be positive.
|
||||||
encoder. Must be positive.
|
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of channels in the input tensor from the encoder. Must be
|
Number of channels in the input tensor. Must be positive.
|
||||||
positive.
|
|
||||||
config : BottleneckConfig, optional
|
config : BottleneckConfig, optional
|
||||||
Configuration specifying the output channel count and any
|
Configuration object specifying the bottleneck channels and whether
|
||||||
additional layers. Uses ``DEFAULT_BOTTLENECK_CONFIG`` if ``None``.
|
to use self-attention. Uses `DEFAULT_BOTTLENECK_CONFIG` if None.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
BottleneckProtocol
|
nn.Module
|
||||||
An initialised ``Bottleneck`` module.
|
An initialized bottleneck module (`Bottleneck` or `BottleneckAttn`).
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
AssertionError
|
ValueError
|
||||||
If any configured layer changes the height of the feature map
|
If `input_height` or `in_channels` are not positive.
|
||||||
(bottleneck layers must preserve height so that it can be restored
|
|
||||||
by repetition).
|
|
||||||
"""
|
"""
|
||||||
config = config or DEFAULT_BOTTLENECK_CONFIG
|
config = config or DEFAULT_BOTTLENECK_CONFIG
|
||||||
|
|
||||||
@ -238,13 +206,11 @@ def build_bottleneck(
|
|||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
for layer_config in config.layers:
|
for layer_config in config.layers:
|
||||||
layer = build_layer(
|
layer, current_channels, current_height = build_layer_from_config(
|
||||||
input_height=current_height,
|
input_height=current_height,
|
||||||
in_channels=current_channels,
|
in_channels=current_channels,
|
||||||
config=layer_config,
|
config=layer_config,
|
||||||
)
|
)
|
||||||
current_height = layer.get_output_height(current_height)
|
|
||||||
current_channels = layer.out_channels
|
|
||||||
assert current_height == input_height, (
|
assert current_height == input_height, (
|
||||||
"Bottleneck layers should not change the spectrogram height"
|
"Bottleneck layers should not change the spectrogram height"
|
||||||
)
|
)
|
||||||
|
|||||||
98
src/batdetect2/models/config.py
Normal file
98
src/batdetect2/models/config.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.models.bottleneck import (
|
||||||
|
DEFAULT_BOTTLENECK_CONFIG,
|
||||||
|
BottleneckConfig,
|
||||||
|
)
|
||||||
|
from batdetect2.models.decoder import (
|
||||||
|
DEFAULT_DECODER_CONFIG,
|
||||||
|
DecoderConfig,
|
||||||
|
)
|
||||||
|
from batdetect2.models.encoder import (
|
||||||
|
DEFAULT_ENCODER_CONFIG,
|
||||||
|
EncoderConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BackboneConfig",
|
||||||
|
"load_backbone_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BackboneConfig(BaseConfig):
|
||||||
|
"""Configuration for the Encoder-Decoder Backbone network.
|
||||||
|
|
||||||
|
Aggregates configurations for the encoder, bottleneck, and decoder
|
||||||
|
components, along with defining the input and final output dimensions
|
||||||
|
for the complete backbone.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
input_height : int, default=128
|
||||||
|
Expected height (frequency bins) of the input spectrograms to the
|
||||||
|
backbone. Must be positive.
|
||||||
|
in_channels : int, default=1
|
||||||
|
Expected number of channels in the input spectrograms (e.g., 1 for
|
||||||
|
mono). Must be positive.
|
||||||
|
encoder : EncoderConfig, optional
|
||||||
|
Configuration for the encoder. If None or omitted,
|
||||||
|
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
|
||||||
|
encoder module) will be used.
|
||||||
|
bottleneck : BottleneckConfig, optional
|
||||||
|
Configuration for the bottleneck layer connecting encoder and decoder.
|
||||||
|
If None or omitted, the default bottleneck configuration will be used.
|
||||||
|
decoder : DecoderConfig, optional
|
||||||
|
Configuration for the decoder. If None or omitted,
|
||||||
|
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
|
||||||
|
decoder module) will be used.
|
||||||
|
out_channels : int, default=32
|
||||||
|
Desired number of channels in the final feature map output by the
|
||||||
|
backbone. Must be positive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_height: int = 128
|
||||||
|
in_channels: int = 1
|
||||||
|
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||||
|
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||||
|
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||||
|
out_channels: int = 32
|
||||||
|
|
||||||
|
|
||||||
|
def load_backbone_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> BackboneConfig:
|
||||||
|
"""Load the backbone configuration from a file.
|
||||||
|
|
||||||
|
Reads a configuration file (YAML) and validates it against the
|
||||||
|
`BackboneConfig` schema, potentially extracting data from a nested field.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
Path to the configuration file.
|
||||||
|
field : str, optional
|
||||||
|
Dot-separated path to a nested section within the file containing the
|
||||||
|
backbone configuration (e.g., "model.backbone"). If None, the entire
|
||||||
|
file content is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BackboneConfig
|
||||||
|
The loaded and validated backbone configuration object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
yaml.YAMLError
|
||||||
|
If the file content is not valid YAML.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the loaded config data does not conform to `BackboneConfig`.
|
||||||
|
KeyError, TypeError
|
||||||
|
If `field` specifies an invalid path.
|
||||||
|
"""
|
||||||
|
return load_config(path, schema=BackboneConfig, field=field)
|
||||||
@ -1,24 +1,24 @@
|
|||||||
"""Decoder (upsampling path) for the BatDetect2 backbone.
|
"""Constructs the Decoder part of an Encoder-Decoder neural network.
|
||||||
|
|
||||||
This module defines ``DecoderConfig`` and the ``Decoder`` ``nn.Module``,
|
This module defines the configuration structure (`DecoderConfig`) for the layer
|
||||||
together with the ``build_decoder`` factory function.
|
sequence and provides the `Decoder` class (an `nn.Module`) along with a factory
|
||||||
|
function (`build_decoder`). Decoders typically form the upsampling path in
|
||||||
|
architectures like U-Nets, taking bottleneck features
|
||||||
|
(usually from an `Encoder`) and skip connections to reconstruct
|
||||||
|
higher-resolution feature maps.
|
||||||
|
|
||||||
In a U-Net-style network the decoder progressively restores the spatial
|
The decoder is built dynamically by stacking neural network blocks based on a
|
||||||
resolution of the feature map back towards the input resolution. At each
|
list of configuration objects provided in `DecoderConfig.layers`. Each config
|
||||||
stage it combines the upsampled features with the corresponding skip-connection
|
object specifies the type of block (e.g., standard convolution,
|
||||||
tensor from the encoder (the residual) by element-wise addition before passing
|
coordinate-feature convolution with upsampling) and its parameters. This allows
|
||||||
the result to the upsampling block.
|
flexible definition of decoder architectures via configuration files.
|
||||||
|
|
||||||
The decoder is fully configurable: the type, number, and parameters of the
|
The `Decoder`'s `forward` method is designed to accept skip connection tensors
|
||||||
upsampling blocks are described by a ``DecoderConfig`` object containing an
|
(`residuals`) from the encoder, merging them with the upsampled feature maps
|
||||||
ordered list of block configuration objects (see ``batdetect2.models.blocks``
|
at each stage.
|
||||||
for available block types).
|
|
||||||
|
|
||||||
A default configuration ``DEFAULT_DECODER_CONFIG`` is provided and used by
|
|
||||||
``build_decoder`` when no explicit configuration is supplied.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, List
|
from typing import Annotated, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -30,7 +30,7 @@ from batdetect2.models.blocks import (
|
|||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
LayerGroupConfig,
|
LayerGroupConfig,
|
||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
build_layer,
|
build_layer_from_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -41,57 +41,63 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
DecoderLayerConfig = Annotated[
|
DecoderLayerConfig = Annotated[
|
||||||
ConvConfig
|
Union[
|
||||||
| FreqCoordConvUpConfig
|
ConvConfig,
|
||||||
| StandardConvUpConfig
|
FreqCoordConvUpConfig,
|
||||||
| LayerGroupConfig,
|
StandardConvUpConfig,
|
||||||
|
LayerGroupConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||||
|
|
||||||
|
|
||||||
class DecoderConfig(BaseConfig):
|
class DecoderConfig(BaseConfig):
|
||||||
"""Configuration for the sequential ``Decoder`` module.
|
"""Configuration for the sequence of layers in the Decoder module.
|
||||||
|
|
||||||
|
Defines the types and parameters of the neural network blocks that
|
||||||
|
constitute the decoder's upsampling path.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
layers : List[DecoderLayerConfig]
|
layers : List[DecoderLayerConfig]
|
||||||
Ordered list of block configuration objects defining the decoder's
|
An ordered list of configuration objects, each defining one layer or
|
||||||
upsampling stages (from deepest to shallowest). Each entry
|
block in the decoder sequence. Each item must be a valid block
|
||||||
specifies the block type (via its ``name`` field) and any
|
config including a `name` field and necessary parameters like
|
||||||
block-specific parameters such as ``out_channels``. Input channels
|
`out_channels`. Input channels for each layer are inferred sequentially.
|
||||||
for each block are inferred automatically from the output of the
|
The list must contain at least one layer.
|
||||||
previous block. Must contain at least one entry.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
layers: List[DecoderLayerConfig] = Field(min_length=1)
|
layers: List[DecoderLayerConfig] = Field(min_length=1)
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
"""Sequential decoder module composed of configurable upsampling layers.
|
"""Sequential Decoder module composed of configurable upsampling layers.
|
||||||
|
|
||||||
Executes a series of upsampling blocks in order, adding the
|
Constructs the upsampling path of an encoder-decoder network by stacking
|
||||||
corresponding encoder skip-connection tensor (residual) to the feature
|
multiple blocks (e.g., `StandardConvUpBlock`, `FreqCoordConvUpBlock`)
|
||||||
map before each block. The residuals are consumed in reverse order (from
|
based on a list of layer modules provided during initialization (typically
|
||||||
deepest encoder layer to shallowest) to match the spatial resolutions at
|
created by the `build_decoder` factory function).
|
||||||
each decoder stage.
|
|
||||||
|
|
||||||
Instances are typically created by ``build_decoder``.
|
The `forward` method is designed to integrate skip connection tensors
|
||||||
|
(`residuals`) from the corresponding encoder stages, by adding them
|
||||||
|
element-wise to the input of each decoder layer before processing.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of channels expected in the input tensor (bottleneck output).
|
Number of channels expected in the input tensor.
|
||||||
out_channels : int
|
out_channels : int
|
||||||
Number of channels in the final output feature map.
|
Number of channels in the final output tensor produced by the last
|
||||||
|
layer.
|
||||||
input_height : int
|
input_height : int
|
||||||
Height (frequency bins) of the input tensor.
|
Height (frequency bins) expected in the input tensor.
|
||||||
output_height : int
|
output_height : int
|
||||||
Height (frequency bins) of the output tensor.
|
Height (frequency bins) expected in the output tensor.
|
||||||
layers : nn.ModuleList
|
layers : nn.ModuleList
|
||||||
Sequence of instantiated upsampling block modules.
|
The sequence of instantiated upscaling layer modules.
|
||||||
depth : int
|
depth : int
|
||||||
Number of upsampling layers.
|
The number of upscaling layers (depth) in the decoder.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -102,24 +108,23 @@ class Decoder(nn.Module):
|
|||||||
output_height: int,
|
output_height: int,
|
||||||
layers: List[nn.Module],
|
layers: List[nn.Module],
|
||||||
):
|
):
|
||||||
"""Initialise the Decoder module.
|
"""Initialize the Decoder module.
|
||||||
|
|
||||||
This constructor is typically called by the ``build_decoder``
|
Note: This constructor is typically called internally by the
|
||||||
factory function.
|
`build_decoder` factory function.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
|
||||||
Number of channels in the input tensor (bottleneck output).
|
|
||||||
out_channels : int
|
out_channels : int
|
||||||
Number of channels produced by the final layer.
|
Number of channels produced by the final layer.
|
||||||
input_height : int
|
input_height : int
|
||||||
Height of the input tensor (bottleneck output height).
|
Expected height of the input tensor (bottleneck).
|
||||||
output_height : int
|
in_channels : int
|
||||||
Height of the output tensor after all layers have been applied.
|
Expected number of channels in the input tensor (bottleneck).
|
||||||
layers : List[nn.Module]
|
layers : List[nn.Module]
|
||||||
Pre-built upsampling block modules in execution order (deepest
|
A list of pre-instantiated upscaling layer modules (e.g.,
|
||||||
stage first).
|
`StandardConvUpBlock` or `FreqCoordConvUpBlock`) in the desired
|
||||||
|
sequence (from bottleneck towards output resolution).
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -137,35 +142,43 @@ class Decoder(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residuals: List[torch.Tensor],
|
residuals: List[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Pass input through all decoder layers, incorporating skip connections.
|
"""Pass input through decoder layers, incorporating skip connections.
|
||||||
|
|
||||||
At each stage the corresponding residual tensor is added
|
Processes the input tensor `x` sequentially through the upscaling
|
||||||
element-wise to ``x`` before it is passed to the upsampling block.
|
layers. At each stage, the corresponding skip connection tensor from
|
||||||
Residuals are consumed in reverse order — the last element of
|
the `residuals` list is added element-wise to the input before passing
|
||||||
``residuals`` (the output of the shallowest encoder layer) is added
|
it to the upscaling block.
|
||||||
at the first decoder stage, and the first element (output of the
|
|
||||||
deepest encoder layer) is added at the last decoder stage.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
x : torch.Tensor
|
x : torch.Tensor
|
||||||
Bottleneck feature map, shape ``(B, C_in, H_in, W)``.
|
Input tensor from the previous stage (e.g., encoder bottleneck).
|
||||||
|
Shape `(B, C_in, H_in, W_in)`, where `C_in` matches
|
||||||
|
`self.in_channels`.
|
||||||
residuals : List[torch.Tensor]
|
residuals : List[torch.Tensor]
|
||||||
Skip-connection tensors from the encoder, ordered from shallowest
|
List containing the skip connection tensors from the corresponding
|
||||||
(index 0) to deepest (index -1). Must contain exactly
|
encoder stages. Should be ordered from the deepest encoder layer
|
||||||
``self.depth`` tensors. Each tensor must have the same spatial
|
output (lowest resolution) to the shallowest (highest resolution
|
||||||
dimensions and channel count as ``x`` at the corresponding
|
near input). The number of tensors in this list must match the
|
||||||
decoder stage.
|
number of decoder layers (`self.depth`). Each residual tensor's
|
||||||
|
channel count must be compatible with the input tensor `x` for
|
||||||
|
element-wise addition (or concatenation if the blocks were designed
|
||||||
|
for it).
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
Decoded feature map, shape ``(B, C_out, H_out, W)``.
|
The final decoded feature map tensor produced by the last layer.
|
||||||
|
Shape `(B, C_out, H_out, W_out)`.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If the number of ``residuals`` does not equal ``self.depth``.
|
If the number of `residuals` provided does not match the decoder
|
||||||
|
depth.
|
||||||
|
RuntimeError
|
||||||
|
If shapes mismatch during skip connection addition or layer
|
||||||
|
processing.
|
||||||
"""
|
"""
|
||||||
if len(residuals) != len(self.layers):
|
if len(residuals) != len(self.layers):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -174,7 +187,7 @@ class Decoder(nn.Module):
|
|||||||
f"but got {len(residuals)}."
|
f"but got {len(residuals)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
for layer, res in zip(self.layers, residuals[::-1], strict=False):
|
for layer, res in zip(self.layers, residuals[::-1]):
|
||||||
x = layer(x + res)
|
x = layer(x + res)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
@ -192,55 +205,53 @@ DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
"""Default decoder configuration used in standard BatDetect2 models.
|
"""A default configuration for the Decoder's *layer sequence*.
|
||||||
|
|
||||||
Mirrors ``DEFAULT_ENCODER_CONFIG`` in reverse. Assumes the bottleneck
|
Specifies an architecture often used in BatDetect2, consisting of three
|
||||||
output has 256 channels and height 16, and produces:
|
frequency coordinate-aware upsampling blocks followed by a standard
|
||||||
|
convolutional block.
|
||||||
- Stage 1 (``FreqCoordConvUp``): 64 channels, height 32.
|
|
||||||
- Stage 2 (``FreqCoordConvUp``): 32 channels, height 64.
|
|
||||||
- Stage 3 (``LayerGroup``):
|
|
||||||
|
|
||||||
- ``FreqCoordConvUp``: 32 channels, height 128.
|
|
||||||
- ``ConvBlock``: 32 channels, height 128 (final feature map).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def build_decoder(
|
def build_decoder(
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
input_height: int,
|
input_height: int,
|
||||||
config: DecoderConfig | None = None,
|
config: Optional[DecoderConfig] = None,
|
||||||
) -> Decoder:
|
) -> Decoder:
|
||||||
"""Build a ``Decoder`` from configuration.
|
"""Factory function to build a Decoder instance from configuration.
|
||||||
|
|
||||||
Constructs a sequential ``Decoder`` by iterating over the block
|
Constructs a sequential `Decoder` module based on the layer sequence
|
||||||
configurations in ``config.layers``, building each block with
|
defined in a `DecoderConfig` object and the provided input dimensions
|
||||||
``build_layer``, and tracking the channel count and feature-map height
|
(bottleneck channels and height). If no config is provided, uses the
|
||||||
as they change through the sequence.
|
default layer sequence from `DEFAULT_DECODER_CONFIG`.
|
||||||
|
|
||||||
|
It iteratively builds the layers using the unified `build_layer_from_config`
|
||||||
|
factory (from `.blocks`), tracking the changing number of channels and
|
||||||
|
feature map height required for each subsequent layer.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of channels in the input tensor (bottleneck output). Must
|
The number of channels in the input tensor to the decoder. Must be > 0.
|
||||||
be positive.
|
|
||||||
input_height : int
|
input_height : int
|
||||||
Height (number of frequency bins) of the input tensor. Must be
|
The height (frequency bins) of the input tensor to the decoder. Must be
|
||||||
positive.
|
> 0.
|
||||||
config : DecoderConfig, optional
|
config : DecoderConfig, optional
|
||||||
Configuration specifying the layer sequence. Defaults to
|
The configuration object detailing the sequence of layers and their
|
||||||
``DEFAULT_DECODER_CONFIG`` if not provided.
|
parameters. If None, `DEFAULT_DECODER_CONFIG` is used.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Decoder
|
Decoder
|
||||||
An initialised ``Decoder`` module.
|
An initialized `Decoder` module.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If ``in_channels`` or ``input_height`` are not positive.
|
If `in_channels` or `input_height` are not positive, or if the layer
|
||||||
KeyError
|
configuration is invalid (e.g., empty list, unknown `name`).
|
||||||
If a layer configuration specifies an unknown block type.
|
NotImplementedError
|
||||||
|
If `build_layer_from_config` encounters an unknown `name`.
|
||||||
"""
|
"""
|
||||||
config = config or DEFAULT_DECODER_CONFIG
|
config = config or DEFAULT_DECODER_CONFIG
|
||||||
|
|
||||||
@ -250,13 +261,11 @@ def build_decoder(
|
|||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
for layer_config in config.layers:
|
for layer_config in config.layers:
|
||||||
layer = build_layer(
|
layer, current_channels, current_height = build_layer_from_config(
|
||||||
in_channels=current_channels,
|
in_channels=current_channels,
|
||||||
input_height=current_height,
|
input_height=current_height,
|
||||||
config=layer_config,
|
config=layer_config,
|
||||||
)
|
)
|
||||||
current_height = layer.get_output_height(current_height)
|
|
||||||
current_channels = layer.out_channels
|
|
||||||
layers.append(layer)
|
layers.append(layer)
|
||||||
|
|
||||||
return Decoder(
|
return Decoder(
|
||||||
|
|||||||
@ -1,32 +1,27 @@
|
|||||||
"""Assembles the complete BatDetect2 detection model.
|
"""Assembles the complete BatDetect2 Detection Model.
|
||||||
|
|
||||||
This module defines the ``Detector`` class, which combines a backbone
|
This module defines the concrete `Detector` class, which implements the
|
||||||
feature extractor with prediction heads for detection, classification, and
|
`DetectionModel` interface defined in `.types`. It combines a feature
|
||||||
bounding-box size regression.
|
extraction backbone with specific prediction heads to create the end-to-end
|
||||||
|
neural network used for detecting bat calls, predicting their size, and
|
||||||
|
classifying them.
|
||||||
|
|
||||||
Components
|
The primary components are:
|
||||||
----------
|
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
|
||||||
- ``Detector`` – the ``torch.nn.Module`` that wires together a backbone
|
|
||||||
(``BackboneModel``) with a ``ClassifierHead`` and a ``BBoxHead`` to
|
|
||||||
produce a ``ModelOutput`` tuple from an input spectrogram.
|
|
||||||
- ``build_detector`` – factory function that builds a ready-to-use
|
|
||||||
``Detector`` from a backbone configuration and a target class count.
|
|
||||||
|
|
||||||
Note that ``Detector`` operates purely on spectrogram tensors; raw audio
|
This module focuses purely on the neural network architecture definition. The
|
||||||
preprocessing and output postprocessing are handled by
|
logic for preprocessing inputs and postprocessing/decoding outputs resides in
|
||||||
``batdetect2.preprocess`` and ``batdetect2.postprocess`` respectively.
|
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import BackboneConfig, build_backbone
|
||||||
BackboneConfig,
|
|
||||||
UNetBackboneConfig,
|
|
||||||
build_backbone,
|
|
||||||
)
|
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Detector",
|
"Detector",
|
||||||
@ -35,30 +30,25 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class Detector(DetectionModel):
|
class Detector(DetectionModel):
|
||||||
"""Complete BatDetect2 detection and classification model.
|
"""Concrete implementation of the BatDetect2 Detection Model.
|
||||||
|
|
||||||
Combines a backbone feature extractor with two prediction heads:
|
Assembles a complete detection and classification model by combining a
|
||||||
|
feature extraction backbone network with specific prediction heads for
|
||||||
- ``ClassifierHead``: predicts per-class probabilities at each
|
detection probability, bounding box size regression, and class
|
||||||
time–frequency location.
|
probabilities.
|
||||||
- ``BBoxHead``: predicts call duration and bandwidth at each location.
|
|
||||||
|
|
||||||
The detection probability map is derived from the class probabilities by
|
|
||||||
summing across the class dimension (i.e. the probability that *any* class
|
|
||||||
is present), rather than from a separate detection head.
|
|
||||||
|
|
||||||
Instances are typically created via ``build_detector``.
|
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
backbone : BackboneModel
|
backbone : BackboneModel
|
||||||
The feature extraction backbone.
|
The feature extraction backbone network module.
|
||||||
num_classes : int
|
num_classes : int
|
||||||
Number of target classes (inferred from the classifier head).
|
The number of specific target classes the model predicts (derived from
|
||||||
|
the `classifier_head`).
|
||||||
classifier_head : ClassifierHead
|
classifier_head : ClassifierHead
|
||||||
Produces per-class probability maps from backbone features.
|
The prediction head responsible for generating class probabilities.
|
||||||
bbox_head : BBoxHead
|
bbox_head : BBoxHead
|
||||||
Produces duration and bandwidth predictions from backbone features.
|
The prediction head responsible for generating bounding box size
|
||||||
|
predictions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
backbone: BackboneModel
|
backbone: BackboneModel
|
||||||
@ -69,21 +59,26 @@ class Detector(DetectionModel):
|
|||||||
classifier_head: ClassifierHead,
|
classifier_head: ClassifierHead,
|
||||||
bbox_head: BBoxHead,
|
bbox_head: BBoxHead,
|
||||||
):
|
):
|
||||||
"""Initialise the Detector model.
|
"""Initialize the Detector model.
|
||||||
|
|
||||||
This constructor is typically called by the ``build_detector``
|
Note: Instances are typically created using the `build_detector`
|
||||||
factory function.
|
factory function.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
backbone : BackboneModel
|
backbone : BackboneModel
|
||||||
An initialised backbone module (e.g. built by
|
An initialized feature extraction backbone module (e.g., built by
|
||||||
``build_backbone``).
|
`build_backbone` from the `.backbone` module).
|
||||||
classifier_head : ClassifierHead
|
classifier_head : ClassifierHead
|
||||||
An initialised classification head. The ``num_classes``
|
An initialized classification head module. The number of classes
|
||||||
attribute is read from this head.
|
is inferred from this head.
|
||||||
bbox_head : BBoxHead
|
bbox_head : BBoxHead
|
||||||
An initialised bounding-box size prediction head.
|
An initialized bounding box size prediction head module.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
TypeError
|
||||||
|
If the provided modules are not of the expected types.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -93,34 +88,31 @@ class Detector(DetectionModel):
|
|||||||
self.bbox_head = bbox_head
|
self.bbox_head = bbox_head
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
"""Run the complete detection model on an input spectrogram.
|
"""Perform the forward pass of the complete detection model.
|
||||||
|
|
||||||
Passes the spectrogram through the backbone to produce a feature
|
Processes the input spectrogram through the backbone to extract
|
||||||
map, then applies the classifier and bounding-box heads. The
|
features, then passes these features through the separate prediction
|
||||||
detection probability map is derived by summing the per-class
|
heads to generate detection probabilities, class probabilities, and
|
||||||
probability maps across the class dimension; no separate detection
|
size predictions.
|
||||||
head is used.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
spec : torch.Tensor
|
spec : torch.Tensor
|
||||||
Input spectrogram tensor, shape
|
Input spectrogram tensor, typically with shape
|
||||||
``(batch_size, channels, frequency_bins, time_bins)``.
|
`(batch_size, input_channels, frequency_bins, time_bins)`. The
|
||||||
|
shape must be compatible with the `self.backbone` input
|
||||||
|
requirements.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
ModelOutput
|
ModelOutput
|
||||||
A named tuple with four fields:
|
A NamedTuple containing the four output tensors:
|
||||||
|
- `detection_probs`: Detection probability heatmap `(B, 1, H, W)`.
|
||||||
- ``detection_probs`` – ``(B, 1, H, W)`` – probability that a
|
- `size_preds`: Predicted scaled size dimensions `(B, 2, H, W)`.
|
||||||
call of any class is present at each location. Derived by
|
- `class_probs`: Class probabilities (excluding background)
|
||||||
summing ``class_probs`` over the class dimension.
|
`(B, num_classes, H, W)`.
|
||||||
- ``size_preds`` – ``(B, 2, H, W)`` – scaled duration (channel
|
- `features`: Output feature map from the backbone
|
||||||
0) and bandwidth (channel 1) predictions at each location.
|
`(B, C_out, H, W)`.
|
||||||
- ``class_probs`` – ``(B, num_classes, H, W)`` – per-class
|
|
||||||
probabilities at each location.
|
|
||||||
- ``features`` – ``(B, C_out, H, W)`` – raw backbone feature
|
|
||||||
map.
|
|
||||||
"""
|
"""
|
||||||
features = self.backbone(spec)
|
features = self.backbone(spec)
|
||||||
classification = self.classifier_head(features)
|
classification = self.classifier_head(features)
|
||||||
@ -135,46 +127,40 @@ class Detector(DetectionModel):
|
|||||||
|
|
||||||
|
|
||||||
def build_detector(
|
def build_detector(
|
||||||
num_classes: int,
|
num_classes: int, config: Optional[BackboneConfig] = None
|
||||||
config: BackboneConfig | None = None,
|
|
||||||
backbone: BackboneModel | None = None,
|
|
||||||
) -> DetectionModel:
|
) -> DetectionModel:
|
||||||
"""Build a complete BatDetect2 detection model.
|
"""Build the complete BatDetect2 detection model.
|
||||||
|
|
||||||
Constructs a backbone from ``config``, attaches a ``ClassifierHead``
|
|
||||||
and a ``BBoxHead`` sized to the backbone's output channel count, and
|
|
||||||
returns them wrapped in a ``Detector``.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
num_classes : int
|
num_classes : int
|
||||||
Number of target bat species or call types to predict. Must be
|
The number of specific target classes the model should predict
|
||||||
positive.
|
(required for the `ClassifierHead`). Must be positive.
|
||||||
config : BackboneConfig, optional
|
config : BackboneConfig, optional
|
||||||
Backbone architecture configuration. Defaults to
|
Configuration object specifying the architecture of the backbone
|
||||||
``UNetBackboneConfig()`` (the standard BatDetect2 architecture) if
|
(encoder, bottleneck, decoder). If None, default configurations defined
|
||||||
not provided.
|
within the respective builder functions (`build_encoder`, etc.) will be
|
||||||
|
used to construct a default backbone architecture.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
DetectionModel
|
DetectionModel
|
||||||
An initialised ``Detector`` instance ready for training or
|
An initialized `Detector` model instance.
|
||||||
inference.
|
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If ``num_classes`` is not positive, or if the backbone
|
If `num_classes` is not positive, or if errors occur during the
|
||||||
configuration is invalid.
|
construction of the backbone or detector components (e.g., incompatible
|
||||||
|
configurations, invalid parameters).
|
||||||
"""
|
"""
|
||||||
if backbone is None:
|
config = config or BackboneConfig()
|
||||||
config = config or UNetBackboneConfig()
|
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building model with config: \n{}",
|
"Building model with config: \n{}",
|
||||||
lambda: config.to_yaml_string(), # type: ignore
|
lambda: config.to_yaml_string(),
|
||||||
)
|
)
|
||||||
backbone = build_backbone(config=config)
|
backbone = build_backbone(config=config)
|
||||||
|
|
||||||
classifier_head = ClassifierHead(
|
classifier_head = ClassifierHead(
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
in_channels=backbone.out_channels,
|
in_channels=backbone.out_channels,
|
||||||
|
|||||||
@ -1,27 +1,26 @@
|
|||||||
"""Encoder (downsampling path) for the BatDetect2 backbone.
|
"""Constructs the Encoder part of a configurable neural network backbone.
|
||||||
|
|
||||||
This module defines ``EncoderConfig`` and the ``Encoder`` ``nn.Module``,
|
This module defines the configuration structure (`EncoderConfig`) and provides
|
||||||
together with the ``build_encoder`` factory function.
|
the `Encoder` class (an `nn.Module`) along with a factory function
|
||||||
|
(`build_encoder`) to create sequential encoders. Encoders typically form the
|
||||||
|
downsampling path in architectures like U-Nets, processing input feature maps
|
||||||
|
(like spectrograms) to produce lower-resolution, higher-dimensionality feature
|
||||||
|
representations (bottleneck features).
|
||||||
|
|
||||||
In a U-Net-style network the encoder progressively reduces the spatial
|
The encoder is built dynamically by stacking neural network blocks based on a
|
||||||
resolution of the spectrogram whilst increasing the number of feature
|
list of configuration objects provided in `EncoderConfig.layers`. Each
|
||||||
channels. Each layer in the encoder produces a feature map that is stored
|
configuration object specifies the type of block (e.g., standard convolution,
|
||||||
for use as a skip connection in the corresponding decoder layer.
|
coordinate-feature convolution with downsampling) and its parameters
|
||||||
|
(e.g., output channels). This allows for flexible definition of encoder
|
||||||
|
architectures via configuration files.
|
||||||
|
|
||||||
The encoder is fully configurable: the type, number, and parameters of the
|
The `Encoder`'s `forward` method returns outputs from all intermediate layers,
|
||||||
downsampling blocks are described by an ``EncoderConfig`` object containing
|
suitable for skip connections, while the `encode` method returns only the final
|
||||||
an ordered list of block configuration objects (see ``batdetect2.models.blocks``
|
bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also
|
||||||
for available block types).
|
provided.
|
||||||
|
|
||||||
``Encoder.forward`` returns the outputs of *all* encoder layers as a list,
|
|
||||||
so that skip connections are available to the decoder.
|
|
||||||
``Encoder.encode`` returns only the final output (the input to the bottleneck).
|
|
||||||
|
|
||||||
A default configuration ``DEFAULT_ENCODER_CONFIG`` is provided and used by
|
|
||||||
``build_encoder`` when no explicit configuration is supplied.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, List
|
from typing import Annotated, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -33,7 +32,7 @@ from batdetect2.models.blocks import (
|
|||||||
FreqCoordConvDownConfig,
|
FreqCoordConvDownConfig,
|
||||||
LayerGroupConfig,
|
LayerGroupConfig,
|
||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
build_layer,
|
build_layer_from_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -44,42 +43,47 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
EncoderLayerConfig = Annotated[
|
EncoderLayerConfig = Annotated[
|
||||||
ConvConfig
|
Union[
|
||||||
| FreqCoordConvDownConfig
|
ConvConfig,
|
||||||
| StandardConvDownConfig
|
FreqCoordConvDownConfig,
|
||||||
| LayerGroupConfig,
|
StandardConvDownConfig,
|
||||||
|
LayerGroupConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
||||||
|
|
||||||
|
|
||||||
class EncoderConfig(BaseConfig):
|
class EncoderConfig(BaseConfig):
|
||||||
"""Configuration for the sequential ``Encoder`` module.
|
"""Configuration for building the sequential Encoder module.
|
||||||
|
|
||||||
|
Defines the sequence of neural network blocks that constitute the encoder
|
||||||
|
(downsampling path).
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
layers : List[EncoderLayerConfig]
|
layers : List[EncoderLayerConfig]
|
||||||
Ordered list of block configuration objects defining the encoder's
|
An ordered list of configuration objects, each defining one layer or
|
||||||
downsampling stages. Each entry specifies the block type (via its
|
block in the encoder sequence. Each item must be a valid block config
|
||||||
``name`` field) and any block-specific parameters such as
|
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
|
||||||
``out_channels``. Input channels for each block are inferred
|
`StandardConvDownConfig`) including a `name` field and necessary
|
||||||
automatically from the output of the previous block. Must contain
|
parameters like `out_channels`. Input channels for each layer are
|
||||||
at least one entry.
|
inferred sequentially. The list must contain at least one layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
layers: List[EncoderLayerConfig] = Field(min_length=1)
|
layers: List[EncoderLayerConfig] = Field(min_length=1)
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
"""Sequential encoder module composed of configurable downsampling layers.
|
"""Sequential Encoder module composed of configurable downscaling layers.
|
||||||
|
|
||||||
Executes a series of downsampling blocks in order, storing the output of
|
Constructs the downsampling path of an encoder-decoder network by stacking
|
||||||
each block so that it can be passed as a skip connection to the
|
multiple downscaling blocks.
|
||||||
corresponding decoder layer.
|
|
||||||
|
|
||||||
``forward`` returns the outputs of *all* layers (useful when skip
|
The `forward` method executes the sequence and returns the output feature
|
||||||
connections are needed). ``encode`` returns only the final output
|
map from *each* downscaling stage, facilitating the implementation of skip
|
||||||
(the input to the bottleneck).
|
connections in U-Net-like architectures. The `encode` method returns only
|
||||||
|
the final output tensor (bottleneck features).
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
@ -87,14 +91,14 @@ class Encoder(nn.Module):
|
|||||||
Number of channels expected in the input tensor.
|
Number of channels expected in the input tensor.
|
||||||
input_height : int
|
input_height : int
|
||||||
Height (frequency bins) expected in the input tensor.
|
Height (frequency bins) expected in the input tensor.
|
||||||
out_channels : int
|
output_channels : int
|
||||||
Number of channels in the final output tensor (bottleneck input).
|
Number of channels in the final output tensor (bottleneck).
|
||||||
output_height : int
|
output_height : int
|
||||||
Height (frequency bins) of the final output tensor.
|
Height (frequency bins) expected in the output tensor.
|
||||||
layers : nn.ModuleList
|
layers : nn.ModuleList
|
||||||
Sequence of instantiated downsampling block modules.
|
The sequence of instantiated downscaling layer modules.
|
||||||
depth : int
|
depth : int
|
||||||
Number of downsampling layers.
|
The number of downscaling layers in the encoder.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -105,22 +109,23 @@ class Encoder(nn.Module):
|
|||||||
input_height: int = 128,
|
input_height: int = 128,
|
||||||
in_channels: int = 1,
|
in_channels: int = 1,
|
||||||
):
|
):
|
||||||
"""Initialise the Encoder module.
|
"""Initialize the Encoder module.
|
||||||
|
|
||||||
This constructor is typically called by the ``build_encoder`` factory
|
Note: This constructor is typically called internally by the
|
||||||
function, which takes care of building the ``layers`` list from a
|
`build_encoder` factory function, which prepares the `layers` list.
|
||||||
configuration object.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
output_channels : int
|
output_channels : int
|
||||||
Number of channels produced by the final layer.
|
Number of channels produced by the final layer.
|
||||||
output_height : int
|
output_height : int
|
||||||
Height of the output tensor after all layers have been applied.
|
The expected height of the output tensor.
|
||||||
layers : List[nn.Module]
|
layers : List[nn.Module]
|
||||||
Pre-built downsampling block modules in execution order.
|
A list of pre-instantiated downscaling layer modules (e.g.,
|
||||||
|
`StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired
|
||||||
|
sequence.
|
||||||
input_height : int, default=128
|
input_height : int, default=128
|
||||||
Expected height of the input tensor (frequency bins).
|
Expected height of the input tensor.
|
||||||
in_channels : int, default=1
|
in_channels : int, default=1
|
||||||
Expected number of channels in the input tensor.
|
Expected number of channels in the input tensor.
|
||||||
"""
|
"""
|
||||||
@ -135,30 +140,29 @@ class Encoder(nn.Module):
|
|||||||
self.depth = len(self.layers)
|
self.depth = len(self.layers)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||||
"""Pass input through all encoder layers and return every output.
|
"""Pass input through encoder layers, returns all intermediate outputs.
|
||||||
|
|
||||||
Used when skip connections are needed (e.g. in a U-Net decoder).
|
This method is typically used when the Encoder is part of a U-Net or
|
||||||
|
similar architecture requiring skip connections.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
x : torch.Tensor
|
x : torch.Tensor
|
||||||
Input spectrogram feature map, shape ``(B, C_in, H_in, W)``.
|
Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match
|
||||||
``C_in`` must match ``self.in_channels`` and ``H_in`` must
|
`self.in_channels` and `H_in` must match `self.input_height`.
|
||||||
match ``self.input_height``.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
List[torch.Tensor]
|
List[torch.Tensor]
|
||||||
Output tensors from every layer in order.
|
A list containing the output tensors from *each* downscaling layer
|
||||||
``outputs[0]`` is the output of the first (shallowest) layer;
|
in the sequence. `outputs[0]` is the output of the first layer,
|
||||||
``outputs[-1]`` is the output of the last (deepest) layer,
|
`outputs[-1]` is the final output (bottleneck) of the encoder.
|
||||||
which serves as the input to the bottleneck.
|
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If the input channel count or height does not match the
|
If input tensor channel count or height does not match expected
|
||||||
expected values.
|
values.
|
||||||
"""
|
"""
|
||||||
if x.shape[1] != self.in_channels:
|
if x.shape[1] != self.in_channels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -181,29 +185,28 @@ class Encoder(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Pass input through all encoder layers and return only the final output.
|
"""Pass input through encoder layers, returning only the final output.
|
||||||
|
|
||||||
Use this when skip connections are not needed and you only require
|
This method provides access to the bottleneck features produced after
|
||||||
the bottleneck feature map.
|
the last downscaling layer.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
x : torch.Tensor
|
x : torch.Tensor
|
||||||
Input spectrogram feature map, shape ``(B, C_in, H_in, W)``.
|
Input tensor, shape `(B, C_in, H_in, W)`. Must match expected
|
||||||
Must satisfy the same shape requirements as ``forward``.
|
`in_channels` and `input_height`.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
Output of the last encoder layer, shape
|
The final output tensor (bottleneck features) from the last layer
|
||||||
``(B, C_out, H_out, W)``, where ``C_out`` is
|
of the encoder. Shape `(B, C_out, H_out, W_out)`.
|
||||||
``self.out_channels`` and ``H_out`` is ``self.output_height``.
|
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If the input channel count or height does not match the
|
If input tensor channel count or height does not match expected
|
||||||
expected values.
|
values.
|
||||||
"""
|
"""
|
||||||
if x.shape[1] != self.in_channels:
|
if x.shape[1] != self.in_channels:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -235,57 +238,58 @@ DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
"""Default encoder configuration used in standard BatDetect2 models.
|
"""Default configuration for the Encoder.
|
||||||
|
|
||||||
Assumes a 1-channel input with 128 frequency bins and produces the
|
Specifies an architecture typically used in BatDetect2:
|
||||||
following feature maps:
|
- Input: 1 channel, 128 frequency bins.
|
||||||
|
- Layer 1: FreqCoordConvDown -> 32 channels, H=64
|
||||||
- Stage 1 (``FreqCoordConvDown``): 32 channels, height 64.
|
- Layer 2: FreqCoordConvDown -> 64 channels, H=32
|
||||||
- Stage 2 (``FreqCoordConvDown``): 64 channels, height 32.
|
- Layer 3: FreqCoordConvDown -> 128 channels, H=16
|
||||||
- Stage 3 (``LayerGroup``):
|
- Layer 4: ConvBlock -> 256 channels, H=16 (Bottleneck)
|
||||||
|
|
||||||
- ``FreqCoordConvDown``: 128 channels, height 16.
|
|
||||||
- ``ConvBlock``: 256 channels, height 16 (bottleneck input).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def build_encoder(
|
def build_encoder(
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
input_height: int,
|
input_height: int,
|
||||||
config: EncoderConfig | None = None,
|
config: Optional[EncoderConfig] = None,
|
||||||
) -> Encoder:
|
) -> Encoder:
|
||||||
"""Build an ``Encoder`` from configuration.
|
"""Factory function to build an Encoder instance from configuration.
|
||||||
|
|
||||||
Constructs a sequential ``Encoder`` by iterating over the block
|
Constructs a sequential `Encoder` module based on the layer sequence
|
||||||
configurations in ``config.layers``, building each block with
|
defined in an `EncoderConfig` object and the provided input dimensions.
|
||||||
``build_layer``, and tracking the channel count and feature-map height
|
If no config is provided, uses the default layer sequence from
|
||||||
as they change through the sequence.
|
`DEFAULT_ENCODER_CONFIG`.
|
||||||
|
|
||||||
|
It iteratively builds the layers using the unified
|
||||||
|
`build_layer_from_config` factory (from `.blocks`), tracking the changing
|
||||||
|
number of channels and feature map height required for each subsequent
|
||||||
|
layer, especially for coordinate- aware blocks.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of channels in the input spectrogram tensor. Must be
|
The number of channels expected in the input tensor to the encoder.
|
||||||
positive.
|
Must be > 0.
|
||||||
input_height : int
|
input_height : int
|
||||||
Height (number of frequency bins) of the input spectrogram.
|
The height (frequency bins) expected in the input tensor. Must be > 0.
|
||||||
Must be positive and should be divisible by
|
Crucial for initializing coordinate-aware layers correctly.
|
||||||
``2 ** (number of downsampling stages)`` to avoid size mismatches
|
|
||||||
later in the network.
|
|
||||||
config : EncoderConfig, optional
|
config : EncoderConfig, optional
|
||||||
Configuration specifying the layer sequence. Defaults to
|
The configuration object detailing the sequence of layers and their
|
||||||
``DEFAULT_ENCODER_CONFIG`` if not provided.
|
parameters. If None, `DEFAULT_ENCODER_CONFIG` is used.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Encoder
|
Encoder
|
||||||
An initialised ``Encoder`` module.
|
An initialized `Encoder` module.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If ``in_channels`` or ``input_height`` are not positive.
|
If `in_channels` or `input_height` are not positive, or if the layer
|
||||||
KeyError
|
configuration is invalid (e.g., empty list, unknown `name`).
|
||||||
If a layer configuration specifies an unknown block type.
|
NotImplementedError
|
||||||
|
If `build_layer_from_config` encounters an unknown `name`.
|
||||||
"""
|
"""
|
||||||
if in_channels <= 0 or input_height <= 0:
|
if in_channels <= 0 or input_height <= 0:
|
||||||
raise ValueError("in_channels and input_height must be positive.")
|
raise ValueError("in_channels and input_height must be positive.")
|
||||||
@ -298,14 +302,12 @@ def build_encoder(
|
|||||||
layers = []
|
layers = []
|
||||||
|
|
||||||
for layer_config in config.layers:
|
for layer_config in config.layers:
|
||||||
layer = build_layer(
|
layer, current_channels, current_height = build_layer_from_config(
|
||||||
in_channels=current_channels,
|
in_channels=current_channels,
|
||||||
input_height=current_height,
|
input_height=current_height,
|
||||||
config=layer_config,
|
config=layer_config,
|
||||||
)
|
)
|
||||||
layers.append(layer)
|
layers.append(layer)
|
||||||
current_height = layer.get_output_height(current_height)
|
|
||||||
current_channels = layer.out_channels
|
|
||||||
|
|
||||||
return Encoder(
|
return Encoder(
|
||||||
input_height=input_height,
|
input_height=input_height,
|
||||||
|
|||||||
@ -1,19 +1,20 @@
|
|||||||
"""Prediction heads attached to the backbone feature map.
|
"""Prediction Head modules for BatDetect2 models.
|
||||||
|
|
||||||
Each head is a lightweight ``torch.nn.Module`` that applies a 1×1
|
This module defines simple `torch.nn.Module` subclasses that serve as
|
||||||
convolution to map backbone feature channels to one specific type of
|
prediction heads, typically attached to the output feature map of a backbone
|
||||||
output required by BatDetect2:
|
network
|
||||||
|
|
||||||
- ``DetectorHead``: single-channel detection probability heatmap (sigmoid
|
Each head is responsible for generating one specific type of output required
|
||||||
activation).
|
by the BatDetect2 task:
|
||||||
- ``ClassifierHead``: multi-class probability map over the target bat
|
- `DetectorHead`: Predicts the probability of sound event presence.
|
||||||
species / call types (softmax activation).
|
- `ClassifierHead`: Predicts the probability distribution over target classes.
|
||||||
- ``BBoxHead``: two-channel map of predicted call duration (time axis) and
|
- `BBoxHead`: Predicts the size (width, height) of the sound event's bounding
|
||||||
bandwidth (frequency axis) at each location (no activation; raw
|
box.
|
||||||
regression output).
|
|
||||||
|
|
||||||
All three heads share the same input feature map produced by the backbone,
|
These heads use 1x1 convolutions to map the backbone feature channels
|
||||||
so they can be evaluated in parallel in a single forward pass.
|
to the desired number of output channels for each prediction task at each
|
||||||
|
spatial location, followed by an appropriate activation function (e.g., sigmoid
|
||||||
|
for detection, softmax for classification, none for size regression).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -27,35 +28,42 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class ClassifierHead(nn.Module):
|
class ClassifierHead(nn.Module):
|
||||||
"""Prediction head for species / call-type classification probabilities.
|
"""Prediction head for multi-class classification probabilities.
|
||||||
|
|
||||||
Takes a backbone feature map and produces a probability map where each
|
Takes an input feature map and produces a probability map where each
|
||||||
channel corresponds to a target class. Internally the 1×1 convolution
|
channel corresponds to a specific target class. It uses a 1x1 convolution
|
||||||
maps ``in_channels`` to ``num_classes + 1`` logits (the extra channel
|
to map input channels to `num_classes + 1` outputs (one for each target
|
||||||
represents a generic background / unknown category); a softmax is then
|
class plus an assumed background/generic class), applies softmax across the
|
||||||
applied across the channel dimension and the background channel is
|
channels, and returns the probabilities for the specific target classes
|
||||||
discarded before returning.
|
(excluding the last background/generic channel).
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
num_classes : int
|
num_classes : int
|
||||||
Number of target classes (bat species or call types) to predict,
|
The number of specific target classes the model should predict
|
||||||
excluding the background category. Must be positive.
|
(excluding any background or generic category). Must be positive.
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of channels in the backbone feature map. Must be positive.
|
Number of channels in the input feature map tensor from the backbone.
|
||||||
|
Must be positive.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
num_classes : int
|
num_classes : int
|
||||||
Number of specific output classes (background excluded).
|
Number of specific output classes.
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of input channels expected.
|
Number of input channels expected.
|
||||||
classifier : nn.Conv2d
|
classifier : nn.Conv2d
|
||||||
1×1 convolution with ``num_classes + 1`` output channels.
|
The 1x1 convolutional layer used for prediction.
|
||||||
|
Output channels = num_classes + 1.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `num_classes` or `in_channels` are not positive.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_classes: int, in_channels: int):
|
def __init__(self, num_classes: int, in_channels: int):
|
||||||
"""Initialise the ClassifierHead."""
|
"""Initialize the ClassifierHead."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
@ -68,20 +76,20 @@ class ClassifierHead(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||||
"""Compute per-class probabilities from backbone features.
|
"""Compute class probabilities from input features.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
features : torch.Tensor
|
features : torch.Tensor
|
||||||
Backbone feature map, shape ``(B, C_in, H, W)``.
|
Input feature map tensor from the backbone, typically with shape
|
||||||
|
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
Class probability map, shape ``(B, num_classes, H, W)``.
|
Class probability map tensor with shape `(B, num_classes, H, W)`.
|
||||||
Values are softmax probabilities in the range [0, 1] and
|
Contains probabilities for the specific target classes after
|
||||||
sum to less than 1 per location (the background probability
|
softmax, excluding the implicit background/generic class channel.
|
||||||
is discarded).
|
|
||||||
"""
|
"""
|
||||||
logits = self.classifier(features)
|
logits = self.classifier(features)
|
||||||
probs = torch.softmax(logits, dim=1)
|
probs = torch.softmax(logits, dim=1)
|
||||||
@ -89,30 +97,36 @@ class ClassifierHead(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DetectorHead(nn.Module):
|
class DetectorHead(nn.Module):
|
||||||
"""Prediction head for detection probability (is a call present here?).
|
"""Prediction head for sound event detection probability.
|
||||||
|
|
||||||
Produces a single-channel heatmap where each value indicates the
|
Takes an input feature map and produces a single-channel heatmap where
|
||||||
probability ([0, 1]) that a bat call of *any* species is present at
|
each value represents the probability ([0, 1]) of a relevant sound event
|
||||||
that time–frequency location in the spectrogram.
|
(of any class) being present at that spatial location.
|
||||||
|
|
||||||
Applies a 1×1 convolution mapping ``in_channels`` → 1, followed by
|
Uses a 1x1 convolution to map input channels to 1 output channel, followed
|
||||||
sigmoid activation.
|
by a sigmoid activation function.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of channels in the backbone feature map. Must be positive.
|
Number of channels in the input feature map tensor from the backbone.
|
||||||
|
Must be positive.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of input channels expected.
|
Number of input channels expected.
|
||||||
detector : nn.Conv2d
|
detector : nn.Conv2d
|
||||||
1×1 convolution with a single output channel.
|
The 1x1 convolutional layer mapping to a single output channel.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `in_channels` is not positive.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels: int):
|
def __init__(self, in_channels: int):
|
||||||
"""Initialise the DetectorHead."""
|
"""Initialize the DetectorHead."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
@ -124,49 +138,62 @@ class DetectorHead(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||||
"""Compute detection probabilities from backbone features.
|
"""Compute detection probabilities from input features.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
features : torch.Tensor
|
features : torch.Tensor
|
||||||
Backbone feature map, shape ``(B, C_in, H, W)``.
|
Input feature map tensor from the backbone, typically with shape
|
||||||
|
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
Detection probability heatmap, shape ``(B, 1, H, W)``.
|
Detection probability heatmap tensor with shape `(B, 1, H, W)`.
|
||||||
Values are in the range [0, 1].
|
Values are in the range [0, 1] due to the sigmoid activation.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
RuntimeError
|
||||||
|
If input channel count does not match `self.in_channels`.
|
||||||
"""
|
"""
|
||||||
return torch.sigmoid(self.detector(features))
|
return torch.sigmoid(self.detector(features))
|
||||||
|
|
||||||
|
|
||||||
class BBoxHead(nn.Module):
|
class BBoxHead(nn.Module):
|
||||||
"""Prediction head for bounding box size (duration and bandwidth).
|
"""Prediction head for bounding box size dimensions.
|
||||||
|
|
||||||
Produces a two-channel map where channel 0 predicts the scaled duration
|
Takes an input feature map and produces a two-channel map where each
|
||||||
(time-axis extent) and channel 1 predicts the scaled bandwidth
|
channel represents a predicted size dimension (typically width/duration and
|
||||||
(frequency-axis extent) of the call at each spectrogram location.
|
height/bandwidth) for a potential sound event at that spatial location.
|
||||||
|
|
||||||
Applies a 1×1 convolution mapping ``in_channels`` → 2 with no
|
Uses a 1x1 convolution to map input channels to 2 output channels. No
|
||||||
activation function (raw regression output). The predicted values are
|
activation function is typically applied, as size prediction is often
|
||||||
in a scaled space and must be converted to real units (seconds and Hz)
|
treated as a direct regression task. The output values usually represent
|
||||||
during postprocessing.
|
*scaled* dimensions that need to be un-scaled during postprocessing.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of channels in the backbone feature map. Must be positive.
|
Number of channels in the input feature map tensor from the backbone.
|
||||||
|
Must be positive.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
in_channels : int
|
in_channels : int
|
||||||
Number of input channels expected.
|
Number of input channels expected.
|
||||||
bbox : nn.Conv2d
|
bbox : nn.Conv2d
|
||||||
1×1 convolution with 2 output channels (duration, bandwidth).
|
The 1x1 convolutional layer mapping to 2 output channels
|
||||||
|
(width, height).
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `in_channels` is not positive.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_channels: int):
|
def __init__(self, in_channels: int):
|
||||||
"""Initialise the BBoxHead."""
|
"""Initialize the BBoxHead."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
|
|
||||||
@ -178,19 +205,19 @@ class BBoxHead(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||||
"""Predict call duration and bandwidth from backbone features.
|
"""Compute predicted bounding box dimensions from input features.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
features : torch.Tensor
|
features : torch.Tensor
|
||||||
Backbone feature map, shape ``(B, C_in, H, W)``.
|
Input feature map tensor from the backbone, typically with shape
|
||||||
|
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
torch.Tensor
|
torch.Tensor
|
||||||
Size prediction tensor, shape ``(B, 2, H, W)``. Channel 0 is
|
Predicted size tensor with shape `(B, 2, H, W)`. Channel 0 usually
|
||||||
the predicted scaled duration; channel 1 is the predicted
|
represents scaled width, Channel 1 scaled height. These values
|
||||||
scaled bandwidth. Values must be rescaled to real units during
|
need to be un-scaled during postprocessing.
|
||||||
postprocessing.
|
|
||||||
"""
|
"""
|
||||||
return self.bbox(features)
|
return self.bbox(features)
|
||||||
|
|||||||
@ -1,90 +0,0 @@
|
|||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import NamedTuple, Protocol
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BackboneModel",
|
|
||||||
"BlockProtocol",
|
|
||||||
"BottleneckProtocol",
|
|
||||||
"DecoderProtocol",
|
|
||||||
"DetectionModel",
|
|
||||||
"EncoderDecoderModel",
|
|
||||||
"EncoderProtocol",
|
|
||||||
"ModelOutput",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class BlockProtocol(Protocol):
|
|
||||||
in_channels: int
|
|
||||||
out_channels: int
|
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
def get_output_height(self, input_height: int) -> int: ...
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderProtocol(Protocol):
|
|
||||||
in_channels: int
|
|
||||||
out_channels: int
|
|
||||||
input_height: int
|
|
||||||
output_height: int
|
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ...
|
|
||||||
|
|
||||||
|
|
||||||
class BottleneckProtocol(Protocol):
|
|
||||||
in_channels: int
|
|
||||||
out_channels: int
|
|
||||||
input_height: int
|
|
||||||
|
|
||||||
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
|
|
||||||
class DecoderProtocol(Protocol):
|
|
||||||
in_channels: int
|
|
||||||
out_channels: int
|
|
||||||
input_height: int
|
|
||||||
output_height: int
|
|
||||||
depth: int
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor,
|
|
||||||
residuals: list[torch.Tensor],
|
|
||||||
) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
|
|
||||||
class ModelOutput(NamedTuple):
|
|
||||||
detection_probs: torch.Tensor
|
|
||||||
size_preds: torch.Tensor
|
|
||||||
class_probs: torch.Tensor
|
|
||||||
features: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class BackboneModel(ABC, torch.nn.Module):
|
|
||||||
input_height: int
|
|
||||||
out_channels: int
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderDecoderModel(BackboneModel):
|
|
||||||
bottleneck_channels: int
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def encode(self, spec: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def decode(self, encoded: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionModel(ABC, torch.nn.Module):
|
|
||||||
backbone: BackboneModel
|
|
||||||
classifier_head: torch.nn.Module
|
|
||||||
bbox_head: torch.nn.Module
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput: ...
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user