Compare commits

..

81 Commits

Author SHA1 Message Date
mbsantiago
4b4c3ecdf5 Added loguru as a dependency 2025-04-24 00:20:58 +01:00
mbsantiago
13afac65a5 Added an example conf file 2025-04-24 00:20:46 +01:00
mbsantiago
d8cf1db19f Fix augmentations 2025-04-24 00:20:38 +01:00
mbsantiago
8a6ed3dec7 Starting to add logging to preprocess 2025-04-24 00:20:30 +01:00
mbsantiago
7dd35d6e3e Refine training config 2025-04-23 23:35:42 +01:00
mbsantiago
d51e3f8bbd Rollback to single configs module 2025-04-23 23:35:27 +01:00
mbsantiago
f3999fbba2 Starting to create a dedicated configs module 2025-04-23 23:19:23 +01:00
mbsantiago
2a45859393 Run lint fix 2025-04-23 23:19:08 +01:00
mbsantiago
86d56d65f4 Added jupyter and marimo group dependencies 2025-04-23 23:15:40 +01:00
mbsantiago
27ba8de463 Updated soundevent 2025-04-23 23:15:29 +01:00
mbsantiago
59bd14bc79 Added clips for random cliping and augmentations 2025-04-23 23:15:08 +01:00
mbsantiago
2396815c13 Add a default target config 2025-04-23 23:14:39 +01:00
mbsantiago
ac4bb8f023 Added min and max freq attributes to preprocessor protocol 2025-04-23 23:14:31 +01:00
mbsantiago
6498b6ca37 Added BlockGroups 2025-04-23 23:14:11 +01:00
mbsantiago
bfcab0331e Attend test warnings 2025-04-22 09:17:17 +01:00
mbsantiago
c276230bff Fix batdetect2 test 2025-04-22 09:08:29 +01:00
mbsantiago
e38c446f59 Update gitignore 2025-04-22 09:06:12 +01:00
mbsantiago
24c4831745 Add example annotations 2025-04-22 09:06:06 +01:00
mbsantiago
9fc713d390 Temporary remove compat params module 2025-04-22 09:01:58 +01:00
mbsantiago
285c6a3347 Updating to new Augmentation object 2025-04-22 09:01:46 +01:00
mbsantiago
541be15c9e Update augmentation tests to new structure 2025-04-22 09:00:57 +01:00
mbsantiago
8a463e3942 Remove test migration 2025-04-22 09:00:44 +01:00
mbsantiago
257e1e01bf Fixing small errors in tests 2025-04-22 08:51:21 +01:00
mbsantiago
ece1a2073d Fixed errors with extraction test 2025-04-22 08:50:56 +01:00
mbsantiago
b82973ca1d Move array and tensor util test to corresponding folder 2025-04-22 08:50:43 +01:00
mbsantiago
7c89e82579 Fixing imports after restructuring 2025-04-22 00:36:34 +01:00
mbsantiago
dcae411ccb Add pytest coverage to dev dependencies 2025-04-21 21:26:01 +01:00
mbsantiago
ce15afc231 Restructured models module 2025-04-21 21:25:50 +01:00
mbsantiago
096d180ea3 Create an encoder module 2025-04-21 15:28:56 +01:00
mbsantiago
ffa4c2e5e9 Remove unnecessary config and build modules 2025-04-21 15:28:47 +01:00
mbsantiago
6c744eaac5 Improved docstrings for blocks 2025-04-21 15:28:26 +01:00
mbsantiago
e00674f628 Added better docstrings to types module 2025-04-21 14:06:16 +01:00
mbsantiago
907e05ea48 Minor changes to makefile 2025-04-21 13:22:08 +01:00
mbsantiago
3123d105fd Created a Makefile 2025-04-21 13:18:41 +01:00
mbsantiago
9b6b8a0bf9 Add postprocessing docs 2025-04-20 18:18:58 +01:00
mbsantiago
4aa2e6905c Remove old post_process module 2025-04-20 17:56:07 +01:00
mbsantiago
3abebc9c17 Added working postprocess decoding tests 2025-04-20 17:53:56 +01:00
mbsantiago
1f4454693e Working on postprocess tests 2025-04-20 15:52:25 +01:00
mbsantiago
bcf339c40d Updated postprocess module with docstrings 2025-04-20 13:56:18 +01:00
mbsantiago
089328a4f0 Update targets docs 2025-04-19 20:35:34 +01:00
mbsantiago
6236e78414 Update TargetProtocol and related to include rois 2025-04-19 20:26:18 +01:00
mbsantiago
9410112e41 Create targets.rois module 2025-04-19 19:43:30 +01:00
mbsantiago
07f065cf93 Rename Preprocessor to PreprocessorProtocol 2025-04-19 12:25:47 +01:00
mbsantiago
ae6063918c Add documentation on how to use the legacy format 2025-04-19 12:25:20 +01:00
mbsantiago
a0a77cada1 Attach a data_source tag to clips loaded from different sources 2025-04-19 12:25:03 +01:00
mbsantiago
355847346e Added a data_source term. 2025-04-19 12:24:38 +01:00
mbsantiago
dfd14df7b9 Added a targets protocol 2025-04-19 12:24:24 +01:00
mbsantiago
f353aaa08c Added unit tests to legacy annotation loader 2025-04-18 18:39:58 +01:00
mbsantiago
bf14f4d37e Added docstrings for the batdetect2 legacy annotation format 2025-04-18 15:14:48 +01:00
mbsantiago
b78e5a3a2f Add AOEF loading documentation 2025-04-18 13:53:00 +01:00
mbsantiago
f9e005ec8b Add tests for aoef loading 2025-04-18 13:32:50 +01:00
mbsantiago
fd7f2b0081 Added unit tests for spectrogram preprocessing 2025-04-17 18:31:24 +01:00
mbsantiago
f314942628 Added more docs for preprocessing module 2025-04-17 16:28:48 +01:00
mbsantiago
638f93fe92 Documented the preprocessing module 2025-04-17 15:56:07 +01:00
mbsantiago
19febf2216 Separated the protocols to separate types module 2025-04-17 15:36:21 +01:00
mbsantiago
3417c496db Changed name of PcenScaleConfig to PcenConfig 2025-04-17 15:35:42 +01:00
mbsantiago
4a9af72580 Moved arrays and tensor operations to utils module 2025-04-17 15:35:17 +01:00
mbsantiago
2212246b11 Add documentation to spectrogram 2025-04-17 14:40:20 +01:00
mbsantiago
aca0b58443 Add audio test suite 2025-04-17 13:48:21 +01:00
mbsantiago
f5071d00a1 Add audio documentation 2025-04-16 19:51:10 +01:00
mbsantiago
23620c2233 Added docstrings to audio module 2025-04-16 19:44:30 +01:00
mbsantiago
a9f91322d4 Moved labels.py back to training 2025-04-16 19:44:23 +01:00
mbsantiago
22036743d1 Add target index documentation 2025-04-16 11:01:21 +01:00
mbsantiago
eda5f91c86 Update classes 2025-04-16 00:01:37 +01:00
mbsantiago
a2ec190b73 Add decode functions to classes module 2025-04-15 23:56:24 +01:00
mbsantiago
04ed669c4f Update gitignore 2025-04-15 22:34:50 +01:00
mbsantiago
b796e0bc7b add docs for configs module 2025-04-15 22:34:45 +01:00
mbsantiago
5d4d9a5edf Begin documentation 2025-04-15 20:29:53 +01:00
mbsantiago
55eff0cebd Fix import error 2025-04-15 20:29:39 +01:00
mbsantiago
f99653d68f Added labels documentation 2025-04-15 19:33:45 +01:00
mbsantiago
0778663a2c Add extensive documentation for the labels module 2025-04-15 19:25:58 +01:00
mbsantiago
62471664fa Add tests for target.classes 2025-04-15 18:22:19 +01:00
mbsantiago
af48c33307 Add target classes module 2025-04-15 07:32:58 +01:00
mbsantiago
02d4779207 Add test for target filtering 2025-04-14 13:31:30 +01:00
mbsantiago
d97614a10d Add transform docs 2025-04-13 17:30:39 +01:00
mbsantiago
991529cf86 Create transform module 2025-04-13 16:43:13 +01:00
mbsantiago
02c1d97e5a Add documentation for batdetect2.targets.terms 2025-04-13 14:11:48 +01:00
mbsantiago
3f3f7cd9c8 Add documentation 2025-04-12 18:05:26 +01:00
mbsantiago
2fb3039f17 Create dedicated filtering module 2025-04-12 18:05:20 +01:00
mbsantiago
26a2c5c851 Modify to numpydoc style 2025-04-12 18:05:02 +01:00
mbsantiago
b93d4c65c2 Create separate targets module 2025-04-12 16:48:40 +01:00
126 changed files with 21389 additions and 2995 deletions

3
.gitignore vendored
View File

@ -95,6 +95,8 @@ dmypy.json
*.json
plots/*
!example_data/anns/*.json
# Model experiments
experiments/*
@ -111,3 +113,4 @@ experiments/*
!tests/data/**/*.wav
notebooks/lightning_logs
example_data/preprocessed
.aider*

94
Makefile Normal file
View File

@ -0,0 +1,94 @@
# Variables
SOURCE_DIR = batdetect2
TESTS_DIR = tests
PYTHON_DIRS = batdetect2 tests
DOCS_SOURCE = docs/source
DOCS_BUILD = docs/build
HTML_COVERAGE_DIR = htmlcov
# Default target (optional, often 'help' or 'all')
.DEFAULT_GOAL := help
# Phony targets (targets that don't produce a file with the same name)
.PHONY: help test coverage coverage-html coverage-serve docs docs-serve format format-check lint lint-fix typecheck check clean clean-pyc clean-test clean-docs clean-build
help:
@echo "Makefile Targets:"
@echo " help Show this help message."
@echo " test Run tests using pytest."
@echo " coverage Run tests and generate coverage data (.coverage, coverage.xml)."
@echo " coverage-html Generate an HTML coverage report in $(HTML_COVERAGE_DIR)/."
@echo " coverage-serve Serve the HTML coverage report locally."
@echo " docs Build documentation using Sphinx."
@echo " docs-serve Serve documentation with live reload using sphinx-autobuild."
@echo " format Format code using ruff."
@echo " format-check Check code formatting using ruff."
@echo " lint Lint code using ruff."
@echo " lint-fix Lint code using ruff and apply automatic fixes."
@echo " typecheck Type check code using pyright."
@echo " check Run all checks (format-check, lint, typecheck)."
@echo " clean Remove all build, test, documentation, and Python artifacts."
@echo " clean-pyc Remove Python bytecode and cache."
@echo " clean-test Remove test and coverage artifacts."
@echo " clean-docs Remove built documentation."
@echo " clean-build Remove package build artifacts."
# Testing & Coverage
test:
pytest tests
coverage:
pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml tests
coverage-html: coverage
@echo "Generating HTML coverage report..."
coverage html -d $(HTML_COVERAGE_DIR)
@echo "HTML coverage report generated in $(HTML_COVERAGE_DIR)/"
coverage-serve: coverage-html
@echo "Serving report at http://localhost:8000/ ..."
python -m http.server --directory $(HTML_COVERAGE_DIR) 8000
# Documentation
docs:
sphinx-build -b html $(DOCS_SOURCE) $(DOCS_BUILD)
docs-serve:
sphinx-autobuild $(DOCS_SOURCE) $(DOCS_BUILD) --watch $(SOURCE_DIR) --open-browser
# Formatting & Linting
format:
ruff format $(PYTHON_DIRS)
format-check:
ruff format --check $(PYTHON_DIRS)
lint:
ruff check $(PYTHON_DIRS)
lint-fix:
ruff check --fix $(PYTHON_DIRS)
# Type Checking
typecheck:
pyright $(PYTHON_DIRS)
# Combined Checks
check: format-check lint typecheck test
# Cleaning tasks
clean-pyc:
find . -type f -name "*.py[co]" -delete
find . -type d -name "__pycache__" -exec rm -rf {} +
clean-test:
rm -f .coverage coverage.xml
rm -rf .pytest_cache htmlcov/
clean-docs:
rm -rf $(DOCS_BUILD)
clean-build:
rm -rf build/ dist/ *.egg-info/
clean: clean-build clean-pyc clean-test clean-docs

View File

@ -35,6 +35,8 @@ def summary(
):
base_dir = base_dir or Path.cwd()
dataset = load_dataset_from_config(
dataset_config, field=field, base_dir=base_dir
dataset_config,
field=field,
base_dir=base_dir,
)
print(f"Number of annotated clips: {len(dataset.clip_annotations)}")
print(f"Number of annotated clips: {len(dataset)}")

View File

@ -2,17 +2,14 @@ from pathlib import Path
from typing import Optional
import click
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.preprocess import (
load_preprocessing_config,
)
from batdetect2.train import (
load_label_config,
load_target_config,
preprocess_annotations,
)
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config
from batdetect2.train import load_label_config, preprocess_annotations
from batdetect2.train.labels import build_clip_labeler
__all__ = ["train"]
@ -127,9 +124,9 @@ def train(): ...
def preprocess(
dataset_config: Path,
output: Path,
target_config: Optional[Path] = None,
base_dir: Optional[Path] = None,
preprocess_config: Optional[Path] = None,
target_config: Optional[Path] = None,
label_config: Optional[Path] = None,
force: bool = False,
num_workers: Optional[int] = None,
@ -138,8 +135,13 @@ def preprocess(
label_config_field: Optional[str] = None,
dataset_field: Optional[str] = None,
):
logger.info("Starting preprocessing.")
output = Path(output)
logger.info("Will save outputs to {output}", output=output)
base_dir = base_dir or Path.cwd()
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
preprocess = (
load_preprocessing_config(
@ -174,15 +176,25 @@ def preprocess(
base_dir=base_dir,
)
logger.info(
"Loaded {num_examples} annotated clips from the configured dataset",
num_examples=len(dataset),
)
targets = build_targets(config=target)
preprocessor = build_preprocessor(config=preprocess)
labeller = build_clip_labeler(targets, config=label)
if not output.exists():
logger.debug("Creating directory {directory}", directory=output)
output.mkdir(parents=True)
logger.info("Will start preprocessing")
preprocess_annotations(
dataset.clip_annotations,
dataset,
output_dir=output,
preprocessor=preprocessor,
labeller=labeller,
replace=force,
preprocessing_config=preprocess,
label_config=label,
target_config=target,
max_workers=num_workers,
)

View File

@ -1,151 +1,152 @@
from batdetect2.preprocess import (
AmplitudeScaleConfig,
AudioConfig,
FrequencyConfig,
LogScaleConfig,
PcenScaleConfig,
PreprocessingConfig,
ResampleConfig,
Scales,
SpecSizeConfig,
SpectrogramConfig,
STFTConfig,
)
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
from batdetect2.terms import TagInfo
from batdetect2.train.preprocess import (
HeatmapsConfig,
TargetConfig,
TrainPreprocessingConfig,
)
def get_spectrogram_scale(scale: str) -> Scales:
if scale == "pcen":
return PcenScaleConfig()
if scale == "log":
return LogScaleConfig()
return AmplitudeScaleConfig()
def get_preprocessing_config(params: dict) -> PreprocessingConfig:
return PreprocessingConfig(
audio=AudioConfig(
resample=ResampleConfig(
samplerate=params["target_samp_rate"],
mode="poly",
),
scale=params["scale_raw_audio"],
center=params["scale_raw_audio"],
duration=None,
),
spectrogram=SpectrogramConfig(
stft=STFTConfig(
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
window_fn="hann",
),
frequencies=FrequencyConfig(
min_freq=params["min_freq"],
max_freq=params["max_freq"],
),
scale=get_spectrogram_scale(params["spec_scale"]),
denoise=params["denoise_spec_avg"],
size=SpecSizeConfig(
height=params["spec_height"],
resize_factor=params["resize_factor"],
),
max_scale=params["max_scale_spec"],
),
)
def get_training_preprocessing_config(
params: dict,
) -> TrainPreprocessingConfig:
generic = params["generic_class"][0]
preprocessing = get_preprocessing_config(params)
freq_bin_width, time_bin_width = get_spectrogram_resolution(
preprocessing.spectrogram
)
return TrainPreprocessingConfig(
preprocessing=preprocessing,
target=TargetConfig(
classes=[
TagInfo(key="class", value=class_name, label=class_name)
for class_name in params["class_names"]
],
generic_class=TagInfo(
key="class",
value=generic,
label=generic,
),
include=[
TagInfo(key="event", value=event)
for event in params["events_of_interest"]
],
exclude=[
TagInfo(key="class", value=value)
for value in params["classes_to_ignore"]
],
),
heatmaps=HeatmapsConfig(
position="bottom-left",
time_scale=1 / time_bin_width,
frequency_scale=1 / freq_bin_width,
sigma=params["target_sigma"],
),
)
# 'standardize_classs_names_ip',
# 'convert_to_genus',
# 'genus_mapping',
# 'standardize_classs_names',
# 'genus_names',
# ['data_dir',
# 'ann_dir',
# 'train_split',
# 'model_name',
# 'num_filters',
# 'experiment',
# 'model_file_name',
# 'op_im_dir',
# 'op_im_dir_test',
# 'notes',
# 'spec_divide_factor',
# 'detection_overlap',
# 'ignore_start_end',
# 'detection_threshold',
# 'nms_kernel_size',
# 'nms_top_k_per_sec',
# 'aug_prob',
# 'augment_at_train',
# 'augment_at_train_combine',
# 'echo_max_delay',
# 'stretch_squeeze_delta',
# 'mask_max_time_perc',
# 'mask_max_freq_perc',
# 'spec_amp_scaling',
# 'aug_sampling_rates',
# 'train_loss',
# 'det_loss_weight',
# 'size_loss_weight',
# 'class_loss_weight',
# 'individual_loss_weight',
# 'emb_dim',
# 'lr',
# 'batch_size',
# 'num_workers',
# 'num_epochs',
# 'num_eval_epochs',
# 'device',
# 'save_test_image_during_train',
# 'save_test_image_after_train',
# 'train_sets',
# 'test_sets',
# 'class_inv_freq',
# 'ip_height']
# from batdetect2.preprocess import (
# AmplitudeScaleConfig,
# AudioConfig,
# FrequencyConfig,
# LogScaleConfig,
# PcenConfig,
# PreprocessingConfig,
# ResampleConfig,
# Scales,
# SpecSizeConfig,
# SpectrogramConfig,
# STFTConfig,
# )
# from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
# from batdetect2.targets import (
# LabelConfig,
# TagInfo,
# TargetConfig,
# )
# from batdetect2.train.preprocess import (
# TrainPreprocessingConfig,
# )
#
#
# def get_spectrogram_scale(scale: str) -> Scales:
# if scale == "pcen":
# return PcenConfig()
# if scale == "log":
# return LogScaleConfig()
# return AmplitudeScaleConfig()
#
#
# def get_preprocessing_config(params: dict) -> PreprocessingConfig:
# return PreprocessingConfig(
# audio=AudioConfig(
# resample=ResampleConfig(
# samplerate=params["target_samp_rate"],
# method="poly",
# ),
# scale=params["scale_raw_audio"],
# center=params["scale_raw_audio"],
# duration=None,
# ),
# spectrogram=SpectrogramConfig(
# stft=STFTConfig(
# window_duration=params["fft_win_length"],
# window_overlap=params["fft_overlap"],
# window_fn="hann",
# ),
# frequencies=FrequencyConfig(
# min_freq=params["min_freq"],
# max_freq=params["max_freq"],
# ),
# scale=get_spectrogram_scale(params["spec_scale"]),
# spectral_mean_substraction=params["denoise_spec_avg"],
# size=SpecSizeConfig(
# height=params["spec_height"],
# resize_factor=params["resize_factor"],
# ),
# peak_normalize=params["max_scale_spec"],
# ),
# )
#
#
# def get_training_preprocessing_config(
# params: dict,
# ) -> TrainPreprocessingConfig:
# generic = params["generic_class"][0]
# preprocessing = get_preprocessing_config(params)
#
# freq_bin_width, time_bin_width = get_spectrogram_resolution(
# preprocessing.spectrogram
# )
#
# return TrainPreprocessingConfig(
# preprocessing=preprocessing,
# target=TargetConfig(
# classes=[
# TagInfo(key="class", value=class_name)
# for class_name in params["class_names"]
# ],
# generic_class=TagInfo(
# key="class",
# value=generic,
# ),
# include=[
# TagInfo(key="event", value=event)
# for event in params["events_of_interest"]
# ],
# exclude=[
# TagInfo(key="class", value=value)
# for value in params["classes_to_ignore"]
# ],
# ),
# labels=LabelConfig(
# position="bottom-left",
# time_scale=1 / time_bin_width,
# frequency_scale=1 / freq_bin_width,
# sigma=params["target_sigma"],
# ),
# )
#
#
# # 'standardize_classs_names_ip',
# # 'convert_to_genus',
# # 'genus_mapping',
# # 'standardize_classs_names',
# # 'genus_names',
#
# # ['data_dir',
# # 'ann_dir',
# # 'train_split',
# # 'model_name',
# # 'num_filters',
# # 'experiment',
# # 'model_file_name',
# # 'op_im_dir',
# # 'op_im_dir_test',
# # 'notes',
# # 'spec_divide_factor',
# # 'detection_overlap',
# # 'ignore_start_end',
# # 'detection_threshold',
# # 'nms_kernel_size',
# # 'nms_top_k_per_sec',
# # 'aug_prob',
# # 'augment_at_train',
# # 'augment_at_train_combine',
# # 'echo_max_delay',
# # 'stretch_squeeze_delta',
# # 'mask_max_time_perc',
# # 'mask_max_freq_perc',
# # 'spec_amp_scaling',
# # 'aug_sampling_rates',
# # 'train_loss',
# # 'det_loss_weight',
# # 'size_loss_weight',
# # 'class_loss_weight',
# # 'individual_loss_weight',
# # 'emb_dim',
# # 'lr',
# # 'batch_size',
# # 'num_workers',
# # 'num_epochs',
# # 'num_eval_epochs',
# # 'device',
# # 'save_test_image_during_train',
# # 'save_test_image_after_train',
# # 'train_sets',
# # 'test_sets',
# # 'class_inv_freq',
# # 'ip_height']

View File

@ -1,23 +1,105 @@
"""Provides base classes and utilities for loading configurations in BatDetect2.
This module leverages Pydantic for robust configuration handling, ensuring
that configuration files (typically YAML) adhere to predefined schemas. It
defines a base configuration class (`BaseConfig`) that enforces strict schema
validation and a utility function (`load_config`) to load and validate
configuration data from files, with optional support for accessing nested
configuration sections.
"""
from typing import Any, Optional, Type, TypeVar
import yaml
from pydantic import BaseModel, ConfigDict
from soundevent.data import PathLike
__all__ = [
"BaseConfig",
"load_config",
]
class BaseConfig(BaseModel):
"""Base class for all configuration models in BatDetect2.
Inherits from Pydantic's `BaseModel` to provide data validation, parsing,
and serialization capabilities.
It sets `extra='forbid'` in its model configuration, meaning that any
fields provided in a configuration file that are *not* explicitly defined
in the specific configuration schema will raise a validation error. This
helps catch typos and ensures configurations strictly adhere to the expected
structure.
Attributes
----------
model_config : ConfigDict
Pydantic model configuration dictionary. Set to forbid extra fields.
"""
model_config = ConfigDict(extra="forbid")
T = TypeVar("T", bound=BaseModel)
def get_object_field(obj: dict, field: str) -> Any:
if "." not in field:
return obj[field]
def get_object_field(obj: dict, current_key: str) -> Any:
"""Access a potentially nested field within a dictionary using dot notation.
Parameters
----------
obj : dict
The dictionary (or nested dictionaries) to access.
field : str
The field name to retrieve. Nested fields are specified using dots
(e.g., "parent_key.child_key.target_field").
Returns
-------
Any
The value found at the specified field path.
Raises
------
KeyError
If any part of the field path does not exist in the dictionary
structure.
TypeError
If an intermediate part of the path exists but is not a dictionary,
preventing further nesting.
Examples
--------
>>> data = {"a": {"b": {"c": 10}}}
>>> get_object_field(data, "a.b.c")
10
>>> get_object_field(data, "a.b")
{'c': 10}
>>> get_object_field(data, "a")
{'b': {'c': 10}}
>>> get_object_field(data, "x")
Traceback (most recent call last):
...
KeyError: 'x'
>>> get_object_field(data, "a.x.c")
Traceback (most recent call last):
...
KeyError: 'x'
"""
if "." not in current_key:
return obj[current_key]
current_key, rest = current_key.split(".", 1)
subobj = obj[current_key]
if not isinstance(subobj, dict):
raise TypeError(
f"Intermediate key '{current_key}' in path '{current_key}' "
f"does not lead to a dictionary (found type: {type(subobj)}). "
"Cannot access further nested field."
)
field, rest = field.split(".", 1)
subobj = obj[field]
return get_object_field(subobj, rest)
@ -26,6 +108,49 @@ def load_config(
schema: Type[T],
field: Optional[str] = None,
) -> T:
"""Load and validate configuration data from a file against a schema.
Reads a YAML file, optionally extracts a specific section using dot
notation, and then validates the resulting data against the provided
Pydantic `schema`.
Parameters
----------
path : PathLike
The path to the configuration file (typically `.yaml`).
schema : Type[T]
The Pydantic `BaseModel` subclass that defines the expected structure
and types for the configuration data.
field : str, optional
A dot-separated string indicating a nested section within the YAML
file to extract before validation. If None (default), the entire
file content is validated against the schema.
Example: `"training.optimizer"` would extract the `optimizer` section
within the `training` section.
Returns
-------
T
An instance of the provided `schema`, populated and validated with
data from the configuration file.
Raises
------
FileNotFoundError
If the file specified by `path` does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded configuration data (after optionally extracting the
`field`) does not conform to the provided `schema` (e.g., missing
required fields, incorrect types, extra fields if using BaseConfig).
KeyError
If `field` is provided and specifies a path where intermediate keys
do not exist in the loaded YAML data.
TypeError
If `field` is provided and specifies a path where an intermediate
value is not a dictionary, preventing access to nested fields.
"""
with open(path, "r") as file:
config = yaml.safe_load(file)

View File

@ -1,13 +1,22 @@
from batdetect2.data.annotations import (
AnnotatedDataset,
AOEFAnnotations,
BatDetect2FilesAnnotations,
BatDetect2MergedAnnotations,
load_annotated_dataset,
)
from batdetect2.data.data import load_dataset, load_dataset_from_config
from batdetect2.data.types import Dataset
from batdetect2.data.datasets import (
DatasetConfig,
load_dataset,
load_dataset_from_config,
)
__all__ = [
"AOEFAnnotations",
"AnnotatedDataset",
"Dataset",
"BatDetect2FilesAnnotations",
"BatDetect2MergedAnnotations",
"DatasetConfig",
"load_annotated_dataset",
"load_dataset",
"load_dataset_from_config",

View File

@ -1,36 +0,0 @@
import json
from pathlib import Path
from typing import Literal, Union
from batdetect2.configs import BaseConfig
__all__ = [
"AOEFAnnotationFile",
"AnnotationFormats",
"BatDetect2AnnotationFile",
"BatDetect2AnnotationFiles",
]
class BatDetect2AnnotationFiles(BaseConfig):
format: Literal["batdetect2"] = "batdetect2"
path: Path
class BatDetect2AnnotationFile(BaseConfig):
format: Literal["batdetect2_file"] = "batdetect2_file"
path: Path
class AOEFAnnotationFile(BaseConfig):
format: Literal["aoef"] = "aoef"
path: Path
AnnotationFormats = Union[
BatDetect2AnnotationFiles,
BatDetect2AnnotationFile,
AOEFAnnotationFile,
]

View File

@ -1,29 +1,44 @@
"""Handles loading of annotation data from various formats.
This module serves as the central dispatcher for parsing annotation data
associated with BatDetect2 datasets. Datasets can be composed of multiple
sources, each potentially using a different annotation format (e.g., the
standard AOEF/soundevent format, or legacy BatDetect2 formats).
This module defines the `AnnotationFormats` type, which represents the union
of possible configuration models for these different formats (each identified
by a unique `format` field). The primary function, `load_annotated_dataset`,
inspects the configuration for a single data source and calls the appropriate
format-specific loading function to retrieve the annotations as a standard
`soundevent.data.AnnotationSet`.
"""
from pathlib import Path
from typing import Optional, Union
from soundevent import data
from batdetect2.data.annotations.aeof import (
from batdetect2.data.annotations.aoef import (
AOEFAnnotations,
load_aoef_annotated_dataset,
)
from batdetect2.data.annotations.batdetect2_files import (
from batdetect2.data.annotations.batdetect2 import (
AnnotationFilter,
BatDetect2FilesAnnotations,
load_batdetect2_files_annotated_dataset,
)
from batdetect2.data.annotations.batdetect2_merged import (
BatDetect2MergedAnnotations,
load_batdetect2_files_annotated_dataset,
load_batdetect2_merged_annotated_dataset,
)
from batdetect2.data.annotations.types import AnnotatedDataset
__all__ = [
"load_annotated_dataset",
"AnnotatedDataset",
"AOEFAnnotations",
"AnnotatedDataset",
"AnnotationFilter",
"AnnotationFormats",
"BatDetect2FilesAnnotations",
"BatDetect2MergedAnnotations",
"AnnotationFormats",
"load_annotated_dataset",
]
@ -32,12 +47,52 @@ AnnotationFormats = Union[
BatDetect2FilesAnnotations,
AOEFAnnotations,
]
"""Type Alias representing all supported data source configurations.
Each specific configuration model within this union (e.g., `AOEFAnnotations`,
`BatDetect2FilesAnnotations`) corresponds to a different annotation format
or storage structure. These models are typically discriminated by a `format`
field (e.g., `format="aoef"`, `format="batdetect2_files"`), allowing Pydantic
and functions like `load_annotated_dataset` to determine which format a given
source configuration represents.
"""
def load_annotated_dataset(
dataset: AnnotatedDataset,
base_dir: Optional[Path] = None,
) -> data.AnnotationSet:
"""Load annotations for a single data source based on its configuration.
This function acts as a dispatcher. It inspects the type of the input
`source_config` object (which corresponds to a specific annotation format)
and calls the appropriate loading function (e.g.,
`load_aoef_annotated_dataset` for `AOEFAnnotations`).
Parameters
----------
source_config : AnnotationFormats
The configuration object for the data source, specifying its format
and necessary details (like paths). Must be an instance of one of the
types included in the `AnnotationFormats` union.
base_dir : Path, optional
An optional base directory path. If provided, relative paths within
the `source_config` might be resolved relative to this directory by
the underlying loading functions. Defaults to None.
Returns
-------
soundevent.data.AnnotationSet
An AnnotationSet containing the `ClipAnnotation` objects loaded and
parsed from the specified data source.
Raises
------
NotImplementedError
If the type of the `source_config` object does not match any of the
known format-specific loading functions implemented in the dispatch
logic.
"""
if isinstance(dataset, AOEFAnnotations):
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)

View File

@ -1,37 +0,0 @@
from pathlib import Path
from typing import Literal, Optional
from soundevent import data, io
from batdetect2.data.annotations.types import AnnotatedDataset
__all__ = [
"AOEFAnnotations",
"load_aoef_annotated_dataset",
]
class AOEFAnnotations(AnnotatedDataset):
format: Literal["aoef"] = "aoef"
annotations_path: Path
def load_aoef_annotated_dataset(
dataset: AOEFAnnotations,
base_dir: Optional[Path] = None,
) -> data.AnnotationSet:
audio_dir = dataset.audio_dir
path = dataset.annotations_path
if base_dir:
audio_dir = base_dir / audio_dir
path = base_dir / path
loaded = io.load(path, audio_dir=audio_dir)
if not isinstance(loaded, (data.AnnotationSet, data.AnnotationProject)):
raise ValueError(
f"The AOEF file at {path} does not contain a set of annotations"
)
return loaded

View File

@ -0,0 +1,270 @@
"""Loads annotation data specifically from the AOEF / soundevent format.
This module provides the necessary configuration model and loading function
to handle data sources where annotations are stored in the standard format
used by the `soundevent` library (often as `.json` or `.aoef` files),
which includes outputs from annotation tools like Whombat.
It supports loading both simple `AnnotationSet` files and more complex
`AnnotationProject` files. For `AnnotationProject` files, it offers optional
filtering capabilities to select only annotations associated with tasks
that meet specific status criteria (e.g., completed, verified, without issues).
"""
from pathlib import Path
from typing import Literal, Optional
from uuid import uuid5
from pydantic import Field
from soundevent import data, io
from batdetect2.configs import BaseConfig
from batdetect2.data.annotations.types import AnnotatedDataset
__all__ = [
"AOEFAnnotations",
"load_aoef_annotated_dataset",
"AnnotationTaskFilter",
]
class AnnotationTaskFilter(BaseConfig):
"""Configuration for filtering Annotation Tasks within an AnnotationProject.
Specifies criteria based on task status badges to select relevant
annotations, typically used when loading data from annotation projects
that might contain work-in-progress.
Attributes
----------
only_completed : bool, default=True
If True, only include annotations from tasks marked as 'completed'.
only_verified : bool, default=False
If True, only include annotations from tasks marked as 'verified'.
exclude_issues : bool, default=True
If True, exclude annotations from tasks marked as 'rejected' (indicating
issues).
"""
only_completed: bool = True
only_verified: bool = False
exclude_issues: bool = True
class AOEFAnnotations(AnnotatedDataset):
"""Configuration defining a data source stored in AOEF format.
This model specifies how to load annotations from an AOEF (JSON file) file
compatible with the `soundevent` library. It inherits `name`,
`description`, and `audio_dir` from `AnnotatedDataset`.
Attributes
----------
format : Literal["aoef"]
The fixed format identifier for this configuration type.
annotations_path : Path
The file system path to the `.aoef` or `.json` file containing the
`AnnotationSet` or `AnnotationProject`.
filter : AnnotationTaskFilter, optional
Configuration for filtering tasks if the `annotations_path` points to
an `AnnotationProject`. If omitted, default filtering
(only completed, exclude issues, verification not required) is applied
to projects. Set explicitly to `None` in config (e.g., `filter: null`)
to disable filtering for projects entirely.
"""
format: Literal["aoef"] = "aoef"
annotations_path: Path
filter: Optional[AnnotationTaskFilter] = Field(
default_factory=AnnotationTaskFilter
)
def load_aoef_annotated_dataset(
dataset: AOEFAnnotations,
base_dir: Optional[Path] = None,
) -> data.AnnotationSet:
"""Load annotations from an AnnotationSet or AnnotationProject file.
Reads the file specified in the `dataset` configuration using
`soundevent.io.load`. If the loaded file contains an `AnnotationProject`
and filtering is enabled via `dataset.filter`, it applies the filter
criteria based on task status and returns a new `AnnotationSet` containing
only the selected annotations. If the file contains an `AnnotationSet`,
or if it's a project and filtering is disabled, the all annotations are
returned.
Parameters
----------
dataset : AOEFAnnotations
The configuration object describing the AOEF data source, including
the path to the annotation file and optional filtering settings.
base_dir : Path, optional
An optional base directory. If provided, `dataset.annotations_path`
and `dataset.audio_dir` will be resolved relative to this
directory. Defaults to None.
Returns
-------
soundevent.data.AnnotationSet
An AnnotationSet containing the loaded (and potentially filtered)
`ClipAnnotation` objects.
Raises
------
FileNotFoundError
If the specified `annotations_path` (after resolving `base_dir`)
does not exist.
ValueError
If the loaded file does not contain a valid `AnnotationSet` or
`AnnotationProject`.
Exception
May re-raise errors from `soundevent.io.load` related to parsing
or file format issues.
Notes
-----
- The `soundevent` library handles parsing of `.json` or `.aoef` formats.
- If an `AnnotationProject` is loaded and `dataset.filter` is *not* None,
a *new* `AnnotationSet` instance is created containing only the filtered
clip annotations.
"""
audio_dir = dataset.audio_dir
path = dataset.annotations_path
if base_dir:
audio_dir = base_dir / audio_dir
path = base_dir / path
loaded = io.load(path, audio_dir=audio_dir)
if not isinstance(loaded, (data.AnnotationSet, data.AnnotationProject)):
raise ValueError(
f"The file at {path} loaded successfully but does not "
"contain a soundevent AnnotationSet or AnnotationProject "
f"(loaded type: {type(loaded).__name__})."
)
if isinstance(loaded, data.AnnotationProject) and dataset.filter:
loaded = filter_ready_clips(
loaded,
only_completed=dataset.filter.only_completed,
only_verified=dataset.filter.only_verified,
exclude_issues=dataset.filter.exclude_issues,
)
return loaded
def select_task(
annotation_task: data.AnnotationTask,
only_completed: bool = True,
only_verified: bool = False,
exclude_issues: bool = True,
) -> bool:
"""Check if an AnnotationTask meets specified status criteria.
Evaluates the `status_badges` of the task against the filter flags.
Parameters
----------
annotation_task : data.AnnotationTask
The annotation task to check.
only_completed : bool, default=True
Task must be marked 'completed' to pass.
only_verified : bool, default=False
Task must be marked 'verified' to pass.
exclude_issues : bool, default=True
Task must *not* be marked 'rejected' (have issues) to pass.
Returns
-------
bool
True if the task meets all active filter criteria, False otherwise.
"""
has_issues = False
is_completed = False
is_verified = False
for badge in annotation_task.status_badges:
if badge.state == data.AnnotationState.completed:
is_completed = True
continue
if badge.state == data.AnnotationState.rejected:
has_issues = True
continue
if badge.state == data.AnnotationState.verified:
is_verified = True
if exclude_issues and has_issues:
return False
if only_verified and not is_verified:
return False
if only_completed and not is_completed:
return False
return True
def filter_ready_clips(
annotation_project: data.AnnotationProject,
only_completed: bool = True,
only_verified: bool = False,
exclude_issues: bool = True,
) -> data.AnnotationSet:
"""Filter AnnotationProject to create an AnnotationSet of 'ready' clips.
Iterates through tasks in the project, selects tasks meeting the status
criteria using `select_task`, and creates a new `AnnotationSet` containing
only the `ClipAnnotation` objects associated with those selected tasks.
Parameters
----------
annotation_project : data.AnnotationProject
The input annotation project.
only_completed : bool, default=True
Filter flag passed to `select_task`.
only_verified : bool, default=False
Filter flag passed to `select_task`.
exclude_issues : bool, default=True
Filter flag passed to `select_task`.
Returns
-------
data.AnnotationSet
A new annotation set containing only the clip annotations linked to
tasks that satisfied the filtering criteria. The returned set has a
deterministic UUID based on the project UUID and filter settings.
"""
ready_clip_uuids = set()
for annotation_task in annotation_project.tasks:
if not select_task(
annotation_task,
only_completed=only_completed,
only_verified=only_verified,
exclude_issues=exclude_issues,
):
continue
ready_clip_uuids.add(annotation_task.clip.uuid)
return data.AnnotationSet(
uuid=uuid5(
annotation_project.uuid,
f"{only_completed}_{only_verified}_{exclude_issues}",
),
name=annotation_project.name,
description=annotation_project.description,
clip_annotations=[
annotation
for annotation in annotation_project.clip_annotations
if annotation.clip.uuid in ready_clip_uuids
],
)

View File

@ -0,0 +1,328 @@
"""Loads annotation data from legacy BatDetect2 JSON formats.
This module provides backward compatibility for loading annotation data stored
in two related formats used by older BatDetect2 tools:
1. **`batdetect2` format** (Directory-based): Annotations are stored in
individual JSON files (one per audio recording) within a specified
directory.
Each JSON file contains a `FileAnnotation` structure. Loaded via
`load_batdetect2_files_annotated_dataset` defined by
`BatDetect2FilesAnnotations`.
2. **`batdetect2_file` format** (Single-file): Annotations for multiple
recordings are merged into a single JSON file, containing a list of
`FileAnnotation` objects. Loaded via
`load_batdetect2_merged_annotated_dataset` defined by
`BatDetect2MergedAnnotations`.
Both formats use the same internal structure for annotations per file and
support filtering based on `annotated` and `issues` flags within that
structure.
The loading functions convert data from these legacy formats into the modern
`soundevent` data model (primarily `ClipAnnotation`) and return the results
aggregated into a `soundevent.data.AnnotationSet`.
"""
import json
import os
from pathlib import Path
from typing import Literal, Optional, Union
from loguru import logger
from pydantic import Field, ValidationError
from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.data.annotations.legacy import (
FileAnnotation,
file_annotation_to_clip,
file_annotation_to_clip_annotation,
list_file_annotations,
load_file_annotation,
)
from batdetect2.data.annotations.types import AnnotatedDataset
PathLike = Union[Path, str, os.PathLike]
__all__ = [
"load_batdetect2_files_annotated_dataset",
"load_batdetect2_merged_annotated_dataset",
"BatDetect2FilesAnnotations",
"BatDetect2MergedAnnotations",
"AnnotationFilter",
]
class AnnotationFilter(BaseConfig):
"""Configuration for filtering legacy FileAnnotations based on flags.
Specifies criteria based on boolean flags (`annotated` and `issues`)
present within the legacy `FileAnnotation` JSON structure to select which
entries (either files or records within a merged file) should be loaded and
converted.
Attributes
----------
only_annotated : bool, default=True
If True, only process entries where the `annotated` flag in the JSON
is set to `True`.
exclude_issues : bool, default=True
If True, skip processing entries where the `issues` flag in the JSON
is set to `True`.
"""
only_annotated: bool = True
exclude_issues: bool = True
class BatDetect2FilesAnnotations(AnnotatedDataset):
"""Configuration for the legacy 'batdetect2' format (directory-based).
Defines a data source where annotations are stored as individual JSON files
(one per recording, containing a `FileAnnotation` structure) within the
`annotations_dir`. Requires a corresponding `audio_dir`. Assumes a naming
convention links audio files to JSON files
(e.g., `rec.wav` -> `rec.wav.json`).
Attributes
----------
format : Literal["batdetect2"]
The fixed format identifier for this configuration type.
annotations_dir : Path
Path to the directory containing the individual JSON annotation files.
filter : AnnotationFilter, optional
Configuration for filtering which files to process based on their
`annotated` and `issues` flags. Defaults to requiring `annotated=True`
and `issues=False`. Set explicitly to `None` in config (e.g.,
`filter: null`) to disable filtering.
"""
format: Literal["batdetect2"] = "batdetect2"
annotations_dir: Path
filter: Optional[AnnotationFilter] = Field(
default_factory=AnnotationFilter,
)
class BatDetect2MergedAnnotations(AnnotatedDataset):
"""Configuration for the legacy 'batdetect2_file' format (merged file).
Defines a data source where annotations for multiple recordings (each as a
`FileAnnotation` structure) are stored within a single JSON file specified
by `annotations_path`. Audio files are expected in `audio_dir`.
Inherits `name`, `description`, and `audio_dir` from `AnnotatedDataset`.
Attributes
----------
format : Literal["batdetect2_file"]
The fixed format identifier for this configuration type.
annotations_path : Path
Path to the single JSON file containing a list of `FileAnnotation`
objects.
filter : AnnotationFilter, optional
Configuration for filtering which `FileAnnotation` entries within the
merged file to process based on their `annotated` and `issues` flags.
Defaults to requiring `annotated=True` and `issues=False`. Set to `None`
in config (e.g., `filter: null`) to disable filtering.
"""
format: Literal["batdetect2_file"] = "batdetect2_file"
annotations_path: Path
filter: Optional[AnnotationFilter] = Field(
default_factory=AnnotationFilter,
)
def load_batdetect2_files_annotated_dataset(
dataset: BatDetect2FilesAnnotations,
base_dir: Optional[PathLike] = None,
) -> data.AnnotationSet:
"""Load and convert 'batdetect2_file' annotations into an AnnotationSet.
Scans the specified `annotations_dir` for individual JSON annotation files.
For each file: loads the legacy `FileAnnotation`, applies filtering based
on `dataset.filter` (`annotated`/`issues` flags), attempts to find the
corresponding audio file, converts valid entries to `ClipAnnotation`, and
collects them into a single `soundevent.data.AnnotationSet`.
Parameters
----------
dataset : BatDetect2FilesAnnotations
Configuration describing the 'batdetect2' (directory) data source.
base_dir : PathLike, optional
Optional base directory to resolve relative paths in `dataset.audio_dir`
and `dataset.annotations_dir`. Defaults to None.
Returns
-------
soundevent.data.AnnotationSet
An AnnotationSet containing all successfully loaded, filtered, and
converted `ClipAnnotation` objects.
Raises
------
FileNotFoundError
If the `annotations_dir` or `audio_dir` does not exist. Errors finding
individual JSON or audio files during iteration are logged and skipped.
"""
audio_dir = dataset.audio_dir
path = dataset.annotations_dir
if base_dir:
audio_dir = base_dir / audio_dir
path = base_dir / path
paths = list_file_annotations(path)
logger.debug(
"Found {num_files} files in the annotations directory {path}",
num_files=len(paths),
path=path,
)
annotations = []
for p in paths:
try:
file_annotation = load_file_annotation(p)
except (FileNotFoundError, ValidationError):
logger.warning("Could not load annotations in file {path}", path=p)
continue
if (
dataset.filter
and dataset.filter.only_annotated
and not file_annotation.annotated
):
logger.debug(
"Annotation in file {path} omited: not annotated",
path=p,
)
continue
if (
dataset.filter
and dataset.filter.exclude_issues
and file_annotation.issues
):
logger.debug(
"Annotation in file {path} omited: has issues",
path=p,
)
continue
try:
clip = file_annotation_to_clip(
file_annotation,
audio_dir=audio_dir,
)
except FileNotFoundError as err:
logger.warning(
"Did not find the audio related to the annotation file {path}. Error: {err}",
path=p,
err=err,
)
continue
annotations.append(
file_annotation_to_clip_annotation(
file_annotation,
clip,
)
)
return data.AnnotationSet(
name=dataset.name,
description=dataset.description,
clip_annotations=annotations,
)
def load_batdetect2_merged_annotated_dataset(
dataset: BatDetect2MergedAnnotations,
base_dir: Optional[PathLike] = None,
) -> data.AnnotationSet:
"""Load and convert 'batdetect2_merged' annotations into an AnnotationSet.
Loads a single JSON file containing a list of legacy `FileAnnotation`
objects. For each entry in the list: applies filtering based on
`dataset.filter` (`annotated`/`issues` flags), attempts to find the
corresponding audio file, converts valid entries to `ClipAnnotation`, and
collects them into a single `soundevent.data.AnnotationSet`.
Parameters
----------
dataset : BatDetect2MergedAnnotations
Configuration describing the 'batdetect2_file' (merged) data source.
base_dir : PathLike, optional
Optional base directory to resolve relative paths in `dataset.audio_dir`
and `dataset.annotations_path`. Defaults to None.
Returns
-------
soundevent.data.AnnotationSet
An AnnotationSet containing all successfully loaded, filtered, and
converted `ClipAnnotation` objects from the merged file.
Raises
------
FileNotFoundError
If the `annotations_path` or `audio_dir` does not exist. Errors
finding individual audio files referenced within the JSON are logged
and skipped.
json.JSONDecodeError
If the annotations file is not valid JSON.
TypeError
If the root JSON structure is not a list.
pydantic.ValidationError
If entries within the JSON list do not conform to the legacy
`FileAnnotation` structure.
"""
audio_dir = dataset.audio_dir
path = dataset.annotations_path
if base_dir:
audio_dir = base_dir / audio_dir
path = base_dir / path
content = json.loads(Path(path).read_text())
if not isinstance(content, list):
raise TypeError(
f"Expected a list of FileAnnotations, but got {type(content)}",
)
annotations = []
for ann in content:
try:
ann = FileAnnotation.model_validate(ann)
except ValueError:
continue
if (
dataset.filter
and dataset.filter.only_annotated
and not ann.annotated
):
continue
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
continue
try:
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
except FileNotFoundError:
continue
annotations.append(file_annotation_to_clip_annotation(ann, clip))
return data.AnnotationSet(
name=dataset.name,
description=dataset.description,
clip_annotations=annotations,
)

View File

@ -1,80 +0,0 @@
import os
from pathlib import Path
from typing import Literal, Optional, Union
from soundevent import data
from batdetect2.data.annotations.legacy import (
file_annotation_to_annotation_task,
file_annotation_to_clip,
file_annotation_to_clip_annotation,
list_file_annotations,
load_file_annotation,
)
from batdetect2.data.annotations.types import AnnotatedDataset
PathLike = Union[Path, str, os.PathLike]
__all__ = [
"load_batdetect2_files_annotated_dataset",
"BatDetect2FilesAnnotations",
]
class BatDetect2FilesAnnotations(AnnotatedDataset):
format: Literal["batdetect2"] = "batdetect2"
annotations_dir: Path
def load_batdetect2_files_annotated_dataset(
dataset: BatDetect2FilesAnnotations,
base_dir: Optional[PathLike] = None,
) -> data.AnnotationProject:
"""Convert annotations to annotation project."""
audio_dir = dataset.audio_dir
path = dataset.annotations_dir
if base_dir:
audio_dir = base_dir / audio_dir
path = base_dir / path
paths = list_file_annotations(path)
annotations = []
tasks = []
for p in paths:
try:
file_annotation = load_file_annotation(p)
except FileNotFoundError:
continue
try:
clip = file_annotation_to_clip(
file_annotation,
audio_dir=audio_dir,
)
except FileNotFoundError:
continue
annotations.append(
file_annotation_to_clip_annotation(
file_annotation,
clip,
)
)
tasks.append(
file_annotation_to_annotation_task(
file_annotation,
clip,
)
)
return data.AnnotationProject(
name=dataset.name,
description=dataset.description,
clip_annotations=annotations,
tasks=tasks,
)

View File

@ -1,64 +0,0 @@
import json
import os
from pathlib import Path
from typing import Literal, Optional, Union
from soundevent import data
from batdetect2.data.annotations.legacy import (
FileAnnotation,
file_annotation_to_annotation_task,
file_annotation_to_clip,
file_annotation_to_clip_annotation,
)
from batdetect2.data.annotations.types import AnnotatedDataset
PathLike = Union[Path, str, os.PathLike]
__all__ = [
"BatDetect2MergedAnnotations",
"load_batdetect2_merged_annotated_dataset",
]
class BatDetect2MergedAnnotations(AnnotatedDataset):
format: Literal["batdetect2_file"] = "batdetect2_file"
annotations_path: Path
def load_batdetect2_merged_annotated_dataset(
dataset: BatDetect2MergedAnnotations,
base_dir: Optional[PathLike] = None,
) -> data.AnnotationProject:
audio_dir = dataset.audio_dir
path = dataset.annotations_path
if base_dir:
audio_dir = base_dir / audio_dir
path = base_dir / path
content = json.loads(Path(path).read_text())
annotations = []
tasks = []
for ann in content:
try:
ann = FileAnnotation.model_validate(ann)
except ValueError:
continue
try:
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
except FileNotFoundError:
continue
annotations.append(file_annotation_to_clip_annotation(ann, clip))
tasks.append(file_annotation_to_annotation_task(ann, clip))
return data.AnnotationProject(
name=dataset.name,
description=dataset.description,
clip_annotations=annotations,
tasks=tasks,
)

View File

@ -1,11 +1,9 @@
from pathlib import Path
from typing import Literal, Union
from batdetect2.configs import BaseConfig
__all__ = [
"AnnotatedDataset",
"BatDetect2MergedAnnotations",
]
@ -19,11 +17,10 @@ class AnnotatedDataset(BaseConfig):
Annotations associated with these recordings are defined by the
`annotations` field, which supports various formats (e.g., AOEF files,
specific CSV
structures).
Crucially, file paths referenced within the annotation data *must* be
relative to the `audio_dir`. This ensures that the dataset definition
remains portable across different systems and base directories.
specific CSV structures). Crucially, file paths referenced within the
annotation data *must* be relative to the `audio_dir`. This ensures that
the dataset definition remains portable across different systems and base
directories.
Attributes:
name: A unique identifier for this data source.
@ -37,5 +34,3 @@ class AnnotatedDataset(BaseConfig):
name: str
audio_dir: Path
description: str = ""

View File

@ -1,37 +0,0 @@
from pathlib import Path
from typing import Optional
from soundevent import data
from batdetect2.configs import load_config
from batdetect2.data.annotations import load_annotated_dataset
from batdetect2.data.types import Dataset
__all__ = [
"load_dataset",
"load_dataset_from_config",
]
def load_dataset(
dataset: Dataset,
base_dir: Optional[Path] = None,
) -> data.AnnotationSet:
clip_annotations = []
for source in dataset.sources:
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
clip_annotations.extend(annotated_source.clip_annotations)
return data.AnnotationSet(clip_annotations=clip_annotations)
def load_dataset_from_config(
path: data.PathLike,
field: Optional[str] = None,
base_dir: Optional[Path] = None,
):
config = load_config(
path=path,
schema=Dataset,
field=field,
)
return load_dataset(config, base_dir=base_dir)

251
batdetect2/data/datasets.py Normal file
View File

@ -0,0 +1,251 @@
"""Defines the overall dataset structure and provides loading/saving utilities.
This module focuses on defining what constitutes a BatDetect2 dataset,
potentially composed of multiple distinct data sources with varying annotation
formats. It provides mechanisms to load the annotation metadata from these
sources into a unified representation.
The core components are:
- `DatasetConfig`: A configuration class (typically loaded from YAML) that
describes the dataset's name, description, and constituent sources.
- `Dataset`: A type alias representing the loaded dataset as a list of
`soundevent.data.ClipAnnotation` objects. Note that this implies all
annotation metadata is loaded into memory.
- Loading functions (`load_dataset`, `load_dataset_from_config`): To parse
a `DatasetConfig` and load the corresponding annotation metadata.
- Saving function (`save_dataset`): To save a loaded list of annotations
into a standard `soundevent` format.
"""
from pathlib import Path
from typing import Annotated, List, Optional
from loguru import logger
from pydantic import Field
from soundevent import data, io
from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.annotations import (
AnnotatedDataset,
AnnotationFormats,
load_annotated_dataset,
)
from batdetect2.targets.terms import data_source
__all__ = [
"load_dataset",
"load_dataset_from_config",
"save_dataset",
"Dataset",
"DatasetConfig",
]
Dataset = List[data.ClipAnnotation]
"""Type alias for a loaded dataset representation.
Represents an entire dataset *after loading* as a flat Python list containing
all `soundevent.data.ClipAnnotation` objects gathered from all configured data
sources.
"""
class DatasetConfig(BaseConfig):
"""Configuration model defining the structure of a BatDetect2 dataset.
This class is typically loaded from a YAML file and describes the components
of the dataset, including metadata and a list of data sources.
Attributes
----------
name : str
A descriptive name for the dataset (e.g., "UK_Bats_Project_2024").
description : str
A longer description of the dataset's contents, origin, purpose, etc.
sources : List[AnnotationFormats]
A list defining the different data sources contributing to this
dataset. Each item in the list must conform to one of the Pydantic
models defined in the `AnnotationFormats` type union. The specific
model used for each source is determined by the mandatory `format`
field within the source's configuration, allowing BatDetect2 to use the
correct parser for different annotation styles.
"""
name: str
description: str
sources: List[
Annotated[AnnotationFormats, Field(..., discriminator="format")]
]
def load_dataset(
dataset: DatasetConfig,
base_dir: Optional[Path] = None,
) -> Dataset:
"""Load all clip annotations from the sources defined in a DatasetConfig.
Iterates through each data source specified in the `dataset_config`,
delegates the loading and parsing of that source's annotations to
`batdetect2.data.annotations.load_annotated_dataset` (which handles
different data formats), and aggregates all resulting `ClipAnnotation`
objects into a single flat list.
Parameters
----------
dataset_config : DatasetConfig
The configuration object describing the dataset and its sources.
base_dir : Path, optional
An optional base directory path. If provided, relative paths for
metadata files or data directories within the `dataset_config`'s
sources might be resolved relative to this directory. Defaults to None.
Returns
-------
Dataset (List[data.ClipAnnotation])
A flat list containing all loaded `ClipAnnotation` metadata objects
from all specified sources.
Raises
------
Exception
Can raise various exceptions during the delegated loading process
(`load_annotated_dataset`) if files are not found, cannot be parsed
according to the specified format, or other I/O errors occur.
"""
clip_annotations = []
for source in dataset.sources:
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
logger.debug(
"Loaded {num_examples} from dataset source '{source_name}'",
num_examples=len(annotated_source.clip_annotations),
source_name=source.name,
)
clip_annotations.extend(
insert_source_tag(clip_annotation, source)
for clip_annotation in annotated_source.clip_annotations
)
return clip_annotations
def insert_source_tag(
clip_annotation: data.ClipAnnotation,
source: AnnotatedDataset,
) -> data.ClipAnnotation:
"""Insert the source tag into a ClipAnnotation.
This function adds a tag to the `ClipAnnotation` object, indicating the
source from which it was loaded. The source information is derived from
the `recording` attribute of the `ClipAnnotation`.
Parameters
----------
clip_annotation : data.ClipAnnotation
The `ClipAnnotation` object to which the source tag will be added.
Returns
-------
data.ClipAnnotation
The modified `ClipAnnotation` object with the source tag added.
"""
return clip_annotation.model_copy(
update=dict(
tags=[
*clip_annotation.tags,
data.Tag(
term=data_source,
value=source.name,
),
]
),
)
def load_dataset_from_config(
path: data.PathLike,
field: Optional[str] = None,
base_dir: Optional[Path] = None,
) -> Dataset:
"""Load dataset annotation metadata from a configuration file.
This is a convenience function that first loads the `DatasetConfig` from
the specified file path and optional nested field, and then calls
`load_dataset` to load all corresponding `ClipAnnotation` objects.
Parameters
----------
path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
Dot-separated path to a nested section within the file containing the
dataset configuration (e.g., "data.training_set"). If None, the
entire file content is assumed to be the `DatasetConfig`.
base_dir : Path, optional
An optional base directory path to resolve relative paths within the
configuration sources. Passed to `load_dataset`. Defaults to None.
Returns
-------
Dataset (List[data.ClipAnnotation])
A flat list containing all loaded `ClipAnnotation` metadata objects.
Raises
------
FileNotFoundError
If the config file `path` does not exist.
yaml.YAMLError, pydantic.ValidationError, KeyError, TypeError
If the configuration file is invalid, cannot be parsed, or does not
match the `DatasetConfig` schema.
Exception
Can raise exceptions from `load_dataset` if loading data from sources
fails.
"""
config = load_config(
path=path,
schema=DatasetConfig,
field=field,
)
return load_dataset(config, base_dir=base_dir)
def save_dataset(
dataset: Dataset,
path: data.PathLike,
name: Optional[str] = None,
description: Optional[str] = None,
audio_dir: Optional[Path] = None,
) -> None:
"""Save a loaded dataset (list of ClipAnnotations) to a file.
Wraps the provided list of `ClipAnnotation` objects into a
`soundevent.data.AnnotationSet` and saves it using `soundevent.io.save`.
This saves the aggregated annotation metadata in the standard soundevent
format.
Note: This function saves the *loaded annotation data*, not the original
`DatasetConfig` structure that defined how the data was assembled from
various sources.
Parameters
----------
dataset : Dataset (List[data.ClipAnnotation])
The list of clip annotations to save (typically the result of
`load_dataset` or a split thereof).
path : data.PathLike
The output file path (e.g., 'train_annotations.json',
'val_annotations.json'). The format is determined by `soundevent.io`.
name : str, optional
An optional name to assign to the saved `AnnotationSet`.
description : str, optional
An optional description to assign to the saved `AnnotationSet`.
audio_dir : Path, optional
Passed to `soundevent.io.save`. May be used to relativize audio file
paths within the saved annotations if applicable to the save format.
"""
annotation_set = data.AnnotationSet(
name=name,
description=description,
clip_annotations=dataset,
)
io.save(annotation_set, path, audio_dir=audio_dir)

View File

@ -1,29 +0,0 @@
from typing import Annotated, List
from pydantic import Field
from batdetect2.configs import BaseConfig
from batdetect2.data.annotations import AnnotationFormats
class Dataset(BaseConfig):
"""Represents a collection of one or more DatasetSources.
In the context of batdetect2, a Dataset aggregates multiple `DatasetSource`
instances. It serves as the primary unit for defining data splits,
typically used for model training, validation, or testing phases.
Attributes:
name: A descriptive name for the overall dataset
(e.g., "UK Training Set").
description: A detailed explanation of the dataset's purpose,
composition, how it was assembled, or any specific characteristics.
sources: A list containing the `DatasetSource` objects included in this
dataset.
"""
name: str
description: str
sources: List[
Annotated[AnnotationFormats, Field(..., discriminator="format")]
]

View File

@ -0,0 +1,9 @@
from batdetect2.evaluate.evaluate import (
compute_error_auc,
match_predictions_and_annotations,
)
__all__ = [
"compute_error_auc",
"match_predictions_and_annotations",
]

View File

@ -1,13 +1,10 @@
from typing import List
import numpy as np
import pandas as pd
from sklearn.metrics import auc, roc_curve
from soundevent import data
from soundevent.evaluation import match_geometries
from batdetect2.train.targets import build_target_encoder, get_class_names
def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation,
@ -51,20 +48,13 @@ def match_predictions_and_annotations(
return matches
def build_evaluation_dataframe(matches: List[data.Match]) -> pd.DataFrame:
ret = []
for match in matches:
pass
def compute_error_auc(op_str, gt, pred, prob):
# classification error
pred_int = (pred > prob).astype(np.int32)
class_acc = (pred_int == gt).mean() * 100.0
# ROC - area under curve
fpr, tpr, thresholds = roc_curve(gt, pred)
fpr, tpr, _ = roc_curve(gt, pred)
roc_auc = auc(fpr, tpr)
print(
@ -177,7 +167,7 @@ def compute_pre_rec(
file_ids.append([pid] * valid_inds.sum())
confidence = np.hstack(confidence)
file_ids = np.hstack(file_ids).astype(np.int)
file_ids = np.hstack(file_ids).astype(int)
pred_boxes = np.vstack(pred_boxes)
if len(pred_class) > 0:
pred_class = np.hstack(pred_class)
@ -197,8 +187,7 @@ def compute_pre_rec(
# note, files with the incorrect duration will cause a problem
if (gg["start_times"] > file_dur).sum() > 0:
print("Error: file duration incorrect for", gg["id"])
assert False
raise ValueError(f"Error: file duration incorrect for {gg['id']}")
boxes = np.vstack(
(
@ -244,6 +233,8 @@ def compute_pre_rec(
gt_id = file_ids[ind]
valid_det = False
det_ind = 0
if gt_boxes[gt_id].shape[0] > 0:
# compute overlap
valid_det, det_ind = compute_affinity_1d(
@ -273,7 +264,7 @@ def compute_pre_rec(
# store threshold values - used for plotting
conf_sorted = np.sort(confidence)[::-1][valid_inds]
thresholds = np.linspace(0.1, 0.9, 9)
thresholds_inds = np.zeros(len(thresholds), dtype=np.int)
thresholds_inds = np.zeros(len(thresholds), dtype=int)
for ii, tt in enumerate(thresholds):
thresholds_inds[ii] = np.argmin(conf_sorted > tt)
thresholds_inds[thresholds_inds == 0] = -1
@ -385,7 +376,7 @@ def compute_file_accuracy(gts, preds, num_classes):
).mean(0)
best_thresh = np.argmax(acc_per_thresh)
best_acc = acc_per_thresh[best_thresh]
pred_valid = pred_valid_all[:, best_thresh].astype(np.int).tolist()
pred_valid = pred_valid_all[:, best_thresh].astype(int).tolist()
res = {}
res["num_valid_files"] = len(gt_valid)

View File

@ -1,92 +1,135 @@
from enum import Enum
from typing import Optional, Tuple
"""Defines and builds the neural network models used in BatDetect2.
from soundevent.data import PathLike
This package (`batdetect2.models`) contains the PyTorch implementations of the
deep neural network architectures used for detecting and classifying bat calls
from spectrograms. It provides modular components and configuration-driven
assembly, allowing for experimentation and use of different architectural
variants.
Key Submodules:
- `.types`: Defines core data structures (`ModelOutput`) and abstract base
classes (`BackboneModel`, `DetectionModel`) establishing interfaces.
- `.blocks`: Provides reusable neural network building blocks.
- `.encoder`: Defines and builds the downsampling path (encoder) of the network.
- `.bottleneck`: Defines and builds the central bottleneck component.
- `.decoder`: Defines and builds the upsampling path (decoder) of the network.
- `.backbone`: Assembles the encoder, bottleneck, and decoder into a complete
feature extraction backbone (e.g., a U-Net like structure).
- `.heads`: Defines simple prediction heads (detection, classification, size)
that attach to the backbone features.
- `.detectors`: Assembles the backbone and prediction heads into the final,
end-to-end `Detector` model.
This module re-exports the most important classes, configurations, and builder
functions from these submodules for convenient access. The primary entry point
for creating a standard BatDetect2 model instance is the `build_model` function
provided here.
"""
from typing import Optional
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.backbones import (
Net2DFast,
Net2DFastNoAttn,
Net2DFastNoCoordConv,
Net2DPlain,
Backbone,
BackboneConfig,
build_backbone,
load_backbone_config,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.models.typing import BackboneModel
from batdetect2.models.blocks import (
ConvConfig,
FreqCoordConvDownConfig,
FreqCoordConvUpConfig,
StandardConvDownConfig,
StandardConvUpConfig,
)
from batdetect2.models.bottleneck import (
Bottleneck,
BottleneckConfig,
build_bottleneck,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
build_decoder,
)
from batdetect2.models.detectors import (
Detector,
build_detector,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
EncoderConfig,
build_encoder,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
__all__ = [
"BBoxHead",
"Backbone",
"BackboneConfig",
"BackboneModel",
"BackboneModel",
"Bottleneck",
"BottleneckConfig",
"ClassifierHead",
"ModelConfig",
"ModelType",
"Net2DFast",
"Net2DFastNoAttn",
"Net2DFastNoCoordConv",
"build_architecture",
"load_model_config",
"ConvConfig",
"DEFAULT_DECODER_CONFIG",
"DEFAULT_ENCODER_CONFIG",
"DecoderConfig",
"DetectionModel",
"Detector",
"DetectorHead",
"EncoderConfig",
"FreqCoordConvDownConfig",
"FreqCoordConvUpConfig",
"ModelOutput",
"StandardConvDownConfig",
"StandardConvUpConfig",
"build_backbone",
"build_bottleneck",
"build_decoder",
"build_detector",
"build_encoder",
"build_model",
"load_backbone_config",
]
class ModelType(str, Enum):
Net2DFast = "Net2DFast"
Net2DFastNoAttn = "Net2DFastNoAttn"
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
Net2DPlain = "Net2DPlain"
def build_model(
num_classes: int,
config: Optional[BackboneConfig] = None,
) -> DetectionModel:
"""Build the complete BatDetect2 detection model.
This high-level factory function constructs the standard BatDetect2 model
architecture. It first builds the feature extraction backbone (typically an
encoder-bottleneck-decoder structure) based on the provided
`BackboneConfig` (or defaults if None), and then attaches the standard
prediction heads (`DetectorHead`, `ClassifierHead`, `BBoxHead`) using the
`build_detector` function.
class ModelConfig(BaseConfig):
name: ModelType = ModelType.Net2DFast
input_height: int = 128
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
bottleneck_channels: int = 256
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
out_channels: int = 32
Parameters
----------
num_classes : int
The number of specific target classes the model should predict
(required for the `ClassifierHead`). Must be positive.
config : BackboneConfig, optional
Configuration object specifying the architecture of the backbone
(encoder, bottleneck, decoder). If None, default configurations defined
within the respective builder functions (`build_encoder`, etc.) will be
used to construct a default backbone architecture.
Returns
-------
DetectionModel
An initialized `Detector` model instance.
def load_model_config(
path: PathLike, field: Optional[str] = None
) -> ModelConfig:
return load_config(path, schema=ModelConfig, field=field)
def build_architecture(
config: Optional[ModelConfig] = None,
) -> BackboneModel:
config = config or ModelConfig()
if config.name == ModelType.Net2DFast:
return Net2DFast(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.name == ModelType.Net2DFastNoAttn:
return Net2DFastNoAttn(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.name == ModelType.Net2DFastNoCoordConv:
return Net2DFastNoCoordConv(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
if config.name == ModelType.Net2DPlain:
return Net2DPlain(
input_height=config.input_height,
encoder_channels=config.encoder_channels,
bottleneck_channels=config.bottleneck_channels,
decoder_channels=config.decoder_channels,
out_channels=config.out_channels,
)
raise ValueError(f"Unknown model type: {config.name}")
Raises
------
ValueError
If `num_classes` is not positive, or if errors occur during the
construction of the backbone or detector components (e.g., incompatible
configurations, invalid parameters).
"""
backbone = build_backbone(config or BackboneConfig())
return build_detector(num_classes, backbone)

View File

@ -1,181 +1,353 @@
from typing import Sequence, Tuple
"""Assembles a complete Encoder-Decoder Backbone network.
This module defines the configuration (`BackboneConfig`) and implementation
(`Backbone`) for a standard encoder-decoder style neural network backbone.
It orchestrates the connection between three main components, built using their
respective configurations and factory functions from sibling modules:
1. Encoder (`batdetect2.models.encoder`): Downsampling path, extracts features
at multiple resolutions and provides skip connections.
2. Bottleneck (`batdetect2.models.bottleneck`): Processes features at the
lowest resolution, optionally applying self-attention.
3. Decoder (`batdetect2.models.decoder`): Upsampling path, reconstructs high-
resolution features using bottleneck features and skip connections.
The resulting `Backbone` module takes a spectrogram as input and outputs a
final feature map, typically used by subsequent prediction heads. It includes
automatic padding to handle input sizes not perfectly divisible by the
network's total downsampling factor.
"""
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from soundevent import data
from torch import nn
from batdetect2.models.blocks import (
ConvBlock,
Decoder,
DownscalingLayer,
Encoder,
SelfAttention,
UpscalingLayer,
VerticalConv,
)
from batdetect2.models.typing import BackboneModel
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.blocks import ConvBlock
from batdetect2.models.bottleneck import BottleneckConfig, build_bottleneck
from batdetect2.models.decoder import Decoder, DecoderConfig, build_decoder
from batdetect2.models.encoder import Encoder, EncoderConfig, build_encoder
from batdetect2.models.types import BackboneModel
__all__ = [
"Net2DFast",
"Net2DFastNoAttn",
"Net2DFastNoCoordConv",
"Backbone",
"BackboneConfig",
"load_backbone_config",
"build_backbone",
]
class Net2DPlain(BackboneModel):
downscaling_layer_type: DownscalingLayer = "ConvBlockDownStandard"
upscaling_layer_type: UpscalingLayer = "ConvBlockUpStandard"
class Backbone(BackboneModel):
"""Encoder-Decoder Backbone Network Implementation.
Combines an Encoder, Bottleneck, and Decoder module sequentially, using
skip connections between the Encoder and Decoder. Implements the standard
U-Net style forward pass. Includes automatic input padding to handle
various input sizes and a final convolutional block to adjust the output
channels.
This class inherits from `BackboneModel` and implements its `forward`
method. Instances are typically created using the `build_backbone` factory
function.
Attributes
----------
input_height : int
Expected height of the input spectrogram.
out_channels : int
Number of channels in the final output feature map.
encoder : Encoder
The instantiated encoder module.
decoder : Decoder
The instantiated decoder module.
bottleneck : nn.Module
The instantiated bottleneck module.
final_conv : ConvBlock
Final convolutional block applied after the decoder.
divide_factor : int
The total downsampling factor (2^depth) applied by the encoder,
used for automatic input padding.
"""
def __init__(
self,
input_height: int = 128,
encoder_channels: Sequence[int] = (1, 32, 64, 128),
bottleneck_channels: int = 256,
decoder_channels: Sequence[int] = (256, 64, 32, 32),
out_channels: int = 32,
input_height: int,
out_channels: int,
encoder: Encoder,
decoder: Decoder,
bottleneck: nn.Module,
):
super().__init__()
"""Initialize the Backbone network.
Parameters
----------
input_height : int
Expected height of the input spectrogram.
out_channels : int
Desired number of output channels for the backbone's feature map.
encoder : Encoder
An initialized Encoder module.
decoder : Decoder
An initialized Decoder module.
bottleneck : nn.Module
An initialized Bottleneck module.
Raises
------
ValueError
If component output/input channels or heights are incompatible.
"""
super().__init__()
self.input_height = input_height
self.encoder_channels = tuple(encoder_channels)
self.decoder_channels = tuple(decoder_channels)
self.out_channels = out_channels
if len(encoder_channels) != len(decoder_channels):
raise ValueError(
f"Mismatched encoder and decoder channel lists. "
f"The encoder has {len(encoder_channels)} channels "
f"(implying {len(encoder_channels) - 1} layers), "
f"while the decoder has {len(decoder_channels)} channels "
f"(implying {len(decoder_channels) - 1} layers). "
f"These lengths must be equal."
)
self.encoder = encoder
self.decoder = decoder
self.bottleneck = bottleneck
self.divide_factor = 2 ** (len(encoder_channels) - 1)
if self.input_height % self.divide_factor != 0:
raise ValueError(
f"Input height ({self.input_height}) must be divisible by "
f"the divide factor ({self.divide_factor}). "
f"This ensures proper upscaling after downscaling to recover "
f"the original input height."
)
self.encoder = Encoder(
channels=encoder_channels,
input_height=self.input_height,
layer_type=self.downscaling_layer_type,
)
self.conv_same_1 = ConvBlock(
in_channels=encoder_channels[-1],
out_channels=bottleneck_channels,
)
# bottleneck
self.conv_vert = VerticalConv(
in_channels=bottleneck_channels,
out_channels=bottleneck_channels,
input_height=self.input_height // (2**self.encoder.depth),
)
self.decoder = Decoder(
channels=decoder_channels,
input_height=self.input_height,
layer_type=self.upscaling_layer_type,
)
self.conv_same_2 = ConvBlock(
in_channels=decoder_channels[-1],
self.final_conv = ConvBlock(
in_channels=decoder.out_channels,
out_channels=out_channels,
)
# Down/Up scaling factor. Need to ensure inputs are divisible by
# this factor in order to be processed by the down/up scaling layers
# and recover the correct shape
self.divide_factor = input_height // self.encoder.output_height
def forward(self, spec: torch.Tensor) -> torch.Tensor:
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
"""Perform the forward pass through the encoder-decoder backbone.
Applies padding, runs encoder, bottleneck, decoder (with skip
connections), removes padding, and applies a final convolution.
Parameters
----------
spec : torch.Tensor
Input spectrogram tensor, shape `(B, C_in, H_in, W_in)`. Must match
`self.encoder.input_channels` and `self.input_height`.
Returns
-------
torch.Tensor
Output feature map tensor, shape `(B, C_out, H_in, W_in)`, where
`C_out` is `self.out_channels`.
"""
spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor)
# encoder
residuals = self.encoder(spec)
residuals[-1] = self.conv_same_1(residuals[-1])
# bottleneck
x = self.conv_vert(residuals[-1])
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
x = self.bottleneck(residuals[-1])
# decoder
x = self.decoder(x, residuals=residuals)
# Restore original size
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
x = _restore_pad(x, h_pad=h_pad, w_pad=w_pad)
return self.conv_same_2(x)
return self.final_conv(x)
class Net2DFast(Net2DPlain):
downscaling_layer_type = "ConvBlockDownCoordF"
upscaling_layer_type = "ConvBlockUpF"
class BackboneConfig(BaseConfig):
"""Configuration for the Encoder-Decoder Backbone network.
def __init__(
self,
input_height: int = 128,
encoder_channels: Sequence[int] = (1, 32, 64, 128),
bottleneck_channels: int = 256,
decoder_channels: Sequence[int] = (256, 64, 32, 32),
out_channels: int = 32,
):
super().__init__(
input_height=input_height,
encoder_channels=encoder_channels,
bottleneck_channels=bottleneck_channels,
decoder_channels=decoder_channels,
out_channels=out_channels,
Aggregates configurations for the encoder, bottleneck, and decoder
components, along with defining the input and final output dimensions
for the complete backbone.
Attributes
----------
input_height : int, default=128
Expected height (frequency bins) of the input spectrograms to the
backbone. Must be positive.
in_channels : int, default=1
Expected number of channels in the input spectrograms (e.g., 1 for
mono). Must be positive.
encoder : EncoderConfig, optional
Configuration for the encoder. If None or omitted,
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
encoder module) will be used.
bottleneck : BottleneckConfig, optional
Configuration for the bottleneck layer connecting encoder and decoder.
If None or omitted, the default bottleneck configuration will be used.
decoder : DecoderConfig, optional
Configuration for the decoder. If None or omitted,
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
decoder module) will be used.
out_channels : int, default=32
Desired number of channels in the final feature map output by the
backbone. Must be positive.
"""
input_height: int = 128
in_channels: int = 1
encoder: Optional[EncoderConfig] = None
bottleneck: Optional[BottleneckConfig] = None
decoder: Optional[DecoderConfig] = None
out_channels: int = 32
def load_backbone_config(
path: data.PathLike,
field: Optional[str] = None,
) -> BackboneConfig:
"""Load the backbone configuration from a file.
Reads a configuration file (YAML) and validates it against the
`BackboneConfig` schema, potentially extracting data from a nested field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
backbone configuration (e.g., "model.backbone"). If None, the entire
file content is used.
Returns
-------
BackboneConfig
The loaded and validated backbone configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded config data does not conform to `BackboneConfig`.
KeyError, TypeError
If `field` specifies an invalid path.
"""
return load_config(path, schema=BackboneConfig, field=field)
def build_backbone(config: BackboneConfig) -> BackboneModel:
"""Factory function to build a Backbone from configuration.
Constructs the `Encoder`, `Bottleneck`, and `Decoder` components based on
the provided `BackboneConfig`, validates their compatibility, and assembles
them into a `Backbone` instance.
Parameters
----------
config : BackboneConfig
The configuration object detailing the backbone architecture, including
input dimensions and configurations for encoder, bottleneck, and
decoder.
Returns
-------
BackboneModel
An initialized `Backbone` module ready for use.
Raises
------
ValueError
If sub-component configurations are incompatible
(e.g., channel mismatches, decoder output height doesn't match backbone
input height).
NotImplementedError
If an unknown block type is specified in sub-configs.
"""
encoder = build_encoder(
in_channels=config.in_channels,
input_height=config.input_height,
config=config.encoder,
)
bottleneck = build_bottleneck(
input_height=encoder.output_height,
in_channels=encoder.out_channels,
config=config.bottleneck,
)
decoder = build_decoder(
in_channels=bottleneck.out_channels,
input_height=encoder.output_height,
config=config.decoder,
)
if decoder.output_height != config.input_height:
raise ValueError(
"Invalid configuration: Decoder output height "
f"({decoder.output_height}) must match the Backbone input height "
f"({config.input_height}). Check encoder/decoder layer "
"configurations and input/bottleneck heights."
)
self.att = SelfAttention(bottleneck_channels, bottleneck_channels)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
# encoder
residuals = self.encoder(spec)
residuals[-1] = self.conv_same_1(residuals[-1])
# bottleneck
x = self.conv_vert(residuals[-1])
x = self.att(x)
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
# decoder
x = self.decoder(x, residuals=residuals)
# Restore original size
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
return self.conv_same_2(x)
return Backbone(
input_height=config.input_height,
out_channels=config.out_channels,
encoder=encoder,
decoder=decoder,
bottleneck=bottleneck,
)
class Net2DFastNoAttn(Net2DPlain):
downscaling_layer_type = "ConvBlockDownCoordF"
upscaling_layer_type = "ConvBlockUpF"
class Net2DFastNoCoordConv(Net2DFast):
downscaling_layer_type = "ConvBlockDownStandard"
upscaling_layer_type = "ConvBlockUpStandard"
def pad_adjust(
def _pad_adjust(
spec: torch.Tensor,
factor: int = 32,
) -> Tuple[torch.Tensor, int, int]:
print(spec.shape)
"""Pad tensor height and width to be divisible by a factor.
Calculates the required padding for the last two dimensions (H, W) to make
them divisible by `factor` and applies right/bottom padding using
`torch.nn.functional.pad`.
Parameters
----------
spec : torch.Tensor
Input tensor, typically shape `(B, C, H, W)`.
factor : int, default=32
The factor to make height and width divisible by.
Returns
-------
Tuple[torch.Tensor, int, int]
A tuple containing:
- The padded tensor.
- The amount of padding added to height (`h_pad`).
- The amount of padding added to width (`w_pad`).
"""
h, w = spec.shape[2:]
h_pad = -h % factor
w_pad = -w % factor
if h_pad == 0 and w_pad == 0:
return spec, 0, 0
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
def restore_pad(
def _restore_pad(
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
) -> torch.Tensor:
# Restore original size
"""Remove padding added by _pad_adjust.
Removes padding from the bottom and right edges of the tensor.
Parameters
----------
x : torch.Tensor
Padded tensor, typically shape `(B, C, H_padded, W_padded)`.
h_pad : int, default=0
Amount of padding previously added to the height (bottom).
w_pad : int, default=0
Amount of padding previously added to the width (right).
Returns
-------
torch.Tensor
Tensor with padding removed, shape `(B, C, H_original, W_original)`.
"""
if h_pad > 0:
x = x[:, :, :-h_pad, :]

View File

@ -1,42 +1,107 @@
"""Module containing custom NN blocks.
"""Commonly used neural network building blocks for BatDetect2 models.
All these classes are subclasses of `torch.nn.Module` and can be used to build
complex neural network architectures.
This module provides various reusable `torch.nn.Module` subclasses that form
the fundamental building blocks for constructing convolutional neural network
architectures, particularly encoder-decoder backbones used in BatDetect2.
It includes standard components like basic convolutional blocks (`ConvBlock`),
blocks incorporating downsampling (`StandardConvDownBlock`), and blocks with
upsampling (`StandardConvUpBlock`).
Additionally, it features specialized layers investigated in BatDetect2
research:
- `SelfAttention`: Applies self-attention along the time dimension, enabling
the model to weigh information across the entire temporal context, often
used in the bottleneck of an encoder-decoder.
- `FreqCoordConvDownBlock` / `FreqCoordConvUpBlock`: Implement the "CoordConv"
concept by concatenating normalized frequency coordinate information as an
extra channel to the input of convolutional layers. This explicitly provides
spatial frequency information to filters, potentially enabling them to learn
frequency-dependent patterns more effectively.
These blocks can be utilized directly in custom PyTorch model definitions or
assembled into larger architectures.
A unified factory function `build_layer_from_config` allows creating instances
of these blocks based on configuration objects.
"""
import sys
from typing import Iterable, List, Literal, Sequence, Tuple
from typing import Annotated, List, Literal, Tuple, Union
import torch
import torch.nn.functional as F
from pydantic import Field
from torch import nn
if sys.version_info >= (3, 10):
from itertools import pairwise
else:
def pairwise(iterable: Sequence) -> Iterable:
for x, y in zip(iterable[:-1], iterable[1:]):
yield x, y
from batdetect2.configs import BaseConfig
__all__ = [
"ConvBlock",
"ConvBlockDownCoordF",
"ConvBlockDownStandard",
"ConvBlockUpF",
"ConvBlockUpStandard",
"SelfAttention",
"BlockGroupConfig",
"VerticalConv",
"DownscalingLayer",
"UpscalingLayer",
"FreqCoordConvDownBlock",
"StandardConvDownBlock",
"FreqCoordConvUpBlock",
"StandardConvUpBlock",
"SelfAttention",
"ConvConfig",
"FreqCoordConvDownConfig",
"StandardConvDownConfig",
"FreqCoordConvUpConfig",
"StandardConvUpConfig",
"LayerConfig",
"build_layer_from_config",
]
class SelfAttention(nn.Module):
"""Self-Attention module.
"""Self-Attention mechanism operating along the time dimension.
This module implements self-attention mechanism.
This module implements a scaled dot-product self-attention mechanism,
specifically designed here to operate across the time steps of an input
feature map, typically after spatial dimensions (like frequency) have been
condensed or squeezed.
By calculating attention weights between all pairs of time steps, it allows
the model to capture long-range temporal dependencies and focus on relevant
parts of the sequence. It's often employed in the bottleneck or
intermediate layers of an encoder-decoder architecture to integrate global
temporal context.
The implementation uses linear projections to create query, key, and value
representations, computes scaled dot-product attention scores, applies
softmax, and produces an output by weighting the values according to the
attention scores, followed by a final linear projection. Positional encoding
is not explicitly included in this block.
Parameters
----------
in_channels : int
Number of input channels (features per time step after spatial squeeze).
attention_channels : int
Number of channels for the query, key, and value projections. Also the
dimension of the output projection's input.
temperature : float, default=1.0
Scaling factor applied *before* the final projection layer. Can be used
to adjust the sharpness or focus of the attention mechanism, although
scaling within the softmax (dividing by sqrt(dim)) is more common for
standard transformers. Here it scales the weighted values.
Attributes
----------
key_fun : nn.Linear
Linear layer for key projection.
value_fun : nn.Linear
Linear layer for value projection.
query_fun : nn.Linear
Linear layer for query projection.
pro_fun : nn.Linear
Final linear projection layer applied after attention weighting.
temperature : float
Scaling factor applied before final projection.
att_dim : int
Dimensionality of the attention space (`attention_channels`).
"""
def __init__(
@ -56,6 +121,27 @@ class SelfAttention(nn.Module):
self.pro_fun = nn.Linear(attention_channels, in_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply self-attention along the time dimension.
Parameters
----------
x : torch.Tensor
Input tensor, expected shape `(B, C, H, W)`, where H is typically
squeezed (e.g., H=1 after a `VerticalConv` or pooling) before
applying attention along the W (time) dimension.
Returns
-------
torch.Tensor
Output tensor of the same shape as the input `(B, C, H, W)`, where
attention has been applied across the W dimension.
Raises
------
RuntimeError
If input tensor dimensions are incompatible with operations.
"""
x = x.squeeze(2).permute(0, 2, 1)
key = torch.matmul(
@ -82,7 +168,42 @@ class SelfAttention(nn.Module):
return op
class ConvConfig(BaseConfig):
"""Configuration for a basic ConvBlock."""
block_type: Literal["ConvBlock"] = "ConvBlock"
"""Discriminator field indicating the block type."""
out_channels: int
"""Number of output channels."""
kernel_size: int = 3
"""Size of the square convolutional kernel."""
pad_size: int = 1
"""Padding size."""
class ConvBlock(nn.Module):
"""Basic Convolutional Block.
A standard building block consisting of a 2D convolution, followed by
batch normalization and a ReLU activation function.
Sequence: Conv2d -> BatchNorm2d -> ReLU.
Parameters
----------
in_channels : int
Number of channels in the input tensor.
out_channels : int
Number of channels produced by the convolution.
kernel_size : int, default=3
Size of the square convolutional kernel.
pad_size : int, default=1
Amount of padding added to preserve spatial dimensions.
"""
def __init__(
self,
in_channels: int,
@ -100,27 +221,41 @@ class ConvBlock(nn.Module):
self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Conv -> BN -> ReLU.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H, W)`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H, W)`.
"""
return F.relu_(self.conv_bn(self.conv(x)))
class VerticalConv(nn.Module):
"""Convolutional layer over full height.
"""Convolutional layer that aggregates features across the entire height.
This layer applies a convolution that captures information across the
entire height of the input image. It uses a kernel with the same height as
the input, effectively condensing the vertical information into a single
output row.
Applies a 2D convolution using a kernel with shape `(input_height, 1)`.
This collapses the height dimension (H) to 1 while preserving the width (W),
effectively summarizing features across the full vertical extent (e.g.,
frequency axis) at each time step. Followed by BatchNorm and ReLU.
More specifically:
Useful for summarizing frequency information before applying operations
along the time axis (like SelfAttention).
* **Input:** (B, C, H, W) where B is the batch size, C is the number of
input channels, H is the image height, and W is the image width.
* **Kernel:** (C', H, 1) where C' is the number of output channels.
* **Output:** (B, C', 1, W) - The height dimension is 1 because the
convolution integrates information from all rows of the input.
This process effectively extracts features that span the full height of
the input image.
Parameters
----------
in_channels : int
Number of channels in the input tensor.
out_channels : int
Number of channels produced by the convolution.
input_height : int
The height (H dimension) of the input tensor. The convolutional kernel
will be sized `(input_height, 1)`.
"""
def __init__(
@ -139,14 +274,66 @@ class VerticalConv(nn.Module):
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Vertical Conv -> BN -> ReLU.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H, W)`, where H must match the
`input_height` provided during initialization.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, 1, W)`.
"""
return F.relu_(self.bn(self.conv(x)))
class ConvBlockDownCoordF(nn.Module):
"""Convolutional Block with Downsampling and Coord Feature.
class FreqCoordConvDownConfig(BaseConfig):
"""Configuration for a FreqCoordConvDownBlock."""
This block performs convolution followed by downsampling
and concatenates with coordinate information.
block_type: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
"""Discriminator field indicating the block type."""
out_channels: int
"""Number of output channels."""
kernel_size: int = 3
"""Size of the square convolutional kernel."""
pad_size: int = 1
"""Padding size."""
class FreqCoordConvDownBlock(nn.Module):
"""Downsampling Conv Block incorporating Frequency Coordinate features.
This block implements a downsampling step (Conv2d + MaxPool2d) commonly
used in CNN encoders. Before the convolution, it concatenates an extra
channel representing the normalized vertical coordinate (frequency) to the
input tensor.
The purpose of adding coordinate features is to potentially help the
convolutional filters become spatially aware, allowing them to learn
patterns that might depend on the relative frequency position within the
spectrogram.
Sequence: Concat Coords -> Conv -> MaxPool -> BatchNorm -> ReLU.
Parameters
----------
in_channels : int
Number of channels in the input tensor.
out_channels : int
Number of output channels after the convolution.
input_height : int
Height (H dimension, frequency bins) of the input tensor to this block.
Used to generate the coordinate features.
kernel_size : int, default=3
Size of the square convolutional kernel.
pad_size : int, default=1
Padding added before convolution.
"""
def __init__(
@ -156,7 +343,6 @@ class ConvBlockDownCoordF(nn.Module):
input_height: int,
kernel_size: int = 3,
pad_size: int = 1,
stride: int = 1,
):
super().__init__()
@ -169,11 +355,24 @@ class ConvBlockDownCoordF(nn.Module):
out_channels,
kernel_size=kernel_size,
padding=pad_size,
stride=stride,
stride=1,
)
self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply CoordF -> Conv -> MaxPool -> BN -> ReLU.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H, W)`, where H must match
`input_height`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H/2, W/2)` (due to MaxPool).
"""
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
x = torch.cat((x, freq_info), 1)
x = F.max_pool2d(self.conv(x), 2, 2)
@ -181,10 +380,40 @@ class ConvBlockDownCoordF(nn.Module):
return x
class ConvBlockDownStandard(nn.Module):
"""Convolutional Block with Downsampling.
class StandardConvDownConfig(BaseConfig):
"""Configuration for a StandardConvDownBlock."""
This block performs convolution followed by downsampling.
block_type: Literal["StandardConvDown"] = "StandardConvDown"
"""Discriminator field indicating the block type."""
out_channels: int
"""Number of output channels."""
kernel_size: int = 3
"""Size of the square convolutional kernel."""
pad_size: int = 1
"""Padding size."""
class StandardConvDownBlock(nn.Module):
"""Standard Downsampling Convolutional Block.
A basic downsampling block consisting of a 2D convolution, followed by
2x2 max pooling, batch normalization, and ReLU activation.
Sequence: Conv -> MaxPool -> BN -> ReLU.
Parameters
----------
in_channels : int
Number of channels in the input tensor.
out_channels : int
Number of output channels after the convolution.
kernel_size : int, default=3
Size of the square convolutional kernel.
pad_size : int, default=1
Padding added before convolution.
"""
def __init__(
@ -193,31 +422,83 @@ class ConvBlockDownStandard(nn.Module):
out_channels: int,
kernel_size: int = 3,
pad_size: int = 1,
stride: int = 1,
):
super(ConvBlockDownStandard, self).__init__()
super(StandardConvDownBlock, self).__init__()
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=pad_size,
stride=stride,
stride=1,
)
self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
"""Apply Conv -> MaxPool -> BN -> ReLU.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H, W)`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H/2, W/2)`.
"""
x = F.max_pool2d(self.conv(x), 2, 2)
return F.relu(self.conv_bn(x), inplace=True)
DownscalingLayer = Literal["ConvBlockDownStandard", "ConvBlockDownCoordF"]
class FreqCoordConvUpConfig(BaseConfig):
"""Configuration for a FreqCoordConvUpBlock."""
block_type: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
"""Discriminator field indicating the block type."""
out_channels: int
"""Number of output channels."""
kernel_size: int = 3
"""Size of the square convolutional kernel."""
pad_size: int = 1
"""Padding size."""
class ConvBlockUpF(nn.Module):
"""Convolutional Block with Upsampling and Coord Feature.
class FreqCoordConvUpBlock(nn.Module):
"""Upsampling Conv Block incorporating Frequency Coordinate features.
This block performs convolution followed by upsampling
and concatenates with coordinate information.
This block implements an upsampling step followed by a convolution,
commonly used in CNN decoders. Before the convolution, it concatenates an
extra channel representing the normalized vertical coordinate (frequency)
of the *upsampled* feature map.
The goal is to provide spatial awareness (frequency position) to the
filters during the decoding/upsampling process.
Sequence: Interpolate -> Concat Coords -> Conv -> BatchNorm -> ReLU.
Parameters
----------
in_channels : int
Number of channels in the input tensor (before upsampling).
out_channels : int
Number of output channels after the convolution.
input_height : int
Height (H dimension, frequency bins) of the tensor *before* upsampling.
Used to calculate the height for coordinate feature generation after
upsampling.
kernel_size : int, default=3
Size of the square convolutional kernel.
pad_size : int, default=1
Padding added before convolution.
up_mode : str, default="bilinear"
Interpolation mode for upsampling (e.g., "nearest", "bilinear",
"bicubic").
up_scale : Tuple[int, int], default=(2, 2)
Scaling factor for height and width during upsampling
(typically (2, 2)).
"""
def __init__(
@ -249,6 +530,19 @@ class ConvBlockUpF(nn.Module):
self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Interpolate -> Concat Coords -> Conv -> BN -> ReLU.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W_in)`, where H_in should match
`input_height` used during initialization.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H_in * scale_h, W_in * scale_w)`.
"""
op = F.interpolate(
x,
size=(
@ -265,10 +559,45 @@ class ConvBlockUpF(nn.Module):
return op
class ConvBlockUpStandard(nn.Module):
"""Convolutional Block with Upsampling.
class StandardConvUpConfig(BaseConfig):
"""Configuration for a StandardConvUpBlock."""
This block performs convolution followed by upsampling.
block_type: Literal["StandardConvUp"] = "StandardConvUp"
"""Discriminator field indicating the block type."""
out_channels: int
"""Number of output channels."""
kernel_size: int = 3
"""Size of the square convolutional kernel."""
pad_size: int = 1
"""Padding size."""
class StandardConvUpBlock(nn.Module):
"""Standard Upsampling Convolutional Block.
A basic upsampling block used in CNN decoders. It first upsamples the input
feature map using interpolation, then applies a 2D convolution, batch
normalization, and ReLU activation. Does not use coordinate features.
Sequence: Interpolate -> Conv -> BN -> ReLU.
Parameters
----------
in_channels : int
Number of channels in the input tensor (before upsampling).
out_channels : int
Number of output channels after the convolution.
kernel_size : int, default=3
Size of the square convolutional kernel.
pad_size : int, default=1
Padding added before convolution.
up_mode : str, default="bilinear"
Interpolation mode for upsampling (e.g., "nearest", "bilinear").
up_scale : Tuple[int, int], default=(2, 2)
Scaling factor for height and width during upsampling.
"""
def __init__(
@ -280,7 +609,7 @@ class ConvBlockUpStandard(nn.Module):
up_mode: str = "bilinear",
up_scale: Tuple[int, int] = (2, 2),
):
super(ConvBlockUpStandard, self).__init__()
super(StandardConvUpBlock, self).__init__()
self.up_scale = up_scale
self.up_mode = up_mode
self.conv = nn.Conv2d(
@ -292,6 +621,18 @@ class ConvBlockUpStandard(nn.Module):
self.conv_bn = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Interpolate -> Conv -> BN -> ReLU.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W_in)`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H_in * scale_h, W_in * scale_w)`.
"""
op = F.interpolate(
x,
size=(
@ -306,141 +647,142 @@ class ConvBlockUpStandard(nn.Module):
return op
UpscalingLayer = Literal["ConvBlockUpStandard", "ConvBlockUpF"]
LayerConfig = Annotated[
Union[
ConvConfig,
FreqCoordConvDownConfig,
StandardConvDownConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
"BlockGroupConfig",
],
Field(discriminator="block_type"),
]
"""Type alias for the discriminated union of block configuration models."""
def build_downscaling_layer(
in_channels: int,
out_channels: int,
class BlockGroupConfig(BaseConfig):
block_type: Literal["group"] = "group"
blocks: List[LayerConfig]
def build_layer_from_config(
input_height: int,
layer_type: DownscalingLayer,
) -> nn.Module:
if layer_type == "ConvBlockDownStandard":
return ConvBlockDownStandard(
in_channels=in_channels,
out_channels=out_channels,
)
if layer_type == "ConvBlockDownCoordF":
return ConvBlockDownCoordF(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
)
raise ValueError(
f"Invalid downscaling layer type {layer_type}. "
f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard"
)
class Encoder(nn.Module):
def __init__(
self,
channels: Sequence[int] = (1, 32, 62, 128),
input_height: int = 128,
layer_type: Literal[
"ConvBlockDownStandard", "ConvBlockDownCoordF"
] = "ConvBlockDownStandard",
):
super().__init__()
self.channels = channels
self.input_height = input_height
self.layers = nn.ModuleList(
[
build_downscaling_layer(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height // (2**layer_num),
layer_type=layer_type,
)
for layer_num, (in_channels, out_channels) in enumerate(
pairwise(channels)
)
]
)
self.depth = len(self.layers)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
outputs = []
for layer in self.layers:
x = layer(x)
outputs.append(x)
return outputs
def build_upscaling_layer(
in_channels: int,
out_channels: int,
input_height: int,
layer_type: UpscalingLayer,
) -> nn.Module:
if layer_type == "ConvBlockUpStandard":
return ConvBlockUpStandard(
in_channels=in_channels,
out_channels=out_channels,
config: LayerConfig,
) -> Tuple[nn.Module, int, int]:
"""Factory function to build a specific nn.Module block from its config.
Takes configuration object (one of the types included in the `LayerConfig`
union) and instantiates the corresponding nn.Module block with the correct
parameters derived from the config and the current pipeline state
(`input_height`, `in_channels`).
It uses the `block_type` field within the `config` object to determine
which block class to instantiate.
Parameters
----------
input_height : int
Height (frequency bins) of the input tensor *to this layer*.
in_channels : int
Number of channels in the input tensor *to this layer*.
config : LayerConfig
A Pydantic configuration object for the desired block (e.g., an
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
by its `block_type` field.
Returns
-------
Tuple[nn.Module, int, int]
A tuple containing:
- The instantiated `nn.Module` block.
- The number of output channels produced by the block.
- The calculated height of the output produced by the block.
Raises
------
NotImplementedError
If the `config.block_type` does not correspond to a known block type.
ValueError
If parameters derived from the config are invalid for the block.
"""
if config.block_type == "ConvBlock":
return (
ConvBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height,
)
if layer_type == "ConvBlockUpF":
return ConvBlockUpF(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
if config.block_type == "FreqCoordConvDown":
return (
FreqCoordConvDownBlock(
in_channels=in_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height // 2,
)
raise ValueError(
f"Invalid upscaling layer type {layer_type}. "
f"Valid values: ConvBlockUpStandard, ConvBlockUpF"
)
class Decoder(nn.Module):
def __init__(
self,
channels: Sequence[int] = (256, 62, 32, 32),
input_height: int = 128,
layer_type: Literal[
"ConvBlockUpStandard", "ConvBlockUpF"
] = "ConvBlockUpStandard",
):
super().__init__()
self.channels = channels
self.input_height = input_height
self.depth = len(self.channels) - 1
self.layers = nn.ModuleList(
[
build_upscaling_layer(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height
// (2 ** (self.depth - layer_num)),
layer_type=layer_type,
)
for layer_num, (in_channels, out_channels) in enumerate(
pairwise(channels)
)
]
if config.block_type == "StandardConvDown":
return (
StandardConvDownBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height // 2,
)
def forward(
self,
x: torch.Tensor,
residuals: List[torch.Tensor],
) -> torch.Tensor:
if len(residuals) != len(self.layers):
raise ValueError(
f"Incorrect number of residuals provided. "
f"Expected {len(self.layers)} (matching the number of layers), "
f"but got {len(residuals)}."
if config.block_type == "FreqCoordConvUp":
return (
FreqCoordConvUpBlock(
in_channels=in_channels,
out_channels=config.out_channels,
input_height=input_height,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height * 2,
)
if config.block_type == "StandardConvUp":
return (
StandardConvUpBlock(
in_channels=in_channels,
out_channels=config.out_channels,
kernel_size=config.kernel_size,
pad_size=config.pad_size,
),
config.out_channels,
input_height * 2,
)
if config.block_type == "group":
current_channels = in_channels
current_height = input_height
blocks = []
for block_config in config.blocks:
block, current_channels, current_height = build_layer_from_config(
input_height=current_height,
in_channels=current_channels,
config=block_config,
)
blocks.append(block)
for layer, res in zip(self.layers, residuals[::-1]):
x = layer(x + res)
return nn.Sequential(*blocks), current_channels, current_height
return x
raise NotImplementedError(f"Unknown block type {config.block_type}")

View File

@ -0,0 +1,254 @@
"""Defines the Bottleneck component of an Encoder-Decoder architecture.
This module provides the configuration (`BottleneckConfig`) and
`torch.nn.Module` implementations (`Bottleneck`, `BottleneckAttn`) for the
bottleneck layer(s) that typically connect the Encoder (downsampling path) and
Decoder (upsampling path) in networks like U-Nets.
The bottleneck processes the lowest-resolution, highest-dimensionality feature
map produced by the Encoder. This module offers a configurable option to include
a `SelfAttention` layer within the bottleneck, allowing the model to capture
global temporal context before features are passed to the Decoder.
A factory function `build_bottleneck` constructs the appropriate bottleneck
module based on the provided configuration.
"""
from typing import Optional
import torch
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import SelfAttention, VerticalConv
__all__ = [
"BottleneckConfig",
"Bottleneck",
"BottleneckAttn",
"build_bottleneck",
]
class BottleneckConfig(BaseConfig):
"""Configuration for the bottleneck layer(s).
Defines the number of channels within the bottleneck and whether to include
a self-attention mechanism.
Attributes
----------
channels : int
The number of output channels produced by the main convolutional layer
within the bottleneck. This often matches the number of channels coming
from the last encoder stage, but can be different. Must be positive.
This also defines the channel dimensions used within the optional
`SelfAttention` layer.
self_attention : bool
If True, includes a `SelfAttention` layer operating on the time
dimension after an initial `VerticalConv` layer within the bottleneck.
If False, only the initial `VerticalConv` (and height repetition) is
performed.
"""
channels: int
self_attention: bool
class Bottleneck(nn.Module):
"""Base Bottleneck module for Encoder-Decoder architectures.
This implementation represents the simplest bottleneck structure
considered, primarily consisting of a `VerticalConv` layer. This layer
collapses the frequency dimension (height) to 1, summarizing information
across frequencies at each time step. The output is then repeated along the
height dimension to match the original bottleneck input height before being
passed to the decoder.
This base version does *not* include self-attention.
Parameters
----------
input_height : int
Height (frequency bins) of the input tensor. Must be positive.
in_channels : int
Number of channels in the input tensor from the encoder. Must be
positive.
out_channels : int
Number of output channels. Must be positive.
Attributes
----------
in_channels : int
Number of input channels accepted by the bottleneck.
input_height : int
Expected height of the input tensor.
channels : int
Number of output channels.
conv_vert : VerticalConv
The vertical convolution layer.
Raises
------
ValueError
If `input_height`, `in_channels`, or `out_channels` are not positive.
"""
def __init__(
self,
input_height: int,
in_channels: int,
out_channels: int,
) -> None:
"""Initialize the base Bottleneck layer."""
super().__init__()
self.in_channels = in_channels
self.input_height = input_height
self.out_channels = out_channels
self.conv_vert = VerticalConv(
in_channels=in_channels,
out_channels=out_channels,
input_height=input_height,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input features through the bottleneck.
Applies vertical convolution and repeats the output height.
Parameters
----------
x : torch.Tensor
Input tensor from the encoder bottleneck, shape
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
`H_in` must match `self.input_height`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H_in, W)`. Note that the height
dimension `H_in` is restored via repetition after the vertical
convolution.
"""
x = self.conv_vert(x)
return x.repeat([1, 1, self.input_height, 1])
class BottleneckAttn(Bottleneck):
"""Bottleneck module including a Self-Attention layer.
Extends the base `Bottleneck` by inserting a `SelfAttention` layer after
the initial `VerticalConv`. This allows the bottleneck to capture global
temporal dependencies in the summarized frequency features before passing
them to the decoder.
Sequence: VerticalConv -> SelfAttention -> Repeat Height.
Parameters
----------
input_height : int
Height (frequency bins) of the input tensor from the encoder.
in_channels : int
Number of channels in the input tensor from the encoder.
out_channels : int
Number of output channels produced by the `VerticalConv` and
subsequently processed and output by this bottleneck. Also determines
the input/output channels of the internal `SelfAttention` layer.
attention : nn.Module
An initialized `SelfAttention` module instance.
Raises
------
ValueError
If `input_height`, `in_channels`, or `out_channels` are not positive.
"""
def __init__(
self,
input_height: int,
in_channels: int,
out_channels: int,
attention: nn.Module,
) -> None:
"""Initialize the Bottleneck with Self-Attention."""
super().__init__(input_height, in_channels, out_channels)
self.attention = attention
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input tensor.
Parameters
----------
x : torch.Tensor
Input tensor from the encoder bottleneck, shape
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
`H_in` must match `self.input_height`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H_in, W)`, after applying attention
and repeating the height dimension.
"""
x = self.conv_vert(x)
x = self.attention(x)
return x.repeat([1, 1, self.input_height, 1])
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
channels=256,
self_attention=True,
)
def build_bottleneck(
input_height: int,
in_channels: int,
config: Optional[BottleneckConfig] = None,
) -> nn.Module:
"""Factory function to build the Bottleneck module from configuration.
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
on the `config.self_attention` flag.
Parameters
----------
input_height : int
Height (frequency bins) of the input tensor. Must be positive.
in_channels : int
Number of channels in the input tensor. Must be positive.
config : BottleneckConfig, optional
Configuration object specifying the bottleneck channels and whether
to use self-attention. Uses `DEFAULT_BOTTLENECK_CONFIG` if None.
Returns
-------
nn.Module
An initialized bottleneck module (`Bottleneck` or `BottleneckAttn`).
Raises
------
ValueError
If `input_height` or `in_channels` are not positive.
"""
config = config or DEFAULT_BOTTLENECK_CONFIG
if config.self_attention:
attention = SelfAttention(
in_channels=config.channels,
attention_channels=config.channels,
)
return BottleneckAttn(
input_height=input_height,
in_channels=in_channels,
out_channels=config.channels,
attention=attention,
)
return Bottleneck(
input_height=input_height,
in_channels=in_channels,
out_channels=config.channels,
)

View File

@ -1,15 +1,277 @@
import sys
from typing import Iterable, List, Literal, Sequence
"""Constructs the Decoder part of an Encoder-Decoder neural network.
This module defines the configuration structure (`DecoderConfig`) for the layer
sequence and provides the `Decoder` class (an `nn.Module`) along with a factory
function (`build_decoder`). Decoders typically form the upsampling path in
architectures like U-Nets, taking bottleneck features
(usually from an `Encoder`) and skip connections to reconstruct
higher-resolution feature maps.
The decoder is built dynamically by stacking neural network blocks based on a
list of configuration objects provided in `DecoderConfig.layers`. Each config
object specifies the type of block (e.g., standard convolution,
coordinate-feature convolution with upsampling) and its parameters. This allows
flexible definition of decoder architectures via configuration files.
The `Decoder`'s `forward` method is designed to accept skip connection tensors
(`residuals`) from the encoder, merging them with the upsampled feature maps
at each stage.
"""
from typing import Annotated, List, Optional, Union
import torch
from pydantic import Field
from torch import nn
from batdetect2.models.blocks import ConvBlockUpF, ConvBlockUpStandard
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import (
BlockGroupConfig,
ConvConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
build_layer_from_config,
)
if sys.version_info >= (3, 10):
from itertools import pairwise
else:
__all__ = [
"DecoderConfig",
"Decoder",
"build_decoder",
"DEFAULT_DECODER_CONFIG",
]
def pairwise(iterable: Sequence) -> Iterable:
for x, y in zip(iterable[:-1], iterable[1:]):
yield x, y
DecoderLayerConfig = Annotated[
Union[
ConvConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
BlockGroupConfig,
],
Field(discriminator="block_type"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""
class DecoderConfig(BaseConfig):
"""Configuration for the sequence of layers in the Decoder module.
Defines the types and parameters of the neural network blocks that
constitute the decoder's upsampling path.
Attributes
----------
layers : List[DecoderLayerConfig]
An ordered list of configuration objects, each defining one layer or
block in the decoder sequence. Each item must be a valid block
config including a `block_type` field and necessary parameters like
`out_channels`. Input channels for each layer are inferred sequentially.
The list must contain at least one layer.
"""
layers: List[DecoderLayerConfig] = Field(min_length=1)
class Decoder(nn.Module):
"""Sequential Decoder module composed of configurable upsampling layers.
Constructs the upsampling path of an encoder-decoder network by stacking
multiple blocks (e.g., `StandardConvUpBlock`, `FreqCoordConvUpBlock`)
based on a list of layer modules provided during initialization (typically
created by the `build_decoder` factory function).
The `forward` method is designed to integrate skip connection tensors
(`residuals`) from the corresponding encoder stages, by adding them
element-wise to the input of each decoder layer before processing.
Attributes
----------
in_channels : int
Number of channels expected in the input tensor.
out_channels : int
Number of channels in the final output tensor produced by the last
layer.
input_height : int
Height (frequency bins) expected in the input tensor.
output_height : int
Height (frequency bins) expected in the output tensor.
layers : nn.ModuleList
The sequence of instantiated upscaling layer modules.
depth : int
The number of upscaling layers (depth) in the decoder.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
input_height: int,
output_height: int,
layers: List[nn.Module],
):
"""Initialize the Decoder module.
Note: This constructor is typically called internally by the
`build_decoder` factory function.
Parameters
----------
out_channels : int
Number of channels produced by the final layer.
input_height : int
Expected height of the input tensor (bottleneck).
in_channels : int
Expected number of channels in the input tensor (bottleneck).
layers : List[nn.Module]
A list of pre-instantiated upscaling layer modules (e.g.,
`StandardConvUpBlock` or `FreqCoordConvUpBlock`) in the desired
sequence (from bottleneck towards output resolution).
"""
super().__init__()
self.input_height = input_height
self.output_height = output_height
self.in_channels = in_channels
self.out_channels = out_channels
self.layers = nn.ModuleList(layers)
self.depth = len(self.layers)
def forward(
self,
x: torch.Tensor,
residuals: List[torch.Tensor],
) -> torch.Tensor:
"""Pass input through decoder layers, incorporating skip connections.
Processes the input tensor `x` sequentially through the upscaling
layers. At each stage, the corresponding skip connection tensor from
the `residuals` list is added element-wise to the input before passing
it to the upscaling block.
Parameters
----------
x : torch.Tensor
Input tensor from the previous stage (e.g., encoder bottleneck).
Shape `(B, C_in, H_in, W_in)`, where `C_in` matches
`self.in_channels`.
residuals : List[torch.Tensor]
List containing the skip connection tensors from the corresponding
encoder stages. Should be ordered from the deepest encoder layer
output (lowest resolution) to the shallowest (highest resolution
near input). The number of tensors in this list must match the
number of decoder layers (`self.depth`). Each residual tensor's
channel count must be compatible with the input tensor `x` for
element-wise addition (or concatenation if the blocks were designed
for it).
Returns
-------
torch.Tensor
The final decoded feature map tensor produced by the last layer.
Shape `(B, C_out, H_out, W_out)`.
Raises
------
ValueError
If the number of `residuals` provided does not match the decoder
depth.
RuntimeError
If shapes mismatch during skip connection addition or layer
processing.
"""
if len(residuals) != len(self.layers):
raise ValueError(
f"Incorrect number of residuals provided. "
f"Expected {len(self.layers)} (matching the number of layers), "
f"but got {len(residuals)}."
)
for layer, res in zip(self.layers, residuals[::-1]):
x = layer(x + res)
return x
DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
layers=[
FreqCoordConvUpConfig(out_channels=64),
FreqCoordConvUpConfig(out_channels=32),
BlockGroupConfig(
blocks=[
FreqCoordConvUpConfig(out_channels=32),
ConvConfig(out_channels=32),
]
),
],
)
"""A default configuration for the Decoder's *layer sequence*.
Specifies an architecture often used in BatDetect2, consisting of three
frequency coordinate-aware upsampling blocks followed by a standard
convolutional block.
"""
def build_decoder(
in_channels: int,
input_height: int,
config: Optional[DecoderConfig] = None,
) -> Decoder:
"""Factory function to build a Decoder instance from configuration.
Constructs a sequential `Decoder` module based on the layer sequence
defined in a `DecoderConfig` object and the provided input dimensions
(bottleneck channels and height). If no config is provided, uses the
default layer sequence from `DEFAULT_DECODER_CONFIG`.
It iteratively builds the layers using the unified `build_layer_from_config`
factory (from `.blocks`), tracking the changing number of channels and
feature map height required for each subsequent layer.
Parameters
----------
in_channels : int
The number of channels in the input tensor to the decoder. Must be > 0.
input_height : int
The height (frequency bins) of the input tensor to the decoder. Must be
> 0.
config : DecoderConfig, optional
The configuration object detailing the sequence of layers and their
parameters. If None, `DEFAULT_DECODER_CONFIG` is used.
Returns
-------
Decoder
An initialized `Decoder` module.
Raises
------
ValueError
If `in_channels` or `input_height` are not positive, or if the layer
configuration is invalid (e.g., empty list, unknown `block_type`).
NotImplementedError
If `build_layer_from_config` encounters an unknown `block_type`.
"""
config = config or DEFAULT_DECODER_CONFIG
current_channels = in_channels
current_height = input_height
layers = []
for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config(
in_channels=current_channels,
input_height=current_height,
config=layer_config,
)
layers.append(layer)
return Decoder(
in_channels=in_channels,
out_channels=current_channels,
input_height=input_height,
output_height=current_height,
layers=layers,
)

View File

@ -0,0 +1,173 @@
"""Assembles the complete BatDetect2 Detection Model.
This module defines the concrete `Detector` class, which implements the
`DetectionModel` interface defined in `.types`. It combines a feature
extraction backbone with specific prediction heads to create the end-to-end
neural network used for detecting bat calls, predicting their size, and
classifying them.
The primary components are:
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
- `build_detector`: A factory function to conveniently construct a standard
`Detector` instance given a backbone and the number of target classes.
This module focuses purely on the neural network architecture definition. The
logic for preprocessing inputs and postprocessing/decoding outputs resides in
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
"""
import torch
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
class Detector(DetectionModel):
"""Concrete implementation of the BatDetect2 Detection Model.
Assembles a complete detection and classification model by combining a
feature extraction backbone network with specific prediction heads for
detection probability, bounding box size regression, and class
probabilities.
Attributes
----------
backbone : BackboneModel
The feature extraction backbone network module.
num_classes : int
The number of specific target classes the model predicts (derived from
the `classifier_head`).
classifier_head : ClassifierHead
The prediction head responsible for generating class probabilities.
detector_head : DetectorHead
The prediction head responsible for generating detection probabilities.
bbox_head : BBoxHead
The prediction head responsible for generating bounding box size
predictions.
"""
backbone: BackboneModel
def __init__(
self,
backbone: BackboneModel,
classifier_head: ClassifierHead,
detector_head: DetectorHead,
bbox_head: BBoxHead,
):
"""Initialize the Detector model.
Note: Instances are typically created using the `build_detector`
factory function.
Parameters
----------
backbone : BackboneModel
An initialized feature extraction backbone module (e.g., built by
`build_backbone` from the `.backbone` module).
classifier_head : ClassifierHead
An initialized classification head module. The number of classes
is inferred from this head.
detector_head : DetectorHead
An initialized detection head module.
bbox_head : BBoxHead
An initialized bounding box size prediction head module.
Raises
------
TypeError
If the provided modules are not of the expected types.
"""
super().__init__()
self.backbone = backbone
self.num_classes = classifier_head.num_classes
self.classifier_head = classifier_head
self.detector_head = detector_head
self.bbox_head = bbox_head
def forward(self, spec: torch.Tensor) -> ModelOutput:
"""Perform the forward pass of the complete detection model.
Processes the input spectrogram through the backbone to extract
features, then passes these features through the separate prediction
heads to generate detection probabilities, class probabilities, and
size predictions.
Parameters
----------
spec : torch.Tensor
Input spectrogram tensor, typically with shape
`(batch_size, input_channels, frequency_bins, time_bins)`. The
shape must be compatible with the `self.backbone` input
requirements.
Returns
-------
ModelOutput
A NamedTuple containing the four output tensors:
- `detection_probs`: Detection probability heatmap `(B, 1, H, W)`.
- `size_preds`: Predicted scaled size dimensions `(B, 2, H, W)`.
- `class_probs`: Class probabilities (excluding background)
`(B, num_classes, H, W)`.
- `features`: Output feature map from the backbone
`(B, C_out, H, W)`.
"""
features = self.backbone(spec)
detection = self.detector_head(features)
classification = self.classifier_head(features)
size_preds = self.bbox_head(features)
return ModelOutput(
detection_probs=detection,
size_preds=size_preds,
class_probs=classification,
features=features,
)
def build_detector(num_classes: int, backbone: BackboneModel) -> Detector:
"""Factory function to build a standard Detector model instance.
Creates the standard prediction heads (`ClassifierHead`, `DetectorHead`,
`BBoxHead`) configured appropriately based on the output channels of the
provided `backbone` and the specified `num_classes`. It then assembles
these components into a `Detector` model.
Parameters
----------
num_classes : int
The number of specific target classes for the classification head
(excluding any implicit background class). Must be positive.
backbone : BackboneModel
An initialized feature extraction backbone module instance. The number
of output channels from this backbone (`backbone.out_channels`) is used
to configure the input channels for the prediction heads.
Returns
-------
Detector
An initialized `Detector` model instance.
Raises
------
ValueError
If `num_classes` is not positive.
AttributeError
If `backbone` does not have the required `out_channels` attribute.
"""
classifier_head = ClassifierHead(
num_classes=num_classes,
in_channels=backbone.out_channels,
)
detector_head = DetectorHead(
in_channels=backbone.out_channels,
)
bbox_head = BBoxHead(
in_channels=backbone.out_channels,
)
return Detector(
backbone=backbone,
classifier_head=classifier_head,
detector_head=detector_head,
bbox_head=bbox_head,
)

View File

@ -1,15 +1,318 @@
import sys
from typing import Iterable, List, Literal, Sequence
"""Constructs the Encoder part of a configurable neural network backbone.
This module defines the configuration structure (`EncoderConfig`) and provides
the `Encoder` class (an `nn.Module`) along with a factory function
(`build_encoder`) to create sequential encoders. Encoders typically form the
downsampling path in architectures like U-Nets, processing input feature maps
(like spectrograms) to produce lower-resolution, higher-dimensionality feature
representations (bottleneck features).
The encoder is built dynamically by stacking neural network blocks based on a
list of configuration objects provided in `EncoderConfig.layers`. Each
configuration object specifies the type of block (e.g., standard convolution,
coordinate-feature convolution with downsampling) and its parameters
(e.g., output channels). This allows for flexible definition of encoder
architectures via configuration files.
The `Encoder`'s `forward` method returns outputs from all intermediate layers,
suitable for skip connections, while the `encode` method returns only the final
bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also
provided.
"""
from typing import Annotated, List, Optional, Union
import torch
from pydantic import Field
from torch import nn
from batdetect2.models.blocks import ConvBlockDownCoordF, ConvBlockDownStandard
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import (
BlockGroupConfig,
ConvConfig,
FreqCoordConvDownConfig,
StandardConvDownConfig,
build_layer_from_config,
)
if sys.version_info >= (3, 10):
from itertools import pairwise
else:
__all__ = [
"EncoderConfig",
"Encoder",
"build_encoder",
"DEFAULT_ENCODER_CONFIG",
]
def pairwise(iterable: Sequence) -> Iterable:
for x, y in zip(iterable[:-1], iterable[1:]):
yield x, y
EncoderLayerConfig = Annotated[
Union[
ConvConfig,
FreqCoordConvDownConfig,
StandardConvDownConfig,
BlockGroupConfig,
],
Field(discriminator="block_type"),
]
"""Type alias for the discriminated union of block configs usable in Encoder."""
class EncoderConfig(BaseConfig):
"""Configuration for building the sequential Encoder module.
Defines the sequence of neural network blocks that constitute the encoder
(downsampling path).
Attributes
----------
layers : List[EncoderLayerConfig]
An ordered list of configuration objects, each defining one layer or
block in the encoder sequence. Each item must be a valid block config
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
`StandardConvDownConfig`) including a `block_type` field and necessary
parameters like `out_channels`. Input channels for each layer are
inferred sequentially. The list must contain at least one layer.
"""
layers: List[EncoderLayerConfig] = Field(min_length=1)
class Encoder(nn.Module):
"""Sequential Encoder module composed of configurable downscaling layers.
Constructs the downsampling path of an encoder-decoder network by stacking
multiple downscaling blocks.
The `forward` method executes the sequence and returns the output feature
map from *each* downscaling stage, facilitating the implementation of skip
connections in U-Net-like architectures. The `encode` method returns only
the final output tensor (bottleneck features).
Attributes
----------
in_channels : int
Number of channels expected in the input tensor.
input_height : int
Height (frequency bins) expected in the input tensor.
output_channels : int
Number of channels in the final output tensor (bottleneck).
output_height : int
Height (frequency bins) expected in the output tensor.
layers : nn.ModuleList
The sequence of instantiated downscaling layer modules.
depth : int
The number of downscaling layers in the encoder.
"""
def __init__(
self,
output_channels: int,
output_height: int,
layers: List[nn.Module],
input_height: int = 128,
in_channels: int = 1,
):
"""Initialize the Encoder module.
Note: This constructor is typically called internally by the
`build_encoder` factory function, which prepares the `layers` list.
Parameters
----------
output_channels : int
Number of channels produced by the final layer.
output_height : int
The expected height of the output tensor.
layers : List[nn.Module]
A list of pre-instantiated downscaling layer modules (e.g.,
`StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired
sequence.
input_height : int, default=128
Expected height of the input tensor.
in_channels : int, default=1
Expected number of channels in the input tensor.
"""
super().__init__()
self.in_channels = in_channels
self.input_height = input_height
self.out_channels = output_channels
self.output_height = output_height
self.layers = nn.ModuleList(layers)
self.depth = len(self.layers)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Pass input through encoder layers, returns all intermediate outputs.
This method is typically used when the Encoder is part of a U-Net or
similar architecture requiring skip connections.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match
`self.in_channels` and `H_in` must match `self.input_height`.
Returns
-------
List[torch.Tensor]
A list containing the output tensors from *each* downscaling layer
in the sequence. `outputs[0]` is the output of the first layer,
`outputs[-1]` is the final output (bottleneck) of the encoder.
Raises
------
ValueError
If input tensor channel count or height does not match expected
values.
"""
if x.shape[1] != self.in_channels:
raise ValueError(
f"Input tensor has {x.shape[1]} channels, "
f"but encoder expects {self.in_channels}."
)
if x.shape[2] != self.input_height:
raise ValueError(
f"Input tensor height {x.shape[2]} does not match "
f"encoder expected input_height {self.input_height}."
)
outputs = []
for layer in self.layers:
x = layer(x)
outputs.append(x)
return outputs
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Pass input through encoder layers, returning only the final output.
This method provides access to the bottleneck features produced after
the last downscaling layer.
Parameters
----------
x : torch.Tensor
Input tensor, shape `(B, C_in, H_in, W)`. Must match expected
`in_channels` and `input_height`.
Returns
-------
torch.Tensor
The final output tensor (bottleneck features) from the last layer
of the encoder. Shape `(B, C_out, H_out, W_out)`.
Raises
------
ValueError
If input tensor channel count or height does not match expected
values.
"""
if x.shape[1] != self.in_channels:
raise ValueError(
f"Input tensor has {x.shape[1]} channels, "
f"but encoder expects {self.in_channels}."
)
if x.shape[2] != self.input_height:
raise ValueError(
f"Input tensor height {x.shape[2]} does not match "
f"encoder expected input_height {self.input_height}."
)
for layer in self.layers:
x = layer(x)
return x
DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
layers=[
FreqCoordConvDownConfig(out_channels=32),
FreqCoordConvDownConfig(out_channels=64),
BlockGroupConfig(
blocks=[
FreqCoordConvDownConfig(out_channels=128),
ConvConfig(out_channels=256),
]
),
],
)
"""Default configuration for the Encoder.
Specifies an architecture typically used in BatDetect2:
- Input: 1 channel, 128 frequency bins.
- Layer 1: FreqCoordConvDown -> 32 channels, H=64
- Layer 2: FreqCoordConvDown -> 64 channels, H=32
- Layer 3: FreqCoordConvDown -> 128 channels, H=16
- Layer 4: ConvBlock -> 256 channels, H=16 (Bottleneck)
"""
def build_encoder(
in_channels: int,
input_height: int,
config: Optional[EncoderConfig] = None,
) -> Encoder:
"""Factory function to build an Encoder instance from configuration.
Constructs a sequential `Encoder` module based on the layer sequence
defined in an `EncoderConfig` object and the provided input dimensions.
If no config is provided, uses the default layer sequence from
`DEFAULT_ENCODER_CONFIG`.
It iteratively builds the layers using the unified
`build_layer_from_config` factory (from `.blocks`), tracking the changing
number of channels and feature map height required for each subsequent
layer, especially for coordinate- aware blocks.
Parameters
----------
in_channels : int
The number of channels expected in the input tensor to the encoder.
Must be > 0.
input_height : int
The height (frequency bins) expected in the input tensor. Must be > 0.
Crucial for initializing coordinate-aware layers correctly.
config : EncoderConfig, optional
The configuration object detailing the sequence of layers and their
parameters. If None, `DEFAULT_ENCODER_CONFIG` is used.
Returns
-------
Encoder
An initialized `Encoder` module.
Raises
------
ValueError
If `in_channels` or `input_height` are not positive, or if the layer
configuration is invalid (e.g., empty list, unknown `block_type`).
NotImplementedError
If `build_layer_from_config` encounters an unknown `block_type`.
"""
if in_channels <= 0 or input_height <= 0:
raise ValueError("in_channels and input_height must be positive.")
config = config or DEFAULT_ENCODER_CONFIG
current_channels = in_channels
current_height = input_height
layers = []
for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config(
in_channels=current_channels,
input_height=current_height,
config=layer_config,
)
layers.append(layer)
return Encoder(
input_height=input_height,
layers=layers,
in_channels=in_channels,
output_channels=current_channels,
output_height=current_height,
)

View File

@ -1,42 +1,199 @@
from typing import NamedTuple
"""Prediction Head modules for BatDetect2 models.
This module defines simple `torch.nn.Module` subclasses that serve as
prediction heads, typically attached to the output feature map of a backbone
network
Each head is responsible for generating one specific type of output required
by the BatDetect2 task:
- `DetectorHead`: Predicts the probability of sound event presence.
- `ClassifierHead`: Predicts the probability distribution over target classes.
- `BBoxHead`: Predicts the size (width, height) of the sound event's bounding
box.
These heads use 1x1 convolutions to map the backbone feature channels
to the desired number of output channels for each prediction task at each
spatial location, followed by an appropriate activation function (e.g., sigmoid
for detection, softmax for classification, none for size regression).
"""
import torch
from torch import nn
__all__ = ["ClassifierHead"]
class Output(NamedTuple):
detection: torch.Tensor
classification: torch.Tensor
__all__ = [
"ClassifierHead",
"DetectorHead",
"BBoxHead",
]
class ClassifierHead(nn.Module):
"""Prediction head for multi-class classification probabilities.
Takes an input feature map and produces a probability map where each
channel corresponds to a specific target class. It uses a 1x1 convolution
to map input channels to `num_classes + 1` outputs (one for each target
class plus an assumed background/generic class), applies softmax across the
channels, and returns the probabilities for the specific target classes
(excluding the last background/generic channel).
Parameters
----------
num_classes : int
The number of specific target classes the model should predict
(excluding any background or generic category). Must be positive.
in_channels : int
Number of channels in the input feature map tensor from the backbone.
Must be positive.
Attributes
----------
num_classes : int
Number of specific output classes.
in_channels : int
Number of input channels expected.
classifier : nn.Conv2d
The 1x1 convolutional layer used for prediction.
Output channels = num_classes + 1.
Raises
------
ValueError
If `num_classes` or `in_channels` are not positive.
"""
def __init__(self, num_classes: int, in_channels: int):
"""Initialize the ClassifierHead."""
super().__init__()
self.num_classes = num_classes
self.in_channels = in_channels
self.classifier = nn.Conv2d(
self.in_channels,
# Add one to account for the background class
self.num_classes + 1,
kernel_size=1,
padding=0,
)
def forward(self, features: torch.Tensor) -> Output:
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute class probabilities from input features.
Parameters
----------
features : torch.Tensor
Input feature map tensor from the backbone, typically with shape
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Returns
-------
torch.Tensor
Class probability map tensor with shape `(B, num_classes, H, W)`.
Contains probabilities for the specific target classes after
softmax, excluding the implicit background/generic class channel.
"""
logits = self.classifier(features)
probs = torch.softmax(logits, dim=1)
detection_probs = probs[:, :-1].sum(dim=1, keepdim=True)
return Output(
detection=detection_probs,
classification=probs[:, :-1],
return probs[:, :-1]
class DetectorHead(nn.Module):
"""Prediction head for sound event detection probability.
Takes an input feature map and produces a single-channel heatmap where
each value represents the probability ([0, 1]) of a relevant sound event
(of any class) being present at that spatial location.
Uses a 1x1 convolution to map input channels to 1 output channel, followed
by a sigmoid activation function.
Parameters
----------
in_channels : int
Number of channels in the input feature map tensor from the backbone.
Must be positive.
Attributes
----------
in_channels : int
Number of input channels expected.
detector : nn.Conv2d
The 1x1 convolutional layer mapping to a single output channel.
Raises
------
ValueError
If `in_channels` is not positive.
"""
def __init__(self, in_channels: int):
"""Initialize the DetectorHead."""
super().__init__()
self.in_channels = in_channels
self.detector = nn.Conv2d(
in_channels=self.in_channels,
out_channels=1,
kernel_size=1,
padding=0,
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute detection probabilities from input features.
Parameters
----------
features : torch.Tensor
Input feature map tensor from the backbone, typically with shape
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Returns
-------
torch.Tensor
Detection probability heatmap tensor with shape `(B, 1, H, W)`.
Values are in the range [0, 1] due to the sigmoid activation.
Raises
------
RuntimeError
If input channel count does not match `self.in_channels`.
"""
return torch.sigmoid(self.detector(features))
class BBoxHead(nn.Module):
"""Prediction head for bounding box size dimensions.
Takes an input feature map and produces a two-channel map where each
channel represents a predicted size dimension (typically width/duration and
height/bandwidth) for a potential sound event at that spatial location.
Uses a 1x1 convolution to map input channels to 2 output channels. No
activation function is typically applied, as size prediction is often
treated as a direct regression task. The output values usually represent
*scaled* dimensions that need to be un-scaled during postprocessing.
Parameters
----------
in_channels : int
Number of channels in the input feature map tensor from the backbone.
Must be positive.
Attributes
----------
in_channels : int
Number of input channels expected.
bbox : nn.Conv2d
The 1x1 convolutional layer mapping to 2 output channels
(width, height).
Raises
------
ValueError
If `in_channels` is not positive.
"""
def __init__(self, in_channels: int):
"""Initialize the BBoxHead."""
super().__init__()
self.in_channels = in_channels
@ -48,4 +205,19 @@ class BBoxHead(nn.Module):
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
"""Compute predicted bounding box dimensions from input features.
Parameters
----------
features : torch.Tensor
Input feature map tensor from the backbone, typically with shape
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
Returns
-------
torch.Tensor
Predicted size tensor with shape `(B, 2, H, W)`. Channel 0 usually
represents scaled width, Channel 1 scaled height. These values
need to be un-scaled during postprocessing.
"""
return self.bbox(features)

225
batdetect2/models/types.py Normal file
View File

@ -0,0 +1,225 @@
"""Defines shared interfaces (ABCs) and data structures for models.
This module centralizes the definitions of core data structures, like the
standard model output container (`ModelOutput`), and establishes abstract base
classes (ABCs) using `abc.ABC` and `torch.nn.Module`. These define contracts
for fundamental model components, ensuring modularity and consistent
interaction within the `batdetect2.models` package.
Key components:
- `ModelOutput`: Standard structure for outputs from detection models.
- `BackboneModel`: Generic interface for any feature extraction backbone.
- `EncoderDecoderModel`: Specialized interface for backbones with distinct
encoder-decoder stages (e.g., U-Net), providing access to intermediate
features.
- `DetectionModel`: Interface for the complete end-to-end detection model.
"""
from abc import ABC, abstractmethod
from typing import NamedTuple
import torch
import torch.nn as nn
__all__ = [
"ModelOutput",
"BackboneModel",
"DetectionModel",
]
class ModelOutput(NamedTuple):
"""Standard container for the outputs of a BatDetect2 detection model.
This structure groups the different prediction tensors produced by the
model for a batch of input spectrograms. All tensors typically share the
same spatial dimensions (height H, width W) corresponding to the model's
output resolution, and the same batch size (N).
Attributes
----------
detection_probs : torch.Tensor
Tensor containing the probability of sound event presence at each
location in the output grid.
Shape: `(N, 1, H, W)`
size_preds : torch.Tensor
Tensor containing the predicted size dimensions
(e.g., width and height) for a potential bounding box at each location.
Shape: `(N, 2, H, W)` (Channel 0 typically width, Channel 1 height)
class_probs : torch.Tensor
Tensor containing the predicted probabilities (or logits, depending on
the final activation) for each target class at each location.
The number of channels corresponds to the number of specific classes
defined in the `Targets` configuration.
Shape: `(N, num_classes, H, W)`
features : torch.Tensor
Tensor containing features extracted by the model's backbone. These
might be used for downstream tasks or analysis. The number of channels
depends on the specific model architecture.
Shape: `(N, num_features, H, W)`
"""
detection_probs: torch.Tensor
size_preds: torch.Tensor
class_probs: torch.Tensor
features: torch.Tensor
class BackboneModel(ABC, nn.Module):
"""Abstract Base Class for generic feature extraction backbone models.
Defines the minimal interface for a feature extractor network within a
BatDetect2 model. Its primary role is to process an input spectrogram
tensor and produce a spatially rich feature map tensor, which is then
typically consumed by separate prediction heads (for detection,
classification, size).
This base class is agnostic to the specific internal architecture (e.g.,
it could be a simple CNN, a U-Net, a Transformer, etc.). Concrete
implementations must inherit from this class and `torch.nn.Module`,
implement the `forward` method, and define the required attributes.
Attributes
----------
input_height : int
Expected height (number of frequency bins) of the input spectrogram
tensor that the backbone is designed to process.
out_channels : int
Number of channels in the final feature map tensor produced by the
backbone's `forward` method.
"""
input_height: int
"""Expected input spectrogram height (frequency bins)."""
out_channels: int
"""Number of output channels in the final feature map."""
@abstractmethod
def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Perform the forward pass to extract features from the spectrogram.
Parameters
----------
spec : torch.Tensor
Input spectrogram tensor, typically with shape
`(batch_size, 1, frequency_bins, time_bins)`.
`frequency_bins` should match `self.input_height`.
Returns
-------
torch.Tensor
Output feature map tensor, typically with shape
`(batch_size, self.out_channels, output_height, output_width)`.
The spatial dimensions (`output_height`, `output_width`) depend
on the specific backbone architecture (e.g., they might match the
input or be downsampled).
"""
raise NotImplementedError
class EncoderDecoderModel(BackboneModel):
"""Abstract Base Class for Encoder-Decoder style backbone models.
This class specializes `BackboneModel` for architectures that have distinct
encoder stages (downsampling path), a bottleneck, and decoder stages
(upsampling path).
It provides separate abstract methods for the `encode` and `decode` steps,
allowing access to the intermediate "bottleneck" features produced by the
encoder. This can be useful for tasks like transfer learning or specialized
analyses.
Attributes
----------
input_height : int
(Inherited from BackboneModel) Expected input spectrogram height.
out_channels : int
(Inherited from BackboneModel) Number of output channels in the final
feature map produced by the decoder/forward pass.
bottleneck_channels : int
Number of channels in the feature map produced by the encoder at its
deepest point (the bottleneck), before the decoder starts.
"""
bottleneck_channels: int
"""Number of channels at the encoder's bottleneck."""
@abstractmethod
def encode(self, spec: torch.Tensor) -> torch.Tensor:
"""Process the input spectrogram through the encoder part.
Takes the input spectrogram and passes it through the downsampling path
of the network up to the bottleneck layer.
Parameters
----------
spec : torch.Tensor
Input spectrogram tensor, typically with shape
`(batch_size, 1, frequency_bins, time_bins)`.
Returns
-------
torch.Tensor
The encoded feature map from the bottleneck layer, typically with
shape `(batch_size, self.bottleneck_channels, bottleneck_height,
bottleneck_width)`. The spatial dimensions are usually downsampled
relative to the input.
"""
...
@abstractmethod
def decode(self, encoded: torch.Tensor) -> torch.Tensor:
"""Process the bottleneck features through the decoder part.
Takes the encoded feature map from the bottleneck and passes it through
the upsampling path (potentially using skip connections from the
encoder) to produce the final output feature map.
Parameters
----------
encoded : torch.Tensor
The bottleneck feature map tensor produced by the `encode` method.
Returns
-------
torch.Tensor
The final output feature map tensor, typically with shape
`(batch_size, self.out_channels, output_height, output_width)`.
This should match the output shape of the `forward` method.
"""
...
class DetectionModel(ABC, nn.Module):
"""Abstract Base Class for complete BatDetect2 detection models.
Defines the interface for the overall model that takes an input spectrogram
and produces all necessary outputs for detection, classification, and size
prediction, packaged within a `ModelOutput` object.
Concrete implementations typically combine a `BackboneModel` for feature
extraction with specific prediction heads for each output type. They must
inherit from this class and `torch.nn.Module`, and implement the `forward`
method.
"""
@abstractmethod
def forward(self, spec: torch.Tensor) -> ModelOutput:
"""Perform the forward pass of the full detection model.
Processes the input spectrogram through the backbone and prediction
heads to generate all required output tensors.
Parameters
----------
spec : torch.Tensor
Input spectrogram tensor, typically with shape
`(batch_size, 1, frequency_bins, time_bins)`.
Returns
-------
ModelOutput
A NamedTuple containing the prediction tensors: `detection_probs`,
`size_preds`, `class_probs`, and `features`.
"""

View File

@ -1,74 +0,0 @@
from abc import ABC, abstractmethod
from typing import NamedTuple, Tuple
import torch
import torch.nn as nn
__all__ = [
"ModelOutput",
"BackboneModel",
]
class ModelOutput(NamedTuple):
"""Output of the detection model.
Each of the tensors has a shape of
`(batch_size, num_channels, spec_height, spec_width)`.
Where `spec_height` and `spec_width` are the height and width of the
input spectrograms.
They contain localised information of:
1. The probability of a bounding box detection at the given location.
2. The predicted size of the bounding box at the given location.
3. The probabilities of each class at the given location before softmax.
4. Features used to make the predictions at the given location.
"""
detection_probs: torch.Tensor
"""Tensor with predict detection probabilities."""
size_preds: torch.Tensor
"""Tensor with predicted bounding box sizes."""
class_probs: torch.Tensor
"""Tensor with predicted class probabilities."""
features: torch.Tensor
"""Tensor with intermediate features."""
class BackboneModel(ABC, nn.Module):
input_height: int
"""Height of the input spectrogram."""
encoder_channels: Tuple[int, ...]
"""Tuple specifying the number of channels for each convolutional layer
in the encoder. The length of this tuple determines the number of
encoder layers."""
decoder_channels: Tuple[int, ...]
"""Tuple specifying the number of channels for each convolutional layer
in the decoder. The length of this tuple determines the number of
decoder layers."""
bottleneck_channels: int
"""Number of channels in the bottleneck layer, which connects the
encoder and decoder."""
out_channels: int
"""Number of channels in the final output feature map produced by the
backbone model."""
@abstractmethod
def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Forward pass of the model."""
class DetectionModel(ABC, nn.Module):
@abstractmethod
def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Forward pass of the detection model."""

View File

@ -1,181 +0,0 @@
from pathlib import Path
from typing import Optional
import lightning as L
import torch
from pydantic import Field
from soundevent import data
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from batdetect2.configs import BaseConfig
from batdetect2.evaluate.evaluate import match_predictions_and_annotations
from batdetect2.models import (
BBoxHead,
ClassifierHead,
ModelConfig,
build_architecture,
)
from batdetect2.models.typing import ModelOutput
from batdetect2.post_process import (
PostprocessConfig,
postprocess_model_outputs,
)
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.losses import compute_loss
from batdetect2.train.targets import (
TargetConfig,
build_decoder,
build_target_encoder,
get_class_names,
)
__all__ = [
"DetectorModel",
]
class ModuleConfig(BaseConfig):
train: TrainingConfig = Field(default_factory=TrainingConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
architecture: ModelConfig = Field(default_factory=ModelConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocessing: PostprocessConfig = Field(
default_factory=PostprocessConfig
)
class DetectorModel(L.LightningModule):
config: ModuleConfig
def __init__(
self,
config: Optional[ModuleConfig] = None,
):
super().__init__()
self.config = config or ModuleConfig()
self.save_hyperparameters()
self.backbone = build_architecture(self.config.architecture)
self.classifier = ClassifierHead(
num_classes=len(self.config.targets.classes),
in_channels=self.backbone.out_channels,
)
self.bbox = BBoxHead(in_channels=self.backbone.out_channels)
conf = self.config.train.loss.classification
self.class_weights = (
torch.tensor(conf.class_weights) if conf.class_weights else None
)
# Training targets
self.class_names = get_class_names(self.config.targets.classes)
self.encoder = build_target_encoder(
self.config.targets.classes,
replacement_rules=self.config.targets.replace,
)
self.decoder = build_decoder(self.config.targets.classes)
self.validation_predictions = []
self.example_input_array = torch.randn([1, 1, 128, 512])
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
features = self.backbone(spec)
detection_probs, classification_probs = self.classifier(features)
size_preds = self.bbox(features)
return ModelOutput(
detection_probs=detection_probs,
size_preds=size_preds,
class_probs=classification_probs,
features=features,
)
def training_step(self, batch: TrainExample):
outputs = self.forward(batch.spec)
losses = compute_loss(
batch,
outputs,
conf=self.config.train.loss,
class_weights=self.class_weights,
)
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
self.log("train/loss/detection", losses.total, logger=True)
self.log("train/loss/size", losses.total, logger=True)
self.log("train/loss/classification", losses.total, logger=True)
return losses.total
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
outputs = self.forward(batch.spec)
losses = compute_loss(
batch,
outputs,
conf=self.config.train.loss,
class_weights=self.class_weights,
)
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
self.log("val/loss/detection", losses.total, logger=True)
self.log("val/loss/size", losses.total, logger=True)
self.log("val/loss/classification", losses.total, logger=True)
dataloaders = self.trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset)
clip_annotation = dataset.get_clip_annotation(batch_idx)
clip_prediction = postprocess_model_outputs(
outputs,
clips=[clip_annotation.clip],
classes=self.class_names,
decoder=self.decoder,
config=self.config.postprocessing,
)[0]
matches = match_predictions_and_annotations(
clip_annotation,
clip_prediction,
)
self.validation_predictions.extend(matches)
def on_validation_epoch_end(self) -> None:
self.validation_predictions.clear()
def configure_optimizers(self):
conf = self.config.train.optimizer
optimizer = Adam(self.parameters(), lr=conf.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
return [optimizer], [scheduler]
def process_clip(
self,
clip: data.Clip,
audio_dir: Optional[Path] = None,
) -> data.ClipPrediction:
spec = preprocess_audio_clip(
clip,
config=self.config.preprocessing,
audio_dir=audio_dir,
)
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
outputs = self.forward(tensor)
return postprocess_model_outputs(
outputs,
clips=[clip],
classes=self.class_names,
decoder=self.decoder,
config=self.config.postprocessing,
)[0]

View File

@ -2,7 +2,6 @@
from typing import Optional, Tuple
import matplotlib.pyplot as plt
import xarray as xr
from matplotlib import axes

View File

@ -1,398 +0,0 @@
"""Module for postprocessing model outputs."""
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
import numpy as np
import torch
from pydantic import Field
from soundevent import data
from torch import nn
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.typing import ModelOutput
__all__ = [
"PostprocessConfig",
"load_postprocess_config",
"postprocess_model_outputs",
]
NMS_KERNEL_SIZE = 9
DETECTION_THRESHOLD = 0.01
TOP_K_PER_SEC = 200
class PostprocessConfig(BaseConfig):
"""Configuration for postprocessing model outputs."""
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0)
min_freq: int = Field(default=10000, gt=0)
max_freq: int = Field(default=120000, gt=0)
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
def load_postprocess_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PostprocessConfig:
return load_config(path, schema=PostprocessConfig, field=field)
class RawPrediction(NamedTuple):
start_time: float
end_time: float
low_freq: float
high_freq: float
detection_score: float
class_scores: Dict[str, float]
features: np.ndarray
def postprocess_model_outputs(
outputs: ModelOutput,
clips: List[data.Clip],
classes: List[str],
decoder: Callable[[str], List[data.Tag]],
config: Optional[PostprocessConfig] = None,
) -> List[data.ClipPrediction]:
"""Postprocesses model outputs to generate clip predictions.
This function takes the output from the model, applies non-maximum suppression,
selects the top-k scores, computes sound events from the outputs, and returns
clip predictions based on these processed outputs.
Parameters
----------
outputs
Output from the model containing detection probabilities, size
predictions, class logits, and features. All tensors are expected
to have a batch dimension.
clips
List of clips for which predictions are made. The number of clips
must match the batch dimension of the model outputs.
config
Configuration for postprocessing model outputs.
Returns
-------
predictions: List[data.ClipPrediction]
List of clip predictions containing predicted sound events.
Raises
------
ValueError
If the number of predictions does not match the number of clips.
"""
config = config or PostprocessConfig()
num_predictions = len(outputs.detection_probs)
if num_predictions == 0:
return []
if num_predictions != len(clips):
raise ValueError(
"Number of predictions must match the number of clips."
)
detection_probs = non_max_suppression(
outputs.detection_probs,
kernel_size=config.nms_kernel_size,
)
duration = clips[0].end_time - clips[0].start_time
scores_batch, y_pos_batch, x_pos_batch = get_topk_scores(
detection_probs,
int(config.top_k_per_sec * duration / 2),
)
predictions: List[data.ClipPrediction] = []
for scores, y_pos, x_pos, size_preds, class_probs, features, clip in zip(
scores_batch,
y_pos_batch,
x_pos_batch,
outputs.size_preds,
outputs.class_probs,
outputs.features,
clips,
):
sound_events = compute_sound_events_from_outputs(
clip,
scores,
y_pos,
x_pos,
size_preds,
class_probs,
features,
classes=classes,
decoder=decoder,
min_freq=config.min_freq,
max_freq=config.max_freq,
detection_threshold=config.detection_threshold,
)
predictions.append(
data.ClipPrediction(
clip=clip,
sound_events=sound_events,
)
)
return predictions
def compute_predictions_from_outputs(
start: float,
end: float,
scores: torch.Tensor,
y_pos: torch.Tensor,
x_pos: torch.Tensor,
size_preds: torch.Tensor,
class_probs: torch.Tensor,
features: torch.Tensor,
classes: List[str],
min_freq: int = 10000,
max_freq: int = 120000,
detection_threshold: float = DETECTION_THRESHOLD,
) -> List[RawPrediction]:
_, freq_bins, time_bins = size_preds.shape
sorted_indices = torch.argsort(x_pos)
valid_indices = sorted_indices[
scores[sorted_indices] > detection_threshold
]
scores = scores[valid_indices]
x_pos = x_pos[valid_indices]
y_pos = y_pos[valid_indices]
predictions: List[RawPrediction] = []
for score, x, y in zip(scores, x_pos, y_pos):
width, height = size_preds[:, y, x]
class_prob = class_probs[:, y, x].detach().numpy()
feats = features[:, y, x].detach().numpy()
start_time = np.interp(
x.item(),
[0, time_bins],
[start, end],
)
end_time = np.interp(
x.item() + width.item(),
[0, time_bins],
[start, end],
)
low_freq = np.interp(
y.item(),
[0, freq_bins],
[max_freq, min_freq],
)
high_freq = np.interp(
y.item() - height.item(),
[0, freq_bins],
[max_freq, min_freq],
)
start_time, end_time = sorted([float(start_time), float(end_time)])
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
predictions.append(
RawPrediction(
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
detection_score=score.item(),
features=feats,
class_scores={
class_name: prob
for class_name, prob in zip(classes, class_prob)
},
)
)
return predictions
def compute_sound_events_from_outputs(
clip: data.Clip,
scores: torch.Tensor,
y_pos: torch.Tensor,
x_pos: torch.Tensor,
size_preds: torch.Tensor,
class_probs: torch.Tensor,
features: torch.Tensor,
classes: List[str],
decoder: Callable[[str], List[data.Tag]],
min_freq: int = 10000,
max_freq: int = 120000,
detection_threshold: float = DETECTION_THRESHOLD,
) -> List[data.SoundEventPrediction]:
_, freq_bins, time_bins = size_preds.shape
sorted_indices = torch.argsort(x_pos)
valid_indices = sorted_indices[
scores[sorted_indices] > detection_threshold
]
scores = scores[valid_indices]
x_pos = x_pos[valid_indices]
y_pos = y_pos[valid_indices]
predictions: List[data.SoundEventPrediction] = []
for score, x, y in zip(scores, x_pos, y_pos):
width, height = size_preds[:, y, x]
class_prob = class_probs[:, y, x]
feature = features[:, y, x]
start_time = np.interp(
x.item(),
[0, time_bins],
[clip.start_time, clip.end_time],
)
end_time = np.interp(
x.item() + width.item(),
[0, time_bins],
[clip.start_time, clip.end_time],
)
low_freq = np.interp(
y.item(),
[0, freq_bins],
[max_freq, min_freq],
)
high_freq = np.interp(
y.item() - height.item(),
[0, freq_bins],
[max_freq, min_freq],
)
predicted_tags: List[data.PredictedTag] = []
for label_id, class_score in enumerate(class_prob):
class_name = classes[label_id]
corresponding_tags = decoder(class_name)
predicted_tags.extend(
[
data.PredictedTag(
tag=tag,
score=max(min(class_score.item(), 1), 0),
)
for tag in corresponding_tags
]
)
start_time, end_time = sorted([float(start_time), float(end_time)])
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
sound_event = data.SoundEvent(
recording=clip.recording,
geometry=data.BoundingBox(
coordinates=[
start_time,
low_freq,
end_time,
high_freq,
]
),
features=[
data.Feature(
term=data.term_from_key(f"batdetect2_{i}"),
value=value.item(),
)
for i, value in enumerate(feature)
],
)
predictions.append(
data.SoundEventPrediction(
sound_event=sound_event,
score=max(min(score.item(), 1), 0),
tags=predicted_tags,
)
)
return predictions
def non_max_suppression(
tensor: torch.Tensor,
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
) -> torch.Tensor:
"""Run non-maximum suppression on a tensor.
This function removes values from the input tensor that are not local
maxima in the neighborhood of the given kernel size.
All non-maximum values are set to zero.
Parameters
----------
tensor : torch.Tensor
Input tensor.
kernel_size : Union[int, Tuple[int, int]], optional
Size of the neighborhood to consider for non-maximum suppression.
If an integer is given, the neighborhood will be a square of the
given size. If a tuple is given, the neighborhood will be a
rectangle with the given height and width.
Returns
-------
torch.Tensor
Tensor with non-maximum suppressed values.
"""
if isinstance(kernel_size, int):
kernel_size_h = kernel_size
kernel_size_w = kernel_size
else:
kernel_size_h, kernel_size_w = kernel_size
pad_h = (kernel_size_h - 1) // 2
pad_w = (kernel_size_w - 1) // 2
hmax = nn.functional.max_pool2d(
tensor,
(kernel_size_h, kernel_size_w),
stride=1,
padding=(pad_h, pad_w),
)
keep = (hmax == tensor).float()
return tensor * keep
def get_topk_scores(
scores: torch.Tensor,
K: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Get the top-k scores and their indices.
Parameters
----------
scores : torch.Tensor
Tensor with scores. Expects input of size: `batch x 1 x height x width`.
K : int
Number of top scores to return.
Returns
-------
scores : torch.Tensor
Top-k scores.
ys : torch.Tensor
Y coordinates of the top-k scores.
xs : torch.Tensor
X coordinates of the top-k scores.
"""
batch, _, height, width = scores.size()
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K)
topk_inds = topk_inds % (height * width)
topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long()
topk_xs = (topk_inds % width).long()
return topk_scores, topk_ys, topk_xs

View File

@ -0,0 +1,566 @@
"""Main entry point for the BatDetect2 Postprocessing pipeline.
This package (`batdetect2.postprocess`) takes the raw outputs from a trained
BatDetect2 neural network model and transforms them into meaningful, structured
predictions, typically in the form of `soundevent.data.ClipPrediction` objects
containing detected sound events with associated class tags and geometry.
The pipeline involves several configurable steps, implemented in submodules:
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
2. Coordinate Remapping (`.remapping`): Adds real-world time/frequency
coordinates to raw model output arrays.
3. Detection Extraction (`.detection`): Identifies candidate detection points
(location and score) based on thresholds and score ranking (top-k).
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
class probabilities, features) at the detected locations.
5. Decoding & Formatting (`.decoding`): Converts extracted numerical data and
class predictions into interpretable `soundevent` objects, including
recovering geometry (ROIs) and decoding class names back to standard tags.
This module provides the primary interface:
- `PostprocessConfig`: A configuration object for postprocessing parameters
(thresholds, NMS kernel size, etc.).
- `load_postprocess_config`: Function to load the configuration from a file.
- `Postprocessor`: The main class (implementing `PostprocessorProtocol`) that
holds the configured pipeline logic.
- `build_postprocessor`: A factory function to create a `Postprocessor`
instance, linking it to the necessary target definitions (`TargetProtocol`).
It also re-exports key components from submodules for convenience.
"""
from typing import List, Optional
import xarray as xr
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.types import ModelOutput
from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_predictions_to_clip_prediction,
convert_xr_dataset_to_raw_prediction,
)
from batdetect2.postprocess.detection import (
DEFAULT_DETECTION_THRESHOLD,
TOP_K_PER_SEC,
extract_detections_from_array,
get_max_detections,
)
from batdetect2.postprocess.extraction import (
extract_detection_xr_dataset,
)
from batdetect2.postprocess.nms import (
NMS_KERNEL_SIZE,
non_max_suppression,
)
from batdetect2.postprocess.remapping import (
classification_to_xarray,
detection_to_xarray,
features_to_xarray,
sizes_to_xarray,
)
from batdetect2.postprocess.types import PostprocessorProtocol, RawPrediction
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets.types import TargetProtocol
__all__ = [
"DEFAULT_CLASSIFICATION_THRESHOLD",
"DEFAULT_DETECTION_THRESHOLD",
"MAX_FREQ",
"MIN_FREQ",
"ModelOutput",
"NMS_KERNEL_SIZE",
"PostprocessConfig",
"Postprocessor",
"PostprocessorProtocol",
"RawPrediction",
"TOP_K_PER_SEC",
"build_postprocessor",
"classification_to_xarray",
"convert_raw_predictions_to_clip_prediction",
"convert_xr_dataset_to_raw_prediction",
"detection_to_xarray",
"extract_detection_xr_dataset",
"extract_detections_from_array",
"features_to_xarray",
"get_max_detections",
"load_postprocess_config",
"non_max_suppression",
"sizes_to_xarray",
]
class PostprocessConfig(BaseConfig):
"""Configuration settings for the postprocessing pipeline.
Defines tunable parameters that control how raw model outputs are
converted into final detections.
Attributes
----------
nms_kernel_size : int, default=NMS_KERNEL_SIZE
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
Used to suppress weaker detections near stronger peaks. Must be
positive.
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
Minimum confidence score from the detection heatmap required to
consider a point as a potential detection. Must be >= 0.
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
Minimum confidence score for a specific class prediction to be included
in the decoded tags for a detection. Must be >= 0.
top_k_per_sec : int, default=TOP_K_PER_SEC
Desired maximum number of detections per second of audio. Used by
`get_max_detections` to calculate an absolute limit based on clip
duration before applying `extract_detections_from_array`. Must be
positive.
"""
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
detection_threshold: float = Field(
default=DEFAULT_DETECTION_THRESHOLD,
ge=0,
)
classification_threshold: float = Field(
default=DEFAULT_CLASSIFICATION_THRESHOLD,
ge=0,
)
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
def load_postprocess_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PostprocessConfig:
"""Load the postprocessing configuration from a file.
Reads a configuration file (YAML) and validates it against the
`PostprocessConfig` schema, potentially extracting data from a nested
field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
postprocessing configuration (e.g., "inference.postprocessing").
If None, the entire file content is used.
Returns
-------
PostprocessConfig
The loaded and validated postprocessing configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded configuration data does not conform to the
`PostprocessConfig` schema.
KeyError, TypeError
If `field` specifies an invalid path within the loaded data.
"""
return load_config(path, schema=PostprocessConfig, field=field)
def build_postprocessor(
targets: TargetProtocol,
config: Optional[PostprocessConfig] = None,
max_freq: int = MAX_FREQ,
min_freq: int = MIN_FREQ,
) -> PostprocessorProtocol:
"""Factory function to build the standard postprocessor.
Creates and initializes the `Postprocessor` instance, providing it with the
necessary `targets` object and the `PostprocessConfig`.
Parameters
----------
targets : TargetProtocol
An initialized object conforming to the `TargetProtocol`, providing
methods like `.decode()` and `.recover_roi()`, and attributes like
`.class_names` and `.generic_class_tags`. This links postprocessing
to the defined target semantics and geometry mappings.
config : PostprocessConfig, optional
Configuration object specifying postprocessing parameters (thresholds,
NMS kernel size, etc.). If None, default settings defined in
`PostprocessConfig` will be used.
min_freq : int, default=MIN_FREQ
The minimum frequency (Hz) corresponding to the frequency axis of the
model outputs. Required for coordinate remapping. Consider setting via
`PostprocessConfig` instead for better encapsulation.
max_freq : int, default=MAX_FREQ
The maximum frequency (Hz) corresponding to the frequency axis of the
model outputs. Required for coordinate remapping. Consider setting via
`PostprocessConfig`.
Returns
-------
PostprocessorProtocol
An initialized `Postprocessor` instance ready to process model outputs.
"""
return Postprocessor(
targets=targets,
config=config or PostprocessConfig(),
min_freq=min_freq,
max_freq=max_freq,
)
class Postprocessor(PostprocessorProtocol):
"""Standard implementation of the postprocessing pipeline.
This class orchestrates the steps required to convert raw model outputs
into interpretable `soundevent` predictions. It uses configured parameters
and leverages functions from the `batdetect2.postprocess` submodules for
each stage (NMS, remapping, detection, extraction, decoding).
It requires a `TargetProtocol` object during initialization to access
necessary decoding information (class name to tag mapping,
ROI recovery logic) ensuring consistency with the target definitions used
during training or specified for inference.
Instances are typically created using the `build_postprocessor` factory
function.
Attributes
----------
targets : TargetProtocol
The configured target definition object providing decoding and ROI
recovery.
config : PostprocessConfig
Configuration object holding parameters for NMS, thresholds, etc.
min_freq : int
Minimum frequency (Hz) assumed for the model output's frequency axis.
max_freq : int
Maximum frequency (Hz) assumed for the model output's frequency axis.
"""
targets: TargetProtocol
def __init__(
self,
targets: TargetProtocol,
config: PostprocessConfig,
min_freq: int = MIN_FREQ,
max_freq: int = MAX_FREQ,
):
"""Initialize the Postprocessor.
Parameters
----------
targets : TargetProtocol
Initialized target definition object.
config : PostprocessConfig
Configuration for postprocessing parameters.
min_freq : int, default=MIN_FREQ
Minimum frequency (Hz) for coordinate remapping.
max_freq : int, default=MAX_FREQ
Maximum frequency (Hz) for coordinate remapping.
"""
self.targets = targets
self.config = config
self.min_freq = min_freq
self.max_freq = max_freq
def get_feature_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Extract and remap raw feature tensors for a batch.
Parameters
----------
output : ModelOutput
Raw model output containing `output.features` tensor for the batch.
clips : List[data.Clip]
List of Clip objects corresponding to the batch items.
Returns
-------
List[xr.DataArray]
List of coordinate-aware feature DataArrays, one per clip.
Raises
------
ValueError
If batch sizes of `output.features` and `clips` do not match.
"""
if len(clips) != len(output.features):
raise ValueError(
"Number of clips and batch size of feature array"
"do not match. "
f"(clips: {len(clips)}, features: {len(output.features)})"
)
return [
features_to_xarray(
feats,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for feats, clip in zip(output.features, clips)
]
def get_detection_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Apply NMS and remap detection heatmaps for a batch.
Parameters
----------
output : ModelOutput
Raw model output containing `output.detection_probs` tensor for the
batch.
clips : List[data.Clip]
List of Clip objects corresponding to the batch items.
Returns
-------
List[xr.DataArray]
List of NMS-applied, coordinate-aware detection heatmaps, one per
clip.
Raises
------
ValueError
If batch sizes of `output.detection_probs` and `clips` do not match.
"""
detections = output.detection_probs
if len(clips) != len(output.detection_probs):
raise ValueError(
"Number of clips and batch size of detection array "
"do not match. "
f"(clips: {len(clips)}, detection: {len(detections)})"
)
detections = non_max_suppression(
detections,
kernel_size=self.config.nms_kernel_size,
)
return [
detection_to_xarray(
dets,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for dets, clip in zip(detections, clips)
]
def get_classification_arrays(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[xr.DataArray]:
"""Extract and remap raw classification tensors for a batch.
Parameters
----------
output : ModelOutput
Raw model output containing `output.class_probs` tensor for the
batch.
clips : List[data.Clip]
List of Clip objects corresponding to the batch items.
Returns
-------
List[xr.DataArray]
List of coordinate-aware class probability maps, one per clip.
Raises
------
ValueError
If batch sizes of `output.class_probs` and `clips` do not match, or
if number of classes mismatches `self.targets.class_names`.
"""
classifications = output.class_probs
if len(clips) != len(classifications):
raise ValueError(
"Number of clips and batch size of classification array "
"do not match. "
f"(clips: {len(clips)}, classification: {len(classifications)})"
)
return [
classification_to_xarray(
class_probs,
start_time=clip.start_time,
end_time=clip.end_time,
class_names=self.targets.class_names,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for class_probs, clip in zip(classifications, clips)
]
def get_sizes_arrays(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[xr.DataArray]:
"""Extract and remap raw size prediction tensors for a batch.
Parameters
----------
output : ModelOutput
Raw model output containing `output.size_preds` tensor for the
batch.
clips : List[data.Clip]
List of Clip objects corresponding to the batch items.
Returns
-------
List[xr.DataArray]
List of coordinate-aware size prediction maps, one per clip.
Raises
------
ValueError
If batch sizes of `output.size_preds` and `clips` do not match.
"""
sizes = output.size_preds
if len(clips) != len(sizes):
raise ValueError(
"Number of clips and batch size of sizes array do not match. "
f"(clips: {len(clips)}, sizes: {len(sizes)})"
)
return [
sizes_to_xarray(
size_preds,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for size_preds, clip in zip(sizes, clips)
]
def get_detection_datasets(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[xr.Dataset]:
"""Perform NMS, remapping, detection, and data extraction for a batch.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[xr.Dataset]
List of xarray Datasets (one per clip). Each Dataset contains
aligned scores, dimensions, class probabilities, and features for
detections found in that clip.
"""
detection_arrays = self.get_detection_arrays(output, clips)
classification_arrays = self.get_classification_arrays(output, clips)
size_arrays = self.get_sizes_arrays(output, clips)
features_arrays = self.get_feature_arrays(output, clips)
datasets = []
for det_array, class_array, sizes_array, feats_array in zip(
detection_arrays,
classification_arrays,
size_arrays,
features_arrays,
):
max_detections = get_max_detections(
det_array,
top_k_per_sec=self.config.top_k_per_sec,
)
positions = extract_detections_from_array(
det_array,
max_detections=max_detections,
threshold=self.config.detection_threshold,
)
datasets.append(
extract_detection_xr_dataset(
positions,
sizes_array,
class_array,
feats_array,
)
)
return datasets
def get_raw_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Processes raw model output through remapping, NMS, detection, data
extraction, and geometry recovery via the configured
`targets.recover_roi`.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[List[RawPrediction]]
List of lists (one inner list per input clip). Each inner list
contains `RawPrediction` objects for detections in that clip.
"""
detection_datasets = self.get_detection_datasets(output, clips)
return [
convert_xr_dataset_to_raw_prediction(
dataset,
self.targets.recover_roi,
)
for dataset in detection_datasets
]
def get_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output and corresponding clips, applies the entire
configured chain (NMS, remapping, extraction, geometry recovery, class
decoding), producing final `soundevent.data.ClipPrediction` objects.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[data.ClipPrediction]
List containing one `ClipPrediction` object for each input clip,
populated with `SoundEventPrediction` objects.
"""
raw_predictions = self.get_raw_predictions(output, clips)
return [
convert_raw_predictions_to_clip_prediction(
prediction,
clip,
sound_event_decoder=self.targets.decode,
generic_class_tags=self.targets.generic_class_tags,
classification_threshold=self.config.classification_threshold,
)
for prediction, clip in zip(raw_predictions, clips)
]

View File

@ -0,0 +1,409 @@
"""Decodes extracted detection data into standard soundevent predictions.
This module handles the final stages of the BatDetect2 postprocessing pipeline.
It takes the structured detection data extracted by the `extraction` module
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
class probabilities, and features for each detection point) and converts it
into meaningful, standardized prediction objects based on the `soundevent` data
model.
The process involves:
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
objects, using a configured geometry builder to recover bounding boxes from
predicted positions and sizes (`convert_xr_dataset_to_raw_prediction`).
2. Converting each `RawPrediction` into a
`soundevent.data.SoundEventPrediction`, which involves:
- Creating the `soundevent.data.SoundEvent` with geometry and features.
- Decoding the predicted class probabilities into representative tags using
a configured class decoder (`SoundEventDecoder`).
- Applying a classification threshold.
- Optionally selecting only the single highest-scoring class (top-1) or
including tags for all classes above the threshold (multi-label).
- Adding generic class tags as a baseline.
- Associating scores with the final prediction and tags.
(`convert_raw_prediction_to_sound_event_prediction`)
3. Grouping the `SoundEventPrediction` objects for a given audio segment into
a `soundevent.data.ClipPrediction`
(`convert_raw_predictions_to_clip_prediction`).
"""
from typing import List, Optional
import numpy as np
import xarray as xr
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
from batdetect2.targets.classes import SoundEventDecoder
__all__ = [
"convert_xr_dataset_to_raw_prediction",
"convert_raw_predictions_to_clip_prediction",
"convert_raw_prediction_to_sound_event_prediction",
"DEFAULT_CLASSIFICATION_THRESHOLD",
]
DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
"""Default threshold applied to classification scores.
Class predictions with scores below this value are typically ignored during
decoding.
"""
def convert_xr_dataset_to_raw_prediction(
detection_dataset: xr.Dataset,
geometry_builder: GeometryBuilder,
) -> List[RawPrediction]:
"""Convert an xarray.Dataset of detections to RawPrediction objects.
Takes the output of the extraction step (`extract_detection_xr_dataset`)
and transforms each detection entry into an intermediate `RawPrediction`
object. This involves recovering the geometry (e.g., bounding box) from
the predicted position and scaled size dimensions using the provided
`geometry_builder` function.
Parameters
----------
detection_dataset : xr.Dataset
An xarray Dataset containing aligned detection information, typically
output by `extract_detection_xr_dataset`. Expected variables include
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
Must have a 'detection' dimension.
geometry_builder : GeometryBuilder
A function that takes a position tuple `(time, freq)` and a NumPy array
of dimensions, and returns the corresponding reconstructed
`soundevent.data.Geometry`.
Returns
-------
List[RawPrediction]
A list of `RawPrediction` objects, each containing the detection score,
recovered bounding box coordinates (start/end time, low/high freq),
the vector of class scores, and the feature vector for one detection.
Raises
------
AttributeError, KeyError, ValueError
If `detection_dataset` is missing expected variables ('scores',
'dimensions', 'classes', 'features') or coordinates ('time', 'freq'
associated with 'scores'), or if `geometry_builder` fails.
"""
detections = []
for det_num in range(detection_dataset.sizes["detection"]):
det_info = detection_dataset.sel(detection=det_num)
geom = geometry_builder(
(det_info.time, det_info.freq),
det_info.dimensions,
)
start_time, low_freq, end_time, high_freq = compute_bounds(geom)
detections.append(
RawPrediction(
detection_score=det_info.score,
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
class_scores=det_info.classes,
features=det_info.features,
)
)
return detections
def convert_raw_predictions_to_clip_prediction(
raw_predictions: List[RawPrediction],
clip: data.Clip,
sound_event_decoder: SoundEventDecoder,
generic_class_tags: List[data.Tag],
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False,
) -> data.ClipPrediction:
"""Convert a list of RawPredictions into a soundevent ClipPrediction.
Iterates through `raw_predictions` (assumed to belong to a single clip),
converts each one into a `soundevent.data.SoundEventPrediction` using
`convert_raw_prediction_to_sound_event_prediction`, and packages them
into a `soundevent.data.ClipPrediction` associated with the original `clip`.
Parameters
----------
raw_predictions : List[RawPrediction]
List of raw prediction objects for a single clip.
clip : data.Clip
The original `soundevent.data.Clip` object these predictions belong to.
sound_event_decoder : SoundEventDecoder
Function to decode class names into representative tags.
generic_class_tags : List[data.Tag]
List of tags representing the generic class category.
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
Threshold applied to class scores during decoding.
top_class_only : bool, default=False
If True, only decode tags for the single highest-scoring class above
the threshold. If False, decode tags for all classes above threshold.
Returns
-------
data.ClipPrediction
A `ClipPrediction` object containing a list of `SoundEventPrediction`
objects corresponding to the input `raw_predictions`.
"""
return data.ClipPrediction(
clip=clip,
sound_events=[
convert_raw_prediction_to_sound_event_prediction(
prediction,
recording=clip.recording,
sound_event_decoder=sound_event_decoder,
generic_class_tags=generic_class_tags,
classification_threshold=classification_threshold,
top_class_only=top_class_only,
)
for prediction in raw_predictions
],
)
def convert_raw_prediction_to_sound_event_prediction(
raw_prediction: RawPrediction,
recording: data.Recording,
sound_event_decoder: SoundEventDecoder,
generic_class_tags: List[data.Tag],
classification_threshold: Optional[
float
] = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False,
):
"""Convert a single RawPrediction into a soundevent SoundEventPrediction.
This function performs the core decoding steps for a single detected event:
1. Creates a `soundevent.data.SoundEvent` containing the geometry
(BoundingBox derived from `raw_prediction` bounds) and any associated
feature vectors.
2. Initializes a list of predicted tags using the provided
`generic_class_tags`, assigning the overall `detection_score` from the
`raw_prediction` to these generic tags.
3. Processes the `class_scores` from the `raw_prediction`:
a. Optionally filters out scores below `classification_threshold`
(if it's not None).
b. Sorts the remaining scores in descending order.
c. Iterates through the sorted, thresholded class scores.
d. For each class, uses the `sound_event_decoder` to get the
representative base tags for that class name.
e. Wraps these base tags in `soundevent.data.PredictedTag`, associating
the specific `score` of that class prediction.
f. Appends these specific predicted tags to the list.
g. If `top_class_only` is True, stops after processing the first
(highest-scoring) class that passed the threshold.
4. Creates and returns the final `soundevent.data.SoundEventPrediction`,
associating the `SoundEvent`, the overall `detection_score`, and the
compiled list of `PredictedTag` objects.
Parameters
----------
raw_prediction : RawPrediction
The raw prediction object containing score, bounds, class scores,
features. Assumes `class_scores` is an `xr.DataArray` with a 'category'
coordinate. Assumes `features` is an `xr.DataArray` with a 'feature'
coordinate.
recording : data.Recording
The recording the sound event belongs to.
sound_event_decoder : SoundEventDecoder
Configured function mapping class names (str) to lists of base
`data.Tag` objects.
generic_class_tags : List[data.Tag]
List of base tags representing the generic category.
classification_threshold : float, optional
The minimum score a class prediction must have to be considered
significant enough to have its tags decoded and added. If None, no
thresholding is applied based on class score (all predicted classes,
or the top one if `top_class_only` is True, will be processed).
Defaults to `DEFAULT_CLASSIFICATION_THRESHOLD`.
top_class_only : bool, default=False
If True, only includes tags for the single highest-scoring class that
exceeds the threshold. If False (default), includes tags for all classes
exceeding the threshold.
Returns
-------
data.SoundEventPrediction
The fully formed sound event prediction object.
Raises
------
ValueError
If `raw_prediction.features` has unexpected structure or if
`data.term_from_key` (if used internally) fails.
If `sound_event_decoder` fails for a class name and errors are raised.
"""
sound_event = data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(
coordinates=[
raw_prediction.start_time,
raw_prediction.low_freq,
raw_prediction.end_time,
raw_prediction.high_freq,
]
),
features=get_prediction_features(raw_prediction.features),
)
tags = [
*get_generic_tags(
raw_prediction.detection_score,
generic_class_tags=generic_class_tags,
),
*get_class_tags(
raw_prediction.class_scores,
sound_event_decoder,
top_class_only=top_class_only,
threshold=classification_threshold,
),
]
return data.SoundEventPrediction(
sound_event=sound_event,
score=raw_prediction.detection_score,
tags=tags,
)
def get_generic_tags(
detection_score: float,
generic_class_tags: List[data.Tag],
) -> List[data.PredictedTag]:
"""Create PredictedTag objects for the generic category.
Takes the base list of generic tags and assigns the overall detection
score to each one, wrapping them in `PredictedTag` objects.
Parameters
----------
detection_score : float
The overall confidence score of the detection event.
generic_class_tags : List[data.Tag]
The list of base `soundevent.data.Tag` objects that define the
generic category (e.g., ['call_type:Echolocation', 'order:Chiroptera']).
Returns
-------
List[data.PredictedTag]
A list of `PredictedTag` objects for the generic category, each
assigned the `detection_score`.
"""
return [
data.PredictedTag(tag=tag, score=detection_score)
for tag in generic_class_tags
]
def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
"""Convert an extracted feature vector DataArray into soundevent Features.
Parameters
----------
features : xr.DataArray
A 1D xarray DataArray containing feature values, indexed by a coordinate
named 'feature' which holds the feature names (e.g., output of selecting
features for one detection from `extract_detection_xr_dataset`).
Returns
-------
List[data.Feature]
A list of `soundevent.data.Feature` objects.
Notes
-----
- This function creates basic `Term` objects using the feature coordinate
names with a "batdetect2:" prefix.
"""
return [
data.Feature(
term=data.Term(
name=f"batdetect2:{feat_name}",
label=feat_name,
definition="Automatically extracted features by BatDetect2",
),
value=value,
)
for feat_name, value in _iterate_over_array(features)
]
def get_class_tags(
class_scores: xr.DataArray,
sound_event_decoder: SoundEventDecoder,
top_class_only: bool = False,
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.PredictedTag]:
"""Generate specific PredictedTags based on class scores and decoder.
Filters class scores by the threshold, sorts remaining scores descending,
decodes the class name(s) into base tags using the `sound_event_decoder`,
and creates `PredictedTag` objects associating the class score. Stops after
the first (top) class if `top_class_only` is True.
Parameters
----------
class_scores : xr.DataArray
A 1D xarray DataArray containing class probabilities/scores, indexed
by a 'category' coordinate holding the class names.
sound_event_decoder : SoundEventDecoder
Function to map a class name string to a list of base `data.Tag`
objects.
top_class_only : bool, default=False
If True, only generate tags for the single highest-scoring class above
the threshold.
threshold : float, optional
Minimum score for a class to be considered. If None, all classes are
processed (or top-1 if `top_class_only` is True). Defaults to
`DEFAULT_CLASSIFICATION_THRESHOLD`.
Returns
-------
List[data.PredictedTag]
A list of `PredictedTag` objects for the class(es) that passed the
threshold, ordered by score if `top_class_only` is False.
"""
tags = []
if threshold is not None:
class_scores = class_scores.where(class_scores > threshold, drop=True)
for class_name, score in _iterate_sorted(class_scores):
class_tags = sound_event_decoder(class_name)
for tag in class_tags:
tags.append(
data.PredictedTag(
tag=tag,
score=score,
)
)
if top_class_only:
break
return tags
def _iterate_over_array(array: xr.DataArray):
dim_name = array.dims[0]
coords = array.coords[dim_name]
for value, coord in zip(array.values, coords.values):
yield coord, float(value)
def _iterate_sorted(array: xr.DataArray):
dim_name = array.dims[0]
coords = array.coords[dim_name].values
indices = np.argsort(-array.values)
for index in indices:
yield str(coords[index]), float(array.values[index])

View File

@ -0,0 +1,162 @@
"""Extracts candidate detection points from a model output heatmap.
This module implements a specific step within the BatDetect2 postprocessing
pipeline. Its primary function is to identify potential sound event locations
by finding peaks (local maxima or high-scoring points) in the detection heatmap
produced by the neural network (usually after Non-Maximum Suppression and
coordinate remapping have been applied).
It provides functionality to:
- Identify the locations (time, frequency) of the highest-scoring points.
- Filter these points based on a minimum confidence score threshold.
- Limit the maximum number of detection points returned (top-k).
The main output is an `xarray.DataArray` containing the scores and
corresponding time/frequency coordinates for the extracted detection points.
This output serves as the input for subsequent postprocessing steps, such as
extracting predicted class probabilities and bounding box sizes at these
specific locations.
"""
from typing import Optional
import numpy as np
import xarray as xr
from soundevent.arrays import Dimensions, get_dim_width
__all__ = [
"extract_detections_from_array",
"get_max_detections",
"DEFAULT_DETECTION_THRESHOLD",
"TOP_K_PER_SEC",
]
DEFAULT_DETECTION_THRESHOLD = 0.01
"""Default confidence score threshold used for filtering detections."""
TOP_K_PER_SEC = 200
"""Default desired maximum number of detections per second of audio."""
def extract_detections_from_array(
detection_array: xr.DataArray,
max_detections: Optional[int] = None,
threshold: Optional[float] = DEFAULT_DETECTION_THRESHOLD,
) -> xr.DataArray:
"""Extract detection locations (time, freq) and scores from a heatmap.
Identifies the pixels with the highest scores in the input detection
heatmap, filters them based on an optional score `threshold`, limits the
number to an optional `max_detections`, and returns their scores along with
their corresponding time and frequency coordinates.
Parameters
----------
detection_array : xr.DataArray
A 2D xarray DataArray representing the detection heatmap. Must have
dimensions and coordinates named 'time' and 'frequency'. Higher values
are assumed to indicate higher detection confidence.
max_detections : int, optional
The absolute maximum number of detections to return. If specified, only
the top `max_detections` highest-scoring detections (passing the
threshold) are returned. If None (default), all detections passing
the threshold are returned, sorted by score.
threshold : float, optional
The minimum confidence score required for a detection peak to be
kept. Detections with scores below this value are discarded.
Defaults to `DEFAULT_DETECTION_THRESHOLD`. If set to None, no
thresholding is applied.
Returns
-------
xr.DataArray
A 1D xarray DataArray named 'score' with a 'detection' dimension.
- The data values are the scores of the extracted detections, sorted
in descending order.
- It has coordinates 'time' and 'frequency' (also indexed by the
'detection' dimension) indicating the location of each detection
peak in the original coordinate system.
- Returns an empty DataArray if no detections pass the criteria.
Raises
------
ValueError
If `max_detections` is not None and not a positive integer, or if
`detection_array` lacks required dimensions/coordinates.
"""
if max_detections is not None:
if max_detections <= 0:
raise ValueError("Max detections must be positive")
values = detection_array.values.flatten()
if max_detections is not None:
top_indices = np.argpartition(-values, max_detections)[:max_detections]
top_sorted_indices = top_indices[np.argsort(-values[top_indices])]
else:
top_sorted_indices = np.argsort(-values)
top_values = values[top_sorted_indices]
if threshold is not None:
mask = top_values > threshold
top_values = top_values[mask]
top_sorted_indices = top_sorted_indices[mask]
freq_indices, time_indices = np.unravel_index(
top_sorted_indices,
detection_array.shape,
)
times = detection_array.coords[Dimensions.time.value].values[time_indices]
freqs = detection_array.coords[Dimensions.frequency.value].values[
freq_indices
]
return xr.DataArray(
data=top_values,
coords={
Dimensions.frequency.value: ("detection", freqs),
Dimensions.time.value: ("detection", times),
},
dims="detection",
name="score",
)
def get_max_detections(
detection_array: xr.DataArray,
top_k_per_sec: int = TOP_K_PER_SEC,
) -> int:
"""Calculate max detections allowed based on duration and rate.
Determines the total maximum number of detections to extract from a
heatmap based on its time duration and a desired rate of detections
per second.
Parameters
----------
detection_array : xr.DataArray
The detection heatmap, requiring 'time' coordinates from which the
total duration can be calculated using
`soundevent.arrays.get_dim_width`.
top_k_per_sec : int, default=TOP_K_PER_SEC
The desired maximum number of detections to allow per second of audio.
Returns
-------
int
The calculated total maximum number of detections allowed for the
entire duration of the `detection_array`.
Raises
------
ValueError
If the duration cannot be calculated from the `detection_array` (e.g.,
missing or invalid 'time' coordinates/dimension).
"""
if top_k_per_sec < 0:
raise ValueError("top_k_per_sec cannot be negative.")
duration = get_dim_width(detection_array, Dimensions.time.value)
return int(duration * top_k_per_sec)

View File

@ -0,0 +1,122 @@
"""Extracts associated data for detected points from model output arrays.
This module implements a key step (Step 4) in the BatDetect2 postprocessing
pipeline. After candidate detection points (time, frequency, score) have been
identified, this module extracts the corresponding values from other raw model
output arrays, such as:
- Predicted bounding box sizes (width, height).
- Class probability scores for each defined target class.
- Intermediate feature vectors.
It uses coordinate-based indexing provided by `xarray` to ensure that the
correct values are retrieved from the original heatmaps/feature maps at the
precise time-frequency location of each detection. The final output aggregates
all extracted information into a structured `xarray.Dataset`.
"""
import xarray as xr
from soundevent.arrays import Dimensions
__all__ = [
"extract_values_at_positions",
"extract_detection_xr_dataset",
]
def extract_values_at_positions(
array: xr.DataArray,
positions: xr.DataArray,
) -> xr.DataArray:
"""Extract values from an array at specified time-frequency positions.
Uses coordinate-based indexing to retrieve values from a source `array`
(e.g., class probabilities, size predictions, features) at the time and
frequency coordinates defined in the `positions` array.
Parameters
----------
array : xr.DataArray
The source DataArray from which to extract values. Must have 'time'
and 'frequency' dimensions and coordinates matching the space of
`positions`.
positions : xr.DataArray
A 1D DataArray whose 'time' and 'frequency' coordinates specify the
locations from which to extract values.
Returns
-------
xr.DataArray
A DataArray containing the values extracted from `array` at the given
positions.
Raises
------
ValueError, IndexError, KeyError
If dimensions or coordinates are missing or incompatible between
`array` and `positions`, or if selection fails.
"""
return array.sel(
**{
Dimensions.frequency.value: positions.coords[
Dimensions.frequency.value
],
Dimensions.time.value: positions.coords[Dimensions.time.value],
}
).T
def extract_detection_xr_dataset(
positions: xr.DataArray,
sizes: xr.DataArray,
classes: xr.DataArray,
features: xr.DataArray,
) -> xr.Dataset:
"""Combine extracted detection information into a structured xr.Dataset.
Takes the detection positions/scores and the full model output heatmaps
(sizes, classes, optional features), extracts the relevant data at the
detection positions, and packages everything into a single `xarray.Dataset`
where all variables are indexed by a common 'detection' dimension.
Parameters
----------
positions : xr.DataArray
Output from `extract_detections_from_array`, containing detection
scores as data and 'time', 'frequency' coordinates along the
'detection' dimension.
sizes : xr.DataArray
The full size prediction heatmap from the model, with dimensions like
('dimension', 'time', 'frequency').
classes : xr.DataArray
The full class probability heatmap from the model, with dimensions like
('category', 'time', 'frequency').
features : xr.DataArray
The full feature map from the model, with
dimensions like ('feature', 'time', 'frequency').
Returns
-------
xr.Dataset
An xarray Dataset containing aligned information for each detection:
- 'scores': DataArray from `positions` (score data, time/freq coords).
- 'dimensions': DataArray with extracted size values
(dims: 'detection', 'dimension').
- 'classes': DataArray with extracted class probabilities
(dims: 'detection', 'category').
- 'features': DataArray with extracted feature vectors
(dims: 'detection', 'feature'), if `features` was provided. All
DataArrays share the 'detection' dimension and associated
time/frequency coordinates.
"""
sizes = extract_values_at_positions(sizes, positions)
classes = extract_values_at_positions(classes, positions)
features = extract_values_at_positions(features, positions)
return xr.Dataset(
{
"scores": positions,
"dimensions": sizes,
"classes": classes,
"features": features,
}
)

View File

@ -0,0 +1,96 @@
"""Performs Non-Maximum Suppression (NMS) on detection heatmaps.
This module provides functionality to apply Non-Maximum Suppression, a common
technique used after model inference, particularly in object detection and peak
detection tasks.
In the context of BatDetect2 postprocessing, NMS is applied
to the raw detection heatmap output by the neural network. Its purpose is to
isolate distinct detection peaks by suppressing (setting to zero) nearby heatmap
activations that have lower scores than a local maximum. This helps prevent
multiple, overlapping detections originating from the same sound event.
"""
from typing import Tuple, Union
import torch
NMS_KERNEL_SIZE = 9
"""Default kernel size (pixels) for Non-Maximum Suppression.
Specifies the side length of the square neighborhood used by default in
`non_max_suppression` to find local maxima. A 9x9 neighborhood is often
a reasonable starting point for typical spectrogram resolutions used in
BatDetect2.
"""
def non_max_suppression(
tensor: torch.Tensor,
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
) -> torch.Tensor:
"""Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap.
This function identifies local maxima within a defined neighborhood for
each point in the input tensor. Values that are *not* the maximum within
their neighborhood are suppressed (set to zero). This is commonly used on
detection probability heatmaps to isolate distinct peaks corresponding to
individual detections and remove redundant lower scores nearby.
The implementation uses efficient 2D max pooling to find the maximum value
in the neighborhood of each point.
Parameters
----------
tensor : torch.Tensor
Input tensor, typically representing a detection heatmap. Must be a
3D (C, H, W) or 4D (N, C, H, W) tensor as required by the underlying
`torch.nn.functional.max_pool2d` operation.
kernel_size : Union[int, Tuple[int, int]], default=NMS_KERNEL_SIZE
Size of the sliding window neighborhood used to find local maxima.
If an integer `k` is provided, a square kernel of size `(k, k)` is used.
If a tuple `(h, w)` is provided, a rectangular kernel of height `h`
and width `w` is used. The kernel size should typically be odd to
have a well-defined center.
Returns
-------
torch.Tensor
A tensor of the same shape as the input, where only local maxima within
their respective neighborhoods (defined by `kernel_size`) retain their
original values. All other values are set to zero.
Raises
------
TypeError
If `kernel_size` is not an int or a tuple of two ints.
RuntimeError
If the input `tensor` does not have 3 or 4 dimensions (as required
by `max_pool2d`).
Notes
-----
- The function assumes higher values in the tensor indicate stronger peaks.
- Choosing an appropriate `kernel_size` is important. It should be large
enough to cover the typical "footprint" of a single detection peak plus
some surrounding context, effectively preventing multiple detections for
the same event. A size that is too large might suppress nearby distinct
events.
"""
if isinstance(kernel_size, int):
kernel_size_h = kernel_size
kernel_size_w = kernel_size
else:
kernel_size_h, kernel_size_w = kernel_size
pad_h = (kernel_size_h - 1) // 2
pad_w = (kernel_size_w - 1) // 2
hmax = torch.nn.functional.max_pool2d(
tensor,
(kernel_size_h, kernel_size_w),
stride=1,
padding=(pad_h, pad_w),
)
keep = (hmax == tensor).float()
return tensor * keep

View File

@ -0,0 +1,316 @@
"""Remaps raw model output tensors to coordinate-aware xarray DataArrays.
This module provides utility functions to convert the raw numerical outputs
(typically PyTorch tensors) from the BatDetect2 DNN model into
`xarray.DataArray` objects. This step adds coordinate information
(time in seconds, frequency in Hz) back to the model's predictions, making them
interpretable in the context of the original audio signal and facilitating
subsequent processing steps.
Functions are provided for common BatDetect2 output types: detection heatmaps,
classification probability maps, size prediction maps, and potentially
intermediate features.
"""
from typing import List
import numpy as np
import torch
import xarray as xr
from soundevent.arrays import Dimensions
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
__all__ = [
"features_to_xarray",
"detection_to_xarray",
"classification_to_xarray",
"sizes_to_xarray",
]
def features_to_xarray(
features: torch.Tensor,
start_time: float,
end_time: float,
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
features_prefix: str = "batdetect2_feature_",
):
"""Convert a multi-channel feature tensor to a coordinate-aware DataArray.
Assigns time, frequency, and feature coordinates to a raw feature tensor
output by the model.
Parameters
----------
features : torch.Tensor
The raw feature tensor from the model. Expected shape is
(num_features, num_freq_bins, num_time_bins).
start_time : float
The start time (in seconds) corresponding to the first time bin of
the tensor.
end_time : float
The end time (in seconds) corresponding to the *end* of the last time
bin.
min_freq : float, default=MIN_FREQ
The minimum frequency (in Hz) corresponding to the first frequency bin.
max_freq : float, default=MAX_FREQ
The maximum frequency (in Hz) corresponding to the *end* of the last
frequency bin.
features_prefix : str, default="batdetect2_feature_"
Prefix used to generate names for the feature coordinate dimension
(e.g., "batdetect2_feature_0", "batdetect2_feature_1", ...).
Returns
-------
xr.DataArray
An xarray DataArray containing the feature data with named dimensions
('feature', 'frequency', 'time') and calculated coordinates.
Raises
------
ValueError
If the input tensor does not have 3 dimensions.
"""
if features.ndim != 3:
raise ValueError(
"Input features tensor must have 3 dimensions (C, T, F), "
f"got shape {features.shape}"
)
num_features, height, width = features.shape
times = np.linspace(start_time, end_time, width, endpoint=False)
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray(
data=features.detach().numpy(),
dims=[
Dimensions.feature.value,
Dimensions.frequency.value,
Dimensions.time.value,
],
coords={
Dimensions.feature.value: [
f"{features_prefix}{i}" for i in range(num_features)
],
Dimensions.frequency.value: freqs,
Dimensions.time.value: times,
},
name="features",
)
def detection_to_xarray(
detection: torch.Tensor,
start_time: float,
end_time: float,
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
) -> xr.DataArray:
"""Convert a single-channel detection heatmap tensor to a DataArray.
Assigns time and frequency coordinates to a raw detection heatmap tensor.
Parameters
----------
detection : torch.Tensor
Raw detection heatmap tensor from the model. Expected shape is
(1, num_freq_bins, num_time_bins).
start_time : float
Start time (seconds) corresponding to the first time bin.
end_time : float
End time (seconds) corresponding to the end of the last time bin.
min_freq : float, default=MIN_FREQ
Minimum frequency (Hz) corresponding to the first frequency bin.
max_freq : float, default=MAX_FREQ
Maximum frequency (Hz) corresponding to the end of the last frequency
bin.
Returns
-------
xr.DataArray
An xarray DataArray containing the detection scores with named
dimensions ('frequency', 'time') and calculated coordinates.
Raises
------
ValueError
If the input tensor does not have 3 dimensions or if the first
dimension size is not 1.
"""
if detection.ndim != 3:
raise ValueError(
"Input detection tensor must have 3 dimensions (1, T, F), "
f"got shape {detection.shape}"
)
num_channels, height, width = detection.shape
if num_channels != 1:
raise ValueError(
"Expected a single channel output, instead got "
f"{num_channels} channels"
)
times = np.linspace(start_time, end_time, width, endpoint=False)
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray(
data=detection.squeeze(dim=0).detach().numpy(),
dims=[
Dimensions.frequency.value,
Dimensions.time.value,
],
coords={
Dimensions.frequency.value: freqs,
Dimensions.time.value: times,
},
name="detection_score",
)
def classification_to_xarray(
classes: torch.Tensor,
start_time: float,
end_time: float,
class_names: List[str],
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
) -> xr.DataArray:
"""Convert multi-channel class probability tensor to a DataArray.
Assigns category (class name), frequency, and time coordinates to a raw
class probability tensor output by the model.
Parameters
----------
classes : torch.Tensor
Raw class probability tensor. Expected shape is
(num_classes, num_freq_bins, num_time_bins).
start_time : float
Start time (seconds) corresponding to the first time bin.
end_time : float
End time (seconds) corresponding to the end of the last time bin.
class_names : List[str]
Ordered list of class names corresponding to the first dimension
of the `classes` tensor. The length must match `classes.shape[0]`.
min_freq : float, default=MIN_FREQ
Minimum frequency (Hz) corresponding to the first frequency bin.
max_freq : float, default=MAX_FREQ
Maximum frequency (Hz) corresponding to the end of the last frequency
bin.
Returns
-------
xr.DataArray
An xarray DataArray containing class probabilities with named
dimensions ('category', 'frequency', 'time') and calculated
coordinates.
Raises
------
ValueError
If the input tensor does not have 3 dimensions, or if the size of the
first dimension does not match the length of `class_names`.
"""
if classes.ndim != 3:
raise ValueError(
"Input classes tensor must have 3 dimensions (C, F, T), "
f"got shape {classes.shape}"
)
num_classes, height, width = classes.shape
if num_classes != len(class_names):
raise ValueError(
"The number of classes does not coincide with the number of "
"class names provided: "
f"({num_classes = }) != ({len(class_names) = })"
)
times = np.linspace(start_time, end_time, width, endpoint=False)
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray(
data=classes.detach().numpy(),
dims=[
"category",
Dimensions.frequency.value,
Dimensions.time.value,
],
coords={
"category": class_names,
Dimensions.frequency.value: freqs,
Dimensions.time.value: times,
},
name="class_scores",
)
def sizes_to_xarray(
sizes: torch.Tensor,
start_time: float,
end_time: float,
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
) -> xr.DataArray:
"""Convert the 2-channel size prediction tensor to a DataArray.
Assigns dimension ('width', 'height'), frequency, and time coordinates
to the raw size prediction tensor output by the model.
Parameters
----------
sizes : torch.Tensor
Raw size prediction tensor. Expected shape is
(2, num_freq_bins, num_time_bins), where the first dimension
corresponds to predicted width and height respectively.
start_time : float
Start time (seconds) corresponding to the first time bin.
end_time : float
End time (seconds) corresponding to the end of the last time bin.
min_freq : float, default=MIN_FREQ
Minimum frequency (Hz) corresponding to the first frequency bin.
max_freq : float, default=MAX_FREQ
Maximum frequency (Hz) corresponding to the end of the last frequency
bin.
Returns
-------
xr.DataArray
An xarray DataArray containing predicted sizes with named dimensions
('dimension', 'frequency', 'time') and calculated time/frequency
coordinates. The 'dimension' coordinate will have values
['width', 'height'].
Raises
------
ValueError
If the input tensor does not have 3 dimensions or if the first
dimension size is not exactly 2.
"""
num_channels, height, width = sizes.shape
if num_channels != 2:
raise ValueError(
"Expected a two-channel output, instead got "
f"{num_channels} channels"
)
times = np.linspace(start_time, end_time, width, endpoint=False)
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
return xr.DataArray(
data=sizes.detach().numpy(),
dims=[
"dimension",
Dimensions.frequency.value,
Dimensions.time.value,
],
coords={
"dimension": ["width", "height"],
Dimensions.frequency.value: freqs,
Dimensions.time.value: times,
},
)

View File

@ -0,0 +1,284 @@
"""Defines shared interfaces and data structures for postprocessing.
This module centralizes the Protocol definitions and common data structures
used throughout the `batdetect2.postprocess` module.
The main component is the `PostprocessorProtocol`, which outlines the standard
interface for an object responsible for executing the entire postprocessing
pipeline. This pipeline transforms raw neural network outputs into interpretable
detections represented as `soundevent` objects. Using protocols ensures
modularity and consistent interaction between different parts of the BatDetect2
system that deal with model predictions.
"""
from typing import Callable, List, NamedTuple, Protocol
import numpy as np
import xarray as xr
from soundevent import data
from batdetect2.models.types import ModelOutput
__all__ = [
"RawPrediction",
"PostprocessorProtocol",
"GeometryBuilder",
]
GeometryBuilder = Callable[[tuple[float, float], np.ndarray], data.Geometry]
"""Type alias for a function that recovers geometry from position and size.
This callable takes:
1. A position tuple `(time, frequency)`.
2. A NumPy array of size dimensions (e.g., `[width, height]`).
It should return the reconstructed `soundevent.data.Geometry` (typically a
`BoundingBox`).
"""
class RawPrediction(NamedTuple):
"""Intermediate representation of a single detected sound event.
Holds extracted information about a detection after initial processing
(like peak finding, coordinate remapping, geometry recovery) but before
final class decoding and conversion into a `SoundEventPrediction`. This
can be useful for evaluation or simpler data handling formats.
Attributes
----------
start_time : float
Start time of the recovered bounding box in seconds.
end_time : float
End time of the recovered bounding box in seconds.
low_freq : float
Lowest frequency of the recovered bounding box in Hz.
high_freq : float
Highest frequency of the recovered bounding box in Hz.
detection_score : float
The confidence score associated with this detection, typically from
the detection heatmap peak.
class_scores : xr.DataArray
An xarray DataArray containing the predicted probabilities or scores
for each target class at the detection location. Indexed by a
'category' coordinate containing class names.
features : xr.DataArray
An xarray DataArray containing extracted feature vectors at the
detection location. Indexed by a 'feature' coordinate.
"""
start_time: float
end_time: float
low_freq: float
high_freq: float
detection_score: float
class_scores: xr.DataArray
features: xr.DataArray
class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline.
This protocol outlines the standard methods for an object that takes raw
output from a BatDetect2 model and the corresponding input clip metadata,
and processes it through various stages (e.g., coordinate remapping, NMS,
detection extraction, data extraction, decoding) to produce interpretable
results at different levels of completion.
Implementations manage the configured logic for all postprocessing steps.
"""
def get_feature_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Remap feature tensors to coordinate-aware DataArrays.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch, expected
to contain the necessary feature tensors.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects, one for each item in the
processed batch. This list provides the timing, recording, and
other metadata context needed to calculate real-world coordinates
(seconds, Hz) for the output arrays. The length of this list must
correspond to the batch size of the `output`.
Returns
-------
List[xr.DataArray]
A list of xarray DataArrays, one for each input clip in the batch,
in the same order. Each DataArray contains the feature vectors
with dimensions like ('feature', 'time', 'frequency') and
corresponding real-world coordinates.
"""
...
def get_detection_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Remap detection tensors to coordinate-aware DataArrays.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch,
containing detection heatmaps.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing coordinate context. Must match the batch size of
`output`.
Returns
-------
List[xr.DataArray]
A list of 2D xarray DataArrays (one per input clip, in order),
representing the detection heatmap with 'time' and 'frequency'
coordinates. Values typically indicate detection confidence.
"""
...
def get_classification_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Remap classification tensors to coordinate-aware DataArrays.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch,
containing class probability tensors.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing coordinate context. Must match the batch size of
`output`.
Returns
-------
List[xr.DataArray]
A list of 3D xarray DataArrays (one per input clip, in order),
representing class probabilities with 'category', 'time', and
'frequency' dimensions and coordinates.
"""
...
def get_sizes_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Remap size prediction tensors to coordinate-aware DataArrays.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch,
containing predicted size tensors (e.g., width and height).
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing coordinate context. Must match the batch size of
`output`.
Returns
-------
List[xr.DataArray]
A list of 3D xarray DataArrays (one per input clip, in order),
representing predicted sizes with 'dimension'
(e.g., ['width', 'height']), 'time', and 'frequency' dimensions and
coordinates. Values represent estimated detection sizes.
"""
...
def get_detection_datasets(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.Dataset]:
"""Perform remapping, NMS, detection, and data extraction for a batch.
Processes the raw model output for a batch to identify detection peaks
and extract all associated information (score, position, size, class
probs, features) at those peak locations, returning a structured
dataset for each input clip in the batch.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[xr.Dataset]
A list of xarray Datasets (one per input clip, in order). Each
Dataset contains multiple DataArrays ('scores', 'dimensions',
'classes', 'features') sharing a common 'detection' dimension,
providing aligned data for each detected event in that clip.
"""
...
def get_raw_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Processes the raw model output for a batch through remapping, NMS,
detection, data extraction, and geometry recovery to produce a list of
`RawPrediction` objects for each corresponding input clip. This provides
a simplified, intermediate representation before final tag decoding.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[List[RawPrediction]]
A list of lists (one inner list per input clip, in order). Each
inner list contains the `RawPrediction` objects extracted for the
corresponding input clip.
"""
...
def get_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output for a batch and corresponding clips, applies the
entire postprocessing chain, and returns the final, interpretable
predictions as a list of `soundevent.data.ClipPrediction` objects.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[data.ClipPrediction]
A list containing one `ClipPrediction` object for each input clip
(in the same order), populated with `SoundEventPrediction` objects
representing the final detections with decoded tags and geometry.
"""
...

View File

@ -1,68 +1,448 @@
"""Module containing functions for preprocessing audio clips."""
"""Main entry point for the BatDetect2 Preprocessing subsystem.
from typing import Optional
This package (`batdetect2.preprocessing`) defines and orchestrates the pipeline
for converting raw audio input (from files or data objects) into processed
spectrograms suitable for input to BatDetect2 models. This ensures consistent
data handling between model training and inference.
The preprocessing pipeline consists of two main stages, configured via nested
data structures:
1. **Audio Processing (`.audio`)**: Loads audio waveforms and applies initial
processing like resampling, duration adjustment, centering, and scaling.
Configured via `AudioConfig`.
2. **Spectrogram Generation (`.spectrogram`)**: Computes the spectrogram from
the processed waveform using STFT, followed by frequency cropping, optional
PCEN, amplitude scaling (dB, power, linear), optional denoising, optional
resizing, and optional peak normalization. Configured via
`SpectrogramConfig`.
This module provides the primary interface:
- `PreprocessingConfig`: A unified configuration object holding `AudioConfig`
and `SpectrogramConfig`.
- `load_preprocessing_config`: Function to load the unified configuration.
- `Preprocessor`: A protocol defining the interface for the end-to-end pipeline.
- `StandardPreprocessor`: The default implementation of the `Preprocessor`.
- `build_preprocessor`: A factory function to create a `StandardPreprocessor`
instance from a `PreprocessingConfig`.
"""
from typing import Optional, Union
import numpy as np
import xarray as xr
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess.audio import (
DEFAULT_DURATION,
SCALE_RAW_AUDIO,
TARGET_SAMPLERATE_HZ,
AudioConfig,
ResampleConfig,
load_clip_audio,
)
from batdetect2.preprocess.config import (
PreprocessingConfig,
load_preprocessing_config,
build_audio_loader,
)
from batdetect2.preprocess.spectrogram import (
AmplitudeScaleConfig,
MAX_FREQ,
MIN_FREQ,
ConfigurableSpectrogramBuilder,
FrequencyConfig,
LogScaleConfig,
PcenScaleConfig,
Scales,
PcenConfig,
SpecSizeConfig,
SpectrogramConfig,
STFTConfig,
compute_spectrogram,
build_spectrogram_builder,
get_spectrogram_resolution,
)
from batdetect2.preprocess.types import (
AudioLoader,
PreprocessorProtocol,
SpectrogramBuilder,
)
__all__ = [
"AmplitudeScaleConfig",
"AudioConfig",
"AudioLoader",
"ConfigurableSpectrogramBuilder",
"DEFAULT_DURATION",
"FrequencyConfig",
"LogScaleConfig",
"PcenScaleConfig",
"MAX_FREQ",
"MIN_FREQ",
"PcenConfig",
"PreprocessingConfig",
"ResampleConfig",
"SCALE_RAW_AUDIO",
"STFTConfig",
"Scales",
"SpecSizeConfig",
"SpectrogramBuilder",
"SpectrogramConfig",
"StandardPreprocessor",
"TARGET_SAMPLERATE_HZ",
"build_audio_loader",
"build_preprocessor",
"build_spectrogram_builder",
"get_spectrogram_resolution",
"load_preprocessing_config",
"preprocess_audio_clip",
]
def preprocess_audio_clip(
clip: data.Clip,
config: Optional[PreprocessingConfig] = None,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Preprocesses audio clip to generate spectrogram.
class PreprocessingConfig(BaseConfig):
"""Unified configuration for the audio preprocessing pipeline.
Aggregates the configuration for both the initial audio processing stage
and the subsequent spectrogram generation stage.
Attributes
----------
audio : AudioConfig
Configuration settings for the audio loading and initial waveform
processing steps (e.g., resampling, duration adjustment, scaling).
Defaults to default `AudioConfig` settings if omitted.
spectrogram : SpectrogramConfig
Configuration settings for the spectrogram generation process
(e.g., STFT parameters, frequency cropping, scaling, denoising,
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
"""
audio: AudioConfig = Field(default_factory=AudioConfig)
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
class StandardPreprocessor(PreprocessorProtocol):
"""Standard implementation of the `Preprocessor` protocol.
Orchestrates the audio loading and spectrogram generation pipeline using
an `AudioLoader` and a `SpectrogramBuilder` internally, which are
configured according to a `PreprocessingConfig`.
This class is typically instantiated using the `build_preprocessor`
factory function.
Attributes
----------
audio_loader : AudioLoader
The configured audio loader instance used for waveform loading and
initial processing.
spectrogram_builder : SpectrogramBuilder
The configured spectrogram builder instance used for generating
spectrograms from waveforms.
default_samplerate : int
The sample rate (in Hz) assumed for input waveforms when they are
provided as raw NumPy arrays without coordinate information (e.g.,
when calling `compute_spectrogram` directly with `np.ndarray`).
This value is derived from the `AudioConfig` (target resample rate
or default if resampling is off) and also serves as documentation
for the pipeline's intended operating sample rate. Note that when
processing `xr.DataArray` inputs that have coordinate information
(the standard internal workflow), the sample rate embedded in the
coordinates takes precedence over this default value during
spectrogram calculation.
"""
audio_loader: AudioLoader
spectrogram_builder: SpectrogramBuilder
default_samplerate: int
max_freq: float
min_freq: float
def __init__(
self,
audio_loader: AudioLoader,
spectrogram_builder: SpectrogramBuilder,
default_samplerate: int,
max_freq: float,
min_freq: float,
) -> None:
"""Initialize the StandardPreprocessor.
Parameters
----------
audio_loader : AudioLoader
An initialized audio loader conforming to the AudioLoader protocol.
spectrogram_builder : SpectrogramBuilder
An initialized spectrogram builder conforming to the
SpectrogramBuilder protocol.
default_samplerate : int
The sample rate to assume for NumPy array inputs and potentially
reflecting the target rate of the audio config.
"""
self.audio_loader = audio_loader
self.spectrogram_builder = spectrogram_builder
self.default_samplerate = default_samplerate
self.max_freq = max_freq
self.min_freq = min_freq
def load_file_audio(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform from a file path.
Delegates to the internal `audio_loader`.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix if `path` is relative.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform (typically first
channel).
"""
return self.audio_loader.load_file(
path,
audio_dir=audio_dir,
)
def load_recording_audio(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform for a Recording.
Delegates to the internal `audio_loader`.
Parameters
----------
recording : data.Recording
The Recording object.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform (typically first
channel).
"""
return self.audio_loader.load_recording(
recording,
audio_dir=audio_dir,
)
def load_clip_audio(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform for a Clip.
Delegates to the internal `audio_loader`.
Parameters
----------
clip : data.Clip
The Clip object defining the segment.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform segment (typically first
channel).
"""
return self.audio_loader.load_clip(
clip,
audio_dir=audio_dir,
)
def preprocess_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio from a file and compute the final processed spectrogram.
Performs the full pipeline:
Load -> Preprocess Audio -> Compute Spectrogram.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix if `path` is relative.
Returns
-------
xr.DataArray
The final processed spectrogram.
"""
wav = self.load_file_audio(path, audio_dir=audio_dir)
return self.spectrogram_builder(
wav,
samplerate=self.default_samplerate,
)
def preprocess_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio for a Recording and compute the processed spectrogram.
Performs the full pipeline for the entire duration of the recording.
Parameters
----------
recording : data.Recording
The Recording object.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The final processed spectrogram.
"""
wav = self.load_recording_audio(recording, audio_dir=audio_dir)
return self.spectrogram_builder(
wav,
samplerate=self.default_samplerate,
)
def preprocess_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio for a Clip and compute the final processed spectrogram.
Performs the full pipeline for the specified clip segment.
Parameters
----------
clip : data.Clip
The Clip object defining the audio segment.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The final processed spectrogram.
"""
wav = self.load_clip_audio(clip, audio_dir=audio_dir)
return self.spectrogram_builder(
wav,
samplerate=self.default_samplerate,
)
def compute_spectrogram(
self, wav: Union[xr.DataArray, np.ndarray]
) -> xr.DataArray:
"""Compute the spectrogram from a pre-loaded audio waveform.
Applies the configured spectrogram generation steps
(STFT, scaling, etc.) using the internal `spectrogram_builder`.
If `wav` is a NumPy array, the `default_samplerate` stored in this
preprocessor instance will be used. If `wav` is an xarray DataArray
with time coordinates, the sample rate derived from those coordinates
will take precedence over `default_samplerate`.
Parameters
----------
wav : Union[xr.DataArray, np.ndarray]
The input audio waveform. If numpy array, `default_samplerate`
stored in this object will be assumed.
Returns
-------
xr.DataArray
The computed spectrogram.
"""
return self.spectrogram_builder(
wav,
samplerate=self.default_samplerate,
)
def load_preprocessing_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PreprocessingConfig:
"""Load the unified preprocessing configuration from a file.
Reads a configuration file (YAML) and validates it against the
`PreprocessingConfig` schema, potentially extracting data from a nested
field.
Parameters
----------
clip
The audio clip to preprocess.
config
Configuration for preprocessing.
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
preprocessing configuration (e.g., "train.preprocessing"). If None, the
entire file content is validated as the PreprocessingConfig.
Returns
-------
xr.DataArray
Preprocessed spectrogram.
PreprocessingConfig
Loaded and validated preprocessing configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded config data does not conform to PreprocessingConfig.
KeyError, TypeError
If `field` specifies an invalid path.
"""
return load_config(path, schema=PreprocessingConfig, field=field)
def build_preprocessor(
config: Optional[PreprocessingConfig] = None,
) -> PreprocessorProtocol:
"""Factory function to build the standard preprocessor from configuration.
Creates instances of the required `AudioLoader` and `SpectrogramBuilder`
based on the provided `PreprocessingConfig` (or defaults if config is None),
determines the effective default sample rate, and initializes the
`StandardPreprocessor`.
Parameters
----------
config : PreprocessingConfig, optional
The unified preprocessing configuration object. If None, default
configurations for audio and spectrogram processing will be used.
Returns
-------
Preprocessor
An initialized `StandardPreprocessor` instance ready to process audio
according to the configuration.
"""
config = config or PreprocessingConfig()
wav = load_clip_audio(clip, config=config.audio, audio_dir=audio_dir)
return compute_spectrogram(wav, config=config.spectrogram)
default_samplerate = (
config.audio.resample.samplerate
if config.audio.resample
else TARGET_SAMPLERATE_HZ
)
min_freq = config.spectrogram.frequencies.min_freq
max_freq = config.spectrogram.frequencies.max_freq
return StandardPreprocessor(
audio_loader=build_audio_loader(config.audio),
spectrogram_builder=build_spectrogram_builder(config.spectrogram),
default_samplerate=default_samplerate,
min_freq=min_freq,
max_freq=max_freq,
)

View File

@ -1,3 +1,25 @@
"""Handles loading and initial preprocessing of audio waveforms.
This module provides components for loading audio data associated with
`soundevent` objects (Clips, Recordings, or raw files) and applying
fundamental waveform processing steps. These steps typically include:
1. Loading the raw audio data.
2. Adjusting the audio clip to a fixed duration (optional).
3. Resampling the audio to a target sample rate (optional).
4. Centering the waveform (DC offset removal) (optional).
5. Scaling the waveform amplitude (optional).
The processing pipeline is configurable via the `AudioConfig` data structure,
allowing for reproducible preprocessing consistent between model training and
inference. It uses the `soundevent` library for audio loading and basic array
operations, and `scipy` for resampling implementations.
The primary interface is the `AudioLoader` protocol, with
`ConfigurableAudioLoader` providing a concrete implementation driven by the
`AudioConfig`.
"""
from typing import Optional
import numpy as np
@ -7,33 +29,246 @@ from pydantic import Field
from scipy.signal import resample, resample_poly
from soundevent import arrays, audio, data
from soundevent.arrays import operations as ops
from soundfile import LibsndfileError
from batdetect2.configs import BaseConfig
from batdetect2.preprocess.types import AudioLoader
__all__ = [
"ResampleConfig",
"AudioConfig",
"ConfigurableAudioLoader",
"build_audio_loader",
"load_file_audio",
"load_recording_audio",
"load_clip_audio",
"adjust_audio_duration",
"resample_audio",
"TARGET_SAMPLERATE_HZ",
"SCALE_RAW_AUDIO",
"DEFAULT_DURATION",
"convert_to_xr",
]
TARGET_SAMPLERATE_HZ = 256_000
"""Default target sample rate in Hz used if resampling is enabled."""
SCALE_RAW_AUDIO = False
"""Default setting for whether to perform peak normalization."""
DEFAULT_DURATION = None
"""Default setting for target audio duration in seconds."""
class ResampleConfig(BaseConfig):
"""Configuration for audio resampling.
Attributes
----------
samplerate : int, default=256000
The target sample rate in Hz to resample the audio to. Must be > 0.
method : str, default="poly"
The resampling algorithm to use. Options:
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
Generally fast.
- "fourier": Resampling via Fourier method using
`scipy.signal.resample`. May handle non-integer
resampling factors differently.
"""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
mode: str = "poly"
method: str = "poly"
class AudioConfig(BaseConfig):
"""Configuration for loading and initial audio preprocessing.
Defines the sequence of operations applied to raw audio waveforms after
loading, controlling steps like resampling, scaling, centering, and
duration adjustment.
Attributes
----------
resample : ResampleConfig, optional
Configuration for resampling. If provided (or defaulted), audio will
be resampled to the specified `samplerate` using the specified
`method`. If set to `None` in the config file, resampling is skipped.
Defaults to a ResampleConfig instance with standard settings.
scale : bool, default=False
If True, scales the audio waveform using peak normalization so that
its maximum absolute amplitude is approximately 1.0. If False
(default), no amplitude scaling is applied.
center : bool, default=True
If True (default), centers the waveform by subtracting its mean
(DC offset removal). If False, the waveform is not centered.
duration : float, optional
If set to a float value (seconds), the loaded audio clip will be
adjusted (cropped or padded with zeros) to exactly this duration.
If None (default), the original duration is kept.
"""
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
scale: bool = SCALE_RAW_AUDIO
center: bool = True
duration: Optional[float] = DEFAULT_DURATION
class ConfigurableAudioLoader:
"""Concrete implementation of the `AudioLoader` driven by `AudioConfig`.
This class loads audio and applies preprocessing steps (resampling,
scaling, centering, duration adjustment) based on the settings provided
in an `AudioConfig` object during initialization. It delegates the actual
work to module-level functions.
"""
def __init__(
self,
config: AudioConfig,
):
"""Initialize the ConfigurableAudioLoader.
Parameters
----------
config : AudioConfig
The configuration object specifying the desired preprocessing steps
and parameters.
"""
self.config = config
def load_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess audio directly from a file path.
Implements the `AudioLoader.load_file` method by delegating to the
`load_file_audio` function, passing the stored configuration.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix if `path` is relative.
Returns
-------
xr.DataArray
Loaded and preprocessed waveform (first channel).
"""
return load_file_audio(path, config=self.config, audio_dir=audio_dir)
def load_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess the entire audio for a Recording object.
Implements the `AudioLoader.load_recording` method by delegating to the
`load_recording_audio` function, passing the stored configuration.
Parameters
----------
recording : data.Recording
The Recording object.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
Loaded and preprocessed waveform (first channel).
"""
return load_recording_audio(
recording, config=self.config, audio_dir=audio_dir
)
def load_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess the audio segment defined by a Clip object.
Implements the `AudioLoader.load_clip` method by delegating to the
`load_clip_audio` function, passing the stored configuration.
Parameters
----------
clip : data.Clip
The Clip object specifying the segment.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
Loaded and preprocessed waveform segment (first channel).
"""
return load_clip_audio(clip, config=self.config, audio_dir=audio_dir)
def build_audio_loader(
config: AudioConfig,
) -> AudioLoader:
"""Factory function to create an AudioLoader based on configuration.
Instantiates and returns a `ConfigurableAudioLoader` initialized with
the provided `AudioConfig`. The return type is `AudioLoader`, adhering
to the protocol.
Parameters
----------
config : AudioConfig
The configuration object specifying preprocessing steps.
Returns
-------
AudioLoader
An instance of `ConfigurableAudioLoader` ready to load and process audio
according to the configuration.
"""
return ConfigurableAudioLoader(config=config)
def load_file_audio(
path: data.PathLike,
config: Optional[AudioConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32,
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
recording = data.Recording.from_file(path)
"""Load and preprocess audio from a file path using specified config.
Creates a `soundevent.data.Recording` object from the file path and then
delegates the loading and processing to `load_recording_audio`.
Parameters
----------
path : PathLike
Path to the audio file.
config : AudioConfig, optional
Audio processing configuration. If None, default settings defined
in `AudioConfig` are used.
audio_dir : PathLike, optional
Directory prefix if `path` is relative.
dtype : DTypeLike, default=np.float32
Target NumPy data type for the loaded audio array.
Returns
-------
xr.DataArray
Loaded and preprocessed waveform (first channel only).
"""
try:
recording = data.Recording.from_file(path)
except LibsndfileError as e:
raise FileNotFoundError(
f"Could not load the recording at path: {path}. Error: {e}"
) from e
return load_recording_audio(
recording,
config=config,
@ -46,8 +281,31 @@ def load_recording_audio(
recording: data.Recording,
config: Optional[AudioConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32,
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Load and preprocess the entire audio content of a recording using config.
Creates a `soundevent.data.Clip` spanning the full duration of the
recording and then delegates the loading and processing to
`load_clip_audio`.
Parameters
----------
recording : data.Recording
The Recording object containing metadata and file path.
config : AudioConfig, optional
Audio processing configuration. If None, default settings are used.
audio_dir : PathLike, optional
Directory containing the audio file, used if the path in `recording`
is relative.
dtype : DTypeLike, default=np.float32
Target NumPy data type for the loaded audio array.
Returns
-------
xr.DataArray
Loaded and preprocessed waveform (first channel only).
"""
clip = data.Clip(
recording=recording,
start_time=0,
@ -65,16 +323,64 @@ def load_clip_audio(
clip: data.Clip,
config: Optional[AudioConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32,
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Load and preprocess a specific audio clip segment based on config.
This is the core function performing the configured processing pipeline:
1. Loads the specified clip segment using `soundevent.audio.load_clip`.
2. Selects the first audio channel.
3. Resamples if `config.resample` is configured.
4. Centers (DC offset removal) if `config.center` is True.
5. Scales (peak normalization) if `config.scale` is True.
6. Adjusts duration (crop/pad) if `config.duration` is set.
Parameters
----------
clip : data.Clip
The Clip object defining the audio segment and source recording.
config : AudioConfig, optional
Audio processing configuration. If None, a default `AudioConfig` is
used.
audio_dir : PathLike, optional
Directory containing the source audio file specified in the clip's
recording.
dtype : DTypeLike, default=np.float32
Target NumPy data type for the processed audio array.
Returns
-------
xr.DataArray
The loaded and preprocessed waveform segment as an xarray DataArray
with time coordinates.
Raises
------
FileNotFoundError
If the underlying audio file cannot be found.
Exception
If audio loading or processing fails for other reasons (e.g., invalid
format, resampling error).
Notes
-----
- **Mono Processing:** This function currently loads and processes only the
**first channel** (channel 0) of the audio file. Any other channels
are ignored.
"""
config = config or AudioConfig()
wav = (
audio.load_clip(clip, audio_dir=audio_dir).sel(channel=0).astype(dtype)
)
if config.duration is not None:
wav = adjust_audio_duration(wav, duration=config.duration)
try:
wav = (
audio.load_clip(clip, audio_dir=audio_dir)
.sel(channel=0)
.astype(dtype)
)
except LibsndfileError as e:
raise FileNotFoundError(
f"Could not load the recording at path: {clip.recording.path}. "
f"Error: {e}"
) from e
if config.resample:
wav = resample_audio(
@ -87,43 +393,126 @@ def load_clip_audio(
wav = ops.center(wav)
if config.scale:
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
wav = scale_audio(wav)
if config.duration is not None:
wav = adjust_audio_duration(wav, duration=config.duration)
return wav.astype(dtype)
def scale_audio(
wave: xr.DataArray,
) -> xr.DataArray:
"""
Scale the audio waveform to have a maximum absolute value of 1.0.
This function normalizes the waveform by dividing it by its maximum
absolute value. If the maximum value is zero, the waveform is returned
unchanged. Also known as peak normalization, this process ensures that the
waveform's amplitude is within a standard range, which can be useful for
audio processing and analysis.
"""
max_val = np.max(np.abs(wave))
if max_val == 0:
return wave
return ops.scale(wave, 1 / max_val)
def adjust_audio_duration(
wave: xr.DataArray,
duration: float,
) -> xr.DataArray:
"""Adjust the duration of an audio waveform array via cropping or padding.
If the current duration is longer than the target, it crops the array
from the beginning. If shorter, it pads the array with zeros at the end
using `soundevent.arrays.extend_dim`.
Parameters
----------
wave : xr.DataArray
The input audio waveform with a 'time' dimension and coordinates.
duration : float
The target duration in seconds.
Returns
-------
xr.DataArray
The waveform adjusted to the target duration. Returns the input
unmodified if duration already matches or if the wave is empty.
Raises
------
ValueError
If `duration` is negative.
"""
start_time, end_time = arrays.get_dim_range(wave, dim="time")
current_duration = end_time - start_time
step = arrays.get_dim_step(wave, dim="time")
current_duration = end_time - start_time + step
if current_duration == duration:
return wave
if current_duration > duration:
return arrays.crop_dim(
with xr.set_options(keep_attrs=True):
if current_duration > duration:
return arrays.crop_dim(
wave,
dim="time",
start=start_time,
stop=start_time + duration - step / 2,
right_closed=True,
)
return arrays.extend_dim(
wave,
dim="time",
start=start_time,
stop=start_time + duration,
stop=start_time + duration - step / 2,
eps=0,
right_closed=True,
)
return arrays.extend_dim(
wave,
dim="time",
start=start_time,
stop=start_time + duration,
)
def resample_audio(
wav: xr.DataArray,
samplerate: int = TARGET_SAMPLERATE_HZ,
mode: str = "poly",
dtype: DTypeLike = np.float32,
method: str = "poly",
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Resample an audio waveform DataArray to a target sample rate.
Updates the 'time' coordinate axis according to the new sample rate and
number of samples. Uses either polyphase (`scipy.signal.resample_poly`)
or Fourier method (`scipy.signal.resample`) based on the `method`.
Parameters
----------
wav : xr.DataArray
Input audio waveform with 'time' dimension and coordinates.
samplerate : int, default=TARGET_SAMPLERATE_HZ
Target sample rate in Hz.
method : str, default="poly"
Resampling algorithm: "poly" or "fourier".
dtype : DTypeLike, default=np.float32
Target data type for the resampled array.
Returns
-------
xr.DataArray
Resampled waveform with updated time coordinates. Returns the input
unmodified (but dtype cast) if the sample rate is already correct or
if the input array is empty.
Raises
------
ValueError
If `wav` lacks a 'time' dimension, the original sample rate cannot
be determined, `samplerate` is non-positive, or `method` is invalid.
"""
if "time" not in wav.dims:
raise ValueError("Audio must have a time dimension")
@ -134,14 +523,14 @@ def resample_audio(
if original_samplerate == samplerate:
return wav.astype(dtype)
if mode == "poly":
if method == "poly":
resampled = resample_audio_poly(
wav,
sr_orig=original_samplerate,
sr_new=samplerate,
axis=time_axis,
)
elif mode == "fourier":
elif method == "fourier":
resampled = resample_audio_fourier(
wav,
sr_orig=original_samplerate,
@ -149,7 +538,9 @@ def resample_audio(
axis=time_axis,
)
else:
raise NotImplementedError(f"Resampling mode '{mode}' not implemented")
raise NotImplementedError(
f"Resampling method '{method}' not implemented"
)
start, stop = arrays.get_dim_range(wav, dim="time")
times = np.linspace(
@ -170,7 +561,7 @@ def resample_audio(
samplerate=samplerate,
),
},
attrs=wav.attrs,
attrs={**wav.attrs, "samplerate": samplerate},
)
@ -180,6 +571,33 @@ def resample_audio_poly(
sr_new: int,
axis: int = -1,
) -> np.ndarray:
"""Resample a numpy array using `scipy.signal.resample_poly`.
This method is often preferred for signals when the ratio of new
to old sample rates can be expressed as a rational number. It uses
polyphase filtering.
Parameters
----------
array : np.ndarray
The input array to resample.
sr_orig : int
The original sample rate in Hz.
sr_new : int
The target sample rate in Hz.
axis : int, default=-1
The axis of `array` along which to resample.
Returns
-------
np.ndarray
The array resampled to the target sample rate.
Raises
------
ValueError
If sample rates are not positive.
"""
gcd = np.gcd(sr_orig, sr_new)
return resample_poly(
array.values,
@ -195,5 +613,97 @@ def resample_audio_fourier(
sr_new: int,
axis: int = -1,
) -> np.ndarray:
"""Resample a numpy array using `scipy.signal.resample`.
This method uses FFTs to resample the signal.
Parameters
----------
array : np.ndarray
The input array to resample.
num : int
The desired number of samples in the output array along `axis`.
axis : int, default=-1
The axis of `array` along which to resample.
Returns
-------
np.ndarray
The array resampled to have `num` samples along `axis`.
Raises
------
ValueError
If `num` is negative.
"""
ratio = sr_new / sr_orig
return resample(array, int(array.shape[axis] * ratio), axis=axis) # type: ignore
return resample( # type: ignore
array,
int(array.shape[axis] * ratio),
axis=axis,
)
def convert_to_xr(
wav: np.ndarray,
samplerate: int,
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Convert a NumPy array to an xarray DataArray with time coordinates.
Parameters
----------
wav : np.ndarray
The input waveform array. Expected to be 1D or 2D (with the first
axis as the channel dimension).
samplerate : int
The sample rate in Hz.
dtype : DTypeLike, default=np.float32
Target data type for the xarray DataArray.
Returns
-------
xr.DataArray
The waveform as an xarray DataArray with time coordinates.
Raises
------
ValueError
If the input array is not 1D or 2D, or if the sample rate is
non-positive. If the input array is empty.
"""
if wav.ndim == 2:
wav = wav[0, :]
if wav.ndim != 1:
raise ValueError(
"Audio must be 1D array or 2D channel where the first "
"axis is the channel dimension"
)
if wav.size == 0:
raise ValueError("Audio array is empty")
if samplerate <= 0:
raise ValueError("Sample rate must be positive")
times = np.linspace(
0,
wav.shape[0] / samplerate,
wav.shape[0],
endpoint=False,
dtype=dtype,
)
return xr.DataArray(
data=wav.astype(dtype),
dims=["time"],
coords={
"time": arrays.create_time_dim_from_array(
times,
samplerate=samplerate,
),
},
attrs={"samplerate": samplerate},
)

View File

@ -1,31 +0,0 @@
from typing import Optional
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess.audio import (
AudioConfig,
)
from batdetect2.preprocess.spectrogram import (
SpectrogramConfig,
)
__all__ = [
"PreprocessingConfig",
"load_preprocessing_config",
]
class PreprocessingConfig(BaseConfig):
"""Configuration for preprocessing data."""
audio: AudioConfig = Field(default_factory=AudioConfig)
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
def load_preprocessing_config(
path: PathLike,
field: Optional[str] = None,
) -> PreprocessingConfig:
return load_config(path, schema=PreprocessingConfig, field=field)

View File

@ -1,7 +1,26 @@
"""Computes spectrograms from audio waveforms with configurable parameters.
This module provides the functionality to convert preprocessed audio waveforms
(typically output from the `batdetect2.preprocessing.audio` module) into
spectrogram representations suitable for input into deep learning models like
BatDetect2.
It offers a configurable pipeline including:
1. Short-Time Fourier Transform (STFT) calculation to get magnitude.
2. Frequency axis cropping to a relevant range.
3. Per-Channel Energy Normalization (PCEN) (optional).
4. Amplitude scaling/representation (dB, power, or linear amplitude).
5. Simple spectral mean subtraction denoising (optional).
6. Resizing to target dimensions (optional).
7. Final peak normalization (optional).
Configuration is managed via the `SpectrogramConfig` class, allowing for
reproducible spectrogram generation consistent between training and inference.
The core computation is performed by `compute_spectrogram`.
"""
from typing import Literal, Optional, Union
import librosa
import librosa.core.spectrum
import numpy as np
import xarray as xr
from numpy.typing import DTypeLike
@ -10,68 +29,302 @@ from soundevent import arrays, audio
from soundevent.arrays import operations as ops
from batdetect2.configs import BaseConfig
from batdetect2.preprocess.audio import convert_to_xr
from batdetect2.preprocess.types import SpectrogramBuilder
__all__ = [
"STFTConfig",
"FrequencyConfig",
"SpecSizeConfig",
"PcenConfig",
"SpectrogramConfig",
"ConfigurableSpectrogramBuilder",
"build_spectrogram_builder",
"compute_spectrogram",
"get_spectrogram_resolution",
"MIN_FREQ",
"MAX_FREQ",
]
MIN_FREQ = 10_000
"""Default minimum frequency (Hz) for spectrogram frequency cropping."""
MAX_FREQ = 120_000
"""Default maximum frequency (Hz) for spectrogram frequency cropping."""
class STFTConfig(BaseConfig):
"""Configuration for the Short-Time Fourier Transform (STFT).
Attributes
----------
window_duration : float, default=0.002
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be
> 0. Determines frequency resolution (longer window = finer frequency
resolution).
window_overlap : float, default=0.75
Fraction of overlap between consecutive STFT windows (e.g., 0.75
for 75%). Must be >= 0 and < 1. Determines time resolution
(higher overlap = finer time resolution).
window_fn : str, default="hann"
Name of the window function to apply before FFT calculation. Common
options include "hann", "hamming", "blackman". See
`scipy.signal.get_window`.
"""
window_duration: float = Field(default=0.002, gt=0)
window_overlap: float = Field(default=0.75, ge=0, lt=1)
window_fn: str = "hann"
class FrequencyConfig(BaseConfig):
max_freq: int = Field(default=120_000, gt=0)
min_freq: int = Field(default=10_000, gt=0)
"""Configuration for frequency axis parameters.
Attributes
----------
max_freq : int, default=120000
Maximum frequency in Hz to retain in the spectrogram after STFT.
Frequencies above this value will be cropped. Must be > 0.
min_freq : int, default=10000
Minimum frequency in Hz to retain in the spectrogram after STFT.
Frequencies below this value will be cropped. Must be >= 0.
"""
max_freq: int = Field(default=120_000, ge=0)
min_freq: int = Field(default=10_000, ge=0)
class SpecSizeConfig(BaseConfig):
"""Configuration for the final size and shape of the spectrogram.
Attributes
----------
height : int, default=128
Target height of the spectrogram in pixels (frequency bins). The
frequency axis will be resized (e.g., via interpolation) to match this
height after frequency cropping and amplitude scaling. Must be > 0.
resize_factor : float, optional
Factor by which to resize the spectrogram along the time axis *after*
STFT calculation. A value of 0.5 halves the number of time bins,
2.0 doubles it. If None (default), no resizing along the time axis is
performed relative to the STFT output width. Must be > 0 if provided.
"""
height: int = 128
"""Height of the spectrogram in pixels. This value determines the
number of frequency bands and corresponds to the vertical dimension
of the spectrogram."""
resize_factor: Optional[float] = 0.5
"""Factor by which to resize the spectrogram along the time axis.
A value of 0.5 reduces the temporal dimension by half, while a
value of 2.0 doubles it. If None, no resizing is performed."""
class LogScaleConfig(BaseConfig):
name: Literal["log"] = "log"
class PcenConfig(BaseConfig):
"""Configuration for Per-Channel Energy Normalization (PCEN).
PCEN is an adaptive gain control method that can help emphasize transients
and suppress stationary noise. Applied after STFT and frequency cropping,
but before final amplitude scaling (dB, power, amplitude).
Attributes
----------
time_constant : float, default=0.4
Time constant (in seconds) for the PCEN smoothing filter. Controls
how quickly the normalization adapts to energy changes.
gain : float, default=0.98
Gain factor (alpha). Controls the adaptive gain component.
bias : float, default=2.0
Bias factor (delta). Added before the exponentiation.
power : float, default=0.5
Exponent (r). Controls the compression characteristic.
"""
class PcenScaleConfig(BaseConfig):
name: Literal["pcen"] = "pcen"
time_constant: float = 0.4
hop_length: int = 512
gain: float = 0.98
bias: float = 2
power: float = 0.5
class AmplitudeScaleConfig(BaseConfig):
name: Literal["amplitude"] = "amplitude"
Scales = Union[LogScaleConfig, PcenScaleConfig, AmplitudeScaleConfig]
class SpectrogramConfig(BaseConfig):
"""Unified configuration for spectrogram generation pipeline.
Aggregates settings for all steps involved in converting a preprocessed
audio waveform into a final spectrogram representation suitable for model
input.
Attributes
----------
stft : STFTConfig
Configuration for the initial Short-Time Fourier Transform.
Defaults to standard settings via `STFTConfig`.
frequencies : FrequencyConfig
Configuration for cropping the frequency range after STFT.
Defaults to standard settings via `FrequencyConfig`.
pcen : PcenConfig, optional
Configuration for applying Per-Channel Energy Normalization (PCEN). If
provided, PCEN is applied after frequency cropping. If None or omitted
(default), PCEN is skipped.
scale : Literal["dB", "amplitude", "power"], default="amplitude"
Determines the final amplitude representation *after* optional PCEN.
- "amplitude": Use linear magnitude values (output of STFT or PCEN).
- "power": Use power values (magnitude squared).
- "dB": Use logarithmic (decibel-like) scaling applied to the magnitude
(or PCEN output if enabled). Calculated as `log1p(C * S)`.
size : SpecSizeConfig, optional
Configuration for resizing the spectrogram dimensions
(frequency height, optional time width factor). Applied after PCEN and
scaling. If None (default), no resizing is performed.
spectral_mean_substraction : bool, default=True
If True (default), applies simple spectral mean subtraction denoising
*after* PCEN and amplitude scaling, but *before* resizing.
peak_normalize : bool, default=False
If True, applies a final peak normalization to the spectrogram *after*
all other steps (including resizing), scaling the overall maximum value
to 1.0. If False (default), this final normalization is skipped.
"""
stft: STFTConfig = Field(default_factory=STFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
scale: Scales = Field(
default_factory=PcenScaleConfig,
discriminator="name",
)
pcen: Optional[PcenConfig] = Field(default_factory=PcenConfig)
scale: Literal["dB", "amplitude", "power"] = "amplitude"
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
denoise: bool = True
max_scale: bool = False
spectral_mean_substraction: bool = True
peak_normalize: bool = False
class ConfigurableSpectrogramBuilder(SpectrogramBuilder):
"""Implementation of `SpectrogramBuilder` driven by `SpectrogramConfig`.
This class computes spectrograms according to the parameters specified in a
`SpectrogramConfig` object provided during initialization. It handles both
numpy array and xarray DataArray inputs for the waveform.
"""
def __init__(
self,
config: SpectrogramConfig,
dtype: DTypeLike = np.float32, # type: ignore
) -> None:
"""Initialize the ConfigurableSpectrogramBuilder.
Parameters
----------
config : SpectrogramConfig
The configuration object specifying all spectrogram parameters.
dtype : DTypeLike, default=np.float32
The target NumPy data type for the computed spectrogram array.
"""
self.config = config
self.dtype = dtype
def __call__(
self,
wav: Union[np.ndarray, xr.DataArray],
samplerate: Optional[int] = None,
) -> xr.DataArray:
"""Generate a spectrogram from an audio waveform using the config.
Implements the `SpectrogramBuilder` protocol. If the input `wav` is
a numpy array, `samplerate` must be provided; the array will be
converted to an xarray DataArray internally. If `wav` is already an
xarray DataArray with time coordinates, `samplerate` is ignored.
Delegates the main computation to `compute_spectrogram`.
Parameters
----------
wav : Union[np.ndarray, xr.DataArray]
The input audio waveform.
samplerate : int, optional
The sample rate in Hz (required only if `wav` is np.ndarray).
Returns
-------
xr.DataArray
The computed spectrogram.
Raises
------
ValueError
If `wav` is np.ndarray and `samplerate` is None.
"""
if isinstance(wav, np.ndarray):
if samplerate is None:
raise ValueError(
"Samplerate must be provided when passing a numpy array."
)
wav = convert_to_xr(
wav,
samplerate=samplerate,
dtype=self.dtype,
)
return compute_spectrogram(
wav,
config=self.config,
dtype=self.dtype,
)
def build_spectrogram_builder(
config: SpectrogramConfig,
dtype: DTypeLike = np.float32, # type: ignore
) -> SpectrogramBuilder:
"""Factory function to create a SpectrogramBuilder based on configuration.
Instantiates and returns a `ConfigurableSpectrogramBuilder` initialized
with the provided `SpectrogramConfig`.
Parameters
----------
config : SpectrogramConfig
The configuration object specifying spectrogram parameters.
dtype : DTypeLike, default=np.float32
The target NumPy data type for the computed spectrogram array.
Returns
-------
SpectrogramBuilder
An instance of `ConfigurableSpectrogramBuilder` ready to compute
spectrograms according to the configuration.
"""
return ConfigurableSpectrogramBuilder(config=config, dtype=dtype)
def compute_spectrogram(
wav: xr.DataArray,
config: Optional[SpectrogramConfig] = None,
dtype: DTypeLike = np.float32,
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Compute a spectrogram from a waveform using specified configurations.
Applies a sequence of operations based on the `config`:
1. Compute STFT magnitude (`stft`).
2. Crop frequency axis (`crop_spectrogram_frequencies`).
3. Apply PCEN if configured (`apply_pcen`).
4. Apply final amplitude scaling (dB, power, amplitude)
(`scale_spectrogram`).
5. Apply spectral mean subtraction denoising if enabled.
6. Resize dimensions if specified (`resize_spectrogram`).
7. Apply final peak normalization if enabled.
Parameters
----------
wav : xr.DataArray
Input audio waveform with a 'time' dimension and coordinates from
which the sample rate can be inferred.
config : SpectrogramConfig, optional
Configuration object specifying spectrogram parameters. If None,
default settings from `SpectrogramConfig` are used.
dtype : DTypeLike, default=np.float32
Target NumPy data type for the final spectrogram array.
Returns
-------
xr.DataArray
The computed and processed spectrogram with 'time' and 'frequency'
coordinates.
Raises
------
ValueError
If `wav` lacks necessary 'time' coordinates or dimensions.
"""
config = config or SpectrogramConfig()
spec = stft(
@ -79,7 +332,6 @@ def compute_spectrogram(
window_duration=config.stft.window_duration,
window_overlap=config.stft.window_overlap,
window_fn=config.stft.window_fn,
dtype=dtype,
)
spec = crop_spectrogram_frequencies(
@ -88,10 +340,19 @@ def compute_spectrogram(
max_freq=config.frequencies.max_freq,
)
if config.pcen:
spec = apply_pcen(
spec,
time_constant=config.pcen.time_constant,
gain=config.pcen.gain,
power=config.pcen.power,
bias=config.pcen.bias,
)
spec = scale_spectrogram(spec, scale=config.scale)
if config.denoise:
spec = denoise_spectrogram(spec)
if config.spectral_mean_substraction:
spec = remove_spectral_mean(spec)
if config.size:
spec = resize_spectrogram(
@ -100,7 +361,7 @@ def compute_spectrogram(
resize_factor=config.size.resize_factor,
)
if config.max_scale:
if config.peak_normalize:
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
return spec.astype(dtype)
@ -111,11 +372,32 @@ def crop_spectrogram_frequencies(
min_freq: int = 10_000,
max_freq: int = 120_000,
) -> xr.DataArray:
"""Crop the frequency axis of a spectrogram to a specified range.
Uses `soundevent.arrays.crop_dim` to select the frequency bins
corresponding to the range [`min_freq`, `max_freq`].
Parameters
----------
spec : xr.DataArray
Input spectrogram with 'frequency' dimension and coordinates.
min_freq : int, default=MIN_FREQ
Minimum frequency (Hz) to keep.
max_freq : int, default=MAX_FREQ
Maximum frequency (Hz) to keep.
Returns
-------
xr.DataArray
Spectrogram cropped along the frequency axis. Preserves dtype.
"""
start_freq, end_freq = arrays.get_dim_range(spec, dim="frequency")
return arrays.crop_dim(
spec,
dim="frequency",
start=min_freq,
stop=max_freq,
start=min_freq if start_freq < min_freq else None,
stop=max_freq if end_freq > max_freq else None,
).astype(spec.dtype)
@ -124,61 +406,61 @@ def stft(
window_duration: float,
window_overlap: float,
window_fn: str = "hann",
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
start_time, end_time = arrays.get_dim_range(wave, dim="time")
step = arrays.get_dim_step(wave, dim="time")
sampling_rate = 1 / step
"""Compute the Short-Time Fourier Transform (STFT) magnitude spectrogram.
nfft = int(window_duration * sampling_rate)
noverlap = int(window_overlap * nfft)
hop_len = nfft - noverlap
hop_duration = hop_len / sampling_rate
Calculates STFT parameters (N-FFT, hop length) based on the window
duration, overlap, and waveform sample rate. Returns an xarray DataArray
with correctly calculated 'time' and 'frequency' coordinates.
spec, _ = librosa.core.spectrum._spectrogram(
y=wave.data.astype(dtype),
power=1,
n_fft=nfft,
hop_length=nfft - noverlap,
center=False,
window=window_fn,
)
Parameters
----------
wave : xr.DataArray
Input audio waveform with 'time' coordinates.
window_duration : float
Duration of the STFT window in seconds.
window_overlap : float
Fractional overlap between consecutive windows.
window_fn : str, default="hann"
Name of the window function (e.g., "hann", "hamming").
return xr.DataArray(
data=spec.astype(dtype),
dims=["frequency", "time"],
coords={
"frequency": arrays.create_frequency_dim_from_array(
np.linspace(
0,
sampling_rate / 2,
spec.shape[0],
endpoint=False,
dtype=dtype,
),
step=sampling_rate / nfft,
),
"time": arrays.create_time_dim_from_array(
np.linspace(
start_time,
end_time - (window_duration - hop_duration),
spec.shape[1],
endpoint=False,
dtype=dtype,
),
step=hop_duration,
),
},
attrs={
**wave.attrs,
"original_samplerate": sampling_rate,
"nfft": nfft,
"noverlap": noverlap,
},
Returns
-------
xr.DataArray
Magnitude spectrogram with 'time' and 'frequency' dimensions and
coordinates. STFT parameters are stored in the `attrs`.
Raises
------
ValueError
If sample rate cannot be determined from `wave` coordinates.
"""
return audio.compute_spectrogram(
wave,
window_size=window_duration,
hop_size=(1 - window_overlap) * window_duration,
window_type=window_fn,
scale="amplitude",
sort_dims=False,
)
def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
def remove_spectral_mean(spec: xr.DataArray) -> xr.DataArray:
"""Apply simple spectral mean subtraction for denoising.
Subtracts the mean value of each frequency bin (calculated across time)
from that bin, then clips negative values to zero.
Parameters
----------
spec : xr.DataArray
Input spectrogram with 'time' and 'frequency' dimensions.
Returns
-------
xr.DataArray
Denoised spectrogram with the same dimensions, coordinates, and dtype.
"""
return xr.DataArray(
data=(spec - spec.mean("time")).clip(0),
dims=spec.dims,
@ -189,34 +471,82 @@ def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
def scale_spectrogram(
spec: xr.DataArray,
scale: Scales,
dtype: DTypeLike = np.float32,
scale: Literal["dB", "power", "amplitude"],
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
if scale.name == "log":
"""Apply final amplitude scaling/representation to the spectrogram.
Converts the input magnitude spectrogram based on the `scale` type:
- "dB": Applies logarithmic scaling `log1p(C * S)`.
- "power": Squares the magnitude values `S^2`.
- "amplitude": Returns the input magnitude values `S` unchanged.
Parameters
----------
spec : xr.DataArray
Input magnitude spectrogram (potentially after PCEN).
scale : Literal["dB", "power", "amplitude"]
The target amplitude representation.
dtype : DTypeLike, default=np.float32
Target data type for the output scaled spectrogram.
Returns
-------
xr.DataArray
Spectrogram with the specified amplitude scaling applied.
"""
if scale == "dB":
return scale_log(spec, dtype=dtype)
if scale.name == "pcen":
return scale_pcen(
spec,
time_constant=scale.time_constant,
hop_length=scale.hop_length,
gain=scale.gain,
power=scale.power,
bias=scale.bias,
)
if scale == "power":
return spec**2
return spec
def scale_pcen(
def apply_pcen(
spec: xr.DataArray,
time_constant: float = 0.4,
hop_length: int = 512,
gain: float = 0.98,
bias: float = 2,
power: float = 0.5,
) -> xr.DataArray:
samplerate = spec.attrs["original_samplerate"]
"""Apply Per-Channel Energy Normalization (PCEN) to a spectrogram.
Parameters
----------
spec : xr.DataArray
Input magnitude spectrogram with required attributes like
'processing_original_samplerate'.
time_constant : float, default=0.4
PCEN time constant in seconds.
gain : float, default=0.98
Gain factor (alpha).
bias : float, default=2.0
Bias factor (delta).
power : float, default=0.5
Exponent (r).
dtype : DTypeLike, default=np.float32
Target data type for the output spectrogram.
Returns
-------
xr.DataArray
PCEN-scaled spectrogram.
Notes
-----
- The input spectrogram magnitude `spec` is multiplied by `2**31` before
being passed to `audio.pcen`. This suggests the underlying implementation
might expect values in a range typical of 16-bit or 32-bit signed integers,
even though the input here might be float. This scaling factor should be
verified against the specific `soundevent.audio.pcen` implementation
details.
"""
samplerate = spec.attrs["samplerate"]
hop_size = spec.attrs["hop_size"]
hop_length = int(hop_size * samplerate)
t_frames = time_constant * samplerate / (float(hop_length) * 10)
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
return audio.pcen(
@ -230,8 +560,34 @@ def scale_pcen(
def scale_log(
spec: xr.DataArray,
dtype: DTypeLike = np.float32,
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Apply logarithmic scaling to a magnitude spectrogram.
Calculates `log(1 + C * S)`, where S is the input magnitude spectrogram
and C is a scaling factor derived from the original STFT parameters
(sample rate, N-FFT, window function) stored in `spec.attrs`.
Parameters
----------
spec : xr.DataArray
Input magnitude spectrogram with required attributes like
'processing_original_samplerate', 'processing_nfft'.
dtype : DTypeLike, default=np.float32
Target data type for the output spectrogram.
Returns
-------
xr.DataArray
Log-scaled spectrogram.
Raises
------
KeyError
If required attributes are missing from `spec.attrs`.
ValueError
If attributes are non-numeric or window function is invalid.
"""
samplerate = spec.attrs["original_samplerate"]
nfft = spec.attrs["nfft"]
log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum())
@ -248,6 +604,28 @@ def resize_spectrogram(
height: int = 128,
resize_factor: Optional[float] = 0.5,
) -> xr.DataArray:
"""Resize a spectrogram to target dimensions using interpolation.
Resizes the frequency axis to `height` bins and optionally resizes the
time axis by `resize_factor`.
Parameters
----------
spec : xr.DataArray
Input spectrogram with 'time' and 'frequency' dimensions.
height : int, default=128
Target number of frequency bins (vertical dimension).
resize_factor : float, optional
Factor to resize the time dimension. If 1.0 or None, time dimension
is unchanged. If 0.5, time dimension is halved, etc.
Returns
-------
xr.DataArray
Resized spectrogram. Coordinates are typically adjusted by the
underlying resize operation if implemented in `ops.resize`.
The dtype is currently hardcoded to float32 by ops.resize call.
"""
resize_factor = resize_factor or 1
current_width = spec.sizes["time"]
return ops.resize(
@ -258,61 +636,41 @@ def resize_spectrogram(
)
def adjust_spectrogram_width(
spec: xr.DataArray,
divide_factor: int = 32,
time_period: float = 0.001,
) -> xr.DataArray:
time_width = spec.sizes["time"]
if time_width % divide_factor == 0:
return spec
target_size = int(
np.ceil(spec.sizes["time"] / divide_factor) * divide_factor
)
extra_duration = (target_size - time_width) * time_period
_, stop = arrays.get_dim_range(spec, dim="time")
resized = ops.extend_dim(
spec,
dim="time",
stop=stop + extra_duration,
)
return resized
def duration_to_spec_width(
duration: float,
samplerate: int,
window_duration: float,
window_overlap: float,
) -> int:
samples = int(duration * samplerate)
fft_len = int(window_duration * samplerate)
fft_overlap = int(window_overlap * fft_len)
hop_len = fft_len - fft_overlap
width = (samples - fft_len + hop_len) / hop_len
return int(np.floor(width))
def spec_width_to_samples(
width: int,
samplerate: int,
window_duration: float,
window_overlap: float,
) -> int:
fft_len = int(window_duration * samplerate)
fft_overlap = int(window_overlap * fft_len)
hop_len = fft_len - fft_overlap
return width * hop_len + fft_len - hop_len
def get_spectrogram_resolution(
config: SpectrogramConfig,
) -> tuple[float, float]:
"""Calculate the approximate resolution of the final spectrogram.
Computes the width of each frequency bin (Hz/bin) and the duration
of each time bin (seconds/bin) based on the configuration parameters.
Parameters
----------
config : SpectrogramConfig
The spectrogram configuration object.
samplerate : int, optional
The sample rate of the audio *before* STFT. Required if needed to
calculate hop duration accurately from STFT config, but the current
implementation calculates hop_duration directly from STFT config times.
Returns
-------
Tuple[float, float]
A tuple containing:
- frequency_resolution (float): Approximate Hz per frequency bin.
- time_resolution (float): Approximate seconds per time bin.
Raises
------
ValueError
If required configuration fields (like `config.size`) are missing
or invalid.
"""
max_freq = config.frequencies.max_freq
min_freq = config.frequencies.min_freq
assert config.size is not None
if config.size is None:
raise ValueError("Spectrogram size configuration is required.")
spec_height = config.size.height
resize_factor = config.size.resize_factor or 1

View File

@ -0,0 +1,381 @@
"""Defines common interfaces (Protocols) for preprocessing components.
This module centralizes the Protocol definitions used throughout the
`batdetect2.preprocess` package. Protocols define expected methods and
signatures, allowing for flexible and interchangeable implementations of
components like audio loaders and spectrogram builders.
Using these protocols ensures that different parts of the preprocessing
pipeline can interact consistently, regardless of the specific underlying
implementation (e.g., different libraries or custom configurations).
"""
from typing import Optional, Protocol, Union
import numpy as np
import xarray as xr
from soundevent import data
class AudioLoader(Protocol):
"""Defines the interface for an audio loading and processing component.
An AudioLoader is responsible for retrieving audio data corresponding to
different soundevent objects (files, Recordings, Clips) and applying a
configured set of initial preprocessing steps. Adhering to this protocol
allows for different loading strategies or implementations.
"""
def load_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess audio directly from a file path.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix to prepend to the path if `path` is relative.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform as an xarray DataArray
with time coordinates. Typically loads only the first channel.
Raises
------
FileNotFoundError
If the audio file cannot be found.
Exception
If the audio file cannot be loaded or processed.
"""
...
def load_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess the entire audio for a Recording object.
Parameters
----------
recording : data.Recording
The Recording object containing metadata about the audio file.
audio_dir : PathLike, optional
A directory where the audio file associated with the recording
can be found, especially if the path in the recording is relative.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform. Typically loads only
the first channel.
Raises
------
FileNotFoundError
If the audio file associated with the recording cannot be found.
Exception
If the audio file cannot be loaded or processed.
"""
...
def load_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess the audio segment defined by a Clip object.
Parameters
----------
clip : data.Clip
The Clip object specifying the recording and the start/end times
of the segment to load.
audio_dir : PathLike, optional
A directory where the audio file associated with the clip's
recording can be found.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform for the specified clip
duration. Typically loads only the first channel.
Raises
------
FileNotFoundError
If the audio file associated with the clip cannot be found.
Exception
If the audio file cannot be loaded or processed.
"""
...
class SpectrogramBuilder(Protocol):
"""Defines the interface for a spectrogram generation component.
A SpectrogramBuilder takes a waveform (as numpy array or xarray DataArray)
and produces a spectrogram (as an xarray DataArray) based on its internal
configuration or implementation.
"""
def __call__(
self,
wav: Union[np.ndarray, xr.DataArray],
samplerate: Optional[int] = None,
) -> xr.DataArray:
"""Generate a spectrogram from an audio waveform.
Parameters
----------
wav : Union[np.ndarray, xr.DataArray]
The input audio waveform. If a numpy array, `samplerate` must
also be provided. If an xarray DataArray, it must have a 'time'
coordinate from which the sample rate can be inferred.
samplerate : int, optional
The sample rate of the audio in Hz. Required if `wav` is a
numpy array. If `wav` is an xarray DataArray, this parameter is
ignored as the sample rate is derived from the coordinates.
Returns
-------
xr.DataArray
The computed spectrogram as an xarray DataArray with 'time' and
'frequency' coordinates.
Raises
------
ValueError
If `wav` is a numpy array and `samplerate` is not provided, or
if `wav` is an xarray DataArray without a valid 'time' coordinate.
"""
...
class PreprocessorProtocol(Protocol):
"""Defines a high-level interface for the complete preprocessing pipeline.
A Preprocessor combines audio loading and spectrogram generation steps.
It provides methods to go directly from source descriptions (file paths,
Recording objects, Clip objects) to the final spectrogram representation
needed by the model. It may also expose intermediate steps like audio
loading or spectrogram computation from a waveform.
"""
max_freq: float
min_freq: float
def preprocess_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio from a file and compute the final processed spectrogram.
Performs the full pipeline:
Load -> Preprocess Audio -> Compute Spectrogram.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix if `path` is relative.
Returns
-------
xr.DataArray
The final processed spectrogram.
Raises
------
FileNotFoundError
If the audio file cannot be found.
Exception
If any step in the loading or preprocessing fails.
"""
...
def preprocess_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio for a Recording and compute the processed spectrogram.
Performs the full pipeline for the entire duration of the recording.
Parameters
----------
recording : data.Recording
The Recording object.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The final processed spectrogram.
Raises
------
FileNotFoundError
If the audio file cannot be found.
Exception
If any step in the loading or preprocessing fails.
"""
...
def preprocess_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio for a Clip and compute the final processed spectrogram.
Performs the full pipeline for the specified clip segment.
Parameters
----------
clip : data.Clip
The Clip object defining the audio segment.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The final processed spectrogram.
Raises
------
FileNotFoundError
If the audio file cannot be found.
Exception
If any step in the loading or preprocessing fails.
"""
...
def load_file_audio(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform from a file path.
Performs the initial audio loading and waveform processing steps
(like resampling, scaling), but stops *before* spectrogram generation.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix if `path` is relative.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform.
Raises
------
FileNotFoundError, Exception
If audio loading/preprocessing fails.
"""
...
def load_recording_audio(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform for a Recording.
Performs the initial audio loading and waveform processing steps
for the entire recording duration.
Parameters
----------
recording : data.Recording
The Recording object.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform.
Raises
------
FileNotFoundError, Exception
If audio loading/preprocessing fails.
"""
...
def load_clip_audio(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform for a Clip.
Performs the initial audio loading and waveform processing steps
for the specified clip segment.
Parameters
----------
clip : data.Clip
The Clip object defining the segment.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform segment.
Raises
------
FileNotFoundError, Exception
If audio loading/preprocessing fails.
"""
...
def compute_spectrogram(
self,
wav: Union[xr.DataArray, np.ndarray],
) -> xr.DataArray:
"""Compute the spectrogram from a pre-loaded audio waveform.
Applies the spectrogram generation steps (STFT, scaling, etc.) defined
by the `SpectrogramBuilder` component of the preprocessor to an
already loaded (and potentially preprocessed) waveform.
Parameters
----------
wav : Union[xr.DataArray, np.ndarray]
The input audio waveform. If numpy array, `samplerate` is required.
samplerate : int, optional
Sample rate in Hz (required if `wav` is np.ndarray).
Returns
-------
xr.DataArray
The computed spectrogram.
Raises
------
ValueError, Exception
If waveform input is invalid or spectrogram computation fails.
"""
...

View File

@ -0,0 +1,666 @@
"""Main entry point for the BatDetect2 Target Definition subsystem.
This package (`batdetect2.targets`) provides the tools and configurations
necessary to define precisely what the BatDetect2 model should learn to detect,
classify, and localize from audio data. It involves several conceptual steps,
managed through configuration files and culminating in an executable pipeline:
1. **Terms (`.terms`)**: Defining vocabulary for annotation tags.
2. **Filtering (`.filtering`)**: Selecting relevant sound event annotations.
3. **Transformation (`.transform`)**: Modifying tags (standardization,
derivation).
4. **ROI Mapping (`.roi`)**: Defining how annotation geometry (ROIs) maps to
target position and size representations, and back.
5. **Class Definition (`.classes`)**: Mapping tags to target class names
(encoding) and mapping predicted names back to tags (decoding).
This module exposes the key components for users to configure and utilize this
target definition pipeline, primarily through the `TargetConfig` data structure
and the `Targets` class (implementing `TargetProtocol`), which encapsulates the
configured processing steps. The main way to create a functional `Targets`
object is via the `build_targets` or `load_targets` functions.
"""
from typing import List, Optional
import numpy as np
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.classes import (
ClassesConfig,
SoundEventDecoder,
SoundEventEncoder,
TargetClass,
build_generic_class_tags,
build_sound_event_decoder,
build_sound_event_encoder,
get_class_names_from_config,
load_classes_config,
load_decoder_from_config,
load_encoder_from_config,
)
from batdetect2.targets.filtering import (
FilterConfig,
FilterRule,
SoundEventFilter,
build_sound_event_filter,
load_filter_config,
load_filter_from_config,
)
from batdetect2.targets.rois import (
ROIConfig,
ROITargetMapper,
build_roi_mapper,
)
from batdetect2.targets.terms import (
TagInfo,
TermInfo,
TermRegistry,
call_type,
get_tag_from_info,
get_term_from_key,
individual,
register_term,
term_registry,
)
from batdetect2.targets.transform import (
DerivationRegistry,
DeriveTagRule,
MapValueRule,
ReplaceRule,
SoundEventTransformation,
TransformConfig,
build_transformation_from_config,
derivation_registry,
get_derivation,
load_transformation_config,
load_transformation_from_config,
register_derivation,
)
from batdetect2.targets.types import TargetProtocol
__all__ = [
"ClassesConfig",
"DEFAULT_TARGET_CONFIG",
"DeriveTagRule",
"FilterConfig",
"FilterRule",
"MapValueRule",
"ROIConfig",
"ROITargetMapper",
"ReplaceRule",
"SoundEventDecoder",
"SoundEventEncoder",
"SoundEventFilter",
"SoundEventTransformation",
"TagInfo",
"TargetClass",
"TargetConfig",
"TargetProtocol",
"Targets",
"TermInfo",
"TransformConfig",
"build_generic_class_tags",
"build_roi_mapper",
"build_sound_event_decoder",
"build_sound_event_encoder",
"build_sound_event_filter",
"build_transformation_from_config",
"call_type",
"get_class_names_from_config",
"get_derivation",
"get_tag_from_info",
"get_term_from_key",
"individual",
"load_classes_config",
"load_decoder_from_config",
"load_encoder_from_config",
"load_filter_config",
"load_filter_from_config",
"load_target_config",
"load_transformation_config",
"load_transformation_from_config",
"register_derivation",
"register_term",
]
class TargetConfig(BaseConfig):
"""Unified configuration for the entire target definition pipeline.
This model aggregates the configurations for semantic processing (filtering,
transformation, class definition) and geometric processing (ROI mapping).
It serves as the primary input for building a complete `Targets` object
via `build_targets` or `load_targets`.
Attributes
----------
filtering : FilterConfig, optional
Configuration for filtering sound event annotations based on tags.
If None or omitted, no filtering is applied.
transforms : TransformConfig, optional
Configuration for transforming annotation tags
(mapping, derivation, etc.). If None or omitted, no tag transformations
are applied.
classes : ClassesConfig
Configuration defining the specific target classes, their tag matching
rules for encoding, their representative tags for decoding
(`output_tags`), and the definition of the generic class tags.
This section is mandatory.
roi : ROIConfig, optional
Configuration defining how geometric ROIs (e.g., bounding boxes) are
mapped to target representations (reference point, scaled size).
Controls `position`, `time_scale`, `frequency_scale`. If None or
omitted, default ROI mapping settings are used.
"""
filtering: Optional[FilterConfig] = None
transforms: Optional[TransformConfig] = None
classes: ClassesConfig
roi: Optional[ROIConfig] = None
def load_target_config(
path: data.PathLike,
field: Optional[str] = None,
) -> TargetConfig:
"""Load the unified target configuration from a file.
Reads a configuration file (typically YAML) and validates it against the
`TargetConfig` schema, potentially extracting data from a nested field.
Parameters
----------
path : data.PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
target configuration. If None, the entire file content is used.
Returns
-------
TargetConfig
The loaded and validated unified target configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded configuration data does not conform to the
`TargetConfig` schema (including validation within nested configs
like `ClassesConfig`).
KeyError, TypeError
If `field` specifies an invalid path within the loaded data.
"""
return load_config(path=path, schema=TargetConfig, field=field)
class Targets(TargetProtocol):
"""Encapsulates the complete configured target definition pipeline.
This class implements the `TargetProtocol`, holding the configured
functions for filtering, transforming, encoding (tags to class name),
decoding (class name to tags), and mapping ROIs (geometry to position/size
and back). It provides a high-level interface to apply these steps and
access relevant metadata like class names and dimension names.
Instances are typically created using the `build_targets` factory function
or the `load_targets` convenience loader.
Attributes
----------
class_names : List[str]
An ordered list of the unique names of the specific target classes
defined in the configuration.
generic_class_tags : List[data.Tag]
A list of `soundevent.data.Tag` objects representing the configured
generic class category (used when no specific class matches).
dimension_names : List[str]
The names of the size dimensions handled by the ROI mapper
(e.g., ['width', 'height']).
"""
class_names: List[str]
generic_class_tags: List[data.Tag]
dimension_names: List[str]
def __init__(
self,
encode_fn: SoundEventEncoder,
decode_fn: SoundEventDecoder,
roi_mapper: ROITargetMapper,
class_names: list[str],
generic_class_tags: List[data.Tag],
filter_fn: Optional[SoundEventFilter] = None,
transform_fn: Optional[SoundEventTransformation] = None,
):
"""Initialize the Targets object.
Note: This constructor is typically called internally by the
`build_targets` factory function.
Parameters
----------
encode_fn : SoundEventEncoder
Configured function to encode annotations to class names.
decode_fn : SoundEventDecoder
Configured function to decode class names to tags.
roi_mapper : ROITargetMapper
Configured object for mapping geometry to/from position/size.
class_names : list[str]
Ordered list of specific target class names.
generic_class_tags : List[data.Tag]
List of tags representing the generic class.
filter_fn : SoundEventFilter, optional
Configured function to filter annotations. Defaults to None.
transform_fn : SoundEventTransformation, optional
Configured function to transform annotation tags. Defaults to None.
"""
self.class_names = class_names
self.generic_class_tags = generic_class_tags
self.dimension_names = roi_mapper.dimension_names
self._roi_mapper = roi_mapper
self._filter_fn = filter_fn
self._encode_fn = encode_fn
self._decode_fn = decode_fn
self._transform_fn = transform_fn
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the configured filter to a sound event annotation.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation to filter.
Returns
-------
bool
True if the annotation should be kept (passes the filter),
False otherwise. If no filter was configured, always returns True.
"""
if not self._filter_fn:
return True
return self._filter_fn(sound_event)
def encode(self, sound_event: data.SoundEventAnnotation) -> Optional[str]:
"""Encode a sound event annotation to its target class name.
Applies the configured class definition rules (including priority)
to determine the specific class name for the annotation.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation to encode. Note: This should typically be called
*after* applying any transformations via the `transform` method.
Returns
-------
str or None
The name of the matched target class, or None if the annotation
does not match any specific class rule (i.e., it belongs to the
generic category).
"""
return self._encode_fn(sound_event)
def decode(self, class_label: str) -> List[data.Tag]:
"""Decode a predicted class name back into representative tags.
Uses the configured mapping (based on `TargetClass.output_tags` or
`TargetClass.tags`) to convert a class name string into a list of
`soundevent.data.Tag` objects.
Parameters
----------
class_label : str
The class name to decode.
Returns
-------
List[data.Tag]
The list of tags corresponding to the input class name.
"""
return self._decode_fn(class_label)
def transform(
self, sound_event: data.SoundEventAnnotation
) -> data.SoundEventAnnotation:
"""Apply the configured tag transformations to an annotation.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation whose tags should be transformed.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the transformed tags. If no
transformations were configured, the original annotation object is
returned.
"""
if self._transform_fn:
return self._transform_fn(sound_event)
return sound_event
def get_position(
self, sound_event: data.SoundEventAnnotation
) -> tuple[float, float]:
"""Extract the target reference position from the annotation's roi.
Delegates to the internal ROI mapper's `get_roi_position` method.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI).
Returns
-------
Tuple[float, float]
The reference position `(time, frequency)`.
Raises
------
ValueError
If the annotation lacks geometry.
"""
geom = sound_event.sound_event.geometry
if geom is None:
raise ValueError(
"Sound event has no geometry, cannot get its position."
)
return self._roi_mapper.get_roi_position(geom)
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
"""Calculate the target size dimensions from the annotation's geometry.
Delegates to the internal ROI mapper's `get_roi_size` method, which
applies configured scaling factors.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI).
Returns
-------
np.ndarray
NumPy array containing the size dimensions, matching the
order in `self.dimension_names` (e.g., `[width, height]`).
Raises
------
ValueError
If the annotation lacks geometry.
"""
geom = sound_event.sound_event.geometry
if geom is None:
raise ValueError(
"Sound event has no geometry, cannot get its size."
)
return self._roi_mapper.get_roi_size(geom)
def recover_roi(
self,
pos: tuple[float, float],
dims: np.ndarray,
) -> data.Geometry:
"""Recover an approximate geometric ROI from a position and dimensions.
Delegates to the internal ROI mapper's `recover_roi` method, which
un-scales the dimensions and reconstructs the geometry (typically a
`BoundingBox`).
Parameters
----------
pos : Tuple[float, float]
The reference position `(time, frequency)`.
dims : np.ndarray
NumPy array with size dimensions (e.g., from model prediction),
matching the order in `self.dimension_names`.
Returns
-------
data.Geometry
The reconstructed geometry (typically `BoundingBox`).
"""
return self._roi_mapper.recover_roi(pos, dims)
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
filtering=FilterConfig(
rules=[
FilterRule(
match_type="all",
tags=[TagInfo(key="event", value="Echolocation")],
),
FilterRule(
match_type="exclude",
tags=[
TagInfo(key="event", value="Feeding"),
TagInfo(key="event", value="Unknown"),
TagInfo(key="event", value="Not Bat"),
],
),
]
),
classes=ClassesConfig(
classes=[
TargetClass(
tags=[TagInfo(value="Myotis mystacinus")],
name="myomys",
),
TargetClass(
tags=[TagInfo(value="Myotis alcathoe")],
name="myoalc",
),
TargetClass(
tags=[TagInfo(value="Eptesicus serotinus")],
name="eptser",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus nathusii")],
name="pipnat",
),
TargetClass(
tags=[TagInfo(value="Barbastellus barbastellus")],
name="barbar",
),
TargetClass(
tags=[TagInfo(value="Myotis nattereri")],
name="myonat",
),
TargetClass(
tags=[TagInfo(value="Myotis daubentonii")],
name="myodau",
),
TargetClass(
tags=[TagInfo(value="Myotis brandtii")],
name="myobra",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus pipistrellus")],
name="pippip",
),
TargetClass(
tags=[TagInfo(value="Myotis bechsteinii")],
name="myobec",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus pygmaeus")],
name="pippyg",
),
TargetClass(
tags=[TagInfo(value="Rhinolophus hipposideros")],
name="rhihip",
),
TargetClass(
tags=[TagInfo(value="Nyctalus leisleri")],
name="nyclei",
),
TargetClass(
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
name="rhifer",
),
TargetClass(
tags=[TagInfo(value="Plecotus auritus")],
name="pleaur",
),
TargetClass(
tags=[TagInfo(value="Nyctalus noctula")],
name="nycnoc",
),
TargetClass(
tags=[TagInfo(value="Plecotus austriacus")],
name="pleaus",
),
],
generic_class=[TagInfo(value="Bat")],
),
)
def build_targets(
config: Optional[TargetConfig] = None,
term_registry: TermRegistry = term_registry,
derivation_registry: DerivationRegistry = derivation_registry,
) -> Targets:
"""Build a Targets object from a loaded TargetConfig.
This factory function takes the unified `TargetConfig` and constructs all
necessary functional components (filter, transform, encoder,
decoder, ROI mapper) by calling their respective builder functions. It also
extracts metadata (class names, generic tags, dimension names) to create
and return a fully initialized `Targets` instance, ready to process
annotations.
Parameters
----------
config : TargetConfig
The loaded and validated unified target configuration object.
term_registry : TermRegistry, optional
The TermRegistry instance to use for resolving term keys. Defaults
to the global `batdetect2.targets.terms.term_registry`.
derivation_registry : DerivationRegistry, optional
The DerivationRegistry instance to use for resolving derivation
function names. Defaults to the global
`batdetect2.targets.transform.derivation_registry`.
Returns
-------
Targets
An initialized `Targets` object ready for use.
Raises
------
KeyError
If term keys or derivation function keys specified in the `config`
are not found in their respective registries.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function fails (when configured).
"""
config = config or DEFAULT_TARGET_CONFIG
filter_fn = (
build_sound_event_filter(
config.filtering,
term_registry=term_registry,
)
if config.filtering
else None
)
encode_fn = build_sound_event_encoder(
config.classes,
term_registry=term_registry,
)
decode_fn = build_sound_event_decoder(
config.classes,
term_registry=term_registry,
)
transform_fn = (
build_transformation_from_config(
config.transforms,
term_registry=term_registry,
derivation_registry=derivation_registry,
)
if config.transforms
else None
)
roi_mapper = build_roi_mapper(config.roi or ROIConfig())
class_names = get_class_names_from_config(config.classes)
generic_class_tags = build_generic_class_tags(
config.classes,
term_registry=term_registry,
)
return Targets(
filter_fn=filter_fn,
encode_fn=encode_fn,
decode_fn=decode_fn,
class_names=class_names,
roi_mapper=roi_mapper,
generic_class_tags=generic_class_tags,
transform_fn=transform_fn,
)
def load_targets(
config_path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = term_registry,
derivation_registry: DerivationRegistry = derivation_registry,
) -> Targets:
"""Load a Targets object directly from a configuration file.
This convenience factory method loads the `TargetConfig` from the
specified file path and then calls `Targets.from_config` to build
the fully initialized `Targets` object.
Parameters
----------
config_path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
Dot-separated path to a nested section within the file containing
the target configuration. If None, the entire file content is used.
term_registry : TermRegistry, optional
The TermRegistry instance to use. Defaults to the global default.
derivation_registry : DerivationRegistry, optional
The DerivationRegistry instance to use. Defaults to the global
default.
Returns
-------
Targets
An initialized `Targets` object ready for use.
Raises
------
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
TypeError
Errors raised during file loading, validation, or extraction via
`load_target_config`.
KeyError, ImportError, AttributeError, TypeError
Errors raised during the build process by `Targets.from_config`
(e.g., missing keys in registries, failed imports).
"""
config = load_target_config(
config_path,
field=field,
)
return build_targets(
config,
term_registry=term_registry,
derivation_registry=derivation_registry,
)

View File

@ -0,0 +1,618 @@
from collections import Counter
from functools import partial
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple
from pydantic import Field, field_validator
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.terms import (
GENERIC_CLASS_KEY,
TagInfo,
TermRegistry,
get_tag_from_info,
term_registry,
)
__all__ = [
"SoundEventEncoder",
"SoundEventDecoder",
"TargetClass",
"ClassesConfig",
"load_classes_config",
"load_encoder_from_config",
"load_decoder_from_config",
"build_sound_event_encoder",
"build_sound_event_decoder",
"build_generic_class_tags",
"get_class_names_from_config",
"DEFAULT_SPECIES_LIST",
]
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
"""Type alias for a sound event class encoder function.
An encoder function takes a sound event annotation and returns the string name
of the target class it belongs to, based on a predefined set of rules.
If the annotation does not match any defined target class according to the
rules, the function returns None.
"""
SoundEventDecoder = Callable[[str], List[data.Tag]]
"""Type alias for a sound event class decoder function.
A decoder function takes a class name string (as predicted by the model or
assigned during encoding) and returns a list of `soundevent.data.Tag` objects
that represent that class according to the configuration. This is used to
translate model outputs back into meaningful annotations.
"""
DEFAULT_SPECIES_LIST = [
"Barbastella barbastellus",
"Eptesicus serotinus",
"Myotis alcathoe",
"Myotis bechsteinii",
"Myotis brandtii",
"Myotis daubentonii",
"Myotis mystacinus",
"Myotis nattereri",
"Nyctalus leisleri",
"Nyctalus noctula",
"Pipistrellus nathusii",
"Pipistrellus pipistrellus",
"Pipistrellus pygmaeus",
"Plecotus auritus",
"Plecotus austriacus",
"Rhinolophus ferrumequinum",
"Rhinolophus hipposideros",
]
"""A default list of common bat species names found in the UK."""
class TargetClass(BaseConfig):
"""Defines criteria for encoding annotations and decoding predictions.
Each instance represents one potential output class for the classification
model. It specifies:
1. A unique `name` for the class.
2. The tag conditions (`tags` and `match_type`) an annotation must meet to
be assigned this class name during training data preparation (encoding).
3. An optional, alternative set of tags (`output_tags`) to be used when
converting a model's prediction of this class name back into annotation
tags (decoding).
Attributes
----------
name : str
The unique name assigned to this target class (e.g., 'pippip',
'myodau', 'noise'). This name is used as the label during model
training and is the expected output from the model's prediction.
Should be unique across all TargetClass definitions in a configuration.
tags : List[TagInfo]
A list of one or more tags (defined using `TagInfo`) used to identify
if an existing annotation belongs to this class during encoding (data
preparation for training). The `match_type` attribute determines how
these tags are evaluated.
match_type : Literal["all", "any"], default="all"
Determines how the `tags` list is evaluated during encoding:
- "all": The annotation must have *all* the tags listed to match.
- "any": The annotation must have *at least one* of the tags listed
to match.
output_tags: Optional[List[TagInfo]], default=None
An optional list of tags (defined using `TagInfo`) to be assigned to a
new annotation when the model predicts this class `name`. If `None`
(default), the tags listed in the `tags` field will be used for
decoding. If provided, this list overrides the `tags` field for the
purpose of decoding predictions back into meaningful annotation tags.
This allows, for example, training on broader categories but decoding
to more specific representative tags.
"""
name: str
tags: List[TagInfo] = Field(min_length=1)
match_type: Literal["all", "any"] = Field(default="all")
output_tags: Optional[List[TagInfo]] = None
def _get_default_classes() -> List[TargetClass]:
"""Generate a list of default target classes.
Returns
-------
List[TargetClass]
A list of TargetClass objects, one for each species in
DEFAULT_SPECIES_LIST. The class names are simplified versions of the
species names.
"""
return [
TargetClass(
name=_get_default_class_name(value),
tags=[TagInfo(key=GENERIC_CLASS_KEY, value=value)],
)
for value in DEFAULT_SPECIES_LIST
]
def _get_default_class_name(species: str) -> str:
"""Generate a default class name from a species name.
Parameters
----------
species : str
The species name (e.g., "Myotis daubentonii").
Returns
-------
str
A simplified class name (e.g., "myodau").
The genus and species names are converted to lowercase,
the first three letters of each are taken, and concatenated.
"""
genus, species = species.strip().split(" ")
return f"{genus.lower()[:3]}{species.lower()[:3]}"
def _get_default_generic_class() -> List[TagInfo]:
"""Generate the default list of TagInfo objects for the generic class.
Provides a default set of tags used to represent the generic "Bat" category
when decoding predictions that didn't match a specific class.
Returns
-------
List[TagInfo]
A list containing default TagInfo objects, typically representing
`call_type: Echolocation` and `order: Chiroptera`.
"""
return [
TagInfo(key="call_type", value="Echolocation"),
TagInfo(key="order", value="Chiroptera"),
]
class ClassesConfig(BaseConfig):
"""Configuration defining target classes and the generic fallback category.
Holds the ordered list of specific target class definitions (`TargetClass`)
and defines the tags representing the generic category for sounds that pass
filtering but do not match any specific class.
The order of `TargetClass` objects in the `classes` list defines the
priority for classification during encoding. The system checks annotations
against these definitions sequentially and assigns the name of the *first*
matching class.
Attributes
----------
classes : List[TargetClass]
An ordered list of specific target class definitions. The order
determines matching priority (first match wins). Defaults to a
standard set of classes via `get_default_classes`.
generic_class : List[TagInfo]
A list of tags defining the "generic" or "unclassified but relevant"
category (e.g., representing a generic 'Bat' call that wasn't
assigned to a specific species). These tags are typically assigned
during decoding when a sound event was detected and passed filtering
but did not match any specific class rule defined in the `classes` list.
Defaults to a standard set of tags via `get_default_generic_class`.
Raises
------
ValueError
If validation fails (e.g., non-unique class names in the `classes`
list).
Notes
-----
- It is crucial that the `name` attribute of each `TargetClass` in the
`classes` list is unique. This configuration includes a validator to
enforce this uniqueness.
- The `generic_class` tags provide a baseline identity for relevant sounds
that don't fit into more specific defined categories.
"""
classes: List[TargetClass] = Field(default_factory=_get_default_classes)
generic_class: List[TagInfo] = Field(
default_factory=_get_default_generic_class
)
@field_validator("classes")
def check_unique_class_names(cls, v: List[TargetClass]):
"""Ensure all defined class names are unique."""
names = [c.name for c in v]
if len(names) != len(set(names)):
name_counts = Counter(names)
duplicates = [
name for name, count in name_counts.items() if count > 1
]
raise ValueError(
"Class names must be unique. Found duplicates: "
f"{', '.join(duplicates)}"
)
return v
def _is_target_class(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
match_all: bool = True,
) -> bool:
"""Check if a sound event annotation matches a set of required tags.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
required_tags : Set[data.Tag]
A set of `soundevent.data.Tag` objects that define the class criteria.
match_all : bool, default=True
If True, checks if *all* `required_tags` are present in the
annotation's tags (subset check). If False, checks if *at least one*
of the `required_tags` is present (intersection check).
Returns
-------
bool
True if the annotation meets the tag criteria, False otherwise.
"""
annotation_tags = set(sound_event_annotation.tags)
if match_all:
return tags <= annotation_tags
return bool(tags & annotation_tags)
def get_class_names_from_config(config: ClassesConfig) -> List[str]:
"""Extract the list of class names from a ClassesConfig object.
Parameters
----------
config : ClassesConfig
The loaded classes configuration object.
Returns
-------
List[str]
An ordered list of unique class names defined in the configuration.
"""
return [class_info.name for class_info in config.classes]
def _encode_with_multiple_classifiers(
sound_event_annotation: data.SoundEventAnnotation,
classifiers: List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]],
) -> Optional[str]:
"""Encode an annotation by checking against a list of classifiers.
Internal helper function used by the `SoundEventEncoder`. It iterates
through the provided list of (class_name, classifier_function) pairs.
Returns the name associated with the first classifier function that
returns True for the given annotation.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to encode.
classifiers : List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]]
An ordered list where each tuple contains a class name and a function
that returns True if the annotation matches that class. The order
determines priority.
Returns
-------
str or None
The name of the first matching class, or None if no classifier matches.
"""
for class_name, classifier in classifiers:
if classifier(sound_event_annotation):
return class_name
return None
def build_sound_event_encoder(
config: ClassesConfig,
term_registry: TermRegistry = term_registry,
) -> SoundEventEncoder:
"""Build a sound event encoder function from the classes configuration.
The returned encoder function iterates through the class definitions in the
order specified in the config. It assigns an annotation the name of the
first class definition it matches.
Parameters
----------
config : ClassesConfig
The loaded and validated classes configuration object.
term_registry : TermRegistry, optional
The TermRegistry instance used to look up term keys specified in the
`TagInfo` objects within the configuration. Defaults to the global
`batdetect2.targets.terms.registry`.
Returns
-------
SoundEventEncoder
A callable function that takes a `SoundEventAnnotation` and returns
an optional string representing the matched class name, or None if no
class matches.
Raises
------
KeyError
If a term key specified in the configuration is not found in the
provided `term_registry`.
"""
binary_classifiers = [
(
class_info.name,
partial(
_is_target_class,
tags={
get_tag_from_info(tag_info, term_registry=term_registry)
for tag_info in class_info.tags
},
match_all=class_info.match_type == "all",
),
)
for class_info in config.classes
]
return partial(
_encode_with_multiple_classifiers,
classifiers=binary_classifiers,
)
def _decode_class(
name: str,
mapping: Dict[str, List[data.Tag]],
raise_on_error: bool = True,
) -> List[data.Tag]:
"""Decode a class name into a list of representative tags using a mapping.
Internal helper function used by the `SoundEventDecoder`. Looks up the
provided class `name` in the `mapping` dictionary.
Parameters
----------
name : str
The class name to decode.
mapping : Dict[str, List[data.Tag]]
A dictionary mapping class names to lists of `soundevent.data.Tag`
objects.
raise_on_error : bool, default=True
If True, raises a ValueError if the `name` is not found in the
`mapping`. If False, returns an empty list if the `name` is not found.
Returns
-------
List[data.Tag]
The list of tags associated with the class name, or an empty list if
not found and `raise_on_error` is False.
Raises
------
ValueError
If `name` is not found in `mapping` and `raise_on_error` is True.
"""
if name not in mapping and raise_on_error:
raise ValueError(f"Class {name} not found in mapping.")
if name not in mapping:
return []
return mapping[name]
def build_sound_event_decoder(
config: ClassesConfig,
term_registry: TermRegistry = term_registry,
raise_on_unmapped: bool = False,
) -> SoundEventDecoder:
"""Build a sound event decoder function from the classes configuration.
Creates a callable `SoundEventDecoder` that maps a class name string
back to a list of representative `soundevent.data.Tag` objects based on
the `ClassesConfig`. It uses the `output_tags` field if provided in a
`TargetClass`, otherwise falls back to the `tags` field.
Parameters
----------
config : ClassesConfig
The loaded and validated classes configuration object.
term_registry : TermRegistry, optional
The TermRegistry instance used to look up term keys. Defaults to the
global `batdetect2.targets.terms.registry`.
raise_on_unmapped : bool, default=False
If True, the returned decoder function will raise a ValueError if asked
to decode a class name that is not in the configuration. If False, it
will return an empty list for unmapped names.
Returns
-------
SoundEventDecoder
A callable function that takes a class name string and returns a list
of `soundevent.data.Tag` objects.
Raises
------
KeyError
If a term key specified in the configuration (`output_tags`, `tags`, or
`generic_class`) is not found in the provided `term_registry`.
"""
mapping = {}
for class_info in config.classes:
tags_to_use = (
class_info.output_tags
if class_info.output_tags is not None
else class_info.tags
)
mapping[class_info.name] = [
get_tag_from_info(tag_info, term_registry=term_registry)
for tag_info in tags_to_use
]
return partial(
_decode_class,
mapping=mapping,
raise_on_error=raise_on_unmapped,
)
def build_generic_class_tags(
config: ClassesConfig,
term_registry: TermRegistry = term_registry,
) -> List[data.Tag]:
"""Extract and build the list of tags for the generic class from config.
Converts the list of `TagInfo` objects defined in `config.generic_class`
into a list of `soundevent.data.Tag` objects using the term registry.
Parameters
----------
config : ClassesConfig
The loaded classes configuration object.
term_registry : TermRegistry, optional
The TermRegistry instance for term lookups. Defaults to the global
`batdetect2.targets.terms.registry`.
Returns
-------
List[data.Tag]
The list of fully constructed tags representing the generic class.
Raises
------
KeyError
If a term key specified in `config.generic_class` is not found in the
provided `term_registry`.
"""
return [
get_tag_from_info(tag_info, term_registry=term_registry)
for tag_info in config.generic_class
]
def load_classes_config(
path: data.PathLike,
field: Optional[str] = None,
) -> ClassesConfig:
"""Load the target classes configuration from a file.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : str, optional
If the classes configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
Returns
-------
ClassesConfig
The loaded and validated classes configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the ClassesConfig schema
or if class names are not unique.
"""
return load_config(path, schema=ClassesConfig, field=field)
def load_encoder_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = term_registry,
) -> SoundEventEncoder:
"""Load a class encoder function directly from a configuration file.
This is a convenience function that combines loading the `ClassesConfig`
from a file and building the final `SoundEventEncoder` function.
Parameters
----------
path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
If the classes configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
term_registry : TermRegistry, optional
The TermRegistry instance used for term lookups. Defaults to the
global `batdetect2.targets.terms.registry`.
Returns
-------
SoundEventEncoder
The final encoder function ready to classify annotations.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the ClassesConfig schema
or if class names are not unique.
KeyError
If a term key specified in the configuration is not found in the
provided `term_registry` during the build process.
"""
config = load_classes_config(path, field=field)
return build_sound_event_encoder(config, term_registry=term_registry)
def load_decoder_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = term_registry,
raise_on_unmapped: bool = False,
) -> SoundEventDecoder:
"""Load a class decoder function directly from a configuration file.
This is a convenience function that combines loading the `ClassesConfig`
from a file and building the final `SoundEventDecoder` function.
Parameters
----------
path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
If the classes configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
term_registry : TermRegistry, optional
The TermRegistry instance used for term lookups. Defaults to the
global `batdetect2.targets.terms.registry`.
raise_on_unmapped : bool, default=False
If True, the returned decoder function will raise a ValueError if asked
to decode a class name that is not in the configuration. If False, it
will return an empty list for unmapped names.
Returns
-------
SoundEventDecoder
The final decoder function ready to convert class names back into tags.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the ClassesConfig schema
or if class names are not unique.
KeyError
If a term key specified in the configuration is not found in the
provided `term_registry` during the build process.
"""
config = load_classes_config(path, field=field)
return build_sound_event_decoder(
config,
term_registry=term_registry,
raise_on_unmapped=raise_on_unmapped,
)

View File

@ -0,0 +1,315 @@
import logging
from functools import partial
from typing import Callable, List, Literal, Optional, Set
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
get_tag_from_info,
term_registry,
)
__all__ = [
"FilterConfig",
"FilterRule",
"SoundEventFilter",
"build_sound_event_filter",
"build_filter_from_rule",
"load_filter_config",
"load_filter_from_config",
]
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
"""Type alias for a filter function.
A filter function accepts a soundevent.data.SoundEventAnnotation object
and returns True if the annotation should be kept based on the filter's
criteria, or False if it should be discarded.
"""
logger = logging.getLogger(__name__)
class FilterRule(BaseConfig):
"""Defines a single rule for filtering sound event annotations.
Based on the `match_type`, this rule checks if the tags associated with a
sound event annotation meet certain criteria relative to the `tags` list
defined in this rule.
Attributes
----------
match_type : Literal["any", "all", "exclude", "equal"]
Determines how the `tags` list is used:
- "any": Pass if the annotation has at least one tag from the list.
- "all": Pass if the annotation has all tags from the list (it can
have others too).
- "exclude": Pass if the annotation has none of the tags from the list.
- "equal": Pass if the annotation's tags are exactly the same set as
provided in the list.
tags : List[TagInfo]
A list of tags (defined using TagInfo for configuration) that this
rule operates on.
"""
match_type: Literal["any", "all", "exclude", "equal"]
tags: List[TagInfo]
def has_any_tag(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
) -> bool:
"""Check if the annotation has at least one of the specified tags.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
tags : Set[data.Tag]
The set of tags to look for.
Returns
-------
bool
True if the annotation has one or more tags from the specified set,
False otherwise.
"""
sound_event_tags = set(sound_event_annotation.tags)
return bool(tags & sound_event_tags)
def contains_tags(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
) -> bool:
"""Check if the annotation contains all of the specified tags.
The annotation may have additional tags beyond those specified.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
tags : Set[data.Tag]
The set of tags that must all be present in the annotation.
Returns
-------
bool
True if the annotation's tags are a superset of the specified tags,
False otherwise.
"""
sound_event_tags = set(sound_event_annotation.tags)
return tags < sound_event_tags
def does_not_have_tags(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
):
"""Check if the annotation has none of the specified tags.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
tags : Set[data.Tag]
The set of tags that must *not* be present in the annotation.
Returns
-------
bool
True if the annotation has zero tags in common with the specified set,
False otherwise.
"""
return not has_any_tag(sound_event_annotation, tags)
def equal_tags(
sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag],
) -> bool:
"""Check if the annotation's tags are exactly equal to the specified set.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
tags : Set[data.Tag]
The exact set of tags the annotation must have.
Returns
-------
bool
True if the annotation's tags set is identical to the specified set,
False otherwise.
"""
sound_event_tags = set(sound_event_annotation.tags)
return tags == sound_event_tags
def build_filter_from_rule(
rule: FilterRule,
term_registry: TermRegistry = term_registry,
) -> SoundEventFilter:
"""Creates a callable filter function from a single FilterRule.
Parameters
----------
rule : FilterRule
The filter rule configuration object.
Returns
-------
SoundEventFilter
A function that takes a SoundEventAnnotation and returns True if it
passes the rule, False otherwise.
Raises
------
ValueError
If the rule contains an invalid `match_type`.
"""
tag_set = {
get_tag_from_info(tag_info, term_registry=term_registry)
for tag_info in rule.tags
}
if rule.match_type == "any":
return partial(has_any_tag, tags=tag_set)
if rule.match_type == "all":
return partial(contains_tags, tags=tag_set)
if rule.match_type == "exclude":
return partial(does_not_have_tags, tags=tag_set)
if rule.match_type == "equal":
return partial(equal_tags, tags=tag_set)
raise ValueError(
f"Invalid match type {rule.match_type}. Valid types "
"are: 'any', 'all', 'exclude' and 'equal'"
)
def _passes_all_filters(
sound_event_annotation: data.SoundEventAnnotation,
filters: List[SoundEventFilter],
) -> bool:
"""Check if the annotation passes all provided filters.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to check.
filters : List[SoundEventFilter]
A list of filter functions to apply.
Returns
-------
bool
True if the annotation passes all filters, False otherwise.
"""
for filter_fn in filters:
if not filter_fn(sound_event_annotation):
logging.debug(
f"Sound event annotation {sound_event_annotation.uuid} "
f"excluded due to rule {filter_fn}",
)
return False
return True
class FilterConfig(BaseConfig):
"""Configuration model for defining a list of filter rules.
Attributes
----------
rules : List[FilterRule]
A list of FilterRule objects. An annotation must pass all rules in
this list to be considered valid by the filter built from this config.
"""
rules: List[FilterRule] = Field(default_factory=list)
def build_sound_event_filter(
config: FilterConfig,
term_registry: TermRegistry = term_registry,
) -> SoundEventFilter:
"""Builds a merged filter function from a FilterConfig object.
Creates individual filter functions for each rule in the configuration
and merges them using AND logic.
Parameters
----------
config : FilterConfig
The configuration object containing the list of filter rules.
Returns
-------
SoundEventFilter
A single callable filter function that applies all defined rules.
"""
filters = [
build_filter_from_rule(rule, term_registry=term_registry)
for rule in config.rules
]
return partial(_passes_all_filters, filters=filters)
def load_filter_config(
path: data.PathLike, field: Optional[str] = None
) -> FilterConfig:
"""Loads the filter configuration from a file.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : Optional[str], optional
If the filter configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
Returns
-------
FilterConfig
The loaded and validated filter configuration object.
"""
return load_config(path, schema=FilterConfig, field=field)
def load_filter_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = term_registry,
) -> SoundEventFilter:
"""Loads filter configuration from a file and builds the filter function.
This is a convenience function that combines loading the configuration
and building the final callable filter function.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : Optional[str], optional
If the filter configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
Returns
-------
SoundEventFilter
The final merged filter function ready to be used.
"""
config = load_filter_config(path=path, field=field)
return build_sound_event_filter(config, term_registry=term_registry)

462
batdetect2/targets/rois.py Normal file
View File

@ -0,0 +1,462 @@
"""Handles mapping between geometric ROIs and target representations.
This module defines the interface and provides implementation for converting
a sound event's Region of Interest (ROI), typically represented by a
`soundevent.data.Geometry` object like a `BoundingBox`, into a format
suitable for use as a machine learning target. This usually involves:
1. Extracting a single reference point (time, frequency) from the geometry.
2. Calculating relevant size dimensions (e.g., duration/width,
bandwidth/height) and applying scaling factors.
It also provides the inverse operation: recovering an approximate geometric ROI
(like a `BoundingBox`) from a predicted reference point and predicted size
dimensions.
This logic is encapsulated within components adhering to the `ROITargetMapper`
protocol. Configuration for this mapping (e.g., which reference point to use,
scaling factors) is managed by the `ROIConfig`. This module separates the
*geometric* aspect of target definition from the *semantic* classification
handled in `batdetect2.targets.classes`.
"""
from typing import List, Optional, Protocol, Tuple
import numpy as np
from soundevent import data, geometry
from soundevent.geometry.operations import Positions
from batdetect2.configs import BaseConfig, load_config
__all__ = [
"ROITargetMapper",
"ROIConfig",
"BBoxEncoder",
"build_roi_mapper",
"load_roi_mapper",
"DEFAULT_POSITION",
"SIZE_WIDTH",
"SIZE_HEIGHT",
"SIZE_ORDER",
"DEFAULT_TIME_SCALE",
"DEFAULT_FREQUENCY_SCALE",
]
SIZE_WIDTH = "width"
"""Standard name for the width/time dimension component ('width')."""
SIZE_HEIGHT = "height"
"""Standard name for the height/frequency dimension component ('height')."""
SIZE_ORDER = (SIZE_WIDTH, SIZE_HEIGHT)
"""Standard order of dimensions for size arrays ([width, height])."""
DEFAULT_TIME_SCALE = 1000.0
"""Default scaling factor for time duration."""
DEFAULT_FREQUENCY_SCALE = 1 / 859.375
"""Default scaling factor for frequency bandwidth."""
DEFAULT_POSITION = "bottom-left"
"""Default reference position within the geometry ('bottom-left' corner)."""
class ROITargetMapper(Protocol):
"""Protocol defining the interface for ROI-to-target mapping.
Specifies the methods required for converting a geometric region of interest
(`soundevent.data.Geometry`) into a target representation (reference point
and scaled dimensions) and for recovering an approximate ROI from that
representation.
Attributes
----------
dimension_names : List[str]
A list containing the names of the dimensions returned by
`get_roi_size` and expected by `recover_roi`
(e.g., ['width', 'height']).
"""
dimension_names: List[str]
def get_roi_position(self, geom: data.Geometry) -> tuple[float, float]:
"""Extract the reference position from a geometry.
Parameters
----------
geom : soundevent.data.Geometry
The input geometry (e.g., BoundingBox, Polygon).
Returns
-------
Tuple[float, float]
The calculated reference position as (time, frequency) coordinates,
based on the implementing class's configuration (e.g., "center",
"bottom-left").
Raises
------
ValueError
If the position cannot be calculated for the given geometry type
or configured reference point.
"""
...
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
"""Calculate the scaled target dimensions from a geometry.
Computes the relevant size measures.
Parameters
----------
geom : soundevent.data.Geometry
The input geometry.
Returns
-------
np.ndarray
A NumPy array containing the scaled dimensions corresponding to
`dimension_names`. For bounding boxes, typically contains
`[scaled_width, scaled_height]`.
Raises
------
TypeError, ValueError
If the size cannot be computed for the given geometry type.
"""
...
def recover_roi(
self, pos: tuple[float, float], dims: np.ndarray
) -> data.Geometry:
"""Recover an approximate ROI from a position and target dimensions.
Performs the inverse mapping: takes a reference position and the
predicted dimensions and reconstructs a geometric representation.
Parameters
----------
pos : Tuple[float, float]
The reference position (time, frequency).
dims : np.ndarray
The NumPy array containing the dimensions, matching the order
specified by `dimension_names`.
Returns
-------
soundevent.data.Geometry
The reconstructed geometry.
Raises
------
ValueError
If the number of provided dimensions `dims` does not match
`dimension_names` or if reconstruction fails.
"""
...
class ROIConfig(BaseConfig):
"""Configuration for mapping Regions of Interest (ROIs).
Defines parameters controlling how geometric ROIs are converted into
target representations (reference points and scaled sizes).
Attributes
----------
position : Positions, default="bottom-left"
Specifies the reference point within the geometry (e.g., bounding box)
to use as the target location (e.g., "center", "bottom-left").
See `soundevent.geometry.operations.Positions`.
time_scale : float, default=1000.0
Scaling factor applied to the time duration (width) of the ROI
when calculating the target size representation. Must match model
expectations.
frequency_scale : float, default=1/859.375
Scaling factor applied to the frequency bandwidth (height) of the ROI
when calculating the target size representation. Must match model
expectations.
"""
position: Positions = DEFAULT_POSITION
time_scale: float = DEFAULT_TIME_SCALE
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
class BBoxEncoder(ROITargetMapper):
"""Concrete implementation of `ROITargetMapper` focused on Bounding Boxes.
This class implements the ROI mapping protocol primarily for
`soundevent.data.BoundingBox` geometry. It extracts reference points,
calculates scaled width/height, and recovers bounding boxes based on
configured position and scaling factors.
Attributes
----------
dimension_names : List[str]
Specifies the output dimension names as ['width', 'height'].
position : Positions
The configured reference point type (e.g., "center", "bottom-left").
time_scale : float
The configured scaling factor for the time dimension (width).
frequency_scale : float
The configured scaling factor for the frequency dimension (height).
"""
dimension_names = [SIZE_WIDTH, SIZE_HEIGHT]
def __init__(
self,
position: Positions = DEFAULT_POSITION,
time_scale: float = DEFAULT_TIME_SCALE,
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
):
"""Initialize the BBoxEncoder.
Parameters
----------
position : Positions, default="bottom-left"
Reference point type within the bounding box.
time_scale : float, default=1000.0
Scaling factor for time duration (width).
frequency_scale : float, default=1/859.375
Scaling factor for frequency bandwidth (height).
"""
self.position: Positions = position
self.time_scale = time_scale
self.frequency_scale = frequency_scale
def get_roi_position(self, geom: data.Geometry) -> Tuple[float, float]:
"""Extract the configured reference position from the geometry.
Uses `soundevent.geometry.get_geometry_point`.
Parameters
----------
geom : soundevent.data.Geometry
Input geometry (e.g., BoundingBox).
Returns
-------
Tuple[float, float]
Reference position (time, frequency).
"""
return geometry.get_geometry_point(geom, position=self.position)
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
"""Calculate the scaled [width, height] from the geometry's bounds.
Computes the bounding box, extracts duration and bandwidth, and applies
the configured `time_scale` and `frequency_scale`.
Parameters
----------
geom : soundevent.data.Geometry
Input geometry.
Returns
-------
np.ndarray
A 1D NumPy array: `[scaled_width, scaled_height]`.
"""
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
geom
)
return np.array(
[
(end_time - start_time) * self.time_scale,
(high_freq - low_freq) * self.frequency_scale,
]
)
def recover_roi(
self,
pos: tuple[float, float],
dims: np.ndarray,
) -> data.Geometry:
"""Recover a BoundingBox from a position and scaled dimensions.
Un-scales the input dimensions using the configured factors and
reconstructs a `soundevent.data.BoundingBox` centered or anchored at
the given reference `pos` according to the configured `position` type.
Parameters
----------
pos : Tuple[float, float]
Reference position (time, frequency).
dims : np.ndarray
NumPy array containing the *scaled* dimensions, expected order is
[scaled_width, scaled_height].
Returns
-------
soundevent.data.BoundingBox
The reconstructed bounding box.
Raises
------
ValueError
If `dims` does not have the expected shape (length 2).
"""
if dims.ndim != 1 or dims.shape[0] != 2:
raise ValueError(
"Dimension array does not have the expected shape. "
f"({dims.shape = }) != ([2])"
)
width, height = dims
return _build_bounding_box(
pos,
duration=width / self.time_scale,
bandwidth=height / self.frequency_scale,
position=self.position,
)
def build_roi_mapper(config: ROIConfig) -> ROITargetMapper:
"""Factory function to create an ROITargetMapper from configuration.
Currently creates a `BBoxEncoder` instance based on the provided
`ROIConfig`.
Parameters
----------
config : ROIConfig
Configuration object specifying ROI mapping parameters.
Returns
-------
ROITargetMapper
An initialized `BBoxEncoder` instance configured with the settings
from `config`.
"""
return BBoxEncoder(
position=config.position,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
)
def load_roi_mapper(
path: data.PathLike, field: Optional[str] = None
) -> ROITargetMapper:
"""Load ROI mapping configuration from a file and build the mapper.
Convenience function that loads an `ROIConfig` from the specified file
(and optional field) and then uses `build_roi_mapper` to create the
corresponding `ROITargetMapper` instance.
Parameters
----------
path : PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
Dot-separated path to a nested section within the file containing the
ROI configuration. If None, the entire file content is used.
Returns
-------
ROITargetMapper
An initialized ROI mapper instance based on the configuration file.
Raises
------
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
TypeError
If the configuration file cannot be found, parsed, validated, or if
the specified `field` is invalid.
"""
config = load_config(path=path, schema=ROIConfig, field=field)
return build_roi_mapper(config)
VALID_POSITIONS = [
"bottom-left",
"bottom-right",
"top-left",
"top-right",
"center-left",
"center-right",
"top-center",
"bottom-center",
"center",
"centroid",
"point_on_surface",
]
def _build_bounding_box(
pos: tuple[float, float],
duration: float,
bandwidth: float,
position: Positions = DEFAULT_POSITION,
) -> data.BoundingBox:
"""Construct a BoundingBox from a reference point, size, and position type.
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
coordinates [start_time, low_freq, end_time, high_freq] based on where
the input `pos` (time, freq) is located relative to the box (e.g.,
center, corner).
Parameters
----------
pos : Tuple[float, float]
Reference position (time, frequency).
duration : float
The required *unscaled* duration (width) of the bounding box.
bandwidth : float
The required *unscaled* frequency bandwidth (height) of the bounding
box.
position : Positions, default="bottom-left"
Specifies which part of the bounding box the input `pos` corresponds to.
Returns
-------
data.BoundingBox
The constructed bounding box object.
Raises
------
ValueError
If `position` is not a recognized value or format.
"""
time, freq = pos
if position in ["center", "centroid", "point_on_surface"]:
return data.BoundingBox(
coordinates=[
time - duration / 2,
freq - bandwidth / 2,
time + duration / 2,
freq + bandwidth / 2,
]
)
if position not in VALID_POSITIONS:
raise ValueError(
f"Invalid position: {position}. "
f"Valid options are: {VALID_POSITIONS}"
)
y, x = position.split("-")
start_time = {
"left": time,
"center": time - duration / 2,
"right": time - duration,
}[x]
low_freq = {
"bottom": freq,
"center": freq - bandwidth / 2,
"top": freq - bandwidth,
}[y]
return data.BoundingBox(
coordinates=[
start_time,
low_freq,
start_time + duration,
low_freq + bandwidth,
]
)

495
batdetect2/targets/terms.py Normal file
View File

@ -0,0 +1,495 @@
"""Manages the vocabulary (Terms and Tags) for defining training targets.
This module provides the necessary tools to declare, register, and manage the
set of `soundevent.data.Term` objects used throughout the `batdetect2.targets`
sub-package. It establishes a consistent vocabulary for filtering,
transforming, and classifying sound events based on their annotations (Tags).
The core component is the `TermRegistry`, which maps unique string keys
(aliases) to specific `Term` definitions. This allows users to refer to complex
terms using simple, consistent keys in configuration files and code.
Terms can be pre-defined, loaded from the `soundevent.terms` library, defined
programmatically, or loaded from external configuration files (e.g., YAML).
"""
from collections.abc import Mapping
from inspect import getmembers
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
from soundevent import data, terms
from batdetect2.configs import load_config
__all__ = [
"call_type",
"individual",
"data_source",
"get_tag_from_info",
"TermInfo",
"TagInfo",
]
# The default key used to reference the 'generic_class' term.
# Often used implicitly when defining classification targets.
GENERIC_CLASS_KEY = "class"
data_source = data.Term(
name="soundevent:data_source",
label="Data Source",
definition=(
"A unique identifier for the source of the data, typically "
"representing the project, site, or deployment context."
),
)
call_type = data.Term(
name="soundevent:call_type",
label="Call Type",
definition=(
"A broad categorization of animal vocalizations based on their "
"intended function or purpose (e.g., social, distress, mating, "
"territorial, echolocation)."
),
)
"""Term representing the broad functional category of a vocalization."""
individual = data.Term(
name="soundevent:individual",
label="Individual",
definition=(
"An id for an individual animal. In the context of bioacoustic "
"annotation, this term is used to label vocalizations that are "
"attributed to a specific individual."
),
)
"""Term used for tags identifying a specific individual animal."""
generic_class = data.Term(
name="soundevent:class",
label="Class",
definition=(
"A generic term representing the name of a class within a "
"classification model. Its specific meaning is determined by "
"the model's application."
),
)
"""Generic term representing a classification model's output class label."""
class TermRegistry(Mapping[str, data.Term]):
"""Manages a registry mapping unique keys to Term definitions.
This class acts as the central repository for the vocabulary of terms
used within the target definition process. It allows registering terms
with simple string keys and retrieving them consistently.
"""
def __init__(self, terms: Optional[Dict[str, data.Term]] = None):
"""Initializes the TermRegistry.
Parameters
----------
terms : dict[str, soundevent.data.Term], optional
An optional dictionary of initial key-to-Term mappings
to populate the registry with. Defaults to an empty registry.
"""
self._terms: Dict[str, data.Term] = terms or {}
def __getitem__(self, key: str) -> data.Term:
return self._terms[key]
def __len__(self) -> int:
return len(self._terms)
def __iter__(self):
return iter(self._terms)
def add_term(self, key: str, term: data.Term) -> None:
"""Adds a Term object to the registry with the specified key.
Parameters
----------
key : str
The unique string key to associate with the term.
term : soundevent.data.Term
The soundevent.data.Term object to register.
Raises
------
KeyError
If a term with the provided key already exists in the
registry.
"""
if key in self._terms:
raise KeyError("A term with the provided key already exists.")
self._terms[key] = term
def get_term(self, key: str) -> data.Term:
"""Retrieves a registered term by its unique key.
Parameters
----------
key : str
The unique string key of the term to retrieve.
Returns
-------
soundevent.data.Term
The corresponding soundevent.data.Term object.
Raises
------
KeyError
If no term with the specified key is found, with a
helpful message suggesting listing available keys.
"""
try:
return self._terms[key]
except KeyError as err:
raise KeyError(
"No term found for key "
f"'{key}'. Ensure it is registered or loaded. "
f"Available keys: {', '.join(self.get_keys())}"
) from err
def add_custom_term(
self,
key: str,
name: Optional[str] = None,
uri: Optional[str] = None,
label: Optional[str] = None,
definition: Optional[str] = None,
) -> data.Term:
"""Creates a new Term from attributes and adds it to the registry.
This is useful for defining terms directly in code or when loading
from configuration files where only attributes are provided.
If optional fields (`name`, `label`, `definition`) are not provided,
reasonable defaults are used (`key` for name/label, "Unknown" for
definition).
Parameters
----------
key : str
The unique string key for the new term.
name : str, optional
The name for the new term (defaults to `key`).
uri : str, optional
The URI for the new term (optional).
label : str, optional
The display label for the new term (defaults to `key`).
definition : str, optional
The definition for the new term (defaults to "Unknown").
Returns
-------
soundevent.data.Term
The newly created and registered soundevent.data.Term object.
Raises
------
KeyError
If a term with the provided key already exists.
"""
term = data.Term(
name=name or key,
label=label or key,
uri=uri,
definition=definition or "Unknown",
)
self.add_term(key, term)
return term
def get_keys(self) -> List[str]:
"""Returns a list of all keys currently registered.
Returns
-------
list[str]
A list of strings representing the keys of all registered terms.
"""
return list(self._terms.keys())
def get_terms(self) -> List[data.Term]:
"""Returns a list of all registered terms.
Returns
-------
list[soundevent.data.Term]
A list containing all registered Term objects.
"""
return list(self._terms.values())
def remove_key(self, key: str) -> None:
del self._terms[key]
term_registry = TermRegistry(
terms=dict(
[
*getmembers(terms, lambda x: isinstance(x, data.Term)),
("event", call_type),
("individual", individual),
("data_source", data_source),
(GENERIC_CLASS_KEY, generic_class),
]
)
)
"""The default, globally accessible TermRegistry instance.
It is pre-populated with standard terms from `soundevent.terms` and common
terms defined in this module (`call_type`, `individual`, `generic_class`).
Functions in this module use this registry by default unless another instance
is explicitly passed.
"""
def get_term_from_key(
key: str,
term_registry: TermRegistry = term_registry,
) -> data.Term:
"""Convenience function to retrieve a term by key from a registry.
Uses the global default registry unless a specific `term_registry`
instance is provided.
Parameters
----------
key : str
The unique key of the term to retrieve.
term_registry : TermRegistry, optional
The TermRegistry instance to search in. Defaults to the global
`registry`.
Returns
-------
soundevent.data.Term
The corresponding soundevent.data.Term object.
Raises
------
KeyError
If the key is not found in the specified registry.
"""
return term_registry.get_term(key)
def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
"""Convenience function to get all registered keys from a registry.
Uses the global default registry unless a specific `term_registry`
instance is provided.
Parameters
----------
term_registry : TermRegistry, optional
The TermRegistry instance to query. Defaults to the global `registry`.
Returns
-------
list[str]
A list of strings representing the keys of all registered terms.
"""
return term_registry.get_keys()
def get_terms(term_registry: TermRegistry = term_registry) -> List[data.Term]:
"""Convenience function to get all registered terms from a registry.
Uses the global default registry unless a specific `term_registry`
instance is provided.
Parameters
----------
term_registry : TermRegistry, optional
The TermRegistry instance to query. Defaults to the global `registry`.
Returns
-------
list[soundevent.data.Term]
A list containing all registered Term objects.
"""
return term_registry.get_terms()
class TagInfo(BaseModel):
"""Represents information needed to define a specific Tag.
This model is typically used in configuration files (e.g., YAML) to
specify tags used for filtering, target class definition, or associating
tags with output classes. It links a tag value to a term definition
via the term's registry key.
Attributes
----------
value : str
The value of the tag (e.g., "Myotis myotis", "Echolocation").
key : str, default="class"
The key (alias) of the term associated with this tag, as
registered in the TermRegistry. Defaults to "class", implying
it represents a classification target label by default.
"""
value: str
key: str = GENERIC_CLASS_KEY
def get_tag_from_info(
tag_info: TagInfo,
term_registry: TermRegistry = term_registry,
) -> data.Tag:
"""Creates a soundevent.data.Tag object from TagInfo data.
Looks up the term using the key in the provided `tag_info` from the
specified registry and constructs a Tag object.
Parameters
----------
tag_info : TagInfo
The TagInfo object containing the value and term key.
term_registry : TermRegistry, optional
The TermRegistry instance to use for term lookup. Defaults to the
global `registry`.
Returns
-------
soundevent.data.Tag
A soundevent.data.Tag object corresponding to the input info.
Raises
------
KeyError
If the term key specified in `tag_info.key` is not found
in the registry.
"""
term = get_term_from_key(tag_info.key, term_registry=term_registry)
return data.Tag(term=term, value=tag_info.value)
class TermInfo(BaseModel):
"""Represents the definition of a Term within a configuration file.
This model allows users to define custom terms directly in configuration
files (e.g., YAML) which can then be loaded into the TermRegistry.
It mirrors the parameters of `TermRegistry.add_custom_term`.
Attributes
----------
key : str
The unique key (alias) that will be used to register and
reference this term.
label : str, optional
The optional display label for the term. Defaults to `key`
if not provided during registration.
name : str, optional
The optional formal name for the term. Defaults to `key`
if not provided during registration.
uri : str, optional
The optional URI identifying the term (e.g., from a standard
vocabulary).
definition : str, optional
The optional textual definition of the term. Defaults to
"Unknown" if not provided during registration.
"""
key: str
label: Optional[str] = None
name: Optional[str] = None
uri: Optional[str] = None
definition: Optional[str] = None
class TermConfig(BaseModel):
"""Pydantic schema for loading a list of term definitions from config.
This model typically corresponds to a section in a configuration file
(e.g., YAML) containing a list of term definitions to be registered.
Attributes
----------
terms : list[TermInfo]
A list of TermInfo objects, each defining a term to be
registered. Defaults to an empty list.
Examples
--------
Example YAML structure:
```yaml
terms:
- key: species
uri: dwc:scientificName
label: Scientific Name
- key: my_custom_term
name: My Custom Term
definition: Describes a specific project attribute.
# ... more TermInfo definitions
```
"""
terms: List[TermInfo] = Field(default_factory=list)
def load_terms_from_config(
path: data.PathLike,
field: Optional[str] = None,
term_registry: TermRegistry = term_registry,
) -> Dict[str, data.Term]:
"""Loads term definitions from a configuration file and registers them.
Parses a configuration file (e.g., YAML) using the TermConfig schema,
extracts the list of TermInfo definitions, and adds each one as a
custom term to the specified TermRegistry instance.
Parameters
----------
path : data.PathLike
The path to the configuration file.
field : str, optional
Optional key indicating a specific section within the config
file where the 'terms' list is located. If None, expects the
list directly at the top level or within a structure matching
TermConfig schema.
term_registry : TermRegistry, optional
The TermRegistry instance to add the loaded terms to. Defaults to
the global `registry`.
Returns
-------
dict[str, soundevent.data.Term]
A dictionary mapping the keys of the newly added terms to their
corresponding Term objects.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the TermConfig schema.
KeyError
If a term key loaded from the config conflicts with a key
already present in the registry.
"""
data = load_config(path, schema=TermConfig, field=field)
return {
info.key: term_registry.add_custom_term(
info.key,
name=info.name,
uri=info.uri,
label=info.label,
definition=info.definition,
)
for info in data.terms
}
def register_term(
key: str, term: data.Term, registry: TermRegistry = term_registry
) -> None:
registry.add_term(key, term)

View File

@ -0,0 +1,699 @@
import importlib
from functools import partial
from typing import (
Annotated,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Union,
)
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
get_tag_from_info,
get_term_from_key,
)
from batdetect2.targets.terms import (
term_registry as default_term_registry,
)
__all__ = [
"DerivationRegistry",
"DeriveTagRule",
"MapValueRule",
"ReplaceRule",
"SoundEventTransformation",
"TransformConfig",
"build_transform_from_rule",
"build_transformation_from_config",
"derivation_registry",
"get_derivation",
"load_transformation_config",
"load_transformation_from_config",
"register_derivation",
]
SoundEventTransformation = Callable[
[data.SoundEventAnnotation], data.SoundEventAnnotation
]
"""Type alias for a sound event transformation function.
A function that accepts a sound event annotation object and returns a
(potentially) modified sound event annotation object. Transformations
should generally return a copy of the annotation rather than modifying
it in place.
"""
Derivation = Callable[[str], str]
"""Type alias for a derivation function.
A function that accepts a single string (typically a tag value) and returns
a new string (the derived value).
"""
class MapValueRule(BaseConfig):
"""Configuration for mapping specific values of a source term.
This rule replaces tags matching a specific term and one of the
original values with a new tag (potentially having a different term)
containing the corresponding replacement value. Useful for standardizing
or grouping tag values.
Attributes
----------
rule_type : Literal["map_value"]
Discriminator field identifying this rule type.
source_term_key : str
The key (registered in `TermRegistry`) of the term whose tags' values
should be checked against the `value_mapping`.
value_mapping : Dict[str, str]
A dictionary mapping original string values to replacement string
values. Only tags whose value is a key in this dictionary will be
affected.
target_term_key : str, optional
The key (registered in `TermRegistry`) for the term of the *output*
tag. If None (default), the output tag uses the same term as the
source (`source_term_key`). If provided, the term of the affected
tag is changed to this target term upon replacement.
"""
rule_type: Literal["map_value"] = "map_value"
source_term_key: str
value_mapping: Dict[str, str]
target_term_key: Optional[str] = None
def map_value_transform(
sound_event_annotation: data.SoundEventAnnotation,
source_term: data.Term,
target_term: data.Term,
mapping: Dict[str, str],
) -> data.SoundEventAnnotation:
"""Apply a value mapping transformation to an annotation's tags.
Iterates through the annotation's tags. If a tag matches the `source_term`
and its value is found in the `mapping`, it is replaced by a new tag with
the `target_term` and the mapped value. Other tags are kept unchanged.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to transform.
source_term : data.Term
The term of tags whose values should be mapped.
target_term : data.Term
The term to use for the newly created tags after mapping.
mapping : Dict[str, str]
The dictionary mapping original values to new values.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the transformed tags.
"""
tags = []
for tag in sound_event_annotation.tags:
if tag.term != source_term or tag.value not in mapping:
tags.append(tag)
continue
new_value = mapping[tag.value]
tags.append(data.Tag(term=target_term, value=new_value))
return sound_event_annotation.model_copy(update=dict(tags=tags))
class DeriveTagRule(BaseConfig):
"""Configuration for deriving a new tag from an existing tag's value.
This rule applies a specified function (`derivation_function`) to the
value of tags matching the `source_term_key`. It then adds a new tag
with the `target_term_key` and the derived value.
Attributes
----------
rule_type : Literal["derive_tag"]
Discriminator field identifying this rule type.
source_term_key : str
The key (registered in `TermRegistry`) of the term whose tag values
will be used as input to the derivation function.
derivation_function : str
The name/key identifying the derivation function to use. This can be
a key registered in the `DerivationRegistry` or, if
`import_derivation` is True, a full Python path like
`'my_module.my_submodule.my_function'`.
target_term_key : str, optional
The key (registered in `TermRegistry`) for the term of the new tag
that will be created with the derived value. If None (default), the
derived tag uses the same term as the source (`source_term_key`),
effectively performing an in-place value transformation.
import_derivation : bool, default=False
If True, treat `derivation_function` as a Python import path and
attempt to dynamically import it if not found in the registry.
Requires the function to be accessible in the Python environment.
keep_source : bool, default=True
If True, the original source tag (whose value was used for derivation)
is kept in the annotation's tag list alongside the newly derived tag.
If False, the original source tag is removed.
"""
rule_type: Literal["derive_tag"] = "derive_tag"
source_term_key: str
derivation_function: str
target_term_key: Optional[str] = None
import_derivation: bool = False
keep_source: bool = True
def derivation_tag_transform(
sound_event_annotation: data.SoundEventAnnotation,
source_term: data.Term,
target_term: data.Term,
derivation: Derivation,
keep_source: bool = True,
) -> data.SoundEventAnnotation:
"""Apply a derivation transformation to an annotation's tags.
Iterates through the annotation's tags. For each tag matching the
`source_term`, its value is passed to the `derivation` function.
A new tag is created with the `target_term` and the derived value,
and added to the output tag list. The original source tag is kept
or discarded based on `keep_source`. Other tags are kept unchanged.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to transform.
source_term : data.Term
The term of tags whose values serve as input for the derivation.
target_term : data.Term
The term to use for the newly created derived tags.
derivation : Derivation
The function to apply to the source tag's value.
keep_source : bool, default=True
Whether to keep the original source tag in the output.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the transformed tags (including derived
ones).
"""
tags = []
for tag in sound_event_annotation.tags:
if tag.term != source_term:
tags.append(tag)
continue
if keep_source:
tags.append(tag)
new_value = derivation(tag.value)
tags.append(data.Tag(term=target_term, value=new_value))
return sound_event_annotation.model_copy(update=dict(tags=tags))
class ReplaceRule(BaseConfig):
"""Configuration for exactly replacing one specific tag with another.
This rule looks for an exact match of the `original` tag (both term and
value) and replaces it with the specified `replacement` tag.
Attributes
----------
rule_type : Literal["replace"]
Discriminator field identifying this rule type.
original : TagInfo
The exact tag to search for, defined using its value and term key.
replacement : TagInfo
The tag to substitute in place of the original tag, defined using
its value and term key.
"""
rule_type: Literal["replace"] = "replace"
original: TagInfo
replacement: TagInfo
def replace_tag_transform(
sound_event_annotation: data.SoundEventAnnotation,
source: data.Tag,
target: data.Tag,
) -> data.SoundEventAnnotation:
"""Apply an exact tag replacement transformation.
Iterates through the annotation's tags. If a tag exactly matches the
`source` tag, it is replaced by the `target` tag. Other tags are kept
unchanged.
Parameters
----------
sound_event_annotation : data.SoundEventAnnotation
The annotation to transform.
source : data.Tag
The exact tag to find and replace.
target : data.Tag
The tag to replace the source tag with.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the replaced tag (if found).
"""
tags = []
for tag in sound_event_annotation.tags:
if tag == source:
tags.append(target)
else:
tags.append(tag)
return sound_event_annotation.model_copy(update=dict(tags=tags))
class TransformConfig(BaseConfig):
"""Configuration model for defining a sequence of transformation rules.
Attributes
----------
rules : List[Union[ReplaceRule, MapValueRule, DeriveTagRule]]
A list of transformation rules to apply. The rules are applied
sequentially in the order they appear in the list. The output of
one rule becomes the input for the next. The `rule_type` field
discriminates between the different rule models.
"""
rules: List[
Annotated[
Union[ReplaceRule, MapValueRule, DeriveTagRule],
Field(discriminator="rule_type"),
]
] = Field(
default_factory=list,
)
class DerivationRegistry(Mapping[str, Derivation]):
"""A registry for managing named derivation functions.
Derivation functions are callables that take a string value and return
a transformed string value, used by `DeriveTagRule`. This registry
allows functions to be registered with a key and retrieved later.
"""
def __init__(self):
"""Initialize an empty DerivationRegistry."""
self._derivations: Dict[str, Derivation] = {}
def __getitem__(self, key: str) -> Derivation:
"""Retrieve a derivation function by key."""
return self._derivations[key]
def __len__(self) -> int:
"""Return the number of registered derivation functions."""
return len(self._derivations)
def __iter__(self):
"""Return an iterator over the keys of registered functions."""
return iter(self._derivations)
def register(self, key: str, derivation: Derivation) -> None:
"""Register a derivation function with a unique key.
Parameters
----------
key : str
The unique key to associate with the derivation function.
derivation : Derivation
The callable derivation function (takes str, returns str).
Raises
------
KeyError
If a derivation function with the same key is already registered.
"""
if key in self._derivations:
raise KeyError(
f"A derivation with the provided key {key} already exists"
)
self._derivations[key] = derivation
def get_derivation(self, key: str) -> Derivation:
"""Retrieve a derivation function by its registered key.
Parameters
----------
key : str
The key of the derivation function to retrieve.
Returns
-------
Derivation
The requested derivation function.
Raises
------
KeyError
If no derivation function with the specified key is registered.
"""
try:
return self._derivations[key]
except KeyError as err:
raise KeyError(
f"No derivation with key {key} is registered."
) from err
def get_keys(self) -> List[str]:
"""Get a list of all registered derivation function keys.
Returns
-------
List[str]
The keys of all registered functions.
"""
return list(self._derivations.keys())
def get_derivations(self) -> List[Derivation]:
"""Get a list of all registered derivation functions.
Returns
-------
List[Derivation]
The registered derivation function objects.
"""
return list(self._derivations.values())
derivation_registry = DerivationRegistry()
"""Global instance of the DerivationRegistry.
Register custom derivation functions here to make them available by key
in `DeriveTagRule` configuration.
"""
def get_derivation(
key: str,
import_derivation: bool = False,
registry: DerivationRegistry = derivation_registry,
):
"""Retrieve a derivation function by key, optionally importing it.
First attempts to find the function in the provided `registry`.
If not found and `import_derivation` is True, attempts to dynamically
import the function using the `key` as a full Python path
(e.g., 'my_module.submodule.my_func').
Parameters
----------
key : str
The key or Python path of the derivation function.
import_derivation : bool, default=False
If True, attempt dynamic import if key is not in the registry.
registry : DerivationRegistry, optional
The registry instance to check first. Defaults to the global
`derivation_registry`.
Returns
-------
Derivation
The requested derivation function.
Raises
------
KeyError
If the key is not found in the registry and either
`import_derivation` is False or the dynamic import fails.
ImportError
If dynamic import fails specifically due to module not found.
AttributeError
If dynamic import fails because the function name isn't in the module.
"""
if not import_derivation or key in registry:
return registry.get_derivation(key)
try:
module_path, func_name = key.rsplit(".", 1)
module = importlib.import_module(module_path)
func = getattr(module, func_name)
return func
except ImportError as err:
raise KeyError(
f"Unable to load derivation '{key}'. Check the path and ensure "
"it points to a valid callable function in an importable module."
) from err
def build_transform_from_rule(
rule: Union[ReplaceRule, MapValueRule, DeriveTagRule],
derivation_registry: DerivationRegistry = derivation_registry,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventTransformation:
"""Build a specific SoundEventTransformation function from a rule config.
Selects the appropriate transformation logic based on the rule's
`rule_type`, fetches necessary terms and derivation functions, and
returns a partially applied function ready to transform an annotation.
Parameters
----------
rule : Union[ReplaceRule, MapValueRule, DeriveTagRule]
The configuration object for a single transformation rule.
registry : DerivationRegistry, optional
The derivation registry to use for `DeriveTagRule`. Defaults to the
global `derivation_registry`.
Returns
-------
SoundEventTransformation
A callable that applies the specified rule to a SoundEventAnnotation.
Raises
------
KeyError
If required term keys or derivation keys are not found.
ValueError
If the rule has an unknown `rule_type`.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function fails.
"""
if rule.rule_type == "replace":
source = get_tag_from_info(
rule.original,
term_registry=term_registry,
)
target = get_tag_from_info(
rule.replacement,
term_registry=term_registry,
)
return partial(replace_tag_transform, source=source, target=target)
if rule.rule_type == "derive_tag":
source_term = get_term_from_key(
rule.source_term_key,
term_registry=term_registry,
)
target_term = (
get_term_from_key(
rule.target_term_key,
term_registry=term_registry,
)
if rule.target_term_key
else source_term
)
derivation = get_derivation(
key=rule.derivation_function,
import_derivation=rule.import_derivation,
registry=derivation_registry,
)
return partial(
derivation_tag_transform,
source_term=source_term,
target_term=target_term,
derivation=derivation,
keep_source=rule.keep_source,
)
if rule.rule_type == "map_value":
source_term = get_term_from_key(
rule.source_term_key,
term_registry=term_registry,
)
target_term = (
get_term_from_key(
rule.target_term_key,
term_registry=term_registry,
)
if rule.target_term_key
else source_term
)
return partial(
map_value_transform,
source_term=source_term,
target_term=target_term,
mapping=rule.value_mapping,
)
# Handle unknown rule type
valid_options = ["replace", "derive_tag", "map_value"]
# Should be caught by Pydantic validation, but good practice
raise ValueError(
f"Invalid transform rule type '{getattr(rule, 'rule_type', 'N/A')}'. "
f"Valid options are: {valid_options}"
)
def build_transformation_from_config(
config: TransformConfig,
derivation_registry: DerivationRegistry = derivation_registry,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventTransformation:
"""Build a composite transformation function from a TransformConfig.
Creates a sequence of individual transformation functions based on the
rules defined in the configuration. Returns a single function that
applies these transformations sequentially to an annotation.
Parameters
----------
config : TransformConfig
The configuration object containing the list of transformation rules.
derivation_reg : DerivationRegistry, optional
The derivation registry to use when building `DeriveTagRule`
transformations. Defaults to the global `derivation_registry`.
Returns
-------
SoundEventTransformation
A single function that applies all configured transformations in order.
"""
transforms = [
build_transform_from_rule(
rule,
derivation_registry=derivation_registry,
term_registry=term_registry,
)
for rule in config.rules
]
def transformation(
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
for transform in transforms:
sound_event_annotation = transform(sound_event_annotation)
return sound_event_annotation
return transformation
def load_transformation_config(
path: data.PathLike, field: Optional[str] = None
) -> TransformConfig:
"""Load the transformation configuration from a file.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : str, optional
If the transformation configuration is nested under a specific key
in the file, specify the key here. Defaults to None.
Returns
-------
TransformConfig
The loaded and validated transformation configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the TransformConfig schema.
"""
return load_config(path=path, schema=TransformConfig, field=field)
def load_transformation_from_config(
path: data.PathLike,
field: Optional[str] = None,
derivation_registry: DerivationRegistry = derivation_registry,
term_registry: TermRegistry = default_term_registry,
) -> SoundEventTransformation:
"""Load transformation config from a file and build the final function.
This is a convenience function that combines loading the configuration
and building the final callable transformation function that applies
all rules sequentially.
Parameters
----------
path : data.PathLike
Path to the configuration file (YAML).
field : str, optional
If the transformation configuration is nested under a specific key
in the file, specify the key here. Defaults to None.
Returns
-------
SoundEventTransformation
The final composite transformation function ready to be used.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the TransformConfig schema.
KeyError
If required term keys or derivation keys specified in the config
are not found during the build process.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function specified in the config
fails.
"""
config = load_transformation_config(path=path, field=field)
return build_transformation_from_config(
config,
derivation_registry=derivation_registry,
term_registry=term_registry,
)
def register_derivation(
key: str,
derivation: Derivation,
derivation_registry: DerivationRegistry = derivation_registry,
) -> None:
"""Register a new derivation function in the global registry.
Parameters
----------
key : str
The unique key to associate with the derivation function.
derivation : Derivation
The callable derivation function (takes str, returns str).
derivation_registry : DerivationRegistry, optional
The registry instance to register the derivation function with.
Defaults to the global `derivation_registry`.
Raises
------
KeyError
If a derivation function with the same key is already registered.
"""
derivation_registry.register(key, derivation)

233
batdetect2/targets/types.py Normal file
View File

@ -0,0 +1,233 @@
"""Defines the core interface (Protocol) for the target definition pipeline.
This module specifies the standard structure, attributes, and methods expected
from an object that encapsulates the complete configured logic for processing
sound event annotations within the `batdetect2.targets` system.
The main component defined here is the `TargetProtocol`. This protocol acts as
a contract for the entire target definition process, covering semantic aspects
(filtering, tag transformation, class encoding/decoding) as well as geometric
aspects (mapping regions of interest to target positions and sizes). It ensures
that components responsible for these tasks can be interacted with consistently
throughout BatDetect2.
"""
from typing import List, Optional, Protocol
import numpy as np
from soundevent import data
__all__ = [
"TargetProtocol",
]
class TargetProtocol(Protocol):
"""Protocol defining the interface for the target definition pipeline.
This protocol outlines the standard attributes and methods for an object
that encapsulates the complete, configured process for handling sound event
annotations (both tags and geometry). It defines how to:
- Filter relevant annotations.
- Transform annotation tags.
- Encode an annotation into a specific target class name.
- Decode a class name back into representative tags.
- Extract a target reference position from an annotation's geometry (ROI).
- Calculate target size dimensions from an annotation's geometry.
- Recover an approximate geometry (ROI) from a position and size
dimensions.
Implementations of this protocol bundle all configured logic for these
steps.
Attributes
----------
class_names : List[str]
An ordered list of the unique names of the specific target classes
defined by the configuration.
generic_class_tags : List[data.Tag]
A list of `soundevent.data.Tag` objects representing the configured
generic class category (e.g., used when no specific class matches).
dimension_names : List[str]
A list containing the names of the size dimensions returned by
`get_size` and expected by `recover_roi` (e.g., ['width', 'height']).
"""
class_names: List[str]
"""Ordered list of unique names for the specific target classes."""
generic_class_tags: List[data.Tag]
"""List of tags representing the generic (unclassified) category."""
dimension_names: List[str]
"""Names of the size dimensions (e.g., ['width', 'height'])."""
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the filter to a sound event annotation.
Determines if the annotation should be included in further processing
and training based on the configured filtering rules.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation to filter.
Returns
-------
bool
True if the annotation should be kept (passes the filter),
False otherwise. Implementations should return True if no
filtering is configured.
"""
...
def transform(
self,
sound_event: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
"""Apply tag transformations to an annotation.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation whose tags should be transformed.
Returns
-------
data.SoundEventAnnotation
A new annotation object with the transformed tags. Implementations
should return the original annotation object if no transformations
were configured.
"""
...
def encode(
self,
sound_event: data.SoundEventAnnotation,
) -> Optional[str]:
"""Encode a sound event annotation to its target class name.
Parameters
----------
sound_event : data.SoundEventAnnotation
The (potentially filtered and transformed) annotation to encode.
Returns
-------
str or None
The string name of the matched target class if the annotation
matches a specific class definition. Returns None if the annotation
does not match any specific class rule (indicating it may belong
to a generic category or should be handled differently downstream).
"""
...
def decode(self, class_label: str) -> List[data.Tag]:
"""Decode a predicted class name back into representative tags.
Parameters
----------
class_label : str
The class name string (e.g., predicted by a model) to decode.
Returns
-------
List[data.Tag]
The list of tags corresponding to the input class name according
to the configuration. May return an empty list or raise an error
for unmapped labels, depending on the implementation's configuration
(e.g., `raise_on_unmapped` flag during building).
Raises
------
ValueError, KeyError
Implementations might raise an error if the `class_label` is not
found in the configured mapping and error raising is enabled.
"""
...
def get_position(
self, sound_event: data.SoundEventAnnotation
) -> tuple[float, float]:
"""Extract the target reference position from the annotation's geometry.
Calculates the `(time, frequency)` coordinate representing the primary
location of the sound event.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI) to process.
Returns
-------
Tuple[float, float]
The calculated reference position `(time, frequency)`.
Raises
------
ValueError
If the annotation lacks geometry or if the position cannot be
calculated for the geometry type or configured reference point.
"""
...
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
"""Calculate the target size dimensions from the annotation's geometry.
Computes the relevant physical size (e.g., duration/width,
bandwidth/height from a bounding box) to produce
the numerical target values expected by the model.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI) to process.
Returns
-------
np.ndarray
A NumPy array containing the size dimensions, matching the
order specified by the `dimension_names` attribute (e.g.,
`[width, height]`).
Raises
------
ValueError
If the annotation lacks geometry or if the size cannot be computed.
TypeError
If geometry type is unsupported.
"""
...
def recover_roi(
self, pos: tuple[float, float], dims: np.ndarray
) -> data.Geometry:
"""Recover the ROI geometry from a position and dimensions.
Performs the inverse mapping of `get_position` and `get_size`. It takes
a reference position `(time, frequency)` and an array of size
dimensions and reconstructs an approximate geometric representation.
Parameters
----------
pos : Tuple[float, float]
The reference position `(time, frequency)`.
dims : np.ndarray
The NumPy array containing the dimensions (e.g., predicted
by the model), corresponding to the order in `dimension_names`.
Returns
-------
soundevent.data.Geometry
The reconstructed geometry.
Raises
------
ValueError
If the number of provided `dims` does not match `dimension_names`,
if dimensions are invalid (e.g., negative after unscaling), or
if reconstruction fails based on the configured position type.
"""
...

View File

@ -1,88 +0,0 @@
from inspect import getmembers
from typing import Optional
from pydantic import BaseModel
from soundevent import data, terms
__all__ = [
"call_type",
"individual",
"get_term_from_info",
"get_tag_from_info",
"TermInfo",
"TagInfo",
]
class TermInfo(BaseModel):
label: Optional[str]
name: Optional[str]
uri: Optional[str]
class TagInfo(BaseModel):
value: str
term: Optional[TermInfo] = None
key: Optional[str] = None
label: Optional[str] = None
call_type = data.Term(
name="soundevent:call_type",
label="Call Type",
definition="A broad categorization of animal vocalizations based on their intended function or purpose (e.g., social, distress, mating, territorial, echolocation).",
)
individual = data.Term(
name="soundevent:individual",
label="Individual",
definition="An id for an individual animal. In the context of bioacoustic annotation, this term is used to label vocalizations that are attributed to a specific individual.",
)
ALL_TERMS = [
*getmembers(terms, lambda x: isinstance(x, data.Term)),
call_type,
individual,
]
def get_term_from_info(term_info: TermInfo) -> data.Term:
for term in ALL_TERMS:
if term_info.name and term_info.name == term.name:
return term
if term_info.label and term_info.label == term.label:
return term
if term_info.uri and term_info.uri == term.uri:
return term
if term_info.name is None:
if term_info.label is None:
raise ValueError("At least one of name or label must be provided.")
term_info.name = (
f"soundevent:{term_info.label.lower().replace(' ', '_')}"
)
if term_info.label is None:
term_info.label = term_info.name
return data.Term(
name=term_info.name,
label=term_info.label,
uri=term_info.uri,
definition="Unknown",
)
def get_tag_from_info(tag_info: TagInfo) -> data.Tag:
if tag_info.term:
term = get_term_from_info(tag_info.term)
elif tag_info.key:
term = data.term_from_key(tag_info.key)
else:
raise ValueError("Either term or key must be provided in tag info.")
return data.Tag(term=term, value=tag_info.value)

View File

@ -1,53 +1,54 @@
from batdetect2.train.augmentations import (
AugmentationsConfig,
EchoAugmentationConfig,
FrequencyMaskAugmentationConfig,
TimeMaskAugmentationConfig,
VolumeAugmentationConfig,
WarpAugmentationConfig,
add_echo,
augment_example,
load_agumentation_config,
build_augmentations,
mask_frequency,
mask_time,
mix_examples,
scale_volume,
select_subclip,
warp_spectrogram,
)
from batdetect2.train.clips import build_clipper, select_subclip
from batdetect2.train.config import TrainingConfig, load_train_config
from batdetect2.train.dataset import (
LabeledDataset,
SubclipConfig,
RandomExampleSource,
TrainExample,
get_preprocessed_files,
list_preprocessed_files,
)
from batdetect2.train.labels import LabelConfig, load_label_config
from batdetect2.train.labels import load_label_config
from batdetect2.train.losses import LossFunction, build_loss
from batdetect2.train.preprocess import (
generate_train_example,
preprocess_annotations,
)
from batdetect2.train.targets import (
TagInfo,
TargetConfig,
build_target_encoder,
load_target_config,
)
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
__all__ = [
"AugmentationsConfig",
"LabelConfig",
"EchoAugmentationConfig",
"FrequencyMaskAugmentationConfig",
"LabeledDataset",
"SubclipConfig",
"TagInfo",
"TargetConfig",
"LossFunction",
"RandomExampleSource",
"TimeMaskAugmentationConfig",
"TrainExample",
"TrainerConfig",
"TrainingConfig",
"VolumeAugmentationConfig",
"WarpAugmentationConfig",
"add_echo",
"augment_example",
"build_target_encoder",
"build_augmentations",
"build_clipper",
"build_loss",
"generate_train_example",
"get_preprocessed_files",
"load_agumentation_config",
"list_preprocessed_files",
"load_label_config",
"load_target_config",
"load_train_config",
"load_trainer_config",
"mask_frequency",

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,55 @@
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from torch.utils.data import DataLoader
from batdetect2.postprocess import PostprocessorProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.types import ModelOutput
class ValidationMetrics(Callback):
def __init__(self, postprocessor: PostprocessorProtocol):
super().__init__()
self.postprocessor = postprocessor
self.predictions = []
def on_validation_epoch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
) -> None:
self.predictions = []
return super().on_validation_epoch_start(trainer, pl_module)
def on_validation_batch_end( # type: ignore
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: ModelOutput,
batch: TrainExample,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset)
clip_annotation = dataset.get_clip_annotation(batch_idx)
# clip_prediction = postprocess_model_outputs(
# outputs,
# clips=[clip_annotation.clip],
# classes=self.class_names,
# decoder=self.decoder,
# config=self.config.postprocessing,
# )[0]
#
# matches = match_predictions_and_annotations(
# clip_annotation,
# clip_prediction,
# )
#
# self.validation_predictions.extend(matches)
# return super().on_validation_batch_end(
# trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
# )

184
batdetect2/train/clips.py Normal file
View File

@ -0,0 +1,184 @@
from typing import Optional, Tuple, Union
import numpy as np
import xarray as xr
from soundevent import arrays
from batdetect2.configs import BaseConfig
from batdetect2.train.types import ClipperProtocol
DEFAULT_TRAIN_CLIP_DURATION = 0.513
DEFAULT_MAX_EMPTY_CLIP = 0.1
class ClipingConfig(BaseConfig):
duration: float = DEFAULT_TRAIN_CLIP_DURATION
random: bool = True
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
class Clipper(ClipperProtocol):
def __init__(
self,
duration: float = 0.5,
max_empty: float = 0.2,
random: bool = True,
):
self.duration = duration
self.random = random
self.max_empty = max_empty
def extract_clip(
self, example: xr.Dataset
) -> Tuple[xr.Dataset, float, float]:
step = arrays.get_dim_step(
example.spectrogram,
dim=arrays.Dimensions.time.value,
)
duration = (
arrays.get_dim_width(
example.spectrogram,
dim=arrays.Dimensions.time.value,
)
+ step
)
start_time = 0
if self.random:
start_time = np.random.uniform(
-self.max_empty,
duration - self.duration + self.max_empty,
)
subclip = select_subclip(
example,
start=start_time,
span=self.duration,
dim="time",
)
return (
select_subclip(
subclip,
start=start_time,
span=self.duration,
dim="audio_time",
),
start_time,
start_time + self.duration,
)
def build_clipper(config: Optional[ClipingConfig] = None) -> ClipperProtocol:
config = config or ClipingConfig()
return Clipper(
duration=config.duration,
max_empty=config.max_empty,
random=config.random,
)
def select_subclip(
dataset: xr.Dataset,
span: float,
start: float,
fill_value: float = 0,
dim: str = "time",
) -> xr.Dataset:
width = _compute_expected_width(
dataset, # type: ignore
span,
dim=dim,
)
coord = dataset.coords[dim]
if len(coord) == width:
return dataset
new_coords, start_pad, end_pad, dim_slice = _extract_coordinate(
coord, start, span
)
data_vars = {}
for name, data_array in dataset.data_vars.items():
if dim not in data_array.dims:
data_vars[name] = data_array
continue
if width == data_array.sizes[dim]:
data_vars[name] = data_array
continue
sliced = data_array.isel({dim: dim_slice}).data
if start_pad > 0 or end_pad > 0:
padding = [
[0, 0] if other_dim != dim else [start_pad, end_pad]
for other_dim in data_array.dims
]
sliced = np.pad(sliced, padding, constant_values=fill_value)
data_vars[name] = xr.DataArray(
data=sliced,
dims=data_array.dims,
coords={**data_array.coords, dim: new_coords},
attrs=data_array.attrs,
)
return xr.Dataset(data_vars=data_vars, attrs=dataset.attrs)
def _extract_coordinate(
coord: xr.DataArray,
start: float,
span: float,
) -> Tuple[xr.Variable, int, int, slice]:
step = arrays.get_dim_step(coord, str(coord.name))
current_width = len(coord)
expected_width = int(np.floor(span / step))
coord_start = float(coord[0])
offset = start - coord_start
start_index = int(np.floor(offset / step))
end_index = start_index + expected_width
if start_index > current_width:
raise ValueError("Requested span does not overlap with current range")
if end_index < 0:
raise ValueError("Requested span does not overlap with current range")
corrected_start = float(start_index * step)
corrected_end = float(end_index * step)
start_index_offset = max(0, -start_index)
end_index_offset = max(0, end_index - current_width)
sl = slice(
start_index if start_index >= 0 else None,
end_index if end_index < current_width else None,
)
return (
arrays.create_range_dim(
str(coord.name),
start=corrected_start,
stop=corrected_end,
step=step,
),
start_index_offset,
end_index_offset,
sl,
)
def _compute_expected_width(
array: Union[xr.DataArray, xr.Dataset],
duration: float,
dim: str,
) -> int:
step = arrays.get_dim_step(array, dim) # type: ignore
return int(np.floor(duration / step))

View File

@ -4,6 +4,11 @@ from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
)
from batdetect2.train.clips import ClipingConfig
from batdetect2.train.losses import LossConfig
__all__ = [
@ -20,9 +25,17 @@ class OptimizerConfig(BaseConfig):
class TrainingConfig(BaseConfig):
batch_size: int = 32
loss: LossConfig = Field(default_factory=LossConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
)
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
def load_train_config(
path: PathLike,

View File

@ -1,87 +1,40 @@
import os
from pathlib import Path
from typing import NamedTuple, Optional, Sequence, Union
from typing import List, Optional, Sequence, Tuple
import numpy as np
import torch
import xarray as xr
from pydantic import Field
from soundevent import data
from torch.utils.data import Dataset
from batdetect2.configs import BaseConfig
from batdetect2.preprocess.tensors import adjust_width
from batdetect2.train.augmentations import (
AugmentationsConfig,
augment_example,
select_subclip,
)
from batdetect2.train.preprocess import PreprocessingConfig
from batdetect2.train.augmentations import Augmentation
from batdetect2.train.types import ClipperProtocol, TrainExample
__all__ = [
"TrainExample",
"LabeledDataset",
]
PathLike = Union[Path, str, os.PathLike]
class TrainExample(NamedTuple):
spec: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
idx: torch.Tensor
class SubclipConfig(BaseConfig):
duration: Optional[float] = None
width: int = 512
random: bool = False
class DatasetConfig(BaseConfig):
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
augmentation: AugmentationsConfig = Field(
default_factory=AugmentationsConfig
)
class LabeledDataset(Dataset):
def __init__(
self,
filenames: Sequence[PathLike],
subclip: Optional[SubclipConfig] = None,
augmentation: Optional[AugmentationsConfig] = None,
preprocessing: Optional[PreprocessingConfig] = None,
filenames: Sequence[data.PathLike],
clipper: ClipperProtocol,
augmentation: Optional[Augmentation] = None,
):
self.filenames = filenames
self.subclip = subclip
self.clipper = clipper
self.augmentation = augmentation
self.preprocessing = preprocessing or PreprocessingConfig()
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx) -> TrainExample:
dataset = self.get_dataset(idx)
if self.subclip:
dataset = select_subclip(
dataset,
duration=self.subclip.duration,
width=self.subclip.width,
random=self.subclip.random,
)
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
if self.augmentation:
dataset = augment_example(
dataset,
self.augmentation,
preprocessing_config=self.preprocessing,
others=self.get_random_example,
)
dataset = self.augmentation(dataset)
return TrainExample(
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
@ -89,37 +42,31 @@ class LabeledDataset(Dataset):
class_heatmap=self.to_tensor(dataset["class"]),
size_heatmap=self.to_tensor(dataset["size"]),
idx=torch.tensor(idx),
start_time=start_time,
end_time=end_time,
)
@classmethod
def from_directory(
cls,
directory: PathLike,
directory: data.PathLike,
clipper: ClipperProtocol,
extension: str = ".nc",
subclip: Optional[SubclipConfig] = None,
augmentation: Optional[AugmentationsConfig] = None,
preprocessing: Optional[PreprocessingConfig] = None,
augmentation: Optional[Augmentation] = None,
):
return cls(
get_preprocessed_files(directory, extension),
subclip=subclip,
filenames=list_preprocessed_files(directory, extension),
clipper=clipper,
augmentation=augmentation,
preprocessing=preprocessing,
)
def get_random_example(self) -> xr.Dataset:
def get_random_example(self) -> Tuple[xr.Dataset, float, float]:
idx = np.random.randint(0, len(self))
dataset = self.get_dataset(idx)
if self.subclip:
dataset = select_subclip(
dataset,
duration=self.subclip.duration,
width=self.subclip.width,
random=self.subclip.random,
)
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
return dataset
return dataset, start_time, end_time
def get_dataset(self, idx) -> xr.Dataset:
return xr.open_dataset(self.filenames[idx])
@ -134,16 +81,23 @@ class LabeledDataset(Dataset):
array: xr.DataArray,
dtype=np.float32,
) -> torch.Tensor:
tensor = torch.tensor(array.values.astype(dtype))
if not self.subclip:
return tensor
width = self.subclip.width
return adjust_width(tensor, width)
return torch.tensor(array.values.astype(dtype))
def get_preprocessed_files(
directory: PathLike, extension: str = ".nc"
def list_preprocessed_files(
directory: data.PathLike, extension: str = ".nc"
) -> Sequence[Path]:
return list(Path(directory).glob(f"*{extension}"))
class RandomExampleSource:
def __init__(self, filenames: List[str], clipper: ClipperProtocol):
self.filenames = filenames
self.clipper = clipper
def __call__(self):
index = int(np.random.randint(len(self.filenames)))
filename = self.filenames[index]
dataset = xr.open_dataset(filename)
example, _, _ = self.clipper.extract_clip(dataset)
return example

View File

@ -1,45 +1,216 @@
"""Generate heatmap training targets for BatDetect2 models.
This module is responsible for creating the target labels used for training
BatDetect2 models. It converts sound event annotations for an audio clip into
the specific multi-channel heatmap formats required by the neural network.
It uses a pre-configured object adhering to the `TargetProtocol` (from
`batdetect2.targets`) which encapsulates all the logic for filtering
annotations, transforming tags, encoding class names, and mapping annotation
geometry (ROIs) to target positions and sizes. This module then focuses on
rendering this information onto the heatmap grids.
The pipeline generates three core outputs for a given spectrogram:
1. **Detection Heatmap**: Indicates presence/location of relevant sound events.
2. **Class Heatmap**: Indicates location and class identity for specifically
classified events.
3. **Size Heatmap**: Encodes the target dimensions (width, height) of events.
The primary function generated by this module is a `ClipLabeller` (defined in
`.types`), which takes a `ClipAnnotation` object and its corresponding
spectrogram and returns the calculated `Heatmaps` tuple. The main configurable
parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
defined in `LabelConfig`.
"""
import logging
from collections.abc import Iterable
from typing import Callable, List, Optional, Sequence, Tuple
from functools import partial
from typing import Optional
import numpy as np
import xarray as xr
from pydantic import Field
from scipy.ndimage import gaussian_filter
from soundevent import arrays, data, geometry
from soundevent.geometry.operations import Positions
from soundevent import arrays, data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.types import (
ClipLabeller,
Heatmaps,
)
__all__ = [
"HeatmapsConfig",
"LabelConfig",
"build_clip_labeler",
"generate_clip_label",
"generate_heatmaps",
"load_label_config",
]
class HeatmapsConfig(BaseConfig):
position: Positions = "bottom-left"
sigma: float = 3.0
time_scale: float = 1000.0
frequency_scale: float = 1 / 859.375
SIZE_DIMENSION = "dimension"
"""Dimension name for the size heatmap."""
logger = logging.getLogger(__name__)
class LabelConfig(BaseConfig):
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
"""Configuration parameters for heatmap generation.
Attributes
----------
sigma : float, default=3.0
The standard deviation (in pixels/bins) of the Gaussian kernel applied
to smooth the detection and class heatmaps. Larger values create more
diffuse targets.
"""
sigma: float = 3.0
def build_clip_labeler(
targets: TargetProtocol,
config: Optional[LabelConfig] = None,
) -> ClipLabeller:
"""Construct the final clip labelling function.
This factory function prepares the callable that will perform the
end-to-end heatmap generation for a given clip and spectrogram during
training data loading. It takes the fully configured `targets` object and
the `LabelConfig` and binds them to the `generate_clip_label` function.
Parameters
----------
targets : TargetProtocol
An initialized object conforming to the `TargetProtocol`, providing all
necessary methods for filtering, transforming, encoding, and ROI
mapping.
config : LabelConfig
Configuration object containing heatmap generation parameters.
Returns
-------
ClipLabeller
A function that accepts a `data.ClipAnnotation` and `xr.DataArray`
(spectrogram) and returns the generated `Heatmaps`.
"""
return partial(
generate_clip_label,
targets=targets,
config=config or LabelConfig(),
)
def generate_clip_label(
clip_annotation: data.ClipAnnotation,
spec: xr.DataArray,
targets: TargetProtocol,
config: LabelConfig,
) -> Heatmaps:
"""Generate training heatmaps for a single annotated clip.
This function orchestrates the target generation process for one clip:
1. Filters and transforms sound events using `targets.filter` and
`targets.transform`.
2. Passes the resulting processed annotations, along with the spectrogram,
the `targets` object, and the Gaussian `sigma` from `config`, to the
core `generate_heatmaps` function.
Parameters
----------
clip_annotation : data.ClipAnnotation
The complete annotation data for the audio clip, including the list
of `sound_events` to process.
spec : xr.DataArray
The spectrogram corresponding to the `clip_annotation`. Must have
'time' and 'frequency' dimensions/coordinates.
targets : TargetProtocol
The fully configured target definition object, providing methods for
filtering, transformation, encoding, and ROI mapping.
config : LabelConfig
Configuration object providing heatmap parameters (primarily `sigma`).
Returns
-------
Heatmaps
A NamedTuple containing the generated 'detection', 'classes', and 'size'
heatmaps for this clip.
"""
return generate_heatmaps(
(
targets.transform(sound_event_annotation)
for sound_event_annotation in clip_annotation.sound_events
if targets.filter(sound_event_annotation)
),
spec=spec,
targets=targets,
target_sigma=config.sigma,
)
def generate_heatmaps(
sound_events: Sequence[data.SoundEventAnnotation],
sound_events: Iterable[data.SoundEventAnnotation],
spec: xr.DataArray,
class_names: List[str],
encoder: Callable[[Iterable[data.Tag]], Optional[str]],
targets: TargetProtocol,
target_sigma: float = 3.0,
position: Positions = "bottom-left",
time_scale: float = 1000.0,
frequency_scale: float = 1 / 859.375,
dtype=np.float32,
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
) -> Heatmaps:
"""Generate detection, class, and size heatmaps from sound events.
Creates heatmap representations from an iterable of sound event
annotations. This function relies on the provided `targets` object to get
the reference position, scaled size, and class encoding for each
annotation.
Process:
1. Initializes empty heatmap arrays based on `spec` shape and `targets`
metadata.
2. Iterates through `sound_events`.
3. For each event:
a. Gets geometry. Skips if missing.
b. Gets reference position using `targets.get_position()`. Skips if out
of bounds.
c. Places a peak (1.0) on the detection heatmap at the position.
d. Gets scaled size using `targets.get_size()` and places it on the
size heatmap.
e. Encodes class using `targets.encode()` and places a peak (1.0) on
the corresponding class heatmap layer if a specific class is
returned.
4. Applies Gaussian smoothing (using `target_sigma`) to detection and class
heatmaps.
5. Normalizes detection and class heatmaps to range [0, 1].
Parameters
----------
sound_events : Iterable[data.SoundEventAnnotation]
An iterable of sound event annotations to render onto the heatmaps.
spec : xr.DataArray
The spectrogram array corresponding to the time/frequency range of
the annotations. Used for shape and coordinate information. Must have
'time' and 'frequency' dimensions/coordinates.
targets : TargetProtocol
The fully configured target definition object. Used to access class
names, dimension names, and the methods `get_position`, `get_size`,
`encode`.
target_sigma : float, default=3.0
Standard deviation (in pixels/bins) of the Gaussian kernel applied to
smooth the detection and class heatmaps.
dtype : type, default=np.float32
The data type for the generated heatmap arrays (e.g., `np.float32`).
Returns
-------
Heatmaps
A NamedTuple containing the 'detection', 'classes', and 'size'
xarray DataArrays, ready for use in model training.
Raises
------
ValueError
If the input spectrogram `spec` does not have both 'time' and
'frequency' dimensions, or if `targets.class_names` is empty.
"""
shape = dict(zip(spec.dims, spec.shape))
if "time" not in shape or "frequency" not in shape:
@ -50,18 +221,18 @@ def generate_heatmaps(
# Initialize heatmaps
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
class_heatmap = xr.DataArray(
data=np.zeros((len(class_names), *spec.shape), dtype=dtype),
data=np.zeros((len(targets.class_names), *spec.shape), dtype=dtype),
dims=["category", *spec.dims],
coords={
"category": [*class_names],
"category": [*targets.class_names],
**spec.coords,
},
)
size_heatmap = xr.DataArray(
data=np.zeros((2, *spec.shape), dtype=dtype),
dims=["dimension", *spec.dims],
dims=[SIZE_DIMENSION, *spec.dims],
coords={
"dimension": ["width", "height"],
SIZE_DIMENSION: targets.dimension_names,
**spec.coords,
},
)
@ -69,10 +240,14 @@ def generate_heatmaps(
for sound_event_annotation in sound_events:
geom = sound_event_annotation.sound_event.geometry
if geom is None:
logger.debug(
"Skipping annotation %s: missing geometry.",
sound_event_annotation.uuid,
)
continue
# Get the position of the sound event
time, frequency = geometry.get_geometry_point(geom, position=position)
time, frequency = targets.get_position(sound_event_annotation)
# Set 1.0 at the position of the sound event in the detection heatmap
try:
@ -84,19 +259,15 @@ def generate_heatmaps(
)
except KeyError:
# Skip the sound event if the position is outside the spectrogram
logger.debug(
"Skipping annotation %s: position outside spectrogram. "
"Pos: %s",
sound_event_annotation.uuid,
(time, frequency),
)
continue
# Set the size of the sound event at the position in the size heatmap
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
geom
)
size = np.array(
[
(end_time - start_time) * time_scale,
(high_freq - low_freq) * frequency_scale,
]
)
size = targets.get_size(sound_event_annotation)
size_heatmap = arrays.set_value_at_pos(
size_heatmap,
@ -106,19 +277,38 @@ def generate_heatmaps(
)
# Get the class name of the sound event
class_name = encoder(sound_event_annotation.tags)
try:
class_name = targets.encode(sound_event_annotation)
except ValueError as e:
logger.warning(
"Skipping annotation %s: Unexpected error while encoding "
"class name %s",
sound_event_annotation.uuid,
e,
)
continue
if class_name is None:
# If the label is None skip the sound event
continue
class_heatmap = arrays.set_value_at_pos(
class_heatmap,
1.0,
time=time,
frequency=frequency,
category=class_name,
)
try:
class_heatmap = arrays.set_value_at_pos(
class_heatmap,
1.0,
time=time,
frequency=frequency,
category=class_name,
)
except KeyError:
# Skip the sound event if the position is outside the spectrogram
logger.debug(
"Skipping annotation %s for class heatmap: "
"position outside spectrogram. Pos: %s",
sound_event_annotation.uuid,
(class_name, time, frequency),
)
continue
# Apply gaussian filters
detection_heatmap = xr.apply_ufunc(
@ -141,10 +331,36 @@ def generate_heatmaps(
class_heatmap / class_heatmap.max(dim=["time", "frequency"])
).fillna(0.0)
return detection_heatmap, class_heatmap, size_heatmap
return Heatmaps(
detection=detection_heatmap,
classes=class_heatmap,
size=size_heatmap,
)
def load_label_config(
path: data.PathLike, field: Optional[str] = None
) -> LabelConfig:
"""Load the heatmap label generation configuration from a file.
Parameters
----------
path : data.PathLike
Path to the configuration file (e.g., YAML or JSON).
field : str, optional
If the label configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
Returns
-------
LabelConfig
The loaded and validated label configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the LabelConfig schema.
"""
return load_config(path, schema=LabelConfig, field=field)

View File

@ -1,4 +1,4 @@
from typing import Callable, NamedTuple, Optional
from typing import NamedTuple, Optional
import torch
from soundevent import data
@ -6,7 +6,7 @@ from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from batdetect2.models.typing import DetectionModel
from batdetect2.models.types import DetectionModel
from batdetect2.train.dataset import LabeledDataset

View File

@ -1,5 +1,5 @@
import sys
import json
import sys
from collections import Counter
from pathlib import Path
from typing import Dict, Generator, List, Optional, Tuple

View File

@ -0,0 +1,71 @@
import lightning as L
import torch
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.models import (
DetectionModel,
ModelOutput,
)
from batdetect2.postprocess.types import PostprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
from batdetect2.train import TrainExample
from batdetect2.train.types import LossProtocol
__all__ = [
"TrainingModule",
]
class TrainingModule(L.LightningModule):
def __init__(
self,
detector: DetectionModel,
loss: LossProtocol,
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
learning_rate: float = 0.001,
t_max: int = 100,
):
super().__init__()
self.loss = loss
self.detector = detector
self.preprocessor = preprocessor
self.targets = targets
self.postprocessor = postprocessor
self.learning_rate = learning_rate
self.t_max = t_max
self.save_hyperparameters()
def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.detector(spec)
def training_step(self, batch: TrainExample):
outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch)
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
self.log("train/loss/detection", losses.total, logger=True)
self.log("train/loss/size", losses.total, logger=True)
self.log("train/loss/classification", losses.total, logger=True)
return losses.total
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
outputs = self.forward(batch.spec)
losses = self.loss(outputs, batch)
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
self.log("val/loss/detection", losses.total, logger=True)
self.log("val/loss/size", losses.total, logger=True)
self.log("val/loss/classification", losses.total, logger=True)
def configure_optimizers(self):
optimizer = Adam(self.parameters(), lr=self.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
return [optimizer], [scheduler]

View File

@ -1,115 +1,323 @@
from typing import NamedTuple, Optional
"""Loss functions and configurations for training BatDetect2 models.
This module defines the loss functions used to train BatDetect2 models,
including individual loss components for different prediction tasks (detection,
classification, size regression) and a main coordinating loss function that
combines them.
It utilizes common loss types like L1 loss (`BBoxLoss`) for regression and
Focal Loss (`FocalLoss`) for handling class imbalance in dense detection and
classification tasks. Configuration objects (`LossConfig`, etc.) allow for easy
customization of loss parameters and weights via configuration files.
The primary entry points are:
- `LossFunction`: An `nn.Module` that computes the weighted sum of individual
loss components given model outputs and ground truth targets.
- `build_loss`: A factory function that constructs the `LossFunction` based
on a `LossConfig` object.
- `LossConfig`: The Pydantic model for configuring loss weights and parameters.
"""
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.typing import ModelOutput
from batdetect2.models.types import ModelOutput
from batdetect2.train.dataset import TrainExample
from batdetect2.train.types import Losses, LossProtocol
__all__ = [
"bbox_size_loss",
"compute_loss",
"focal_loss",
"mse_loss",
"BBoxLoss",
"ClassificationLossConfig",
"DetectionLossConfig",
"FocalLoss",
"FocalLossConfig",
"LossConfig",
"LossFunction",
"MSELoss",
"SizeLossConfig",
"build_loss",
]
class SizeLossConfig(BaseConfig):
"""Configuration for the bounding box size loss component.
Attributes
----------
weight : float, default=0.1
The weighting factor applied to the size loss when combining it with
other losses (detection, classification) to form the total training
loss.
"""
weight: float = 0.1
def bbox_size_loss(
pred_size: torch.Tensor,
gt_size: torch.Tensor,
) -> torch.Tensor:
class BBoxLoss(nn.Module):
"""Computes L1 loss for bounding box size regression.
Calculates the Mean Absolute Error (MAE or L1 loss) between the predicted
size dimensions (`pred`) and the ground truth size dimensions (`gt`).
Crucially, the loss is only computed at locations where the ground truth
size heatmap (`gt`) contains non-zero values (i.e., at the reference points
of actual annotated sound events). This prevents the model from being
penalized for size predictions in background regions.
The loss is summed over all valid locations and normalized by the number
of valid locations.
"""
Bounding box size loss. Only compute loss where there is a bounding box.
"""
gt_size_mask = (gt_size > 0).float()
return F.l1_loss(pred_size * gt_size_mask, gt_size, reduction="sum") / (
gt_size_mask.sum() + 1e-5
)
def forward(self, pred: torch.Tensor, gt: torch.Tensor) -> torch.Tensor:
"""Calculate masked L1 loss for size prediction.
Parameters
----------
pred : torch.Tensor
Predicted size tensor, typically shape `(B, 2, H, W)`, where
channels represent scaled width and height.
gt : torch.Tensor
Ground truth size tensor, same shape as `pred`. Non-zero values
indicate locations and target sizes of actual annotations.
Returns
-------
torch.Tensor
Scalar tensor representing the calculated masked L1 loss.
"""
gt_size_mask = (gt > 0).float()
masked_pred = pred * gt_size_mask
loss = F.l1_loss(masked_pred, gt, reduction="sum")
num_pos = gt_size_mask.sum() + 1e-5
return loss / num_pos
class FocalLossConfig(BaseConfig):
"""Configuration parameters for the Focal Loss function.
Attributes
----------
beta : float, default=4
Exponent controlling the down-weighting of easy negative examples.
Higher values increase down-weighting (focus more on hard negatives).
alpha : float, default=2
Exponent controlling the down-weighting based on prediction confidence.
Higher values focus more on misclassified examples (both positive and
negative).
"""
beta: float = 4
alpha: float = 2
def focal_loss(
pred: torch.Tensor,
gt: torch.Tensor,
weights: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
eps: float = 1e-5,
beta: float = 4,
alpha: float = 2,
) -> torch.Tensor:
"""
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
pred (batch x c x h x w)
gt (batch x c x h x w)
class FocalLoss(nn.Module):
"""Focal Loss implementation, adapted from CornerNet.
Addresses class imbalance in dense object detection/classification tasks by
down-weighting the loss contribution from easy examples (both positive and
negative), allowing the model to focus more on hard-to-classify examples.
Parameters
----------
eps : float, default=1e-5
Small epsilon value added for numerical stability.
beta : float, default=4
Exponent focusing on hard negative examples (modulates `(1-gt)^beta`).
alpha : float, default=2
Exponent focusing on misclassified examples (modulates `(1-p)^alpha`
for positives and `p^alpha` for negatives).
class_weights : torch.Tensor, optional
Optional tensor containing weights for each class (applied to positive
loss). Shape should be broadcastable to the channel dimension of the
input tensors.
mask_zero : bool, default=False
If True, ignores loss contributions from spatial locations where the
ground truth `gt` tensor is zero across *all* channels. Useful for
classification heatmaps where some areas might have no assigned class.
References
----------
- Lin, T. Y., et al. "Focal loss for dense object detection." ICCV 2017.
- Law, H., & Deng, J. "CornerNet: Detecting Objects as Paired Keypoints."
ECCV 2018.
"""
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
def __init__(
self,
eps: float = 1e-5,
beta: float = 4,
alpha: float = 2,
class_weights: Optional[torch.Tensor] = None,
mask_zero: bool = False,
):
super().__init__()
self.class_weights = class_weights
self.eps = eps
self.beta = beta
self.alpha = alpha
self.mask_zero = mask_zero
pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds
neg_loss = (
torch.log(1 - pred + eps)
* torch.pow(pred, alpha)
* torch.pow(1 - gt, beta)
* neg_inds
)
def forward(
self,
pred: torch.Tensor,
gt: torch.Tensor,
) -> torch.Tensor:
"""Compute the Focal Loss.
if weights is not None:
pos_loss = pos_loss * torch.tensor(weights)
# neg_loss = neg_loss*weights
Parameters
----------
pred : torch.Tensor
Predicted probabilities or logits (typically sigmoid output for
detection, or softmax/sigmoid for classification). Must be in the
range [0, 1] after potential activation. Shape `(B, C, H, W)`.
gt : torch.Tensor
Ground truth heatmap tensor. Shape `(B, C, H, W)`. Values typically
represent target probabilities (e.g., Gaussian peaks for detection,
one-hot encoding or smoothed labels for classification). For the
adapted CornerNet loss, `gt=1` indicates a positive location, and
values `< 1` indicate negative locations (with potential Gaussian
weighting `(1-gt)^beta` for negatives near positives).
if valid_mask is not None:
pos_loss = pos_loss * valid_mask
neg_loss = neg_loss * valid_mask
Returns
-------
torch.Tensor
Scalar tensor representing the computed focal loss, normalized by
the number of positive locations.
"""
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()
num_pos = pos_inds.float().sum()
if num_pos == 0:
loss = -neg_loss
else:
loss = -(pos_loss + neg_loss) / num_pos
return loss
pos_loss = (
torch.log(pred + self.eps)
* torch.pow(1 - pred, self.alpha)
* pos_inds
)
neg_loss = (
torch.log(1 - pred + self.eps)
* torch.pow(pred, self.alpha)
* torch.pow(1 - gt, self.beta)
* neg_inds
)
if self.class_weights is not None:
pos_loss = pos_loss * torch.tensor(self.class_weights)
if self.mask_zero:
valid_mask = gt.any(dim=1, keepdim=True).float()
pos_loss = pos_loss * valid_mask
neg_loss = neg_loss * valid_mask
pos_loss = pos_loss.sum()
neg_loss = neg_loss.sum()
num_pos = pos_inds.float().sum()
if num_pos == 0:
loss = -neg_loss
else:
loss = -(pos_loss + neg_loss) / num_pos
return loss
def mse_loss(
pred: torch.Tensor,
gt: torch.Tensor,
valid_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
class MSELoss(nn.Module):
"""Mean Squared Error (MSE) Loss module.
Calculates the mean squared difference between predictions and ground
truth. Optionally masks contributions where the ground truth is zero across
channels.
Parameters
----------
mask_zero : bool, default=False
If True, calculates the loss only over spatial locations (H, W) where
at least one channel in the ground truth `gt` tensor is non-zero. The
loss is then averaged over these valid locations. If False (default),
the standard MSE over all elements is computed.
"""
Mean squared error loss.
"""
if valid_mask is None:
op = ((gt - pred) ** 2).mean()
else:
op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
return op
def __init__(self, mask_zero: bool = False):
super().__init__()
self.mask_zero = mask_zero
def forward(
self,
pred: torch.Tensor,
gt: torch.Tensor,
) -> torch.Tensor:
"""Compute the Mean Squared Error loss.
Parameters
----------
pred : torch.Tensor
Predicted tensor, shape `(B, C, H, W)`.
gt : torch.Tensor
Ground truth tensor, shape `(B, C, H, W)`.
Returns
-------
torch.Tensor
Scalar tensor representing the calculated MSE loss.
"""
if not self.mask_zero:
return ((gt - pred) ** 2).mean()
valid_mask = gt.any(dim=1, keepdim=True).float()
return (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
class DetectionLossConfig(BaseConfig):
"""Configuration for the detection loss component.
Attributes
----------
weight : float, default=1.0
Weighting factor for the detection loss in the combined total loss.
focal : FocalLossConfig
Configuration for the Focal Loss used for detection. Defaults to
standard Focal Loss parameters (`alpha=2`, `beta=4`).
"""
weight: float = 1.0
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
class ClassificationLossConfig(BaseConfig):
"""Configuration for the classification loss component.
Attributes
----------
weight : float, default=2.0
Weighting factor for the classification loss in the combined total loss.
focal : FocalLossConfig
Configuration for the Focal Loss used for classification. Defaults to
standard Focal Loss parameters (`alpha=2`, `beta=4`).
"""
weight: float = 2.0
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
class_weights: Optional[list[float]] = None
class LossConfig(BaseConfig):
"""Aggregated configuration for all loss components.
Defines the configuration and weighting for detection, size regression,
and classification losses used in the main `LossFunction`.
Attributes
----------
detection : DetectionLossConfig
Configuration for the detection loss (Focal Loss).
size : SizeLossConfig
Configuration for the size regression loss (L1 loss).
classification : ClassificationLossConfig
Configuration for the classification loss (Focal Loss).
"""
detection: DetectionLossConfig = Field(default_factory=DetectionLossConfig)
size: SizeLossConfig = Field(default_factory=SizeLossConfig)
classification: ClassificationLossConfig = Field(
@ -117,50 +325,157 @@ class LossConfig(BaseConfig):
)
class Losses(NamedTuple):
detection: torch.Tensor
size: torch.Tensor
classification: torch.Tensor
total: torch.Tensor
class LossFunction(nn.Module, LossProtocol):
"""Computes the combined training loss for the BatDetect2 model.
Aggregates individual loss functions for detection, size regression, and
classification tasks. Calculates each component loss based on model outputs
and ground truth targets, applies configured weights, and sums them to get
the final total loss used for optimization. Also returns individual
components for monitoring.
Parameters
----------
size_loss : nn.Module
Instantiated loss module for size regression (e.g., `BBoxLoss`).
detection_loss : nn.Module
Instantiated loss module for detection (e.g., `FocalLoss`).
classification_loss : nn.Module
Instantiated loss module for classification (e.g., `FocalLoss`).
size_weight : float, default=0.1
Weighting factor for the size loss component.
detection_weight : float, default=1.0
Weighting factor for the detection loss component.
classification_weight : float, default=2.0
Weighting factor for the classification loss component.
Attributes
----------
size_loss_fn : nn.Module
detection_loss_fn : nn.Module
classification_loss_fn : nn.Module
size_weight : float
detection_weight : float
classification_weight : float
"""
def __init__(
self,
size_loss: nn.Module,
detection_loss: nn.Module,
classification_loss: nn.Module,
size_weight: float = 0.1,
detection_weight: float = 1.0,
classification_weight: float = 2.0,
):
super().__init__()
self.size_loss_fn = size_loss
self.detection_loss_fn = detection_loss
self.classification_loss_fn = classification_loss
self.size_weight = size_weight
self.detection_weight = detection_weight
self.classification_weight = classification_weight
def forward(
self,
pred: ModelOutput,
gt: TrainExample,
) -> Losses:
"""Calculate the combined loss and individual components.
Parameters
----------
pred: ModelOutput
A NamedTuple containing the model's prediction tensors for the
batch: `detection_probs`, `size_preds`, `class_probs`.
gt: TrainExample
A structure containing the ground truth targets for the batch,
expected to have attributes like `detection_heatmap`,
`size_heatmap`, and `class_heatmap` (as `torch.Tensor`).
Returns
-------
Losses
A NamedTuple containing the scalar loss values for detection, size,
classification, and the total weighted loss.
"""
size_loss = self.size_loss_fn(pred.size_preds, gt.size_heatmap)
detection_loss = self.detection_loss_fn(
pred.detection_probs,
gt.detection_heatmap,
)
classification_loss = self.classification_loss_fn(
pred.class_probs,
gt.class_heatmap,
)
total_loss = (
size_loss * self.size_weight
+ classification_loss * self.classification_weight
+ detection_loss * self.detection_weight
)
return Losses(
detection=detection_loss,
size=size_loss,
classification=classification_loss,
total=total_loss,
)
def compute_loss(
batch: TrainExample,
outputs: ModelOutput,
conf: LossConfig,
class_weights: Optional[torch.Tensor] = None,
) -> Losses:
detection_loss = focal_loss(
outputs.detection_probs,
batch.detection_heatmap,
beta=conf.detection.focal.beta,
alpha=conf.detection.focal.alpha,
def build_loss(
config: Optional[LossConfig] = None,
class_weights: Optional[np.ndarray] = None,
) -> nn.Module:
"""Factory function to build the main LossFunction from configuration.
Instantiates the necessary loss components (`BBoxLoss`, `FocalLoss`) based
on the provided `LossConfig` (or defaults) and optional `class_weights`,
then assembles them into the main `LossFunction` module with the specified
component weights.
Parameters
----------
config : LossConfig, optional
Configuration object defining weights and parameters (e.g., alpha, beta
for Focal Loss) for each loss component. If None, default settings
from `LossConfig` and its nested configs are used.
class_weights : np.ndarray, optional
An array of weights for each specific class, used to adjust the
classification loss (typically Focal Loss). If provided, this overrides
any `class_weights` specified within `config.classification`. If None,
weights from the config (or default of equal weights) are used.
Returns
-------
LossFunction
An initialized `LossFunction` module ready for training.
"""
config = config or LossConfig()
class_weights_tensor = (
torch.tensor(class_weights) if class_weights else None
)
size_loss = bbox_size_loss(
outputs.size_preds,
batch.size_heatmap,
detection_loss_fn = FocalLoss(
beta=config.detection.focal.beta,
alpha=config.detection.focal.alpha,
mask_zero=False,
)
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
classification_loss = focal_loss(
outputs.class_probs,
batch.class_heatmap,
weights=class_weights,
valid_mask=valid_mask,
beta=conf.classification.focal.beta,
alpha=conf.classification.focal.alpha,
classification_loss_fn = FocalLoss(
beta=config.classification.focal.beta,
alpha=config.classification.focal.alpha,
class_weights=class_weights_tensor,
mask_zero=True,
)
total = (
detection_loss * conf.detection.weight
+ size_loss * conf.size.weight
+ classification_loss * conf.classification.weight
)
size_loss_fn = BBoxLoss()
return Losses(
detection=detection_loss,
size=size_loss,
classification=classification_loss,
total=total,
return LossFunction(
size_loss=size_loss_fn,
classification_loss=classification_loss_fn,
detection_loss=detection_loss_fn,
size_weight=config.size.weight,
detection_weight=config.detection.weight,
classification_weight=config.classification.weight,
)

View File

@ -1,97 +1,102 @@
"""Module for preprocessing data for training."""
"""Preprocesses datasets for BatDetect2 model training.
This module provides functions to take a collection of annotated audio clips
(`soundevent.data.ClipAnnotation`) and process them into the final format
required for training a BatDetect2 model. This typically involves:
1. Loading the relevant audio segment for each annotation using a configured
`PreprocessorProtocol`.
2. Generating the corresponding input spectrogram using the
`PreprocessorProtocol`.
3. Generating the target heatmaps (detection, classification, size) using a
configured `ClipLabeller` (which encapsulates the `TargetProtocol` logic).
4. Packaging the input spectrogram, target heatmaps, and potentially the
processed audio waveform into an `xarray.Dataset`.
5. Saving each processed `xarray.Dataset` to a separate file (typically NetCDF)
in an output directory.
This offline preprocessing is often preferred for large datasets as it avoids
computationally intensive steps during the actual training loop. The module
includes utilities for parallel processing using `multiprocessing`.
"""
import os
from functools import partial
from multiprocessing import Pool
from pathlib import Path
from typing import Callable, Optional, Sequence, Union
from typing import Callable, Optional, Sequence
import xarray as xr
from pydantic import Field
from loguru import logger
from soundevent import data
from tqdm.auto import tqdm
from batdetect2.configs import BaseConfig
from batdetect2.preprocess import (
PreprocessingConfig,
compute_spectrogram,
load_clip_audio,
)
from batdetect2.train.labels import LabelConfig, generate_heatmaps
from batdetect2.train.targets import (
TargetConfig,
build_target_encoder,
build_sound_event_filter,
get_class_names,
)
PathLike = Union[Path, str, os.PathLike]
FilenameFn = Callable[[data.ClipAnnotation], str]
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.train.types import ClipLabeller
__all__ = [
"preprocess_annotations",
"preprocess_single_annotation",
"generate_train_example",
"TrainPreprocessingConfig",
]
class TrainPreprocessingConfig(BaseConfig):
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
target: TargetConfig = Field(default_factory=TargetConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
FilenameFn = Callable[[data.ClipAnnotation], str]
"""Type alias for a function that generates an output filename."""
def generate_train_example(
clip_annotation: data.ClipAnnotation,
preprocessing_config: Optional[PreprocessingConfig] = None,
target_config: Optional[TargetConfig] = None,
label_config: Optional[LabelConfig] = None,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
) -> xr.Dataset:
"""Generate a training example."""
config = TrainPreprocessingConfig(
preprocessing=preprocessing_config or PreprocessingConfig(),
target=target_config or TargetConfig(),
labels=label_config or LabelConfig(),
)
"""Generate a complete training example for one annotation.
wave = load_clip_audio(
clip_annotation.clip,
config=config.preprocessing.audio,
)
This function takes a single `ClipAnnotation`, applies the configured
preprocessing (`PreprocessorProtocol`) to get the processed waveform and
input spectrogram, applies the configured target generation
(`ClipLabeller`) to get the target heatmaps, and packages them all into a
single `xr.Dataset`.
spectrogram = compute_spectrogram(
wave,
config=config.preprocessing.spectrogram,
)
Parameters
----------
clip_annotation : data.ClipAnnotation
The annotated clip to process. Contains the reference to the `Clip`
(audio segment) and the associated `SoundEventAnnotation` objects.
preprocessor : PreprocessorProtocol
An initialized preprocessor object responsible for loading/processing
audio and computing the input spectrogram.
labeller : ClipLabeller
An initialized clip labeller function responsible for generating the
target heatmaps (detection, class, size) from the `clip_annotation`
and the computed spectrogram.
filter_fn = build_sound_event_filter(
include=config.target.include,
exclude=config.target.exclude,
)
Returns
-------
xr.Dataset
An xarray Dataset containing the following data variables:
- `audio`: The preprocessed audio waveform (dims: 'audio_time').
- `spectrogram`: The computed input spectrogram
(dims: 'time', 'frequency').
- `detection`: The target detection heatmap
(dims: 'time', 'frequency').
- `class`: The target class heatmap
(dims: 'category', 'time', 'frequency').
- `size`: The target size heatmap
(dims: 'dimension', 'time', 'frequency').
The Dataset also includes metadata in its attributes.
selected_events = [
event for event in clip_annotation.sound_events if filter_fn(event)
]
Notes
-----
- The 'time' dimension of the 'audio' DataArray is renamed to 'audio_time'
within the output Dataset to avoid coordinate conflicts with the
spectrogram's 'time' dimension when stored together.
- The original `ClipAnnotation` metadata is stored as a JSON string in the
Dataset's attributes for provenance.
"""
wave = preprocessor.load_clip_audio(clip_annotation.clip)
encoder = build_target_encoder(
config.target.classes,
replacement_rules=config.target.replace,
)
class_names = get_class_names(config.target.classes)
spectrogram = preprocessor.compute_spectrogram(wave)
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
selected_events,
spectrogram,
class_names,
encoder,
target_sigma=config.labels.heatmaps.sigma,
position=config.labels.heatmaps.position,
time_scale=config.labels.heatmaps.time_scale,
frequency_scale=config.labels.heatmaps.frequency_scale,
)
heatmaps = labeller(clip_annotation, spectrogram)
dataset = xr.Dataset(
{
@ -101,15 +106,14 @@ def generate_train_example(
# as the waveform.
"audio": wave.rename({"time": "audio_time"}),
"spectrogram": spectrogram,
"detection": detection_heatmap,
"class": class_heatmap,
"size": size_heatmap,
"detection": heatmaps.detection,
"class": heatmaps.classes,
"size": heatmaps.size,
}
)
return dataset.assign_attrs(
title=f"Training example for {clip_annotation.uuid}",
config=config.model_dump_json(),
clip_annotation=clip_annotation.model_dump_json(
exclude_none=True,
exclude_defaults=True,
@ -118,13 +122,25 @@ def generate_train_example(
)
def save_to_file(
def _save_xr_dataset_to_file(
dataset: xr.Dataset,
path: PathLike,
path: data.PathLike,
) -> None:
"""Save an xarray Dataset to a NetCDF file with compression.
Internal helper function used by `preprocess_single_annotation`.
Parameters
----------
dataset : xr.Dataset
The training example dataset to save.
path : PathLike
The output file path (e.g., 'output/uuid.nc').
"""
dataset.to_netcdf(
path,
encoding={
"audio": {"zlib": True},
"spectrogram": {"zlib": True},
"size": {"zlib": True},
"class": {"zlib": True},
@ -134,20 +150,60 @@ def save_to_file(
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
"""Generate a default output filename based on the annotation UUID."""
return f"{clip_annotation.uuid}.nc"
def preprocess_annotations(
clip_annotations: Sequence[data.ClipAnnotation],
output_dir: PathLike,
output_dir: data.PathLike,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
filename_fn: FilenameFn = _get_filename,
replace: bool = False,
preprocessing_config: Optional[PreprocessingConfig] = None,
target_config: Optional[TargetConfig] = None,
label_config: Optional[LabelConfig] = None,
max_workers: Optional[int] = None,
) -> None:
"""Preprocess annotations and save to disk."""
"""Preprocess a sequence of ClipAnnotations and save results to disk.
Generates the full training example (spectrogram, heatmaps, etc.) for each
`ClipAnnotation` in the input sequence using the provided `preprocessor`
and `labeller`. Saves each example as a separate NetCDF file in the
`output_dir`. Utilizes multiprocessing for potentially faster processing.
Parameters
----------
clip_annotations : Sequence[data.ClipAnnotation]
A sequence (e.g., list) of the clip annotations to preprocess.
output_dir : PathLike
Path to the directory where the processed NetCDF files will be saved.
Will be created if it doesn't exist.
preprocessor : PreprocessorProtocol
Initialized preprocessor object to generate spectrograms.
labeller : ClipLabeller
Initialized labeller function to generate target heatmaps.
filename_fn : FilenameFn, optional
Function to generate the output filename (without extension) for each
`ClipAnnotation`. Defaults to using the annotation UUID via
`_get_filename`.
replace : bool, default=False
If True, existing files in `output_dir` with the same generated name
will be overwritten. If False (default), existing files are skipped.
max_workers : int, optional
Maximum number of worker processes to use for parallel processing.
If None (default), uses the number of CPUs available (`os.cpu_count()`).
Returns
-------
None
This function does not return anything; its side effect is creating
files in the `output_dir`.
Raises
------
RuntimeError
If processing fails for any individual annotation when using
multiprocessing. The original exception will be attached as the cause.
"""
output_dir = Path(output_dir)
if not output_dir.is_dir():
@ -162,9 +218,8 @@ def preprocess_annotations(
output_dir=output_dir,
filename_fn=filename_fn,
replace=replace,
preprocessing_config=preprocessing_config,
target_config=target_config,
label_config=label_config,
preprocessor=preprocessor,
labeller=labeller,
),
clip_annotations,
),
@ -175,13 +230,34 @@ def preprocess_annotations(
def preprocess_single_annotation(
clip_annotation: data.ClipAnnotation,
output_dir: PathLike,
preprocessing_config: Optional[PreprocessingConfig] = None,
target_config: Optional[TargetConfig] = None,
label_config: Optional[LabelConfig] = None,
output_dir: data.PathLike,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
filename_fn: FilenameFn = _get_filename,
replace: bool = False,
) -> None:
"""Process a single ClipAnnotation and save the result to a file.
Internal function designed to be called by `preprocess_annotations`, often
in parallel worker processes. It generates the training example using
`generate_train_example` and saves it using `save_to_file`. Handles
file existence checks based on the `replace` flag.
Parameters
----------
clip_annotation : data.ClipAnnotation
The single annotation to process.
output_dir : Path
The directory where the output NetCDF file should be saved.
preprocessor : PreprocessorProtocol
Initialized preprocessor object.
labeller : ClipLabeller
Initialized labeller function.
filename_fn : FilenameFn, default=_get_filename
Function to determine the output filename.
replace : bool, default=False
Whether to overwrite existing output files.
"""
output_dir = Path(output_dir)
filename = filename_fn(clip_annotation)
@ -196,13 +272,12 @@ def preprocess_single_annotation(
try:
sample = generate_train_example(
clip_annotation,
preprocessing_config=preprocessing_config,
target_config=target_config,
label_config=label_config,
preprocessor=preprocessor,
labeller=labeller,
)
except Exception as error:
raise RuntimeError(
f"Failed to process annotation: {clip_annotation.uuid}"
) from error
save_to_file(sample, path)
_save_xr_dataset_to_file(sample, path)

View File

@ -1,181 +0,0 @@
from collections.abc import Iterable
from functools import partial
from pathlib import Path
from typing import Callable, List, Optional, Set
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.terms import TagInfo, get_tag_from_info
__all__ = [
"TargetConfig",
"load_target_config",
"build_target_encoder",
"build_decoder",
"filter_sound_event",
]
class ReplaceConfig(BaseConfig):
"""Configuration for replacing tags."""
original: TagInfo
replacement: TagInfo
class TargetConfig(BaseConfig):
"""Configuration for target generation."""
classes: List[TagInfo] = Field(
default_factory=lambda: [
TagInfo(key="class", value=value) for value in DEFAULT_SPECIES_LIST
]
)
generic_class: Optional[TagInfo] = Field(
default_factory=lambda: TagInfo(key="class", value="Bat")
)
include: Optional[List[TagInfo]] = Field(
default_factory=lambda: [TagInfo(key="event", value="Echolocation")]
)
exclude: Optional[List[TagInfo]] = Field(
default_factory=lambda: [
TagInfo(key="class", value=""),
TagInfo(key="class", value=" "),
TagInfo(key="class", value="Unknown"),
]
)
replace: Optional[List[ReplaceConfig]] = None
def build_sound_event_filter(
include: Optional[List[TagInfo]] = None,
exclude: Optional[List[TagInfo]] = None,
) -> Callable[[data.SoundEventAnnotation], bool]:
include_tags = (
{get_tag_from_info(tag) for tag in include} if include else None
)
exclude_tags = (
{get_tag_from_info(tag) for tag in exclude} if exclude else None
)
return partial(
filter_sound_event,
include=include_tags,
exclude=exclude_tags,
)
def get_tag_label(tag_info: TagInfo) -> str:
return tag_info.label if tag_info.label else tag_info.value
def get_class_names(classes: List[TagInfo]) -> List[str]:
return sorted({get_tag_label(tag) for tag in classes})
def build_replacer(
rules: List[ReplaceConfig],
) -> Callable[[data.Tag], data.Tag]:
mapping = {
get_tag_from_info(rule.original): get_tag_from_info(rule.replacement)
for rule in rules
}
def replacer(tag: data.Tag) -> data.Tag:
return mapping.get(tag, tag)
return replacer
def build_target_encoder(
classes: List[TagInfo],
replacement_rules: Optional[List[ReplaceConfig]] = None,
) -> Callable[[Iterable[data.Tag]], Optional[str]]:
target_tags = set([get_tag_from_info(tag) for tag in classes])
tag_mapping = {
tag: get_tag_label(tag_info)
for tag, tag_info in zip(target_tags, classes)
}
replacer = (
build_replacer(replacement_rules) if replacement_rules else lambda x: x
)
def encoder(
tags: Iterable[data.Tag],
) -> Optional[str]:
sanitized_tags = {replacer(tag) for tag in tags}
intersection = sanitized_tags & target_tags
if not intersection:
return None
first = intersection.pop()
return tag_mapping[first]
return encoder
def build_decoder(
classes: List[TagInfo],
) -> Callable[[str], List[data.Tag]]:
target_tags = set([get_tag_from_info(tag) for tag in classes])
tag_mapping = {
get_tag_label(tag_info): tag
for tag, tag_info in zip(target_tags, classes)
}
def decoder(label: str) -> List[data.Tag]:
tag = tag_mapping.get(label)
return [tag] if tag else []
return decoder
def filter_sound_event(
sound_event_annotation: data.SoundEventAnnotation,
include: Optional[Set[data.Tag]] = None,
exclude: Optional[Set[data.Tag]] = None,
) -> bool:
tags = set(sound_event_annotation.tags)
if include is not None and not tags & include:
return False
if exclude is not None and tags & exclude:
return False
return True
def load_target_config(
path: Path, field: Optional[str] = None
) -> TargetConfig:
return load_config(path, schema=TargetConfig, field=field)
DEFAULT_SPECIES_LIST = [
"Barbastellus barbastellus",
"Eptesicus serotinus",
"Myotis alcathoe",
"Myotis bechsteinii",
"Myotis brandtii",
"Myotis daubentonii",
"Myotis mystacinus",
"Myotis nattereri",
"Nyctalus leisleri",
"Nyctalus noctula",
"Pipistrellus nathusii",
"Pipistrellus pipistrellus",
"Pipistrellus pygmaeus",
"Plecotus auritus",
"Plecotus austriacus",
"Rhinolophus ferrumequinum",
"Rhinolophus hipposideros",
]

100
batdetect2/train/types.py Normal file
View File

@ -0,0 +1,100 @@
from typing import Callable, NamedTuple, Protocol, Tuple
import torch
import xarray as xr
from soundevent import data
from batdetect2.models import ModelOutput
__all__ = [
"Heatmaps",
"ClipLabeller",
"Augmentation",
"LossProtocol",
"TrainExample",
]
class Heatmaps(NamedTuple):
"""Structure holding the generated heatmap targets.
Attributes
----------
detection : xr.DataArray
Heatmap indicating the probability of sound event presence. Typically
smoothed with a Gaussian kernel centered on event reference points.
Shape matches the input spectrogram. Values normalized [0, 1].
classes : xr.DataArray
Heatmap indicating the probability of specific class presence. Has an
additional 'category' dimension corresponding to the target class
names. Each category slice is typically smoothed with a Gaussian
kernel. Values normalized [0, 1] per category.
size : xr.DataArray
Heatmap encoding the size (width, height) of detected events. Has an
additional 'dimension' coordinate ('width', 'height'). Values represent
scaled dimensions placed at the event reference points.
"""
detection: xr.DataArray
classes: xr.DataArray
size: xr.DataArray
ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps]
"""Type alias for the final clip labelling function.
This function takes the complete annotations for a clip and the corresponding
spectrogram, applies all configured filtering, transformation, and encoding
steps, and returns the final `Heatmaps` used for model training.
"""
Augmentation = Callable[[xr.Dataset], xr.Dataset]
class TrainExample(NamedTuple):
spec: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
idx: torch.Tensor
start_time: float
end_time: float
class Losses(NamedTuple):
"""Structure to hold the computed loss values.
Allows returning individual loss components along with the total weighted
loss for monitoring and analysis during training.
Attributes
----------
detection : torch.Tensor
Scalar tensor representing the calculated detection loss component
(before weighting).
size : torch.Tensor
Scalar tensor representing the calculated size regression loss component
(before weighting).
classification : torch.Tensor
Scalar tensor representing the calculated classification loss component
(before weighting).
total : torch.Tensor
Scalar tensor representing the final combined loss, computed as the
weighted sum of the detection, size, and classification components.
This is the value typically used for backpropagation.
"""
detection: torch.Tensor
size: torch.Tensor
classification: torch.Tensor
total: torch.Tensor
class LossProtocol(Protocol):
def __call__(self, pred: ModelOutput, gt: TrainExample) -> Losses: ...
class ClipperProtocol(Protocol):
def extract_clip(
self, example: xr.Dataset
) -> Tuple[xr.Dataset, float, float]: ...

View File

@ -1,14 +1,10 @@
"""Types used in the code base."""
from typing import Any, List, NamedTuple, Optional
from typing import Any, List, NamedTuple, Optional, TypedDict
import numpy as np
import torch
from typing import TypedDict
try:
from typing import Protocol
except ImportError:

20
docs/Makefile Normal file
View File

@ -0,0 +1,20 @@
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

35
docs/make.bat Normal file
View File

@ -0,0 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

61
docs/source/conf.py Normal file
View File

@ -0,0 +1,61 @@
# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
project = "batdetect2"
copyright = "2025, Oisin Mac Aodha, Santiago Martinez Balvanera"
author = "Oisin Mac Aodha, Santiago Martinez Balvanera"
release = "1.1.1"
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.intersphinx",
"sphinxcontrib.autodoc_pydantic",
"numpydoc",
"myst_parser",
"sphinx_autodoc_typehints",
]
templates_path = ["_templates"]
exclude_patterns = []
source_suffix = {
".rst": "restructuredtext",
".txt": "markdown",
".md": "markdown",
}
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = "sphinx_book_theme"
html_static_path = ["_static"]
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"soundevent": ("https://mbsantiago.github.io/soundevent/", None),
"pydantic": ("https://docs.pydantic.dev/latest/", None),
"xarray": ("https://docs.xarray.dev/en/stable/", None),
}
# -- Options for autodoc ------------------------------------------------------
autosummary_generate = True
autosummary_imported_members = True
autodoc_default_options = {
"members": True,
"undoc-members": False,
"private-members": False,
"special-members": False,
"inherited-members": False,
"show-inheritance": True,
"module-first": True,
}

106
docs/source/data/aoef.md Normal file
View File

@ -0,0 +1,106 @@
# Using AOEF / Soundevent Data Sources
## Introduction
The **AOEF (Acoustic Open Event Format)**, stored as `.json` files, is the annotation format used by the underlying `soundevent` library and is compatible with annotation tools like **Whombat**.
BatDetect2 can directly load annotation data stored in this format.
This format can represent two main types of annotation collections:
1. `AnnotationSet`: A straightforward collection of annotations for various audio clips.
2. `AnnotationProject`: A more structured format often exported by annotation tools (like Whombat).
It includes not only the annotations but also information about annotation _tasks_ (work assigned to annotators) and their status (e.g., in-progress, completed, verified, rejected).
This section explains how to configure a data source in your `DatasetConfig` to load data from either type of AOEF file.
## Configuration
To define a data source using the AOEF format, you add an entry to the `sources` list in your main `DatasetConfig` (usually within your primary YAML configuration file) and set the `format` field to `"aoef"`.
Here are the key fields you need to specify for an AOEF source:
- `format: "aoef"`: **(Required)** Tells BatDetect2 to use the AOEF loader for this source.
- `name: your_source_name`: **(Required)** A unique name you choose for this data source (e.g., `"whombat_project_export"`, `"final_annotations"`).
- `audio_dir: path/to/audio/files`: **(Required)** The path to the directory where the actual audio `.wav` files referenced in the annotations are located.
- `annotations_path: path/to/your/annotations.aoef`: **(Required)** The path to the single `.aoef` or `.json` file containing the annotation data (either an `AnnotationSet` or an `AnnotationProject`).
- `description: "Details about this source..."`: (Optional) A brief description of the data source.
- `filter: ...`: **(Optional)** Specific settings used _only if_ the `annotations_path` file contains an `AnnotationProject`.
See details below.
## Filtering Annotation Projects (Optional)
When working with annotation projects, especially collaborative ones or those still in progress (like exports from Whombat), you often want to train only on annotations that are considered complete and reliable.
The optional `filter:` section allows you to specify criteria based on the status of the annotation _tasks_ within the project.
**If `annotations_path` points to a simple `AnnotationSet` file, the `filter:` section is ignored.**
If `annotations_path` points to an `AnnotationProject`, you can add a `filter:` block with the following options:
- `only_completed: <true_or_false>`:
- `true` (Default): Only include annotations from tasks that have been marked as "completed".
- `false`: Include annotations regardless of task completion status.
- `only_verified: <true_or_false>`:
- `false` (Default): Verification status is not considered.
- `true`: Only include annotations from tasks that have _also_ been marked as "verified" (typically meaning they passed a review step).
- `exclude_issues: <true_or_false>`:
- `true` (Default): Exclude annotations from any task that has been marked as "rejected" or flagged with issues.
- `false`: Include annotations even if their task was marked as having issues (use with caution).
**Default Filtering:** If you include the `filter:` block but omit some options, or if you _omit the entire `filter:` block_, the default settings are applied to `AnnotationProject` files: `only_completed: true`, `only_verified: false`, `exclude_issues: true`.
This common default selects annotations from completed tasks that haven't been rejected, without requiring separate verification.
**Disabling Filtering:** If you want to load _all_ annotations from an `AnnotationProject` regardless of task status, you can explicitly disable filtering by setting `filter: null` in your YAML configuration.
## YAML Configuration Examples
**Example 1: Loading a standard AnnotationSet (or a Project with default filtering)**
```yaml
# In your main DatasetConfig YAML file
sources:
- name: "MyFinishedAnnotations"
format: "aoef" # Specifies the loader
audio_dir: "/path/to/my/audio/"
annotations_path: "/path/to/my/dataset.soundevent.json" # Path to the AOEF file
description: "Finalized annotations set."
# No 'filter:' block means default filtering applied IF it's an AnnotationProject,
# or no filtering applied if it's an AnnotationSet.
```
**Example 2: Loading an AnnotationProject, requiring verification**
```yaml
# In your main DatasetConfig YAML file
sources:
- name: "WhombatVerifiedExport"
format: "aoef"
audio_dir: "relative/path/to/audio/" # Relative to where BatDetect2 runs or a base_dir
annotations_path: "exports/whombat_project.aoef" # Path to the project file
description: "Annotations from Whombat project, only using verified tasks."
filter: # Customize the filter
only_completed: true # Still require completion
only_verified: true # *Also* require verification
exclude_issues: true # Still exclude rejected tasks
```
**Example 3: Loading an AnnotationProject, disabling all filtering**
```yaml
# In your main DatasetConfig YAML file
sources:
- name: "WhombatRawExport"
format: "aoef"
audio_dir: "data/audio_pool/"
annotations_path: "exports/whombat_project_all.aoef"
description: "All annotations from Whombat, regardless of task status."
filter: null # Explicitly disable task filtering
```
## Summary
To load standard `soundevent` annotations (including Whombat exports), set `format: "aoef"` for your data source in the `DatasetConfig`.
Provide the `audio_dir` and the path to the single `annotations_path` file.
If dealing with `AnnotationProject` files, you can optionally use the `filter:` block to select annotations based on task completion, verification, or issue status.

View File

@ -0,0 +1,9 @@
# Loading Data
```{toctree}
:maxdepth: 1
:caption: Loading Data
aoef
legacy
```

122
docs/source/data/legacy.md Normal file
View File

@ -0,0 +1,122 @@
# Using Legacy BatDetect2 Annotation Formats
## Introduction
If you have annotation data created using older BatDetect2 annotation tools, BatDetect2 provides tools to load these datasets.
These older formats typically use JSON files to store annotation information, including bounding boxes and labels for sound events within recordings.
There are two main variations of this legacy format that BatDetect2 can load:
1. **Directory-Based (`format: "batdetect2"`):** Annotations for each audio recording are stored in a _separate_ JSON file within a dedicated directory.
There's a naming convention linking the JSON file to its corresponding audio file (e.g., `my_recording.wav` annotations are stored in `my_recording.wav.json`).
2. **Single Merged File (`format: "batdetect2_file"`):** Annotations for _multiple_ recordings are aggregated into a _single_ JSON file.
This file contains a list, where each item represents the annotations for one recording, following the same internal structure as the directory-based format.
When you configure BatDetect2 to use these formats, it will read the legacy data and convert it internally into the standard `soundevent` data structures used by the rest of the pipeline.
## Configuration
You specify which legacy format to use within the `sources` list of your main `DatasetConfig` (usually in your primary YAML configuration file).
### Format 1: Directory-Based
Use this when you have a folder containing many individual JSON annotation files, one for each audio file.
**Configuration Fields:**
- `format: "batdetect2"`: **(Required)** Identifies this specific legacy format loader.
- `name: your_source_name`: **(Required)** A unique name for this data source.
- `audio_dir: path/to/audio/files`: **(Required)** Path to the directory containing the `.wav` audio files.
- `annotations_dir: path/to/annotation/jsons`: **(Required)** Path to the directory containing the individual `.json` annotation files.
- `description: "Details..."`: (Optional) Description of this source.
- `filter: ...`: (Optional) Settings to filter which JSON files are processed based on flags within them (see "Filtering Legacy Annotations" below).
**YAML Example:**
```yaml
# In your main DatasetConfig YAML file
sources:
- name: "OldProject_SiteA_Files"
format: "batdetect2" # Use the directory-based loader
audio_dir: "/data/SiteA/Audio/"
annotations_dir: "/data/SiteA/Annotations_JSON/"
description: "Legacy annotations stored as individual JSONs per recording."
# filter: ... # Optional filter settings can be added here
```
### Format 2: Single Merged File
Use this when you have a single JSON file that contains a list of annotations for multiple recordings.
**Configuration Fields:**
- `format: "batdetect2_file"`: **(Required)** Identifies this specific legacy format loader.
- `name: your_source_name`: **(Required)** A unique name for this data source.
- `audio_dir: path/to/audio/files`: **(Required)** Path to the directory containing the `.wav` audio files referenced _within_ the merged JSON file.
- `annotations_path: path/to/your/merged_annotations.json`: **(Required)** Path to the single `.json` file containing the list of annotations.
- `description: "Details..."`: (Optional) Description of this source.
- `filter: ...`: (Optional) Settings to filter which records _within_ the merged file are processed (see "Filtering Legacy Annotations" below).
**YAML Example:**
```yaml
# In your main DatasetConfig YAML file
sources:
- name: "OldProject_Merged"
format: "batdetect2_file" # Use the merged file loader
audio_dir: "/data/AllAudio/"
annotations_path: "/data/CombinedAnnotations/old_project_merged.json"
description: "Legacy annotations aggregated into a single JSON file."
# filter: ... # Optional filter settings can be added here
```
## Filtering Legacy Annotations
The legacy JSON annotation structure (for both formats) included boolean flags indicating the status of the annotation work for each recording:
- `annotated`: Typically `true` if a human had reviewed or created annotations for the file.
- `issues`: Typically `true` if problems were noted during annotation or review.
You can optionally filter the data based on these flags using a `filter:` block within the source configuration.
This applies whether you use `"batdetect2"` or `"batdetect2_file"`.
**Filter Options:**
- `only_annotated: <true_or_false>`:
- `true` (**Default**): Only process entries where the `annotated` flag in the JSON is `true`.
- `false`: Process entries regardless of the `annotated` flag.
- `exclude_issues: <true_or_false>`:
- `true` (**Default**): Skip processing entries where the `issues` flag in the JSON is `true`.
- `false`: Process entries even if they are flagged with `issues`.
**Default Filtering:** If you **omit** the `filter:` block entirely, the default settings (`only_annotated: true`, `exclude_issues: true`) are applied automatically.
This means only entries marked as annotated and not having issues will be loaded.
**Disabling Filtering:** To load _all_ entries from the legacy source regardless of the `annotated` or `issues` flags, explicitly disable the filter:
```yaml
filter: null
```
**YAML Example (Custom Filter):** Only load entries marked as annotated, but _include_ those with issues.
```yaml
sources:
- name: "LegacyData_WithIssues"
format: "batdetect2" # Or "batdetect2_file"
audio_dir: "path/to/audio"
annotations_dir: "path/to/annotations" # Or annotations_path for merged
filter:
only_annotated: true
exclude_issues: false # Include entries even if issues flag is true
```
## Summary
BatDetect2 allows you to incorporate datasets stored in older "BatDetect2" JSON formats.
- Use `format: "batdetect2"` and provide `annotations_dir` if you have one JSON file per recording in a directory.
- Use `format: "batdetect2_file"` and provide `annotations_path` if you have a single JSON file containing annotations for multiple recordings.
- Optionally use the `filter:` block with `only_annotated` and `exclude_issues` to select data based on flags present in the legacy JSON structure.
The system will handle loading, filtering (if configured), and converting this legacy data into the standard `soundevent` format used internally.

14
docs/source/index.md Normal file
View File

@ -0,0 +1,14 @@
# batdetect2 documentation
Hi!
```{toctree}
:maxdepth: 1
:caption: Contents:
data/index
preprocessing/index
postprocessing
targets/index
reference/index
```

View File

@ -0,0 +1,126 @@
# Postprocessing: From Model Output to Predictions
## What is Postprocessing?
After the BatDetect2 neural network analyzes a spectrogram, it doesn't directly output a neat list of bat calls.
Instead, it produces raw numerical data, usually in the form of multi-dimensional arrays or "heatmaps".
These arrays contain information like:
- The probability of a sound event being present at each time-frequency location.
- The probability of each possible target class (e.g., species) at each location.
- Predicted size characteristics (like duration and bandwidth) at each location.
- Internal learned features at each location.
**Postprocessing** is the sequence of steps that takes these numerical model outputs and translates them into a structured list of detected sound events, complete with predicted tags, bounding boxes, and confidence scores.
The {py:mod}`batdetect2.postprocess` mode handles this entire workflow.
## Why is Postprocessing Necessary?
1. **Interpretation:** Raw heatmap outputs need interpretation to identify distinct sound events (detections).
A high probability score might spread across several adjacent time-frequency bins, all related to the same call.
2. **Refinement:** Model outputs can be noisy or contain redundancies.
Postprocessing steps like Non-Maximum Suppression (NMS) clean this up, ensuring (ideally) only one detection is reported for each actual sound event.
3. **Contextualization:** Raw outputs lack real-world units.
Postprocessing adds back time (seconds) and frequency (Hz) coordinates, converts predicted sizes to physical units using configured scales, and decodes predicted class indices back into meaningful tags based on your target definitions.
4. **User Control:** Postprocessing includes tunable parameters, most importantly **thresholds**.
By adjusting these, you can control the trade-off between finding more potential calls (sensitivity) versus reducing false positives (specificity) _without needing to retrain the model_.
## The Postprocessing Pipeline
BatDetect2 applies a series of steps to convert the raw model output into final predictions.
Understanding these steps helps interpret the results and configure the process effectively:
1. **Non-Maximum Suppression (NMS):**
- **Goal:** Reduce redundant detections.
If the model outputs high scores for several nearby points corresponding to the same call, NMS selects the single highest peak in a local neighbourhood and suppresses the others (sets their score to zero).
- **Configurable:** The size of the neighbourhood (`nms_kernel_size`) can be adjusted.
2. **Coordinate Remapping:**
- **Goal:** Add coordinate (time/frequency) information.
This step takes the grid-based model outputs (which just have row/column indices) and associates them with actual time (seconds) and frequency (Hz) coordinates based on the input spectrogram's properties.
The result is coordinate-aware arrays (using {py:class}`xarray.DataArray`}).
3. **Detection Extraction:**
- **Goal:** Identify the specific points representing detected events.
- **Process:** Looks for peaks in the NMS-processed detection heatmap that are above a certain confidence level (`detection_threshold`).
It also often limits the maximum number of detections based on a rate (`top_k_per_sec`) to avoid excessive outputs in very busy files.
- **Configurable:** `detection_threshold`, `top_k_per_sec`.
4. **Data Extraction:**
- **Goal:** Gather all relevant information for each detected point.
- **Process:** For each time-frequency location identified in Step 3, this step looks up the corresponding values in the _other_ remapped model output arrays (class probabilities, predicted sizes, internal features).
- **Intermediate Output 1:** The result of this stage (containing aligned scores, positions, sizes, class probabilities, and features for all detections in a clip) is often accessible programmatically as an {py:class}`xarray.Dataset`}.
This can be useful for advanced users needing direct access to the numerical outputs.
5. **Decoding & Formatting:**
- **Goal:** Convert the extracted numerical data into interpretable, standard formats.
- **Process:**
- **ROI Recovery:** Uses the predicted position and size values, along with the ROI mapping configuration defined in the `targets` module, to reconstruct an estimated bounding box ({py:class}`soundevent.data.BoundingBox`}).
- **Class Decoding:** Translates the numerical class probability vector into meaningful {py:class}`soundevent.data.PredictedTag` objects.
This involves:
- Applying the `classification_threshold` to ignore low-confidence class scores.
- Using the class decoding rules (from the `targets` module) to map the name(s) of the high-scoring class(es) back to standard tags (like `species: Myotis daubentonii`).
- Optionally selecting only the top-scoring class or multiple classes above the threshold.
- Including the generic "Bat" tags if no specific class meets the threshold.
- **Feature Conversion:** Converts raw feature vectors into {py:class}`soundevent.data.Feature` objects.
- **Intermediate Output 2:** This step might internally create a list of simplified `RawPrediction` objects containing the bounding box, scores, and features.
This intermediate list might also be accessible programmatically for users who prefer a simpler structure than the final {py:mod}`soundevent` objects.
6. **Final Output (`ClipPrediction`):**
- **Goal:** Package everything into a standard format.
- **Process:** Collects all the fully processed `SoundEventPrediction` objects (each containing a sound event with geometry, features, overall score, and predicted tags with scores) for a given audio clip into a final {py:class}`soundevent.data.ClipPrediction` object.
This is the standard output format representing the model's findings for that clip.
## Configuring Postprocessing
You can control key aspects of this pipeline, especially the thresholds and NMS settings, via a `postprocess:` section in your main configuration YAML file.
Adjusting these **allows you to fine-tune the detection results without retraining the model**.
**Key Configurable Parameters:**
- `detection_threshold`: (Number >= 0, e.g., `0.1`) Minimum score for a peak to be considered a detection.
**Lowering this increases sensitivity (more detections, potentially more false positives); raising it increases specificity (fewer detections, potentially missing faint calls).**
- `classification_threshold`: (Number >= 0, e.g., `0.3`) Minimum score for a _specific class_ prediction to be assigned as a tag.
Affects how confidently the model must identify the class.
- `top_k_per_sec`: (Integer > 0, e.g., `200`) Limits the maximum density of detections reported per second.
Helps manage extremely dense recordings.
- `nms_kernel_size`: (Integer > 0, e.g., `9`) Size of the NMS window in pixels/bins.
Affects how close two distinct peaks can be before one suppresses the other.
**Example YAML Configuration:**
```yaml
# Inside your main configuration file (e.g., config.yaml)
postprocess:
nms_kernel_size: 9
detection_threshold: 0.1 # Lower threshold -> more sensitive
classification_threshold: 0.3 # Higher threshold -> more confident classifications
top_k_per_sec: 200
# ... other sections preprocessing, targets ...
```
**Note:** These parameters can often also be adjusted via Command Line Interface (CLI) arguments when running predictions, or through function arguments if using the Python API, providing flexibility for experimentation.
## Accessing Intermediate Results
While the final `ClipPrediction` objects are the standard output, the `Postprocessor` object used internally provides methods to access results from intermediate stages (like the `xr.Dataset` after Step 4, or the list of `RawPrediction` objects after Step 5).
This can be valuable for:
- Debugging the pipeline.
- Performing custom analyses on the numerical outputs before final decoding.
- **Transfer Learning / Feature Extraction:** Directly accessing the extracted `features` (from Step 4 or 5a) associated with detected events can be highly useful for training other models or further analysis.
Consult the API documentation for details on how to access these intermediate results programmatically if needed.
## Summary
Postprocessing is the conversion between neural network outputs and meaningful, interpretable sound event detections.
BatDetect2 provides a configurable pipeline including NMS, coordinate remapping, peak detection with thresholding, data extraction, and class/geometry decoding.
Researchers can easily tune key parameters like thresholds via configuration files or arguments to adjust the final set of predictions without altering the trained model itself, and advanced users can access intermediate results for custom analyses or feature reuse.

View File

@ -0,0 +1,92 @@
# Audio Loading and Preprocessing
## Purpose
Before BatDetect2 can analyze the sounds in your recordings, the raw audio data needs to be loaded from the file and prepared.
This initial preparation involves several standard waveform processing steps.
This `audio` module handles this first stage of preprocessing.
It's crucial to understand that the _exact same_ preprocessing steps must be applied both when **training** a model and when **using** that trained model later to make predictions (inference).
Consistent preprocessing ensures the model receives data in the format it expects.
BatDetect2 allows you to control these audio preprocessing steps through settings in your main configuration file.
## The Audio Processing Pipeline
When BatDetect2 needs to process an audio segment (either a full recording or a specific clip), it follows a defined sequence of steps:
1. **Load Audio Segment:** The system first reads the specified time segment from the audio file.
- **Note:** BatDetect2 typically works with **mono** audio.
By default, if your file has multiple channels (e.g., stereo), only the **first channel** is loaded and used for subsequent processing.
2. **Adjust Duration (Optional):** If you've specified a target duration in your configuration, the loaded audio segment is either shortened (by cropping from the start) or lengthened (by adding silence, i.e., zeros, at the end) to match that exact duration.
This is sometimes required by specific model architectures that expect fixed-size inputs.
By default, this step is **off**, and the original clip duration is used.
3. **Resample (Optional):** If configured (and usually **on** by default), the audio's sample rate is changed to a specific target value (e.g., 256,000 Hz).
This is vital for standardizing the data, as different recording devices capture audio at different rates.
The model needs to be trained and run on data with a consistent sample rate.
4. **Center Waveform (Optional):** If configured (and typically **on** by default), the system removes any constant shift away from zero in the waveform (known as DC offset).
This is a standard practice that can sometimes improve the quality of later signal processing steps.
5. **Scale Amplitude (Optional):** If configured (typically **off** by default), the waveform's amplitude (loudness) is adjusted.
The standard method used here is "peak normalization," which scales the entire clip so that the loudest point has an absolute value of 1.0.
This can help standardize volume levels across different recordings, although it's not always necessary or desirable depending on your analysis goals.
## Configuring Audio Processing
You can control these steps via settings in your main configuration file (e.g., `config.yaml`), usually within a dedicated `audio:` section (which might itself be under a broader `preprocessing:` section).
Here are the key options you can set:
- **Resampling (`resample`)**:
- To enable resampling (recommended and usually default), include a `resample:` block.
To disable it completely, you might set `resample: null` or omit the block.
- `samplerate`: (Number) The target sample rate in Hertz (Hz) that all audio will be converted to.
This **must** match the sample rate expected by the BatDetect2 model you are using or training (e.g., `samplerate: 256000`).
- `mode`: (Text, `"poly"` or `"fourier"`) The underlying algorithm used for resampling.
The default `"poly"` is generally a good choice.
You typically don't need to change this unless you have specific reasons.
- **Duration (`duration`)**:
- (Number or `null`) Sets a fixed duration for all audio clips in **seconds**.
If set (e.g., `duration: 4.0`), shorter clips are padded with silence, and longer clips are cropped.
If `null` (default), the original clip duration is used.
- **Centering (`center`)**:
- (Boolean, `true` or `false`) Controls DC offset removal.
Default is usually `true`.
Set to `false` to disable.
- **Scaling (`scale`)**:
- (Boolean, `true` or `false`) Controls peak amplitude normalization.
Default is usually `false`.
Set to `true` to enable scaling so the maximum absolute amplitude becomes 1.0.
**Example YAML Configuration:**
```yaml
# Inside your main configuration file (e.g., training_config.yaml)
preprocessing: # Or this might be at the top level
audio:
# --- Resampling Settings ---
resample: # Settings block to control resampling
samplerate: 256000 # Target sample rate in Hz (Required if resampling)
mode: poly # Algorithm ('poly' or 'fourier', optional, defaults to 'poly')
# To disable resampling entirely, you might use:
# resample: null
# --- Other Settings ---
duration: null # Keep original clip duration (e.g., use 4.0 for 4 seconds)
center: true # Remove DC offset (default is often true)
scale: false # Do not normalize peak amplitude (default is often false)
# ... other configuration sections (like model, dataset, targets) ...
```
## Outcome
After these steps, the output is a standardized audio waveform (represented as a numerical array with time information).
This processed waveform is now ready for the next stage of preprocessing, which typically involves calculating the spectrogram (covered in the next module/section).
Ensuring these audio preprocessing settings are consistent is fundamental for achieving reliable results in both training and inference.

View File

@ -0,0 +1,46 @@
# Preprocessing Audio for BatDetect2
## What is Preprocessing?
Preprocessing refers to the steps taken to transform your raw audio recordings into a standardized format suitable for analysis by the BatDetect2 deep learning model.
This module (`batdetect2.preprocessing`) provides the tools to perform these transformations.
## Why is Preprocessing Important?
Applying a consistent preprocessing pipeline is important for several reasons:
1. **Standardization:** Audio recordings vary significantly depending on the equipment used, recording conditions, and settings (e.g., different sample rates, varying loudness levels, background noise).
Preprocessing helps standardize these aspects, making the data more uniform and allowing the model to learn relevant patterns more effectively.
2. **Model Requirements:** Deep learning models, particularly those like BatDetect2 that analyze 2D-patterns in spectrograms, are designed to work with specific input characteristics.
They often expect spectrograms of a certain size (time x frequency bins), with values represented on a particular scale (e.g., logarithmic/dB), and within a defined frequency range.
Preprocessing ensures the data meets these requirements.
3. **Consistency is Key:** Perhaps most importantly, the **exact same preprocessing steps** must be applied both when _training_ the model and when _using the trained model to make predictions_ (inference) on new data.
Any discrepancy between the preprocessing used during training and inference can significantly degrade the model's performance and lead to unreliable results.
BatDetect2's configurable pipeline ensures this consistency.
## How Preprocessing is Done in BatDetect2
BatDetect2 handles preprocessing through a configurable, two-stage pipeline:
1. **Audio Loading & Preparation:** This first stage deals with the raw audio waveform.
It involves loading the specified audio segment (from a file or clip), selecting a single channel (mono), optionally resampling it to a consistent sample rate, optionally adjusting its duration, and applying basic waveform conditioning like centering (DC offset removal) and amplitude scaling.
(Details in the {doc}`audio` section).
2. **Spectrogram Generation:** The prepared audio waveform is then converted into a spectrogram.
This involves calculating the Short-Time Fourier Transform (STFT) and then applying a series of configurable steps like cropping the frequency range, applying amplitude representations (like dB scale or PCEN), optional denoising, optional resizing to the model's required dimensions, and optional final normalization.
(Details in the {doc}`spectrogram` section).
The entire pipeline is controlled via settings in your main configuration file (typically a YAML file), usually grouped under a `preprocessing:` section which contains subsections like `audio:` and `spectrogram:`.
This allows you to easily define, share, and reproduce the exact preprocessing used for a specific model or experiment.
## Next Steps
Explore the following sections for detailed explanations of how to configure each stage of the preprocessing pipeline and how to use the resulting preprocessor:
```{toctree}
:maxdepth: 1
:caption: Preprocessing Steps:
audio
spectrogram
usage
```

View File

@ -0,0 +1,141 @@
# Spectrogram Generation
## Purpose
After loading and performing initial processing on the audio waveform (as described in the Audio Loading section), the next crucial step in the `preprocessing` pipeline is to convert that waveform into a **spectrogram**.
A spectrogram is a visual representation of sound, showing frequency content over time, and it's the primary input format for many deep learning models, including BatDetect2.
This module handles the calculation and subsequent processing of the spectrogram.
Just like the audio processing, these steps need to be applied **consistently** during both model training and later use (inference) to ensure the model performs reliably.
You control this entire process through the configuration file.
## The Spectrogram Generation Pipeline
Once BatDetect2 has a prepared audio waveform, it follows these steps to create the final spectrogram input for the model:
1. **Calculate STFT (Short-Time Fourier Transform):** This is the fundamental step that converts the 1D audio waveform into a 2D time-frequency representation.
It calculates the frequency content within short, overlapping time windows.
The output is typically a **magnitude spectrogram**, showing the intensity (amplitude) of different frequencies at different times.
Key parameters here are the `window_duration` and `window_overlap`, which affect the trade-off between time and frequency resolution.
2. **Crop Frequencies:** The STFT often produces frequency information over a very wide range (e.g., 0 Hz up to half the sample rate).
This step crops the spectrogram to focus only on the frequency range relevant to your target sounds (e.g., 10 kHz to 120 kHz for typical bat echolocation).
3. **Apply PCEN (Optional):** If configured, Per-Channel Energy Normalization is applied.
PCEN is an adaptive technique that adjusts the gain (loudness) in each frequency channel based on its recent history.
It can help suppress stationary background noise and enhance the prominence of transient sounds like echolocation pulses.
This step is optional.
4. **Set Amplitude Scale / Representation:** The values in the spectrogram (either raw magnitude or post-PCEN values) need to be represented on a suitable scale.
You choose one of the following:
- `"amplitude"`: Use the linear magnitude values directly.
(Default)
- `"power"`: Use the squared magnitude values (representing energy).
- `"dB"`: Apply a logarithmic transformation (specifically `log(1 + C*Magnitude)`).
This compresses the range of values, often making variations in quieter sounds more apparent, similar to how humans perceive loudness.
5. **Denoise (Optional):** If configured (and usually **on** by default), a simple noise reduction technique is applied.
This method subtracts the average value of each frequency bin (calculated across time) from that bin, assuming the average represents steady background noise.
Negative values after subtraction are clipped to zero.
6. **Resize (Optional):** If configured, the dimensions (height/frequency bins and width/time bins) of the spectrogram are adjusted using interpolation to match the exact input size expected by the neural network architecture.
7. **Peak Normalize (Optional):** If configured (typically **off** by default), the entire final spectrogram is scaled so that its highest value is exactly 1.0.
This ensures all spectrograms fed to the model have a consistent maximum value, which can sometimes aid training stability.
## Configuring Spectrogram Generation
You control all these steps via settings in your main configuration file (e.g., `config.yaml`), within the `spectrogram:` section (usually located under the main `preprocessing:` section).
Here are the key configuration options:
- **STFT Settings (`stft`)**:
- `window_duration`: (Number, seconds, e.g., `0.002`) Length of the analysis window.
- `window_overlap`: (Number, 0.0 to <1.0, e.g., `0.75`) Fractional overlap between windows.
- `window_fn`: (Text, e.g., `"hann"`) Name of the windowing function.
- **Frequency Cropping (`frequencies`)**:
- `min_freq`: (Integer, Hz, e.g., `10000`) Minimum frequency to keep.
- `max_freq`: (Integer, Hz, e.g., `120000`) Maximum frequency to keep.
- **PCEN (`pcen`)**:
- This entire section is **optional**.
Include it only if you want to apply PCEN.
If omitted or set to `null`, PCEN is skipped.
- `time_constant`: (Number, seconds, e.g., `0.4`) Controls adaptation speed.
- `gain`: (Number, e.g., `0.98`) Gain factor.
- `bias`: (Number, e.g., `2.0`) Bias factor.
- `power`: (Number, e.g., `0.5`) Compression exponent.
- **Amplitude Scale (`scale`)**:
- (Text: `"dB"`, `"power"`, or `"amplitude"`) Selects the final representation of the spectrogram values.
Default is `"amplitude"`.
- **Denoising (`spectral_mean_substraction`)**:
- (Boolean: `true` or `false`) Enables/disables the spectral mean subtraction denoising step.
Default is usually `true`.
- **Resizing (`size`)**:
- This entire section is **optional**.
Include it only if you need to resize the spectrogram to specific dimensions required by the model.
If omitted or set to `null`, no resizing occurs after frequency cropping.
- `height`: (Integer, e.g., `128`) Target number of frequency bins.
- `resize_factor`: (Number or `null`, e.g., `0.5`) Factor to scale the time dimension by.
`0.5` halves the width, `null` or `1.0` keeps the original width.
- **Peak Normalization (`peak_normalize`)**:
- (Boolean: `true` or `false`) Enables/disables final scaling of the entire spectrogram so the maximum value is 1.0.
Default is usually `false`.
**Example YAML Configuration:**
```yaml
# Inside your main configuration file
preprocessing:
audio:
# ... (your audio configuration settings) ...
resample:
samplerate: 256000 # Ensure this matches model needs
spectrogram:
# --- STFT Parameters ---
stft:
window_duration: 0.002 # 2ms window
window_overlap: 0.75 # 75% overlap
window_fn: hann
# --- Frequency Range ---
frequencies:
min_freq: 10000 # 10 kHz
max_freq: 120000 # 120 kHz
# --- PCEN (Optional) ---
# Include this block to enable PCEN, omit or set to null to disable.
pcen:
time_constant: 0.4
gain: 0.98
bias: 2.0
power: 0.5
# --- Final Amplitude Representation ---
scale: dB # Choose 'dB', 'power', or 'amplitude'
# --- Denoising ---
spectral_mean_substraction: true # Enable spectral mean subtraction
# --- Resizing (Optional) ---
# Include this block to resize, omit or set to null to disable.
size:
height: 128 # Target height in frequency bins
resize_factor: 0.5 # Halve the number of time bins
# --- Final Normalization ---
peak_normalize: false # Do not scale max value to 1.0
```
## Outcome
The output of this module is the final, processed spectrogram (as a 2D numerical array with time and frequency information).
This spectrogram is now in the precise format expected by the BatDetect2 neural network, ready to be used for training the model or for making predictions on new data.
Remember, using the exact same `spectrogram` configuration settings during training and inference is essential for correct model performance.

View File

@ -0,0 +1,175 @@
# Using Preprocessors in BatDetect2
## Overview
In the previous sections ({doc}`audio`and {doc}`spectrogram`), we discussed the individual steps involved in converting raw audio into a processed spectrogram suitable for BatDetect2 models, and how to configure these steps using YAML files (specifically the `audio:` and `spectrogram:` sections within a main `preprocessing:` configuration block).
This page focuses on how this configured pipeline is represented and used within BatDetect2, primarily through the concept of a **`Preprocessor`** object.
This object bundles together your chosen audio loading settings and spectrogram generation settings into a single component that can perform the end-to-end processing.
## Do I Need to Interact with Preprocessors Directly?
**Usually, no.** For standard model training or running inference with BatDetect2 using the provided scripts, the system will automatically:
1. Read your main configuration file (e.g., `config.yaml`).
2. Find the `preprocessing:` section (containing `audio:` and `spectrogram:` settings).
3. Build the appropriate `Preprocessor` object internally based on your settings.
4. Use that internal `Preprocessor` object automatically whenever audio needs to be loaded and converted to a spectrogram.
**However**, understanding the `Preprocessor` object is useful if you want to:
- **Customize:** Go beyond the standard configuration options by interacting with parts of the pipeline programmatically.
- **Integrate:** Use BatDetect2's preprocessing steps within your own custom Python analysis scripts.
- **Inspect/Debug:** Manually apply preprocessing steps to specific files or clips to examine intermediate outputs (like the processed waveform) or the final spectrogram.
## Getting a Preprocessor Object
If you _do_ want to work with the preprocessor programmatically, you first need to get an instance of it.
This is typically done based on a configuration:
1. **Define Configuration:** Create your `preprocessing:` configuration, usually in a YAML file (let's call it `preprocess_config.yaml`), detailing your desired `audio` and `spectrogram` settings.
```yaml
# preprocess_config.yaml
audio:
resample:
samplerate: 256000
# ... other audio settings ...
spectrogram:
frequencies:
min_freq: 15000
max_freq: 120000
scale: dB
# ... other spectrogram settings ...
```
2. **Load Configuration & Build Preprocessor (in Python):**
```python
from batdetect2.preprocessing import load_preprocessing_config, build_preprocessor
from batdetect2.preprocess.types import Preprocessor # Import the type
# Load the configuration from the file
config_path = "path/to/your/preprocess_config.yaml"
preprocessing_config = load_preprocessing_config(config_path)
# Build the actual preprocessor object using the loaded config
preprocessor: Preprocessor = build_preprocessor(preprocessing_config)
# 'preprocessor' is now ready to use!
```
3. **Using Defaults:** If you just want the standard BatDetect2 default preprocessing settings, you can build one without loading a config file:
```python
from batdetect2.preprocessing import build_preprocessor
from batdetect2.preprocess.types import Preprocessor
# Build with default settings
default_preprocessor: Preprocessor = build_preprocessor()
```
## Applying Preprocessing
Once you have a `preprocessor` object, you can use its methods to process audio data:
**1.
End-to-End Processing (Common Use Case):**
These methods take an audio source identifier (file path, Recording object, or Clip object) and return the final, processed spectrogram.
- `preprocessor.preprocess_file(path)`: Processes an entire audio file.
- `preprocessor.preprocess_recording(recording_obj)`: Processes the entire audio associated with a `soundevent.data.Recording` object.
- `preprocessor.preprocess_clip(clip_obj)`: Processes only the specific time segment defined by a `soundevent.data.Clip` object.
- **Efficiency Note:** Using `preprocess_clip` is **highly recommended** when you are only interested in analyzing a small portion of a potentially long recording.
It avoids loading the entire audio file into memory, making it much more efficient.
```python
from soundevent import data
# Assume 'preprocessor' is built as shown before
# Assume 'my_clip' is a soundevent.data.Clip object defining a segment
# Process an entire file
spectrogram_from_file = preprocessor.preprocess_file("my_recording.wav")
# Process only a specific clip (more efficient for segments)
spectrogram_from_clip = preprocessor.preprocess_clip(my_clip)
# The results (spectrogram_from_file, spectrogram_from_clip) are xr.DataArrays
print(type(spectrogram_from_clip))
# Output: <class 'xarray.core.dataarray.DataArray'>
```
**2.
Intermediate Steps (Advanced Use Cases):**
The preprocessor also allows access to intermediate stages if needed:
- `preprocessor.load_clip_audio(clip_obj)` (and similar for file/recording): Loads the audio and applies _only_ the waveform processing steps (resampling, centering, etc.) defined in the `audio` config.
Returns the processed waveform as an `xr.DataArray`.
This is useful if you want to analyze or manipulate the waveform itself before spectrogram generation.
- `preprocessor.compute_spectrogram(waveform)`: Takes an _already loaded_ waveform (either `np.ndarray` or `xr.DataArray`) and applies _only_ the spectrogram generation steps defined in the `spectrogram` config.
- If you provide an `xr.DataArray` (e.g., from `load_clip_audio`), it uses the sample rate from the array's coordinates.
- If you provide a raw `np.ndarray`, it **must assume a sample rate**.
It uses the `default_samplerate` that was determined when the `preprocessor` was built (based on your `audio` config's resample settings or the global default).
Be cautious when using NumPy arrays to ensure the sample rate assumption is correct for your data!
```python
# Example: Get waveform first, then spectrogram
waveform = preprocessor.load_clip_audio(my_clip)
# waveform is an xr.DataArray
# ...potentially do other things with the waveform...
# Compute spectrogram from the loaded waveform
spectrogram = preprocessor.compute_spectrogram(waveform)
# Example: Process external numpy array (use with caution re: sample rate)
# import soundfile as sf # Requires installing soundfile
# numpy_waveform, original_sr = sf.read("some_other_audio.wav")
# # MUST ensure numpy_waveform's actual sample rate matches
# # preprocessor.default_samplerate for correct results here!
# spec_from_numpy = preprocessor.compute_spectrogram(numpy_waveform)
```
## Understanding the Output: `xarray.DataArray`
All preprocessing methods return the final spectrogram (or the intermediate waveform) as an **`xarray.DataArray`**.
**What is it?** Think of it like a standard NumPy array (holding the numerical data of the spectrogram) but with added "superpowers":
- **Labeled Dimensions:** Instead of just having axis 0 and axis 1, the dimensions have names, typically `"frequency"` and `"time"`.
- **Coordinates:** It stores the actual frequency values (e.g., in Hz) corresponding to each row and the actual time values (e.g., in seconds) corresponding to each column along the dimensions.
**Why is it used?**
- **Clarity:** The data is self-describing.
You don't need to remember which axis is time and which is frequency, or what the units are it's stored with the data.
- **Convenience:** You can select, slice, or plot data using the real-world coordinate values (times, frequencies) instead of just numerical indices.
This makes analysis code easier to write and less prone to errors.
- **Metadata:** It can hold additional metadata about the processing steps in its `attrs` (attributes) dictionary.
**Using the Output:**
- **Input to Model:** For standard training or inference, you typically pass this `xr.DataArray` spectrogram directly to the BatDetect2 model functions.
- **Inspection/Analysis:** If you're working programmatically, you can use xarray's powerful features.
For example (these are just illustrations of xarray):
```python
# Get the shape (frequency_bins, time_bins)
# print(spectrogram.shape)
# Get the frequency coordinate values
# print(spectrogram['frequency'].values)
# Select data near a specific time and frequency
# value_at_point = spectrogram.sel(time=0.5, frequency=50000, method="nearest")
# print(value_at_point)
# Select a time slice between 0.2 and 0.3 seconds
# time_slice = spectrogram.sel(time=slice(0.2, 0.3))
# print(time_slice.shape)
```
In summary, while BatDetect2 often handles preprocessing automatically based on your configuration, the underlying `Preprocessor` object provides a flexible interface for applying these steps programmatically if needed, returning results in the convenient and informative `xarray.DataArray` format.

View File

@ -0,0 +1,7 @@
# Config Reference
```{eval-rst}
.. automodule:: batdetect2.configs
:members:
:inherited-members: pydantic.BaseModel
```

View File

@ -0,0 +1,10 @@
# Reference documentation
```{eval-rst}
.. toctree::
:maxdepth: 1
:caption: Contents:
configs
targets
```

View File

@ -0,0 +1,6 @@
# Targets Reference
```{eval-rst}
.. automodule:: batdetect2.targets
:members:
```

View File

@ -0,0 +1,141 @@
# Step 4: Defining Target Classes and Decoding Rules
## Purpose and Context
You've prepared your data by defining your annotation vocabulary (Step 1: Terms), removing irrelevant sounds (Step 2: Filtering), and potentially cleaning up or modifying tags (Step 3: Transforming Tags).
Now, it's time for a crucial step with two related goals:
1. Telling `batdetect2` **exactly what categories (classes) your model should learn to identify** by defining rules that map annotation tags to class names (like `pippip`, `myodau`, or `noise`).
This process is often called **encoding**.
2. Defining how the model's predictions (those same class names) should be translated back into meaningful, structured **annotation tags** when you use the trained model.
This is often called **decoding**.
These definitions are essential for both training the model correctly and interpreting its output later.
## How it Works: Defining Classes with Rules
You define your target classes and their corresponding decoding rules in your main configuration file (e.g., your `.yaml` training config), typically under a section named `classes`.
This section contains:
1. A **list** of specific class definitions.
2. A definition for the **generic class** tags.
Each item in the `classes` list defines one specific class your model should learn.
## Defining a Single Class
Each specific class definition rule requires the following information:
1. `name`: **(Required)** This is the unique, simple name for this class (e.g., `pipistrellus_pipistrellus`, `myotis_daubentonii`, `noise`).
This label is used during training and is what the model predicts.
Choose clear, distinct names.
**Each class name must be unique.**
2. `tags`: **(Required)** This list contains one or more specific tags (using `key` and `value`) used to identify if an _existing_ annotation belongs to this class during the _encoding_ phase (preparing training data).
3. `match_type`: **(Optional, defaults to `"all"`)** Determines how the `tags` list is evaluated during _encoding_:
- `"all"`: The annotation must have **ALL** listed tags to match.
(Default).
- `"any"`: The annotation needs **AT LEAST ONE** listed tag to match.
4. `output_tags`: **(Optional)** This list specifies the tags that should be assigned to an annotation when the model _predicts_ this class `name`.
This is used during the _decoding_ phase (interpreting model output).
- **If you omit `output_tags` (or set it to `null`/~), the system will default to using the same tags listed in the `tags` field for decoding.** This is often what you want.
- Providing `output_tags` allows you to specify a different, potentially more canonical or detailed, set of tags to represent the class upon prediction.
For example, you could match based on simplified tags but output standardized tags.
**Example: Defining Species Classes (Encoding & Default Decoding)**
Here, the `tags` used for matching during encoding will also be used for decoding, as `output_tags` is omitted.
```yaml
# In your main configuration file
classes:
# Definition for the first class
- name: pippip # Simple name for Pipistrellus pipistrellus
tags: # Used for BOTH encoding match and decoding output
- key: species
value: Pipistrellus pipistrellus
# match_type defaults to "all"
# output_tags is omitted, defaults to using 'tags' above
# Definition for the second class
- name: myodau # Simple name for Myotis daubentonii
tags: # Used for BOTH encoding match and decoding output
- key: species
value: Myotis daubentonii
```
**Example: Defining a Class with Separate Encoding and Decoding Tags**
Here, we match based on _either_ of two tags (`match_type: any`), but when the model predicts `'pipistrelle'`, we decode it _only_ to the specific `Pipistrellus pipistrellus` tag plus a genus tag.
```yaml
classes:
- name: pipistrelle # Name for a Pipistrellus group
match_type: any # Match if EITHER tag below is present during encoding
tags:
- key: species
value: Pipistrellus pipistrellus
- key: species
value: Pipistrellus pygmaeus # Match pygmaeus too
output_tags: # BUT, when decoding 'pipistrelle', assign THESE tags:
- key: species
value: Pipistrellus pipistrellus # Canonical species
- key: genus # Assumes 'genus' key exists
value: Pipistrellus # Add genus tag
```
## Handling Overlap During Encoding: Priority Order Matters
As before, when preparing training data (encoding), if an annotation matches the `tags` and `match_type` rules for multiple class definitions, the **order of the class definitions in the configuration list determines the priority**.
- The system checks rules from the **top** of the `classes` list down.
- The annotation gets assigned the `name` of the **first class rule it matches**.
- **Place more specific rules before more general rules.**
_(The YAML example for prioritizing Species over Noise remains the same as the previous version)_
## Handling Non-Matches & Decoding the Generic Class
What happens if an annotation passes filtering/transformation but doesn't match any specific class rule during encoding?
- **Encoding:** As explained previously, these annotations are **not ignored**.
They are typically assigned to a generic "relevant sound" category, often called the **"Bat"** class in BatDetect2, intended for all relevant bat calls not specifically classified.
- **Decoding:** When the model predicts this generic "Bat" category (or when processing sounds that weren't assigned a specific class during encoding), we need a way to represent this generic status with tags.
This is defined by the `generic_class` list directly within the main `classes` configuration section.
**Defining the Generic Class Tags:**
You specify the tags for the generic class like this:
```yaml
# In your main configuration file
classes: # Main configuration section for classes
# --- List of specific class definitions ---
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
# ... other specific classes ...
# --- Definition of the generic class tags ---
generic_class: # Define tags for the generic 'Bat' category
- key: call_type
value: Echolocation
- key: order
value: Chiroptera
# These tags will be assigned when decoding the generic category
```
This `generic_class` list provides the standard tags assigned when a sound is identified as relevant (passed filtering) but doesn't belong to one of the specific target classes you defined.
Like the specific classes, sensible defaults are often provided if you don't explicitly define `generic_class`.
**Crucially:** Remember, if sounds should be **completely excluded** from training (not even considered "generic"), use **Filtering rules (Step 2)**.
### Outcome
By defining this list of prioritized class rules (including their `name`, matching `tags`, `match_type`, and optional `output_tags`) and the `generic_class` tags, you provide `batdetect2` with:
1. A clear procedure to assign a target label (`name`) to each relevant annotation for training.
2. A clear mapping to convert predicted class names (including the generic case) back into meaningful annotation tags.
This complete definition prepares your data for the final heatmap generation (Step 5) and enables interpretation of the model's results.

View File

@ -0,0 +1,141 @@
# Step 2: Filtering Sound Events
## Purpose
When preparing your annotated audio data for training a `batdetect2` model, you often want to select only specific sound events.
For example, you might want to:
- Focus only on echolocation calls and ignore social calls or noise.
- Exclude annotations that were marked as low quality.
- Train only on specific species or groups of species.
This filtering module allows you to define rules based on the **tags** associated with each sound event annotation.
Only the events that pass _all_ your defined rules will be kept for further processing and training.
## How it Works: Rules
Filtering is controlled by a list of **rules**.
Each rule defines a condition based on the tags attached to a sound event.
An event must satisfy **all** the rules you define in your configuration to be included.
If an event fails even one rule, it is discarded.
## Defining Rules in Configuration
You define these rules within your main configuration file (usually a `.yaml` file) under a specific section (the exact name might depend on the main training config, but let's assume it's called `filtering`).
The configuration consists of a list named `rules`.
Each item in this list is a single filter rule.
Each **rule** has two parts:
1. `match_type`: Specifies the _kind_ of check to perform.
2. `tags`: A list of specific tags (each with a `key` and `value`) that the rule applies to.
```yaml
# Example structure in your configuration file
filtering:
rules:
- match_type: <TYPE_OF_CHECK_1>
tags:
- key: <tag_key_1a>
value: <tag_value_1a>
- key: <tag_key_1b>
value: <tag_value_1b>
- match_type: <TYPE_OF_CHECK_2>
tags:
- key: <tag_key_2a>
value: <tag_value_2a>
# ... add more rules as needed
```
## Understanding `match_type`
This determines _how_ the list of `tags` in the rule is used to check a sound event.
There are four types:
1. **`any`**: (Keep if _at least one_ tag matches)
- The sound event **passes** this rule if it has **at least one** of the tags listed in the `tags` section of the rule.
- Think of it as an **OR** condition.
- _Example Use Case:_ Keep events if they are tagged as `Species: Pip Pip` OR `Species: Pip Pyg`.
2. **`all`**: (Keep only if _all_ tags match)
- The sound event **passes** this rule only if it has **all** of the tags listed in the `tags` section.
The event can have _other_ tags as well, but it must contain _all_ the ones specified here.
- Think of it as an **AND** condition.
- _Example Use Case:_ Keep events only if they are tagged with `Sound Type: Echolocation` AND `Quality: Good`.
3. **`exclude`**: (Discard if _any_ tag matches)
- The sound event **passes** this rule only if it does **not** have **any** of the tags listed in the `tags` section.
If it matches even one tag in the list, the event is discarded.
- _Example Use Case:_ Discard events if they are tagged `Quality: Poor` OR `Noise Source: Insect`.
4. **`equal`**: (Keep only if tags match _exactly_)
- The sound event **passes** this rule only if its set of tags is _exactly identical_ to the list of `tags` provided in the rule (no more, no less).
- _Note:_ This is very strict and usually less useful than `all` or `any`.
## Combining Rules
Remember: A sound event must **pass every single rule** defined in the `rules` list to be kept.
The rules are checked one by one, and if an event fails any rule, it's immediately excluded from further consideration.
## Examples
**Example 1: Keep good quality echolocation calls**
```yaml
filtering:
rules:
# Rule 1: Must have the 'Echolocation' tag
- match_type: any # Could also use 'all' if 'Sound Type' is the only tag expected
tags:
- key: Sound Type
value: Echolocation
# Rule 2: Must NOT have the 'Poor' quality tag
- match_type: exclude
tags:
- key: Quality
value: Poor
```
_Explanation:_ An event is kept only if it passes BOTH rules.
It must have the `Sound Type: Echolocation` tag AND it must NOT have the `Quality: Poor` tag.
**Example 2: Keep calls from Pipistrellus species recorded in a specific project, excluding uncertain IDs**
```yaml
filtering:
rules:
# Rule 1: Must be either Pip pip or Pip pyg
- match_type: any
tags:
- key: Species
value: Pipistrellus pipistrellus
- key: Species
value: Pipistrellus pygmaeus
# Rule 2: Must belong to 'Project Alpha'
- match_type: any # Using 'any' as it likely only has one project tag
tags:
- key: Project ID
value: Project Alpha
# Rule 3: Exclude if ID Certainty is 'Low' or 'Maybe'
- match_type: exclude
tags:
- key: ID Certainty
value: Low
- key: ID Certainty
value: Maybe
```
_Explanation:_ An event is kept only if it passes ALL three rules:
1. It has a `Species` tag that is _either_ `Pipistrellus pipistrellus` OR `Pipistrellus pygmaeus`.
2. It has the `Project ID: Project Alpha` tag.
3. It does _not_ have an `ID Certainty: Low` tag AND it does _not_ have an `ID Certainty: Maybe` tag.
## Usage
You will typically specify the path to the configuration file containing these `filtering` rules when you set up your data processing or training pipeline in `batdetect2`.
The tool will then automatically load these rules and apply them to your annotated sound events.

View File

@ -0,0 +1,78 @@
# Defining Training Targets
A crucial aspect of training any supervised machine learning model, including BatDetect2, is clearly defining the **training targets**.
This process determines precisely what the model should learn to detect, localize, classify, and characterize from the input data (in this case, spectrograms).
The choices made here directly influence the model's focus, its performance, and how its predictions should be interpreted.
For BatDetect2, defining targets involves specifying:
- Which sounds in your annotated dataset are relevant for training.
- How these sounds should be categorized into distinct **classes** (e.g., different species).
- How the geometric **Region of Interest (ROI)** (e.g., bounding box) of each sound maps to the specific **position** and **size** targets the model predicts.
- How these classes and geometric properties relate back to the detailed information stored in your annotation **tags** (using a consistent **vocabulary/terms**).
- How the model's output (predicted class names, positions, sizes) should be translated back into meaningful tags and geometries.
## Sound Event Annotations: The Starting Point
BatDetect2 assumes your training data consists of audio recordings where relevant sound events have been **annotated**.
A typical annotation for a single sound event provides two key pieces of information:
1. **Location & Extent:** Information defining _where_ the sound occurs in time and frequency, usually represented as a **bounding box** (the ROI) drawn on a spectrogram.
2. **Description (Tags):** Information _about_ the sound event, provided as a set of descriptive **tags** (key-value pairs).
For example, an annotation might have a bounding box and tags like:
- `species: Myotis daubentonii`
- `quality: Good`
- `call_type: Echolocation`
A single sound event can have **multiple tags**, allowing for rich descriptions.
This richness requires a structured process to translate the annotation (both tags and geometry) into the precise targets needed for model training.
The **target definition process** provides clear rules to:
- Interpret the meaning of different tag keys (**Terms**).
- Select only the relevant annotations (**Filtering**).
- Potentially standardize or modify the tags (**Transforming**).
- Map the geometric ROI to specific position and size targets (**ROI Mapping**).
- Map the final set of tags on each selected annotation to a single, definitive **target class** label (**Classes**).
## Configuration-Driven Workflow
BatDetect2 is designed so that researchers can configure this entire target definition process primarily through **configuration files** (typically written in YAML format), minimizing the need for direct programming for standard workflows.
These settings are usually grouped under a main `targets:` key within your overall training configuration file.
## The Target Definition Steps
Defining the targets involves several sequential steps, each configurable and building upon the previous one:
1. **Defining Vocabulary (Terms & Tags):** Understand how annotations use tags (key-value pairs).
This step involves defining the meaning (**Terms**) behind the tag keys (e.g., `species`, `call_type`).
Often, default terms are sufficient, but understanding this is key to using tags in later steps.
(See: {doc}`tags_and_terms`})
2. **Filtering Sound Events:** Select only the relevant sound event annotations based on their tags (e.g., keeping only high-quality calls).
(See: {doc}`filtering`})
3. **Transforming Tags (Optional):** Modify tags on selected annotations for standardization, correction, grouping (e.g., species to genus), or deriving new tags.
(See: {doc}`transform`})
4. **Defining Classes & Decoding Rules:** Map the final tags to specific target **class names** (like `pippip` or `myodau`).
Define priorities for overlap and specify how predicted names map back to tags (decoding).
(See: {doc}`classes`})
5. **Mapping ROIs (Position & Size):** Define how the geometric ROI (e.g., bounding box) of each sound event maps to the specific reference **point** (e.g., center, corner) and scaled **size** values (width, height) used as targets by the model.
(See: {doc}`rois`})
6. **The `Targets` Object:** Understand the outcome of configuring steps 1-5 a functional object used internally by BatDetect2 that encapsulates all your defined rules for filtering, transforming, ROI mapping, encoding, and decoding.
(See: {doc}`use`)
The result of this configuration process is a clear set of instructions that BatDetect2 uses during training data preparation to determine the correct "answer" (the ground truth label and geometry representation) for each relevant sound event.
Explore the detailed steps using the links below:
```{toctree}
:maxdepth: 1
:caption: Target Definition Steps:
tags_and_terms
filtering
transform
classes
rois
use
```

View File

@ -0,0 +1,76 @@
# Step 5: Generating Training Targets
## Purpose and Context
Following the previous steps of defining terms, filtering events, transforming tags, and defining specific class rules, this final stage focuses on **generating the ground truth data** used directly for training the BatDetect2 model.
This involves converting the refined annotation information for each audio clip into specific **heatmap formats** required by the underlying neural network architecture.
This step essentially translates your structured annotations into the precise "answer key" the model learns to replicate during training.
## What are Heatmaps?
Heatmaps, in this context, are multi-dimensional arrays, often visualized as images aligned with the input spectrogram, where the values at different time-frequency coordinates represent specific information about the sound events.
For BatDetect2 training, three primary heatmaps are generated:
1. **Detection Heatmap:**
- **Represents:** The presence or likelihood of relevant sound events across the spectrogram.
- **Structure:** A 2D array matching the spectrogram's time-frequency dimensions.
Peaks (typically smoothed) are generated at the reference locations of all sound events that passed the filtering stage (including both specifically classified events and those falling into the generic "Bat" category).
2. **Class Heatmap:**
- **Represents:** The location and class identity for sounds belonging to the _specific_ target classes you defined in Step 4.
- **Structure:** A 3D array with dimensions for category, time, and frequency.
It contains a separate 2D layer (channel) for each target class name (e.g., 'pippip', 'myodau').
A smoothed peak appears on a specific class layer only if a sound event assigned to that class exists at that location.
Events assigned only to the generic class do not produce peaks here.
3. **Size Heatmap:**
- **Represents:** The target dimensions (duration/width and bandwidth/height) of detected sound events.
- **Structure:** A 3D array with dimensions for size-dimension ('width', 'height'), time, and frequency.
At the reference location of each detected sound event, this heatmap stores two numerical values corresponding to the scaled width and height derived from the event's bounding box.
## How Heatmaps are Created
The generation of these heatmaps is an automated process within `batdetect2`, driven by your configurations from all previous steps.
For each audio clip and its corresponding spectrogram in the training dataset:
1. The system retrieves the associated sound event annotations.
2. Configured **filtering rules** (Step 2) are applied to select relevant annotations.
3. Configured **tag transformation rules** (Step 3) are applied to modify the tags of the selected annotations.
4. Configured **class definition rules** (Step 4) are used to assign a specific class name or determine generic "Bat" status for each processed annotation.
5. These final annotations are then mapped onto initialized heatmap arrays:
- A signal (initially a single point) is placed on the **Detection Heatmap** at the reference location for each relevant annotation.
- The scaled width and height values are placed on the **Size Heatmap** at the reference location.
- If an annotation received a specific class name, a signal is placed on the corresponding layer of the **Class Heatmap** at the reference location.
6. Finally, Gaussian smoothing (a blurring effect) is typically applied to the Detection and Class heatmaps to create spatially smoother targets, which often aids model training stability and performance.
## Configurable Settings for Heatmap Generation
While the content of the heatmaps is primarily determined by the previous configuration steps, a few parameters specific to the heatmap drawing process itself can be adjusted.
These are usually set in your main configuration file under a section like `labelling`:
- `sigma`: (Number, e.g., `3.0`) Defines the standard deviation, in pixels or bins, of the Gaussian kernel used for smoothing the Detection and Class heatmaps.
Larger values result in more diffused heatmap peaks.
- `position`: (Text, e.g., `"bottom-left"`, `"center"`) Specifies the geometric reference point within each sound event's bounding box that anchors its representation on the heatmaps.
- `time_scale` & `frequency_scale`: (Numbers) These crucial scaling factors convert the physical duration (in seconds) and frequency bandwidth (in Hz) of annotation bounding boxes into the numerical values stored in the 'width' and 'height' channels of the Size Heatmap.
- **Important Note:** The appropriate values for these scales are dictated by the requirements of the specific BatDetect2 model architecture being trained.
They ensure the size information is presented in the units or relative scale the model expects.
**Consult the documentation or tutorials for your specific model to determine the correct `time_scale` and `frequency_scale` values.** Mismatched scales can hinder the model's ability to learn size regression accurately.
**Example YAML Configuration for Labelling Settings:**
```yaml
# In your main configuration file
labelling:
sigma: 3.0 # Std. dev. for Gaussian smoothing (pixels/bins)
position: "bottom-left" # Bounding box reference point
time_scale: 1000.0 # Example: Scales seconds to milliseconds
frequency_scale: 0.00116 # Example: Scales Hz relative to ~860 Hz (model specific!)
```
## Outcome: Final Training Targets
Executing this step for all training data yields the complete set of target heatmaps (Detection, Class, Size) for each corresponding input spectrogram.
These arrays constitute the ground truth data that the BatDetect2 model directly compares its predictions against during the training phase, guiding its learning process.

View File

@ -0,0 +1,85 @@
## Defining Target Geometry: Mapping Sound Event Regions
### Introduction
In the previous steps of defining targets, we focused on determining _which_ sound events are relevant (`filtering`), _what_ descriptive tags they should have (`transform`), and _which category_ they belong to (`classes`).
However, for the model to learn effectively, it also needs to know **where** in the spectrogram each sound event is located and approximately **how large** it is.
Your annotations typically define the location and extent of a sound event using a **Region of Interest (ROI)**, most commonly a **bounding box** drawn around the call on the spectrogram.
This ROI contains detailed spatial information (start/end time, low/high frequency).
This section explains how BatDetect2 converts the geometric ROI from your annotations into the specific positional and size information used as targets during model training.
### From ROI to Model Targets: Position & Size
BatDetect2 does not directly predict a full bounding box.
Instead, it is trained to predict:
1. **A Reference Point:** A single point `(time, frequency)` that represents the primary location of the detected sound event within the spectrogram.
2. **Size Dimensions:** Numerical values representing the event's size relative to that reference point, typically its `width` (duration in time) and `height` (bandwidth in frequency).
This step defines _how_ BatDetect2 calculates this specific reference point and these numerical size values from the original annotation's bounding box.
It also handles the reverse process converting predicted positions and sizes back into bounding boxes for visualization or analysis.
### Configuring the ROI Mapping
You can control how this conversion happens through settings in your configuration file (e.g., your main `.yaml` file).
These settings are usually placed within the main `targets:` configuration block, under a specific `roi:` key.
Here are the key settings:
- **`position`**:
- **What it does:** Determines which specific point on the annotation's bounding box is used as the single **Reference Point** for training (e.g., `"center"`, `"bottom-left"`).
- **Why configure it?** This affects where the peak signal appears in the target heatmaps used for training.
Different choices might slightly influence model learning.
The default (`"bottom-left"`) is often a good starting point.
- **Example Value:** `position: "center"`
- **`time_scale`**:
- **What it does:** This is a numerical scaling factor that converts the _actual duration_ (width, measured in seconds) of the bounding box into the numerical 'width' value the model learns to predict (and which is stored in the Size Heatmap).
- **Why configure it?** The model predicts raw numbers for size; this scale gives those numbers meaning.
For example, setting `time_scale: 1000.0` means the model will be trained to predict the duration in **milliseconds** instead of seconds.
- **Important Considerations:**
- You can often set this value based on the units you prefer the model to work with internally.
However, having target numerical values roughly centered around 1 (e.g., typically between 0.1 and 10) can sometimes improve numerical stability during model training.
- The default value in BatDetect2 (e.g., `1000.0`) has been chosen to scale the duration relative to the spectrogram width under default STFT settings.
Be aware that if you significantly change STFT parameters (window size or overlap), the relationship between the default scale and the spectrogram dimensions might change.
- Crucially, whatever scale you use during training **must** be used when decoding the model's predictions back into real-world time units (seconds).
BatDetect2 generally handles this consistency for you automatically when using the full pipeline.
- **Example Value:** `time_scale: 1000.0`
- **`frequency_scale`**:
- **What it does:** Similar to `time_scale`, this numerical scaling factor converts the _actual frequency bandwidth_ (height, typically measured in Hz or kHz) of the bounding box into the numerical 'height' value the model learns to predict.
- **Why configure it?** It gives physical meaning to the model's raw numerical prediction for bandwidth and allows you to choose the internal units or scale.
- **Important Considerations:**
- Same as for `time_scale`.
- **Example Value:** `frequency_scale: 0.00116`
**Example YAML Configuration:**
```yaml
# Inside your main configuration file (e.g., training_config.yaml)
targets: # Top-level key for target definition
# ... filtering settings ...
# ... transforms settings ...
# ... classes settings ...
# --- ROI Mapping Settings ---
roi:
position: "bottom-left" # Reference point (e.g., "center", "bottom-left")
time_scale: 1000.0 # e.g., Model predicts width in ms
frequency_scale: 0.00116 # e.g., Model predicts height relative to ~860Hz (or other model-specific scaling)
```
### Decoding Size Predictions
These scaling factors (`time_scale`, `frequency_scale`) are also essential for interpreting the model's output correctly.
When the model predicts numerical values for width and height, BatDetect2 uses these same scales (in reverse) to convert those numbers back into physically meaningful durations (seconds) and bandwidths (Hz/kHz) when reconstructing bounding boxes from predictions.
### Outcome
By configuring the `roi` settings, you ensure that BatDetect2 consistently translates the geometric information from your annotations into the specific reference points and scaled size values required for training the model.
Using consistent scales that are appropriate for your data and potentially beneficial for training stability allows the model to effectively learn not just _what_ sound is present, but also _where_ it is located and _how large_ it is, and enables meaningful interpretation of the model's spatial and size predictions.

View File

@ -0,0 +1,166 @@
# Step 1: Managing Annotation Vocabulary
## Purpose
To train `batdetect2`, you will need sound events that have been carefully annotated. We annotate sound events using **tags**. A tag is simply a piece of information attached to an annotation, often describing what the sound is or its characteristics. Common examples include `Species: Myotis daubentonii` or `Quality: Good`.
Each tag fundamentally has two parts:
* **Value:** The specific information (e.g., "Myotis daubentonii", "Good").
* **Term:** The *type* of information (e.g., "Species", "Quality"). This defines the context or meaning of the value.
We use this flexible **Term: Value** approach because it allows you to annotate your data with any kind of information relevant to your project, while still providing a structure that makes the meaning clear.
While simple terms like "Species" are easy to understand, sometimes the underlying definition needs to be more precise to ensure everyone interprets it the same way (e.g., using a standard scientific definition for "Species" or clarifying what "Call Type" specifically refers to).
This `terms` module is designed to help manage these definitions effectively:
1. It provides **standard definitions** for common terms used in bioacoustics, ensuring consistency.
2. It lets you **define your own custom terms** if you need concepts specific to your project.
3. Crucially, it allows you to use simple **"keys"** (like shortcuts) in your configuration files to refer to these potentially complex term definitions, making configuration much easier and less error-prone.
## The Problem: Why We Need Defined Terms
Imagine you have a tag that simply says `"Myomyo"`.
If you created this tag, you might know it's a shortcut for the species _Myotis myotis_.
But what about someone else using your data or model later? Does `"Myomyo"` refer to the species? Or maybe it's the name of an individual bat, or even the location where it was recorded? Simple tags like this can be ambiguous.
To make things clearer, it's good practice to provide context.
We can do this by pairing the specific information (the **Value**) with the _type_ of information (the **Term**).
For example, writing the tag as `species: Myomyo` is much less ambiguous.
Here, `species` is the **Term**, explaining that `Myomyo` is a **Value** representing a species.
However, another challenge often comes up when sharing data or collaborating.
You might use the term `species`, while a colleague uses `Species`, and someone else uses the more formal `Scientific Name`.
Even though you all mean the same thing, these inconsistencies make it hard to combine data or reuse analysis pipelines automatically.
This is where standardized **Terms** become very helpful.
Several groups work to create standard definitions for common concepts.
For instance, the Darwin Core standard provides widely accepted terms for biological data, like `dwc:scientificName` for a species name.
Using standard Terms whenever possible makes your data clearer, easier for others (and machines!) to understand correctly, and much more reusable across different projects.
**But here's the practical problem:** While using standard, well-defined Terms is important for clarity and reusability, writing out full definitions or long standard names (like `dwc:scientificName` or "Scientific Name according to Darwin Core standard") every single time you need to refer to a species tag in a configuration file would be extremely tedious and prone to typing errors.
## The Solution: Keys (Shortcuts) and the Registry
This module uses a central **Registry** that stores the full definitions of various Terms.
Each Term in the registry is assigned a unique, short **key** (a simple string).
Think of the **key** as shortcut.
Instead of using the full Term definition in your configuration files, you just use its **key**.
The system automatically looks up the full definition in the registry using the key when needed.
**Example:**
- **Full Term Definition:** Represents the scientific name of the organism.
- **Key:** `species`
- **In Config:** You just write `species`.
## Available Keys
The registry comes pre-loaded with keys for many standard terms used in bioacoustics, including those from the `soundevent` package and some specific to `batdetect2`. This means you can often use these common concepts without needing to define them yourself.
Common examples of pre-defined keys might include:
* `species`: For scientific species names (e.g., *Myotis daubentonii*).
* `common_name`: For the common name of a species (e.g., "Daubenton's bat").
* `genus`, `family`, `order`: For higher levels of biological taxonomy.
* `call_type`: For functional call types (e.g., 'Echolocation', 'Social').
* `individual`: For identifying specific individuals if tracked.
* `class`: **(Special Key)** This key is often used **by default** in configurations when defining the target classes for your model (e.g., the different species you want the model to classify). If you are specifying a tag that represents a target class label, you often only need to provide the `value`, and the system assumes the `key` is `class`.
This is not an exhaustive list. To discover all the term keys currently available in the registry (including any standard ones loaded automatically and any custom ones you've added in your configuration), you can:
1. Use the function `batdetect2.terms.get_term_keys()` if you are working directly with Python code.
2. Refer to the main `batdetect2` API documentation for a list of commonly included standard terms.
Okay, let's refine the "Defining Your Own Terms" section to incorporate the explanation about namespacing within the `name` field description, keeping the style clear and researcher-focused.
## Defining Your Own Terms
While many common terms have pre-defined keys, you might need a term specific to your project or data that isn't already available (e.g., "Recording Setup", "Weather Condition", "Project Phase", "Noise Source"). You can easily define these custom terms directly within a configuration file (usually your main `.yaml` file).
Typically, you define custom terms under a dedicated section (often named `terms`). Inside this section, you create a list, where each item in the list defines one new term using the following fields:
* `key`: **(Required)** This is the unique shortcut key or nickname you will use to refer to this term throughout your configuration (e.g., `weather`, `setup_id`, `noise_src`). Choose something short and memorable.
* `label`: (Optional) A user-friendly label for the term, which might be used in reports or visualizations (e.g., "Weather Condition", "Setup ID"). If you don't provide one, it defaults to using the `key`.
* `name`: (Optional) A more formal or technical name for the term.
* It's good practice, especially if defining terms that might overlap with standard vocabularies, to use a **namespaced format** like `<namespace>:<term_name>`. The `namespace` part helps avoid clashes with terms defined elsewhere. For example, the standard Darwin Core term for scientific name is `dwc:scientificName`, where `dwc` is the namespace for Darwin Core. Using namespaces makes your custom terms more specific and reduces potential confusion.
* If you don't provide a `name`, it defaults to using the `key`.
* `definition`: (Optional) A brief text description explaining what this term represents (e.g., "The primary source of background noise identified", "General weather conditions during recording"). If omitted, it defaults to "Unknown".
* `uri`: (Optional) If your term definition comes directly from a standard online vocabulary (like Darwin Core), you can include its unique web identifier (URI) here.
**Example YAML Configuration for Custom Terms:**
```yaml
# In your main configuration file
# (Optional section to define custom terms)
terms:
- key: weather # Your chosen shortcut
label: Weather Condition
name: myproj:weather # Formal namespaced name
definition: General weather conditions during recording (e.g., Clear, Rain, Fog).
- key: setup_id # Another shortcut
label: Recording Setup ID
name: myproj:setupID # Formal namespaced name
definition: The unique identifier for the specific hardware setup used.
- key: species # Defining a term with a standard URI
label: Scientific Name
name: dwc:scientificName
uri: http://rs.tdwg.org/dwc/terms/scientificName # Example URI
definition: The full scientific name according to Darwin Core.
# ... other configuration sections ...
```
When `batdetect2` loads your configuration, it reads this `terms` section and adds your custom definitions (linked to their unique keys) to the central registry. These keys (`weather`, `setup_id`, etc.) are then ready to be used in other parts of your configuration, like defining filters or target classes.
## Using Keys to Specify Tags (in Filters, Class Definitions, etc.)
Now that you have keys for all the terms you need (both pre-defined and custom), you can easily refer to specific **tags** in other parts of your configuration, such as:
- Filtering rules (as seen in the `filtering` module documentation).
- Defining which tags represent your target classes.
- Associating extra information with your classes.
When you need to specify a tag, you typically use a structure with two fields:
- `key`: The **key** (shortcut) for the _Term_ part of the tag (e.g., `species`, `quality`, `weather`).
**It defaults to `class`** if you omit it, which is common when defining the main target classes.
- `value`: The specific _value_ of the tag (e.g., `Myotis daubentonii`, `Good`, `Rain`).
**Example YAML Configuration using TagInfo (e.g., inside a filter rule):**
```yaml
# ... inside a filtering configuration section ...
rules:
# Rule: Exclude events recorded in 'Rain'
- match_type: exclude
tags:
- key: weather # Use the custom term key defined earlier
value: Rain
# Rule: Keep only 'Myotis daubentonii' (using the default 'class' key implicitly)
- match_type: any # Or 'all' depending on logic
tags:
- value: Myotis daubentonii # 'key: class' is assumed by default here
# key: class # Explicitly writing this is also fine
# Rule: Keep only 'Good' quality events
- match_type: any # Or 'all' depending on logic
tags:
- key: quality # Use a likely pre-defined key
value: Good
```
## Summary
- Annotations have **tags** (Term + Value).
- This module uses short **keys** as shortcuts for Term definitions, stored in a **registry**.
- Many **common keys are pre-defined**.
- You can define **custom terms and keys** in your configuration file (using `key`, `label`, `definition`).
- You use these **keys** along with specific **values** to refer to tags in other configuration sections (like filters or class definitions), often defaulting to the `class` key.
This system makes your configurations cleaner, more readable, and less prone to errors by avoiding repetition of complex term definitions.

View File

@ -0,0 +1,118 @@
# Step 3: Transforming Annotation Tags (Optional)
## Purpose and Context
After defining your vocabulary (Step 1: Terms) and filtering out irrelevant sound events (Step 2: Filtering), you have a dataset of annotations ready for the next stages.
Before you select the final target classes for training (Step 4), you might want or need to **modify the tags** associated with your annotations.
This optional step allows you to clean up, standardize, or derive new information from your existing tags.
**Why transform tags?**
- **Correcting Mistakes:** Fix typos or incorrect values in specific tags (e.g., changing an incorrect species label).
- **Standardizing Labels:** Ensure consistency if the same information was tagged using slightly different values (e.g., mapping "echolocation", "Echoloc.", and "Echolocation Call" all to a single standard value: "Echolocation").
- **Grouping Related Concepts:** Combine different specific tags into a broader category (e.g., mapping several different species tags like _Myotis daubentonii_ and _Myotis nattereri_ to a single `genus: Myotis` tag).
- **Deriving New Information:** Automatically create new tags based on existing ones (e.g., automatically generating a `genus: Myotis` tag whenever a `species: Myotis daubentonii` tag is present).
This step uses the `batdetect2.targets.transform` module to apply these changes based on rules you define.
## How it Works: Transformation Rules
You control how tags are transformed by defining a list of **rules** in your configuration file (e.g., your main `.yaml` file, often under a section named `transform`).
Each rule specifies a particular type of transformation to perform.
Importantly, the rules are applied **sequentially**, in the exact order they appear in your configuration list.
The output annotation from one rule becomes the input for the next rule in the list.
This means the order can matter!
## Types of Transformation Rules
Here are the main types of rules you can define:
1. **Replace an Exact Tag (`replace`)**
- **Use Case:** Fixing a specific, known incorrect tag.
- **How it works:** You specify the _exact_ original tag (both its term key and value) and the _exact_ tag you want to replace it with.
- **Example Config:** Replace the informal tag `species: Pip pip` with the correct scientific name tag.
```yaml
transform:
rules:
- rule_type: replace
original:
key: species # Term key of the tag to find
value: "Pip pip" # Value of the tag to find
replacement:
key: species # Term key of the replacement tag
value: "Pipistrellus pipistrellus" # Value of the replacement tag
```
2. **Map Values (`map_value`)**
- **Use Case:** Standardizing different values used for the same concept, or grouping multiple specific values into one category.
- **How it works:** You specify a `source_term_key` (the type of tag to look at, e.g., `call_type`).
Then you provide a `value_mapping` dictionary listing original values and the new values they should be mapped to.
Only tags matching the `source_term_key` and having a value listed in the mapping will be changed.
You can optionally specify a `target_term_key` if you want to change the term type as well (e.g., mapping species to a genus).
- **Example Config:** Standardize different ways "Echolocation" might have been written for the `call_type` term.
```yaml
transform:
rules:
- rule_type: map_value
source_term_key: call_type # Look at 'call_type' tags
# target_term_key is not specified, so the term stays 'call_type'
value_mapping:
echolocation: Echolocation
Echolocation Call: Echolocation
Echoloc.: Echolocation
# Add mappings for other values like 'Social' if needed
```
- **Example Config (Grouping):** Map specific Pipistrellus species tags to a single `genus: Pipistrellus` tag.
```yaml
transform:
rules:
- rule_type: map_value
source_term_key: species # Look at 'species' tags
target_term_key: genus # Change the term to 'genus'
value_mapping:
"Pipistrellus pipistrellus": Pipistrellus
"Pipistrellus pygmaeus": Pipistrellus
"Pipistrellus nathusii": Pipistrellus
```
3. **Derive a New Tag (`derive_tag`)**
- **Use Case:** Automatically creating new information based on existing tags, like getting the genus from a species name.
- **How it works:** You specify a `source_term_key` (e.g., `species`).
You provide a `target_term_key` for the new tag to be created (e.g., `genus`).
You also provide the name of a `derivation_function` (e.g., `"extract_genus"`) that knows how to perform the calculation (e.g., take "Myotis daubentonii" and return "Myotis").
`batdetect2` has some built-in functions, or you can potentially define your own (see advanced documentation).
You can also choose whether to keep the original source tag (`keep_source: true`).
- **Example Config:** Create a `genus` tag from the existing `species` tag, keeping the species tag.
```yaml
transform:
rules:
- rule_type: derive_tag
source_term_key: species # Use the value from the 'species' tag
target_term_key: genus # Create a tag with the 'genus' term
derivation_function: extract_genus # Use the built-in function for this
keep_source: true # Keep the original 'species' tag
```
- **Another Example:** Convert species names to uppercase (modifying the value of the _same_ term).
```yaml
transform:
rules:
- rule_type: derive_tag
source_term_key: species # Use the value from the 'species' tag
# target_term_key is not specified, so the term stays 'species'
derivation_function: to_upper_case # Assume this function exists
keep_source: false # Replace the original species tag
```
## Rule Order Matters
Remember that rules are applied one after another.
If you have multiple rules, make sure they are ordered correctly to achieve the desired outcome.
For instance, you might want to standardize species names _before_ deriving the genus from them.
## Outcome
After applying all the transformation rules you've defined, the annotations will proceed to the next step (Step 4: Select Target Tags & Define Classes) with their tags potentially cleaned, standardized, or augmented based on your configuration.
If you don't define any rules, the tags simply pass through this step unchanged.

View File

@ -0,0 +1,91 @@
## Bringing It All Together: The `Targets` Object
### Recap: Defining Your Target Strategy
In the previous sections, we covered the sequential steps to precisely define what your BatDetect2 model should learn, specified within your configuration file:
1. **Terms:** Establishing the vocabulary for annotation tags.
2. **Filtering:** Selecting relevant sound event annotations.
3. **Transforming:** Optionally modifying tags.
4. **Classes:** Defining target categories, setting priorities, and specifying tag decoding rules.
5. **ROI Mapping:** Defining how annotation geometry maps to target position and size values.
You define all these aspects within your configuration file (e.g., YAML), which holds the complete specification for your target definition strategy, typically under a main `targets:` key.
### What is the `Targets` Object?
While the configuration file specifies _what_ you want to happen, BatDetect2 needs an active component to actually _perform_ these steps.
This is the role of the `Targets` object.
The `Targets` is an organized container that holds all the specific functions and settings derived from your configuration file (`TargetConfig`).
It's created directly from your configuration and provides methods to apply the **filtering**, **transformation**, **ROI mapping** (geometry to position/size and back), **class encoding**, and **class decoding** steps you defined.
It effectively bundles together all the target definition logic determined by your settings into a single, usable object.
### How is it Created and Used?
For most standard training workflows, you typically won't need to create or interact with the `Targets` object directly in Python code.
BatDetect2 usually handles its creation automatically when you provide your main configuration file during training setup.
Conceptually, here's what happens behind the scenes:
1. You provide the path to your configuration file (e.g., `my_training_config.yaml`).
2. BatDetect2 reads this file and finds your `targets:` configuration section.
3. It uses this configuration to build an instance of the `Targets` object using a dedicated function (like `load_targets`), loading it with the appropriate logic based on your settings.
```python
# Conceptual Example: How BatDetect2 might use your configuration
from batdetect2.targets import load_targets # The function to load/build the object
from batdetect2.targets.types import TargetProtocol # The type/interface
# You provide this path, usually as part of the main training setup
target_config_file = "path/to/your/target_config.yaml"
# --- BatDetect2 Internally Does Something Like This: ---
# Loads your config and builds the Targets object using the loader function
# The resulting object adheres to the TargetProtocol interface
targets_processor: TargetProtocol = load_targets(target_config_file)
# ---------------------------------------------------------
# Now, 'targets_processor' holds all your configured logic and is ready
# to be used internally by the training pipeline or for prediction processing.
```
### What Does the `Targets` Object Do? (Its Role)
Once created, the `targets_processor` object plays several vital roles within the BatDetect2 system:
1. **Preparing Training Data:** During the data loading and label generation phase of training, BatDetect2 uses this object to process each annotation from your dataset _before_ the final training format (e.g., heatmaps) is generated.
For each annotation, it internally applies the logic:
- `targets_processor.filter(...)`: To decide whether to keep the annotation.
- `targets_processor.transform(...)`: To apply any tag modifications.
- `targets_processor.encode(...)`: To get the final class name (e.g., `'pippip'`, `'myodau'`, or `None` for the generic class).
- `targets_processor.get_position(...)`: To determine the reference `(time, frequency)` point from the annotation's geometry.
- `targets_processor.get_size(...)`: To calculate the _scaled_ width and height target values from the annotation's geometry.
2. **Interpreting Model Predictions:** When you use a trained model, its raw outputs (like predicted class names, positions, and sizes) need to be translated back into meaningful results.
This object provides the necessary decoding logic:
- `targets_processor.decode(...)`: Converts a predicted class name back into representative annotation tags.
- `targets_processor.recover_roi(...)`: Converts a predicted position and _scaled_ size values back into an estimated geometric bounding box in real-world coordinates (seconds, Hz).
- `targets_processor.generic_class_tags`: Provides the tags for sounds classified into the generic category.
3. **Providing Metadata:** It conveniently holds useful information derived from your configuration:
- `targets_processor.class_names`: The final list of specific target class names.
- `targets_processor.generic_class_tags`: The tags representing the generic class.
- `targets_processor.dimension_names`: The names used for the size dimensions (e.g., `['width', 'height']`).
### Why is Understanding This Important?
As a researcher using BatDetect2, your primary interaction is typically through the **configuration file**.
The `Targets` object is the component that materializes your configurations.
Understanding its role can be important:
- It helps connect the settings in your configuration file (covering terms, filtering, transforms, classes, and ROIs) to the actual behavior observed during training or when interpreting model outputs.
If the results aren't as expected (e.g., wrong classifications, incorrect bounding box predictions), reviewing the relevant sections of your `TargetConfig` is the first step in debugging.
- Furthermore, understanding this structure is beneficial if you plan to create custom Python scripts.
While standard training runs handle this object internally, the underlying functions for filtering, transforming, encoding, decoding, and ROI mapping are accessible or can be built individually.
This modular design provides the **flexibility to use or customize specific parts of the target definition workflow programmatically** for advanced analyses, integration tasks, or specialized data processing pipelines, should you need to go beyond the standard configuration-driven approach.
### Summary
The `Targets` object encapsulates the entire configured target definition logic specified in your `TargetConfig` file.
It acts as the central component within BatDetect2 for applying filtering, tag transformation, ROI mapping (geometry to/from position/size), class encoding (for training preparation), and class/ROI decoding (for interpreting predictions).
It bridges the gap between your declarative configuration and the functional steps needed for training and using BatDetect2 models effectively, while also offering components for more advanced, scripted workflows.

10
example_conf.yaml Normal file
View File

@ -0,0 +1,10 @@
datasets:
train:
name: example dataset
description: Only for demonstration purposes
sources:
- format: batdetect2
name: Example Data
description: Examples included for testing batdetect2
annotations_dir: example_data/anns
audio_dir: example_data/audio

View File

@ -0,0 +1,177 @@
{
"annotated": true,
"annotation": [
{
"class": "Myotis mystacinus",
"class_prob": 0.55,
"det_prob": 0.658,
"end_time": 0.028,
"event": "Echolocation",
"high_freq": 107492,
"individual": "-1",
"low_freq": 33203,
"start_time": 0.0225
},
{
"class": "Myotis mystacinus",
"class_prob": 0.679,
"det_prob": 0.742,
"end_time": 0.0583,
"event": "Echolocation",
"high_freq": 113192,
"individual": "-1",
"low_freq": 28046,
"start_time": 0.0525
},
{
"class": "Myotis mystacinus",
"class_prob": 0.488,
"det_prob": 0.585,
"end_time": 0.1211,
"event": "Echolocation",
"high_freq": 107008,
"individual": "-1",
"low_freq": 33203,
"start_time": 0.1155
},
{
"class": "Pipistrellus pipistrellus",
"class_prob": 0.46,
"det_prob": 0.503,
"end_time": 0.145,
"event": "Echolocation",
"high_freq": 59621,
"individual": "-1",
"low_freq": 48671,
"start_time": 0.1385
},
{
"class": "Myotis mystacinus",
"class_prob": 0.656,
"det_prob": 0.704,
"end_time": 0.1513,
"event": "Echolocation",
"high_freq": 113493,
"individual": "-1",
"low_freq": 27187,
"start_time": 0.1445
},
{
"class": "Myotis mystacinus",
"class_prob": 0.549,
"det_prob": 0.63,
"end_time": 0.2076,
"event": "Echolocation",
"high_freq": 108573,
"individual": "-1",
"low_freq": 34062,
"start_time": 0.2025
},
{
"class": "Pipistrellus pipistrellus",
"class_prob": 0.503,
"det_prob": 0.528,
"end_time": 0.224,
"event": "Echolocation",
"high_freq": 57361,
"individual": "-1",
"low_freq": 48671,
"start_time": 0.2195
},
{
"class": "Myotis mystacinus",
"class_prob": 0.672,
"det_prob": 0.737,
"end_time": 0.2374,
"event": "Echolocation",
"high_freq": 116415,
"individual": "-1",
"low_freq": 27187,
"start_time": 0.2315
},
{
"class": "Pipistrellus pipistrellus",
"class_prob": 0.65,
"det_prob": 0.736,
"end_time": 0.3058,
"event": "Echolocation",
"high_freq": 56624,
"individual": "-1",
"low_freq": 48671,
"start_time": 0.2995
},
{
"class": "Myotis mystacinus",
"class_prob": 0.687,
"det_prob": 0.724,
"end_time": 0.3312,
"event": "Echolocation",
"high_freq": 116522,
"individual": "-1",
"low_freq": 27187,
"start_time": 0.3245
},
{
"class": "Myotis mystacinus",
"class_prob": 0.547,
"det_prob": 0.599,
"end_time": 0.3762,
"event": "Echolocation",
"high_freq": 108530,
"individual": "-1",
"low_freq": 34062,
"start_time": 0.3705
},
{
"class": "Myotis mystacinus",
"class_prob": 0.664,
"det_prob": 0.711,
"end_time": 0.4184,
"event": "Echolocation",
"high_freq": 115775,
"individual": "-1",
"low_freq": 28906,
"start_time": 0.4125
},
{
"class": "Myotis mystacinus",
"class_prob": 0.544,
"det_prob": 0.598,
"end_time": 0.4423,
"event": "Echolocation",
"high_freq": 104197,
"individual": "-1",
"low_freq": 36640,
"start_time": 0.4365
},
{
"class": "Pipistrellus pipistrellus",
"class_prob": 0.73,
"det_prob": 0.78,
"end_time": 0.4803,
"event": "Echolocation",
"high_freq": 58290,
"individual": "-1",
"low_freq": 48671,
"start_time": 0.4745
},
{
"class": "Myotis mystacinus",
"class_prob": 0.404,
"det_prob": 0.449,
"end_time": 0.4947,
"event": "Echolocation",
"high_freq": 111336,
"individual": "-1",
"low_freq": 36640,
"start_time": 0.4895
}
],
"class_name": "Myotis mystacinus",
"duration": 0.5,
"id": "20170701_213954-MYOMYS-LR_0_0.5.wav",
"issues": false,
"notes": "Automatically generated. Example data do not assume correct!",
"time_exp": 1
}

View File

@ -0,0 +1,231 @@
{
"annotated": true,
"annotation": [
{
"class": "Eptesicus serotinus",
"class_prob": 0.744,
"det_prob": 0.77,
"end_time": 0.0162,
"event": "Echolocation",
"high_freq": 65592,
"individual": "-1",
"low_freq": 27187,
"start_time": 0.0085
},
{
"class": "Pipistrellus pipistrellus",
"class_prob": 0.453,
"det_prob": 0.459,
"end_time": 0.0255,
"event": "Echolocation",
"high_freq": 59730,
"individual": "-1",
"low_freq": 46093,
"start_time": 0.0205
},
{
"class": "Pipistrellus pipistrellus",
"class_prob": 0.668,
"det_prob": 0.68,
"end_time": 0.0499,
"event": "Echolocation",
"high_freq": 57080,
"individual": "-1",
"low_freq": 46953,
"start_time": 0.0445
},
{
"class": "Pipistrellus pipistrellus",
"class_prob": 0.729,
"det_prob": 0.739,
"end_time": 0.109,
"event": "Echolocation",
"high_freq": 62808,
"individual": "-1",
"low_freq": 44375,
"start_time": 0.1025
},
{
"class": "Pipistrellus pipistrellus",
"class_prob": 0.591,
"det_prob": 0.602,
"end_time": 0.1311,
"event": "Echolocation",
"high_freq": 56848,
"individual": "-1",
"low_freq": 46953,
"start_time": 0.1255
},
{
"class": "Eptesicus serotinus",
"class_prob": 0.696,
"det_prob": 0.735,
"end_time": 0.1694,
"event": "Echolocation",
"high_freq": 67238,
"individual": "-1",
"low_freq": 28046,
"start_time": 0.1625
},
{
"class": "Pipistrellus pipistrellus",
"class_prob": 0.617,
"det_prob": 0.643,
"end_time": 0.2031,
"event": "Echolocation",
"high_freq": 57047,
"individual": "-1",
"low_freq": 46093,
"start_time": 0.1975
},
{
"class": "Pipistrellus pipistrellus",
"class_prob": 0.507,
"det_prob": 0.515,
"end_time": 0.2222,
"event": "Echolocation",
"high_freq": 58214,
"individual": "-1",
"low_freq": 47812,
"start_time": 0.2175
},
{
"class": "Eptesicus serotinus",
"class_prob": 0.201,
"det_prob": 0.372,
"end_time": 0.2839,
"event": "Echolocation",
"high_freq": 55667,
"individual": "-1",
"low_freq": 33203,
"start_time": 0.2775
},
{
"class": "Pipistrellus pipistrellus",
"class_prob": 0.749,
"det_prob": 0.78,
"end_time": 0.2918,
"event": "Echolocation",
"high_freq": 60611,
"individual": "-1",
"low_freq": 45234,
"start_time": 0.2855
},
{
"class": "Eptesicus serotinus",
"class_prob": 0.239,
"det_prob": 0.325,
"end_time": 0.3148,
"event": "Echolocation",
"high_freq": 54100,
"individual": "-1",
"low_freq": 30625,
"start_time": 0.3085
},
{
"class": "Eptesicus serotinus",
"class_prob": 0.621,
"det_prob": 0.652,
"end_time": 0.3227,
"event": "Echolocation",
"high_freq": 63504,
"individual": "-1",
"low_freq": 27187,
"start_time": 0.3155
},
{
"class": "Eptesicus serotinus",
"class_prob": 0.32,
"det_prob": 0.414,
"end_time": 0.3546,
"event": "Echolocation",
"high_freq": 37589,
"individual": "-1",
"low_freq": 27187,
"start_time": 0.3455
},
{
"class": "Pipistrellus pipistrellus",
"class_prob": 0.69,
"det_prob": 0.697,
"end_time": 0.3776,
"event": "Echolocation",
"high_freq": 57262,
"individual": "-1",
"low_freq": 46093,
"start_time": 0.3735
},
{
"class": "Eptesicus serotinus",
"class_prob": 0.34,
"det_prob": 0.415,
"end_time": 0.4069,
"event": "Echolocation",
"high_freq": 52025,
"individual": "-1",
"low_freq": 31484,
"start_time": 0.4005
},
{
"class": "Eptesicus serotinus",
"class_prob": 0.386,
"det_prob": 0.445,
"end_time": 0.4178,
"event": "Echolocation",
"high_freq": 53951,
"individual": "-1",
"low_freq": 27187,
"start_time": 0.4115
},
{
"class": "Eptesicus serotinus",
"class_prob": 0.393,
"det_prob": 0.517,
"end_time": 0.4359,
"event": "Echolocation",
"high_freq": 51724,
"individual": "-1",
"low_freq": 30625,
"start_time": 0.4305
},
{
"class": "Eptesicus serotinus",
"class_prob": 0.332,
"det_prob": 0.396,
"end_time": 0.4502,
"event": "Echolocation",
"high_freq": 58310,
"individual": "-1",
"low_freq": 27187,
"start_time": 0.4435
},
{
"class": "Pipistrellus pipistrellus",
"class_prob": 0.45,
"det_prob": 0.456,
"end_time": 0.4638,
"event": "Echolocation",
"high_freq": 55714,
"individual": "-1",
"low_freq": 46093,
"start_time": 0.4575
},
{
"class": "Eptesicus serotinus",
"class_prob": 0.719,
"det_prob": 0.766,
"end_time": 0.4824,
"event": "Echolocation",
"high_freq": 66101,
"individual": "-1",
"low_freq": 28046,
"start_time": 0.4755
}
],
"class_name": "Pipistrellus pipistrellus",
"duration": 0.5,
"id": "20180530_213516-EPTSER-LR_0_0.5.wav",
"issues": false,
"notes": "Automatically generated. Example data do not assume correct!",
"time_exp": 1
}

View File

@ -0,0 +1,111 @@
{
"annotated": true,
"annotation": [
{
"class": "Rhinolophus ferrumequinum",
"class_prob": 0.407,
"det_prob": 0.407,
"end_time": 0.066,
"event": "Echolocation",
"high_freq": 84254,
"individual": "-1",
"low_freq": 68437,
"start_time": 0.0245
},
{
"class": "Rhinolophus ferrumequinum",
"class_prob": 0.759,
"det_prob": 0.76,
"end_time": 0.1576,
"event": "Echolocation",
"high_freq": 84048,
"individual": "-1",
"low_freq": 68437,
"start_time": 0.0955
},
{
"class": "Rhinolophus ferrumequinum",
"class_prob": 0.754,
"det_prob": 0.755,
"end_time": 0.269,
"event": "Echolocation",
"high_freq": 83768,
"individual": "-1",
"low_freq": 68437,
"start_time": 0.2095
},
{
"class": "Rhinolophus ferrumequinum",
"class_prob": 0.495,
"det_prob": 0.495,
"end_time": 0.2869,
"event": "Echolocation",
"high_freq": 84055,
"individual": "-1",
"low_freq": 68437,
"start_time": 0.2425
},
{
"class": "Rhinolophus ferrumequinum",
"class_prob": 0.73,
"det_prob": 0.73,
"end_time": 0.3631,
"event": "Echolocation",
"high_freq": 84280,
"individual": "-1",
"low_freq": 68437,
"start_time": 0.3055
},
{
"class": "Rhinolophus ferrumequinum",
"class_prob": 0.648,
"det_prob": 0.649,
"end_time": 0.3798,
"event": "Echolocation",
"high_freq": 83030,
"individual": "-1",
"low_freq": 68437,
"start_time": 0.3215
},
{
"class": "Rhinolophus ferrumequinum",
"class_prob": 0.678,
"det_prob": 0.678,
"end_time": 0.4611,
"event": "Echolocation",
"high_freq": 84020,
"individual": "-1",
"low_freq": 68437,
"start_time": 0.4065
},
{
"class": "Rhinolophus ferrumequinum",
"class_prob": 0.717,
"det_prob": 0.718,
"end_time": 0.4987,
"event": "Echolocation",
"high_freq": 83603,
"individual": "-1",
"low_freq": 68437,
"start_time": 0.4365
},
{
"class": "Rhinolophus ferrumequinum",
"class_prob": 0.662,
"det_prob": 0.662,
"end_time": 0.5503,
"event": "Echolocation",
"high_freq": 83710,
"individual": "-1",
"low_freq": 68437,
"start_time": 0.4975
}
],
"class_name": "Rhinolophus ferrumequinum",
"duration": 0.5,
"id": "20180627_215323-RHIFER-LR_0_0.5.wav",
"issues": false,
"notes": "Automatically generated. Example data do not assume correct!",
"time_exp": 1
}

View File

@ -17,7 +17,7 @@ dependencies = [
"torch>=1.13.1,<2.5.0",
"torchaudio>=1.13.1,<2.5.0",
"torchvision>=0.14.0",
"soundevent[audio,geometry,plot]>=2.3",
"soundevent[audio,geometry,plot]>=2.5.0",
"click>=8.1.7",
"netcdf4>=1.6.5",
"tqdm>=4.66.2",
@ -30,6 +30,7 @@ dependencies = [
"pyyaml>=6.0.2",
"hydra-core>=1.3.2",
"numba>=0.60",
"loguru>=0.7.3",
]
requires-python = ">=3.9,<3.13"
readme = "README.md"
@ -72,7 +73,14 @@ dev-dependencies = [
"ruff>=0.7.3",
"ipykernel>=6.29.4",
"setuptools>=69.5.1",
"basedpyright>=1.28.4",
"pyright>=1.1.399",
"myst-parser>=3.0.1",
"sphinx-autobuild>=2024.10.3",
"numpydoc>=1.8.0",
"sphinx-autodoc-typehints>=2.3.0",
"sphinx-book-theme>=1.1.4",
"autodoc-pydantic>=2.2.0",
"pytest-cov>=6.1.1",
]
[tool.ruff]
@ -91,7 +99,24 @@ convention = "numpy"
[tool.pyright]
include = ["batdetect2", "tests"]
venvPath = "."
venv = ".venv"
pythonVersion = "3.9"
pythonPlatform = "All"
exclude = [
"batdetect2/detector/",
"batdetect2/finetune",
"batdetect2/utils",
"batdetect2/plotting",
"batdetect2/plot",
"batdetect2/api",
"batdetect2/evaluate/legacy",
"batdetect2/train/legacy",
]
[dependency-groups]
jupyter = [
"ipywidgets>=8.1.5",
"jupyter>=1.1.1",
]
marimo = [
"marimo>=0.12.2",
]

View File

@ -5,7 +5,22 @@ from typing import Callable, List, Optional
import numpy as np
import pytest
import soundfile as sf
from soundevent import data
from soundevent import data, terms
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import (
TargetConfig,
TermRegistry,
build_targets,
call_type,
)
from batdetect2.targets.classes import ClassesConfig, TargetClass
from batdetect2.targets.filtering import FilterConfig, FilterRule
from batdetect2.targets.terms import TagInfo
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.types import ClipLabeller
@pytest.fixture
@ -84,8 +99,8 @@ def wav_factory(tmp_path: Path):
@pytest.fixture
def recording_factory(wav_factory: Callable[..., Path]):
def _recording_factory(
def create_recording(wav_factory: Callable[..., Path]):
def factory(
tags: Optional[list[data.Tag]] = None,
path: Optional[Path] = None,
recording_id: Optional[uuid.UUID] = None,
@ -94,7 +109,8 @@ def recording_factory(wav_factory: Callable[..., Path]):
samplerate: int = 256_000,
time_expansion: float = 1,
) -> data.Recording:
path = path or wav_factory(
path = wav_factory(
path=path,
duration=duration,
channels=channels,
samplerate=samplerate,
@ -106,4 +122,264 @@ def recording_factory(wav_factory: Callable[..., Path]):
tags=tags or [],
)
return _recording_factory
return factory
@pytest.fixture
def recording(
create_recording: Callable[..., data.Recording],
) -> data.Recording:
return create_recording()
@pytest.fixture
def create_clip():
def factory(
recording: data.Recording,
start_time: float = 0,
end_time: float = 0.5,
) -> data.Clip:
return data.Clip(
recording=recording,
start_time=start_time,
end_time=end_time,
)
return factory
@pytest.fixture
def clip(recording: data.Recording) -> data.Clip:
return data.Clip(recording=recording, start_time=0, end_time=0.5)
@pytest.fixture
def create_sound_event():
def factory(
recording: data.Recording,
coords: Optional[List[float]] = None,
) -> data.SoundEvent:
coords = coords or [0.2, 60_000, 0.3, 70_000]
return data.SoundEvent(
geometry=data.BoundingBox(coordinates=coords),
recording=recording,
)
return factory
@pytest.fixture
def sound_event(recording: data.Recording) -> data.SoundEvent:
return data.SoundEvent(
geometry=data.BoundingBox(coordinates=[0.1, 67_000, 0.11, 73_000]),
recording=recording,
)
@pytest.fixture
def create_sound_event_annotation():
def factory(
sound_event: data.SoundEvent,
tags: Optional[List[data.Tag]] = None,
) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=sound_event,
tags=tags or [],
)
return factory
@pytest.fixture
def echolocation_call(recording: data.Recording) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.BoundingBox(coordinates=[0.1, 67_000, 0.11, 73_000]),
recording=recording,
),
tags=[
data.Tag(term=terms.scientific_name, value="Myotis myotis"),
data.Tag(term=call_type, value="Echolocation"),
],
)
@pytest.fixture
def generic_call(recording: data.Recording) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.BoundingBox(
coordinates=[0.34, 35_000, 0.348, 62_000]
),
recording=recording,
),
tags=[
data.Tag(term=terms.order, value="Chiroptera"),
data.Tag(term=call_type, value="Echolocation"),
],
)
@pytest.fixture
def non_relevant_sound_event(
recording: data.Recording,
) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.BoundingBox(
coordinates=[0.22, 50_000, 0.24, 58_000]
),
recording=recording,
),
tags=[
data.Tag(
term=terms.scientific_name,
value="Muscardinus avellanarius",
),
],
)
@pytest.fixture
def create_clip_annotation():
def factory(
clip: data.Clip,
clip_tags: Optional[List[data.Tag]] = None,
sound_events: Optional[List[data.SoundEventAnnotation]] = None,
) -> data.ClipAnnotation:
return data.ClipAnnotation(
clip=clip,
tags=clip_tags or [],
sound_events=sound_events or [],
)
return factory
@pytest.fixture
def clip_annotation(
clip: data.Clip,
echolocation_call: data.SoundEventAnnotation,
generic_call: data.SoundEventAnnotation,
non_relevant_sound_event: data.SoundEventAnnotation,
) -> data.ClipAnnotation:
return data.ClipAnnotation(
clip=clip,
sound_events=[
echolocation_call,
generic_call,
non_relevant_sound_event,
],
)
@pytest.fixture
def create_annotation_set():
def factory(
name: str = "test",
description: str = "Test annotation set",
annotations: Optional[List[data.ClipAnnotation]] = None,
) -> data.AnnotationSet:
return data.AnnotationSet(
name=name,
description=description,
clip_annotations=annotations or [],
)
return factory
@pytest.fixture
def create_annotation_project():
def factory(
name: str = "test_project",
description: str = "Test Annotation Project",
tasks: Optional[List[data.AnnotationTask]] = None,
annotations: Optional[List[data.ClipAnnotation]] = None,
) -> data.AnnotationProject:
return data.AnnotationProject(
name=name,
description=description,
tasks=tasks or [],
clip_annotations=annotations or [],
)
return factory
@pytest.fixture
def sample_term_registry() -> TermRegistry:
"""Fixture for a sample TermRegistry."""
registry = TermRegistry()
registry.add_custom_term("class")
registry.add_custom_term("order")
registry.add_custom_term("species")
registry.add_custom_term("call_type")
registry.add_custom_term("quality")
return registry
@pytest.fixture
def sample_preprocessor() -> PreprocessorProtocol:
return build_preprocessor()
@pytest.fixture
def bat_tag() -> TagInfo:
return TagInfo(key="class", value="bat")
@pytest.fixture
def noise_tag() -> TagInfo:
return TagInfo(key="class", value="noise")
@pytest.fixture
def myomyo_tag() -> TagInfo:
return TagInfo(key="species", value="Myotis myotis")
@pytest.fixture
def pippip_tag() -> TagInfo:
return TagInfo(key="species", value="Pipistrellus pipistrellus")
@pytest.fixture
def sample_target_config(
sample_term_registry: TermRegistry,
bat_tag: TagInfo,
noise_tag: TagInfo,
myomyo_tag: TagInfo,
pippip_tag: TagInfo,
) -> TargetConfig:
return TargetConfig(
filtering=FilterConfig(
rules=[FilterRule(match_type="exclude", tags=[noise_tag])]
),
classes=ClassesConfig(
classes=[
TargetClass(name="pippip", tags=[pippip_tag]),
TargetClass(name="myomyo", tags=[myomyo_tag]),
],
generic_class=[bat_tag],
),
)
@pytest.fixture
def sample_targets(
sample_target_config: TargetConfig,
sample_term_registry: TermRegistry,
) -> TargetProtocol:
return build_targets(
sample_target_config,
term_registry=sample_term_registry,
)
@pytest.fixture
def sample_labeller(
sample_targets: TargetProtocol,
) -> ClipLabeller:
return build_clip_labeler(sample_targets)

View File

@ -8,7 +8,7 @@ from batdetect2.detector import parameters
from batdetect2.utils import audio_utils, detector_utils
@given(duration=st.floats(min_value=0.1, max_value=2))
@given(duration=st.floats(min_value=0.1, max_value=1))
def test_can_compute_correct_spectrogram_width(duration: float):
samplerate = parameters.TARGET_SAMPLERATE_HZ
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS

Some files were not shown because too many files have changed in this diff Show More