mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-30 15:12:06 +02:00
Compare commits
81 Commits
acda71ea45
...
4b4c3ecdf5
Author | SHA1 | Date | |
---|---|---|---|
![]() |
4b4c3ecdf5 | ||
![]() |
13afac65a5 | ||
![]() |
d8cf1db19f | ||
![]() |
8a6ed3dec7 | ||
![]() |
7dd35d6e3e | ||
![]() |
d51e3f8bbd | ||
![]() |
f3999fbba2 | ||
![]() |
2a45859393 | ||
![]() |
86d56d65f4 | ||
![]() |
27ba8de463 | ||
![]() |
59bd14bc79 | ||
![]() |
2396815c13 | ||
![]() |
ac4bb8f023 | ||
![]() |
6498b6ca37 | ||
![]() |
bfcab0331e | ||
![]() |
c276230bff | ||
![]() |
e38c446f59 | ||
![]() |
24c4831745 | ||
![]() |
9fc713d390 | ||
![]() |
285c6a3347 | ||
![]() |
541be15c9e | ||
![]() |
8a463e3942 | ||
![]() |
257e1e01bf | ||
![]() |
ece1a2073d | ||
![]() |
b82973ca1d | ||
![]() |
7c89e82579 | ||
![]() |
dcae411ccb | ||
![]() |
ce15afc231 | ||
![]() |
096d180ea3 | ||
![]() |
ffa4c2e5e9 | ||
![]() |
6c744eaac5 | ||
![]() |
e00674f628 | ||
![]() |
907e05ea48 | ||
![]() |
3123d105fd | ||
![]() |
9b6b8a0bf9 | ||
![]() |
4aa2e6905c | ||
![]() |
3abebc9c17 | ||
![]() |
1f4454693e | ||
![]() |
bcf339c40d | ||
![]() |
089328a4f0 | ||
![]() |
6236e78414 | ||
![]() |
9410112e41 | ||
![]() |
07f065cf93 | ||
![]() |
ae6063918c | ||
![]() |
a0a77cada1 | ||
![]() |
355847346e | ||
![]() |
dfd14df7b9 | ||
![]() |
f353aaa08c | ||
![]() |
bf14f4d37e | ||
![]() |
b78e5a3a2f | ||
![]() |
f9e005ec8b | ||
![]() |
fd7f2b0081 | ||
![]() |
f314942628 | ||
![]() |
638f93fe92 | ||
![]() |
19febf2216 | ||
![]() |
3417c496db | ||
![]() |
4a9af72580 | ||
![]() |
2212246b11 | ||
![]() |
aca0b58443 | ||
![]() |
f5071d00a1 | ||
![]() |
23620c2233 | ||
![]() |
a9f91322d4 | ||
![]() |
22036743d1 | ||
![]() |
eda5f91c86 | ||
![]() |
a2ec190b73 | ||
![]() |
04ed669c4f | ||
![]() |
b796e0bc7b | ||
![]() |
5d4d9a5edf | ||
![]() |
55eff0cebd | ||
![]() |
f99653d68f | ||
![]() |
0778663a2c | ||
![]() |
62471664fa | ||
![]() |
af48c33307 | ||
![]() |
02d4779207 | ||
![]() |
d97614a10d | ||
![]() |
991529cf86 | ||
![]() |
02c1d97e5a | ||
![]() |
3f3f7cd9c8 | ||
![]() |
2fb3039f17 | ||
![]() |
26a2c5c851 | ||
![]() |
b93d4c65c2 |
3
.gitignore
vendored
3
.gitignore
vendored
@ -95,6 +95,8 @@ dmypy.json
|
||||
*.json
|
||||
plots/*
|
||||
|
||||
!example_data/anns/*.json
|
||||
|
||||
# Model experiments
|
||||
experiments/*
|
||||
|
||||
@ -111,3 +113,4 @@ experiments/*
|
||||
!tests/data/**/*.wav
|
||||
notebooks/lightning_logs
|
||||
example_data/preprocessed
|
||||
.aider*
|
||||
|
94
Makefile
Normal file
94
Makefile
Normal file
@ -0,0 +1,94 @@
|
||||
# Variables
|
||||
SOURCE_DIR = batdetect2
|
||||
TESTS_DIR = tests
|
||||
PYTHON_DIRS = batdetect2 tests
|
||||
DOCS_SOURCE = docs/source
|
||||
DOCS_BUILD = docs/build
|
||||
HTML_COVERAGE_DIR = htmlcov
|
||||
|
||||
# Default target (optional, often 'help' or 'all')
|
||||
.DEFAULT_GOAL := help
|
||||
|
||||
# Phony targets (targets that don't produce a file with the same name)
|
||||
.PHONY: help test coverage coverage-html coverage-serve docs docs-serve format format-check lint lint-fix typecheck check clean clean-pyc clean-test clean-docs clean-build
|
||||
|
||||
help:
|
||||
@echo "Makefile Targets:"
|
||||
@echo " help Show this help message."
|
||||
@echo " test Run tests using pytest."
|
||||
@echo " coverage Run tests and generate coverage data (.coverage, coverage.xml)."
|
||||
@echo " coverage-html Generate an HTML coverage report in $(HTML_COVERAGE_DIR)/."
|
||||
@echo " coverage-serve Serve the HTML coverage report locally."
|
||||
@echo " docs Build documentation using Sphinx."
|
||||
@echo " docs-serve Serve documentation with live reload using sphinx-autobuild."
|
||||
@echo " format Format code using ruff."
|
||||
@echo " format-check Check code formatting using ruff."
|
||||
@echo " lint Lint code using ruff."
|
||||
@echo " lint-fix Lint code using ruff and apply automatic fixes."
|
||||
@echo " typecheck Type check code using pyright."
|
||||
@echo " check Run all checks (format-check, lint, typecheck)."
|
||||
@echo " clean Remove all build, test, documentation, and Python artifacts."
|
||||
@echo " clean-pyc Remove Python bytecode and cache."
|
||||
@echo " clean-test Remove test and coverage artifacts."
|
||||
@echo " clean-docs Remove built documentation."
|
||||
@echo " clean-build Remove package build artifacts."
|
||||
|
||||
# Testing & Coverage
|
||||
test:
|
||||
pytest tests
|
||||
|
||||
coverage:
|
||||
pytest --cov=batdetect2 --cov-report=term-missing --cov-report=xml tests
|
||||
|
||||
coverage-html: coverage
|
||||
@echo "Generating HTML coverage report..."
|
||||
coverage html -d $(HTML_COVERAGE_DIR)
|
||||
@echo "HTML coverage report generated in $(HTML_COVERAGE_DIR)/"
|
||||
|
||||
coverage-serve: coverage-html
|
||||
@echo "Serving report at http://localhost:8000/ ..."
|
||||
python -m http.server --directory $(HTML_COVERAGE_DIR) 8000
|
||||
|
||||
# Documentation
|
||||
docs:
|
||||
sphinx-build -b html $(DOCS_SOURCE) $(DOCS_BUILD)
|
||||
|
||||
docs-serve:
|
||||
sphinx-autobuild $(DOCS_SOURCE) $(DOCS_BUILD) --watch $(SOURCE_DIR) --open-browser
|
||||
|
||||
# Formatting & Linting
|
||||
format:
|
||||
ruff format $(PYTHON_DIRS)
|
||||
|
||||
format-check:
|
||||
ruff format --check $(PYTHON_DIRS)
|
||||
|
||||
lint:
|
||||
ruff check $(PYTHON_DIRS)
|
||||
|
||||
lint-fix:
|
||||
ruff check --fix $(PYTHON_DIRS)
|
||||
|
||||
# Type Checking
|
||||
typecheck:
|
||||
pyright $(PYTHON_DIRS)
|
||||
|
||||
# Combined Checks
|
||||
check: format-check lint typecheck test
|
||||
|
||||
# Cleaning tasks
|
||||
clean-pyc:
|
||||
find . -type f -name "*.py[co]" -delete
|
||||
find . -type d -name "__pycache__" -exec rm -rf {} +
|
||||
|
||||
clean-test:
|
||||
rm -f .coverage coverage.xml
|
||||
rm -rf .pytest_cache htmlcov/
|
||||
|
||||
clean-docs:
|
||||
rm -rf $(DOCS_BUILD)
|
||||
|
||||
clean-build:
|
||||
rm -rf build/ dist/ *.egg-info/
|
||||
|
||||
clean: clean-build clean-pyc clean-test clean-docs
|
@ -35,6 +35,8 @@ def summary(
|
||||
):
|
||||
base_dir = base_dir or Path.cwd()
|
||||
dataset = load_dataset_from_config(
|
||||
dataset_config, field=field, base_dir=base_dir
|
||||
dataset_config,
|
||||
field=field,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
print(f"Number of annotated clips: {len(dataset.clip_annotations)}")
|
||||
print(f"Number of annotated clips: {len(dataset)}")
|
||||
|
@ -2,17 +2,14 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.preprocess import (
|
||||
load_preprocessing_config,
|
||||
)
|
||||
from batdetect2.train import (
|
||||
load_label_config,
|
||||
load_target_config,
|
||||
preprocess_annotations,
|
||||
)
|
||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||
from batdetect2.targets import build_targets, load_target_config
|
||||
from batdetect2.train import load_label_config, preprocess_annotations
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
|
||||
__all__ = ["train"]
|
||||
|
||||
@ -127,9 +124,9 @@ def train(): ...
|
||||
def preprocess(
|
||||
dataset_config: Path,
|
||||
output: Path,
|
||||
target_config: Optional[Path] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
preprocess_config: Optional[Path] = None,
|
||||
target_config: Optional[Path] = None,
|
||||
label_config: Optional[Path] = None,
|
||||
force: bool = False,
|
||||
num_workers: Optional[int] = None,
|
||||
@ -138,8 +135,13 @@ def preprocess(
|
||||
label_config_field: Optional[str] = None,
|
||||
dataset_field: Optional[str] = None,
|
||||
):
|
||||
logger.info("Starting preprocessing.")
|
||||
|
||||
output = Path(output)
|
||||
logger.info("Will save outputs to {output}", output=output)
|
||||
|
||||
base_dir = base_dir or Path.cwd()
|
||||
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
|
||||
|
||||
preprocess = (
|
||||
load_preprocessing_config(
|
||||
@ -174,15 +176,25 @@ def preprocess(
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Loaded {num_examples} annotated clips from the configured dataset",
|
||||
num_examples=len(dataset),
|
||||
)
|
||||
|
||||
targets = build_targets(config=target)
|
||||
preprocessor = build_preprocessor(config=preprocess)
|
||||
labeller = build_clip_labeler(targets, config=label)
|
||||
|
||||
if not output.exists():
|
||||
logger.debug("Creating directory {directory}", directory=output)
|
||||
output.mkdir(parents=True)
|
||||
|
||||
logger.info("Will start preprocessing")
|
||||
preprocess_annotations(
|
||||
dataset.clip_annotations,
|
||||
dataset,
|
||||
output_dir=output,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
replace=force,
|
||||
preprocessing_config=preprocess,
|
||||
label_config=label,
|
||||
target_config=target,
|
||||
max_workers=num_workers,
|
||||
)
|
||||
|
@ -1,151 +1,152 @@
|
||||
from batdetect2.preprocess import (
|
||||
AmplitudeScaleConfig,
|
||||
AudioConfig,
|
||||
FrequencyConfig,
|
||||
LogScaleConfig,
|
||||
PcenScaleConfig,
|
||||
PreprocessingConfig,
|
||||
ResampleConfig,
|
||||
Scales,
|
||||
SpecSizeConfig,
|
||||
SpectrogramConfig,
|
||||
STFTConfig,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
|
||||
from batdetect2.terms import TagInfo
|
||||
from batdetect2.train.preprocess import (
|
||||
HeatmapsConfig,
|
||||
TargetConfig,
|
||||
TrainPreprocessingConfig,
|
||||
)
|
||||
|
||||
|
||||
def get_spectrogram_scale(scale: str) -> Scales:
|
||||
if scale == "pcen":
|
||||
return PcenScaleConfig()
|
||||
if scale == "log":
|
||||
return LogScaleConfig()
|
||||
return AmplitudeScaleConfig()
|
||||
|
||||
|
||||
def get_preprocessing_config(params: dict) -> PreprocessingConfig:
|
||||
return PreprocessingConfig(
|
||||
audio=AudioConfig(
|
||||
resample=ResampleConfig(
|
||||
samplerate=params["target_samp_rate"],
|
||||
mode="poly",
|
||||
),
|
||||
scale=params["scale_raw_audio"],
|
||||
center=params["scale_raw_audio"],
|
||||
duration=None,
|
||||
),
|
||||
spectrogram=SpectrogramConfig(
|
||||
stft=STFTConfig(
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
window_fn="hann",
|
||||
),
|
||||
frequencies=FrequencyConfig(
|
||||
min_freq=params["min_freq"],
|
||||
max_freq=params["max_freq"],
|
||||
),
|
||||
scale=get_spectrogram_scale(params["spec_scale"]),
|
||||
denoise=params["denoise_spec_avg"],
|
||||
size=SpecSizeConfig(
|
||||
height=params["spec_height"],
|
||||
resize_factor=params["resize_factor"],
|
||||
),
|
||||
max_scale=params["max_scale_spec"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_training_preprocessing_config(
|
||||
params: dict,
|
||||
) -> TrainPreprocessingConfig:
|
||||
generic = params["generic_class"][0]
|
||||
preprocessing = get_preprocessing_config(params)
|
||||
|
||||
freq_bin_width, time_bin_width = get_spectrogram_resolution(
|
||||
preprocessing.spectrogram
|
||||
)
|
||||
|
||||
return TrainPreprocessingConfig(
|
||||
preprocessing=preprocessing,
|
||||
target=TargetConfig(
|
||||
classes=[
|
||||
TagInfo(key="class", value=class_name, label=class_name)
|
||||
for class_name in params["class_names"]
|
||||
],
|
||||
generic_class=TagInfo(
|
||||
key="class",
|
||||
value=generic,
|
||||
label=generic,
|
||||
),
|
||||
include=[
|
||||
TagInfo(key="event", value=event)
|
||||
for event in params["events_of_interest"]
|
||||
],
|
||||
exclude=[
|
||||
TagInfo(key="class", value=value)
|
||||
for value in params["classes_to_ignore"]
|
||||
],
|
||||
),
|
||||
heatmaps=HeatmapsConfig(
|
||||
position="bottom-left",
|
||||
time_scale=1 / time_bin_width,
|
||||
frequency_scale=1 / freq_bin_width,
|
||||
sigma=params["target_sigma"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# 'standardize_classs_names_ip',
|
||||
# 'convert_to_genus',
|
||||
# 'genus_mapping',
|
||||
# 'standardize_classs_names',
|
||||
# 'genus_names',
|
||||
|
||||
# ['data_dir',
|
||||
# 'ann_dir',
|
||||
# 'train_split',
|
||||
# 'model_name',
|
||||
# 'num_filters',
|
||||
# 'experiment',
|
||||
# 'model_file_name',
|
||||
# 'op_im_dir',
|
||||
# 'op_im_dir_test',
|
||||
# 'notes',
|
||||
# 'spec_divide_factor',
|
||||
# 'detection_overlap',
|
||||
# 'ignore_start_end',
|
||||
# 'detection_threshold',
|
||||
# 'nms_kernel_size',
|
||||
# 'nms_top_k_per_sec',
|
||||
# 'aug_prob',
|
||||
# 'augment_at_train',
|
||||
# 'augment_at_train_combine',
|
||||
# 'echo_max_delay',
|
||||
# 'stretch_squeeze_delta',
|
||||
# 'mask_max_time_perc',
|
||||
# 'mask_max_freq_perc',
|
||||
# 'spec_amp_scaling',
|
||||
# 'aug_sampling_rates',
|
||||
# 'train_loss',
|
||||
# 'det_loss_weight',
|
||||
# 'size_loss_weight',
|
||||
# 'class_loss_weight',
|
||||
# 'individual_loss_weight',
|
||||
# 'emb_dim',
|
||||
# 'lr',
|
||||
# 'batch_size',
|
||||
# 'num_workers',
|
||||
# 'num_epochs',
|
||||
# 'num_eval_epochs',
|
||||
# 'device',
|
||||
# 'save_test_image_during_train',
|
||||
# 'save_test_image_after_train',
|
||||
# 'train_sets',
|
||||
# 'test_sets',
|
||||
# 'class_inv_freq',
|
||||
# 'ip_height']
|
||||
# from batdetect2.preprocess import (
|
||||
# AmplitudeScaleConfig,
|
||||
# AudioConfig,
|
||||
# FrequencyConfig,
|
||||
# LogScaleConfig,
|
||||
# PcenConfig,
|
||||
# PreprocessingConfig,
|
||||
# ResampleConfig,
|
||||
# Scales,
|
||||
# SpecSizeConfig,
|
||||
# SpectrogramConfig,
|
||||
# STFTConfig,
|
||||
# )
|
||||
# from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
|
||||
# from batdetect2.targets import (
|
||||
# LabelConfig,
|
||||
# TagInfo,
|
||||
# TargetConfig,
|
||||
# )
|
||||
# from batdetect2.train.preprocess import (
|
||||
# TrainPreprocessingConfig,
|
||||
# )
|
||||
#
|
||||
#
|
||||
# def get_spectrogram_scale(scale: str) -> Scales:
|
||||
# if scale == "pcen":
|
||||
# return PcenConfig()
|
||||
# if scale == "log":
|
||||
# return LogScaleConfig()
|
||||
# return AmplitudeScaleConfig()
|
||||
#
|
||||
#
|
||||
# def get_preprocessing_config(params: dict) -> PreprocessingConfig:
|
||||
# return PreprocessingConfig(
|
||||
# audio=AudioConfig(
|
||||
# resample=ResampleConfig(
|
||||
# samplerate=params["target_samp_rate"],
|
||||
# method="poly",
|
||||
# ),
|
||||
# scale=params["scale_raw_audio"],
|
||||
# center=params["scale_raw_audio"],
|
||||
# duration=None,
|
||||
# ),
|
||||
# spectrogram=SpectrogramConfig(
|
||||
# stft=STFTConfig(
|
||||
# window_duration=params["fft_win_length"],
|
||||
# window_overlap=params["fft_overlap"],
|
||||
# window_fn="hann",
|
||||
# ),
|
||||
# frequencies=FrequencyConfig(
|
||||
# min_freq=params["min_freq"],
|
||||
# max_freq=params["max_freq"],
|
||||
# ),
|
||||
# scale=get_spectrogram_scale(params["spec_scale"]),
|
||||
# spectral_mean_substraction=params["denoise_spec_avg"],
|
||||
# size=SpecSizeConfig(
|
||||
# height=params["spec_height"],
|
||||
# resize_factor=params["resize_factor"],
|
||||
# ),
|
||||
# peak_normalize=params["max_scale_spec"],
|
||||
# ),
|
||||
# )
|
||||
#
|
||||
#
|
||||
# def get_training_preprocessing_config(
|
||||
# params: dict,
|
||||
# ) -> TrainPreprocessingConfig:
|
||||
# generic = params["generic_class"][0]
|
||||
# preprocessing = get_preprocessing_config(params)
|
||||
#
|
||||
# freq_bin_width, time_bin_width = get_spectrogram_resolution(
|
||||
# preprocessing.spectrogram
|
||||
# )
|
||||
#
|
||||
# return TrainPreprocessingConfig(
|
||||
# preprocessing=preprocessing,
|
||||
# target=TargetConfig(
|
||||
# classes=[
|
||||
# TagInfo(key="class", value=class_name)
|
||||
# for class_name in params["class_names"]
|
||||
# ],
|
||||
# generic_class=TagInfo(
|
||||
# key="class",
|
||||
# value=generic,
|
||||
# ),
|
||||
# include=[
|
||||
# TagInfo(key="event", value=event)
|
||||
# for event in params["events_of_interest"]
|
||||
# ],
|
||||
# exclude=[
|
||||
# TagInfo(key="class", value=value)
|
||||
# for value in params["classes_to_ignore"]
|
||||
# ],
|
||||
# ),
|
||||
# labels=LabelConfig(
|
||||
# position="bottom-left",
|
||||
# time_scale=1 / time_bin_width,
|
||||
# frequency_scale=1 / freq_bin_width,
|
||||
# sigma=params["target_sigma"],
|
||||
# ),
|
||||
# )
|
||||
#
|
||||
#
|
||||
# # 'standardize_classs_names_ip',
|
||||
# # 'convert_to_genus',
|
||||
# # 'genus_mapping',
|
||||
# # 'standardize_classs_names',
|
||||
# # 'genus_names',
|
||||
#
|
||||
# # ['data_dir',
|
||||
# # 'ann_dir',
|
||||
# # 'train_split',
|
||||
# # 'model_name',
|
||||
# # 'num_filters',
|
||||
# # 'experiment',
|
||||
# # 'model_file_name',
|
||||
# # 'op_im_dir',
|
||||
# # 'op_im_dir_test',
|
||||
# # 'notes',
|
||||
# # 'spec_divide_factor',
|
||||
# # 'detection_overlap',
|
||||
# # 'ignore_start_end',
|
||||
# # 'detection_threshold',
|
||||
# # 'nms_kernel_size',
|
||||
# # 'nms_top_k_per_sec',
|
||||
# # 'aug_prob',
|
||||
# # 'augment_at_train',
|
||||
# # 'augment_at_train_combine',
|
||||
# # 'echo_max_delay',
|
||||
# # 'stretch_squeeze_delta',
|
||||
# # 'mask_max_time_perc',
|
||||
# # 'mask_max_freq_perc',
|
||||
# # 'spec_amp_scaling',
|
||||
# # 'aug_sampling_rates',
|
||||
# # 'train_loss',
|
||||
# # 'det_loss_weight',
|
||||
# # 'size_loss_weight',
|
||||
# # 'class_loss_weight',
|
||||
# # 'individual_loss_weight',
|
||||
# # 'emb_dim',
|
||||
# # 'lr',
|
||||
# # 'batch_size',
|
||||
# # 'num_workers',
|
||||
# # 'num_epochs',
|
||||
# # 'num_eval_epochs',
|
||||
# # 'device',
|
||||
# # 'save_test_image_during_train',
|
||||
# # 'save_test_image_after_train',
|
||||
# # 'train_sets',
|
||||
# # 'test_sets',
|
||||
# # 'class_inv_freq',
|
||||
# # 'ip_height']
|
||||
|
@ -1,23 +1,105 @@
|
||||
"""Provides base classes and utilities for loading configurations in BatDetect2.
|
||||
|
||||
This module leverages Pydantic for robust configuration handling, ensuring
|
||||
that configuration files (typically YAML) adhere to predefined schemas. It
|
||||
defines a base configuration class (`BaseConfig`) that enforces strict schema
|
||||
validation and a utility function (`load_config`) to load and validate
|
||||
configuration data from files, with optional support for accessing nested
|
||||
configuration sections.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, Type, TypeVar
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from soundevent.data import PathLike
|
||||
|
||||
__all__ = [
|
||||
"BaseConfig",
|
||||
"load_config",
|
||||
]
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
"""Base class for all configuration models in BatDetect2.
|
||||
|
||||
Inherits from Pydantic's `BaseModel` to provide data validation, parsing,
|
||||
and serialization capabilities.
|
||||
|
||||
It sets `extra='forbid'` in its model configuration, meaning that any
|
||||
fields provided in a configuration file that are *not* explicitly defined
|
||||
in the specific configuration schema will raise a validation error. This
|
||||
helps catch typos and ensures configurations strictly adhere to the expected
|
||||
structure.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
model_config : ConfigDict
|
||||
Pydantic model configuration dictionary. Set to forbid extra fields.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
def get_object_field(obj: dict, field: str) -> Any:
|
||||
if "." not in field:
|
||||
return obj[field]
|
||||
def get_object_field(obj: dict, current_key: str) -> Any:
|
||||
"""Access a potentially nested field within a dictionary using dot notation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obj : dict
|
||||
The dictionary (or nested dictionaries) to access.
|
||||
field : str
|
||||
The field name to retrieve. Nested fields are specified using dots
|
||||
(e.g., "parent_key.child_key.target_field").
|
||||
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
The value found at the specified field path.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If any part of the field path does not exist in the dictionary
|
||||
structure.
|
||||
TypeError
|
||||
If an intermediate part of the path exists but is not a dictionary,
|
||||
preventing further nesting.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> data = {"a": {"b": {"c": 10}}}
|
||||
>>> get_object_field(data, "a.b.c")
|
||||
10
|
||||
>>> get_object_field(data, "a.b")
|
||||
{'c': 10}
|
||||
>>> get_object_field(data, "a")
|
||||
{'b': {'c': 10}}
|
||||
>>> get_object_field(data, "x")
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
KeyError: 'x'
|
||||
>>> get_object_field(data, "a.x.c")
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
KeyError: 'x'
|
||||
"""
|
||||
if "." not in current_key:
|
||||
return obj[current_key]
|
||||
|
||||
current_key, rest = current_key.split(".", 1)
|
||||
subobj = obj[current_key]
|
||||
|
||||
if not isinstance(subobj, dict):
|
||||
raise TypeError(
|
||||
f"Intermediate key '{current_key}' in path '{current_key}' "
|
||||
f"does not lead to a dictionary (found type: {type(subobj)}). "
|
||||
"Cannot access further nested field."
|
||||
)
|
||||
|
||||
field, rest = field.split(".", 1)
|
||||
subobj = obj[field]
|
||||
return get_object_field(subobj, rest)
|
||||
|
||||
|
||||
@ -26,6 +108,49 @@ def load_config(
|
||||
schema: Type[T],
|
||||
field: Optional[str] = None,
|
||||
) -> T:
|
||||
"""Load and validate configuration data from a file against a schema.
|
||||
|
||||
Reads a YAML file, optionally extracts a specific section using dot
|
||||
notation, and then validates the resulting data against the provided
|
||||
Pydantic `schema`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
The path to the configuration file (typically `.yaml`).
|
||||
schema : Type[T]
|
||||
The Pydantic `BaseModel` subclass that defines the expected structure
|
||||
and types for the configuration data.
|
||||
field : str, optional
|
||||
A dot-separated string indicating a nested section within the YAML
|
||||
file to extract before validation. If None (default), the entire
|
||||
file content is validated against the schema.
|
||||
Example: `"training.optimizer"` would extract the `optimizer` section
|
||||
within the `training` section.
|
||||
|
||||
Returns
|
||||
-------
|
||||
T
|
||||
An instance of the provided `schema`, populated and validated with
|
||||
data from the configuration file.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the file specified by `path` does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded configuration data (after optionally extracting the
|
||||
`field`) does not conform to the provided `schema` (e.g., missing
|
||||
required fields, incorrect types, extra fields if using BaseConfig).
|
||||
KeyError
|
||||
If `field` is provided and specifies a path where intermediate keys
|
||||
do not exist in the loaded YAML data.
|
||||
TypeError
|
||||
If `field` is provided and specifies a path where an intermediate
|
||||
value is not a dictionary, preventing access to nested fields.
|
||||
"""
|
||||
with open(path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
|
@ -1,13 +1,22 @@
|
||||
from batdetect2.data.annotations import (
|
||||
AnnotatedDataset,
|
||||
AOEFAnnotations,
|
||||
BatDetect2FilesAnnotations,
|
||||
BatDetect2MergedAnnotations,
|
||||
load_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.data import load_dataset, load_dataset_from_config
|
||||
from batdetect2.data.types import Dataset
|
||||
from batdetect2.data.datasets import (
|
||||
DatasetConfig,
|
||||
load_dataset,
|
||||
load_dataset_from_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AOEFAnnotations",
|
||||
"AnnotatedDataset",
|
||||
"Dataset",
|
||||
"BatDetect2FilesAnnotations",
|
||||
"BatDetect2MergedAnnotations",
|
||||
"DatasetConfig",
|
||||
"load_annotated_dataset",
|
||||
"load_dataset",
|
||||
"load_dataset_from_config",
|
||||
|
@ -1,36 +0,0 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Literal, Union
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"AOEFAnnotationFile",
|
||||
"AnnotationFormats",
|
||||
"BatDetect2AnnotationFile",
|
||||
"BatDetect2AnnotationFiles",
|
||||
]
|
||||
|
||||
|
||||
class BatDetect2AnnotationFiles(BaseConfig):
|
||||
format: Literal["batdetect2"] = "batdetect2"
|
||||
path: Path
|
||||
|
||||
|
||||
class BatDetect2AnnotationFile(BaseConfig):
|
||||
format: Literal["batdetect2_file"] = "batdetect2_file"
|
||||
path: Path
|
||||
|
||||
|
||||
class AOEFAnnotationFile(BaseConfig):
|
||||
format: Literal["aoef"] = "aoef"
|
||||
path: Path
|
||||
|
||||
|
||||
AnnotationFormats = Union[
|
||||
BatDetect2AnnotationFiles,
|
||||
BatDetect2AnnotationFile,
|
||||
AOEFAnnotationFile,
|
||||
]
|
||||
|
||||
|
@ -1,29 +1,44 @@
|
||||
"""Handles loading of annotation data from various formats.
|
||||
|
||||
This module serves as the central dispatcher for parsing annotation data
|
||||
associated with BatDetect2 datasets. Datasets can be composed of multiple
|
||||
sources, each potentially using a different annotation format (e.g., the
|
||||
standard AOEF/soundevent format, or legacy BatDetect2 formats).
|
||||
|
||||
This module defines the `AnnotationFormats` type, which represents the union
|
||||
of possible configuration models for these different formats (each identified
|
||||
by a unique `format` field). The primary function, `load_annotated_dataset`,
|
||||
inspects the configuration for a single data source and calls the appropriate
|
||||
format-specific loading function to retrieve the annotations as a standard
|
||||
`soundevent.data.AnnotationSet`.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.annotations.aeof import (
|
||||
from batdetect2.data.annotations.aoef import (
|
||||
AOEFAnnotations,
|
||||
load_aoef_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.annotations.batdetect2_files import (
|
||||
from batdetect2.data.annotations.batdetect2 import (
|
||||
AnnotationFilter,
|
||||
BatDetect2FilesAnnotations,
|
||||
load_batdetect2_files_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.annotations.batdetect2_merged import (
|
||||
BatDetect2MergedAnnotations,
|
||||
load_batdetect2_files_annotated_dataset,
|
||||
load_batdetect2_merged_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
__all__ = [
|
||||
"load_annotated_dataset",
|
||||
"AnnotatedDataset",
|
||||
"AOEFAnnotations",
|
||||
"AnnotatedDataset",
|
||||
"AnnotationFilter",
|
||||
"AnnotationFormats",
|
||||
"BatDetect2FilesAnnotations",
|
||||
"BatDetect2MergedAnnotations",
|
||||
"AnnotationFormats",
|
||||
"load_annotated_dataset",
|
||||
]
|
||||
|
||||
|
||||
@ -32,12 +47,52 @@ AnnotationFormats = Union[
|
||||
BatDetect2FilesAnnotations,
|
||||
AOEFAnnotations,
|
||||
]
|
||||
"""Type Alias representing all supported data source configurations.
|
||||
|
||||
Each specific configuration model within this union (e.g., `AOEFAnnotations`,
|
||||
`BatDetect2FilesAnnotations`) corresponds to a different annotation format
|
||||
or storage structure. These models are typically discriminated by a `format`
|
||||
field (e.g., `format="aoef"`, `format="batdetect2_files"`), allowing Pydantic
|
||||
and functions like `load_annotated_dataset` to determine which format a given
|
||||
source configuration represents.
|
||||
"""
|
||||
|
||||
|
||||
def load_annotated_dataset(
|
||||
dataset: AnnotatedDataset,
|
||||
base_dir: Optional[Path] = None,
|
||||
) -> data.AnnotationSet:
|
||||
"""Load annotations for a single data source based on its configuration.
|
||||
|
||||
This function acts as a dispatcher. It inspects the type of the input
|
||||
`source_config` object (which corresponds to a specific annotation format)
|
||||
and calls the appropriate loading function (e.g.,
|
||||
`load_aoef_annotated_dataset` for `AOEFAnnotations`).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_config : AnnotationFormats
|
||||
The configuration object for the data source, specifying its format
|
||||
and necessary details (like paths). Must be an instance of one of the
|
||||
types included in the `AnnotationFormats` union.
|
||||
base_dir : Path, optional
|
||||
An optional base directory path. If provided, relative paths within
|
||||
the `source_config` might be resolved relative to this directory by
|
||||
the underlying loading functions. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.AnnotationSet
|
||||
An AnnotationSet containing the `ClipAnnotation` objects loaded and
|
||||
parsed from the specified data source.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
If the type of the `source_config` object does not match any of the
|
||||
known format-specific loading functions implemented in the dispatch
|
||||
logic.
|
||||
"""
|
||||
if isinstance(dataset, AOEFAnnotations):
|
||||
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)
|
||||
|
||||
|
@ -1,37 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
__all__ = [
|
||||
"AOEFAnnotations",
|
||||
"load_aoef_annotated_dataset",
|
||||
]
|
||||
|
||||
|
||||
class AOEFAnnotations(AnnotatedDataset):
|
||||
format: Literal["aoef"] = "aoef"
|
||||
annotations_path: Path
|
||||
|
||||
|
||||
def load_aoef_annotated_dataset(
|
||||
dataset: AOEFAnnotations,
|
||||
base_dir: Optional[Path] = None,
|
||||
) -> data.AnnotationSet:
|
||||
audio_dir = dataset.audio_dir
|
||||
path = dataset.annotations_path
|
||||
|
||||
if base_dir:
|
||||
audio_dir = base_dir / audio_dir
|
||||
path = base_dir / path
|
||||
|
||||
loaded = io.load(path, audio_dir=audio_dir)
|
||||
|
||||
if not isinstance(loaded, (data.AnnotationSet, data.AnnotationProject)):
|
||||
raise ValueError(
|
||||
f"The AOEF file at {path} does not contain a set of annotations"
|
||||
)
|
||||
|
||||
return loaded
|
270
batdetect2/data/annotations/aoef.py
Normal file
270
batdetect2/data/annotations/aoef.py
Normal file
@ -0,0 +1,270 @@
|
||||
"""Loads annotation data specifically from the AOEF / soundevent format.
|
||||
|
||||
This module provides the necessary configuration model and loading function
|
||||
to handle data sources where annotations are stored in the standard format
|
||||
used by the `soundevent` library (often as `.json` or `.aoef` files),
|
||||
which includes outputs from annotation tools like Whombat.
|
||||
|
||||
It supports loading both simple `AnnotationSet` files and more complex
|
||||
`AnnotationProject` files. For `AnnotationProject` files, it offers optional
|
||||
filtering capabilities to select only annotations associated with tasks
|
||||
that meet specific status criteria (e.g., completed, verified, without issues).
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
from uuid import uuid5
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
__all__ = [
|
||||
"AOEFAnnotations",
|
||||
"load_aoef_annotated_dataset",
|
||||
"AnnotationTaskFilter",
|
||||
]
|
||||
|
||||
|
||||
class AnnotationTaskFilter(BaseConfig):
|
||||
"""Configuration for filtering Annotation Tasks within an AnnotationProject.
|
||||
|
||||
Specifies criteria based on task status badges to select relevant
|
||||
annotations, typically used when loading data from annotation projects
|
||||
that might contain work-in-progress.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
only_completed : bool, default=True
|
||||
If True, only include annotations from tasks marked as 'completed'.
|
||||
only_verified : bool, default=False
|
||||
If True, only include annotations from tasks marked as 'verified'.
|
||||
exclude_issues : bool, default=True
|
||||
If True, exclude annotations from tasks marked as 'rejected' (indicating
|
||||
issues).
|
||||
"""
|
||||
|
||||
only_completed: bool = True
|
||||
only_verified: bool = False
|
||||
exclude_issues: bool = True
|
||||
|
||||
|
||||
class AOEFAnnotations(AnnotatedDataset):
|
||||
"""Configuration defining a data source stored in AOEF format.
|
||||
|
||||
This model specifies how to load annotations from an AOEF (JSON file) file
|
||||
compatible with the `soundevent` library. It inherits `name`,
|
||||
`description`, and `audio_dir` from `AnnotatedDataset`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
format : Literal["aoef"]
|
||||
The fixed format identifier for this configuration type.
|
||||
annotations_path : Path
|
||||
The file system path to the `.aoef` or `.json` file containing the
|
||||
`AnnotationSet` or `AnnotationProject`.
|
||||
filter : AnnotationTaskFilter, optional
|
||||
Configuration for filtering tasks if the `annotations_path` points to
|
||||
an `AnnotationProject`. If omitted, default filtering
|
||||
(only completed, exclude issues, verification not required) is applied
|
||||
to projects. Set explicitly to `None` in config (e.g., `filter: null`)
|
||||
to disable filtering for projects entirely.
|
||||
"""
|
||||
|
||||
format: Literal["aoef"] = "aoef"
|
||||
|
||||
annotations_path: Path
|
||||
|
||||
filter: Optional[AnnotationTaskFilter] = Field(
|
||||
default_factory=AnnotationTaskFilter
|
||||
)
|
||||
|
||||
|
||||
def load_aoef_annotated_dataset(
|
||||
dataset: AOEFAnnotations,
|
||||
base_dir: Optional[Path] = None,
|
||||
) -> data.AnnotationSet:
|
||||
"""Load annotations from an AnnotationSet or AnnotationProject file.
|
||||
|
||||
Reads the file specified in the `dataset` configuration using
|
||||
`soundevent.io.load`. If the loaded file contains an `AnnotationProject`
|
||||
and filtering is enabled via `dataset.filter`, it applies the filter
|
||||
criteria based on task status and returns a new `AnnotationSet` containing
|
||||
only the selected annotations. If the file contains an `AnnotationSet`,
|
||||
or if it's a project and filtering is disabled, the all annotations are
|
||||
returned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : AOEFAnnotations
|
||||
The configuration object describing the AOEF data source, including
|
||||
the path to the annotation file and optional filtering settings.
|
||||
base_dir : Path, optional
|
||||
An optional base directory. If provided, `dataset.annotations_path`
|
||||
and `dataset.audio_dir` will be resolved relative to this
|
||||
directory. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.AnnotationSet
|
||||
An AnnotationSet containing the loaded (and potentially filtered)
|
||||
`ClipAnnotation` objects.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the specified `annotations_path` (after resolving `base_dir`)
|
||||
does not exist.
|
||||
ValueError
|
||||
If the loaded file does not contain a valid `AnnotationSet` or
|
||||
`AnnotationProject`.
|
||||
Exception
|
||||
May re-raise errors from `soundevent.io.load` related to parsing
|
||||
or file format issues.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The `soundevent` library handles parsing of `.json` or `.aoef` formats.
|
||||
- If an `AnnotationProject` is loaded and `dataset.filter` is *not* None,
|
||||
a *new* `AnnotationSet` instance is created containing only the filtered
|
||||
clip annotations.
|
||||
"""
|
||||
audio_dir = dataset.audio_dir
|
||||
path = dataset.annotations_path
|
||||
|
||||
if base_dir:
|
||||
audio_dir = base_dir / audio_dir
|
||||
path = base_dir / path
|
||||
|
||||
loaded = io.load(path, audio_dir=audio_dir)
|
||||
|
||||
if not isinstance(loaded, (data.AnnotationSet, data.AnnotationProject)):
|
||||
raise ValueError(
|
||||
f"The file at {path} loaded successfully but does not "
|
||||
"contain a soundevent AnnotationSet or AnnotationProject "
|
||||
f"(loaded type: {type(loaded).__name__})."
|
||||
)
|
||||
|
||||
if isinstance(loaded, data.AnnotationProject) and dataset.filter:
|
||||
loaded = filter_ready_clips(
|
||||
loaded,
|
||||
only_completed=dataset.filter.only_completed,
|
||||
only_verified=dataset.filter.only_verified,
|
||||
exclude_issues=dataset.filter.exclude_issues,
|
||||
)
|
||||
|
||||
return loaded
|
||||
|
||||
|
||||
def select_task(
|
||||
annotation_task: data.AnnotationTask,
|
||||
only_completed: bool = True,
|
||||
only_verified: bool = False,
|
||||
exclude_issues: bool = True,
|
||||
) -> bool:
|
||||
"""Check if an AnnotationTask meets specified status criteria.
|
||||
|
||||
Evaluates the `status_badges` of the task against the filter flags.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
annotation_task : data.AnnotationTask
|
||||
The annotation task to check.
|
||||
only_completed : bool, default=True
|
||||
Task must be marked 'completed' to pass.
|
||||
only_verified : bool, default=False
|
||||
Task must be marked 'verified' to pass.
|
||||
exclude_issues : bool, default=True
|
||||
Task must *not* be marked 'rejected' (have issues) to pass.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the task meets all active filter criteria, False otherwise.
|
||||
"""
|
||||
has_issues = False
|
||||
is_completed = False
|
||||
is_verified = False
|
||||
|
||||
for badge in annotation_task.status_badges:
|
||||
if badge.state == data.AnnotationState.completed:
|
||||
is_completed = True
|
||||
continue
|
||||
|
||||
if badge.state == data.AnnotationState.rejected:
|
||||
has_issues = True
|
||||
continue
|
||||
|
||||
if badge.state == data.AnnotationState.verified:
|
||||
is_verified = True
|
||||
|
||||
if exclude_issues and has_issues:
|
||||
return False
|
||||
|
||||
if only_verified and not is_verified:
|
||||
return False
|
||||
|
||||
if only_completed and not is_completed:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def filter_ready_clips(
|
||||
annotation_project: data.AnnotationProject,
|
||||
only_completed: bool = True,
|
||||
only_verified: bool = False,
|
||||
exclude_issues: bool = True,
|
||||
) -> data.AnnotationSet:
|
||||
"""Filter AnnotationProject to create an AnnotationSet of 'ready' clips.
|
||||
|
||||
Iterates through tasks in the project, selects tasks meeting the status
|
||||
criteria using `select_task`, and creates a new `AnnotationSet` containing
|
||||
only the `ClipAnnotation` objects associated with those selected tasks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
annotation_project : data.AnnotationProject
|
||||
The input annotation project.
|
||||
only_completed : bool, default=True
|
||||
Filter flag passed to `select_task`.
|
||||
only_verified : bool, default=False
|
||||
Filter flag passed to `select_task`.
|
||||
exclude_issues : bool, default=True
|
||||
Filter flag passed to `select_task`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.AnnotationSet
|
||||
A new annotation set containing only the clip annotations linked to
|
||||
tasks that satisfied the filtering criteria. The returned set has a
|
||||
deterministic UUID based on the project UUID and filter settings.
|
||||
"""
|
||||
ready_clip_uuids = set()
|
||||
|
||||
for annotation_task in annotation_project.tasks:
|
||||
if not select_task(
|
||||
annotation_task,
|
||||
only_completed=only_completed,
|
||||
only_verified=only_verified,
|
||||
exclude_issues=exclude_issues,
|
||||
):
|
||||
continue
|
||||
|
||||
ready_clip_uuids.add(annotation_task.clip.uuid)
|
||||
|
||||
return data.AnnotationSet(
|
||||
uuid=uuid5(
|
||||
annotation_project.uuid,
|
||||
f"{only_completed}_{only_verified}_{exclude_issues}",
|
||||
),
|
||||
name=annotation_project.name,
|
||||
description=annotation_project.description,
|
||||
clip_annotations=[
|
||||
annotation
|
||||
for annotation in annotation_project.clip_annotations
|
||||
if annotation.clip.uuid in ready_clip_uuids
|
||||
],
|
||||
)
|
328
batdetect2/data/annotations/batdetect2.py
Normal file
328
batdetect2/data/annotations/batdetect2.py
Normal file
@ -0,0 +1,328 @@
|
||||
"""Loads annotation data from legacy BatDetect2 JSON formats.
|
||||
|
||||
This module provides backward compatibility for loading annotation data stored
|
||||
in two related formats used by older BatDetect2 tools:
|
||||
|
||||
1. **`batdetect2` format** (Directory-based): Annotations are stored in
|
||||
individual JSON files (one per audio recording) within a specified
|
||||
directory.
|
||||
Each JSON file contains a `FileAnnotation` structure. Loaded via
|
||||
`load_batdetect2_files_annotated_dataset` defined by
|
||||
`BatDetect2FilesAnnotations`.
|
||||
2. **`batdetect2_file` format** (Single-file): Annotations for multiple
|
||||
recordings are merged into a single JSON file, containing a list of
|
||||
`FileAnnotation` objects. Loaded via
|
||||
`load_batdetect2_merged_annotated_dataset` defined by
|
||||
`BatDetect2MergedAnnotations`.
|
||||
|
||||
Both formats use the same internal structure for annotations per file and
|
||||
support filtering based on `annotated` and `issues` flags within that
|
||||
structure.
|
||||
|
||||
The loading functions convert data from these legacy formats into the modern
|
||||
`soundevent` data model (primarily `ClipAnnotation`) and return the results
|
||||
aggregated into a `soundevent.data.AnnotationSet`.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field, ValidationError
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data.annotations.legacy import (
|
||||
FileAnnotation,
|
||||
file_annotation_to_clip,
|
||||
file_annotation_to_clip_annotation,
|
||||
list_file_annotations,
|
||||
load_file_annotation,
|
||||
)
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"load_batdetect2_files_annotated_dataset",
|
||||
"load_batdetect2_merged_annotated_dataset",
|
||||
"BatDetect2FilesAnnotations",
|
||||
"BatDetect2MergedAnnotations",
|
||||
"AnnotationFilter",
|
||||
]
|
||||
|
||||
|
||||
class AnnotationFilter(BaseConfig):
|
||||
"""Configuration for filtering legacy FileAnnotations based on flags.
|
||||
|
||||
Specifies criteria based on boolean flags (`annotated` and `issues`)
|
||||
present within the legacy `FileAnnotation` JSON structure to select which
|
||||
entries (either files or records within a merged file) should be loaded and
|
||||
converted.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
only_annotated : bool, default=True
|
||||
If True, only process entries where the `annotated` flag in the JSON
|
||||
is set to `True`.
|
||||
exclude_issues : bool, default=True
|
||||
If True, skip processing entries where the `issues` flag in the JSON
|
||||
is set to `True`.
|
||||
"""
|
||||
|
||||
only_annotated: bool = True
|
||||
exclude_issues: bool = True
|
||||
|
||||
|
||||
class BatDetect2FilesAnnotations(AnnotatedDataset):
|
||||
"""Configuration for the legacy 'batdetect2' format (directory-based).
|
||||
|
||||
Defines a data source where annotations are stored as individual JSON files
|
||||
(one per recording, containing a `FileAnnotation` structure) within the
|
||||
`annotations_dir`. Requires a corresponding `audio_dir`. Assumes a naming
|
||||
convention links audio files to JSON files
|
||||
(e.g., `rec.wav` -> `rec.wav.json`).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
format : Literal["batdetect2"]
|
||||
The fixed format identifier for this configuration type.
|
||||
annotations_dir : Path
|
||||
Path to the directory containing the individual JSON annotation files.
|
||||
filter : AnnotationFilter, optional
|
||||
Configuration for filtering which files to process based on their
|
||||
`annotated` and `issues` flags. Defaults to requiring `annotated=True`
|
||||
and `issues=False`. Set explicitly to `None` in config (e.g.,
|
||||
`filter: null`) to disable filtering.
|
||||
"""
|
||||
|
||||
format: Literal["batdetect2"] = "batdetect2"
|
||||
annotations_dir: Path
|
||||
|
||||
filter: Optional[AnnotationFilter] = Field(
|
||||
default_factory=AnnotationFilter,
|
||||
)
|
||||
|
||||
|
||||
class BatDetect2MergedAnnotations(AnnotatedDataset):
|
||||
"""Configuration for the legacy 'batdetect2_file' format (merged file).
|
||||
|
||||
Defines a data source where annotations for multiple recordings (each as a
|
||||
`FileAnnotation` structure) are stored within a single JSON file specified
|
||||
by `annotations_path`. Audio files are expected in `audio_dir`.
|
||||
|
||||
Inherits `name`, `description`, and `audio_dir` from `AnnotatedDataset`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
format : Literal["batdetect2_file"]
|
||||
The fixed format identifier for this configuration type.
|
||||
annotations_path : Path
|
||||
Path to the single JSON file containing a list of `FileAnnotation`
|
||||
objects.
|
||||
filter : AnnotationFilter, optional
|
||||
Configuration for filtering which `FileAnnotation` entries within the
|
||||
merged file to process based on their `annotated` and `issues` flags.
|
||||
Defaults to requiring `annotated=True` and `issues=False`. Set to `None`
|
||||
in config (e.g., `filter: null`) to disable filtering.
|
||||
"""
|
||||
|
||||
format: Literal["batdetect2_file"] = "batdetect2_file"
|
||||
annotations_path: Path
|
||||
|
||||
filter: Optional[AnnotationFilter] = Field(
|
||||
default_factory=AnnotationFilter,
|
||||
)
|
||||
|
||||
|
||||
def load_batdetect2_files_annotated_dataset(
|
||||
dataset: BatDetect2FilesAnnotations,
|
||||
base_dir: Optional[PathLike] = None,
|
||||
) -> data.AnnotationSet:
|
||||
"""Load and convert 'batdetect2_file' annotations into an AnnotationSet.
|
||||
|
||||
Scans the specified `annotations_dir` for individual JSON annotation files.
|
||||
For each file: loads the legacy `FileAnnotation`, applies filtering based
|
||||
on `dataset.filter` (`annotated`/`issues` flags), attempts to find the
|
||||
corresponding audio file, converts valid entries to `ClipAnnotation`, and
|
||||
collects them into a single `soundevent.data.AnnotationSet`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : BatDetect2FilesAnnotations
|
||||
Configuration describing the 'batdetect2' (directory) data source.
|
||||
base_dir : PathLike, optional
|
||||
Optional base directory to resolve relative paths in `dataset.audio_dir`
|
||||
and `dataset.annotations_dir`. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.AnnotationSet
|
||||
An AnnotationSet containing all successfully loaded, filtered, and
|
||||
converted `ClipAnnotation` objects.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the `annotations_dir` or `audio_dir` does not exist. Errors finding
|
||||
individual JSON or audio files during iteration are logged and skipped.
|
||||
"""
|
||||
audio_dir = dataset.audio_dir
|
||||
path = dataset.annotations_dir
|
||||
|
||||
if base_dir:
|
||||
audio_dir = base_dir / audio_dir
|
||||
path = base_dir / path
|
||||
|
||||
paths = list_file_annotations(path)
|
||||
logger.debug(
|
||||
"Found {num_files} files in the annotations directory {path}",
|
||||
num_files=len(paths),
|
||||
path=path,
|
||||
)
|
||||
|
||||
annotations = []
|
||||
|
||||
for p in paths:
|
||||
try:
|
||||
file_annotation = load_file_annotation(p)
|
||||
except (FileNotFoundError, ValidationError):
|
||||
logger.warning("Could not load annotations in file {path}", path=p)
|
||||
continue
|
||||
|
||||
if (
|
||||
dataset.filter
|
||||
and dataset.filter.only_annotated
|
||||
and not file_annotation.annotated
|
||||
):
|
||||
logger.debug(
|
||||
"Annotation in file {path} omited: not annotated",
|
||||
path=p,
|
||||
)
|
||||
continue
|
||||
|
||||
if (
|
||||
dataset.filter
|
||||
and dataset.filter.exclude_issues
|
||||
and file_annotation.issues
|
||||
):
|
||||
logger.debug(
|
||||
"Annotation in file {path} omited: has issues",
|
||||
path=p,
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
clip = file_annotation_to_clip(
|
||||
file_annotation,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
except FileNotFoundError as err:
|
||||
logger.warning(
|
||||
"Did not find the audio related to the annotation file {path}. Error: {err}",
|
||||
path=p,
|
||||
err=err,
|
||||
)
|
||||
continue
|
||||
|
||||
annotations.append(
|
||||
file_annotation_to_clip_annotation(
|
||||
file_annotation,
|
||||
clip,
|
||||
)
|
||||
)
|
||||
|
||||
return data.AnnotationSet(
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
clip_annotations=annotations,
|
||||
)
|
||||
|
||||
|
||||
def load_batdetect2_merged_annotated_dataset(
|
||||
dataset: BatDetect2MergedAnnotations,
|
||||
base_dir: Optional[PathLike] = None,
|
||||
) -> data.AnnotationSet:
|
||||
"""Load and convert 'batdetect2_merged' annotations into an AnnotationSet.
|
||||
|
||||
Loads a single JSON file containing a list of legacy `FileAnnotation`
|
||||
objects. For each entry in the list: applies filtering based on
|
||||
`dataset.filter` (`annotated`/`issues` flags), attempts to find the
|
||||
corresponding audio file, converts valid entries to `ClipAnnotation`, and
|
||||
collects them into a single `soundevent.data.AnnotationSet`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : BatDetect2MergedAnnotations
|
||||
Configuration describing the 'batdetect2_file' (merged) data source.
|
||||
base_dir : PathLike, optional
|
||||
Optional base directory to resolve relative paths in `dataset.audio_dir`
|
||||
and `dataset.annotations_path`. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.AnnotationSet
|
||||
An AnnotationSet containing all successfully loaded, filtered, and
|
||||
converted `ClipAnnotation` objects from the merged file.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the `annotations_path` or `audio_dir` does not exist. Errors
|
||||
finding individual audio files referenced within the JSON are logged
|
||||
and skipped.
|
||||
json.JSONDecodeError
|
||||
If the annotations file is not valid JSON.
|
||||
TypeError
|
||||
If the root JSON structure is not a list.
|
||||
pydantic.ValidationError
|
||||
If entries within the JSON list do not conform to the legacy
|
||||
`FileAnnotation` structure.
|
||||
"""
|
||||
audio_dir = dataset.audio_dir
|
||||
path = dataset.annotations_path
|
||||
|
||||
if base_dir:
|
||||
audio_dir = base_dir / audio_dir
|
||||
path = base_dir / path
|
||||
|
||||
content = json.loads(Path(path).read_text())
|
||||
|
||||
if not isinstance(content, list):
|
||||
raise TypeError(
|
||||
f"Expected a list of FileAnnotations, but got {type(content)}",
|
||||
)
|
||||
|
||||
annotations = []
|
||||
|
||||
for ann in content:
|
||||
try:
|
||||
ann = FileAnnotation.model_validate(ann)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if (
|
||||
dataset.filter
|
||||
and dataset.filter.only_annotated
|
||||
and not ann.annotated
|
||||
):
|
||||
continue
|
||||
|
||||
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
|
||||
continue
|
||||
|
||||
try:
|
||||
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
||||
|
||||
return data.AnnotationSet(
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
clip_annotations=annotations,
|
||||
)
|
@ -1,80 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.annotations.legacy import (
|
||||
file_annotation_to_annotation_task,
|
||||
file_annotation_to_clip,
|
||||
file_annotation_to_clip_annotation,
|
||||
list_file_annotations,
|
||||
load_file_annotation,
|
||||
)
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"load_batdetect2_files_annotated_dataset",
|
||||
"BatDetect2FilesAnnotations",
|
||||
]
|
||||
|
||||
|
||||
class BatDetect2FilesAnnotations(AnnotatedDataset):
|
||||
format: Literal["batdetect2"] = "batdetect2"
|
||||
annotations_dir: Path
|
||||
|
||||
|
||||
def load_batdetect2_files_annotated_dataset(
|
||||
dataset: BatDetect2FilesAnnotations,
|
||||
base_dir: Optional[PathLike] = None,
|
||||
) -> data.AnnotationProject:
|
||||
"""Convert annotations to annotation project."""
|
||||
audio_dir = dataset.audio_dir
|
||||
path = dataset.annotations_dir
|
||||
|
||||
if base_dir:
|
||||
audio_dir = base_dir / audio_dir
|
||||
path = base_dir / path
|
||||
|
||||
paths = list_file_annotations(path)
|
||||
|
||||
annotations = []
|
||||
tasks = []
|
||||
|
||||
for p in paths:
|
||||
try:
|
||||
file_annotation = load_file_annotation(p)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
try:
|
||||
clip = file_annotation_to_clip(
|
||||
file_annotation,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
annotations.append(
|
||||
file_annotation_to_clip_annotation(
|
||||
file_annotation,
|
||||
clip,
|
||||
)
|
||||
)
|
||||
|
||||
tasks.append(
|
||||
file_annotation_to_annotation_task(
|
||||
file_annotation,
|
||||
clip,
|
||||
)
|
||||
)
|
||||
|
||||
return data.AnnotationProject(
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
clip_annotations=annotations,
|
||||
tasks=tasks,
|
||||
)
|
@ -1,64 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.annotations.legacy import (
|
||||
FileAnnotation,
|
||||
file_annotation_to_annotation_task,
|
||||
file_annotation_to_clip,
|
||||
file_annotation_to_clip_annotation,
|
||||
)
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
__all__ = [
|
||||
"BatDetect2MergedAnnotations",
|
||||
"load_batdetect2_merged_annotated_dataset",
|
||||
]
|
||||
|
||||
|
||||
class BatDetect2MergedAnnotations(AnnotatedDataset):
|
||||
format: Literal["batdetect2_file"] = "batdetect2_file"
|
||||
annotations_path: Path
|
||||
|
||||
|
||||
def load_batdetect2_merged_annotated_dataset(
|
||||
dataset: BatDetect2MergedAnnotations,
|
||||
base_dir: Optional[PathLike] = None,
|
||||
) -> data.AnnotationProject:
|
||||
audio_dir = dataset.audio_dir
|
||||
path = dataset.annotations_path
|
||||
|
||||
if base_dir:
|
||||
audio_dir = base_dir / audio_dir
|
||||
path = base_dir / path
|
||||
|
||||
content = json.loads(Path(path).read_text())
|
||||
|
||||
annotations = []
|
||||
tasks = []
|
||||
|
||||
for ann in content:
|
||||
try:
|
||||
ann = FileAnnotation.model_validate(ann)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
try:
|
||||
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
||||
tasks.append(file_annotation_to_annotation_task(ann, clip))
|
||||
|
||||
return data.AnnotationProject(
|
||||
name=dataset.name,
|
||||
description=dataset.description,
|
||||
clip_annotations=annotations,
|
||||
tasks=tasks,
|
||||
)
|
@ -1,11 +1,9 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal, Union
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"AnnotatedDataset",
|
||||
"BatDetect2MergedAnnotations",
|
||||
]
|
||||
|
||||
|
||||
@ -19,11 +17,10 @@ class AnnotatedDataset(BaseConfig):
|
||||
|
||||
Annotations associated with these recordings are defined by the
|
||||
`annotations` field, which supports various formats (e.g., AOEF files,
|
||||
specific CSV
|
||||
structures).
|
||||
Crucially, file paths referenced within the annotation data *must* be
|
||||
relative to the `audio_dir`. This ensures that the dataset definition
|
||||
remains portable across different systems and base directories.
|
||||
specific CSV structures). Crucially, file paths referenced within the
|
||||
annotation data *must* be relative to the `audio_dir`. This ensures that
|
||||
the dataset definition remains portable across different systems and base
|
||||
directories.
|
||||
|
||||
Attributes:
|
||||
name: A unique identifier for this data source.
|
||||
@ -37,5 +34,3 @@ class AnnotatedDataset(BaseConfig):
|
||||
name: str
|
||||
audio_dir: Path
|
||||
description: str = ""
|
||||
|
||||
|
||||
|
@ -1,37 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import load_config
|
||||
from batdetect2.data.annotations import load_annotated_dataset
|
||||
from batdetect2.data.types import Dataset
|
||||
|
||||
__all__ = [
|
||||
"load_dataset",
|
||||
"load_dataset_from_config",
|
||||
]
|
||||
|
||||
|
||||
def load_dataset(
|
||||
dataset: Dataset,
|
||||
base_dir: Optional[Path] = None,
|
||||
) -> data.AnnotationSet:
|
||||
clip_annotations = []
|
||||
for source in dataset.sources:
|
||||
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
|
||||
clip_annotations.extend(annotated_source.clip_annotations)
|
||||
return data.AnnotationSet(clip_annotations=clip_annotations)
|
||||
|
||||
|
||||
def load_dataset_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
):
|
||||
config = load_config(
|
||||
path=path,
|
||||
schema=Dataset,
|
||||
field=field,
|
||||
)
|
||||
return load_dataset(config, base_dir=base_dir)
|
251
batdetect2/data/datasets.py
Normal file
251
batdetect2/data/datasets.py
Normal file
@ -0,0 +1,251 @@
|
||||
"""Defines the overall dataset structure and provides loading/saving utilities.
|
||||
|
||||
This module focuses on defining what constitutes a BatDetect2 dataset,
|
||||
potentially composed of multiple distinct data sources with varying annotation
|
||||
formats. It provides mechanisms to load the annotation metadata from these
|
||||
sources into a unified representation.
|
||||
|
||||
The core components are:
|
||||
- `DatasetConfig`: A configuration class (typically loaded from YAML) that
|
||||
describes the dataset's name, description, and constituent sources.
|
||||
- `Dataset`: A type alias representing the loaded dataset as a list of
|
||||
`soundevent.data.ClipAnnotation` objects. Note that this implies all
|
||||
annotation metadata is loaded into memory.
|
||||
- Loading functions (`load_dataset`, `load_dataset_from_config`): To parse
|
||||
a `DatasetConfig` and load the corresponding annotation metadata.
|
||||
- Saving function (`save_dataset`): To save a loaded list of annotations
|
||||
into a standard `soundevent` format.
|
||||
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.data.annotations import (
|
||||
AnnotatedDataset,
|
||||
AnnotationFormats,
|
||||
load_annotated_dataset,
|
||||
)
|
||||
from batdetect2.targets.terms import data_source
|
||||
|
||||
__all__ = [
|
||||
"load_dataset",
|
||||
"load_dataset_from_config",
|
||||
"save_dataset",
|
||||
"Dataset",
|
||||
"DatasetConfig",
|
||||
]
|
||||
|
||||
|
||||
Dataset = List[data.ClipAnnotation]
|
||||
"""Type alias for a loaded dataset representation.
|
||||
|
||||
Represents an entire dataset *after loading* as a flat Python list containing
|
||||
all `soundevent.data.ClipAnnotation` objects gathered from all configured data
|
||||
sources.
|
||||
"""
|
||||
|
||||
|
||||
class DatasetConfig(BaseConfig):
|
||||
"""Configuration model defining the structure of a BatDetect2 dataset.
|
||||
|
||||
This class is typically loaded from a YAML file and describes the components
|
||||
of the dataset, including metadata and a list of data sources.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
A descriptive name for the dataset (e.g., "UK_Bats_Project_2024").
|
||||
description : str
|
||||
A longer description of the dataset's contents, origin, purpose, etc.
|
||||
sources : List[AnnotationFormats]
|
||||
A list defining the different data sources contributing to this
|
||||
dataset. Each item in the list must conform to one of the Pydantic
|
||||
models defined in the `AnnotationFormats` type union. The specific
|
||||
model used for each source is determined by the mandatory `format`
|
||||
field within the source's configuration, allowing BatDetect2 to use the
|
||||
correct parser for different annotation styles.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
sources: List[
|
||||
Annotated[AnnotationFormats, Field(..., discriminator="format")]
|
||||
]
|
||||
|
||||
|
||||
def load_dataset(
|
||||
dataset: DatasetConfig,
|
||||
base_dir: Optional[Path] = None,
|
||||
) -> Dataset:
|
||||
"""Load all clip annotations from the sources defined in a DatasetConfig.
|
||||
|
||||
Iterates through each data source specified in the `dataset_config`,
|
||||
delegates the loading and parsing of that source's annotations to
|
||||
`batdetect2.data.annotations.load_annotated_dataset` (which handles
|
||||
different data formats), and aggregates all resulting `ClipAnnotation`
|
||||
objects into a single flat list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset_config : DatasetConfig
|
||||
The configuration object describing the dataset and its sources.
|
||||
base_dir : Path, optional
|
||||
An optional base directory path. If provided, relative paths for
|
||||
metadata files or data directories within the `dataset_config`'s
|
||||
sources might be resolved relative to this directory. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dataset (List[data.ClipAnnotation])
|
||||
A flat list containing all loaded `ClipAnnotation` metadata objects
|
||||
from all specified sources.
|
||||
|
||||
Raises
|
||||
------
|
||||
Exception
|
||||
Can raise various exceptions during the delegated loading process
|
||||
(`load_annotated_dataset`) if files are not found, cannot be parsed
|
||||
according to the specified format, or other I/O errors occur.
|
||||
"""
|
||||
clip_annotations = []
|
||||
for source in dataset.sources:
|
||||
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
|
||||
logger.debug(
|
||||
"Loaded {num_examples} from dataset source '{source_name}'",
|
||||
num_examples=len(annotated_source.clip_annotations),
|
||||
source_name=source.name,
|
||||
)
|
||||
clip_annotations.extend(
|
||||
insert_source_tag(clip_annotation, source)
|
||||
for clip_annotation in annotated_source.clip_annotations
|
||||
)
|
||||
return clip_annotations
|
||||
|
||||
|
||||
def insert_source_tag(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
source: AnnotatedDataset,
|
||||
) -> data.ClipAnnotation:
|
||||
"""Insert the source tag into a ClipAnnotation.
|
||||
|
||||
This function adds a tag to the `ClipAnnotation` object, indicating the
|
||||
source from which it was loaded. The source information is derived from
|
||||
the `recording` attribute of the `ClipAnnotation`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip_annotation : data.ClipAnnotation
|
||||
The `ClipAnnotation` object to which the source tag will be added.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.ClipAnnotation
|
||||
The modified `ClipAnnotation` object with the source tag added.
|
||||
"""
|
||||
return clip_annotation.model_copy(
|
||||
update=dict(
|
||||
tags=[
|
||||
*clip_annotation.tags,
|
||||
data.Tag(
|
||||
term=data_source,
|
||||
value=source.name,
|
||||
),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def load_dataset_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
) -> Dataset:
|
||||
"""Load dataset annotation metadata from a configuration file.
|
||||
|
||||
This is a convenience function that first loads the `DatasetConfig` from
|
||||
the specified file path and optional nested field, and then calls
|
||||
`load_dataset` to load all corresponding `ClipAnnotation` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (e.g., YAML).
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
dataset configuration (e.g., "data.training_set"). If None, the
|
||||
entire file content is assumed to be the `DatasetConfig`.
|
||||
base_dir : Path, optional
|
||||
An optional base directory path to resolve relative paths within the
|
||||
configuration sources. Passed to `load_dataset`. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dataset (List[data.ClipAnnotation])
|
||||
A flat list containing all loaded `ClipAnnotation` metadata objects.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file `path` does not exist.
|
||||
yaml.YAMLError, pydantic.ValidationError, KeyError, TypeError
|
||||
If the configuration file is invalid, cannot be parsed, or does not
|
||||
match the `DatasetConfig` schema.
|
||||
Exception
|
||||
Can raise exceptions from `load_dataset` if loading data from sources
|
||||
fails.
|
||||
"""
|
||||
config = load_config(
|
||||
path=path,
|
||||
schema=DatasetConfig,
|
||||
field=field,
|
||||
)
|
||||
return load_dataset(config, base_dir=base_dir)
|
||||
|
||||
|
||||
def save_dataset(
|
||||
dataset: Dataset,
|
||||
path: data.PathLike,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
audio_dir: Optional[Path] = None,
|
||||
) -> None:
|
||||
"""Save a loaded dataset (list of ClipAnnotations) to a file.
|
||||
|
||||
Wraps the provided list of `ClipAnnotation` objects into a
|
||||
`soundevent.data.AnnotationSet` and saves it using `soundevent.io.save`.
|
||||
This saves the aggregated annotation metadata in the standard soundevent
|
||||
format.
|
||||
|
||||
Note: This function saves the *loaded annotation data*, not the original
|
||||
`DatasetConfig` structure that defined how the data was assembled from
|
||||
various sources.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : Dataset (List[data.ClipAnnotation])
|
||||
The list of clip annotations to save (typically the result of
|
||||
`load_dataset` or a split thereof).
|
||||
path : data.PathLike
|
||||
The output file path (e.g., 'train_annotations.json',
|
||||
'val_annotations.json'). The format is determined by `soundevent.io`.
|
||||
name : str, optional
|
||||
An optional name to assign to the saved `AnnotationSet`.
|
||||
description : str, optional
|
||||
An optional description to assign to the saved `AnnotationSet`.
|
||||
audio_dir : Path, optional
|
||||
Passed to `soundevent.io.save`. May be used to relativize audio file
|
||||
paths within the saved annotations if applicable to the save format.
|
||||
"""
|
||||
|
||||
annotation_set = data.AnnotationSet(
|
||||
name=name,
|
||||
description=description,
|
||||
clip_annotations=dataset,
|
||||
)
|
||||
io.save(annotation_set, path, audio_dir=audio_dir)
|
@ -1,29 +0,0 @@
|
||||
from typing import Annotated, List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data.annotations import AnnotationFormats
|
||||
|
||||
|
||||
class Dataset(BaseConfig):
|
||||
"""Represents a collection of one or more DatasetSources.
|
||||
|
||||
In the context of batdetect2, a Dataset aggregates multiple `DatasetSource`
|
||||
instances. It serves as the primary unit for defining data splits,
|
||||
typically used for model training, validation, or testing phases.
|
||||
|
||||
Attributes:
|
||||
name: A descriptive name for the overall dataset
|
||||
(e.g., "UK Training Set").
|
||||
description: A detailed explanation of the dataset's purpose,
|
||||
composition, how it was assembled, or any specific characteristics.
|
||||
sources: A list containing the `DatasetSource` objects included in this
|
||||
dataset.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
sources: List[
|
||||
Annotated[AnnotationFormats, Field(..., discriminator="format")]
|
||||
]
|
@ -0,0 +1,9 @@
|
||||
from batdetect2.evaluate.evaluate import (
|
||||
compute_error_auc,
|
||||
match_predictions_and_annotations,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"compute_error_auc",
|
||||
"match_predictions_and_annotations",
|
||||
]
|
@ -1,13 +1,10 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import auc, roc_curve
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_geometries
|
||||
|
||||
from batdetect2.train.targets import build_target_encoder, get_class_names
|
||||
|
||||
|
||||
def match_predictions_and_annotations(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
@ -51,20 +48,13 @@ def match_predictions_and_annotations(
|
||||
return matches
|
||||
|
||||
|
||||
def build_evaluation_dataframe(matches: List[data.Match]) -> pd.DataFrame:
|
||||
ret = []
|
||||
|
||||
for match in matches:
|
||||
pass
|
||||
|
||||
|
||||
def compute_error_auc(op_str, gt, pred, prob):
|
||||
# classification error
|
||||
pred_int = (pred > prob).astype(np.int32)
|
||||
class_acc = (pred_int == gt).mean() * 100.0
|
||||
|
||||
# ROC - area under curve
|
||||
fpr, tpr, thresholds = roc_curve(gt, pred)
|
||||
fpr, tpr, _ = roc_curve(gt, pred)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
|
||||
print(
|
||||
@ -177,7 +167,7 @@ def compute_pre_rec(
|
||||
file_ids.append([pid] * valid_inds.sum())
|
||||
|
||||
confidence = np.hstack(confidence)
|
||||
file_ids = np.hstack(file_ids).astype(np.int)
|
||||
file_ids = np.hstack(file_ids).astype(int)
|
||||
pred_boxes = np.vstack(pred_boxes)
|
||||
if len(pred_class) > 0:
|
||||
pred_class = np.hstack(pred_class)
|
||||
@ -197,8 +187,7 @@ def compute_pre_rec(
|
||||
|
||||
# note, files with the incorrect duration will cause a problem
|
||||
if (gg["start_times"] > file_dur).sum() > 0:
|
||||
print("Error: file duration incorrect for", gg["id"])
|
||||
assert False
|
||||
raise ValueError(f"Error: file duration incorrect for {gg['id']}")
|
||||
|
||||
boxes = np.vstack(
|
||||
(
|
||||
@ -244,6 +233,8 @@ def compute_pre_rec(
|
||||
gt_id = file_ids[ind]
|
||||
|
||||
valid_det = False
|
||||
det_ind = 0
|
||||
|
||||
if gt_boxes[gt_id].shape[0] > 0:
|
||||
# compute overlap
|
||||
valid_det, det_ind = compute_affinity_1d(
|
||||
@ -273,7 +264,7 @@ def compute_pre_rec(
|
||||
# store threshold values - used for plotting
|
||||
conf_sorted = np.sort(confidence)[::-1][valid_inds]
|
||||
thresholds = np.linspace(0.1, 0.9, 9)
|
||||
thresholds_inds = np.zeros(len(thresholds), dtype=np.int)
|
||||
thresholds_inds = np.zeros(len(thresholds), dtype=int)
|
||||
for ii, tt in enumerate(thresholds):
|
||||
thresholds_inds[ii] = np.argmin(conf_sorted > tt)
|
||||
thresholds_inds[thresholds_inds == 0] = -1
|
||||
@ -385,7 +376,7 @@ def compute_file_accuracy(gts, preds, num_classes):
|
||||
).mean(0)
|
||||
best_thresh = np.argmax(acc_per_thresh)
|
||||
best_acc = acc_per_thresh[best_thresh]
|
||||
pred_valid = pred_valid_all[:, best_thresh].astype(np.int).tolist()
|
||||
pred_valid = pred_valid_all[:, best_thresh].astype(int).tolist()
|
||||
|
||||
res = {}
|
||||
res["num_valid_files"] = len(gt_valid)
|
||||
|
@ -1,92 +1,135 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
"""Defines and builds the neural network models used in BatDetect2.
|
||||
|
||||
from soundevent.data import PathLike
|
||||
This package (`batdetect2.models`) contains the PyTorch implementations of the
|
||||
deep neural network architectures used for detecting and classifying bat calls
|
||||
from spectrograms. It provides modular components and configuration-driven
|
||||
assembly, allowing for experimentation and use of different architectural
|
||||
variants.
|
||||
|
||||
Key Submodules:
|
||||
- `.types`: Defines core data structures (`ModelOutput`) and abstract base
|
||||
classes (`BackboneModel`, `DetectionModel`) establishing interfaces.
|
||||
- `.blocks`: Provides reusable neural network building blocks.
|
||||
- `.encoder`: Defines and builds the downsampling path (encoder) of the network.
|
||||
- `.bottleneck`: Defines and builds the central bottleneck component.
|
||||
- `.decoder`: Defines and builds the upsampling path (decoder) of the network.
|
||||
- `.backbone`: Assembles the encoder, bottleneck, and decoder into a complete
|
||||
feature extraction backbone (e.g., a U-Net like structure).
|
||||
- `.heads`: Defines simple prediction heads (detection, classification, size)
|
||||
that attach to the backbone features.
|
||||
- `.detectors`: Assembles the backbone and prediction heads into the final,
|
||||
end-to-end `Detector` model.
|
||||
|
||||
This module re-exports the most important classes, configurations, and builder
|
||||
functions from these submodules for convenient access. The primary entry point
|
||||
for creating a standard BatDetect2 model instance is the `build_model` function
|
||||
provided here.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.backbones import (
|
||||
Net2DFast,
|
||||
Net2DFastNoAttn,
|
||||
Net2DFastNoCoordConv,
|
||||
Net2DPlain,
|
||||
Backbone,
|
||||
BackboneConfig,
|
||||
build_backbone,
|
||||
load_backbone_config,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.models.typing import BackboneModel
|
||||
from batdetect2.models.blocks import (
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvDownConfig,
|
||||
StandardConvUpConfig,
|
||||
)
|
||||
from batdetect2.models.bottleneck import (
|
||||
Bottleneck,
|
||||
BottleneckConfig,
|
||||
build_bottleneck,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
DecoderConfig,
|
||||
build_decoder,
|
||||
)
|
||||
from batdetect2.models.detectors import (
|
||||
Detector,
|
||||
build_detector,
|
||||
)
|
||||
from batdetect2.models.encoder import (
|
||||
DEFAULT_ENCODER_CONFIG,
|
||||
EncoderConfig,
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"BBoxHead",
|
||||
"Backbone",
|
||||
"BackboneConfig",
|
||||
"BackboneModel",
|
||||
"BackboneModel",
|
||||
"Bottleneck",
|
||||
"BottleneckConfig",
|
||||
"ClassifierHead",
|
||||
"ModelConfig",
|
||||
"ModelType",
|
||||
"Net2DFast",
|
||||
"Net2DFastNoAttn",
|
||||
"Net2DFastNoCoordConv",
|
||||
"build_architecture",
|
||||
"load_model_config",
|
||||
"ConvConfig",
|
||||
"DEFAULT_DECODER_CONFIG",
|
||||
"DEFAULT_ENCODER_CONFIG",
|
||||
"DecoderConfig",
|
||||
"DetectionModel",
|
||||
"Detector",
|
||||
"DetectorHead",
|
||||
"EncoderConfig",
|
||||
"FreqCoordConvDownConfig",
|
||||
"FreqCoordConvUpConfig",
|
||||
"ModelOutput",
|
||||
"StandardConvDownConfig",
|
||||
"StandardConvUpConfig",
|
||||
"build_backbone",
|
||||
"build_bottleneck",
|
||||
"build_decoder",
|
||||
"build_detector",
|
||||
"build_encoder",
|
||||
"build_model",
|
||||
"load_backbone_config",
|
||||
]
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
Net2DFast = "Net2DFast"
|
||||
Net2DFastNoAttn = "Net2DFastNoAttn"
|
||||
Net2DFastNoCoordConv = "Net2DFastNoCoordConv"
|
||||
Net2DPlain = "Net2DPlain"
|
||||
def build_model(
|
||||
num_classes: int,
|
||||
config: Optional[BackboneConfig] = None,
|
||||
) -> DetectionModel:
|
||||
"""Build the complete BatDetect2 detection model.
|
||||
|
||||
This high-level factory function constructs the standard BatDetect2 model
|
||||
architecture. It first builds the feature extraction backbone (typically an
|
||||
encoder-bottleneck-decoder structure) based on the provided
|
||||
`BackboneConfig` (or defaults if None), and then attaches the standard
|
||||
prediction heads (`DetectorHead`, `ClassifierHead`, `BBoxHead`) using the
|
||||
`build_detector` function.
|
||||
|
||||
class ModelConfig(BaseConfig):
|
||||
name: ModelType = ModelType.Net2DFast
|
||||
input_height: int = 128
|
||||
encoder_channels: Tuple[int, ...] = (1, 32, 64, 128)
|
||||
bottleneck_channels: int = 256
|
||||
decoder_channels: Tuple[int, ...] = (256, 64, 32, 32)
|
||||
out_channels: int = 32
|
||||
Parameters
|
||||
----------
|
||||
num_classes : int
|
||||
The number of specific target classes the model should predict
|
||||
(required for the `ClassifierHead`). Must be positive.
|
||||
config : BackboneConfig, optional
|
||||
Configuration object specifying the architecture of the backbone
|
||||
(encoder, bottleneck, decoder). If None, default configurations defined
|
||||
within the respective builder functions (`build_encoder`, etc.) will be
|
||||
used to construct a default backbone architecture.
|
||||
|
||||
Returns
|
||||
-------
|
||||
DetectionModel
|
||||
An initialized `Detector` model instance.
|
||||
|
||||
def load_model_config(
|
||||
path: PathLike, field: Optional[str] = None
|
||||
) -> ModelConfig:
|
||||
return load_config(path, schema=ModelConfig, field=field)
|
||||
|
||||
|
||||
def build_architecture(
|
||||
config: Optional[ModelConfig] = None,
|
||||
) -> BackboneModel:
|
||||
config = config or ModelConfig()
|
||||
|
||||
if config.name == ModelType.Net2DFast:
|
||||
return Net2DFast(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DFastNoAttn:
|
||||
return Net2DFastNoAttn(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DFastNoCoordConv:
|
||||
return Net2DFastNoCoordConv(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
if config.name == ModelType.Net2DPlain:
|
||||
return Net2DPlain(
|
||||
input_height=config.input_height,
|
||||
encoder_channels=config.encoder_channels,
|
||||
bottleneck_channels=config.bottleneck_channels,
|
||||
decoder_channels=config.decoder_channels,
|
||||
out_channels=config.out_channels,
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown model type: {config.name}")
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num_classes` is not positive, or if errors occur during the
|
||||
construction of the backbone or detector components (e.g., incompatible
|
||||
configurations, invalid parameters).
|
||||
"""
|
||||
backbone = build_backbone(config or BackboneConfig())
|
||||
return build_detector(num_classes, backbone)
|
||||
|
@ -1,181 +1,353 @@
|
||||
from typing import Sequence, Tuple
|
||||
"""Assembles a complete Encoder-Decoder Backbone network.
|
||||
|
||||
This module defines the configuration (`BackboneConfig`) and implementation
|
||||
(`Backbone`) for a standard encoder-decoder style neural network backbone.
|
||||
|
||||
It orchestrates the connection between three main components, built using their
|
||||
respective configurations and factory functions from sibling modules:
|
||||
1. Encoder (`batdetect2.models.encoder`): Downsampling path, extracts features
|
||||
at multiple resolutions and provides skip connections.
|
||||
2. Bottleneck (`batdetect2.models.bottleneck`): Processes features at the
|
||||
lowest resolution, optionally applying self-attention.
|
||||
3. Decoder (`batdetect2.models.decoder`): Upsampling path, reconstructs high-
|
||||
resolution features using bottleneck features and skip connections.
|
||||
|
||||
The resulting `Backbone` module takes a spectrogram as input and outputs a
|
||||
final feature map, typically used by subsequent prediction heads. It includes
|
||||
automatic padding to handle input sizes not perfectly divisible by the
|
||||
network's total downsampling factor.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from soundevent import data
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.models.blocks import (
|
||||
ConvBlock,
|
||||
Decoder,
|
||||
DownscalingLayer,
|
||||
Encoder,
|
||||
SelfAttention,
|
||||
UpscalingLayer,
|
||||
VerticalConv,
|
||||
)
|
||||
from batdetect2.models.typing import BackboneModel
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.blocks import ConvBlock
|
||||
from batdetect2.models.bottleneck import BottleneckConfig, build_bottleneck
|
||||
from batdetect2.models.decoder import Decoder, DecoderConfig, build_decoder
|
||||
from batdetect2.models.encoder import Encoder, EncoderConfig, build_encoder
|
||||
from batdetect2.models.types import BackboneModel
|
||||
|
||||
__all__ = [
|
||||
"Net2DFast",
|
||||
"Net2DFastNoAttn",
|
||||
"Net2DFastNoCoordConv",
|
||||
"Backbone",
|
||||
"BackboneConfig",
|
||||
"load_backbone_config",
|
||||
"build_backbone",
|
||||
]
|
||||
|
||||
|
||||
class Net2DPlain(BackboneModel):
|
||||
downscaling_layer_type: DownscalingLayer = "ConvBlockDownStandard"
|
||||
upscaling_layer_type: UpscalingLayer = "ConvBlockUpStandard"
|
||||
class Backbone(BackboneModel):
|
||||
"""Encoder-Decoder Backbone Network Implementation.
|
||||
|
||||
Combines an Encoder, Bottleneck, and Decoder module sequentially, using
|
||||
skip connections between the Encoder and Decoder. Implements the standard
|
||||
U-Net style forward pass. Includes automatic input padding to handle
|
||||
various input sizes and a final convolutional block to adjust the output
|
||||
channels.
|
||||
|
||||
This class inherits from `BackboneModel` and implements its `forward`
|
||||
method. Instances are typically created using the `build_backbone` factory
|
||||
function.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int
|
||||
Expected height of the input spectrogram.
|
||||
out_channels : int
|
||||
Number of channels in the final output feature map.
|
||||
encoder : Encoder
|
||||
The instantiated encoder module.
|
||||
decoder : Decoder
|
||||
The instantiated decoder module.
|
||||
bottleneck : nn.Module
|
||||
The instantiated bottleneck module.
|
||||
final_conv : ConvBlock
|
||||
Final convolutional block applied after the decoder.
|
||||
divide_factor : int
|
||||
The total downsampling factor (2^depth) applied by the encoder,
|
||||
used for automatic input padding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_height: int = 128,
|
||||
encoder_channels: Sequence[int] = (1, 32, 64, 128),
|
||||
bottleneck_channels: int = 256,
|
||||
decoder_channels: Sequence[int] = (256, 64, 32, 32),
|
||||
out_channels: int = 32,
|
||||
input_height: int,
|
||||
out_channels: int,
|
||||
encoder: Encoder,
|
||||
decoder: Decoder,
|
||||
bottleneck: nn.Module,
|
||||
):
|
||||
super().__init__()
|
||||
"""Initialize the Backbone network.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Expected height of the input spectrogram.
|
||||
out_channels : int
|
||||
Desired number of output channels for the backbone's feature map.
|
||||
encoder : Encoder
|
||||
An initialized Encoder module.
|
||||
decoder : Decoder
|
||||
An initialized Decoder module.
|
||||
bottleneck : nn.Module
|
||||
An initialized Bottleneck module.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If component output/input channels or heights are incompatible.
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_height = input_height
|
||||
self.encoder_channels = tuple(encoder_channels)
|
||||
self.decoder_channels = tuple(decoder_channels)
|
||||
self.out_channels = out_channels
|
||||
|
||||
if len(encoder_channels) != len(decoder_channels):
|
||||
raise ValueError(
|
||||
f"Mismatched encoder and decoder channel lists. "
|
||||
f"The encoder has {len(encoder_channels)} channels "
|
||||
f"(implying {len(encoder_channels) - 1} layers), "
|
||||
f"while the decoder has {len(decoder_channels)} channels "
|
||||
f"(implying {len(decoder_channels) - 1} layers). "
|
||||
f"These lengths must be equal."
|
||||
)
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.bottleneck = bottleneck
|
||||
|
||||
self.divide_factor = 2 ** (len(encoder_channels) - 1)
|
||||
if self.input_height % self.divide_factor != 0:
|
||||
raise ValueError(
|
||||
f"Input height ({self.input_height}) must be divisible by "
|
||||
f"the divide factor ({self.divide_factor}). "
|
||||
f"This ensures proper upscaling after downscaling to recover "
|
||||
f"the original input height."
|
||||
)
|
||||
|
||||
self.encoder = Encoder(
|
||||
channels=encoder_channels,
|
||||
input_height=self.input_height,
|
||||
layer_type=self.downscaling_layer_type,
|
||||
)
|
||||
|
||||
self.conv_same_1 = ConvBlock(
|
||||
in_channels=encoder_channels[-1],
|
||||
out_channels=bottleneck_channels,
|
||||
)
|
||||
|
||||
# bottleneck
|
||||
self.conv_vert = VerticalConv(
|
||||
in_channels=bottleneck_channels,
|
||||
out_channels=bottleneck_channels,
|
||||
input_height=self.input_height // (2**self.encoder.depth),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
channels=decoder_channels,
|
||||
input_height=self.input_height,
|
||||
layer_type=self.upscaling_layer_type,
|
||||
)
|
||||
|
||||
self.conv_same_2 = ConvBlock(
|
||||
in_channels=decoder_channels[-1],
|
||||
self.final_conv = ConvBlock(
|
||||
in_channels=decoder.out_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
|
||||
# Down/Up scaling factor. Need to ensure inputs are divisible by
|
||||
# this factor in order to be processed by the down/up scaling layers
|
||||
# and recover the correct shape
|
||||
self.divide_factor = input_height // self.encoder.output_height
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
|
||||
"""Perform the forward pass through the encoder-decoder backbone.
|
||||
|
||||
Applies padding, runs encoder, bottleneck, decoder (with skip
|
||||
connections), removes padding, and applies a final convolution.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor, shape `(B, C_in, H_in, W_in)`. Must match
|
||||
`self.encoder.input_channels` and `self.input_height`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output feature map tensor, shape `(B, C_out, H_in, W_in)`, where
|
||||
`C_out` is `self.out_channels`.
|
||||
"""
|
||||
spec, h_pad, w_pad = _pad_adjust(spec, factor=self.divide_factor)
|
||||
|
||||
# encoder
|
||||
residuals = self.encoder(spec)
|
||||
residuals[-1] = self.conv_same_1(residuals[-1])
|
||||
|
||||
# bottleneck
|
||||
x = self.conv_vert(residuals[-1])
|
||||
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
|
||||
x = self.bottleneck(residuals[-1])
|
||||
|
||||
# decoder
|
||||
x = self.decoder(x, residuals=residuals)
|
||||
|
||||
# Restore original size
|
||||
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
||||
x = _restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
||||
|
||||
return self.conv_same_2(x)
|
||||
return self.final_conv(x)
|
||||
|
||||
|
||||
class Net2DFast(Net2DPlain):
|
||||
downscaling_layer_type = "ConvBlockDownCoordF"
|
||||
upscaling_layer_type = "ConvBlockUpF"
|
||||
class BackboneConfig(BaseConfig):
|
||||
"""Configuration for the Encoder-Decoder Backbone network.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_height: int = 128,
|
||||
encoder_channels: Sequence[int] = (1, 32, 64, 128),
|
||||
bottleneck_channels: int = 256,
|
||||
decoder_channels: Sequence[int] = (256, 64, 32, 32),
|
||||
out_channels: int = 32,
|
||||
):
|
||||
super().__init__(
|
||||
input_height=input_height,
|
||||
encoder_channels=encoder_channels,
|
||||
bottleneck_channels=bottleneck_channels,
|
||||
decoder_channels=decoder_channels,
|
||||
out_channels=out_channels,
|
||||
Aggregates configurations for the encoder, bottleneck, and decoder
|
||||
components, along with defining the input and final output dimensions
|
||||
for the complete backbone.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int, default=128
|
||||
Expected height (frequency bins) of the input spectrograms to the
|
||||
backbone. Must be positive.
|
||||
in_channels : int, default=1
|
||||
Expected number of channels in the input spectrograms (e.g., 1 for
|
||||
mono). Must be positive.
|
||||
encoder : EncoderConfig, optional
|
||||
Configuration for the encoder. If None or omitted,
|
||||
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
|
||||
encoder module) will be used.
|
||||
bottleneck : BottleneckConfig, optional
|
||||
Configuration for the bottleneck layer connecting encoder and decoder.
|
||||
If None or omitted, the default bottleneck configuration will be used.
|
||||
decoder : DecoderConfig, optional
|
||||
Configuration for the decoder. If None or omitted,
|
||||
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
|
||||
decoder module) will be used.
|
||||
out_channels : int, default=32
|
||||
Desired number of channels in the final feature map output by the
|
||||
backbone. Must be positive.
|
||||
"""
|
||||
|
||||
input_height: int = 128
|
||||
in_channels: int = 1
|
||||
encoder: Optional[EncoderConfig] = None
|
||||
bottleneck: Optional[BottleneckConfig] = None
|
||||
decoder: Optional[DecoderConfig] = None
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> BackboneConfig:
|
||||
"""Load the backbone configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`BackboneConfig` schema, potentially extracting data from a nested field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
backbone configuration (e.g., "model.backbone"). If None, the entire
|
||||
file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneConfig
|
||||
The loaded and validated backbone configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded config data does not conform to `BackboneConfig`.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path.
|
||||
"""
|
||||
return load_config(path, schema=BackboneConfig, field=field)
|
||||
|
||||
|
||||
def build_backbone(config: BackboneConfig) -> BackboneModel:
|
||||
"""Factory function to build a Backbone from configuration.
|
||||
|
||||
Constructs the `Encoder`, `Bottleneck`, and `Decoder` components based on
|
||||
the provided `BackboneConfig`, validates their compatibility, and assembles
|
||||
them into a `Backbone` instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : BackboneConfig
|
||||
The configuration object detailing the backbone architecture, including
|
||||
input dimensions and configurations for encoder, bottleneck, and
|
||||
decoder.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneModel
|
||||
An initialized `Backbone` module ready for use.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If sub-component configurations are incompatible
|
||||
(e.g., channel mismatches, decoder output height doesn't match backbone
|
||||
input height).
|
||||
NotImplementedError
|
||||
If an unknown block type is specified in sub-configs.
|
||||
"""
|
||||
encoder = build_encoder(
|
||||
in_channels=config.in_channels,
|
||||
input_height=config.input_height,
|
||||
config=config.encoder,
|
||||
)
|
||||
|
||||
bottleneck = build_bottleneck(
|
||||
input_height=encoder.output_height,
|
||||
in_channels=encoder.out_channels,
|
||||
config=config.bottleneck,
|
||||
)
|
||||
|
||||
decoder = build_decoder(
|
||||
in_channels=bottleneck.out_channels,
|
||||
input_height=encoder.output_height,
|
||||
config=config.decoder,
|
||||
)
|
||||
|
||||
if decoder.output_height != config.input_height:
|
||||
raise ValueError(
|
||||
"Invalid configuration: Decoder output height "
|
||||
f"({decoder.output_height}) must match the Backbone input height "
|
||||
f"({config.input_height}). Check encoder/decoder layer "
|
||||
"configurations and input/bottleneck heights."
|
||||
)
|
||||
|
||||
self.att = SelfAttention(bottleneck_channels, bottleneck_channels)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
spec, h_pad, w_pad = pad_adjust(spec, factor=self.divide_factor)
|
||||
|
||||
# encoder
|
||||
residuals = self.encoder(spec)
|
||||
residuals[-1] = self.conv_same_1(residuals[-1])
|
||||
|
||||
# bottleneck
|
||||
x = self.conv_vert(residuals[-1])
|
||||
x = self.att(x)
|
||||
x = x.repeat([1, 1, residuals[-1].shape[-2], 1])
|
||||
|
||||
# decoder
|
||||
x = self.decoder(x, residuals=residuals)
|
||||
|
||||
# Restore original size
|
||||
x = restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
||||
|
||||
return self.conv_same_2(x)
|
||||
return Backbone(
|
||||
input_height=config.input_height,
|
||||
out_channels=config.out_channels,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
bottleneck=bottleneck,
|
||||
)
|
||||
|
||||
|
||||
class Net2DFastNoAttn(Net2DPlain):
|
||||
downscaling_layer_type = "ConvBlockDownCoordF"
|
||||
upscaling_layer_type = "ConvBlockUpF"
|
||||
|
||||
|
||||
class Net2DFastNoCoordConv(Net2DFast):
|
||||
downscaling_layer_type = "ConvBlockDownStandard"
|
||||
upscaling_layer_type = "ConvBlockUpStandard"
|
||||
|
||||
|
||||
def pad_adjust(
|
||||
def _pad_adjust(
|
||||
spec: torch.Tensor,
|
||||
factor: int = 32,
|
||||
) -> Tuple[torch.Tensor, int, int]:
|
||||
print(spec.shape)
|
||||
"""Pad tensor height and width to be divisible by a factor.
|
||||
|
||||
Calculates the required padding for the last two dimensions (H, W) to make
|
||||
them divisible by `factor` and applies right/bottom padding using
|
||||
`torch.nn.functional.pad`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input tensor, typically shape `(B, C, H, W)`.
|
||||
factor : int, default=32
|
||||
The factor to make height and width divisible by.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[torch.Tensor, int, int]
|
||||
A tuple containing:
|
||||
- The padded tensor.
|
||||
- The amount of padding added to height (`h_pad`).
|
||||
- The amount of padding added to width (`w_pad`).
|
||||
"""
|
||||
h, w = spec.shape[2:]
|
||||
h_pad = -h % factor
|
||||
w_pad = -w % factor
|
||||
|
||||
if h_pad == 0 and w_pad == 0:
|
||||
return spec, 0, 0
|
||||
|
||||
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
|
||||
|
||||
|
||||
def restore_pad(
|
||||
def _restore_pad(
|
||||
x: torch.Tensor, h_pad: int = 0, w_pad: int = 0
|
||||
) -> torch.Tensor:
|
||||
# Restore original size
|
||||
"""Remove padding added by _pad_adjust.
|
||||
|
||||
Removes padding from the bottom and right edges of the tensor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Padded tensor, typically shape `(B, C, H_padded, W_padded)`.
|
||||
h_pad : int, default=0
|
||||
Amount of padding previously added to the height (bottom).
|
||||
w_pad : int, default=0
|
||||
Amount of padding previously added to the width (right).
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Tensor with padding removed, shape `(B, C, H_original, W_original)`.
|
||||
"""
|
||||
if h_pad > 0:
|
||||
x = x[:, :, :-h_pad, :]
|
||||
|
||||
|
@ -1,42 +1,107 @@
|
||||
"""Module containing custom NN blocks.
|
||||
"""Commonly used neural network building blocks for BatDetect2 models.
|
||||
|
||||
All these classes are subclasses of `torch.nn.Module` and can be used to build
|
||||
complex neural network architectures.
|
||||
This module provides various reusable `torch.nn.Module` subclasses that form
|
||||
the fundamental building blocks for constructing convolutional neural network
|
||||
architectures, particularly encoder-decoder backbones used in BatDetect2.
|
||||
|
||||
It includes standard components like basic convolutional blocks (`ConvBlock`),
|
||||
blocks incorporating downsampling (`StandardConvDownBlock`), and blocks with
|
||||
upsampling (`StandardConvUpBlock`).
|
||||
|
||||
Additionally, it features specialized layers investigated in BatDetect2
|
||||
research:
|
||||
|
||||
- `SelfAttention`: Applies self-attention along the time dimension, enabling
|
||||
the model to weigh information across the entire temporal context, often
|
||||
used in the bottleneck of an encoder-decoder.
|
||||
- `FreqCoordConvDownBlock` / `FreqCoordConvUpBlock`: Implement the "CoordConv"
|
||||
concept by concatenating normalized frequency coordinate information as an
|
||||
extra channel to the input of convolutional layers. This explicitly provides
|
||||
spatial frequency information to filters, potentially enabling them to learn
|
||||
frequency-dependent patterns more effectively.
|
||||
|
||||
These blocks can be utilized directly in custom PyTorch model definitions or
|
||||
assembled into larger architectures.
|
||||
|
||||
A unified factory function `build_layer_from_config` allows creating instances
|
||||
of these blocks based on configuration objects.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Iterable, List, Literal, Sequence, Tuple
|
||||
from typing import Annotated, List, Literal, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from itertools import pairwise
|
||||
else:
|
||||
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"ConvBlock",
|
||||
"ConvBlockDownCoordF",
|
||||
"ConvBlockDownStandard",
|
||||
"ConvBlockUpF",
|
||||
"ConvBlockUpStandard",
|
||||
"SelfAttention",
|
||||
"BlockGroupConfig",
|
||||
"VerticalConv",
|
||||
"DownscalingLayer",
|
||||
"UpscalingLayer",
|
||||
"FreqCoordConvDownBlock",
|
||||
"StandardConvDownBlock",
|
||||
"FreqCoordConvUpBlock",
|
||||
"StandardConvUpBlock",
|
||||
"SelfAttention",
|
||||
"ConvConfig",
|
||||
"FreqCoordConvDownConfig",
|
||||
"StandardConvDownConfig",
|
||||
"FreqCoordConvUpConfig",
|
||||
"StandardConvUpConfig",
|
||||
"LayerConfig",
|
||||
"build_layer_from_config",
|
||||
]
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""Self-Attention module.
|
||||
"""Self-Attention mechanism operating along the time dimension.
|
||||
|
||||
This module implements self-attention mechanism.
|
||||
This module implements a scaled dot-product self-attention mechanism,
|
||||
specifically designed here to operate across the time steps of an input
|
||||
feature map, typically after spatial dimensions (like frequency) have been
|
||||
condensed or squeezed.
|
||||
|
||||
By calculating attention weights between all pairs of time steps, it allows
|
||||
the model to capture long-range temporal dependencies and focus on relevant
|
||||
parts of the sequence. It's often employed in the bottleneck or
|
||||
intermediate layers of an encoder-decoder architecture to integrate global
|
||||
temporal context.
|
||||
|
||||
The implementation uses linear projections to create query, key, and value
|
||||
representations, computes scaled dot-product attention scores, applies
|
||||
softmax, and produces an output by weighting the values according to the
|
||||
attention scores, followed by a final linear projection. Positional encoding
|
||||
is not explicitly included in this block.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels (features per time step after spatial squeeze).
|
||||
attention_channels : int
|
||||
Number of channels for the query, key, and value projections. Also the
|
||||
dimension of the output projection's input.
|
||||
temperature : float, default=1.0
|
||||
Scaling factor applied *before* the final projection layer. Can be used
|
||||
to adjust the sharpness or focus of the attention mechanism, although
|
||||
scaling within the softmax (dividing by sqrt(dim)) is more common for
|
||||
standard transformers. Here it scales the weighted values.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
key_fun : nn.Linear
|
||||
Linear layer for key projection.
|
||||
value_fun : nn.Linear
|
||||
Linear layer for value projection.
|
||||
query_fun : nn.Linear
|
||||
Linear layer for query projection.
|
||||
pro_fun : nn.Linear
|
||||
Final linear projection layer applied after attention weighting.
|
||||
temperature : float
|
||||
Scaling factor applied before final projection.
|
||||
att_dim : int
|
||||
Dimensionality of the attention space (`attention_channels`).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -56,6 +121,27 @@ class SelfAttention(nn.Module):
|
||||
self.pro_fun = nn.Linear(attention_channels, in_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply self-attention along the time dimension.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, expected shape `(B, C, H, W)`, where H is typically
|
||||
squeezed (e.g., H=1 after a `VerticalConv` or pooling) before
|
||||
applying attention along the W (time) dimension.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor of the same shape as the input `(B, C, H, W)`, where
|
||||
attention has been applied across the W dimension.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If input tensor dimensions are incompatible with operations.
|
||||
"""
|
||||
|
||||
x = x.squeeze(2).permute(0, 2, 1)
|
||||
|
||||
key = torch.matmul(
|
||||
@ -82,7 +168,42 @@ class SelfAttention(nn.Module):
|
||||
return op
|
||||
|
||||
|
||||
class ConvConfig(BaseConfig):
|
||||
"""Configuration for a basic ConvBlock."""
|
||||
|
||||
block_type: Literal["ConvBlock"] = "ConvBlock"
|
||||
"""Discriminator field indicating the block type."""
|
||||
|
||||
out_channels: int
|
||||
"""Number of output channels."""
|
||||
|
||||
kernel_size: int = 3
|
||||
"""Size of the square convolutional kernel."""
|
||||
|
||||
pad_size: int = 1
|
||||
"""Padding size."""
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Basic Convolutional Block.
|
||||
|
||||
A standard building block consisting of a 2D convolution, followed by
|
||||
batch normalization and a ReLU activation function.
|
||||
|
||||
Sequence: Conv2d -> BatchNorm2d -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input tensor.
|
||||
out_channels : int
|
||||
Number of channels produced by the convolution.
|
||||
kernel_size : int, default=3
|
||||
Size of the square convolutional kernel.
|
||||
pad_size : int, default=1
|
||||
Amount of padding added to preserve spatial dimensions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@ -100,27 +221,41 @@ class ConvBlock(nn.Module):
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply Conv -> BN -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H, W)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H, W)`.
|
||||
"""
|
||||
return F.relu_(self.conv_bn(self.conv(x)))
|
||||
|
||||
|
||||
class VerticalConv(nn.Module):
|
||||
"""Convolutional layer over full height.
|
||||
"""Convolutional layer that aggregates features across the entire height.
|
||||
|
||||
This layer applies a convolution that captures information across the
|
||||
entire height of the input image. It uses a kernel with the same height as
|
||||
the input, effectively condensing the vertical information into a single
|
||||
output row.
|
||||
Applies a 2D convolution using a kernel with shape `(input_height, 1)`.
|
||||
This collapses the height dimension (H) to 1 while preserving the width (W),
|
||||
effectively summarizing features across the full vertical extent (e.g.,
|
||||
frequency axis) at each time step. Followed by BatchNorm and ReLU.
|
||||
|
||||
More specifically:
|
||||
Useful for summarizing frequency information before applying operations
|
||||
along the time axis (like SelfAttention).
|
||||
|
||||
* **Input:** (B, C, H, W) where B is the batch size, C is the number of
|
||||
input channels, H is the image height, and W is the image width.
|
||||
* **Kernel:** (C', H, 1) where C' is the number of output channels.
|
||||
* **Output:** (B, C', 1, W) - The height dimension is 1 because the
|
||||
convolution integrates information from all rows of the input.
|
||||
|
||||
This process effectively extracts features that span the full height of
|
||||
the input image.
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input tensor.
|
||||
out_channels : int
|
||||
Number of channels produced by the convolution.
|
||||
input_height : int
|
||||
The height (H dimension) of the input tensor. The convolutional kernel
|
||||
will be sized `(input_height, 1)`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -139,14 +274,66 @@ class VerticalConv(nn.Module):
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply Vertical Conv -> BN -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H, W)`, where H must match the
|
||||
`input_height` provided during initialization.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, 1, W)`.
|
||||
"""
|
||||
return F.relu_(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class ConvBlockDownCoordF(nn.Module):
|
||||
"""Convolutional Block with Downsampling and Coord Feature.
|
||||
class FreqCoordConvDownConfig(BaseConfig):
|
||||
"""Configuration for a FreqCoordConvDownBlock."""
|
||||
|
||||
This block performs convolution followed by downsampling
|
||||
and concatenates with coordinate information.
|
||||
block_type: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
|
||||
"""Discriminator field indicating the block type."""
|
||||
|
||||
out_channels: int
|
||||
"""Number of output channels."""
|
||||
|
||||
kernel_size: int = 3
|
||||
"""Size of the square convolutional kernel."""
|
||||
|
||||
pad_size: int = 1
|
||||
"""Padding size."""
|
||||
|
||||
|
||||
class FreqCoordConvDownBlock(nn.Module):
|
||||
"""Downsampling Conv Block incorporating Frequency Coordinate features.
|
||||
|
||||
This block implements a downsampling step (Conv2d + MaxPool2d) commonly
|
||||
used in CNN encoders. Before the convolution, it concatenates an extra
|
||||
channel representing the normalized vertical coordinate (frequency) to the
|
||||
input tensor.
|
||||
|
||||
The purpose of adding coordinate features is to potentially help the
|
||||
convolutional filters become spatially aware, allowing them to learn
|
||||
patterns that might depend on the relative frequency position within the
|
||||
spectrogram.
|
||||
|
||||
Sequence: Concat Coords -> Conv -> MaxPool -> BatchNorm -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input tensor.
|
||||
out_channels : int
|
||||
Number of output channels after the convolution.
|
||||
input_height : int
|
||||
Height (H dimension, frequency bins) of the input tensor to this block.
|
||||
Used to generate the coordinate features.
|
||||
kernel_size : int, default=3
|
||||
Size of the square convolutional kernel.
|
||||
pad_size : int, default=1
|
||||
Padding added before convolution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -156,7 +343,6 @@ class ConvBlockDownCoordF(nn.Module):
|
||||
input_height: int,
|
||||
kernel_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
stride: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -169,11 +355,24 @@ class ConvBlockDownCoordF(nn.Module):
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
stride=stride,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply CoordF -> Conv -> MaxPool -> BN -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H, W)`, where H must match
|
||||
`input_height`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H/2, W/2)` (due to MaxPool).
|
||||
"""
|
||||
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
||||
x = torch.cat((x, freq_info), 1)
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
@ -181,10 +380,40 @@ class ConvBlockDownCoordF(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class ConvBlockDownStandard(nn.Module):
|
||||
"""Convolutional Block with Downsampling.
|
||||
class StandardConvDownConfig(BaseConfig):
|
||||
"""Configuration for a StandardConvDownBlock."""
|
||||
|
||||
This block performs convolution followed by downsampling.
|
||||
block_type: Literal["StandardConvDown"] = "StandardConvDown"
|
||||
"""Discriminator field indicating the block type."""
|
||||
|
||||
out_channels: int
|
||||
"""Number of output channels."""
|
||||
|
||||
kernel_size: int = 3
|
||||
"""Size of the square convolutional kernel."""
|
||||
|
||||
pad_size: int = 1
|
||||
"""Padding size."""
|
||||
|
||||
|
||||
class StandardConvDownBlock(nn.Module):
|
||||
"""Standard Downsampling Convolutional Block.
|
||||
|
||||
A basic downsampling block consisting of a 2D convolution, followed by
|
||||
2x2 max pooling, batch normalization, and ReLU activation.
|
||||
|
||||
Sequence: Conv -> MaxPool -> BN -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input tensor.
|
||||
out_channels : int
|
||||
Number of output channels after the convolution.
|
||||
kernel_size : int, default=3
|
||||
Size of the square convolutional kernel.
|
||||
pad_size : int, default=1
|
||||
Padding added before convolution.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -193,31 +422,83 @@ class ConvBlockDownStandard(nn.Module):
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
stride: int = 1,
|
||||
):
|
||||
super(ConvBlockDownStandard, self).__init__()
|
||||
super(StandardConvDownBlock, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
stride=stride,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
"""Apply Conv -> MaxPool -> BN -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H, W)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H/2, W/2)`.
|
||||
"""
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
return F.relu(self.conv_bn(x), inplace=True)
|
||||
|
||||
|
||||
DownscalingLayer = Literal["ConvBlockDownStandard", "ConvBlockDownCoordF"]
|
||||
class FreqCoordConvUpConfig(BaseConfig):
|
||||
"""Configuration for a FreqCoordConvUpBlock."""
|
||||
|
||||
block_type: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
|
||||
"""Discriminator field indicating the block type."""
|
||||
|
||||
out_channels: int
|
||||
"""Number of output channels."""
|
||||
|
||||
kernel_size: int = 3
|
||||
"""Size of the square convolutional kernel."""
|
||||
|
||||
pad_size: int = 1
|
||||
"""Padding size."""
|
||||
|
||||
|
||||
class ConvBlockUpF(nn.Module):
|
||||
"""Convolutional Block with Upsampling and Coord Feature.
|
||||
class FreqCoordConvUpBlock(nn.Module):
|
||||
"""Upsampling Conv Block incorporating Frequency Coordinate features.
|
||||
|
||||
This block performs convolution followed by upsampling
|
||||
and concatenates with coordinate information.
|
||||
This block implements an upsampling step followed by a convolution,
|
||||
commonly used in CNN decoders. Before the convolution, it concatenates an
|
||||
extra channel representing the normalized vertical coordinate (frequency)
|
||||
of the *upsampled* feature map.
|
||||
|
||||
The goal is to provide spatial awareness (frequency position) to the
|
||||
filters during the decoding/upsampling process.
|
||||
|
||||
Sequence: Interpolate -> Concat Coords -> Conv -> BatchNorm -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input tensor (before upsampling).
|
||||
out_channels : int
|
||||
Number of output channels after the convolution.
|
||||
input_height : int
|
||||
Height (H dimension, frequency bins) of the tensor *before* upsampling.
|
||||
Used to calculate the height for coordinate feature generation after
|
||||
upsampling.
|
||||
kernel_size : int, default=3
|
||||
Size of the square convolutional kernel.
|
||||
pad_size : int, default=1
|
||||
Padding added before convolution.
|
||||
up_mode : str, default="bilinear"
|
||||
Interpolation mode for upsampling (e.g., "nearest", "bilinear",
|
||||
"bicubic").
|
||||
up_scale : Tuple[int, int], default=(2, 2)
|
||||
Scaling factor for height and width during upsampling
|
||||
(typically (2, 2)).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -249,6 +530,19 @@ class ConvBlockUpF(nn.Module):
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply Interpolate -> Concat Coords -> Conv -> BN -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H_in, W_in)`, where H_in should match
|
||||
`input_height` used during initialization.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H_in * scale_h, W_in * scale_w)`.
|
||||
"""
|
||||
op = F.interpolate(
|
||||
x,
|
||||
size=(
|
||||
@ -265,10 +559,45 @@ class ConvBlockUpF(nn.Module):
|
||||
return op
|
||||
|
||||
|
||||
class ConvBlockUpStandard(nn.Module):
|
||||
"""Convolutional Block with Upsampling.
|
||||
class StandardConvUpConfig(BaseConfig):
|
||||
"""Configuration for a StandardConvUpBlock."""
|
||||
|
||||
This block performs convolution followed by upsampling.
|
||||
block_type: Literal["StandardConvUp"] = "StandardConvUp"
|
||||
"""Discriminator field indicating the block type."""
|
||||
|
||||
out_channels: int
|
||||
"""Number of output channels."""
|
||||
|
||||
kernel_size: int = 3
|
||||
"""Size of the square convolutional kernel."""
|
||||
|
||||
pad_size: int = 1
|
||||
"""Padding size."""
|
||||
|
||||
|
||||
class StandardConvUpBlock(nn.Module):
|
||||
"""Standard Upsampling Convolutional Block.
|
||||
|
||||
A basic upsampling block used in CNN decoders. It first upsamples the input
|
||||
feature map using interpolation, then applies a 2D convolution, batch
|
||||
normalization, and ReLU activation. Does not use coordinate features.
|
||||
|
||||
Sequence: Interpolate -> Conv -> BN -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input tensor (before upsampling).
|
||||
out_channels : int
|
||||
Number of output channels after the convolution.
|
||||
kernel_size : int, default=3
|
||||
Size of the square convolutional kernel.
|
||||
pad_size : int, default=1
|
||||
Padding added before convolution.
|
||||
up_mode : str, default="bilinear"
|
||||
Interpolation mode for upsampling (e.g., "nearest", "bilinear").
|
||||
up_scale : Tuple[int, int], default=(2, 2)
|
||||
Scaling factor for height and width during upsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -280,7 +609,7 @@ class ConvBlockUpStandard(nn.Module):
|
||||
up_mode: str = "bilinear",
|
||||
up_scale: Tuple[int, int] = (2, 2),
|
||||
):
|
||||
super(ConvBlockUpStandard, self).__init__()
|
||||
super(StandardConvUpBlock, self).__init__()
|
||||
self.up_scale = up_scale
|
||||
self.up_mode = up_mode
|
||||
self.conv = nn.Conv2d(
|
||||
@ -292,6 +621,18 @@ class ConvBlockUpStandard(nn.Module):
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply Interpolate -> Conv -> BN -> ReLU.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H_in, W_in)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H_in * scale_h, W_in * scale_w)`.
|
||||
"""
|
||||
op = F.interpolate(
|
||||
x,
|
||||
size=(
|
||||
@ -306,141 +647,142 @@ class ConvBlockUpStandard(nn.Module):
|
||||
return op
|
||||
|
||||
|
||||
UpscalingLayer = Literal["ConvBlockUpStandard", "ConvBlockUpF"]
|
||||
LayerConfig = Annotated[
|
||||
Union[
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
StandardConvDownConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
"BlockGroupConfig",
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configuration models."""
|
||||
|
||||
|
||||
def build_downscaling_layer(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
class BlockGroupConfig(BaseConfig):
|
||||
block_type: Literal["group"] = "group"
|
||||
blocks: List[LayerConfig]
|
||||
|
||||
|
||||
def build_layer_from_config(
|
||||
input_height: int,
|
||||
layer_type: DownscalingLayer,
|
||||
) -> nn.Module:
|
||||
if layer_type == "ConvBlockDownStandard":
|
||||
return ConvBlockDownStandard(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
|
||||
if layer_type == "ConvBlockDownCoordF":
|
||||
return ConvBlockDownCoordF(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid downscaling layer type {layer_type}. "
|
||||
f"Valid values: ConvBlockDownCoordF, ConvBlockDownStandard"
|
||||
)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: Sequence[int] = (1, 32, 62, 128),
|
||||
input_height: int = 128,
|
||||
layer_type: Literal[
|
||||
"ConvBlockDownStandard", "ConvBlockDownCoordF"
|
||||
] = "ConvBlockDownStandard",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.input_height = input_height
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
build_downscaling_layer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height // (2**layer_num),
|
||||
layer_type=layer_type,
|
||||
)
|
||||
for layer_num, (in_channels, out_channels) in enumerate(
|
||||
pairwise(channels)
|
||||
)
|
||||
]
|
||||
)
|
||||
self.depth = len(self.layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
outputs = []
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
outputs.append(x)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def build_upscaling_layer(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
layer_type: UpscalingLayer,
|
||||
) -> nn.Module:
|
||||
if layer_type == "ConvBlockUpStandard":
|
||||
return ConvBlockUpStandard(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
config: LayerConfig,
|
||||
) -> Tuple[nn.Module, int, int]:
|
||||
"""Factory function to build a specific nn.Module block from its config.
|
||||
|
||||
Takes configuration object (one of the types included in the `LayerConfig`
|
||||
union) and instantiates the corresponding nn.Module block with the correct
|
||||
parameters derived from the config and the current pipeline state
|
||||
(`input_height`, `in_channels`).
|
||||
|
||||
It uses the `block_type` field within the `config` object to determine
|
||||
which block class to instantiate.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Height (frequency bins) of the input tensor *to this layer*.
|
||||
in_channels : int
|
||||
Number of channels in the input tensor *to this layer*.
|
||||
config : LayerConfig
|
||||
A Pydantic configuration object for the desired block (e.g., an
|
||||
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
|
||||
by its `block_type` field.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[nn.Module, int, int]
|
||||
A tuple containing:
|
||||
- The instantiated `nn.Module` block.
|
||||
- The number of output channels produced by the block.
|
||||
- The calculated height of the output produced by the block.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
If the `config.block_type` does not correspond to a known block type.
|
||||
ValueError
|
||||
If parameters derived from the config are invalid for the block.
|
||||
"""
|
||||
if config.block_type == "ConvBlock":
|
||||
return (
|
||||
ConvBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=config.out_channels,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
),
|
||||
config.out_channels,
|
||||
input_height,
|
||||
)
|
||||
|
||||
if layer_type == "ConvBlockUpF":
|
||||
return ConvBlockUpF(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height,
|
||||
if config.block_type == "FreqCoordConvDown":
|
||||
return (
|
||||
FreqCoordConvDownBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=config.out_channels,
|
||||
input_height=input_height,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
),
|
||||
config.out_channels,
|
||||
input_height // 2,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid upscaling layer type {layer_type}. "
|
||||
f"Valid values: ConvBlockUpStandard, ConvBlockUpF"
|
||||
)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: Sequence[int] = (256, 62, 32, 32),
|
||||
input_height: int = 128,
|
||||
layer_type: Literal[
|
||||
"ConvBlockUpStandard", "ConvBlockUpF"
|
||||
] = "ConvBlockUpStandard",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.input_height = input_height
|
||||
self.depth = len(self.channels) - 1
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
build_upscaling_layer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height
|
||||
// (2 ** (self.depth - layer_num)),
|
||||
layer_type=layer_type,
|
||||
)
|
||||
for layer_num, (in_channels, out_channels) in enumerate(
|
||||
pairwise(channels)
|
||||
)
|
||||
]
|
||||
if config.block_type == "StandardConvDown":
|
||||
return (
|
||||
StandardConvDownBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=config.out_channels,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
),
|
||||
config.out_channels,
|
||||
input_height // 2,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residuals: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
if len(residuals) != len(self.layers):
|
||||
raise ValueError(
|
||||
f"Incorrect number of residuals provided. "
|
||||
f"Expected {len(self.layers)} (matching the number of layers), "
|
||||
f"but got {len(residuals)}."
|
||||
if config.block_type == "FreqCoordConvUp":
|
||||
return (
|
||||
FreqCoordConvUpBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=config.out_channels,
|
||||
input_height=input_height,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
),
|
||||
config.out_channels,
|
||||
input_height * 2,
|
||||
)
|
||||
|
||||
if config.block_type == "StandardConvUp":
|
||||
return (
|
||||
StandardConvUpBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=config.out_channels,
|
||||
kernel_size=config.kernel_size,
|
||||
pad_size=config.pad_size,
|
||||
),
|
||||
config.out_channels,
|
||||
input_height * 2,
|
||||
)
|
||||
|
||||
if config.block_type == "group":
|
||||
current_channels = in_channels
|
||||
current_height = input_height
|
||||
|
||||
blocks = []
|
||||
|
||||
for block_config in config.blocks:
|
||||
block, current_channels, current_height = build_layer_from_config(
|
||||
input_height=current_height,
|
||||
in_channels=current_channels,
|
||||
config=block_config,
|
||||
)
|
||||
blocks.append(block)
|
||||
|
||||
for layer, res in zip(self.layers, residuals[::-1]):
|
||||
x = layer(x + res)
|
||||
return nn.Sequential(*blocks), current_channels, current_height
|
||||
|
||||
return x
|
||||
raise NotImplementedError(f"Unknown block type {config.block_type}")
|
||||
|
254
batdetect2/models/bottleneck.py
Normal file
254
batdetect2/models/bottleneck.py
Normal file
@ -0,0 +1,254 @@
|
||||
"""Defines the Bottleneck component of an Encoder-Decoder architecture.
|
||||
|
||||
This module provides the configuration (`BottleneckConfig`) and
|
||||
`torch.nn.Module` implementations (`Bottleneck`, `BottleneckAttn`) for the
|
||||
bottleneck layer(s) that typically connect the Encoder (downsampling path) and
|
||||
Decoder (upsampling path) in networks like U-Nets.
|
||||
|
||||
The bottleneck processes the lowest-resolution, highest-dimensionality feature
|
||||
map produced by the Encoder. This module offers a configurable option to include
|
||||
a `SelfAttention` layer within the bottleneck, allowing the model to capture
|
||||
global temporal context before features are passed to the Decoder.
|
||||
|
||||
A factory function `build_bottleneck` constructs the appropriate bottleneck
|
||||
module based on the provided configuration.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.blocks import SelfAttention, VerticalConv
|
||||
|
||||
__all__ = [
|
||||
"BottleneckConfig",
|
||||
"Bottleneck",
|
||||
"BottleneckAttn",
|
||||
"build_bottleneck",
|
||||
]
|
||||
|
||||
|
||||
class BottleneckConfig(BaseConfig):
|
||||
"""Configuration for the bottleneck layer(s).
|
||||
|
||||
Defines the number of channels within the bottleneck and whether to include
|
||||
a self-attention mechanism.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
channels : int
|
||||
The number of output channels produced by the main convolutional layer
|
||||
within the bottleneck. This often matches the number of channels coming
|
||||
from the last encoder stage, but can be different. Must be positive.
|
||||
This also defines the channel dimensions used within the optional
|
||||
`SelfAttention` layer.
|
||||
self_attention : bool
|
||||
If True, includes a `SelfAttention` layer operating on the time
|
||||
dimension after an initial `VerticalConv` layer within the bottleneck.
|
||||
If False, only the initial `VerticalConv` (and height repetition) is
|
||||
performed.
|
||||
"""
|
||||
|
||||
channels: int
|
||||
self_attention: bool
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""Base Bottleneck module for Encoder-Decoder architectures.
|
||||
|
||||
This implementation represents the simplest bottleneck structure
|
||||
considered, primarily consisting of a `VerticalConv` layer. This layer
|
||||
collapses the frequency dimension (height) to 1, summarizing information
|
||||
across frequencies at each time step. The output is then repeated along the
|
||||
height dimension to match the original bottleneck input height before being
|
||||
passed to the decoder.
|
||||
|
||||
This base version does *not* include self-attention.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Height (frequency bins) of the input tensor. Must be positive.
|
||||
in_channels : int
|
||||
Number of channels in the input tensor from the encoder. Must be
|
||||
positive.
|
||||
out_channels : int
|
||||
Number of output channels. Must be positive.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels accepted by the bottleneck.
|
||||
input_height : int
|
||||
Expected height of the input tensor.
|
||||
channels : int
|
||||
Number of output channels.
|
||||
conv_vert : VerticalConv
|
||||
The vertical convolution layer.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `input_height`, `in_channels`, or `out_channels` are not positive.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_height: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
) -> None:
|
||||
"""Initialize the base Bottleneck layer."""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.input_height = input_height
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.conv_vert = VerticalConv(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
input_height=input_height,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Process input features through the bottleneck.
|
||||
|
||||
Applies vertical convolution and repeats the output height.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor from the encoder bottleneck, shape
|
||||
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
|
||||
`H_in` must match `self.input_height`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H_in, W)`. Note that the height
|
||||
dimension `H_in` is restored via repetition after the vertical
|
||||
convolution.
|
||||
"""
|
||||
x = self.conv_vert(x)
|
||||
return x.repeat([1, 1, self.input_height, 1])
|
||||
|
||||
|
||||
class BottleneckAttn(Bottleneck):
|
||||
"""Bottleneck module including a Self-Attention layer.
|
||||
|
||||
Extends the base `Bottleneck` by inserting a `SelfAttention` layer after
|
||||
the initial `VerticalConv`. This allows the bottleneck to capture global
|
||||
temporal dependencies in the summarized frequency features before passing
|
||||
them to the decoder.
|
||||
|
||||
Sequence: VerticalConv -> SelfAttention -> Repeat Height.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Height (frequency bins) of the input tensor from the encoder.
|
||||
in_channels : int
|
||||
Number of channels in the input tensor from the encoder.
|
||||
out_channels : int
|
||||
Number of output channels produced by the `VerticalConv` and
|
||||
subsequently processed and output by this bottleneck. Also determines
|
||||
the input/output channels of the internal `SelfAttention` layer.
|
||||
attention : nn.Module
|
||||
An initialized `SelfAttention` module instance.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `input_height`, `in_channels`, or `out_channels` are not positive.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_height: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
attention: nn.Module,
|
||||
) -> None:
|
||||
"""Initialize the Bottleneck with Self-Attention."""
|
||||
super().__init__(input_height, in_channels, out_channels)
|
||||
self.attention = attention
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Process input tensor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor from the encoder bottleneck, shape
|
||||
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
|
||||
`H_in` must match `self.input_height`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H_in, W)`, after applying attention
|
||||
and repeating the height dimension.
|
||||
"""
|
||||
x = self.conv_vert(x)
|
||||
x = self.attention(x)
|
||||
return x.repeat([1, 1, self.input_height, 1])
|
||||
|
||||
|
||||
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
||||
channels=256,
|
||||
self_attention=True,
|
||||
)
|
||||
|
||||
|
||||
def build_bottleneck(
|
||||
input_height: int,
|
||||
in_channels: int,
|
||||
config: Optional[BottleneckConfig] = None,
|
||||
) -> nn.Module:
|
||||
"""Factory function to build the Bottleneck module from configuration.
|
||||
|
||||
Constructs either a base `Bottleneck` or a `BottleneckAttn` instance based
|
||||
on the `config.self_attention` flag.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_height : int
|
||||
Height (frequency bins) of the input tensor. Must be positive.
|
||||
in_channels : int
|
||||
Number of channels in the input tensor. Must be positive.
|
||||
config : BottleneckConfig, optional
|
||||
Configuration object specifying the bottleneck channels and whether
|
||||
to use self-attention. Uses `DEFAULT_BOTTLENECK_CONFIG` if None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
nn.Module
|
||||
An initialized bottleneck module (`Bottleneck` or `BottleneckAttn`).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `input_height` or `in_channels` are not positive.
|
||||
"""
|
||||
config = config or DEFAULT_BOTTLENECK_CONFIG
|
||||
|
||||
if config.self_attention:
|
||||
attention = SelfAttention(
|
||||
in_channels=config.channels,
|
||||
attention_channels=config.channels,
|
||||
)
|
||||
|
||||
return BottleneckAttn(
|
||||
input_height=input_height,
|
||||
in_channels=in_channels,
|
||||
out_channels=config.channels,
|
||||
attention=attention,
|
||||
)
|
||||
|
||||
return Bottleneck(
|
||||
input_height=input_height,
|
||||
in_channels=in_channels,
|
||||
out_channels=config.channels,
|
||||
)
|
@ -1,15 +1,277 @@
|
||||
import sys
|
||||
from typing import Iterable, List, Literal, Sequence
|
||||
"""Constructs the Decoder part of an Encoder-Decoder neural network.
|
||||
|
||||
This module defines the configuration structure (`DecoderConfig`) for the layer
|
||||
sequence and provides the `Decoder` class (an `nn.Module`) along with a factory
|
||||
function (`build_decoder`). Decoders typically form the upsampling path in
|
||||
architectures like U-Nets, taking bottleneck features
|
||||
(usually from an `Encoder`) and skip connections to reconstruct
|
||||
higher-resolution feature maps.
|
||||
|
||||
The decoder is built dynamically by stacking neural network blocks based on a
|
||||
list of configuration objects provided in `DecoderConfig.layers`. Each config
|
||||
object specifies the type of block (e.g., standard convolution,
|
||||
coordinate-feature convolution with upsampling) and its parameters. This allows
|
||||
flexible definition of decoder architectures via configuration files.
|
||||
|
||||
The `Decoder`'s `forward` method is designed to accept skip connection tensors
|
||||
(`residuals`) from the encoder, merging them with the upsampled feature maps
|
||||
at each stage.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.models.blocks import ConvBlockUpF, ConvBlockUpStandard
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
BlockGroupConfig,
|
||||
ConvConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
build_layer_from_config,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from itertools import pairwise
|
||||
else:
|
||||
__all__ = [
|
||||
"DecoderConfig",
|
||||
"Decoder",
|
||||
"build_decoder",
|
||||
"DEFAULT_DECODER_CONFIG",
|
||||
]
|
||||
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
||||
DecoderLayerConfig = Annotated[
|
||||
Union[
|
||||
ConvConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
BlockGroupConfig,
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||
|
||||
|
||||
class DecoderConfig(BaseConfig):
|
||||
"""Configuration for the sequence of layers in the Decoder module.
|
||||
|
||||
Defines the types and parameters of the neural network blocks that
|
||||
constitute the decoder's upsampling path.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
layers : List[DecoderLayerConfig]
|
||||
An ordered list of configuration objects, each defining one layer or
|
||||
block in the decoder sequence. Each item must be a valid block
|
||||
config including a `block_type` field and necessary parameters like
|
||||
`out_channels`. Input channels for each layer are inferred sequentially.
|
||||
The list must contain at least one layer.
|
||||
"""
|
||||
|
||||
layers: List[DecoderLayerConfig] = Field(min_length=1)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Sequential Decoder module composed of configurable upsampling layers.
|
||||
|
||||
Constructs the upsampling path of an encoder-decoder network by stacking
|
||||
multiple blocks (e.g., `StandardConvUpBlock`, `FreqCoordConvUpBlock`)
|
||||
based on a list of layer modules provided during initialization (typically
|
||||
created by the `build_decoder` factory function).
|
||||
|
||||
The `forward` method is designed to integrate skip connection tensors
|
||||
(`residuals`) from the corresponding encoder stages, by adding them
|
||||
element-wise to the input of each decoder layer before processing.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels expected in the input tensor.
|
||||
out_channels : int
|
||||
Number of channels in the final output tensor produced by the last
|
||||
layer.
|
||||
input_height : int
|
||||
Height (frequency bins) expected in the input tensor.
|
||||
output_height : int
|
||||
Height (frequency bins) expected in the output tensor.
|
||||
layers : nn.ModuleList
|
||||
The sequence of instantiated upscaling layer modules.
|
||||
depth : int
|
||||
The number of upscaling layers (depth) in the decoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
input_height: int,
|
||||
output_height: int,
|
||||
layers: List[nn.Module],
|
||||
):
|
||||
"""Initialize the Decoder module.
|
||||
|
||||
Note: This constructor is typically called internally by the
|
||||
`build_decoder` factory function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
out_channels : int
|
||||
Number of channels produced by the final layer.
|
||||
input_height : int
|
||||
Expected height of the input tensor (bottleneck).
|
||||
in_channels : int
|
||||
Expected number of channels in the input tensor (bottleneck).
|
||||
layers : List[nn.Module]
|
||||
A list of pre-instantiated upscaling layer modules (e.g.,
|
||||
`StandardConvUpBlock` or `FreqCoordConvUpBlock`) in the desired
|
||||
sequence (from bottleneck towards output resolution).
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.input_height = input_height
|
||||
self.output_height = output_height
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.depth = len(self.layers)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residuals: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Pass input through decoder layers, incorporating skip connections.
|
||||
|
||||
Processes the input tensor `x` sequentially through the upscaling
|
||||
layers. At each stage, the corresponding skip connection tensor from
|
||||
the `residuals` list is added element-wise to the input before passing
|
||||
it to the upscaling block.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor from the previous stage (e.g., encoder bottleneck).
|
||||
Shape `(B, C_in, H_in, W_in)`, where `C_in` matches
|
||||
`self.in_channels`.
|
||||
residuals : List[torch.Tensor]
|
||||
List containing the skip connection tensors from the corresponding
|
||||
encoder stages. Should be ordered from the deepest encoder layer
|
||||
output (lowest resolution) to the shallowest (highest resolution
|
||||
near input). The number of tensors in this list must match the
|
||||
number of decoder layers (`self.depth`). Each residual tensor's
|
||||
channel count must be compatible with the input tensor `x` for
|
||||
element-wise addition (or concatenation if the blocks were designed
|
||||
for it).
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The final decoded feature map tensor produced by the last layer.
|
||||
Shape `(B, C_out, H_out, W_out)`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the number of `residuals` provided does not match the decoder
|
||||
depth.
|
||||
RuntimeError
|
||||
If shapes mismatch during skip connection addition or layer
|
||||
processing.
|
||||
"""
|
||||
if len(residuals) != len(self.layers):
|
||||
raise ValueError(
|
||||
f"Incorrect number of residuals provided. "
|
||||
f"Expected {len(self.layers)} (matching the number of layers), "
|
||||
f"but got {len(residuals)}."
|
||||
)
|
||||
|
||||
for layer, res in zip(self.layers, residuals[::-1]):
|
||||
x = layer(x + res)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
DEFAULT_DECODER_CONFIG: DecoderConfig = DecoderConfig(
|
||||
layers=[
|
||||
FreqCoordConvUpConfig(out_channels=64),
|
||||
FreqCoordConvUpConfig(out_channels=32),
|
||||
BlockGroupConfig(
|
||||
blocks=[
|
||||
FreqCoordConvUpConfig(out_channels=32),
|
||||
ConvConfig(out_channels=32),
|
||||
]
|
||||
),
|
||||
],
|
||||
)
|
||||
"""A default configuration for the Decoder's *layer sequence*.
|
||||
|
||||
Specifies an architecture often used in BatDetect2, consisting of three
|
||||
frequency coordinate-aware upsampling blocks followed by a standard
|
||||
convolutional block.
|
||||
"""
|
||||
|
||||
|
||||
def build_decoder(
|
||||
in_channels: int,
|
||||
input_height: int,
|
||||
config: Optional[DecoderConfig] = None,
|
||||
) -> Decoder:
|
||||
"""Factory function to build a Decoder instance from configuration.
|
||||
|
||||
Constructs a sequential `Decoder` module based on the layer sequence
|
||||
defined in a `DecoderConfig` object and the provided input dimensions
|
||||
(bottleneck channels and height). If no config is provided, uses the
|
||||
default layer sequence from `DEFAULT_DECODER_CONFIG`.
|
||||
|
||||
It iteratively builds the layers using the unified `build_layer_from_config`
|
||||
factory (from `.blocks`), tracking the changing number of channels and
|
||||
feature map height required for each subsequent layer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
The number of channels in the input tensor to the decoder. Must be > 0.
|
||||
input_height : int
|
||||
The height (frequency bins) of the input tensor to the decoder. Must be
|
||||
> 0.
|
||||
config : DecoderConfig, optional
|
||||
The configuration object detailing the sequence of layers and their
|
||||
parameters. If None, `DEFAULT_DECODER_CONFIG` is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Decoder
|
||||
An initialized `Decoder` module.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `in_channels` or `input_height` are not positive, or if the layer
|
||||
configuration is invalid (e.g., empty list, unknown `block_type`).
|
||||
NotImplementedError
|
||||
If `build_layer_from_config` encounters an unknown `block_type`.
|
||||
"""
|
||||
config = config or DEFAULT_DECODER_CONFIG
|
||||
|
||||
current_channels = in_channels
|
||||
current_height = input_height
|
||||
|
||||
layers = []
|
||||
|
||||
for layer_config in config.layers:
|
||||
layer, current_channels, current_height = build_layer_from_config(
|
||||
in_channels=current_channels,
|
||||
input_height=current_height,
|
||||
config=layer_config,
|
||||
)
|
||||
layers.append(layer)
|
||||
|
||||
return Decoder(
|
||||
in_channels=in_channels,
|
||||
out_channels=current_channels,
|
||||
input_height=input_height,
|
||||
output_height=current_height,
|
||||
layers=layers,
|
||||
)
|
||||
|
173
batdetect2/models/detectors.py
Normal file
173
batdetect2/models/detectors.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""Assembles the complete BatDetect2 Detection Model.
|
||||
|
||||
This module defines the concrete `Detector` class, which implements the
|
||||
`DetectionModel` interface defined in `.types`. It combines a feature
|
||||
extraction backbone with specific prediction heads to create the end-to-end
|
||||
neural network used for detecting bat calls, predicting their size, and
|
||||
classifying them.
|
||||
|
||||
The primary components are:
|
||||
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
|
||||
- `build_detector`: A factory function to conveniently construct a standard
|
||||
`Detector` instance given a backbone and the number of target classes.
|
||||
|
||||
This module focuses purely on the neural network architecture definition. The
|
||||
logic for preprocessing inputs and postprocessing/decoding outputs resides in
|
||||
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||
|
||||
|
||||
class Detector(DetectionModel):
|
||||
"""Concrete implementation of the BatDetect2 Detection Model.
|
||||
|
||||
Assembles a complete detection and classification model by combining a
|
||||
feature extraction backbone network with specific prediction heads for
|
||||
detection probability, bounding box size regression, and class
|
||||
probabilities.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
backbone : BackboneModel
|
||||
The feature extraction backbone network module.
|
||||
num_classes : int
|
||||
The number of specific target classes the model predicts (derived from
|
||||
the `classifier_head`).
|
||||
classifier_head : ClassifierHead
|
||||
The prediction head responsible for generating class probabilities.
|
||||
detector_head : DetectorHead
|
||||
The prediction head responsible for generating detection probabilities.
|
||||
bbox_head : BBoxHead
|
||||
The prediction head responsible for generating bounding box size
|
||||
predictions.
|
||||
"""
|
||||
|
||||
backbone: BackboneModel
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
backbone: BackboneModel,
|
||||
classifier_head: ClassifierHead,
|
||||
detector_head: DetectorHead,
|
||||
bbox_head: BBoxHead,
|
||||
):
|
||||
"""Initialize the Detector model.
|
||||
|
||||
Note: Instances are typically created using the `build_detector`
|
||||
factory function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
backbone : BackboneModel
|
||||
An initialized feature extraction backbone module (e.g., built by
|
||||
`build_backbone` from the `.backbone` module).
|
||||
classifier_head : ClassifierHead
|
||||
An initialized classification head module. The number of classes
|
||||
is inferred from this head.
|
||||
detector_head : DetectorHead
|
||||
An initialized detection head module.
|
||||
bbox_head : BBoxHead
|
||||
An initialized bounding box size prediction head module.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError
|
||||
If the provided modules are not of the expected types.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.backbone = backbone
|
||||
self.num_classes = classifier_head.num_classes
|
||||
self.classifier_head = classifier_head
|
||||
self.detector_head = detector_head
|
||||
self.bbox_head = bbox_head
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
"""Perform the forward pass of the complete detection model.
|
||||
|
||||
Processes the input spectrogram through the backbone to extract
|
||||
features, then passes these features through the separate prediction
|
||||
heads to generate detection probabilities, class probabilities, and
|
||||
size predictions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor, typically with shape
|
||||
`(batch_size, input_channels, frequency_bins, time_bins)`. The
|
||||
shape must be compatible with the `self.backbone` input
|
||||
requirements.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ModelOutput
|
||||
A NamedTuple containing the four output tensors:
|
||||
- `detection_probs`: Detection probability heatmap `(B, 1, H, W)`.
|
||||
- `size_preds`: Predicted scaled size dimensions `(B, 2, H, W)`.
|
||||
- `class_probs`: Class probabilities (excluding background)
|
||||
`(B, num_classes, H, W)`.
|
||||
- `features`: Output feature map from the backbone
|
||||
`(B, C_out, H, W)`.
|
||||
"""
|
||||
features = self.backbone(spec)
|
||||
detection = self.detector_head(features)
|
||||
classification = self.classifier_head(features)
|
||||
size_preds = self.bbox_head(features)
|
||||
return ModelOutput(
|
||||
detection_probs=detection,
|
||||
size_preds=size_preds,
|
||||
class_probs=classification,
|
||||
features=features,
|
||||
)
|
||||
|
||||
|
||||
def build_detector(num_classes: int, backbone: BackboneModel) -> Detector:
|
||||
"""Factory function to build a standard Detector model instance.
|
||||
|
||||
Creates the standard prediction heads (`ClassifierHead`, `DetectorHead`,
|
||||
`BBoxHead`) configured appropriately based on the output channels of the
|
||||
provided `backbone` and the specified `num_classes`. It then assembles
|
||||
these components into a `Detector` model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_classes : int
|
||||
The number of specific target classes for the classification head
|
||||
(excluding any implicit background class). Must be positive.
|
||||
backbone : BackboneModel
|
||||
An initialized feature extraction backbone module instance. The number
|
||||
of output channels from this backbone (`backbone.out_channels`) is used
|
||||
to configure the input channels for the prediction heads.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Detector
|
||||
An initialized `Detector` model instance.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num_classes` is not positive.
|
||||
AttributeError
|
||||
If `backbone` does not have the required `out_channels` attribute.
|
||||
"""
|
||||
classifier_head = ClassifierHead(
|
||||
num_classes=num_classes,
|
||||
in_channels=backbone.out_channels,
|
||||
)
|
||||
detector_head = DetectorHead(
|
||||
in_channels=backbone.out_channels,
|
||||
)
|
||||
bbox_head = BBoxHead(
|
||||
in_channels=backbone.out_channels,
|
||||
)
|
||||
return Detector(
|
||||
backbone=backbone,
|
||||
classifier_head=classifier_head,
|
||||
detector_head=detector_head,
|
||||
bbox_head=bbox_head,
|
||||
)
|
@ -1,15 +1,318 @@
|
||||
import sys
|
||||
from typing import Iterable, List, Literal, Sequence
|
||||
"""Constructs the Encoder part of a configurable neural network backbone.
|
||||
|
||||
This module defines the configuration structure (`EncoderConfig`) and provides
|
||||
the `Encoder` class (an `nn.Module`) along with a factory function
|
||||
(`build_encoder`) to create sequential encoders. Encoders typically form the
|
||||
downsampling path in architectures like U-Nets, processing input feature maps
|
||||
(like spectrograms) to produce lower-resolution, higher-dimensionality feature
|
||||
representations (bottleneck features).
|
||||
|
||||
The encoder is built dynamically by stacking neural network blocks based on a
|
||||
list of configuration objects provided in `EncoderConfig.layers`. Each
|
||||
configuration object specifies the type of block (e.g., standard convolution,
|
||||
coordinate-feature convolution with downsampling) and its parameters
|
||||
(e.g., output channels). This allows for flexible definition of encoder
|
||||
architectures via configuration files.
|
||||
|
||||
The `Encoder`'s `forward` method returns outputs from all intermediate layers,
|
||||
suitable for skip connections, while the `encode` method returns only the final
|
||||
bottleneck output. A default configuration (`DEFAULT_ENCODER_CONFIG`) is also
|
||||
provided.
|
||||
"""
|
||||
|
||||
from typing import Annotated, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.models.blocks import ConvBlockDownCoordF, ConvBlockDownStandard
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
BlockGroupConfig,
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
StandardConvDownConfig,
|
||||
build_layer_from_config,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from itertools import pairwise
|
||||
else:
|
||||
__all__ = [
|
||||
"EncoderConfig",
|
||||
"Encoder",
|
||||
"build_encoder",
|
||||
"DEFAULT_ENCODER_CONFIG",
|
||||
]
|
||||
|
||||
def pairwise(iterable: Sequence) -> Iterable:
|
||||
for x, y in zip(iterable[:-1], iterable[1:]):
|
||||
yield x, y
|
||||
EncoderLayerConfig = Annotated[
|
||||
Union[
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
StandardConvDownConfig,
|
||||
BlockGroupConfig,
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
||||
|
||||
|
||||
class EncoderConfig(BaseConfig):
|
||||
"""Configuration for building the sequential Encoder module.
|
||||
|
||||
Defines the sequence of neural network blocks that constitute the encoder
|
||||
(downsampling path).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
layers : List[EncoderLayerConfig]
|
||||
An ordered list of configuration objects, each defining one layer or
|
||||
block in the encoder sequence. Each item must be a valid block config
|
||||
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
|
||||
`StandardConvDownConfig`) including a `block_type` field and necessary
|
||||
parameters like `out_channels`. Input channels for each layer are
|
||||
inferred sequentially. The list must contain at least one layer.
|
||||
"""
|
||||
|
||||
layers: List[EncoderLayerConfig] = Field(min_length=1)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Sequential Encoder module composed of configurable downscaling layers.
|
||||
|
||||
Constructs the downsampling path of an encoder-decoder network by stacking
|
||||
multiple downscaling blocks.
|
||||
|
||||
The `forward` method executes the sequence and returns the output feature
|
||||
map from *each* downscaling stage, facilitating the implementation of skip
|
||||
connections in U-Net-like architectures. The `encode` method returns only
|
||||
the final output tensor (bottleneck features).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels expected in the input tensor.
|
||||
input_height : int
|
||||
Height (frequency bins) expected in the input tensor.
|
||||
output_channels : int
|
||||
Number of channels in the final output tensor (bottleneck).
|
||||
output_height : int
|
||||
Height (frequency bins) expected in the output tensor.
|
||||
layers : nn.ModuleList
|
||||
The sequence of instantiated downscaling layer modules.
|
||||
depth : int
|
||||
The number of downscaling layers in the encoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_channels: int,
|
||||
output_height: int,
|
||||
layers: List[nn.Module],
|
||||
input_height: int = 128,
|
||||
in_channels: int = 1,
|
||||
):
|
||||
"""Initialize the Encoder module.
|
||||
|
||||
Note: This constructor is typically called internally by the
|
||||
`build_encoder` factory function, which prepares the `layers` list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output_channels : int
|
||||
Number of channels produced by the final layer.
|
||||
output_height : int
|
||||
The expected height of the output tensor.
|
||||
layers : List[nn.Module]
|
||||
A list of pre-instantiated downscaling layer modules (e.g.,
|
||||
`StandardConvDownBlock` or `FreqCoordConvDownBlock`) in the desired
|
||||
sequence.
|
||||
input_height : int, default=128
|
||||
Expected height of the input tensor.
|
||||
in_channels : int, default=1
|
||||
Expected number of channels in the input tensor.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.input_height = input_height
|
||||
self.out_channels = output_channels
|
||||
self.output_height = output_height
|
||||
|
||||
self.layers = nn.ModuleList(layers)
|
||||
self.depth = len(self.layers)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
"""Pass input through encoder layers, returns all intermediate outputs.
|
||||
|
||||
This method is typically used when the Encoder is part of a U-Net or
|
||||
similar architecture requiring skip connections.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H_in, W)`, where `C_in` must match
|
||||
`self.in_channels` and `H_in` must match `self.input_height`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[torch.Tensor]
|
||||
A list containing the output tensors from *each* downscaling layer
|
||||
in the sequence. `outputs[0]` is the output of the first layer,
|
||||
`outputs[-1]` is the final output (bottleneck) of the encoder.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If input tensor channel count or height does not match expected
|
||||
values.
|
||||
"""
|
||||
if x.shape[1] != self.in_channels:
|
||||
raise ValueError(
|
||||
f"Input tensor has {x.shape[1]} channels, "
|
||||
f"but encoder expects {self.in_channels}."
|
||||
)
|
||||
|
||||
if x.shape[2] != self.input_height:
|
||||
raise ValueError(
|
||||
f"Input tensor height {x.shape[2]} does not match "
|
||||
f"encoder expected input_height {self.input_height}."
|
||||
)
|
||||
|
||||
outputs = []
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
outputs.append(x)
|
||||
|
||||
return outputs
|
||||
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pass input through encoder layers, returning only the final output.
|
||||
|
||||
This method provides access to the bottleneck features produced after
|
||||
the last downscaling layer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor, shape `(B, C_in, H_in, W)`. Must match expected
|
||||
`in_channels` and `input_height`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The final output tensor (bottleneck features) from the last layer
|
||||
of the encoder. Shape `(B, C_out, H_out, W_out)`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If input tensor channel count or height does not match expected
|
||||
values.
|
||||
"""
|
||||
if x.shape[1] != self.in_channels:
|
||||
raise ValueError(
|
||||
f"Input tensor has {x.shape[1]} channels, "
|
||||
f"but encoder expects {self.in_channels}."
|
||||
)
|
||||
|
||||
if x.shape[2] != self.input_height:
|
||||
raise ValueError(
|
||||
f"Input tensor height {x.shape[2]} does not match "
|
||||
f"encoder expected input_height {self.input_height}."
|
||||
)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
DEFAULT_ENCODER_CONFIG: EncoderConfig = EncoderConfig(
|
||||
layers=[
|
||||
FreqCoordConvDownConfig(out_channels=32),
|
||||
FreqCoordConvDownConfig(out_channels=64),
|
||||
BlockGroupConfig(
|
||||
blocks=[
|
||||
FreqCoordConvDownConfig(out_channels=128),
|
||||
ConvConfig(out_channels=256),
|
||||
]
|
||||
),
|
||||
],
|
||||
)
|
||||
"""Default configuration for the Encoder.
|
||||
|
||||
Specifies an architecture typically used in BatDetect2:
|
||||
- Input: 1 channel, 128 frequency bins.
|
||||
- Layer 1: FreqCoordConvDown -> 32 channels, H=64
|
||||
- Layer 2: FreqCoordConvDown -> 64 channels, H=32
|
||||
- Layer 3: FreqCoordConvDown -> 128 channels, H=16
|
||||
- Layer 4: ConvBlock -> 256 channels, H=16 (Bottleneck)
|
||||
"""
|
||||
|
||||
|
||||
def build_encoder(
|
||||
in_channels: int,
|
||||
input_height: int,
|
||||
config: Optional[EncoderConfig] = None,
|
||||
) -> Encoder:
|
||||
"""Factory function to build an Encoder instance from configuration.
|
||||
|
||||
Constructs a sequential `Encoder` module based on the layer sequence
|
||||
defined in an `EncoderConfig` object and the provided input dimensions.
|
||||
If no config is provided, uses the default layer sequence from
|
||||
`DEFAULT_ENCODER_CONFIG`.
|
||||
|
||||
It iteratively builds the layers using the unified
|
||||
`build_layer_from_config` factory (from `.blocks`), tracking the changing
|
||||
number of channels and feature map height required for each subsequent
|
||||
layer, especially for coordinate- aware blocks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
The number of channels expected in the input tensor to the encoder.
|
||||
Must be > 0.
|
||||
input_height : int
|
||||
The height (frequency bins) expected in the input tensor. Must be > 0.
|
||||
Crucial for initializing coordinate-aware layers correctly.
|
||||
config : EncoderConfig, optional
|
||||
The configuration object detailing the sequence of layers and their
|
||||
parameters. If None, `DEFAULT_ENCODER_CONFIG` is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Encoder
|
||||
An initialized `Encoder` module.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `in_channels` or `input_height` are not positive, or if the layer
|
||||
configuration is invalid (e.g., empty list, unknown `block_type`).
|
||||
NotImplementedError
|
||||
If `build_layer_from_config` encounters an unknown `block_type`.
|
||||
"""
|
||||
if in_channels <= 0 or input_height <= 0:
|
||||
raise ValueError("in_channels and input_height must be positive.")
|
||||
|
||||
config = config or DEFAULT_ENCODER_CONFIG
|
||||
|
||||
current_channels = in_channels
|
||||
current_height = input_height
|
||||
|
||||
layers = []
|
||||
|
||||
for layer_config in config.layers:
|
||||
layer, current_channels, current_height = build_layer_from_config(
|
||||
in_channels=current_channels,
|
||||
input_height=current_height,
|
||||
config=layer_config,
|
||||
)
|
||||
layers.append(layer)
|
||||
|
||||
return Encoder(
|
||||
input_height=input_height,
|
||||
layers=layers,
|
||||
in_channels=in_channels,
|
||||
output_channels=current_channels,
|
||||
output_height=current_height,
|
||||
)
|
||||
|
@ -1,42 +1,199 @@
|
||||
from typing import NamedTuple
|
||||
"""Prediction Head modules for BatDetect2 models.
|
||||
|
||||
This module defines simple `torch.nn.Module` subclasses that serve as
|
||||
prediction heads, typically attached to the output feature map of a backbone
|
||||
network
|
||||
|
||||
Each head is responsible for generating one specific type of output required
|
||||
by the BatDetect2 task:
|
||||
- `DetectorHead`: Predicts the probability of sound event presence.
|
||||
- `ClassifierHead`: Predicts the probability distribution over target classes.
|
||||
- `BBoxHead`: Predicts the size (width, height) of the sound event's bounding
|
||||
box.
|
||||
|
||||
These heads use 1x1 convolutions to map the backbone feature channels
|
||||
to the desired number of output channels for each prediction task at each
|
||||
spatial location, followed by an appropriate activation function (e.g., sigmoid
|
||||
for detection, softmax for classification, none for size regression).
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
__all__ = ["ClassifierHead"]
|
||||
|
||||
|
||||
class Output(NamedTuple):
|
||||
detection: torch.Tensor
|
||||
classification: torch.Tensor
|
||||
__all__ = [
|
||||
"ClassifierHead",
|
||||
"DetectorHead",
|
||||
"BBoxHead",
|
||||
]
|
||||
|
||||
|
||||
class ClassifierHead(nn.Module):
|
||||
"""Prediction head for multi-class classification probabilities.
|
||||
|
||||
Takes an input feature map and produces a probability map where each
|
||||
channel corresponds to a specific target class. It uses a 1x1 convolution
|
||||
to map input channels to `num_classes + 1` outputs (one for each target
|
||||
class plus an assumed background/generic class), applies softmax across the
|
||||
channels, and returns the probabilities for the specific target classes
|
||||
(excluding the last background/generic channel).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_classes : int
|
||||
The number of specific target classes the model should predict
|
||||
(excluding any background or generic category). Must be positive.
|
||||
in_channels : int
|
||||
Number of channels in the input feature map tensor from the backbone.
|
||||
Must be positive.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
num_classes : int
|
||||
Number of specific output classes.
|
||||
in_channels : int
|
||||
Number of input channels expected.
|
||||
classifier : nn.Conv2d
|
||||
The 1x1 convolutional layer used for prediction.
|
||||
Output channels = num_classes + 1.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num_classes` or `in_channels` are not positive.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes: int, in_channels: int):
|
||||
"""Initialize the ClassifierHead."""
|
||||
super().__init__()
|
||||
|
||||
self.num_classes = num_classes
|
||||
self.in_channels = in_channels
|
||||
self.classifier = nn.Conv2d(
|
||||
self.in_channels,
|
||||
# Add one to account for the background class
|
||||
self.num_classes + 1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, features: torch.Tensor) -> Output:
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute class probabilities from input features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : torch.Tensor
|
||||
Input feature map tensor from the backbone, typically with shape
|
||||
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Class probability map tensor with shape `(B, num_classes, H, W)`.
|
||||
Contains probabilities for the specific target classes after
|
||||
softmax, excluding the implicit background/generic class channel.
|
||||
"""
|
||||
logits = self.classifier(features)
|
||||
probs = torch.softmax(logits, dim=1)
|
||||
detection_probs = probs[:, :-1].sum(dim=1, keepdim=True)
|
||||
return Output(
|
||||
detection=detection_probs,
|
||||
classification=probs[:, :-1],
|
||||
return probs[:, :-1]
|
||||
|
||||
|
||||
class DetectorHead(nn.Module):
|
||||
"""Prediction head for sound event detection probability.
|
||||
|
||||
Takes an input feature map and produces a single-channel heatmap where
|
||||
each value represents the probability ([0, 1]) of a relevant sound event
|
||||
(of any class) being present at that spatial location.
|
||||
|
||||
Uses a 1x1 convolution to map input channels to 1 output channel, followed
|
||||
by a sigmoid activation function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input feature map tensor from the backbone.
|
||||
Must be positive.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels expected.
|
||||
detector : nn.Conv2d
|
||||
The 1x1 convolutional layer mapping to a single output channel.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `in_channels` is not positive.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
"""Initialize the DetectorHead."""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.detector = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute detection probabilities from input features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : torch.Tensor
|
||||
Input feature map tensor from the backbone, typically with shape
|
||||
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Detection probability heatmap tensor with shape `(B, 1, H, W)`.
|
||||
Values are in the range [0, 1] due to the sigmoid activation.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If input channel count does not match `self.in_channels`.
|
||||
"""
|
||||
return torch.sigmoid(self.detector(features))
|
||||
|
||||
|
||||
class BBoxHead(nn.Module):
|
||||
"""Prediction head for bounding box size dimensions.
|
||||
|
||||
Takes an input feature map and produces a two-channel map where each
|
||||
channel represents a predicted size dimension (typically width/duration and
|
||||
height/bandwidth) for a potential sound event at that spatial location.
|
||||
|
||||
Uses a 1x1 convolution to map input channels to 2 output channels. No
|
||||
activation function is typically applied, as size prediction is often
|
||||
treated as a direct regression task. The output values usually represent
|
||||
*scaled* dimensions that need to be un-scaled during postprocessing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_channels : int
|
||||
Number of channels in the input feature map tensor from the backbone.
|
||||
Must be positive.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
in_channels : int
|
||||
Number of input channels expected.
|
||||
bbox : nn.Conv2d
|
||||
The 1x1 convolutional layer mapping to 2 output channels
|
||||
(width, height).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `in_channels` is not positive.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
"""Initialize the BBoxHead."""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
@ -48,4 +205,19 @@ class BBoxHead(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute predicted bounding box dimensions from input features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : torch.Tensor
|
||||
Input feature map tensor from the backbone, typically with shape
|
||||
`(B, C_in, H, W)`. `C_in` must match `self.in_channels`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Predicted size tensor with shape `(B, 2, H, W)`. Channel 0 usually
|
||||
represents scaled width, Channel 1 scaled height. These values
|
||||
need to be un-scaled during postprocessing.
|
||||
"""
|
||||
return self.bbox(features)
|
||||
|
225
batdetect2/models/types.py
Normal file
225
batdetect2/models/types.py
Normal file
@ -0,0 +1,225 @@
|
||||
"""Defines shared interfaces (ABCs) and data structures for models.
|
||||
|
||||
This module centralizes the definitions of core data structures, like the
|
||||
standard model output container (`ModelOutput`), and establishes abstract base
|
||||
classes (ABCs) using `abc.ABC` and `torch.nn.Module`. These define contracts
|
||||
for fundamental model components, ensuring modularity and consistent
|
||||
interaction within the `batdetect2.models` package.
|
||||
|
||||
Key components:
|
||||
- `ModelOutput`: Standard structure for outputs from detection models.
|
||||
- `BackboneModel`: Generic interface for any feature extraction backbone.
|
||||
- `EncoderDecoderModel`: Specialized interface for backbones with distinct
|
||||
encoder-decoder stages (e.g., U-Net), providing access to intermediate
|
||||
features.
|
||||
- `DetectionModel`: Interface for the complete end-to-end detection model.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = [
|
||||
"ModelOutput",
|
||||
"BackboneModel",
|
||||
"DetectionModel",
|
||||
]
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
"""Standard container for the outputs of a BatDetect2 detection model.
|
||||
|
||||
This structure groups the different prediction tensors produced by the
|
||||
model for a batch of input spectrograms. All tensors typically share the
|
||||
same spatial dimensions (height H, width W) corresponding to the model's
|
||||
output resolution, and the same batch size (N).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
detection_probs : torch.Tensor
|
||||
Tensor containing the probability of sound event presence at each
|
||||
location in the output grid.
|
||||
Shape: `(N, 1, H, W)`
|
||||
size_preds : torch.Tensor
|
||||
Tensor containing the predicted size dimensions
|
||||
(e.g., width and height) for a potential bounding box at each location.
|
||||
Shape: `(N, 2, H, W)` (Channel 0 typically width, Channel 1 height)
|
||||
class_probs : torch.Tensor
|
||||
Tensor containing the predicted probabilities (or logits, depending on
|
||||
the final activation) for each target class at each location.
|
||||
The number of channels corresponds to the number of specific classes
|
||||
defined in the `Targets` configuration.
|
||||
Shape: `(N, num_classes, H, W)`
|
||||
features : torch.Tensor
|
||||
Tensor containing features extracted by the model's backbone. These
|
||||
might be used for downstream tasks or analysis. The number of channels
|
||||
depends on the specific model architecture.
|
||||
Shape: `(N, num_features, H, W)`
|
||||
"""
|
||||
|
||||
detection_probs: torch.Tensor
|
||||
size_preds: torch.Tensor
|
||||
class_probs: torch.Tensor
|
||||
features: torch.Tensor
|
||||
|
||||
|
||||
class BackboneModel(ABC, nn.Module):
|
||||
"""Abstract Base Class for generic feature extraction backbone models.
|
||||
|
||||
Defines the minimal interface for a feature extractor network within a
|
||||
BatDetect2 model. Its primary role is to process an input spectrogram
|
||||
tensor and produce a spatially rich feature map tensor, which is then
|
||||
typically consumed by separate prediction heads (for detection,
|
||||
classification, size).
|
||||
|
||||
This base class is agnostic to the specific internal architecture (e.g.,
|
||||
it could be a simple CNN, a U-Net, a Transformer, etc.). Concrete
|
||||
implementations must inherit from this class and `torch.nn.Module`,
|
||||
implement the `forward` method, and define the required attributes.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int
|
||||
Expected height (number of frequency bins) of the input spectrogram
|
||||
tensor that the backbone is designed to process.
|
||||
out_channels : int
|
||||
Number of channels in the final feature map tensor produced by the
|
||||
backbone's `forward` method.
|
||||
"""
|
||||
|
||||
input_height: int
|
||||
"""Expected input spectrogram height (frequency bins)."""
|
||||
|
||||
out_channels: int
|
||||
"""Number of output channels in the final feature map."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Perform the forward pass to extract features from the spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor, typically with shape
|
||||
`(batch_size, 1, frequency_bins, time_bins)`.
|
||||
`frequency_bins` should match `self.input_height`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output feature map tensor, typically with shape
|
||||
`(batch_size, self.out_channels, output_height, output_width)`.
|
||||
The spatial dimensions (`output_height`, `output_width`) depend
|
||||
on the specific backbone architecture (e.g., they might match the
|
||||
input or be downsampled).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class EncoderDecoderModel(BackboneModel):
|
||||
"""Abstract Base Class for Encoder-Decoder style backbone models.
|
||||
|
||||
This class specializes `BackboneModel` for architectures that have distinct
|
||||
encoder stages (downsampling path), a bottleneck, and decoder stages
|
||||
(upsampling path).
|
||||
|
||||
It provides separate abstract methods for the `encode` and `decode` steps,
|
||||
allowing access to the intermediate "bottleneck" features produced by the
|
||||
encoder. This can be useful for tasks like transfer learning or specialized
|
||||
analyses.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int
|
||||
(Inherited from BackboneModel) Expected input spectrogram height.
|
||||
out_channels : int
|
||||
(Inherited from BackboneModel) Number of output channels in the final
|
||||
feature map produced by the decoder/forward pass.
|
||||
bottleneck_channels : int
|
||||
Number of channels in the feature map produced by the encoder at its
|
||||
deepest point (the bottleneck), before the decoder starts.
|
||||
"""
|
||||
|
||||
bottleneck_channels: int
|
||||
"""Number of channels at the encoder's bottleneck."""
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Process the input spectrogram through the encoder part.
|
||||
|
||||
Takes the input spectrogram and passes it through the downsampling path
|
||||
of the network up to the bottleneck layer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor, typically with shape
|
||||
`(batch_size, 1, frequency_bins, time_bins)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The encoded feature map from the bottleneck layer, typically with
|
||||
shape `(batch_size, self.bottleneck_channels, bottleneck_height,
|
||||
bottleneck_width)`. The spatial dimensions are usually downsampled
|
||||
relative to the input.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, encoded: torch.Tensor) -> torch.Tensor:
|
||||
"""Process the bottleneck features through the decoder part.
|
||||
|
||||
Takes the encoded feature map from the bottleneck and passes it through
|
||||
the upsampling path (potentially using skip connections from the
|
||||
encoder) to produce the final output feature map.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
encoded : torch.Tensor
|
||||
The bottleneck feature map tensor produced by the `encode` method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The final output feature map tensor, typically with shape
|
||||
`(batch_size, self.out_channels, output_height, output_width)`.
|
||||
This should match the output shape of the `forward` method.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class DetectionModel(ABC, nn.Module):
|
||||
"""Abstract Base Class for complete BatDetect2 detection models.
|
||||
|
||||
Defines the interface for the overall model that takes an input spectrogram
|
||||
and produces all necessary outputs for detection, classification, and size
|
||||
prediction, packaged within a `ModelOutput` object.
|
||||
|
||||
Concrete implementations typically combine a `BackboneModel` for feature
|
||||
extraction with specific prediction heads for each output type. They must
|
||||
inherit from this class and `torch.nn.Module`, and implement the `forward`
|
||||
method.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
"""Perform the forward pass of the full detection model.
|
||||
|
||||
Processes the input spectrogram through the backbone and prediction
|
||||
heads to generate all required output tensors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : torch.Tensor
|
||||
Input spectrogram tensor, typically with shape
|
||||
`(batch_size, 1, frequency_bins, time_bins)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ModelOutput
|
||||
A NamedTuple containing the prediction tensors: `detection_probs`,
|
||||
`size_preds`, `class_probs`, and `features`.
|
||||
"""
|
@ -1,74 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import NamedTuple, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = [
|
||||
"ModelOutput",
|
||||
"BackboneModel",
|
||||
]
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
"""Output of the detection model.
|
||||
|
||||
Each of the tensors has a shape of
|
||||
|
||||
`(batch_size, num_channels, spec_height, spec_width)`.
|
||||
|
||||
Where `spec_height` and `spec_width` are the height and width of the
|
||||
input spectrograms.
|
||||
|
||||
They contain localised information of:
|
||||
|
||||
1. The probability of a bounding box detection at the given location.
|
||||
2. The predicted size of the bounding box at the given location.
|
||||
3. The probabilities of each class at the given location before softmax.
|
||||
4. Features used to make the predictions at the given location.
|
||||
"""
|
||||
|
||||
detection_probs: torch.Tensor
|
||||
"""Tensor with predict detection probabilities."""
|
||||
|
||||
size_preds: torch.Tensor
|
||||
"""Tensor with predicted bounding box sizes."""
|
||||
|
||||
class_probs: torch.Tensor
|
||||
"""Tensor with predicted class probabilities."""
|
||||
|
||||
features: torch.Tensor
|
||||
"""Tensor with intermediate features."""
|
||||
|
||||
|
||||
class BackboneModel(ABC, nn.Module):
|
||||
input_height: int
|
||||
"""Height of the input spectrogram."""
|
||||
|
||||
encoder_channels: Tuple[int, ...]
|
||||
"""Tuple specifying the number of channels for each convolutional layer
|
||||
in the encoder. The length of this tuple determines the number of
|
||||
encoder layers."""
|
||||
|
||||
decoder_channels: Tuple[int, ...]
|
||||
"""Tuple specifying the number of channels for each convolutional layer
|
||||
in the decoder. The length of this tuple determines the number of
|
||||
decoder layers."""
|
||||
|
||||
bottleneck_channels: int
|
||||
"""Number of channels in the bottleneck layer, which connects the
|
||||
encoder and decoder."""
|
||||
|
||||
out_channels: int
|
||||
"""Number of channels in the final output feature map produced by the
|
||||
backbone model."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass of the model."""
|
||||
|
||||
|
||||
class DetectionModel(ABC, nn.Module):
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass of the detection model."""
|
@ -1,181 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.evaluate.evaluate import match_predictions_and_annotations
|
||||
from batdetect2.models import (
|
||||
BBoxHead,
|
||||
ClassifierHead,
|
||||
ModelConfig,
|
||||
build_architecture,
|
||||
)
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
from batdetect2.post_process import (
|
||||
PostprocessConfig,
|
||||
postprocess_model_outputs,
|
||||
)
|
||||
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||
from batdetect2.train.losses import compute_loss
|
||||
from batdetect2.train.targets import (
|
||||
TargetConfig,
|
||||
build_decoder,
|
||||
build_target_encoder,
|
||||
get_class_names,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DetectorModel",
|
||||
]
|
||||
|
||||
|
||||
class ModuleConfig(BaseConfig):
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
architecture: ModelConfig = Field(default_factory=ModelConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocessing: PostprocessConfig = Field(
|
||||
default_factory=PostprocessConfig
|
||||
)
|
||||
|
||||
|
||||
class DetectorModel(L.LightningModule):
|
||||
config: ModuleConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[ModuleConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config or ModuleConfig()
|
||||
self.save_hyperparameters()
|
||||
|
||||
self.backbone = build_architecture(self.config.architecture)
|
||||
|
||||
self.classifier = ClassifierHead(
|
||||
num_classes=len(self.config.targets.classes),
|
||||
in_channels=self.backbone.out_channels,
|
||||
)
|
||||
|
||||
self.bbox = BBoxHead(in_channels=self.backbone.out_channels)
|
||||
|
||||
conf = self.config.train.loss.classification
|
||||
self.class_weights = (
|
||||
torch.tensor(conf.class_weights) if conf.class_weights else None
|
||||
)
|
||||
|
||||
# Training targets
|
||||
self.class_names = get_class_names(self.config.targets.classes)
|
||||
self.encoder = build_target_encoder(
|
||||
self.config.targets.classes,
|
||||
replacement_rules=self.config.targets.replace,
|
||||
)
|
||||
self.decoder = build_decoder(self.config.targets.classes)
|
||||
|
||||
self.validation_predictions = []
|
||||
|
||||
self.example_input_array = torch.randn([1, 1, 128, 512])
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
||||
features = self.backbone(spec)
|
||||
detection_probs, classification_probs = self.classifier(features)
|
||||
size_preds = self.bbox(features)
|
||||
return ModelOutput(
|
||||
detection_probs=detection_probs,
|
||||
size_preds=size_preds,
|
||||
class_probs=classification_probs,
|
||||
features=features,
|
||||
)
|
||||
|
||||
def training_step(self, batch: TrainExample):
|
||||
outputs = self.forward(batch.spec)
|
||||
losses = compute_loss(
|
||||
batch,
|
||||
outputs,
|
||||
conf=self.config.train.loss,
|
||||
class_weights=self.class_weights,
|
||||
)
|
||||
|
||||
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
|
||||
self.log("train/loss/detection", losses.total, logger=True)
|
||||
self.log("train/loss/size", losses.total, logger=True)
|
||||
self.log("train/loss/classification", losses.total, logger=True)
|
||||
|
||||
return losses.total
|
||||
|
||||
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
|
||||
outputs = self.forward(batch.spec)
|
||||
|
||||
losses = compute_loss(
|
||||
batch,
|
||||
outputs,
|
||||
conf=self.config.train.loss,
|
||||
class_weights=self.class_weights,
|
||||
)
|
||||
|
||||
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
|
||||
self.log("val/loss/detection", losses.total, logger=True)
|
||||
self.log("val/loss/size", losses.total, logger=True)
|
||||
self.log("val/loss/classification", losses.total, logger=True)
|
||||
|
||||
dataloaders = self.trainer.val_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
dataset = dataloaders.dataset
|
||||
assert isinstance(dataset, LabeledDataset)
|
||||
clip_annotation = dataset.get_clip_annotation(batch_idx)
|
||||
|
||||
clip_prediction = postprocess_model_outputs(
|
||||
outputs,
|
||||
clips=[clip_annotation.clip],
|
||||
classes=self.class_names,
|
||||
decoder=self.decoder,
|
||||
config=self.config.postprocessing,
|
||||
)[0]
|
||||
|
||||
matches = match_predictions_and_annotations(
|
||||
clip_annotation,
|
||||
clip_prediction,
|
||||
)
|
||||
|
||||
self.validation_predictions.extend(matches)
|
||||
|
||||
def on_validation_epoch_end(self) -> None:
|
||||
self.validation_predictions.clear()
|
||||
|
||||
def configure_optimizers(self):
|
||||
conf = self.config.train.optimizer
|
||||
optimizer = Adam(self.parameters(), lr=conf.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=conf.t_max)
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
def process_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[Path] = None,
|
||||
) -> data.ClipPrediction:
|
||||
spec = preprocess_audio_clip(
|
||||
clip,
|
||||
config=self.config.preprocessing,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
tensor = torch.from_numpy(spec.data).unsqueeze(0).unsqueeze(0)
|
||||
outputs = self.forward(tensor)
|
||||
return postprocess_model_outputs(
|
||||
outputs,
|
||||
clips=[clip],
|
||||
classes=self.class_names,
|
||||
decoder=self.decoder,
|
||||
config=self.config.postprocessing,
|
||||
)[0]
|
@ -2,7 +2,6 @@
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import xarray as xr
|
||||
from matplotlib import axes
|
||||
|
||||
|
@ -1,398 +0,0 @@
|
||||
"""Module for postprocessing model outputs."""
|
||||
|
||||
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"PostprocessConfig",
|
||||
"load_postprocess_config",
|
||||
"postprocess_model_outputs",
|
||||
]
|
||||
|
||||
NMS_KERNEL_SIZE = 9
|
||||
DETECTION_THRESHOLD = 0.01
|
||||
TOP_K_PER_SEC = 200
|
||||
|
||||
|
||||
class PostprocessConfig(BaseConfig):
|
||||
"""Configuration for postprocessing model outputs."""
|
||||
|
||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||
detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0)
|
||||
min_freq: int = Field(default=10000, gt=0)
|
||||
max_freq: int = Field(default=120000, gt=0)
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
def load_postprocess_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PostprocessConfig:
|
||||
return load_config(path, schema=PostprocessConfig, field=field)
|
||||
|
||||
|
||||
class RawPrediction(NamedTuple):
|
||||
start_time: float
|
||||
end_time: float
|
||||
low_freq: float
|
||||
high_freq: float
|
||||
detection_score: float
|
||||
class_scores: Dict[str, float]
|
||||
features: np.ndarray
|
||||
|
||||
|
||||
def postprocess_model_outputs(
|
||||
outputs: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
classes: List[str],
|
||||
decoder: Callable[[str], List[data.Tag]],
|
||||
config: Optional[PostprocessConfig] = None,
|
||||
) -> List[data.ClipPrediction]:
|
||||
"""Postprocesses model outputs to generate clip predictions.
|
||||
|
||||
This function takes the output from the model, applies non-maximum suppression,
|
||||
selects the top-k scores, computes sound events from the outputs, and returns
|
||||
clip predictions based on these processed outputs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outputs
|
||||
Output from the model containing detection probabilities, size
|
||||
predictions, class logits, and features. All tensors are expected
|
||||
to have a batch dimension.
|
||||
clips
|
||||
List of clips for which predictions are made. The number of clips
|
||||
must match the batch dimension of the model outputs.
|
||||
config
|
||||
Configuration for postprocessing model outputs.
|
||||
|
||||
Returns
|
||||
-------
|
||||
predictions: List[data.ClipPrediction]
|
||||
List of clip predictions containing predicted sound events.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the number of predictions does not match the number of clips.
|
||||
"""
|
||||
|
||||
config = config or PostprocessConfig()
|
||||
|
||||
num_predictions = len(outputs.detection_probs)
|
||||
|
||||
if num_predictions == 0:
|
||||
return []
|
||||
|
||||
if num_predictions != len(clips):
|
||||
raise ValueError(
|
||||
"Number of predictions must match the number of clips."
|
||||
)
|
||||
|
||||
detection_probs = non_max_suppression(
|
||||
outputs.detection_probs,
|
||||
kernel_size=config.nms_kernel_size,
|
||||
)
|
||||
|
||||
duration = clips[0].end_time - clips[0].start_time
|
||||
|
||||
scores_batch, y_pos_batch, x_pos_batch = get_topk_scores(
|
||||
detection_probs,
|
||||
int(config.top_k_per_sec * duration / 2),
|
||||
)
|
||||
|
||||
predictions: List[data.ClipPrediction] = []
|
||||
for scores, y_pos, x_pos, size_preds, class_probs, features, clip in zip(
|
||||
scores_batch,
|
||||
y_pos_batch,
|
||||
x_pos_batch,
|
||||
outputs.size_preds,
|
||||
outputs.class_probs,
|
||||
outputs.features,
|
||||
clips,
|
||||
):
|
||||
sound_events = compute_sound_events_from_outputs(
|
||||
clip,
|
||||
scores,
|
||||
y_pos,
|
||||
x_pos,
|
||||
size_preds,
|
||||
class_probs,
|
||||
features,
|
||||
classes=classes,
|
||||
decoder=decoder,
|
||||
min_freq=config.min_freq,
|
||||
max_freq=config.max_freq,
|
||||
detection_threshold=config.detection_threshold,
|
||||
)
|
||||
|
||||
predictions.append(
|
||||
data.ClipPrediction(
|
||||
clip=clip,
|
||||
sound_events=sound_events,
|
||||
)
|
||||
)
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
def compute_predictions_from_outputs(
|
||||
start: float,
|
||||
end: float,
|
||||
scores: torch.Tensor,
|
||||
y_pos: torch.Tensor,
|
||||
x_pos: torch.Tensor,
|
||||
size_preds: torch.Tensor,
|
||||
class_probs: torch.Tensor,
|
||||
features: torch.Tensor,
|
||||
classes: List[str],
|
||||
min_freq: int = 10000,
|
||||
max_freq: int = 120000,
|
||||
detection_threshold: float = DETECTION_THRESHOLD,
|
||||
) -> List[RawPrediction]:
|
||||
_, freq_bins, time_bins = size_preds.shape
|
||||
|
||||
sorted_indices = torch.argsort(x_pos)
|
||||
valid_indices = sorted_indices[
|
||||
scores[sorted_indices] > detection_threshold
|
||||
]
|
||||
|
||||
scores = scores[valid_indices]
|
||||
x_pos = x_pos[valid_indices]
|
||||
y_pos = y_pos[valid_indices]
|
||||
|
||||
predictions: List[RawPrediction] = []
|
||||
for score, x, y in zip(scores, x_pos, y_pos):
|
||||
width, height = size_preds[:, y, x]
|
||||
class_prob = class_probs[:, y, x].detach().numpy()
|
||||
feats = features[:, y, x].detach().numpy()
|
||||
|
||||
start_time = np.interp(
|
||||
x.item(),
|
||||
[0, time_bins],
|
||||
[start, end],
|
||||
)
|
||||
|
||||
end_time = np.interp(
|
||||
x.item() + width.item(),
|
||||
[0, time_bins],
|
||||
[start, end],
|
||||
)
|
||||
|
||||
low_freq = np.interp(
|
||||
y.item(),
|
||||
[0, freq_bins],
|
||||
[max_freq, min_freq],
|
||||
)
|
||||
|
||||
high_freq = np.interp(
|
||||
y.item() - height.item(),
|
||||
[0, freq_bins],
|
||||
[max_freq, min_freq],
|
||||
)
|
||||
|
||||
start_time, end_time = sorted([float(start_time), float(end_time)])
|
||||
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
|
||||
|
||||
predictions.append(
|
||||
RawPrediction(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
detection_score=score.item(),
|
||||
features=feats,
|
||||
class_scores={
|
||||
class_name: prob
|
||||
for class_name, prob in zip(classes, class_prob)
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
def compute_sound_events_from_outputs(
|
||||
clip: data.Clip,
|
||||
scores: torch.Tensor,
|
||||
y_pos: torch.Tensor,
|
||||
x_pos: torch.Tensor,
|
||||
size_preds: torch.Tensor,
|
||||
class_probs: torch.Tensor,
|
||||
features: torch.Tensor,
|
||||
classes: List[str],
|
||||
decoder: Callable[[str], List[data.Tag]],
|
||||
min_freq: int = 10000,
|
||||
max_freq: int = 120000,
|
||||
detection_threshold: float = DETECTION_THRESHOLD,
|
||||
) -> List[data.SoundEventPrediction]:
|
||||
_, freq_bins, time_bins = size_preds.shape
|
||||
|
||||
sorted_indices = torch.argsort(x_pos)
|
||||
valid_indices = sorted_indices[
|
||||
scores[sorted_indices] > detection_threshold
|
||||
]
|
||||
|
||||
scores = scores[valid_indices]
|
||||
x_pos = x_pos[valid_indices]
|
||||
y_pos = y_pos[valid_indices]
|
||||
|
||||
predictions: List[data.SoundEventPrediction] = []
|
||||
for score, x, y in zip(scores, x_pos, y_pos):
|
||||
width, height = size_preds[:, y, x]
|
||||
class_prob = class_probs[:, y, x]
|
||||
feature = features[:, y, x]
|
||||
|
||||
start_time = np.interp(
|
||||
x.item(),
|
||||
[0, time_bins],
|
||||
[clip.start_time, clip.end_time],
|
||||
)
|
||||
|
||||
end_time = np.interp(
|
||||
x.item() + width.item(),
|
||||
[0, time_bins],
|
||||
[clip.start_time, clip.end_time],
|
||||
)
|
||||
|
||||
low_freq = np.interp(
|
||||
y.item(),
|
||||
[0, freq_bins],
|
||||
[max_freq, min_freq],
|
||||
)
|
||||
|
||||
high_freq = np.interp(
|
||||
y.item() - height.item(),
|
||||
[0, freq_bins],
|
||||
[max_freq, min_freq],
|
||||
)
|
||||
|
||||
predicted_tags: List[data.PredictedTag] = []
|
||||
|
||||
for label_id, class_score in enumerate(class_prob):
|
||||
class_name = classes[label_id]
|
||||
corresponding_tags = decoder(class_name)
|
||||
predicted_tags.extend(
|
||||
[
|
||||
data.PredictedTag(
|
||||
tag=tag,
|
||||
score=max(min(class_score.item(), 1), 0),
|
||||
)
|
||||
for tag in corresponding_tags
|
||||
]
|
||||
)
|
||||
|
||||
start_time, end_time = sorted([float(start_time), float(end_time)])
|
||||
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
|
||||
|
||||
sound_event = data.SoundEvent(
|
||||
recording=clip.recording,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
start_time,
|
||||
low_freq,
|
||||
end_time,
|
||||
high_freq,
|
||||
]
|
||||
),
|
||||
features=[
|
||||
data.Feature(
|
||||
term=data.term_from_key(f"batdetect2_{i}"),
|
||||
value=value.item(),
|
||||
)
|
||||
for i, value in enumerate(feature)
|
||||
],
|
||||
)
|
||||
|
||||
predictions.append(
|
||||
data.SoundEventPrediction(
|
||||
sound_event=sound_event,
|
||||
score=max(min(score.item(), 1), 0),
|
||||
tags=predicted_tags,
|
||||
)
|
||||
)
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
def non_max_suppression(
|
||||
tensor: torch.Tensor,
|
||||
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||
) -> torch.Tensor:
|
||||
"""Run non-maximum suppression on a tensor.
|
||||
|
||||
This function removes values from the input tensor that are not local
|
||||
maxima in the neighborhood of the given kernel size.
|
||||
|
||||
All non-maximum values are set to zero.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tensor : torch.Tensor
|
||||
Input tensor.
|
||||
kernel_size : Union[int, Tuple[int, int]], optional
|
||||
Size of the neighborhood to consider for non-maximum suppression.
|
||||
If an integer is given, the neighborhood will be a square of the
|
||||
given size. If a tuple is given, the neighborhood will be a
|
||||
rectangle with the given height and width.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Tensor with non-maximum suppressed values.
|
||||
"""
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size_h = kernel_size
|
||||
kernel_size_w = kernel_size
|
||||
else:
|
||||
kernel_size_h, kernel_size_w = kernel_size
|
||||
|
||||
pad_h = (kernel_size_h - 1) // 2
|
||||
pad_w = (kernel_size_w - 1) // 2
|
||||
|
||||
hmax = nn.functional.max_pool2d(
|
||||
tensor,
|
||||
(kernel_size_h, kernel_size_w),
|
||||
stride=1,
|
||||
padding=(pad_h, pad_w),
|
||||
)
|
||||
keep = (hmax == tensor).float()
|
||||
return tensor * keep
|
||||
|
||||
|
||||
def get_topk_scores(
|
||||
scores: torch.Tensor,
|
||||
K: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Get the top-k scores and their indices.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scores : torch.Tensor
|
||||
Tensor with scores. Expects input of size: `batch x 1 x height x width`.
|
||||
K : int
|
||||
Number of top scores to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
scores : torch.Tensor
|
||||
Top-k scores.
|
||||
ys : torch.Tensor
|
||||
Y coordinates of the top-k scores.
|
||||
xs : torch.Tensor
|
||||
X coordinates of the top-k scores.
|
||||
"""
|
||||
batch, _, height, width = scores.size()
|
||||
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K)
|
||||
topk_inds = topk_inds % (height * width)
|
||||
topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long()
|
||||
topk_xs = (topk_inds % width).long()
|
||||
return topk_scores, topk_ys, topk_xs
|
566
batdetect2/postprocess/__init__.py
Normal file
566
batdetect2/postprocess/__init__.py
Normal file
@ -0,0 +1,566 @@
|
||||
"""Main entry point for the BatDetect2 Postprocessing pipeline.
|
||||
|
||||
This package (`batdetect2.postprocess`) takes the raw outputs from a trained
|
||||
BatDetect2 neural network model and transforms them into meaningful, structured
|
||||
predictions, typically in the form of `soundevent.data.ClipPrediction` objects
|
||||
containing detected sound events with associated class tags and geometry.
|
||||
|
||||
The pipeline involves several configurable steps, implemented in submodules:
|
||||
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
|
||||
2. Coordinate Remapping (`.remapping`): Adds real-world time/frequency
|
||||
coordinates to raw model output arrays.
|
||||
3. Detection Extraction (`.detection`): Identifies candidate detection points
|
||||
(location and score) based on thresholds and score ranking (top-k).
|
||||
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
|
||||
class probabilities, features) at the detected locations.
|
||||
5. Decoding & Formatting (`.decoding`): Converts extracted numerical data and
|
||||
class predictions into interpretable `soundevent` objects, including
|
||||
recovering geometry (ROIs) and decoding class names back to standard tags.
|
||||
|
||||
This module provides the primary interface:
|
||||
- `PostprocessConfig`: A configuration object for postprocessing parameters
|
||||
(thresholds, NMS kernel size, etc.).
|
||||
- `load_postprocess_config`: Function to load the configuration from a file.
|
||||
- `Postprocessor`: The main class (implementing `PostprocessorProtocol`) that
|
||||
holds the configured pipeline logic.
|
||||
- `build_postprocessor`: A factory function to create a `Postprocessor`
|
||||
instance, linking it to the necessary target definitions (`TargetProtocol`).
|
||||
It also re-exports key components from submodules for convenience.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.postprocess.decoding import (
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
convert_raw_predictions_to_clip_prediction,
|
||||
convert_xr_dataset_to_raw_prediction,
|
||||
)
|
||||
from batdetect2.postprocess.detection import (
|
||||
DEFAULT_DETECTION_THRESHOLD,
|
||||
TOP_K_PER_SEC,
|
||||
extract_detections_from_array,
|
||||
get_max_detections,
|
||||
)
|
||||
from batdetect2.postprocess.extraction import (
|
||||
extract_detection_xr_dataset,
|
||||
)
|
||||
from batdetect2.postprocess.nms import (
|
||||
NMS_KERNEL_SIZE,
|
||||
non_max_suppression,
|
||||
)
|
||||
from batdetect2.postprocess.remapping import (
|
||||
classification_to_xarray,
|
||||
detection_to_xarray,
|
||||
features_to_xarray,
|
||||
sizes_to_xarray,
|
||||
)
|
||||
from batdetect2.postprocess.types import PostprocessorProtocol, RawPrediction
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||
"DEFAULT_DETECTION_THRESHOLD",
|
||||
"MAX_FREQ",
|
||||
"MIN_FREQ",
|
||||
"ModelOutput",
|
||||
"NMS_KERNEL_SIZE",
|
||||
"PostprocessConfig",
|
||||
"Postprocessor",
|
||||
"PostprocessorProtocol",
|
||||
"RawPrediction",
|
||||
"TOP_K_PER_SEC",
|
||||
"build_postprocessor",
|
||||
"classification_to_xarray",
|
||||
"convert_raw_predictions_to_clip_prediction",
|
||||
"convert_xr_dataset_to_raw_prediction",
|
||||
"detection_to_xarray",
|
||||
"extract_detection_xr_dataset",
|
||||
"extract_detections_from_array",
|
||||
"features_to_xarray",
|
||||
"get_max_detections",
|
||||
"load_postprocess_config",
|
||||
"non_max_suppression",
|
||||
"sizes_to_xarray",
|
||||
]
|
||||
|
||||
|
||||
class PostprocessConfig(BaseConfig):
|
||||
"""Configuration settings for the postprocessing pipeline.
|
||||
|
||||
Defines tunable parameters that control how raw model outputs are
|
||||
converted into final detections.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
nms_kernel_size : int, default=NMS_KERNEL_SIZE
|
||||
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
|
||||
Used to suppress weaker detections near stronger peaks. Must be
|
||||
positive.
|
||||
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
|
||||
Minimum confidence score from the detection heatmap required to
|
||||
consider a point as a potential detection. Must be >= 0.
|
||||
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||
Minimum confidence score for a specific class prediction to be included
|
||||
in the decoded tags for a detection. Must be >= 0.
|
||||
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||
Desired maximum number of detections per second of audio. Used by
|
||||
`get_max_detections` to calculate an absolute limit based on clip
|
||||
duration before applying `extract_detections_from_array`. Must be
|
||||
positive.
|
||||
"""
|
||||
|
||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||
detection_threshold: float = Field(
|
||||
default=DEFAULT_DETECTION_THRESHOLD,
|
||||
ge=0,
|
||||
)
|
||||
classification_threshold: float = Field(
|
||||
default=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
ge=0,
|
||||
)
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
def load_postprocess_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PostprocessConfig:
|
||||
"""Load the postprocessing configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`PostprocessConfig` schema, potentially extracting data from a nested
|
||||
field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
postprocessing configuration (e.g., "inference.postprocessing").
|
||||
If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PostprocessConfig
|
||||
The loaded and validated postprocessing configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded configuration data does not conform to the
|
||||
`PostprocessConfig` schema.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path within the loaded data.
|
||||
"""
|
||||
return load_config(path, schema=PostprocessConfig, field=field)
|
||||
|
||||
|
||||
def build_postprocessor(
|
||||
targets: TargetProtocol,
|
||||
config: Optional[PostprocessConfig] = None,
|
||||
max_freq: int = MAX_FREQ,
|
||||
min_freq: int = MIN_FREQ,
|
||||
) -> PostprocessorProtocol:
|
||||
"""Factory function to build the standard postprocessor.
|
||||
|
||||
Creates and initializes the `Postprocessor` instance, providing it with the
|
||||
necessary `targets` object and the `PostprocessConfig`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
targets : TargetProtocol
|
||||
An initialized object conforming to the `TargetProtocol`, providing
|
||||
methods like `.decode()` and `.recover_roi()`, and attributes like
|
||||
`.class_names` and `.generic_class_tags`. This links postprocessing
|
||||
to the defined target semantics and geometry mappings.
|
||||
config : PostprocessConfig, optional
|
||||
Configuration object specifying postprocessing parameters (thresholds,
|
||||
NMS kernel size, etc.). If None, default settings defined in
|
||||
`PostprocessConfig` will be used.
|
||||
min_freq : int, default=MIN_FREQ
|
||||
The minimum frequency (Hz) corresponding to the frequency axis of the
|
||||
model outputs. Required for coordinate remapping. Consider setting via
|
||||
`PostprocessConfig` instead for better encapsulation.
|
||||
max_freq : int, default=MAX_FREQ
|
||||
The maximum frequency (Hz) corresponding to the frequency axis of the
|
||||
model outputs. Required for coordinate remapping. Consider setting via
|
||||
`PostprocessConfig`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PostprocessorProtocol
|
||||
An initialized `Postprocessor` instance ready to process model outputs.
|
||||
"""
|
||||
return Postprocessor(
|
||||
targets=targets,
|
||||
config=config or PostprocessConfig(),
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
|
||||
|
||||
class Postprocessor(PostprocessorProtocol):
|
||||
"""Standard implementation of the postprocessing pipeline.
|
||||
|
||||
This class orchestrates the steps required to convert raw model outputs
|
||||
into interpretable `soundevent` predictions. It uses configured parameters
|
||||
and leverages functions from the `batdetect2.postprocess` submodules for
|
||||
each stage (NMS, remapping, detection, extraction, decoding).
|
||||
|
||||
It requires a `TargetProtocol` object during initialization to access
|
||||
necessary decoding information (class name to tag mapping,
|
||||
ROI recovery logic) ensuring consistency with the target definitions used
|
||||
during training or specified for inference.
|
||||
|
||||
Instances are typically created using the `build_postprocessor` factory
|
||||
function.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
targets : TargetProtocol
|
||||
The configured target definition object providing decoding and ROI
|
||||
recovery.
|
||||
config : PostprocessConfig
|
||||
Configuration object holding parameters for NMS, thresholds, etc.
|
||||
min_freq : int
|
||||
Minimum frequency (Hz) assumed for the model output's frequency axis.
|
||||
max_freq : int
|
||||
Maximum frequency (Hz) assumed for the model output's frequency axis.
|
||||
"""
|
||||
|
||||
targets: TargetProtocol
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
config: PostprocessConfig,
|
||||
min_freq: int = MIN_FREQ,
|
||||
max_freq: int = MAX_FREQ,
|
||||
):
|
||||
"""Initialize the Postprocessor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
targets : TargetProtocol
|
||||
Initialized target definition object.
|
||||
config : PostprocessConfig
|
||||
Configuration for postprocessing parameters.
|
||||
min_freq : int, default=MIN_FREQ
|
||||
Minimum frequency (Hz) for coordinate remapping.
|
||||
max_freq : int, default=MAX_FREQ
|
||||
Maximum frequency (Hz) for coordinate remapping.
|
||||
"""
|
||||
self.targets = targets
|
||||
self.config = config
|
||||
self.min_freq = min_freq
|
||||
self.max_freq = max_freq
|
||||
|
||||
def get_feature_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Extract and remap raw feature tensors for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw model output containing `output.features` tensor for the batch.
|
||||
clips : List[data.Clip]
|
||||
List of Clip objects corresponding to the batch items.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
List of coordinate-aware feature DataArrays, one per clip.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If batch sizes of `output.features` and `clips` do not match.
|
||||
"""
|
||||
if len(clips) != len(output.features):
|
||||
raise ValueError(
|
||||
"Number of clips and batch size of feature array"
|
||||
"do not match. "
|
||||
f"(clips: {len(clips)}, features: {len(output.features)})"
|
||||
)
|
||||
|
||||
return [
|
||||
features_to_xarray(
|
||||
feats,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for feats, clip in zip(output.features, clips)
|
||||
]
|
||||
|
||||
def get_detection_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Apply NMS and remap detection heatmaps for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw model output containing `output.detection_probs` tensor for the
|
||||
batch.
|
||||
clips : List[data.Clip]
|
||||
List of Clip objects corresponding to the batch items.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
List of NMS-applied, coordinate-aware detection heatmaps, one per
|
||||
clip.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If batch sizes of `output.detection_probs` and `clips` do not match.
|
||||
"""
|
||||
detections = output.detection_probs
|
||||
|
||||
if len(clips) != len(output.detection_probs):
|
||||
raise ValueError(
|
||||
"Number of clips and batch size of detection array "
|
||||
"do not match. "
|
||||
f"(clips: {len(clips)}, detection: {len(detections)})"
|
||||
)
|
||||
|
||||
detections = non_max_suppression(
|
||||
detections,
|
||||
kernel_size=self.config.nms_kernel_size,
|
||||
)
|
||||
|
||||
return [
|
||||
detection_to_xarray(
|
||||
dets,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for dets, clip in zip(detections, clips)
|
||||
]
|
||||
|
||||
def get_classification_arrays(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[xr.DataArray]:
|
||||
"""Extract and remap raw classification tensors for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw model output containing `output.class_probs` tensor for the
|
||||
batch.
|
||||
clips : List[data.Clip]
|
||||
List of Clip objects corresponding to the batch items.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
List of coordinate-aware class probability maps, one per clip.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If batch sizes of `output.class_probs` and `clips` do not match, or
|
||||
if number of classes mismatches `self.targets.class_names`.
|
||||
"""
|
||||
classifications = output.class_probs
|
||||
|
||||
if len(clips) != len(classifications):
|
||||
raise ValueError(
|
||||
"Number of clips and batch size of classification array "
|
||||
"do not match. "
|
||||
f"(clips: {len(clips)}, classification: {len(classifications)})"
|
||||
)
|
||||
|
||||
return [
|
||||
classification_to_xarray(
|
||||
class_probs,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
class_names=self.targets.class_names,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for class_probs, clip in zip(classifications, clips)
|
||||
]
|
||||
|
||||
def get_sizes_arrays(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[xr.DataArray]:
|
||||
"""Extract and remap raw size prediction tensors for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw model output containing `output.size_preds` tensor for the
|
||||
batch.
|
||||
clips : List[data.Clip]
|
||||
List of Clip objects corresponding to the batch items.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
List of coordinate-aware size prediction maps, one per clip.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If batch sizes of `output.size_preds` and `clips` do not match.
|
||||
"""
|
||||
sizes = output.size_preds
|
||||
|
||||
if len(clips) != len(sizes):
|
||||
raise ValueError(
|
||||
"Number of clips and batch size of sizes array do not match. "
|
||||
f"(clips: {len(clips)}, sizes: {len(sizes)})"
|
||||
)
|
||||
|
||||
return [
|
||||
sizes_to_xarray(
|
||||
size_preds,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for size_preds, clip in zip(sizes, clips)
|
||||
]
|
||||
|
||||
def get_detection_datasets(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[xr.Dataset]:
|
||||
"""Perform NMS, remapping, detection, and data extraction for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.Dataset]
|
||||
List of xarray Datasets (one per clip). Each Dataset contains
|
||||
aligned scores, dimensions, class probabilities, and features for
|
||||
detections found in that clip.
|
||||
"""
|
||||
detection_arrays = self.get_detection_arrays(output, clips)
|
||||
classification_arrays = self.get_classification_arrays(output, clips)
|
||||
size_arrays = self.get_sizes_arrays(output, clips)
|
||||
features_arrays = self.get_feature_arrays(output, clips)
|
||||
|
||||
datasets = []
|
||||
for det_array, class_array, sizes_array, feats_array in zip(
|
||||
detection_arrays,
|
||||
classification_arrays,
|
||||
size_arrays,
|
||||
features_arrays,
|
||||
):
|
||||
max_detections = get_max_detections(
|
||||
det_array,
|
||||
top_k_per_sec=self.config.top_k_per_sec,
|
||||
)
|
||||
|
||||
positions = extract_detections_from_array(
|
||||
det_array,
|
||||
max_detections=max_detections,
|
||||
threshold=self.config.detection_threshold,
|
||||
)
|
||||
|
||||
datasets.append(
|
||||
extract_detection_xr_dataset(
|
||||
positions,
|
||||
sizes_array,
|
||||
class_array,
|
||||
feats_array,
|
||||
)
|
||||
)
|
||||
|
||||
return datasets
|
||||
|
||||
def get_raw_predictions(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[List[RawPrediction]]:
|
||||
"""Extract intermediate RawPrediction objects for a batch.
|
||||
|
||||
Processes raw model output through remapping, NMS, detection, data
|
||||
extraction, and geometry recovery via the configured
|
||||
`targets.recover_roi`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[List[RawPrediction]]
|
||||
List of lists (one inner list per input clip). Each inner list
|
||||
contains `RawPrediction` objects for detections in that clip.
|
||||
"""
|
||||
detection_datasets = self.get_detection_datasets(output, clips)
|
||||
return [
|
||||
convert_xr_dataset_to_raw_prediction(
|
||||
dataset,
|
||||
self.targets.recover_roi,
|
||||
)
|
||||
for dataset in detection_datasets
|
||||
]
|
||||
|
||||
def get_predictions(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[data.ClipPrediction]:
|
||||
"""Perform the full postprocessing pipeline for a batch.
|
||||
|
||||
Takes raw model output and corresponding clips, applies the entire
|
||||
configured chain (NMS, remapping, extraction, geometry recovery, class
|
||||
decoding), producing final `soundevent.data.ClipPrediction` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.ClipPrediction]
|
||||
List containing one `ClipPrediction` object for each input clip,
|
||||
populated with `SoundEventPrediction` objects.
|
||||
"""
|
||||
raw_predictions = self.get_raw_predictions(output, clips)
|
||||
return [
|
||||
convert_raw_predictions_to_clip_prediction(
|
||||
prediction,
|
||||
clip,
|
||||
sound_event_decoder=self.targets.decode,
|
||||
generic_class_tags=self.targets.generic_class_tags,
|
||||
classification_threshold=self.config.classification_threshold,
|
||||
)
|
||||
for prediction, clip in zip(raw_predictions, clips)
|
||||
]
|
409
batdetect2/postprocess/decoding.py
Normal file
409
batdetect2/postprocess/decoding.py
Normal file
@ -0,0 +1,409 @@
|
||||
"""Decodes extracted detection data into standard soundevent predictions.
|
||||
|
||||
This module handles the final stages of the BatDetect2 postprocessing pipeline.
|
||||
It takes the structured detection data extracted by the `extraction` module
|
||||
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
|
||||
class probabilities, and features for each detection point) and converts it
|
||||
into meaningful, standardized prediction objects based on the `soundevent` data
|
||||
model.
|
||||
|
||||
The process involves:
|
||||
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
|
||||
objects, using a configured geometry builder to recover bounding boxes from
|
||||
predicted positions and sizes (`convert_xr_dataset_to_raw_prediction`).
|
||||
2. Converting each `RawPrediction` into a
|
||||
`soundevent.data.SoundEventPrediction`, which involves:
|
||||
- Creating the `soundevent.data.SoundEvent` with geometry and features.
|
||||
- Decoding the predicted class probabilities into representative tags using
|
||||
a configured class decoder (`SoundEventDecoder`).
|
||||
- Applying a classification threshold.
|
||||
- Optionally selecting only the single highest-scoring class (top-1) or
|
||||
including tags for all classes above the threshold (multi-label).
|
||||
- Adding generic class tags as a baseline.
|
||||
- Associating scores with the final prediction and tags.
|
||||
(`convert_raw_prediction_to_sound_event_prediction`)
|
||||
3. Grouping the `SoundEventPrediction` objects for a given audio segment into
|
||||
a `soundevent.data.ClipPrediction`
|
||||
(`convert_raw_predictions_to_clip_prediction`).
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
|
||||
from batdetect2.targets.classes import SoundEventDecoder
|
||||
|
||||
__all__ = [
|
||||
"convert_xr_dataset_to_raw_prediction",
|
||||
"convert_raw_predictions_to_clip_prediction",
|
||||
"convert_raw_prediction_to_sound_event_prediction",
|
||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
|
||||
"""Default threshold applied to classification scores.
|
||||
|
||||
Class predictions with scores below this value are typically ignored during
|
||||
decoding.
|
||||
"""
|
||||
|
||||
|
||||
def convert_xr_dataset_to_raw_prediction(
|
||||
detection_dataset: xr.Dataset,
|
||||
geometry_builder: GeometryBuilder,
|
||||
) -> List[RawPrediction]:
|
||||
"""Convert an xarray.Dataset of detections to RawPrediction objects.
|
||||
|
||||
Takes the output of the extraction step (`extract_detection_xr_dataset`)
|
||||
and transforms each detection entry into an intermediate `RawPrediction`
|
||||
object. This involves recovering the geometry (e.g., bounding box) from
|
||||
the predicted position and scaled size dimensions using the provided
|
||||
`geometry_builder` function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_dataset : xr.Dataset
|
||||
An xarray Dataset containing aligned detection information, typically
|
||||
output by `extract_detection_xr_dataset`. Expected variables include
|
||||
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
|
||||
Must have a 'detection' dimension.
|
||||
geometry_builder : GeometryBuilder
|
||||
A function that takes a position tuple `(time, freq)` and a NumPy array
|
||||
of dimensions, and returns the corresponding reconstructed
|
||||
`soundevent.data.Geometry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[RawPrediction]
|
||||
A list of `RawPrediction` objects, each containing the detection score,
|
||||
recovered bounding box coordinates (start/end time, low/high freq),
|
||||
the vector of class scores, and the feature vector for one detection.
|
||||
|
||||
Raises
|
||||
------
|
||||
AttributeError, KeyError, ValueError
|
||||
If `detection_dataset` is missing expected variables ('scores',
|
||||
'dimensions', 'classes', 'features') or coordinates ('time', 'freq'
|
||||
associated with 'scores'), or if `geometry_builder` fails.
|
||||
"""
|
||||
detections = []
|
||||
|
||||
for det_num in range(detection_dataset.sizes["detection"]):
|
||||
det_info = detection_dataset.sel(detection=det_num)
|
||||
|
||||
geom = geometry_builder(
|
||||
(det_info.time, det_info.freq),
|
||||
det_info.dimensions,
|
||||
)
|
||||
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(geom)
|
||||
detections.append(
|
||||
RawPrediction(
|
||||
detection_score=det_info.score,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
class_scores=det_info.classes,
|
||||
features=det_info.features,
|
||||
)
|
||||
)
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
def convert_raw_predictions_to_clip_prediction(
|
||||
raw_predictions: List[RawPrediction],
|
||||
clip: data.Clip,
|
||||
sound_event_decoder: SoundEventDecoder,
|
||||
generic_class_tags: List[data.Tag],
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only: bool = False,
|
||||
) -> data.ClipPrediction:
|
||||
"""Convert a list of RawPredictions into a soundevent ClipPrediction.
|
||||
|
||||
Iterates through `raw_predictions` (assumed to belong to a single clip),
|
||||
converts each one into a `soundevent.data.SoundEventPrediction` using
|
||||
`convert_raw_prediction_to_sound_event_prediction`, and packages them
|
||||
into a `soundevent.data.ClipPrediction` associated with the original `clip`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw_predictions : List[RawPrediction]
|
||||
List of raw prediction objects for a single clip.
|
||||
clip : data.Clip
|
||||
The original `soundevent.data.Clip` object these predictions belong to.
|
||||
sound_event_decoder : SoundEventDecoder
|
||||
Function to decode class names into representative tags.
|
||||
generic_class_tags : List[data.Tag]
|
||||
List of tags representing the generic class category.
|
||||
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||
Threshold applied to class scores during decoding.
|
||||
top_class_only : bool, default=False
|
||||
If True, only decode tags for the single highest-scoring class above
|
||||
the threshold. If False, decode tags for all classes above threshold.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.ClipPrediction
|
||||
A `ClipPrediction` object containing a list of `SoundEventPrediction`
|
||||
objects corresponding to the input `raw_predictions`.
|
||||
"""
|
||||
return data.ClipPrediction(
|
||||
clip=clip,
|
||||
sound_events=[
|
||||
convert_raw_prediction_to_sound_event_prediction(
|
||||
prediction,
|
||||
recording=clip.recording,
|
||||
sound_event_decoder=sound_event_decoder,
|
||||
generic_class_tags=generic_class_tags,
|
||||
classification_threshold=classification_threshold,
|
||||
top_class_only=top_class_only,
|
||||
)
|
||||
for prediction in raw_predictions
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction: RawPrediction,
|
||||
recording: data.Recording,
|
||||
sound_event_decoder: SoundEventDecoder,
|
||||
generic_class_tags: List[data.Tag],
|
||||
classification_threshold: Optional[
|
||||
float
|
||||
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only: bool = False,
|
||||
):
|
||||
"""Convert a single RawPrediction into a soundevent SoundEventPrediction.
|
||||
|
||||
This function performs the core decoding steps for a single detected event:
|
||||
1. Creates a `soundevent.data.SoundEvent` containing the geometry
|
||||
(BoundingBox derived from `raw_prediction` bounds) and any associated
|
||||
feature vectors.
|
||||
2. Initializes a list of predicted tags using the provided
|
||||
`generic_class_tags`, assigning the overall `detection_score` from the
|
||||
`raw_prediction` to these generic tags.
|
||||
3. Processes the `class_scores` from the `raw_prediction`:
|
||||
a. Optionally filters out scores below `classification_threshold`
|
||||
(if it's not None).
|
||||
b. Sorts the remaining scores in descending order.
|
||||
c. Iterates through the sorted, thresholded class scores.
|
||||
d. For each class, uses the `sound_event_decoder` to get the
|
||||
representative base tags for that class name.
|
||||
e. Wraps these base tags in `soundevent.data.PredictedTag`, associating
|
||||
the specific `score` of that class prediction.
|
||||
f. Appends these specific predicted tags to the list.
|
||||
g. If `top_class_only` is True, stops after processing the first
|
||||
(highest-scoring) class that passed the threshold.
|
||||
4. Creates and returns the final `soundevent.data.SoundEventPrediction`,
|
||||
associating the `SoundEvent`, the overall `detection_score`, and the
|
||||
compiled list of `PredictedTag` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw_prediction : RawPrediction
|
||||
The raw prediction object containing score, bounds, class scores,
|
||||
features. Assumes `class_scores` is an `xr.DataArray` with a 'category'
|
||||
coordinate. Assumes `features` is an `xr.DataArray` with a 'feature'
|
||||
coordinate.
|
||||
recording : data.Recording
|
||||
The recording the sound event belongs to.
|
||||
sound_event_decoder : SoundEventDecoder
|
||||
Configured function mapping class names (str) to lists of base
|
||||
`data.Tag` objects.
|
||||
generic_class_tags : List[data.Tag]
|
||||
List of base tags representing the generic category.
|
||||
classification_threshold : float, optional
|
||||
The minimum score a class prediction must have to be considered
|
||||
significant enough to have its tags decoded and added. If None, no
|
||||
thresholding is applied based on class score (all predicted classes,
|
||||
or the top one if `top_class_only` is True, will be processed).
|
||||
Defaults to `DEFAULT_CLASSIFICATION_THRESHOLD`.
|
||||
top_class_only : bool, default=False
|
||||
If True, only includes tags for the single highest-scoring class that
|
||||
exceeds the threshold. If False (default), includes tags for all classes
|
||||
exceeding the threshold.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.SoundEventPrediction
|
||||
The fully formed sound event prediction object.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `raw_prediction.features` has unexpected structure or if
|
||||
`data.term_from_key` (if used internally) fails.
|
||||
If `sound_event_decoder` fails for a class name and errors are raised.
|
||||
"""
|
||||
sound_event = data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
raw_prediction.start_time,
|
||||
raw_prediction.low_freq,
|
||||
raw_prediction.end_time,
|
||||
raw_prediction.high_freq,
|
||||
]
|
||||
),
|
||||
features=get_prediction_features(raw_prediction.features),
|
||||
)
|
||||
|
||||
tags = [
|
||||
*get_generic_tags(
|
||||
raw_prediction.detection_score,
|
||||
generic_class_tags=generic_class_tags,
|
||||
),
|
||||
*get_class_tags(
|
||||
raw_prediction.class_scores,
|
||||
sound_event_decoder,
|
||||
top_class_only=top_class_only,
|
||||
threshold=classification_threshold,
|
||||
),
|
||||
]
|
||||
|
||||
return data.SoundEventPrediction(
|
||||
sound_event=sound_event,
|
||||
score=raw_prediction.detection_score,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
|
||||
def get_generic_tags(
|
||||
detection_score: float,
|
||||
generic_class_tags: List[data.Tag],
|
||||
) -> List[data.PredictedTag]:
|
||||
"""Create PredictedTag objects for the generic category.
|
||||
|
||||
Takes the base list of generic tags and assigns the overall detection
|
||||
score to each one, wrapping them in `PredictedTag` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_score : float
|
||||
The overall confidence score of the detection event.
|
||||
generic_class_tags : List[data.Tag]
|
||||
The list of base `soundevent.data.Tag` objects that define the
|
||||
generic category (e.g., ['call_type:Echolocation', 'order:Chiroptera']).
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.PredictedTag]
|
||||
A list of `PredictedTag` objects for the generic category, each
|
||||
assigned the `detection_score`.
|
||||
"""
|
||||
return [
|
||||
data.PredictedTag(tag=tag, score=detection_score)
|
||||
for tag in generic_class_tags
|
||||
]
|
||||
|
||||
|
||||
def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
|
||||
"""Convert an extracted feature vector DataArray into soundevent Features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : xr.DataArray
|
||||
A 1D xarray DataArray containing feature values, indexed by a coordinate
|
||||
named 'feature' which holds the feature names (e.g., output of selecting
|
||||
features for one detection from `extract_detection_xr_dataset`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.Feature]
|
||||
A list of `soundevent.data.Feature` objects.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- This function creates basic `Term` objects using the feature coordinate
|
||||
names with a "batdetect2:" prefix.
|
||||
"""
|
||||
return [
|
||||
data.Feature(
|
||||
term=data.Term(
|
||||
name=f"batdetect2:{feat_name}",
|
||||
label=feat_name,
|
||||
definition="Automatically extracted features by BatDetect2",
|
||||
),
|
||||
value=value,
|
||||
)
|
||||
for feat_name, value in _iterate_over_array(features)
|
||||
]
|
||||
|
||||
|
||||
def get_class_tags(
|
||||
class_scores: xr.DataArray,
|
||||
sound_event_decoder: SoundEventDecoder,
|
||||
top_class_only: bool = False,
|
||||
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
) -> List[data.PredictedTag]:
|
||||
"""Generate specific PredictedTags based on class scores and decoder.
|
||||
|
||||
Filters class scores by the threshold, sorts remaining scores descending,
|
||||
decodes the class name(s) into base tags using the `sound_event_decoder`,
|
||||
and creates `PredictedTag` objects associating the class score. Stops after
|
||||
the first (top) class if `top_class_only` is True.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
class_scores : xr.DataArray
|
||||
A 1D xarray DataArray containing class probabilities/scores, indexed
|
||||
by a 'category' coordinate holding the class names.
|
||||
sound_event_decoder : SoundEventDecoder
|
||||
Function to map a class name string to a list of base `data.Tag`
|
||||
objects.
|
||||
top_class_only : bool, default=False
|
||||
If True, only generate tags for the single highest-scoring class above
|
||||
the threshold.
|
||||
threshold : float, optional
|
||||
Minimum score for a class to be considered. If None, all classes are
|
||||
processed (or top-1 if `top_class_only` is True). Defaults to
|
||||
`DEFAULT_CLASSIFICATION_THRESHOLD`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.PredictedTag]
|
||||
A list of `PredictedTag` objects for the class(es) that passed the
|
||||
threshold, ordered by score if `top_class_only` is False.
|
||||
"""
|
||||
tags = []
|
||||
|
||||
if threshold is not None:
|
||||
class_scores = class_scores.where(class_scores > threshold, drop=True)
|
||||
|
||||
for class_name, score in _iterate_sorted(class_scores):
|
||||
class_tags = sound_event_decoder(class_name)
|
||||
|
||||
for tag in class_tags:
|
||||
tags.append(
|
||||
data.PredictedTag(
|
||||
tag=tag,
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
|
||||
if top_class_only:
|
||||
break
|
||||
|
||||
return tags
|
||||
|
||||
|
||||
def _iterate_over_array(array: xr.DataArray):
|
||||
dim_name = array.dims[0]
|
||||
coords = array.coords[dim_name]
|
||||
for value, coord in zip(array.values, coords.values):
|
||||
yield coord, float(value)
|
||||
|
||||
|
||||
def _iterate_sorted(array: xr.DataArray):
|
||||
dim_name = array.dims[0]
|
||||
coords = array.coords[dim_name].values
|
||||
indices = np.argsort(-array.values)
|
||||
for index in indices:
|
||||
yield str(coords[index]), float(array.values[index])
|
162
batdetect2/postprocess/detection.py
Normal file
162
batdetect2/postprocess/detection.py
Normal file
@ -0,0 +1,162 @@
|
||||
"""Extracts candidate detection points from a model output heatmap.
|
||||
|
||||
This module implements a specific step within the BatDetect2 postprocessing
|
||||
pipeline. Its primary function is to identify potential sound event locations
|
||||
by finding peaks (local maxima or high-scoring points) in the detection heatmap
|
||||
produced by the neural network (usually after Non-Maximum Suppression and
|
||||
coordinate remapping have been applied).
|
||||
|
||||
It provides functionality to:
|
||||
- Identify the locations (time, frequency) of the highest-scoring points.
|
||||
- Filter these points based on a minimum confidence score threshold.
|
||||
- Limit the maximum number of detection points returned (top-k).
|
||||
|
||||
The main output is an `xarray.DataArray` containing the scores and
|
||||
corresponding time/frequency coordinates for the extracted detection points.
|
||||
This output serves as the input for subsequent postprocessing steps, such as
|
||||
extracting predicted class probabilities and bounding box sizes at these
|
||||
specific locations.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions, get_dim_width
|
||||
|
||||
__all__ = [
|
||||
"extract_detections_from_array",
|
||||
"get_max_detections",
|
||||
"DEFAULT_DETECTION_THRESHOLD",
|
||||
"TOP_K_PER_SEC",
|
||||
]
|
||||
|
||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||
"""Default confidence score threshold used for filtering detections."""
|
||||
|
||||
TOP_K_PER_SEC = 200
|
||||
"""Default desired maximum number of detections per second of audio."""
|
||||
|
||||
|
||||
def extract_detections_from_array(
|
||||
detection_array: xr.DataArray,
|
||||
max_detections: Optional[int] = None,
|
||||
threshold: Optional[float] = DEFAULT_DETECTION_THRESHOLD,
|
||||
) -> xr.DataArray:
|
||||
"""Extract detection locations (time, freq) and scores from a heatmap.
|
||||
|
||||
Identifies the pixels with the highest scores in the input detection
|
||||
heatmap, filters them based on an optional score `threshold`, limits the
|
||||
number to an optional `max_detections`, and returns their scores along with
|
||||
their corresponding time and frequency coordinates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_array : xr.DataArray
|
||||
A 2D xarray DataArray representing the detection heatmap. Must have
|
||||
dimensions and coordinates named 'time' and 'frequency'. Higher values
|
||||
are assumed to indicate higher detection confidence.
|
||||
max_detections : int, optional
|
||||
The absolute maximum number of detections to return. If specified, only
|
||||
the top `max_detections` highest-scoring detections (passing the
|
||||
threshold) are returned. If None (default), all detections passing
|
||||
the threshold are returned, sorted by score.
|
||||
threshold : float, optional
|
||||
The minimum confidence score required for a detection peak to be
|
||||
kept. Detections with scores below this value are discarded.
|
||||
Defaults to `DEFAULT_DETECTION_THRESHOLD`. If set to None, no
|
||||
thresholding is applied.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
A 1D xarray DataArray named 'score' with a 'detection' dimension.
|
||||
- The data values are the scores of the extracted detections, sorted
|
||||
in descending order.
|
||||
- It has coordinates 'time' and 'frequency' (also indexed by the
|
||||
'detection' dimension) indicating the location of each detection
|
||||
peak in the original coordinate system.
|
||||
- Returns an empty DataArray if no detections pass the criteria.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `max_detections` is not None and not a positive integer, or if
|
||||
`detection_array` lacks required dimensions/coordinates.
|
||||
"""
|
||||
if max_detections is not None:
|
||||
if max_detections <= 0:
|
||||
raise ValueError("Max detections must be positive")
|
||||
|
||||
values = detection_array.values.flatten()
|
||||
|
||||
if max_detections is not None:
|
||||
top_indices = np.argpartition(-values, max_detections)[:max_detections]
|
||||
top_sorted_indices = top_indices[np.argsort(-values[top_indices])]
|
||||
else:
|
||||
top_sorted_indices = np.argsort(-values)
|
||||
|
||||
top_values = values[top_sorted_indices]
|
||||
|
||||
if threshold is not None:
|
||||
mask = top_values > threshold
|
||||
top_values = top_values[mask]
|
||||
top_sorted_indices = top_sorted_indices[mask]
|
||||
|
||||
freq_indices, time_indices = np.unravel_index(
|
||||
top_sorted_indices,
|
||||
detection_array.shape,
|
||||
)
|
||||
|
||||
times = detection_array.coords[Dimensions.time.value].values[time_indices]
|
||||
freqs = detection_array.coords[Dimensions.frequency.value].values[
|
||||
freq_indices
|
||||
]
|
||||
|
||||
return xr.DataArray(
|
||||
data=top_values,
|
||||
coords={
|
||||
Dimensions.frequency.value: ("detection", freqs),
|
||||
Dimensions.time.value: ("detection", times),
|
||||
},
|
||||
dims="detection",
|
||||
name="score",
|
||||
)
|
||||
|
||||
|
||||
def get_max_detections(
|
||||
detection_array: xr.DataArray,
|
||||
top_k_per_sec: int = TOP_K_PER_SEC,
|
||||
) -> int:
|
||||
"""Calculate max detections allowed based on duration and rate.
|
||||
|
||||
Determines the total maximum number of detections to extract from a
|
||||
heatmap based on its time duration and a desired rate of detections
|
||||
per second.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_array : xr.DataArray
|
||||
The detection heatmap, requiring 'time' coordinates from which the
|
||||
total duration can be calculated using
|
||||
`soundevent.arrays.get_dim_width`.
|
||||
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||
The desired maximum number of detections to allow per second of audio.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The calculated total maximum number of detections allowed for the
|
||||
entire duration of the `detection_array`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the duration cannot be calculated from the `detection_array` (e.g.,
|
||||
missing or invalid 'time' coordinates/dimension).
|
||||
"""
|
||||
if top_k_per_sec < 0:
|
||||
raise ValueError("top_k_per_sec cannot be negative.")
|
||||
|
||||
duration = get_dim_width(detection_array, Dimensions.time.value)
|
||||
return int(duration * top_k_per_sec)
|
122
batdetect2/postprocess/extraction.py
Normal file
122
batdetect2/postprocess/extraction.py
Normal file
@ -0,0 +1,122 @@
|
||||
"""Extracts associated data for detected points from model output arrays.
|
||||
|
||||
This module implements a key step (Step 4) in the BatDetect2 postprocessing
|
||||
pipeline. After candidate detection points (time, frequency, score) have been
|
||||
identified, this module extracts the corresponding values from other raw model
|
||||
output arrays, such as:
|
||||
|
||||
- Predicted bounding box sizes (width, height).
|
||||
- Class probability scores for each defined target class.
|
||||
- Intermediate feature vectors.
|
||||
|
||||
It uses coordinate-based indexing provided by `xarray` to ensure that the
|
||||
correct values are retrieved from the original heatmaps/feature maps at the
|
||||
precise time-frequency location of each detection. The final output aggregates
|
||||
all extracted information into a structured `xarray.Dataset`.
|
||||
"""
|
||||
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
__all__ = [
|
||||
"extract_values_at_positions",
|
||||
"extract_detection_xr_dataset",
|
||||
]
|
||||
|
||||
|
||||
def extract_values_at_positions(
|
||||
array: xr.DataArray,
|
||||
positions: xr.DataArray,
|
||||
) -> xr.DataArray:
|
||||
"""Extract values from an array at specified time-frequency positions.
|
||||
|
||||
Uses coordinate-based indexing to retrieve values from a source `array`
|
||||
(e.g., class probabilities, size predictions, features) at the time and
|
||||
frequency coordinates defined in the `positions` array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : xr.DataArray
|
||||
The source DataArray from which to extract values. Must have 'time'
|
||||
and 'frequency' dimensions and coordinates matching the space of
|
||||
`positions`.
|
||||
positions : xr.DataArray
|
||||
A 1D DataArray whose 'time' and 'frequency' coordinates specify the
|
||||
locations from which to extract values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
A DataArray containing the values extracted from `array` at the given
|
||||
positions.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError, IndexError, KeyError
|
||||
If dimensions or coordinates are missing or incompatible between
|
||||
`array` and `positions`, or if selection fails.
|
||||
"""
|
||||
return array.sel(
|
||||
**{
|
||||
Dimensions.frequency.value: positions.coords[
|
||||
Dimensions.frequency.value
|
||||
],
|
||||
Dimensions.time.value: positions.coords[Dimensions.time.value],
|
||||
}
|
||||
).T
|
||||
|
||||
|
||||
def extract_detection_xr_dataset(
|
||||
positions: xr.DataArray,
|
||||
sizes: xr.DataArray,
|
||||
classes: xr.DataArray,
|
||||
features: xr.DataArray,
|
||||
) -> xr.Dataset:
|
||||
"""Combine extracted detection information into a structured xr.Dataset.
|
||||
|
||||
Takes the detection positions/scores and the full model output heatmaps
|
||||
(sizes, classes, optional features), extracts the relevant data at the
|
||||
detection positions, and packages everything into a single `xarray.Dataset`
|
||||
where all variables are indexed by a common 'detection' dimension.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
positions : xr.DataArray
|
||||
Output from `extract_detections_from_array`, containing detection
|
||||
scores as data and 'time', 'frequency' coordinates along the
|
||||
'detection' dimension.
|
||||
sizes : xr.DataArray
|
||||
The full size prediction heatmap from the model, with dimensions like
|
||||
('dimension', 'time', 'frequency').
|
||||
classes : xr.DataArray
|
||||
The full class probability heatmap from the model, with dimensions like
|
||||
('category', 'time', 'frequency').
|
||||
features : xr.DataArray
|
||||
The full feature map from the model, with
|
||||
dimensions like ('feature', 'time', 'frequency').
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
An xarray Dataset containing aligned information for each detection:
|
||||
- 'scores': DataArray from `positions` (score data, time/freq coords).
|
||||
- 'dimensions': DataArray with extracted size values
|
||||
(dims: 'detection', 'dimension').
|
||||
- 'classes': DataArray with extracted class probabilities
|
||||
(dims: 'detection', 'category').
|
||||
- 'features': DataArray with extracted feature vectors
|
||||
(dims: 'detection', 'feature'), if `features` was provided. All
|
||||
DataArrays share the 'detection' dimension and associated
|
||||
time/frequency coordinates.
|
||||
"""
|
||||
sizes = extract_values_at_positions(sizes, positions)
|
||||
classes = extract_values_at_positions(classes, positions)
|
||||
features = extract_values_at_positions(features, positions)
|
||||
return xr.Dataset(
|
||||
{
|
||||
"scores": positions,
|
||||
"dimensions": sizes,
|
||||
"classes": classes,
|
||||
"features": features,
|
||||
}
|
||||
)
|
96
batdetect2/postprocess/nms.py
Normal file
96
batdetect2/postprocess/nms.py
Normal file
@ -0,0 +1,96 @@
|
||||
"""Performs Non-Maximum Suppression (NMS) on detection heatmaps.
|
||||
|
||||
This module provides functionality to apply Non-Maximum Suppression, a common
|
||||
technique used after model inference, particularly in object detection and peak
|
||||
detection tasks.
|
||||
|
||||
In the context of BatDetect2 postprocessing, NMS is applied
|
||||
to the raw detection heatmap output by the neural network. Its purpose is to
|
||||
isolate distinct detection peaks by suppressing (setting to zero) nearby heatmap
|
||||
activations that have lower scores than a local maximum. This helps prevent
|
||||
multiple, overlapping detections originating from the same sound event.
|
||||
"""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
NMS_KERNEL_SIZE = 9
|
||||
"""Default kernel size (pixels) for Non-Maximum Suppression.
|
||||
|
||||
Specifies the side length of the square neighborhood used by default in
|
||||
`non_max_suppression` to find local maxima. A 9x9 neighborhood is often
|
||||
a reasonable starting point for typical spectrogram resolutions used in
|
||||
BatDetect2.
|
||||
"""
|
||||
|
||||
|
||||
def non_max_suppression(
|
||||
tensor: torch.Tensor,
|
||||
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||
) -> torch.Tensor:
|
||||
"""Apply Non-Maximum Suppression (NMS) to a tensor, typically a heatmap.
|
||||
|
||||
This function identifies local maxima within a defined neighborhood for
|
||||
each point in the input tensor. Values that are *not* the maximum within
|
||||
their neighborhood are suppressed (set to zero). This is commonly used on
|
||||
detection probability heatmaps to isolate distinct peaks corresponding to
|
||||
individual detections and remove redundant lower scores nearby.
|
||||
|
||||
The implementation uses efficient 2D max pooling to find the maximum value
|
||||
in the neighborhood of each point.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tensor : torch.Tensor
|
||||
Input tensor, typically representing a detection heatmap. Must be a
|
||||
3D (C, H, W) or 4D (N, C, H, W) tensor as required by the underlying
|
||||
`torch.nn.functional.max_pool2d` operation.
|
||||
kernel_size : Union[int, Tuple[int, int]], default=NMS_KERNEL_SIZE
|
||||
Size of the sliding window neighborhood used to find local maxima.
|
||||
If an integer `k` is provided, a square kernel of size `(k, k)` is used.
|
||||
If a tuple `(h, w)` is provided, a rectangular kernel of height `h`
|
||||
and width `w` is used. The kernel size should typically be odd to
|
||||
have a well-defined center.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
A tensor of the same shape as the input, where only local maxima within
|
||||
their respective neighborhoods (defined by `kernel_size`) retain their
|
||||
original values. All other values are set to zero.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError
|
||||
If `kernel_size` is not an int or a tuple of two ints.
|
||||
RuntimeError
|
||||
If the input `tensor` does not have 3 or 4 dimensions (as required
|
||||
by `max_pool2d`).
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The function assumes higher values in the tensor indicate stronger peaks.
|
||||
- Choosing an appropriate `kernel_size` is important. It should be large
|
||||
enough to cover the typical "footprint" of a single detection peak plus
|
||||
some surrounding context, effectively preventing multiple detections for
|
||||
the same event. A size that is too large might suppress nearby distinct
|
||||
events.
|
||||
"""
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size_h = kernel_size
|
||||
kernel_size_w = kernel_size
|
||||
else:
|
||||
kernel_size_h, kernel_size_w = kernel_size
|
||||
|
||||
pad_h = (kernel_size_h - 1) // 2
|
||||
pad_w = (kernel_size_w - 1) // 2
|
||||
|
||||
hmax = torch.nn.functional.max_pool2d(
|
||||
tensor,
|
||||
(kernel_size_h, kernel_size_w),
|
||||
stride=1,
|
||||
padding=(pad_h, pad_w),
|
||||
)
|
||||
keep = (hmax == tensor).float()
|
||||
return tensor * keep
|
316
batdetect2/postprocess/remapping.py
Normal file
316
batdetect2/postprocess/remapping.py
Normal file
@ -0,0 +1,316 @@
|
||||
"""Remaps raw model output tensors to coordinate-aware xarray DataArrays.
|
||||
|
||||
This module provides utility functions to convert the raw numerical outputs
|
||||
(typically PyTorch tensors) from the BatDetect2 DNN model into
|
||||
`xarray.DataArray` objects. This step adds coordinate information
|
||||
(time in seconds, frequency in Hz) back to the model's predictions, making them
|
||||
interpretable in the context of the original audio signal and facilitating
|
||||
subsequent processing steps.
|
||||
|
||||
Functions are provided for common BatDetect2 output types: detection heatmaps,
|
||||
classification probability maps, size prediction maps, and potentially
|
||||
intermediate features.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
|
||||
__all__ = [
|
||||
"features_to_xarray",
|
||||
"detection_to_xarray",
|
||||
"classification_to_xarray",
|
||||
"sizes_to_xarray",
|
||||
]
|
||||
|
||||
|
||||
def features_to_xarray(
|
||||
features: torch.Tensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
features_prefix: str = "batdetect2_feature_",
|
||||
):
|
||||
"""Convert a multi-channel feature tensor to a coordinate-aware DataArray.
|
||||
|
||||
Assigns time, frequency, and feature coordinates to a raw feature tensor
|
||||
output by the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : torch.Tensor
|
||||
The raw feature tensor from the model. Expected shape is
|
||||
(num_features, num_freq_bins, num_time_bins).
|
||||
start_time : float
|
||||
The start time (in seconds) corresponding to the first time bin of
|
||||
the tensor.
|
||||
end_time : float
|
||||
The end time (in seconds) corresponding to the *end* of the last time
|
||||
bin.
|
||||
min_freq : float, default=MIN_FREQ
|
||||
The minimum frequency (in Hz) corresponding to the first frequency bin.
|
||||
max_freq : float, default=MAX_FREQ
|
||||
The maximum frequency (in Hz) corresponding to the *end* of the last
|
||||
frequency bin.
|
||||
features_prefix : str, default="batdetect2_feature_"
|
||||
Prefix used to generate names for the feature coordinate dimension
|
||||
(e.g., "batdetect2_feature_0", "batdetect2_feature_1", ...).
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
An xarray DataArray containing the feature data with named dimensions
|
||||
('feature', 'frequency', 'time') and calculated coordinates.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input tensor does not have 3 dimensions.
|
||||
"""
|
||||
if features.ndim != 3:
|
||||
raise ValueError(
|
||||
"Input features tensor must have 3 dimensions (C, T, F), "
|
||||
f"got shape {features.shape}"
|
||||
)
|
||||
|
||||
num_features, height, width = features.shape
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
return xr.DataArray(
|
||||
data=features.detach().numpy(),
|
||||
dims=[
|
||||
Dimensions.feature.value,
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
coords={
|
||||
Dimensions.feature.value: [
|
||||
f"{features_prefix}{i}" for i in range(num_features)
|
||||
],
|
||||
Dimensions.frequency.value: freqs,
|
||||
Dimensions.time.value: times,
|
||||
},
|
||||
name="features",
|
||||
)
|
||||
|
||||
|
||||
def detection_to_xarray(
|
||||
detection: torch.Tensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
) -> xr.DataArray:
|
||||
"""Convert a single-channel detection heatmap tensor to a DataArray.
|
||||
|
||||
Assigns time and frequency coordinates to a raw detection heatmap tensor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection : torch.Tensor
|
||||
Raw detection heatmap tensor from the model. Expected shape is
|
||||
(1, num_freq_bins, num_time_bins).
|
||||
start_time : float
|
||||
Start time (seconds) corresponding to the first time bin.
|
||||
end_time : float
|
||||
End time (seconds) corresponding to the end of the last time bin.
|
||||
min_freq : float, default=MIN_FREQ
|
||||
Minimum frequency (Hz) corresponding to the first frequency bin.
|
||||
max_freq : float, default=MAX_FREQ
|
||||
Maximum frequency (Hz) corresponding to the end of the last frequency
|
||||
bin.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
An xarray DataArray containing the detection scores with named
|
||||
dimensions ('frequency', 'time') and calculated coordinates.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input tensor does not have 3 dimensions or if the first
|
||||
dimension size is not 1.
|
||||
"""
|
||||
if detection.ndim != 3:
|
||||
raise ValueError(
|
||||
"Input detection tensor must have 3 dimensions (1, T, F), "
|
||||
f"got shape {detection.shape}"
|
||||
)
|
||||
|
||||
num_channels, height, width = detection.shape
|
||||
|
||||
if num_channels != 1:
|
||||
raise ValueError(
|
||||
"Expected a single channel output, instead got "
|
||||
f"{num_channels} channels"
|
||||
)
|
||||
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
return xr.DataArray(
|
||||
data=detection.squeeze(dim=0).detach().numpy(),
|
||||
dims=[
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
coords={
|
||||
Dimensions.frequency.value: freqs,
|
||||
Dimensions.time.value: times,
|
||||
},
|
||||
name="detection_score",
|
||||
)
|
||||
|
||||
|
||||
def classification_to_xarray(
|
||||
classes: torch.Tensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
class_names: List[str],
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
) -> xr.DataArray:
|
||||
"""Convert multi-channel class probability tensor to a DataArray.
|
||||
|
||||
Assigns category (class name), frequency, and time coordinates to a raw
|
||||
class probability tensor output by the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
classes : torch.Tensor
|
||||
Raw class probability tensor. Expected shape is
|
||||
(num_classes, num_freq_bins, num_time_bins).
|
||||
start_time : float
|
||||
Start time (seconds) corresponding to the first time bin.
|
||||
end_time : float
|
||||
End time (seconds) corresponding to the end of the last time bin.
|
||||
class_names : List[str]
|
||||
Ordered list of class names corresponding to the first dimension
|
||||
of the `classes` tensor. The length must match `classes.shape[0]`.
|
||||
min_freq : float, default=MIN_FREQ
|
||||
Minimum frequency (Hz) corresponding to the first frequency bin.
|
||||
max_freq : float, default=MAX_FREQ
|
||||
Maximum frequency (Hz) corresponding to the end of the last frequency
|
||||
bin.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
An xarray DataArray containing class probabilities with named
|
||||
dimensions ('category', 'frequency', 'time') and calculated
|
||||
coordinates.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input tensor does not have 3 dimensions, or if the size of the
|
||||
first dimension does not match the length of `class_names`.
|
||||
"""
|
||||
if classes.ndim != 3:
|
||||
raise ValueError(
|
||||
"Input classes tensor must have 3 dimensions (C, F, T), "
|
||||
f"got shape {classes.shape}"
|
||||
)
|
||||
|
||||
num_classes, height, width = classes.shape
|
||||
|
||||
if num_classes != len(class_names):
|
||||
raise ValueError(
|
||||
"The number of classes does not coincide with the number of "
|
||||
"class names provided: "
|
||||
f"({num_classes = }) != ({len(class_names) = })"
|
||||
)
|
||||
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
return xr.DataArray(
|
||||
data=classes.detach().numpy(),
|
||||
dims=[
|
||||
"category",
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
coords={
|
||||
"category": class_names,
|
||||
Dimensions.frequency.value: freqs,
|
||||
Dimensions.time.value: times,
|
||||
},
|
||||
name="class_scores",
|
||||
)
|
||||
|
||||
|
||||
def sizes_to_xarray(
|
||||
sizes: torch.Tensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
) -> xr.DataArray:
|
||||
"""Convert the 2-channel size prediction tensor to a DataArray.
|
||||
|
||||
Assigns dimension ('width', 'height'), frequency, and time coordinates
|
||||
to the raw size prediction tensor output by the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sizes : torch.Tensor
|
||||
Raw size prediction tensor. Expected shape is
|
||||
(2, num_freq_bins, num_time_bins), where the first dimension
|
||||
corresponds to predicted width and height respectively.
|
||||
start_time : float
|
||||
Start time (seconds) corresponding to the first time bin.
|
||||
end_time : float
|
||||
End time (seconds) corresponding to the end of the last time bin.
|
||||
min_freq : float, default=MIN_FREQ
|
||||
Minimum frequency (Hz) corresponding to the first frequency bin.
|
||||
max_freq : float, default=MAX_FREQ
|
||||
Maximum frequency (Hz) corresponding to the end of the last frequency
|
||||
bin.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
An xarray DataArray containing predicted sizes with named dimensions
|
||||
('dimension', 'frequency', 'time') and calculated time/frequency
|
||||
coordinates. The 'dimension' coordinate will have values
|
||||
['width', 'height'].
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input tensor does not have 3 dimensions or if the first
|
||||
dimension size is not exactly 2.
|
||||
"""
|
||||
num_channels, height, width = sizes.shape
|
||||
|
||||
if num_channels != 2:
|
||||
raise ValueError(
|
||||
"Expected a two-channel output, instead got "
|
||||
f"{num_channels} channels"
|
||||
)
|
||||
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
return xr.DataArray(
|
||||
data=sizes.detach().numpy(),
|
||||
dims=[
|
||||
"dimension",
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
coords={
|
||||
"dimension": ["width", "height"],
|
||||
Dimensions.frequency.value: freqs,
|
||||
Dimensions.time.value: times,
|
||||
},
|
||||
)
|
284
batdetect2/postprocess/types.py
Normal file
284
batdetect2/postprocess/types.py
Normal file
@ -0,0 +1,284 @@
|
||||
"""Defines shared interfaces and data structures for postprocessing.
|
||||
|
||||
This module centralizes the Protocol definitions and common data structures
|
||||
used throughout the `batdetect2.postprocess` module.
|
||||
|
||||
The main component is the `PostprocessorProtocol`, which outlines the standard
|
||||
interface for an object responsible for executing the entire postprocessing
|
||||
pipeline. This pipeline transforms raw neural network outputs into interpretable
|
||||
detections represented as `soundevent` objects. Using protocols ensures
|
||||
modularity and consistent interaction between different parts of the BatDetect2
|
||||
system that deal with model predictions.
|
||||
"""
|
||||
|
||||
from typing import Callable, List, NamedTuple, Protocol
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models.types import ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"RawPrediction",
|
||||
"PostprocessorProtocol",
|
||||
"GeometryBuilder",
|
||||
]
|
||||
|
||||
|
||||
GeometryBuilder = Callable[[tuple[float, float], np.ndarray], data.Geometry]
|
||||
"""Type alias for a function that recovers geometry from position and size.
|
||||
|
||||
This callable takes:
|
||||
1. A position tuple `(time, frequency)`.
|
||||
2. A NumPy array of size dimensions (e.g., `[width, height]`).
|
||||
It should return the reconstructed `soundevent.data.Geometry` (typically a
|
||||
`BoundingBox`).
|
||||
"""
|
||||
|
||||
|
||||
class RawPrediction(NamedTuple):
|
||||
"""Intermediate representation of a single detected sound event.
|
||||
|
||||
Holds extracted information about a detection after initial processing
|
||||
(like peak finding, coordinate remapping, geometry recovery) but before
|
||||
final class decoding and conversion into a `SoundEventPrediction`. This
|
||||
can be useful for evaluation or simpler data handling formats.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
start_time : float
|
||||
Start time of the recovered bounding box in seconds.
|
||||
end_time : float
|
||||
End time of the recovered bounding box in seconds.
|
||||
low_freq : float
|
||||
Lowest frequency of the recovered bounding box in Hz.
|
||||
high_freq : float
|
||||
Highest frequency of the recovered bounding box in Hz.
|
||||
detection_score : float
|
||||
The confidence score associated with this detection, typically from
|
||||
the detection heatmap peak.
|
||||
class_scores : xr.DataArray
|
||||
An xarray DataArray containing the predicted probabilities or scores
|
||||
for each target class at the detection location. Indexed by a
|
||||
'category' coordinate containing class names.
|
||||
features : xr.DataArray
|
||||
An xarray DataArray containing extracted feature vectors at the
|
||||
detection location. Indexed by a 'feature' coordinate.
|
||||
"""
|
||||
|
||||
start_time: float
|
||||
end_time: float
|
||||
low_freq: float
|
||||
high_freq: float
|
||||
detection_score: float
|
||||
class_scores: xr.DataArray
|
||||
features: xr.DataArray
|
||||
|
||||
|
||||
class PostprocessorProtocol(Protocol):
|
||||
"""Protocol defining the interface for the full postprocessing pipeline.
|
||||
|
||||
This protocol outlines the standard methods for an object that takes raw
|
||||
output from a BatDetect2 model and the corresponding input clip metadata,
|
||||
and processes it through various stages (e.g., coordinate remapping, NMS,
|
||||
detection extraction, data extraction, decoding) to produce interpretable
|
||||
results at different levels of completion.
|
||||
|
||||
Implementations manage the configured logic for all postprocessing steps.
|
||||
"""
|
||||
|
||||
def get_feature_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Remap feature tensors to coordinate-aware DataArrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch, expected
|
||||
to contain the necessary feature tensors.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects, one for each item in the
|
||||
processed batch. This list provides the timing, recording, and
|
||||
other metadata context needed to calculate real-world coordinates
|
||||
(seconds, Hz) for the output arrays. The length of this list must
|
||||
correspond to the batch size of the `output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
A list of xarray DataArrays, one for each input clip in the batch,
|
||||
in the same order. Each DataArray contains the feature vectors
|
||||
with dimensions like ('feature', 'time', 'frequency') and
|
||||
corresponding real-world coordinates.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_detection_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Remap detection tensors to coordinate-aware DataArrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch,
|
||||
containing detection heatmaps.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing coordinate context. Must match the batch size of
|
||||
`output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
A list of 2D xarray DataArrays (one per input clip, in order),
|
||||
representing the detection heatmap with 'time' and 'frequency'
|
||||
coordinates. Values typically indicate detection confidence.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_classification_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Remap classification tensors to coordinate-aware DataArrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch,
|
||||
containing class probability tensors.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing coordinate context. Must match the batch size of
|
||||
`output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
A list of 3D xarray DataArrays (one per input clip, in order),
|
||||
representing class probabilities with 'category', 'time', and
|
||||
'frequency' dimensions and coordinates.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_sizes_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Remap size prediction tensors to coordinate-aware DataArrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch,
|
||||
containing predicted size tensors (e.g., width and height).
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing coordinate context. Must match the batch size of
|
||||
`output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
A list of 3D xarray DataArrays (one per input clip, in order),
|
||||
representing predicted sizes with 'dimension'
|
||||
(e.g., ['width', 'height']), 'time', and 'frequency' dimensions and
|
||||
coordinates. Values represent estimated detection sizes.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_detection_datasets(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.Dataset]:
|
||||
"""Perform remapping, NMS, detection, and data extraction for a batch.
|
||||
|
||||
Processes the raw model output for a batch to identify detection peaks
|
||||
and extract all associated information (score, position, size, class
|
||||
probs, features) at those peak locations, returning a structured
|
||||
dataset for each input clip in the batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing context. Must match the batch size of `output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.Dataset]
|
||||
A list of xarray Datasets (one per input clip, in order). Each
|
||||
Dataset contains multiple DataArrays ('scores', 'dimensions',
|
||||
'classes', 'features') sharing a common 'detection' dimension,
|
||||
providing aligned data for each detected event in that clip.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_raw_predictions(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[List[RawPrediction]]:
|
||||
"""Extract intermediate RawPrediction objects for a batch.
|
||||
|
||||
Processes the raw model output for a batch through remapping, NMS,
|
||||
detection, data extraction, and geometry recovery to produce a list of
|
||||
`RawPrediction` objects for each corresponding input clip. This provides
|
||||
a simplified, intermediate representation before final tag decoding.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing context. Must match the batch size of `output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[List[RawPrediction]]
|
||||
A list of lists (one inner list per input clip, in order). Each
|
||||
inner list contains the `RawPrediction` objects extracted for the
|
||||
corresponding input clip.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_predictions(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[data.ClipPrediction]:
|
||||
"""Perform the full postprocessing pipeline for a batch.
|
||||
|
||||
Takes raw model output for a batch and corresponding clips, applies the
|
||||
entire postprocessing chain, and returns the final, interpretable
|
||||
predictions as a list of `soundevent.data.ClipPrediction` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing context. Must match the batch size of `output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.ClipPrediction]
|
||||
A list containing one `ClipPrediction` object for each input clip
|
||||
(in the same order), populated with `SoundEventPrediction` objects
|
||||
representing the final detections with decoded tags and geometry.
|
||||
"""
|
||||
...
|
@ -1,68 +1,448 @@
|
||||
"""Module containing functions for preprocessing audio clips."""
|
||||
"""Main entry point for the BatDetect2 Preprocessing subsystem.
|
||||
|
||||
from typing import Optional
|
||||
This package (`batdetect2.preprocessing`) defines and orchestrates the pipeline
|
||||
for converting raw audio input (from files or data objects) into processed
|
||||
spectrograms suitable for input to BatDetect2 models. This ensures consistent
|
||||
data handling between model training and inference.
|
||||
|
||||
The preprocessing pipeline consists of two main stages, configured via nested
|
||||
data structures:
|
||||
1. **Audio Processing (`.audio`)**: Loads audio waveforms and applies initial
|
||||
processing like resampling, duration adjustment, centering, and scaling.
|
||||
Configured via `AudioConfig`.
|
||||
2. **Spectrogram Generation (`.spectrogram`)**: Computes the spectrogram from
|
||||
the processed waveform using STFT, followed by frequency cropping, optional
|
||||
PCEN, amplitude scaling (dB, power, linear), optional denoising, optional
|
||||
resizing, and optional peak normalization. Configured via
|
||||
`SpectrogramConfig`.
|
||||
|
||||
This module provides the primary interface:
|
||||
|
||||
- `PreprocessingConfig`: A unified configuration object holding `AudioConfig`
|
||||
and `SpectrogramConfig`.
|
||||
- `load_preprocessing_config`: Function to load the unified configuration.
|
||||
- `Preprocessor`: A protocol defining the interface for the end-to-end pipeline.
|
||||
- `StandardPreprocessor`: The default implementation of the `Preprocessor`.
|
||||
- `build_preprocessor`: A factory function to create a `StandardPreprocessor`
|
||||
instance from a `PreprocessingConfig`.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess.audio import (
|
||||
DEFAULT_DURATION,
|
||||
SCALE_RAW_AUDIO,
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
AudioConfig,
|
||||
ResampleConfig,
|
||||
load_clip_audio,
|
||||
)
|
||||
from batdetect2.preprocess.config import (
|
||||
PreprocessingConfig,
|
||||
load_preprocessing_config,
|
||||
build_audio_loader,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
AmplitudeScaleConfig,
|
||||
MAX_FREQ,
|
||||
MIN_FREQ,
|
||||
ConfigurableSpectrogramBuilder,
|
||||
FrequencyConfig,
|
||||
LogScaleConfig,
|
||||
PcenScaleConfig,
|
||||
Scales,
|
||||
PcenConfig,
|
||||
SpecSizeConfig,
|
||||
SpectrogramConfig,
|
||||
STFTConfig,
|
||||
compute_spectrogram,
|
||||
build_spectrogram_builder,
|
||||
get_spectrogram_resolution,
|
||||
)
|
||||
from batdetect2.preprocess.types import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
SpectrogramBuilder,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AmplitudeScaleConfig",
|
||||
"AudioConfig",
|
||||
"AudioLoader",
|
||||
"ConfigurableSpectrogramBuilder",
|
||||
"DEFAULT_DURATION",
|
||||
"FrequencyConfig",
|
||||
"LogScaleConfig",
|
||||
"PcenScaleConfig",
|
||||
"MAX_FREQ",
|
||||
"MIN_FREQ",
|
||||
"PcenConfig",
|
||||
"PreprocessingConfig",
|
||||
"ResampleConfig",
|
||||
"SCALE_RAW_AUDIO",
|
||||
"STFTConfig",
|
||||
"Scales",
|
||||
"SpecSizeConfig",
|
||||
"SpectrogramBuilder",
|
||||
"SpectrogramConfig",
|
||||
"StandardPreprocessor",
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"build_audio_loader",
|
||||
"build_preprocessor",
|
||||
"build_spectrogram_builder",
|
||||
"get_spectrogram_resolution",
|
||||
"load_preprocessing_config",
|
||||
"preprocess_audio_clip",
|
||||
]
|
||||
|
||||
|
||||
def preprocess_audio_clip(
|
||||
clip: data.Clip,
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Preprocesses audio clip to generate spectrogram.
|
||||
class PreprocessingConfig(BaseConfig):
|
||||
"""Unified configuration for the audio preprocessing pipeline.
|
||||
|
||||
Aggregates the configuration for both the initial audio processing stage
|
||||
and the subsequent spectrogram generation stage.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
audio : AudioConfig
|
||||
Configuration settings for the audio loading and initial waveform
|
||||
processing steps (e.g., resampling, duration adjustment, scaling).
|
||||
Defaults to default `AudioConfig` settings if omitted.
|
||||
spectrogram : SpectrogramConfig
|
||||
Configuration settings for the spectrogram generation process
|
||||
(e.g., STFT parameters, frequency cropping, scaling, denoising,
|
||||
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
|
||||
"""
|
||||
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||
|
||||
|
||||
class StandardPreprocessor(PreprocessorProtocol):
|
||||
"""Standard implementation of the `Preprocessor` protocol.
|
||||
|
||||
Orchestrates the audio loading and spectrogram generation pipeline using
|
||||
an `AudioLoader` and a `SpectrogramBuilder` internally, which are
|
||||
configured according to a `PreprocessingConfig`.
|
||||
|
||||
This class is typically instantiated using the `build_preprocessor`
|
||||
factory function.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
audio_loader : AudioLoader
|
||||
The configured audio loader instance used for waveform loading and
|
||||
initial processing.
|
||||
spectrogram_builder : SpectrogramBuilder
|
||||
The configured spectrogram builder instance used for generating
|
||||
spectrograms from waveforms.
|
||||
default_samplerate : int
|
||||
The sample rate (in Hz) assumed for input waveforms when they are
|
||||
provided as raw NumPy arrays without coordinate information (e.g.,
|
||||
when calling `compute_spectrogram` directly with `np.ndarray`).
|
||||
This value is derived from the `AudioConfig` (target resample rate
|
||||
or default if resampling is off) and also serves as documentation
|
||||
for the pipeline's intended operating sample rate. Note that when
|
||||
processing `xr.DataArray` inputs that have coordinate information
|
||||
(the standard internal workflow), the sample rate embedded in the
|
||||
coordinates takes precedence over this default value during
|
||||
spectrogram calculation.
|
||||
"""
|
||||
|
||||
audio_loader: AudioLoader
|
||||
spectrogram_builder: SpectrogramBuilder
|
||||
default_samplerate: int
|
||||
max_freq: float
|
||||
min_freq: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_loader: AudioLoader,
|
||||
spectrogram_builder: SpectrogramBuilder,
|
||||
default_samplerate: int,
|
||||
max_freq: float,
|
||||
min_freq: float,
|
||||
) -> None:
|
||||
"""Initialize the StandardPreprocessor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_loader : AudioLoader
|
||||
An initialized audio loader conforming to the AudioLoader protocol.
|
||||
spectrogram_builder : SpectrogramBuilder
|
||||
An initialized spectrogram builder conforming to the
|
||||
SpectrogramBuilder protocol.
|
||||
default_samplerate : int
|
||||
The sample rate to assume for NumPy array inputs and potentially
|
||||
reflecting the target rate of the audio config.
|
||||
"""
|
||||
self.audio_loader = audio_loader
|
||||
self.spectrogram_builder = spectrogram_builder
|
||||
self.default_samplerate = default_samplerate
|
||||
self.max_freq = max_freq
|
||||
self.min_freq = min_freq
|
||||
|
||||
def load_file_audio(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess *only* the audio waveform from a file path.
|
||||
|
||||
Delegates to the internal `audio_loader`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the audio file.
|
||||
audio_dir : PathLike, optional
|
||||
A directory prefix if `path` is relative.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform (typically first
|
||||
channel).
|
||||
"""
|
||||
return self.audio_loader.load_file(
|
||||
path,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_recording_audio(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess *only* the audio waveform for a Recording.
|
||||
|
||||
Delegates to the internal `audio_loader`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The Recording object.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform (typically first
|
||||
channel).
|
||||
"""
|
||||
return self.audio_loader.load_recording(
|
||||
recording,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_clip_audio(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess *only* the audio waveform for a Clip.
|
||||
|
||||
Delegates to the internal `audio_loader`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
The Clip object defining the segment.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform segment (typically first
|
||||
channel).
|
||||
"""
|
||||
return self.audio_loader.load_clip(
|
||||
clip,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def preprocess_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load audio from a file and compute the final processed spectrogram.
|
||||
|
||||
Performs the full pipeline:
|
||||
|
||||
Load -> Preprocess Audio -> Compute Spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the audio file.
|
||||
audio_dir : PathLike, optional
|
||||
A directory prefix if `path` is relative.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The final processed spectrogram.
|
||||
"""
|
||||
wav = self.load_file_audio(path, audio_dir=audio_dir)
|
||||
return self.spectrogram_builder(
|
||||
wav,
|
||||
samplerate=self.default_samplerate,
|
||||
)
|
||||
|
||||
def preprocess_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load audio for a Recording and compute the processed spectrogram.
|
||||
|
||||
Performs the full pipeline for the entire duration of the recording.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The Recording object.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The final processed spectrogram.
|
||||
"""
|
||||
wav = self.load_recording_audio(recording, audio_dir=audio_dir)
|
||||
return self.spectrogram_builder(
|
||||
wav,
|
||||
samplerate=self.default_samplerate,
|
||||
)
|
||||
|
||||
def preprocess_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load audio for a Clip and compute the final processed spectrogram.
|
||||
|
||||
Performs the full pipeline for the specified clip segment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
The Clip object defining the audio segment.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The final processed spectrogram.
|
||||
"""
|
||||
wav = self.load_clip_audio(clip, audio_dir=audio_dir)
|
||||
return self.spectrogram_builder(
|
||||
wav,
|
||||
samplerate=self.default_samplerate,
|
||||
)
|
||||
|
||||
def compute_spectrogram(
|
||||
self, wav: Union[xr.DataArray, np.ndarray]
|
||||
) -> xr.DataArray:
|
||||
"""Compute the spectrogram from a pre-loaded audio waveform.
|
||||
|
||||
Applies the configured spectrogram generation steps
|
||||
(STFT, scaling, etc.) using the internal `spectrogram_builder`.
|
||||
|
||||
If `wav` is a NumPy array, the `default_samplerate` stored in this
|
||||
preprocessor instance will be used. If `wav` is an xarray DataArray
|
||||
with time coordinates, the sample rate derived from those coordinates
|
||||
will take precedence over `default_samplerate`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : Union[xr.DataArray, np.ndarray]
|
||||
The input audio waveform. If numpy array, `default_samplerate`
|
||||
stored in this object will be assumed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The computed spectrogram.
|
||||
"""
|
||||
return self.spectrogram_builder(
|
||||
wav,
|
||||
samplerate=self.default_samplerate,
|
||||
)
|
||||
|
||||
|
||||
def load_preprocessing_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PreprocessingConfig:
|
||||
"""Load the unified preprocessing configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`PreprocessingConfig` schema, potentially extracting data from a nested
|
||||
field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip
|
||||
The audio clip to preprocess.
|
||||
config
|
||||
Configuration for preprocessing.
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
preprocessing configuration (e.g., "train.preprocessing"). If None, the
|
||||
entire file content is validated as the PreprocessingConfig.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Preprocessed spectrogram.
|
||||
PreprocessingConfig
|
||||
Loaded and validated preprocessing configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded config data does not conform to PreprocessingConfig.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path.
|
||||
"""
|
||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
||||
|
||||
|
||||
def build_preprocessor(
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
) -> PreprocessorProtocol:
|
||||
"""Factory function to build the standard preprocessor from configuration.
|
||||
|
||||
Creates instances of the required `AudioLoader` and `SpectrogramBuilder`
|
||||
based on the provided `PreprocessingConfig` (or defaults if config is None),
|
||||
determines the effective default sample rate, and initializes the
|
||||
`StandardPreprocessor`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : PreprocessingConfig, optional
|
||||
The unified preprocessing configuration object. If None, default
|
||||
configurations for audio and spectrogram processing will be used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Preprocessor
|
||||
An initialized `StandardPreprocessor` instance ready to process audio
|
||||
according to the configuration.
|
||||
"""
|
||||
config = config or PreprocessingConfig()
|
||||
wav = load_clip_audio(clip, config=config.audio, audio_dir=audio_dir)
|
||||
return compute_spectrogram(wav, config=config.spectrogram)
|
||||
|
||||
default_samplerate = (
|
||||
config.audio.resample.samplerate
|
||||
if config.audio.resample
|
||||
else TARGET_SAMPLERATE_HZ
|
||||
)
|
||||
|
||||
min_freq = config.spectrogram.frequencies.min_freq
|
||||
max_freq = config.spectrogram.frequencies.max_freq
|
||||
|
||||
return StandardPreprocessor(
|
||||
audio_loader=build_audio_loader(config.audio),
|
||||
spectrogram_builder=build_spectrogram_builder(config.spectrogram),
|
||||
default_samplerate=default_samplerate,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
|
@ -1,3 +1,25 @@
|
||||
"""Handles loading and initial preprocessing of audio waveforms.
|
||||
|
||||
This module provides components for loading audio data associated with
|
||||
`soundevent` objects (Clips, Recordings, or raw files) and applying
|
||||
fundamental waveform processing steps. These steps typically include:
|
||||
|
||||
1. Loading the raw audio data.
|
||||
2. Adjusting the audio clip to a fixed duration (optional).
|
||||
3. Resampling the audio to a target sample rate (optional).
|
||||
4. Centering the waveform (DC offset removal) (optional).
|
||||
5. Scaling the waveform amplitude (optional).
|
||||
|
||||
The processing pipeline is configurable via the `AudioConfig` data structure,
|
||||
allowing for reproducible preprocessing consistent between model training and
|
||||
inference. It uses the `soundevent` library for audio loading and basic array
|
||||
operations, and `scipy` for resampling implementations.
|
||||
|
||||
The primary interface is the `AudioLoader` protocol, with
|
||||
`ConfigurableAudioLoader` providing a concrete implementation driven by the
|
||||
`AudioConfig`.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
@ -7,33 +29,246 @@ from pydantic import Field
|
||||
from scipy.signal import resample, resample_poly
|
||||
from soundevent import arrays, audio, data
|
||||
from soundevent.arrays import operations as ops
|
||||
from soundfile import LibsndfileError
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.types import AudioLoader
|
||||
|
||||
__all__ = [
|
||||
"ResampleConfig",
|
||||
"AudioConfig",
|
||||
"ConfigurableAudioLoader",
|
||||
"build_audio_loader",
|
||||
"load_file_audio",
|
||||
"load_recording_audio",
|
||||
"load_clip_audio",
|
||||
"adjust_audio_duration",
|
||||
"resample_audio",
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"SCALE_RAW_AUDIO",
|
||||
"DEFAULT_DURATION",
|
||||
"convert_to_xr",
|
||||
]
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256_000
|
||||
"""Default target sample rate in Hz used if resampling is enabled."""
|
||||
|
||||
SCALE_RAW_AUDIO = False
|
||||
"""Default setting for whether to perform peak normalization."""
|
||||
|
||||
DEFAULT_DURATION = None
|
||||
"""Default setting for target audio duration in seconds."""
|
||||
|
||||
|
||||
class ResampleConfig(BaseConfig):
|
||||
"""Configuration for audio resampling.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
samplerate : int, default=256000
|
||||
The target sample rate in Hz to resample the audio to. Must be > 0.
|
||||
method : str, default="poly"
|
||||
The resampling algorithm to use. Options:
|
||||
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
|
||||
Generally fast.
|
||||
- "fourier": Resampling via Fourier method using
|
||||
`scipy.signal.resample`. May handle non-integer
|
||||
resampling factors differently.
|
||||
"""
|
||||
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
mode: str = "poly"
|
||||
method: str = "poly"
|
||||
|
||||
|
||||
class AudioConfig(BaseConfig):
|
||||
"""Configuration for loading and initial audio preprocessing.
|
||||
|
||||
Defines the sequence of operations applied to raw audio waveforms after
|
||||
loading, controlling steps like resampling, scaling, centering, and
|
||||
duration adjustment.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
resample : ResampleConfig, optional
|
||||
Configuration for resampling. If provided (or defaulted), audio will
|
||||
be resampled to the specified `samplerate` using the specified
|
||||
`method`. If set to `None` in the config file, resampling is skipped.
|
||||
Defaults to a ResampleConfig instance with standard settings.
|
||||
scale : bool, default=False
|
||||
If True, scales the audio waveform using peak normalization so that
|
||||
its maximum absolute amplitude is approximately 1.0. If False
|
||||
(default), no amplitude scaling is applied.
|
||||
center : bool, default=True
|
||||
If True (default), centers the waveform by subtracting its mean
|
||||
(DC offset removal). If False, the waveform is not centered.
|
||||
duration : float, optional
|
||||
If set to a float value (seconds), the loaded audio clip will be
|
||||
adjusted (cropped or padded with zeros) to exactly this duration.
|
||||
If None (default), the original duration is kept.
|
||||
"""
|
||||
|
||||
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
||||
scale: bool = SCALE_RAW_AUDIO
|
||||
center: bool = True
|
||||
duration: Optional[float] = DEFAULT_DURATION
|
||||
|
||||
|
||||
class ConfigurableAudioLoader:
|
||||
"""Concrete implementation of the `AudioLoader` driven by `AudioConfig`.
|
||||
|
||||
This class loads audio and applies preprocessing steps (resampling,
|
||||
scaling, centering, duration adjustment) based on the settings provided
|
||||
in an `AudioConfig` object during initialization. It delegates the actual
|
||||
work to module-level functions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AudioConfig,
|
||||
):
|
||||
"""Initialize the ConfigurableAudioLoader.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : AudioConfig
|
||||
The configuration object specifying the desired preprocessing steps
|
||||
and parameters.
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
def load_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess audio directly from a file path.
|
||||
|
||||
Implements the `AudioLoader.load_file` method by delegating to the
|
||||
`load_file_audio` function, passing the stored configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the audio file.
|
||||
audio_dir : PathLike, optional
|
||||
A directory prefix if `path` is relative.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Loaded and preprocessed waveform (first channel).
|
||||
"""
|
||||
return load_file_audio(path, config=self.config, audio_dir=audio_dir)
|
||||
|
||||
def load_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess the entire audio for a Recording object.
|
||||
|
||||
Implements the `AudioLoader.load_recording` method by delegating to the
|
||||
`load_recording_audio` function, passing the stored configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The Recording object.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Loaded and preprocessed waveform (first channel).
|
||||
"""
|
||||
return load_recording_audio(
|
||||
recording, config=self.config, audio_dir=audio_dir
|
||||
)
|
||||
|
||||
def load_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess the audio segment defined by a Clip object.
|
||||
|
||||
Implements the `AudioLoader.load_clip` method by delegating to the
|
||||
`load_clip_audio` function, passing the stored configuration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
The Clip object specifying the segment.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Loaded and preprocessed waveform segment (first channel).
|
||||
"""
|
||||
return load_clip_audio(clip, config=self.config, audio_dir=audio_dir)
|
||||
|
||||
|
||||
def build_audio_loader(
|
||||
config: AudioConfig,
|
||||
) -> AudioLoader:
|
||||
"""Factory function to create an AudioLoader based on configuration.
|
||||
|
||||
Instantiates and returns a `ConfigurableAudioLoader` initialized with
|
||||
the provided `AudioConfig`. The return type is `AudioLoader`, adhering
|
||||
to the protocol.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : AudioConfig
|
||||
The configuration object specifying preprocessing steps.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AudioLoader
|
||||
An instance of `ConfigurableAudioLoader` ready to load and process audio
|
||||
according to the configuration.
|
||||
"""
|
||||
return ConfigurableAudioLoader(config=config)
|
||||
|
||||
|
||||
def load_file_audio(
|
||||
path: data.PathLike,
|
||||
config: Optional[AudioConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
recording = data.Recording.from_file(path)
|
||||
"""Load and preprocess audio from a file path using specified config.
|
||||
|
||||
Creates a `soundevent.data.Recording` object from the file path and then
|
||||
delegates the loading and processing to `load_recording_audio`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the audio file.
|
||||
config : AudioConfig, optional
|
||||
Audio processing configuration. If None, default settings defined
|
||||
in `AudioConfig` are used.
|
||||
audio_dir : PathLike, optional
|
||||
Directory prefix if `path` is relative.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target NumPy data type for the loaded audio array.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Loaded and preprocessed waveform (first channel only).
|
||||
"""
|
||||
try:
|
||||
recording = data.Recording.from_file(path)
|
||||
except LibsndfileError as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not load the recording at path: {path}. Error: {e}"
|
||||
) from e
|
||||
|
||||
return load_recording_audio(
|
||||
recording,
|
||||
config=config,
|
||||
@ -46,8 +281,31 @@ def load_recording_audio(
|
||||
recording: data.Recording,
|
||||
config: Optional[AudioConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess the entire audio content of a recording using config.
|
||||
|
||||
Creates a `soundevent.data.Clip` spanning the full duration of the
|
||||
recording and then delegates the loading and processing to
|
||||
`load_clip_audio`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The Recording object containing metadata and file path.
|
||||
config : AudioConfig, optional
|
||||
Audio processing configuration. If None, default settings are used.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file, used if the path in `recording`
|
||||
is relative.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target NumPy data type for the loaded audio array.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Loaded and preprocessed waveform (first channel only).
|
||||
"""
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
@ -65,16 +323,64 @@ def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
config: Optional[AudioConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess a specific audio clip segment based on config.
|
||||
|
||||
This is the core function performing the configured processing pipeline:
|
||||
1. Loads the specified clip segment using `soundevent.audio.load_clip`.
|
||||
2. Selects the first audio channel.
|
||||
3. Resamples if `config.resample` is configured.
|
||||
4. Centers (DC offset removal) if `config.center` is True.
|
||||
5. Scales (peak normalization) if `config.scale` is True.
|
||||
6. Adjusts duration (crop/pad) if `config.duration` is set.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
The Clip object defining the audio segment and source recording.
|
||||
config : AudioConfig, optional
|
||||
Audio processing configuration. If None, a default `AudioConfig` is
|
||||
used.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the source audio file specified in the clip's
|
||||
recording.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target NumPy data type for the processed audio array.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed waveform segment as an xarray DataArray
|
||||
with time coordinates.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the underlying audio file cannot be found.
|
||||
Exception
|
||||
If audio loading or processing fails for other reasons (e.g., invalid
|
||||
format, resampling error).
|
||||
|
||||
Notes
|
||||
-----
|
||||
- **Mono Processing:** This function currently loads and processes only the
|
||||
**first channel** (channel 0) of the audio file. Any other channels
|
||||
are ignored.
|
||||
"""
|
||||
config = config or AudioConfig()
|
||||
|
||||
wav = (
|
||||
audio.load_clip(clip, audio_dir=audio_dir).sel(channel=0).astype(dtype)
|
||||
)
|
||||
|
||||
if config.duration is not None:
|
||||
wav = adjust_audio_duration(wav, duration=config.duration)
|
||||
try:
|
||||
wav = (
|
||||
audio.load_clip(clip, audio_dir=audio_dir)
|
||||
.sel(channel=0)
|
||||
.astype(dtype)
|
||||
)
|
||||
except LibsndfileError as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not load the recording at path: {clip.recording.path}. "
|
||||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
if config.resample:
|
||||
wav = resample_audio(
|
||||
@ -87,43 +393,126 @@ def load_clip_audio(
|
||||
wav = ops.center(wav)
|
||||
|
||||
if config.scale:
|
||||
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
|
||||
wav = scale_audio(wav)
|
||||
|
||||
if config.duration is not None:
|
||||
wav = adjust_audio_duration(wav, duration=config.duration)
|
||||
|
||||
return wav.astype(dtype)
|
||||
|
||||
|
||||
def scale_audio(
|
||||
wave: xr.DataArray,
|
||||
) -> xr.DataArray:
|
||||
"""
|
||||
Scale the audio waveform to have a maximum absolute value of 1.0.
|
||||
|
||||
This function normalizes the waveform by dividing it by its maximum
|
||||
absolute value. If the maximum value is zero, the waveform is returned
|
||||
unchanged. Also known as peak normalization, this process ensures that the
|
||||
waveform's amplitude is within a standard range, which can be useful for
|
||||
audio processing and analysis.
|
||||
|
||||
"""
|
||||
max_val = np.max(np.abs(wave))
|
||||
|
||||
if max_val == 0:
|
||||
return wave
|
||||
|
||||
return ops.scale(wave, 1 / max_val)
|
||||
|
||||
|
||||
def adjust_audio_duration(
|
||||
wave: xr.DataArray,
|
||||
duration: float,
|
||||
) -> xr.DataArray:
|
||||
"""Adjust the duration of an audio waveform array via cropping or padding.
|
||||
|
||||
If the current duration is longer than the target, it crops the array
|
||||
from the beginning. If shorter, it pads the array with zeros at the end
|
||||
using `soundevent.arrays.extend_dim`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wave : xr.DataArray
|
||||
The input audio waveform with a 'time' dimension and coordinates.
|
||||
duration : float
|
||||
The target duration in seconds.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The waveform adjusted to the target duration. Returns the input
|
||||
unmodified if duration already matches or if the wave is empty.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `duration` is negative.
|
||||
"""
|
||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||
current_duration = end_time - start_time
|
||||
step = arrays.get_dim_step(wave, dim="time")
|
||||
current_duration = end_time - start_time + step
|
||||
|
||||
if current_duration == duration:
|
||||
return wave
|
||||
|
||||
if current_duration > duration:
|
||||
return arrays.crop_dim(
|
||||
with xr.set_options(keep_attrs=True):
|
||||
if current_duration > duration:
|
||||
return arrays.crop_dim(
|
||||
wave,
|
||||
dim="time",
|
||||
start=start_time,
|
||||
stop=start_time + duration - step / 2,
|
||||
right_closed=True,
|
||||
)
|
||||
|
||||
return arrays.extend_dim(
|
||||
wave,
|
||||
dim="time",
|
||||
start=start_time,
|
||||
stop=start_time + duration,
|
||||
stop=start_time + duration - step / 2,
|
||||
eps=0,
|
||||
right_closed=True,
|
||||
)
|
||||
|
||||
return arrays.extend_dim(
|
||||
wave,
|
||||
dim="time",
|
||||
start=start_time,
|
||||
stop=start_time + duration,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio(
|
||||
wav: xr.DataArray,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
mode: str = "poly",
|
||||
dtype: DTypeLike = np.float32,
|
||||
method: str = "poly",
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
"""Resample an audio waveform DataArray to a target sample rate.
|
||||
|
||||
Updates the 'time' coordinate axis according to the new sample rate and
|
||||
number of samples. Uses either polyphase (`scipy.signal.resample_poly`)
|
||||
or Fourier method (`scipy.signal.resample`) based on the `method`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : xr.DataArray
|
||||
Input audio waveform with 'time' dimension and coordinates.
|
||||
samplerate : int, default=TARGET_SAMPLERATE_HZ
|
||||
Target sample rate in Hz.
|
||||
method : str, default="poly"
|
||||
Resampling algorithm: "poly" or "fourier".
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target data type for the resampled array.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Resampled waveform with updated time coordinates. Returns the input
|
||||
unmodified (but dtype cast) if the sample rate is already correct or
|
||||
if the input array is empty.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `wav` lacks a 'time' dimension, the original sample rate cannot
|
||||
be determined, `samplerate` is non-positive, or `method` is invalid.
|
||||
"""
|
||||
if "time" not in wav.dims:
|
||||
raise ValueError("Audio must have a time dimension")
|
||||
|
||||
@ -134,14 +523,14 @@ def resample_audio(
|
||||
if original_samplerate == samplerate:
|
||||
return wav.astype(dtype)
|
||||
|
||||
if mode == "poly":
|
||||
if method == "poly":
|
||||
resampled = resample_audio_poly(
|
||||
wav,
|
||||
sr_orig=original_samplerate,
|
||||
sr_new=samplerate,
|
||||
axis=time_axis,
|
||||
)
|
||||
elif mode == "fourier":
|
||||
elif method == "fourier":
|
||||
resampled = resample_audio_fourier(
|
||||
wav,
|
||||
sr_orig=original_samplerate,
|
||||
@ -149,7 +538,9 @@ def resample_audio(
|
||||
axis=time_axis,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Resampling mode '{mode}' not implemented")
|
||||
raise NotImplementedError(
|
||||
f"Resampling method '{method}' not implemented"
|
||||
)
|
||||
|
||||
start, stop = arrays.get_dim_range(wav, dim="time")
|
||||
times = np.linspace(
|
||||
@ -170,7 +561,7 @@ def resample_audio(
|
||||
samplerate=samplerate,
|
||||
),
|
||||
},
|
||||
attrs=wav.attrs,
|
||||
attrs={**wav.attrs, "samplerate": samplerate},
|
||||
)
|
||||
|
||||
|
||||
@ -180,6 +571,33 @@ def resample_audio_poly(
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
"""Resample a numpy array using `scipy.signal.resample_poly`.
|
||||
|
||||
This method is often preferred for signals when the ratio of new
|
||||
to old sample rates can be expressed as a rational number. It uses
|
||||
polyphase filtering.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : np.ndarray
|
||||
The input array to resample.
|
||||
sr_orig : int
|
||||
The original sample rate in Hz.
|
||||
sr_new : int
|
||||
The target sample rate in Hz.
|
||||
axis : int, default=-1
|
||||
The axis of `array` along which to resample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The array resampled to the target sample rate.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If sample rates are not positive.
|
||||
"""
|
||||
gcd = np.gcd(sr_orig, sr_new)
|
||||
return resample_poly(
|
||||
array.values,
|
||||
@ -195,5 +613,97 @@ def resample_audio_fourier(
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
"""Resample a numpy array using `scipy.signal.resample`.
|
||||
|
||||
This method uses FFTs to resample the signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : np.ndarray
|
||||
The input array to resample.
|
||||
num : int
|
||||
The desired number of samples in the output array along `axis`.
|
||||
axis : int, default=-1
|
||||
The axis of `array` along which to resample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The array resampled to have `num` samples along `axis`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num` is negative.
|
||||
"""
|
||||
ratio = sr_new / sr_orig
|
||||
return resample(array, int(array.shape[axis] * ratio), axis=axis) # type: ignore
|
||||
return resample( # type: ignore
|
||||
array,
|
||||
int(array.shape[axis] * ratio),
|
||||
axis=axis,
|
||||
)
|
||||
|
||||
|
||||
def convert_to_xr(
|
||||
wav: np.ndarray,
|
||||
samplerate: int,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
"""Convert a NumPy array to an xarray DataArray with time coordinates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : np.ndarray
|
||||
The input waveform array. Expected to be 1D or 2D (with the first
|
||||
axis as the channel dimension).
|
||||
samplerate : int
|
||||
The sample rate in Hz.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target data type for the xarray DataArray.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The waveform as an xarray DataArray with time coordinates.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input array is not 1D or 2D, or if the sample rate is
|
||||
non-positive. If the input array is empty.
|
||||
"""
|
||||
|
||||
if wav.ndim == 2:
|
||||
wav = wav[0, :]
|
||||
|
||||
if wav.ndim != 1:
|
||||
raise ValueError(
|
||||
"Audio must be 1D array or 2D channel where the first "
|
||||
"axis is the channel dimension"
|
||||
)
|
||||
|
||||
if wav.size == 0:
|
||||
raise ValueError("Audio array is empty")
|
||||
|
||||
if samplerate <= 0:
|
||||
raise ValueError("Sample rate must be positive")
|
||||
|
||||
times = np.linspace(
|
||||
0,
|
||||
wav.shape[0] / samplerate,
|
||||
wav.shape[0],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data=wav.astype(dtype),
|
||||
dims=["time"],
|
||||
coords={
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
times,
|
||||
samplerate=samplerate,
|
||||
),
|
||||
},
|
||||
attrs={"samplerate": samplerate},
|
||||
)
|
||||
|
@ -1,31 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess.audio import (
|
||||
AudioConfig,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
SpectrogramConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PreprocessingConfig",
|
||||
"load_preprocessing_config",
|
||||
]
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseConfig):
|
||||
"""Configuration for preprocessing data."""
|
||||
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||
|
||||
|
||||
def load_preprocessing_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PreprocessingConfig:
|
||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
@ -1,7 +1,26 @@
|
||||
"""Computes spectrograms from audio waveforms with configurable parameters.
|
||||
|
||||
This module provides the functionality to convert preprocessed audio waveforms
|
||||
(typically output from the `batdetect2.preprocessing.audio` module) into
|
||||
spectrogram representations suitable for input into deep learning models like
|
||||
BatDetect2.
|
||||
|
||||
It offers a configurable pipeline including:
|
||||
1. Short-Time Fourier Transform (STFT) calculation to get magnitude.
|
||||
2. Frequency axis cropping to a relevant range.
|
||||
3. Per-Channel Energy Normalization (PCEN) (optional).
|
||||
4. Amplitude scaling/representation (dB, power, or linear amplitude).
|
||||
5. Simple spectral mean subtraction denoising (optional).
|
||||
6. Resizing to target dimensions (optional).
|
||||
7. Final peak normalization (optional).
|
||||
|
||||
Configuration is managed via the `SpectrogramConfig` class, allowing for
|
||||
reproducible spectrogram generation consistent between training and inference.
|
||||
The core computation is performed by `compute_spectrogram`.
|
||||
"""
|
||||
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import librosa
|
||||
import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from numpy.typing import DTypeLike
|
||||
@ -10,68 +29,302 @@ from soundevent import arrays, audio
|
||||
from soundevent.arrays import operations as ops
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.audio import convert_to_xr
|
||||
from batdetect2.preprocess.types import SpectrogramBuilder
|
||||
|
||||
__all__ = [
|
||||
"STFTConfig",
|
||||
"FrequencyConfig",
|
||||
"SpecSizeConfig",
|
||||
"PcenConfig",
|
||||
"SpectrogramConfig",
|
||||
"ConfigurableSpectrogramBuilder",
|
||||
"build_spectrogram_builder",
|
||||
"compute_spectrogram",
|
||||
"get_spectrogram_resolution",
|
||||
"MIN_FREQ",
|
||||
"MAX_FREQ",
|
||||
]
|
||||
|
||||
|
||||
MIN_FREQ = 10_000
|
||||
"""Default minimum frequency (Hz) for spectrogram frequency cropping."""
|
||||
|
||||
MAX_FREQ = 120_000
|
||||
"""Default maximum frequency (Hz) for spectrogram frequency cropping."""
|
||||
|
||||
|
||||
class STFTConfig(BaseConfig):
|
||||
"""Configuration for the Short-Time Fourier Transform (STFT).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
window_duration : float, default=0.002
|
||||
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be
|
||||
> 0. Determines frequency resolution (longer window = finer frequency
|
||||
resolution).
|
||||
window_overlap : float, default=0.75
|
||||
Fraction of overlap between consecutive STFT windows (e.g., 0.75
|
||||
for 75%). Must be >= 0 and < 1. Determines time resolution
|
||||
(higher overlap = finer time resolution).
|
||||
window_fn : str, default="hann"
|
||||
Name of the window function to apply before FFT calculation. Common
|
||||
options include "hann", "hamming", "blackman". See
|
||||
`scipy.signal.get_window`.
|
||||
"""
|
||||
|
||||
window_duration: float = Field(default=0.002, gt=0)
|
||||
window_overlap: float = Field(default=0.75, ge=0, lt=1)
|
||||
window_fn: str = "hann"
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
max_freq: int = Field(default=120_000, gt=0)
|
||||
min_freq: int = Field(default=10_000, gt=0)
|
||||
"""Configuration for frequency axis parameters.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
max_freq : int, default=120000
|
||||
Maximum frequency in Hz to retain in the spectrogram after STFT.
|
||||
Frequencies above this value will be cropped. Must be > 0.
|
||||
min_freq : int, default=10000
|
||||
Minimum frequency in Hz to retain in the spectrogram after STFT.
|
||||
Frequencies below this value will be cropped. Must be >= 0.
|
||||
"""
|
||||
|
||||
max_freq: int = Field(default=120_000, ge=0)
|
||||
min_freq: int = Field(default=10_000, ge=0)
|
||||
|
||||
|
||||
class SpecSizeConfig(BaseConfig):
|
||||
"""Configuration for the final size and shape of the spectrogram.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
height : int, default=128
|
||||
Target height of the spectrogram in pixels (frequency bins). The
|
||||
frequency axis will be resized (e.g., via interpolation) to match this
|
||||
height after frequency cropping and amplitude scaling. Must be > 0.
|
||||
resize_factor : float, optional
|
||||
Factor by which to resize the spectrogram along the time axis *after*
|
||||
STFT calculation. A value of 0.5 halves the number of time bins,
|
||||
2.0 doubles it. If None (default), no resizing along the time axis is
|
||||
performed relative to the STFT output width. Must be > 0 if provided.
|
||||
"""
|
||||
|
||||
height: int = 128
|
||||
"""Height of the spectrogram in pixels. This value determines the
|
||||
number of frequency bands and corresponds to the vertical dimension
|
||||
of the spectrogram."""
|
||||
|
||||
resize_factor: Optional[float] = 0.5
|
||||
"""Factor by which to resize the spectrogram along the time axis.
|
||||
A value of 0.5 reduces the temporal dimension by half, while a
|
||||
value of 2.0 doubles it. If None, no resizing is performed."""
|
||||
|
||||
|
||||
class LogScaleConfig(BaseConfig):
|
||||
name: Literal["log"] = "log"
|
||||
class PcenConfig(BaseConfig):
|
||||
"""Configuration for Per-Channel Energy Normalization (PCEN).
|
||||
|
||||
PCEN is an adaptive gain control method that can help emphasize transients
|
||||
and suppress stationary noise. Applied after STFT and frequency cropping,
|
||||
but before final amplitude scaling (dB, power, amplitude).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
time_constant : float, default=0.4
|
||||
Time constant (in seconds) for the PCEN smoothing filter. Controls
|
||||
how quickly the normalization adapts to energy changes.
|
||||
gain : float, default=0.98
|
||||
Gain factor (alpha). Controls the adaptive gain component.
|
||||
bias : float, default=2.0
|
||||
Bias factor (delta). Added before the exponentiation.
|
||||
power : float, default=0.5
|
||||
Exponent (r). Controls the compression characteristic.
|
||||
"""
|
||||
|
||||
class PcenScaleConfig(BaseConfig):
|
||||
name: Literal["pcen"] = "pcen"
|
||||
time_constant: float = 0.4
|
||||
hop_length: int = 512
|
||||
gain: float = 0.98
|
||||
bias: float = 2
|
||||
power: float = 0.5
|
||||
|
||||
|
||||
class AmplitudeScaleConfig(BaseConfig):
|
||||
name: Literal["amplitude"] = "amplitude"
|
||||
|
||||
|
||||
Scales = Union[LogScaleConfig, PcenScaleConfig, AmplitudeScaleConfig]
|
||||
|
||||
|
||||
class SpectrogramConfig(BaseConfig):
|
||||
"""Unified configuration for spectrogram generation pipeline.
|
||||
|
||||
Aggregates settings for all steps involved in converting a preprocessed
|
||||
audio waveform into a final spectrogram representation suitable for model
|
||||
input.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
stft : STFTConfig
|
||||
Configuration for the initial Short-Time Fourier Transform.
|
||||
Defaults to standard settings via `STFTConfig`.
|
||||
frequencies : FrequencyConfig
|
||||
Configuration for cropping the frequency range after STFT.
|
||||
Defaults to standard settings via `FrequencyConfig`.
|
||||
pcen : PcenConfig, optional
|
||||
Configuration for applying Per-Channel Energy Normalization (PCEN). If
|
||||
provided, PCEN is applied after frequency cropping. If None or omitted
|
||||
(default), PCEN is skipped.
|
||||
scale : Literal["dB", "amplitude", "power"], default="amplitude"
|
||||
Determines the final amplitude representation *after* optional PCEN.
|
||||
- "amplitude": Use linear magnitude values (output of STFT or PCEN).
|
||||
- "power": Use power values (magnitude squared).
|
||||
- "dB": Use logarithmic (decibel-like) scaling applied to the magnitude
|
||||
(or PCEN output if enabled). Calculated as `log1p(C * S)`.
|
||||
size : SpecSizeConfig, optional
|
||||
Configuration for resizing the spectrogram dimensions
|
||||
(frequency height, optional time width factor). Applied after PCEN and
|
||||
scaling. If None (default), no resizing is performed.
|
||||
spectral_mean_substraction : bool, default=True
|
||||
If True (default), applies simple spectral mean subtraction denoising
|
||||
*after* PCEN and amplitude scaling, but *before* resizing.
|
||||
peak_normalize : bool, default=False
|
||||
If True, applies a final peak normalization to the spectrogram *after*
|
||||
all other steps (including resizing), scaling the overall maximum value
|
||||
to 1.0. If False (default), this final normalization is skipped.
|
||||
"""
|
||||
|
||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||
scale: Scales = Field(
|
||||
default_factory=PcenScaleConfig,
|
||||
discriminator="name",
|
||||
)
|
||||
pcen: Optional[PcenConfig] = Field(default_factory=PcenConfig)
|
||||
scale: Literal["dB", "amplitude", "power"] = "amplitude"
|
||||
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
||||
denoise: bool = True
|
||||
max_scale: bool = False
|
||||
spectral_mean_substraction: bool = True
|
||||
peak_normalize: bool = False
|
||||
|
||||
|
||||
class ConfigurableSpectrogramBuilder(SpectrogramBuilder):
|
||||
"""Implementation of `SpectrogramBuilder` driven by `SpectrogramConfig`.
|
||||
|
||||
This class computes spectrograms according to the parameters specified in a
|
||||
`SpectrogramConfig` object provided during initialization. It handles both
|
||||
numpy array and xarray DataArray inputs for the waveform.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SpectrogramConfig,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> None:
|
||||
"""Initialize the ConfigurableSpectrogramBuilder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : SpectrogramConfig
|
||||
The configuration object specifying all spectrogram parameters.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
The target NumPy data type for the computed spectrogram array.
|
||||
"""
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
wav: Union[np.ndarray, xr.DataArray],
|
||||
samplerate: Optional[int] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Generate a spectrogram from an audio waveform using the config.
|
||||
|
||||
Implements the `SpectrogramBuilder` protocol. If the input `wav` is
|
||||
a numpy array, `samplerate` must be provided; the array will be
|
||||
converted to an xarray DataArray internally. If `wav` is already an
|
||||
xarray DataArray with time coordinates, `samplerate` is ignored.
|
||||
Delegates the main computation to `compute_spectrogram`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : Union[np.ndarray, xr.DataArray]
|
||||
The input audio waveform.
|
||||
samplerate : int, optional
|
||||
The sample rate in Hz (required only if `wav` is np.ndarray).
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The computed spectrogram.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `wav` is np.ndarray and `samplerate` is None.
|
||||
"""
|
||||
if isinstance(wav, np.ndarray):
|
||||
if samplerate is None:
|
||||
raise ValueError(
|
||||
"Samplerate must be provided when passing a numpy array."
|
||||
)
|
||||
wav = convert_to_xr(
|
||||
wav,
|
||||
samplerate=samplerate,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
return compute_spectrogram(
|
||||
wav,
|
||||
config=self.config,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
|
||||
def build_spectrogram_builder(
|
||||
config: SpectrogramConfig,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> SpectrogramBuilder:
|
||||
"""Factory function to create a SpectrogramBuilder based on configuration.
|
||||
|
||||
Instantiates and returns a `ConfigurableSpectrogramBuilder` initialized
|
||||
with the provided `SpectrogramConfig`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : SpectrogramConfig
|
||||
The configuration object specifying spectrogram parameters.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
The target NumPy data type for the computed spectrogram array.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SpectrogramBuilder
|
||||
An instance of `ConfigurableSpectrogramBuilder` ready to compute
|
||||
spectrograms according to the configuration.
|
||||
"""
|
||||
return ConfigurableSpectrogramBuilder(config=config, dtype=dtype)
|
||||
|
||||
|
||||
def compute_spectrogram(
|
||||
wav: xr.DataArray,
|
||||
config: Optional[SpectrogramConfig] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
"""Compute a spectrogram from a waveform using specified configurations.
|
||||
|
||||
Applies a sequence of operations based on the `config`:
|
||||
1. Compute STFT magnitude (`stft`).
|
||||
2. Crop frequency axis (`crop_spectrogram_frequencies`).
|
||||
3. Apply PCEN if configured (`apply_pcen`).
|
||||
4. Apply final amplitude scaling (dB, power, amplitude)
|
||||
(`scale_spectrogram`).
|
||||
5. Apply spectral mean subtraction denoising if enabled.
|
||||
6. Resize dimensions if specified (`resize_spectrogram`).
|
||||
7. Apply final peak normalization if enabled.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : xr.DataArray
|
||||
Input audio waveform with a 'time' dimension and coordinates from
|
||||
which the sample rate can be inferred.
|
||||
config : SpectrogramConfig, optional
|
||||
Configuration object specifying spectrogram parameters. If None,
|
||||
default settings from `SpectrogramConfig` are used.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target NumPy data type for the final spectrogram array.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The computed and processed spectrogram with 'time' and 'frequency'
|
||||
coordinates.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `wav` lacks necessary 'time' coordinates or dimensions.
|
||||
"""
|
||||
config = config or SpectrogramConfig()
|
||||
|
||||
spec = stft(
|
||||
@ -79,7 +332,6 @@ def compute_spectrogram(
|
||||
window_duration=config.stft.window_duration,
|
||||
window_overlap=config.stft.window_overlap,
|
||||
window_fn=config.stft.window_fn,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
spec = crop_spectrogram_frequencies(
|
||||
@ -88,10 +340,19 @@ def compute_spectrogram(
|
||||
max_freq=config.frequencies.max_freq,
|
||||
)
|
||||
|
||||
if config.pcen:
|
||||
spec = apply_pcen(
|
||||
spec,
|
||||
time_constant=config.pcen.time_constant,
|
||||
gain=config.pcen.gain,
|
||||
power=config.pcen.power,
|
||||
bias=config.pcen.bias,
|
||||
)
|
||||
|
||||
spec = scale_spectrogram(spec, scale=config.scale)
|
||||
|
||||
if config.denoise:
|
||||
spec = denoise_spectrogram(spec)
|
||||
if config.spectral_mean_substraction:
|
||||
spec = remove_spectral_mean(spec)
|
||||
|
||||
if config.size:
|
||||
spec = resize_spectrogram(
|
||||
@ -100,7 +361,7 @@ def compute_spectrogram(
|
||||
resize_factor=config.size.resize_factor,
|
||||
)
|
||||
|
||||
if config.max_scale:
|
||||
if config.peak_normalize:
|
||||
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
||||
|
||||
return spec.astype(dtype)
|
||||
@ -111,11 +372,32 @@ def crop_spectrogram_frequencies(
|
||||
min_freq: int = 10_000,
|
||||
max_freq: int = 120_000,
|
||||
) -> xr.DataArray:
|
||||
"""Crop the frequency axis of a spectrogram to a specified range.
|
||||
|
||||
Uses `soundevent.arrays.crop_dim` to select the frequency bins
|
||||
corresponding to the range [`min_freq`, `max_freq`].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : xr.DataArray
|
||||
Input spectrogram with 'frequency' dimension and coordinates.
|
||||
min_freq : int, default=MIN_FREQ
|
||||
Minimum frequency (Hz) to keep.
|
||||
max_freq : int, default=MAX_FREQ
|
||||
Maximum frequency (Hz) to keep.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Spectrogram cropped along the frequency axis. Preserves dtype.
|
||||
"""
|
||||
start_freq, end_freq = arrays.get_dim_range(spec, dim="frequency")
|
||||
|
||||
return arrays.crop_dim(
|
||||
spec,
|
||||
dim="frequency",
|
||||
start=min_freq,
|
||||
stop=max_freq,
|
||||
start=min_freq if start_freq < min_freq else None,
|
||||
stop=max_freq if end_freq > max_freq else None,
|
||||
).astype(spec.dtype)
|
||||
|
||||
|
||||
@ -124,61 +406,61 @@ def stft(
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
window_fn: str = "hann",
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||
step = arrays.get_dim_step(wave, dim="time")
|
||||
sampling_rate = 1 / step
|
||||
"""Compute the Short-Time Fourier Transform (STFT) magnitude spectrogram.
|
||||
|
||||
nfft = int(window_duration * sampling_rate)
|
||||
noverlap = int(window_overlap * nfft)
|
||||
hop_len = nfft - noverlap
|
||||
hop_duration = hop_len / sampling_rate
|
||||
Calculates STFT parameters (N-FFT, hop length) based on the window
|
||||
duration, overlap, and waveform sample rate. Returns an xarray DataArray
|
||||
with correctly calculated 'time' and 'frequency' coordinates.
|
||||
|
||||
spec, _ = librosa.core.spectrum._spectrogram(
|
||||
y=wave.data.astype(dtype),
|
||||
power=1,
|
||||
n_fft=nfft,
|
||||
hop_length=nfft - noverlap,
|
||||
center=False,
|
||||
window=window_fn,
|
||||
)
|
||||
Parameters
|
||||
----------
|
||||
wave : xr.DataArray
|
||||
Input audio waveform with 'time' coordinates.
|
||||
window_duration : float
|
||||
Duration of the STFT window in seconds.
|
||||
window_overlap : float
|
||||
Fractional overlap between consecutive windows.
|
||||
window_fn : str, default="hann"
|
||||
Name of the window function (e.g., "hann", "hamming").
|
||||
|
||||
return xr.DataArray(
|
||||
data=spec.astype(dtype),
|
||||
dims=["frequency", "time"],
|
||||
coords={
|
||||
"frequency": arrays.create_frequency_dim_from_array(
|
||||
np.linspace(
|
||||
0,
|
||||
sampling_rate / 2,
|
||||
spec.shape[0],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=sampling_rate / nfft,
|
||||
),
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
np.linspace(
|
||||
start_time,
|
||||
end_time - (window_duration - hop_duration),
|
||||
spec.shape[1],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=hop_duration,
|
||||
),
|
||||
},
|
||||
attrs={
|
||||
**wave.attrs,
|
||||
"original_samplerate": sampling_rate,
|
||||
"nfft": nfft,
|
||||
"noverlap": noverlap,
|
||||
},
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Magnitude spectrogram with 'time' and 'frequency' dimensions and
|
||||
coordinates. STFT parameters are stored in the `attrs`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If sample rate cannot be determined from `wave` coordinates.
|
||||
"""
|
||||
return audio.compute_spectrogram(
|
||||
wave,
|
||||
window_size=window_duration,
|
||||
hop_size=(1 - window_overlap) * window_duration,
|
||||
window_type=window_fn,
|
||||
scale="amplitude",
|
||||
sort_dims=False,
|
||||
)
|
||||
|
||||
|
||||
def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
|
||||
def remove_spectral_mean(spec: xr.DataArray) -> xr.DataArray:
|
||||
"""Apply simple spectral mean subtraction for denoising.
|
||||
|
||||
Subtracts the mean value of each frequency bin (calculated across time)
|
||||
from that bin, then clips negative values to zero.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : xr.DataArray
|
||||
Input spectrogram with 'time' and 'frequency' dimensions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Denoised spectrogram with the same dimensions, coordinates, and dtype.
|
||||
"""
|
||||
return xr.DataArray(
|
||||
data=(spec - spec.mean("time")).clip(0),
|
||||
dims=spec.dims,
|
||||
@ -189,34 +471,82 @@ def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
|
||||
|
||||
def scale_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
scale: Scales,
|
||||
dtype: DTypeLike = np.float32,
|
||||
scale: Literal["dB", "power", "amplitude"],
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
if scale.name == "log":
|
||||
"""Apply final amplitude scaling/representation to the spectrogram.
|
||||
|
||||
Converts the input magnitude spectrogram based on the `scale` type:
|
||||
- "dB": Applies logarithmic scaling `log1p(C * S)`.
|
||||
- "power": Squares the magnitude values `S^2`.
|
||||
- "amplitude": Returns the input magnitude values `S` unchanged.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : xr.DataArray
|
||||
Input magnitude spectrogram (potentially after PCEN).
|
||||
scale : Literal["dB", "power", "amplitude"]
|
||||
The target amplitude representation.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target data type for the output scaled spectrogram.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Spectrogram with the specified amplitude scaling applied.
|
||||
"""
|
||||
if scale == "dB":
|
||||
return scale_log(spec, dtype=dtype)
|
||||
|
||||
if scale.name == "pcen":
|
||||
return scale_pcen(
|
||||
spec,
|
||||
time_constant=scale.time_constant,
|
||||
hop_length=scale.hop_length,
|
||||
gain=scale.gain,
|
||||
power=scale.power,
|
||||
bias=scale.bias,
|
||||
)
|
||||
if scale == "power":
|
||||
return spec**2
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def scale_pcen(
|
||||
def apply_pcen(
|
||||
spec: xr.DataArray,
|
||||
time_constant: float = 0.4,
|
||||
hop_length: int = 512,
|
||||
gain: float = 0.98,
|
||||
bias: float = 2,
|
||||
power: float = 0.5,
|
||||
) -> xr.DataArray:
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
"""Apply Per-Channel Energy Normalization (PCEN) to a spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : xr.DataArray
|
||||
Input magnitude spectrogram with required attributes like
|
||||
'processing_original_samplerate'.
|
||||
time_constant : float, default=0.4
|
||||
PCEN time constant in seconds.
|
||||
gain : float, default=0.98
|
||||
Gain factor (alpha).
|
||||
bias : float, default=2.0
|
||||
Bias factor (delta).
|
||||
power : float, default=0.5
|
||||
Exponent (r).
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target data type for the output spectrogram.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
PCEN-scaled spectrogram.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- The input spectrogram magnitude `spec` is multiplied by `2**31` before
|
||||
being passed to `audio.pcen`. This suggests the underlying implementation
|
||||
might expect values in a range typical of 16-bit or 32-bit signed integers,
|
||||
even though the input here might be float. This scaling factor should be
|
||||
verified against the specific `soundevent.audio.pcen` implementation
|
||||
details.
|
||||
"""
|
||||
samplerate = spec.attrs["samplerate"]
|
||||
hop_size = spec.attrs["hop_size"]
|
||||
|
||||
hop_length = int(hop_size * samplerate)
|
||||
t_frames = time_constant * samplerate / (float(hop_length) * 10)
|
||||
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||
return audio.pcen(
|
||||
@ -230,8 +560,34 @@ def scale_pcen(
|
||||
|
||||
def scale_log(
|
||||
spec: xr.DataArray,
|
||||
dtype: DTypeLike = np.float32,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> xr.DataArray:
|
||||
"""Apply logarithmic scaling to a magnitude spectrogram.
|
||||
|
||||
Calculates `log(1 + C * S)`, where S is the input magnitude spectrogram
|
||||
and C is a scaling factor derived from the original STFT parameters
|
||||
(sample rate, N-FFT, window function) stored in `spec.attrs`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : xr.DataArray
|
||||
Input magnitude spectrogram with required attributes like
|
||||
'processing_original_samplerate', 'processing_nfft'.
|
||||
dtype : DTypeLike, default=np.float32
|
||||
Target data type for the output spectrogram.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Log-scaled spectrogram.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If required attributes are missing from `spec.attrs`.
|
||||
ValueError
|
||||
If attributes are non-numeric or window function is invalid.
|
||||
"""
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
nfft = spec.attrs["nfft"]
|
||||
log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum())
|
||||
@ -248,6 +604,28 @@ def resize_spectrogram(
|
||||
height: int = 128,
|
||||
resize_factor: Optional[float] = 0.5,
|
||||
) -> xr.DataArray:
|
||||
"""Resize a spectrogram to target dimensions using interpolation.
|
||||
|
||||
Resizes the frequency axis to `height` bins and optionally resizes the
|
||||
time axis by `resize_factor`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec : xr.DataArray
|
||||
Input spectrogram with 'time' and 'frequency' dimensions.
|
||||
height : int, default=128
|
||||
Target number of frequency bins (vertical dimension).
|
||||
resize_factor : float, optional
|
||||
Factor to resize the time dimension. If 1.0 or None, time dimension
|
||||
is unchanged. If 0.5, time dimension is halved, etc.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Resized spectrogram. Coordinates are typically adjusted by the
|
||||
underlying resize operation if implemented in `ops.resize`.
|
||||
The dtype is currently hardcoded to float32 by ops.resize call.
|
||||
"""
|
||||
resize_factor = resize_factor or 1
|
||||
current_width = spec.sizes["time"]
|
||||
return ops.resize(
|
||||
@ -258,61 +636,41 @@ def resize_spectrogram(
|
||||
)
|
||||
|
||||
|
||||
def adjust_spectrogram_width(
|
||||
spec: xr.DataArray,
|
||||
divide_factor: int = 32,
|
||||
time_period: float = 0.001,
|
||||
) -> xr.DataArray:
|
||||
time_width = spec.sizes["time"]
|
||||
|
||||
if time_width % divide_factor == 0:
|
||||
return spec
|
||||
|
||||
target_size = int(
|
||||
np.ceil(spec.sizes["time"] / divide_factor) * divide_factor
|
||||
)
|
||||
extra_duration = (target_size - time_width) * time_period
|
||||
_, stop = arrays.get_dim_range(spec, dim="time")
|
||||
resized = ops.extend_dim(
|
||||
spec,
|
||||
dim="time",
|
||||
stop=stop + extra_duration,
|
||||
)
|
||||
return resized
|
||||
|
||||
|
||||
def duration_to_spec_width(
|
||||
duration: float,
|
||||
samplerate: int,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
) -> int:
|
||||
samples = int(duration * samplerate)
|
||||
fft_len = int(window_duration * samplerate)
|
||||
fft_overlap = int(window_overlap * fft_len)
|
||||
hop_len = fft_len - fft_overlap
|
||||
width = (samples - fft_len + hop_len) / hop_len
|
||||
return int(np.floor(width))
|
||||
|
||||
|
||||
def spec_width_to_samples(
|
||||
width: int,
|
||||
samplerate: int,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
) -> int:
|
||||
fft_len = int(window_duration * samplerate)
|
||||
fft_overlap = int(window_overlap * fft_len)
|
||||
hop_len = fft_len - fft_overlap
|
||||
return width * hop_len + fft_len - hop_len
|
||||
|
||||
|
||||
def get_spectrogram_resolution(
|
||||
config: SpectrogramConfig,
|
||||
) -> tuple[float, float]:
|
||||
"""Calculate the approximate resolution of the final spectrogram.
|
||||
|
||||
Computes the width of each frequency bin (Hz/bin) and the duration
|
||||
of each time bin (seconds/bin) based on the configuration parameters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : SpectrogramConfig
|
||||
The spectrogram configuration object.
|
||||
samplerate : int, optional
|
||||
The sample rate of the audio *before* STFT. Required if needed to
|
||||
calculate hop duration accurately from STFT config, but the current
|
||||
implementation calculates hop_duration directly from STFT config times.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
A tuple containing:
|
||||
- frequency_resolution (float): Approximate Hz per frequency bin.
|
||||
- time_resolution (float): Approximate seconds per time bin.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If required configuration fields (like `config.size`) are missing
|
||||
or invalid.
|
||||
"""
|
||||
max_freq = config.frequencies.max_freq
|
||||
min_freq = config.frequencies.min_freq
|
||||
assert config.size is not None
|
||||
|
||||
if config.size is None:
|
||||
raise ValueError("Spectrogram size configuration is required.")
|
||||
|
||||
spec_height = config.size.height
|
||||
resize_factor = config.size.resize_factor or 1
|
||||
|
381
batdetect2/preprocess/types.py
Normal file
381
batdetect2/preprocess/types.py
Normal file
@ -0,0 +1,381 @@
|
||||
"""Defines common interfaces (Protocols) for preprocessing components.
|
||||
|
||||
This module centralizes the Protocol definitions used throughout the
|
||||
`batdetect2.preprocess` package. Protocols define expected methods and
|
||||
signatures, allowing for flexible and interchangeable implementations of
|
||||
components like audio loaders and spectrogram builders.
|
||||
|
||||
Using these protocols ensures that different parts of the preprocessing
|
||||
pipeline can interact consistently, regardless of the specific underlying
|
||||
implementation (e.g., different libraries or custom configurations).
|
||||
"""
|
||||
|
||||
from typing import Optional, Protocol, Union
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
|
||||
class AudioLoader(Protocol):
|
||||
"""Defines the interface for an audio loading and processing component.
|
||||
|
||||
An AudioLoader is responsible for retrieving audio data corresponding to
|
||||
different soundevent objects (files, Recordings, Clips) and applying a
|
||||
configured set of initial preprocessing steps. Adhering to this protocol
|
||||
allows for different loading strategies or implementations.
|
||||
"""
|
||||
|
||||
def load_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess audio directly from a file path.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the audio file.
|
||||
audio_dir : PathLike, optional
|
||||
A directory prefix to prepend to the path if `path` is relative.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform as an xarray DataArray
|
||||
with time coordinates. Typically loads only the first channel.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the audio file cannot be found.
|
||||
Exception
|
||||
If the audio file cannot be loaded or processed.
|
||||
"""
|
||||
...
|
||||
|
||||
def load_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess the entire audio for a Recording object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The Recording object containing metadata about the audio file.
|
||||
audio_dir : PathLike, optional
|
||||
A directory where the audio file associated with the recording
|
||||
can be found, especially if the path in the recording is relative.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform. Typically loads only
|
||||
the first channel.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the audio file associated with the recording cannot be found.
|
||||
Exception
|
||||
If the audio file cannot be loaded or processed.
|
||||
"""
|
||||
...
|
||||
|
||||
def load_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess the audio segment defined by a Clip object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
The Clip object specifying the recording and the start/end times
|
||||
of the segment to load.
|
||||
audio_dir : PathLike, optional
|
||||
A directory where the audio file associated with the clip's
|
||||
recording can be found.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform for the specified clip
|
||||
duration. Typically loads only the first channel.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the audio file associated with the clip cannot be found.
|
||||
Exception
|
||||
If the audio file cannot be loaded or processed.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class SpectrogramBuilder(Protocol):
|
||||
"""Defines the interface for a spectrogram generation component.
|
||||
|
||||
A SpectrogramBuilder takes a waveform (as numpy array or xarray DataArray)
|
||||
and produces a spectrogram (as an xarray DataArray) based on its internal
|
||||
configuration or implementation.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
wav: Union[np.ndarray, xr.DataArray],
|
||||
samplerate: Optional[int] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Generate a spectrogram from an audio waveform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : Union[np.ndarray, xr.DataArray]
|
||||
The input audio waveform. If a numpy array, `samplerate` must
|
||||
also be provided. If an xarray DataArray, it must have a 'time'
|
||||
coordinate from which the sample rate can be inferred.
|
||||
samplerate : int, optional
|
||||
The sample rate of the audio in Hz. Required if `wav` is a
|
||||
numpy array. If `wav` is an xarray DataArray, this parameter is
|
||||
ignored as the sample rate is derived from the coordinates.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The computed spectrogram as an xarray DataArray with 'time' and
|
||||
'frequency' coordinates.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `wav` is a numpy array and `samplerate` is not provided, or
|
||||
if `wav` is an xarray DataArray without a valid 'time' coordinate.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class PreprocessorProtocol(Protocol):
|
||||
"""Defines a high-level interface for the complete preprocessing pipeline.
|
||||
|
||||
A Preprocessor combines audio loading and spectrogram generation steps.
|
||||
It provides methods to go directly from source descriptions (file paths,
|
||||
Recording objects, Clip objects) to the final spectrogram representation
|
||||
needed by the model. It may also expose intermediate steps like audio
|
||||
loading or spectrogram computation from a waveform.
|
||||
"""
|
||||
|
||||
max_freq: float
|
||||
|
||||
min_freq: float
|
||||
|
||||
def preprocess_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load audio from a file and compute the final processed spectrogram.
|
||||
|
||||
Performs the full pipeline:
|
||||
|
||||
Load -> Preprocess Audio -> Compute Spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the audio file.
|
||||
audio_dir : PathLike, optional
|
||||
A directory prefix if `path` is relative.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The final processed spectrogram.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the audio file cannot be found.
|
||||
Exception
|
||||
If any step in the loading or preprocessing fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def preprocess_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load audio for a Recording and compute the processed spectrogram.
|
||||
|
||||
Performs the full pipeline for the entire duration of the recording.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The Recording object.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The final processed spectrogram.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the audio file cannot be found.
|
||||
Exception
|
||||
If any step in the loading or preprocessing fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def preprocess_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load audio for a Clip and compute the final processed spectrogram.
|
||||
|
||||
Performs the full pipeline for the specified clip segment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
The Clip object defining the audio segment.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The final processed spectrogram.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the audio file cannot be found.
|
||||
Exception
|
||||
If any step in the loading or preprocessing fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def load_file_audio(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess *only* the audio waveform from a file path.
|
||||
|
||||
Performs the initial audio loading and waveform processing steps
|
||||
(like resampling, scaling), but stops *before* spectrogram generation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the audio file.
|
||||
audio_dir : PathLike, optional
|
||||
A directory prefix if `path` is relative.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError, Exception
|
||||
If audio loading/preprocessing fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def load_recording_audio(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess *only* the audio waveform for a Recording.
|
||||
|
||||
Performs the initial audio loading and waveform processing steps
|
||||
for the entire recording duration.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The Recording object.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError, Exception
|
||||
If audio loading/preprocessing fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def load_clip_audio(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess *only* the audio waveform for a Clip.
|
||||
|
||||
Performs the initial audio loading and waveform processing steps
|
||||
for the specified clip segment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
The Clip object defining the segment.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform segment.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError, Exception
|
||||
If audio loading/preprocessing fails.
|
||||
"""
|
||||
...
|
||||
|
||||
def compute_spectrogram(
|
||||
self,
|
||||
wav: Union[xr.DataArray, np.ndarray],
|
||||
) -> xr.DataArray:
|
||||
"""Compute the spectrogram from a pre-loaded audio waveform.
|
||||
|
||||
Applies the spectrogram generation steps (STFT, scaling, etc.) defined
|
||||
by the `SpectrogramBuilder` component of the preprocessor to an
|
||||
already loaded (and potentially preprocessed) waveform.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : Union[xr.DataArray, np.ndarray]
|
||||
The input audio waveform. If numpy array, `samplerate` is required.
|
||||
samplerate : int, optional
|
||||
Sample rate in Hz (required if `wav` is np.ndarray).
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The computed spectrogram.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError, Exception
|
||||
If waveform input is invalid or spectrogram computation fails.
|
||||
"""
|
||||
...
|
666
batdetect2/targets/__init__.py
Normal file
666
batdetect2/targets/__init__.py
Normal file
@ -0,0 +1,666 @@
|
||||
"""Main entry point for the BatDetect2 Target Definition subsystem.
|
||||
|
||||
This package (`batdetect2.targets`) provides the tools and configurations
|
||||
necessary to define precisely what the BatDetect2 model should learn to detect,
|
||||
classify, and localize from audio data. It involves several conceptual steps,
|
||||
managed through configuration files and culminating in an executable pipeline:
|
||||
|
||||
1. **Terms (`.terms`)**: Defining vocabulary for annotation tags.
|
||||
2. **Filtering (`.filtering`)**: Selecting relevant sound event annotations.
|
||||
3. **Transformation (`.transform`)**: Modifying tags (standardization,
|
||||
derivation).
|
||||
4. **ROI Mapping (`.roi`)**: Defining how annotation geometry (ROIs) maps to
|
||||
target position and size representations, and back.
|
||||
5. **Class Definition (`.classes`)**: Mapping tags to target class names
|
||||
(encoding) and mapping predicted names back to tags (decoding).
|
||||
|
||||
This module exposes the key components for users to configure and utilize this
|
||||
target definition pipeline, primarily through the `TargetConfig` data structure
|
||||
and the `Targets` class (implementing `TargetProtocol`), which encapsulates the
|
||||
configured processing steps. The main way to create a functional `Targets`
|
||||
object is via the `build_targets` or `load_targets` functions.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.classes import (
|
||||
ClassesConfig,
|
||||
SoundEventDecoder,
|
||||
SoundEventEncoder,
|
||||
TargetClass,
|
||||
build_generic_class_tags,
|
||||
build_sound_event_decoder,
|
||||
build_sound_event_encoder,
|
||||
get_class_names_from_config,
|
||||
load_classes_config,
|
||||
load_decoder_from_config,
|
||||
load_encoder_from_config,
|
||||
)
|
||||
from batdetect2.targets.filtering import (
|
||||
FilterConfig,
|
||||
FilterRule,
|
||||
SoundEventFilter,
|
||||
build_sound_event_filter,
|
||||
load_filter_config,
|
||||
load_filter_from_config,
|
||||
)
|
||||
from batdetect2.targets.rois import (
|
||||
ROIConfig,
|
||||
ROITargetMapper,
|
||||
build_roi_mapper,
|
||||
)
|
||||
from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermInfo,
|
||||
TermRegistry,
|
||||
call_type,
|
||||
get_tag_from_info,
|
||||
get_term_from_key,
|
||||
individual,
|
||||
register_term,
|
||||
term_registry,
|
||||
)
|
||||
from batdetect2.targets.transform import (
|
||||
DerivationRegistry,
|
||||
DeriveTagRule,
|
||||
MapValueRule,
|
||||
ReplaceRule,
|
||||
SoundEventTransformation,
|
||||
TransformConfig,
|
||||
build_transformation_from_config,
|
||||
derivation_registry,
|
||||
get_derivation,
|
||||
load_transformation_config,
|
||||
load_transformation_from_config,
|
||||
register_derivation,
|
||||
)
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClassesConfig",
|
||||
"DEFAULT_TARGET_CONFIG",
|
||||
"DeriveTagRule",
|
||||
"FilterConfig",
|
||||
"FilterRule",
|
||||
"MapValueRule",
|
||||
"ROIConfig",
|
||||
"ROITargetMapper",
|
||||
"ReplaceRule",
|
||||
"SoundEventDecoder",
|
||||
"SoundEventEncoder",
|
||||
"SoundEventFilter",
|
||||
"SoundEventTransformation",
|
||||
"TagInfo",
|
||||
"TargetClass",
|
||||
"TargetConfig",
|
||||
"TargetProtocol",
|
||||
"Targets",
|
||||
"TermInfo",
|
||||
"TransformConfig",
|
||||
"build_generic_class_tags",
|
||||
"build_roi_mapper",
|
||||
"build_sound_event_decoder",
|
||||
"build_sound_event_encoder",
|
||||
"build_sound_event_filter",
|
||||
"build_transformation_from_config",
|
||||
"call_type",
|
||||
"get_class_names_from_config",
|
||||
"get_derivation",
|
||||
"get_tag_from_info",
|
||||
"get_term_from_key",
|
||||
"individual",
|
||||
"load_classes_config",
|
||||
"load_decoder_from_config",
|
||||
"load_encoder_from_config",
|
||||
"load_filter_config",
|
||||
"load_filter_from_config",
|
||||
"load_target_config",
|
||||
"load_transformation_config",
|
||||
"load_transformation_from_config",
|
||||
"register_derivation",
|
||||
"register_term",
|
||||
]
|
||||
|
||||
|
||||
class TargetConfig(BaseConfig):
|
||||
"""Unified configuration for the entire target definition pipeline.
|
||||
|
||||
This model aggregates the configurations for semantic processing (filtering,
|
||||
transformation, class definition) and geometric processing (ROI mapping).
|
||||
It serves as the primary input for building a complete `Targets` object
|
||||
via `build_targets` or `load_targets`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
filtering : FilterConfig, optional
|
||||
Configuration for filtering sound event annotations based on tags.
|
||||
If None or omitted, no filtering is applied.
|
||||
transforms : TransformConfig, optional
|
||||
Configuration for transforming annotation tags
|
||||
(mapping, derivation, etc.). If None or omitted, no tag transformations
|
||||
are applied.
|
||||
classes : ClassesConfig
|
||||
Configuration defining the specific target classes, their tag matching
|
||||
rules for encoding, their representative tags for decoding
|
||||
(`output_tags`), and the definition of the generic class tags.
|
||||
This section is mandatory.
|
||||
roi : ROIConfig, optional
|
||||
Configuration defining how geometric ROIs (e.g., bounding boxes) are
|
||||
mapped to target representations (reference point, scaled size).
|
||||
Controls `position`, `time_scale`, `frequency_scale`. If None or
|
||||
omitted, default ROI mapping settings are used.
|
||||
"""
|
||||
|
||||
filtering: Optional[FilterConfig] = None
|
||||
transforms: Optional[TransformConfig] = None
|
||||
classes: ClassesConfig
|
||||
roi: Optional[ROIConfig] = None
|
||||
|
||||
|
||||
def load_target_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> TargetConfig:
|
||||
"""Load the unified target configuration from a file.
|
||||
|
||||
Reads a configuration file (typically YAML) and validates it against the
|
||||
`TargetConfig` schema, potentially extracting data from a nested field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
target configuration. If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TargetConfig
|
||||
The loaded and validated unified target configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded configuration data does not conform to the
|
||||
`TargetConfig` schema (including validation within nested configs
|
||||
like `ClassesConfig`).
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path within the loaded data.
|
||||
"""
|
||||
return load_config(path=path, schema=TargetConfig, field=field)
|
||||
|
||||
|
||||
class Targets(TargetProtocol):
|
||||
"""Encapsulates the complete configured target definition pipeline.
|
||||
|
||||
This class implements the `TargetProtocol`, holding the configured
|
||||
functions for filtering, transforming, encoding (tags to class name),
|
||||
decoding (class name to tags), and mapping ROIs (geometry to position/size
|
||||
and back). It provides a high-level interface to apply these steps and
|
||||
access relevant metadata like class names and dimension names.
|
||||
|
||||
Instances are typically created using the `build_targets` factory function
|
||||
or the `load_targets` convenience loader.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
class_names : List[str]
|
||||
An ordered list of the unique names of the specific target classes
|
||||
defined in the configuration.
|
||||
generic_class_tags : List[data.Tag]
|
||||
A list of `soundevent.data.Tag` objects representing the configured
|
||||
generic class category (used when no specific class matches).
|
||||
dimension_names : List[str]
|
||||
The names of the size dimensions handled by the ROI mapper
|
||||
(e.g., ['width', 'height']).
|
||||
"""
|
||||
|
||||
class_names: List[str]
|
||||
generic_class_tags: List[data.Tag]
|
||||
dimension_names: List[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encode_fn: SoundEventEncoder,
|
||||
decode_fn: SoundEventDecoder,
|
||||
roi_mapper: ROITargetMapper,
|
||||
class_names: list[str],
|
||||
generic_class_tags: List[data.Tag],
|
||||
filter_fn: Optional[SoundEventFilter] = None,
|
||||
transform_fn: Optional[SoundEventTransformation] = None,
|
||||
):
|
||||
"""Initialize the Targets object.
|
||||
|
||||
Note: This constructor is typically called internally by the
|
||||
`build_targets` factory function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
encode_fn : SoundEventEncoder
|
||||
Configured function to encode annotations to class names.
|
||||
decode_fn : SoundEventDecoder
|
||||
Configured function to decode class names to tags.
|
||||
roi_mapper : ROITargetMapper
|
||||
Configured object for mapping geometry to/from position/size.
|
||||
class_names : list[str]
|
||||
Ordered list of specific target class names.
|
||||
generic_class_tags : List[data.Tag]
|
||||
List of tags representing the generic class.
|
||||
filter_fn : SoundEventFilter, optional
|
||||
Configured function to filter annotations. Defaults to None.
|
||||
transform_fn : SoundEventTransformation, optional
|
||||
Configured function to transform annotation tags. Defaults to None.
|
||||
"""
|
||||
self.class_names = class_names
|
||||
self.generic_class_tags = generic_class_tags
|
||||
self.dimension_names = roi_mapper.dimension_names
|
||||
|
||||
self._roi_mapper = roi_mapper
|
||||
self._filter_fn = filter_fn
|
||||
self._encode_fn = encode_fn
|
||||
self._decode_fn = decode_fn
|
||||
self._transform_fn = transform_fn
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||
"""Apply the configured filter to a sound event annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation to filter.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation should be kept (passes the filter),
|
||||
False otherwise. If no filter was configured, always returns True.
|
||||
"""
|
||||
if not self._filter_fn:
|
||||
return True
|
||||
return self._filter_fn(sound_event)
|
||||
|
||||
def encode(self, sound_event: data.SoundEventAnnotation) -> Optional[str]:
|
||||
"""Encode a sound event annotation to its target class name.
|
||||
|
||||
Applies the configured class definition rules (including priority)
|
||||
to determine the specific class name for the annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation to encode. Note: This should typically be called
|
||||
*after* applying any transformations via the `transform` method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
The name of the matched target class, or None if the annotation
|
||||
does not match any specific class rule (i.e., it belongs to the
|
||||
generic category).
|
||||
"""
|
||||
return self._encode_fn(sound_event)
|
||||
|
||||
def decode(self, class_label: str) -> List[data.Tag]:
|
||||
"""Decode a predicted class name back into representative tags.
|
||||
|
||||
Uses the configured mapping (based on `TargetClass.output_tags` or
|
||||
`TargetClass.tags`) to convert a class name string into a list of
|
||||
`soundevent.data.Tag` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
class_label : str
|
||||
The class name to decode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.Tag]
|
||||
The list of tags corresponding to the input class name.
|
||||
"""
|
||||
return self._decode_fn(class_label)
|
||||
|
||||
def transform(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> data.SoundEventAnnotation:
|
||||
"""Apply the configured tag transformations to an annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation whose tags should be transformed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.SoundEventAnnotation
|
||||
A new annotation object with the transformed tags. If no
|
||||
transformations were configured, the original annotation object is
|
||||
returned.
|
||||
"""
|
||||
if self._transform_fn:
|
||||
return self._transform_fn(sound_event)
|
||||
return sound_event
|
||||
|
||||
def get_position(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> tuple[float, float]:
|
||||
"""Extract the target reference position from the annotation's roi.
|
||||
|
||||
Delegates to the internal ROI mapper's `get_roi_position` method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry.
|
||||
"""
|
||||
geom = sound_event.sound_event.geometry
|
||||
|
||||
if geom is None:
|
||||
raise ValueError(
|
||||
"Sound event has no geometry, cannot get its position."
|
||||
)
|
||||
|
||||
return self._roi_mapper.get_roi_position(geom)
|
||||
|
||||
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
|
||||
"""Calculate the target size dimensions from the annotation's geometry.
|
||||
|
||||
Delegates to the internal ROI mapper's `get_roi_size` method, which
|
||||
applies configured scaling factors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI).
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
NumPy array containing the size dimensions, matching the
|
||||
order in `self.dimension_names` (e.g., `[width, height]`).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry.
|
||||
"""
|
||||
geom = sound_event.sound_event.geometry
|
||||
|
||||
if geom is None:
|
||||
raise ValueError(
|
||||
"Sound event has no geometry, cannot get its size."
|
||||
)
|
||||
|
||||
return self._roi_mapper.get_roi_size(geom)
|
||||
|
||||
def recover_roi(
|
||||
self,
|
||||
pos: tuple[float, float],
|
||||
dims: np.ndarray,
|
||||
) -> data.Geometry:
|
||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||
|
||||
Delegates to the internal ROI mapper's `recover_roi` method, which
|
||||
un-scales the dimensions and reconstructs the geometry (typically a
|
||||
`BoundingBox`).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
dims : np.ndarray
|
||||
NumPy array with size dimensions (e.g., from model prediction),
|
||||
matching the order in `self.dimension_names`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.Geometry
|
||||
The reconstructed geometry (typically `BoundingBox`).
|
||||
"""
|
||||
return self._roi_mapper.recover_roi(pos, dims)
|
||||
|
||||
|
||||
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
filtering=FilterConfig(
|
||||
rules=[
|
||||
FilterRule(
|
||||
match_type="all",
|
||||
tags=[TagInfo(key="event", value="Echolocation")],
|
||||
),
|
||||
FilterRule(
|
||||
match_type="exclude",
|
||||
tags=[
|
||||
TagInfo(key="event", value="Feeding"),
|
||||
TagInfo(key="event", value="Unknown"),
|
||||
TagInfo(key="event", value="Not Bat"),
|
||||
],
|
||||
),
|
||||
]
|
||||
),
|
||||
classes=ClassesConfig(
|
||||
classes=[
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Myotis mystacinus")],
|
||||
name="myomys",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Myotis alcathoe")],
|
||||
name="myoalc",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Eptesicus serotinus")],
|
||||
name="eptser",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Pipistrellus nathusii")],
|
||||
name="pipnat",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Barbastellus barbastellus")],
|
||||
name="barbar",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Myotis nattereri")],
|
||||
name="myonat",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Myotis daubentonii")],
|
||||
name="myodau",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Myotis brandtii")],
|
||||
name="myobra",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Pipistrellus pipistrellus")],
|
||||
name="pippip",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Myotis bechsteinii")],
|
||||
name="myobec",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Pipistrellus pygmaeus")],
|
||||
name="pippyg",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Rhinolophus hipposideros")],
|
||||
name="rhihip",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Nyctalus leisleri")],
|
||||
name="nyclei",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
|
||||
name="rhifer",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Plecotus auritus")],
|
||||
name="pleaur",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Nyctalus noctula")],
|
||||
name="nycnoc",
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Plecotus austriacus")],
|
||||
name="pleaus",
|
||||
),
|
||||
],
|
||||
generic_class=[TagInfo(value="Bat")],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def build_targets(
|
||||
config: Optional[TargetConfig] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
) -> Targets:
|
||||
"""Build a Targets object from a loaded TargetConfig.
|
||||
|
||||
This factory function takes the unified `TargetConfig` and constructs all
|
||||
necessary functional components (filter, transform, encoder,
|
||||
decoder, ROI mapper) by calling their respective builder functions. It also
|
||||
extracts metadata (class names, generic tags, dimension names) to create
|
||||
and return a fully initialized `Targets` instance, ready to process
|
||||
annotations.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : TargetConfig
|
||||
The loaded and validated unified target configuration object.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to use for resolving term keys. Defaults
|
||||
to the global `batdetect2.targets.terms.term_registry`.
|
||||
derivation_registry : DerivationRegistry, optional
|
||||
The DerivationRegistry instance to use for resolving derivation
|
||||
function names. Defaults to the global
|
||||
`batdetect2.targets.transform.derivation_registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Targets
|
||||
An initialized `Targets` object ready for use.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If term keys or derivation function keys specified in the `config`
|
||||
are not found in their respective registries.
|
||||
ImportError, AttributeError, TypeError
|
||||
If dynamic import of a derivation function fails (when configured).
|
||||
"""
|
||||
config = config or DEFAULT_TARGET_CONFIG
|
||||
|
||||
filter_fn = (
|
||||
build_sound_event_filter(
|
||||
config.filtering,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
if config.filtering
|
||||
else None
|
||||
)
|
||||
encode_fn = build_sound_event_encoder(
|
||||
config.classes,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
decode_fn = build_sound_event_decoder(
|
||||
config.classes,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
transform_fn = (
|
||||
build_transformation_from_config(
|
||||
config.transforms,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
||||
if config.transforms
|
||||
else None
|
||||
)
|
||||
roi_mapper = build_roi_mapper(config.roi or ROIConfig())
|
||||
class_names = get_class_names_from_config(config.classes)
|
||||
generic_class_tags = build_generic_class_tags(
|
||||
config.classes,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
|
||||
return Targets(
|
||||
filter_fn=filter_fn,
|
||||
encode_fn=encode_fn,
|
||||
decode_fn=decode_fn,
|
||||
class_names=class_names,
|
||||
roi_mapper=roi_mapper,
|
||||
generic_class_tags=generic_class_tags,
|
||||
transform_fn=transform_fn,
|
||||
)
|
||||
|
||||
|
||||
def load_targets(
|
||||
config_path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
) -> Targets:
|
||||
"""Load a Targets object directly from a configuration file.
|
||||
|
||||
This convenience factory method loads the `TargetConfig` from the
|
||||
specified file path and then calls `Targets.from_config` to build
|
||||
the fully initialized `Targets` object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : data.PathLike
|
||||
Path to the configuration file (e.g., YAML).
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing
|
||||
the target configuration. If None, the entire file content is used.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to use. Defaults to the global default.
|
||||
derivation_registry : DerivationRegistry, optional
|
||||
The DerivationRegistry instance to use. Defaults to the global
|
||||
default.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Targets
|
||||
An initialized `Targets` object ready for use.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
|
||||
TypeError
|
||||
Errors raised during file loading, validation, or extraction via
|
||||
`load_target_config`.
|
||||
KeyError, ImportError, AttributeError, TypeError
|
||||
Errors raised during the build process by `Targets.from_config`
|
||||
(e.g., missing keys in registries, failed imports).
|
||||
"""
|
||||
config = load_target_config(
|
||||
config_path,
|
||||
field=field,
|
||||
)
|
||||
return build_targets(
|
||||
config,
|
||||
term_registry=term_registry,
|
||||
derivation_registry=derivation_registry,
|
||||
)
|
618
batdetect2/targets/classes.py
Normal file
618
batdetect2/targets/classes.py
Normal file
@ -0,0 +1,618 @@
|
||||
from collections import Counter
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.terms import (
|
||||
GENERIC_CLASS_KEY,
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
get_tag_from_info,
|
||||
term_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SoundEventEncoder",
|
||||
"SoundEventDecoder",
|
||||
"TargetClass",
|
||||
"ClassesConfig",
|
||||
"load_classes_config",
|
||||
"load_encoder_from_config",
|
||||
"load_decoder_from_config",
|
||||
"build_sound_event_encoder",
|
||||
"build_sound_event_decoder",
|
||||
"build_generic_class_tags",
|
||||
"get_class_names_from_config",
|
||||
"DEFAULT_SPECIES_LIST",
|
||||
]
|
||||
|
||||
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||
"""Type alias for a sound event class encoder function.
|
||||
|
||||
An encoder function takes a sound event annotation and returns the string name
|
||||
of the target class it belongs to, based on a predefined set of rules.
|
||||
If the annotation does not match any defined target class according to the
|
||||
rules, the function returns None.
|
||||
"""
|
||||
|
||||
|
||||
SoundEventDecoder = Callable[[str], List[data.Tag]]
|
||||
"""Type alias for a sound event class decoder function.
|
||||
|
||||
A decoder function takes a class name string (as predicted by the model or
|
||||
assigned during encoding) and returns a list of `soundevent.data.Tag` objects
|
||||
that represent that class according to the configuration. This is used to
|
||||
translate model outputs back into meaningful annotations.
|
||||
"""
|
||||
|
||||
DEFAULT_SPECIES_LIST = [
|
||||
"Barbastella barbastellus",
|
||||
"Eptesicus serotinus",
|
||||
"Myotis alcathoe",
|
||||
"Myotis bechsteinii",
|
||||
"Myotis brandtii",
|
||||
"Myotis daubentonii",
|
||||
"Myotis mystacinus",
|
||||
"Myotis nattereri",
|
||||
"Nyctalus leisleri",
|
||||
"Nyctalus noctula",
|
||||
"Pipistrellus nathusii",
|
||||
"Pipistrellus pipistrellus",
|
||||
"Pipistrellus pygmaeus",
|
||||
"Plecotus auritus",
|
||||
"Plecotus austriacus",
|
||||
"Rhinolophus ferrumequinum",
|
||||
"Rhinolophus hipposideros",
|
||||
]
|
||||
"""A default list of common bat species names found in the UK."""
|
||||
|
||||
|
||||
class TargetClass(BaseConfig):
|
||||
"""Defines criteria for encoding annotations and decoding predictions.
|
||||
|
||||
Each instance represents one potential output class for the classification
|
||||
model. It specifies:
|
||||
1. A unique `name` for the class.
|
||||
2. The tag conditions (`tags` and `match_type`) an annotation must meet to
|
||||
be assigned this class name during training data preparation (encoding).
|
||||
3. An optional, alternative set of tags (`output_tags`) to be used when
|
||||
converting a model's prediction of this class name back into annotation
|
||||
tags (decoding).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
The unique name assigned to this target class (e.g., 'pippip',
|
||||
'myodau', 'noise'). This name is used as the label during model
|
||||
training and is the expected output from the model's prediction.
|
||||
Should be unique across all TargetClass definitions in a configuration.
|
||||
tags : List[TagInfo]
|
||||
A list of one or more tags (defined using `TagInfo`) used to identify
|
||||
if an existing annotation belongs to this class during encoding (data
|
||||
preparation for training). The `match_type` attribute determines how
|
||||
these tags are evaluated.
|
||||
match_type : Literal["all", "any"], default="all"
|
||||
Determines how the `tags` list is evaluated during encoding:
|
||||
- "all": The annotation must have *all* the tags listed to match.
|
||||
- "any": The annotation must have *at least one* of the tags listed
|
||||
to match.
|
||||
output_tags: Optional[List[TagInfo]], default=None
|
||||
An optional list of tags (defined using `TagInfo`) to be assigned to a
|
||||
new annotation when the model predicts this class `name`. If `None`
|
||||
(default), the tags listed in the `tags` field will be used for
|
||||
decoding. If provided, this list overrides the `tags` field for the
|
||||
purpose of decoding predictions back into meaningful annotation tags.
|
||||
This allows, for example, training on broader categories but decoding
|
||||
to more specific representative tags.
|
||||
"""
|
||||
|
||||
name: str
|
||||
tags: List[TagInfo] = Field(min_length=1)
|
||||
match_type: Literal["all", "any"] = Field(default="all")
|
||||
output_tags: Optional[List[TagInfo]] = None
|
||||
|
||||
|
||||
def _get_default_classes() -> List[TargetClass]:
|
||||
"""Generate a list of default target classes.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[TargetClass]
|
||||
A list of TargetClass objects, one for each species in
|
||||
DEFAULT_SPECIES_LIST. The class names are simplified versions of the
|
||||
species names.
|
||||
"""
|
||||
return [
|
||||
TargetClass(
|
||||
name=_get_default_class_name(value),
|
||||
tags=[TagInfo(key=GENERIC_CLASS_KEY, value=value)],
|
||||
)
|
||||
for value in DEFAULT_SPECIES_LIST
|
||||
]
|
||||
|
||||
|
||||
def _get_default_class_name(species: str) -> str:
|
||||
"""Generate a default class name from a species name.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
species : str
|
||||
The species name (e.g., "Myotis daubentonii").
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
A simplified class name (e.g., "myodau").
|
||||
The genus and species names are converted to lowercase,
|
||||
the first three letters of each are taken, and concatenated.
|
||||
"""
|
||||
genus, species = species.strip().split(" ")
|
||||
return f"{genus.lower()[:3]}{species.lower()[:3]}"
|
||||
|
||||
|
||||
def _get_default_generic_class() -> List[TagInfo]:
|
||||
"""Generate the default list of TagInfo objects for the generic class.
|
||||
|
||||
Provides a default set of tags used to represent the generic "Bat" category
|
||||
when decoding predictions that didn't match a specific class.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[TagInfo]
|
||||
A list containing default TagInfo objects, typically representing
|
||||
`call_type: Echolocation` and `order: Chiroptera`.
|
||||
"""
|
||||
return [
|
||||
TagInfo(key="call_type", value="Echolocation"),
|
||||
TagInfo(key="order", value="Chiroptera"),
|
||||
]
|
||||
|
||||
|
||||
class ClassesConfig(BaseConfig):
|
||||
"""Configuration defining target classes and the generic fallback category.
|
||||
|
||||
Holds the ordered list of specific target class definitions (`TargetClass`)
|
||||
and defines the tags representing the generic category for sounds that pass
|
||||
filtering but do not match any specific class.
|
||||
|
||||
The order of `TargetClass` objects in the `classes` list defines the
|
||||
priority for classification during encoding. The system checks annotations
|
||||
against these definitions sequentially and assigns the name of the *first*
|
||||
matching class.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
classes : List[TargetClass]
|
||||
An ordered list of specific target class definitions. The order
|
||||
determines matching priority (first match wins). Defaults to a
|
||||
standard set of classes via `get_default_classes`.
|
||||
generic_class : List[TagInfo]
|
||||
A list of tags defining the "generic" or "unclassified but relevant"
|
||||
category (e.g., representing a generic 'Bat' call that wasn't
|
||||
assigned to a specific species). These tags are typically assigned
|
||||
during decoding when a sound event was detected and passed filtering
|
||||
but did not match any specific class rule defined in the `classes` list.
|
||||
Defaults to a standard set of tags via `get_default_generic_class`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If validation fails (e.g., non-unique class names in the `classes`
|
||||
list).
|
||||
|
||||
Notes
|
||||
-----
|
||||
- It is crucial that the `name` attribute of each `TargetClass` in the
|
||||
`classes` list is unique. This configuration includes a validator to
|
||||
enforce this uniqueness.
|
||||
- The `generic_class` tags provide a baseline identity for relevant sounds
|
||||
that don't fit into more specific defined categories.
|
||||
"""
|
||||
|
||||
classes: List[TargetClass] = Field(default_factory=_get_default_classes)
|
||||
|
||||
generic_class: List[TagInfo] = Field(
|
||||
default_factory=_get_default_generic_class
|
||||
)
|
||||
|
||||
@field_validator("classes")
|
||||
def check_unique_class_names(cls, v: List[TargetClass]):
|
||||
"""Ensure all defined class names are unique."""
|
||||
names = [c.name for c in v]
|
||||
|
||||
if len(names) != len(set(names)):
|
||||
name_counts = Counter(names)
|
||||
duplicates = [
|
||||
name for name, count in name_counts.items() if count > 1
|
||||
]
|
||||
raise ValueError(
|
||||
"Class names must be unique. Found duplicates: "
|
||||
f"{', '.join(duplicates)}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
def _is_target_class(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
match_all: bool = True,
|
||||
) -> bool:
|
||||
"""Check if a sound event annotation matches a set of required tags.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
required_tags : Set[data.Tag]
|
||||
A set of `soundevent.data.Tag` objects that define the class criteria.
|
||||
match_all : bool, default=True
|
||||
If True, checks if *all* `required_tags` are present in the
|
||||
annotation's tags (subset check). If False, checks if *at least one*
|
||||
of the `required_tags` is present (intersection check).
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation meets the tag criteria, False otherwise.
|
||||
"""
|
||||
annotation_tags = set(sound_event_annotation.tags)
|
||||
|
||||
if match_all:
|
||||
return tags <= annotation_tags
|
||||
|
||||
return bool(tags & annotation_tags)
|
||||
|
||||
|
||||
def get_class_names_from_config(config: ClassesConfig) -> List[str]:
|
||||
"""Extract the list of class names from a ClassesConfig object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : ClassesConfig
|
||||
The loaded classes configuration object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
An ordered list of unique class names defined in the configuration.
|
||||
"""
|
||||
return [class_info.name for class_info in config.classes]
|
||||
|
||||
|
||||
def _encode_with_multiple_classifiers(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
classifiers: List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]],
|
||||
) -> Optional[str]:
|
||||
"""Encode an annotation by checking against a list of classifiers.
|
||||
|
||||
Internal helper function used by the `SoundEventEncoder`. It iterates
|
||||
through the provided list of (class_name, classifier_function) pairs.
|
||||
Returns the name associated with the first classifier function that
|
||||
returns True for the given annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to encode.
|
||||
classifiers : List[Tuple[str, Callable[[data.SoundEventAnnotation], bool]]]
|
||||
An ordered list where each tuple contains a class name and a function
|
||||
that returns True if the annotation matches that class. The order
|
||||
determines priority.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
The name of the first matching class, or None if no classifier matches.
|
||||
"""
|
||||
for class_name, classifier in classifiers:
|
||||
if classifier(sound_event_annotation):
|
||||
return class_name
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def build_sound_event_encoder(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> SoundEventEncoder:
|
||||
"""Build a sound event encoder function from the classes configuration.
|
||||
|
||||
The returned encoder function iterates through the class definitions in the
|
||||
order specified in the config. It assigns an annotation the name of the
|
||||
first class definition it matches.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : ClassesConfig
|
||||
The loaded and validated classes configuration object.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance used to look up term keys specified in the
|
||||
`TagInfo` objects within the configuration. Defaults to the global
|
||||
`batdetect2.targets.terms.registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventEncoder
|
||||
A callable function that takes a `SoundEventAnnotation` and returns
|
||||
an optional string representing the matched class name, or None if no
|
||||
class matches.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a term key specified in the configuration is not found in the
|
||||
provided `term_registry`.
|
||||
"""
|
||||
binary_classifiers = [
|
||||
(
|
||||
class_info.name,
|
||||
partial(
|
||||
_is_target_class,
|
||||
tags={
|
||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||
for tag_info in class_info.tags
|
||||
},
|
||||
match_all=class_info.match_type == "all",
|
||||
),
|
||||
)
|
||||
for class_info in config.classes
|
||||
]
|
||||
|
||||
return partial(
|
||||
_encode_with_multiple_classifiers,
|
||||
classifiers=binary_classifiers,
|
||||
)
|
||||
|
||||
|
||||
def _decode_class(
|
||||
name: str,
|
||||
mapping: Dict[str, List[data.Tag]],
|
||||
raise_on_error: bool = True,
|
||||
) -> List[data.Tag]:
|
||||
"""Decode a class name into a list of representative tags using a mapping.
|
||||
|
||||
Internal helper function used by the `SoundEventDecoder`. Looks up the
|
||||
provided class `name` in the `mapping` dictionary.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The class name to decode.
|
||||
mapping : Dict[str, List[data.Tag]]
|
||||
A dictionary mapping class names to lists of `soundevent.data.Tag`
|
||||
objects.
|
||||
raise_on_error : bool, default=True
|
||||
If True, raises a ValueError if the `name` is not found in the
|
||||
`mapping`. If False, returns an empty list if the `name` is not found.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.Tag]
|
||||
The list of tags associated with the class name, or an empty list if
|
||||
not found and `raise_on_error` is False.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `name` is not found in `mapping` and `raise_on_error` is True.
|
||||
"""
|
||||
if name not in mapping and raise_on_error:
|
||||
raise ValueError(f"Class {name} not found in mapping.")
|
||||
|
||||
if name not in mapping:
|
||||
return []
|
||||
|
||||
return mapping[name]
|
||||
|
||||
|
||||
def build_sound_event_decoder(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
raise_on_unmapped: bool = False,
|
||||
) -> SoundEventDecoder:
|
||||
"""Build a sound event decoder function from the classes configuration.
|
||||
|
||||
Creates a callable `SoundEventDecoder` that maps a class name string
|
||||
back to a list of representative `soundevent.data.Tag` objects based on
|
||||
the `ClassesConfig`. It uses the `output_tags` field if provided in a
|
||||
`TargetClass`, otherwise falls back to the `tags` field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : ClassesConfig
|
||||
The loaded and validated classes configuration object.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance used to look up term keys. Defaults to the
|
||||
global `batdetect2.targets.terms.registry`.
|
||||
raise_on_unmapped : bool, default=False
|
||||
If True, the returned decoder function will raise a ValueError if asked
|
||||
to decode a class name that is not in the configuration. If False, it
|
||||
will return an empty list for unmapped names.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventDecoder
|
||||
A callable function that takes a class name string and returns a list
|
||||
of `soundevent.data.Tag` objects.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a term key specified in the configuration (`output_tags`, `tags`, or
|
||||
`generic_class`) is not found in the provided `term_registry`.
|
||||
"""
|
||||
mapping = {}
|
||||
for class_info in config.classes:
|
||||
tags_to_use = (
|
||||
class_info.output_tags
|
||||
if class_info.output_tags is not None
|
||||
else class_info.tags
|
||||
)
|
||||
mapping[class_info.name] = [
|
||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||
for tag_info in tags_to_use
|
||||
]
|
||||
|
||||
return partial(
|
||||
_decode_class,
|
||||
mapping=mapping,
|
||||
raise_on_error=raise_on_unmapped,
|
||||
)
|
||||
|
||||
|
||||
def build_generic_class_tags(
|
||||
config: ClassesConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> List[data.Tag]:
|
||||
"""Extract and build the list of tags for the generic class from config.
|
||||
|
||||
Converts the list of `TagInfo` objects defined in `config.generic_class`
|
||||
into a list of `soundevent.data.Tag` objects using the term registry.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : ClassesConfig
|
||||
The loaded classes configuration object.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance for term lookups. Defaults to the global
|
||||
`batdetect2.targets.terms.registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.Tag]
|
||||
The list of fully constructed tags representing the generic class.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a term key specified in `config.generic_class` is not found in the
|
||||
provided `term_registry`.
|
||||
"""
|
||||
return [
|
||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||
for tag_info in config.generic_class
|
||||
]
|
||||
|
||||
|
||||
def load_classes_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> ClassesConfig:
|
||||
"""Load the target classes configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (YAML).
|
||||
field : str, optional
|
||||
If the classes configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ClassesConfig
|
||||
The loaded and validated classes configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the ClassesConfig schema
|
||||
or if class names are not unique.
|
||||
"""
|
||||
return load_config(path, schema=ClassesConfig, field=field)
|
||||
|
||||
|
||||
def load_encoder_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> SoundEventEncoder:
|
||||
"""Load a class encoder function directly from a configuration file.
|
||||
|
||||
This is a convenience function that combines loading the `ClassesConfig`
|
||||
from a file and building the final `SoundEventEncoder` function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (e.g., YAML).
|
||||
field : str, optional
|
||||
If the classes configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance used for term lookups. Defaults to the
|
||||
global `batdetect2.targets.terms.registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventEncoder
|
||||
The final encoder function ready to classify annotations.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the ClassesConfig schema
|
||||
or if class names are not unique.
|
||||
KeyError
|
||||
If a term key specified in the configuration is not found in the
|
||||
provided `term_registry` during the build process.
|
||||
"""
|
||||
config = load_classes_config(path, field=field)
|
||||
return build_sound_event_encoder(config, term_registry=term_registry)
|
||||
|
||||
|
||||
def load_decoder_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
raise_on_unmapped: bool = False,
|
||||
) -> SoundEventDecoder:
|
||||
"""Load a class decoder function directly from a configuration file.
|
||||
|
||||
This is a convenience function that combines loading the `ClassesConfig`
|
||||
from a file and building the final `SoundEventDecoder` function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (e.g., YAML).
|
||||
field : str, optional
|
||||
If the classes configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance used for term lookups. Defaults to the
|
||||
global `batdetect2.targets.terms.registry`.
|
||||
raise_on_unmapped : bool, default=False
|
||||
If True, the returned decoder function will raise a ValueError if asked
|
||||
to decode a class name that is not in the configuration. If False, it
|
||||
will return an empty list for unmapped names.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventDecoder
|
||||
The final decoder function ready to convert class names back into tags.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the ClassesConfig schema
|
||||
or if class names are not unique.
|
||||
KeyError
|
||||
If a term key specified in the configuration is not found in the
|
||||
provided `term_registry` during the build process.
|
||||
"""
|
||||
config = load_classes_config(path, field=field)
|
||||
return build_sound_event_decoder(
|
||||
config,
|
||||
term_registry=term_registry,
|
||||
raise_on_unmapped=raise_on_unmapped,
|
||||
)
|
315
batdetect2/targets/filtering.py
Normal file
315
batdetect2/targets/filtering.py
Normal file
@ -0,0 +1,315 @@
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Callable, List, Literal, Optional, Set
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
get_tag_from_info,
|
||||
term_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FilterConfig",
|
||||
"FilterRule",
|
||||
"SoundEventFilter",
|
||||
"build_sound_event_filter",
|
||||
"build_filter_from_rule",
|
||||
"load_filter_config",
|
||||
"load_filter_from_config",
|
||||
]
|
||||
|
||||
|
||||
SoundEventFilter = Callable[[data.SoundEventAnnotation], bool]
|
||||
"""Type alias for a filter function.
|
||||
|
||||
A filter function accepts a soundevent.data.SoundEventAnnotation object
|
||||
and returns True if the annotation should be kept based on the filter's
|
||||
criteria, or False if it should be discarded.
|
||||
"""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FilterRule(BaseConfig):
|
||||
"""Defines a single rule for filtering sound event annotations.
|
||||
|
||||
Based on the `match_type`, this rule checks if the tags associated with a
|
||||
sound event annotation meet certain criteria relative to the `tags` list
|
||||
defined in this rule.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
match_type : Literal["any", "all", "exclude", "equal"]
|
||||
Determines how the `tags` list is used:
|
||||
- "any": Pass if the annotation has at least one tag from the list.
|
||||
- "all": Pass if the annotation has all tags from the list (it can
|
||||
have others too).
|
||||
- "exclude": Pass if the annotation has none of the tags from the list.
|
||||
- "equal": Pass if the annotation's tags are exactly the same set as
|
||||
provided in the list.
|
||||
tags : List[TagInfo]
|
||||
A list of tags (defined using TagInfo for configuration) that this
|
||||
rule operates on.
|
||||
"""
|
||||
|
||||
match_type: Literal["any", "all", "exclude", "equal"]
|
||||
tags: List[TagInfo]
|
||||
|
||||
|
||||
def has_any_tag(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
) -> bool:
|
||||
"""Check if the annotation has at least one of the specified tags.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
tags : Set[data.Tag]
|
||||
The set of tags to look for.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation has one or more tags from the specified set,
|
||||
False otherwise.
|
||||
"""
|
||||
sound_event_tags = set(sound_event_annotation.tags)
|
||||
return bool(tags & sound_event_tags)
|
||||
|
||||
|
||||
def contains_tags(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
) -> bool:
|
||||
"""Check if the annotation contains all of the specified tags.
|
||||
|
||||
The annotation may have additional tags beyond those specified.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
tags : Set[data.Tag]
|
||||
The set of tags that must all be present in the annotation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation's tags are a superset of the specified tags,
|
||||
False otherwise.
|
||||
"""
|
||||
sound_event_tags = set(sound_event_annotation.tags)
|
||||
return tags < sound_event_tags
|
||||
|
||||
|
||||
def does_not_have_tags(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
):
|
||||
"""Check if the annotation has none of the specified tags.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
tags : Set[data.Tag]
|
||||
The set of tags that must *not* be present in the annotation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation has zero tags in common with the specified set,
|
||||
False otherwise.
|
||||
"""
|
||||
return not has_any_tag(sound_event_annotation, tags)
|
||||
|
||||
|
||||
def equal_tags(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
) -> bool:
|
||||
"""Check if the annotation's tags are exactly equal to the specified set.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
tags : Set[data.Tag]
|
||||
The exact set of tags the annotation must have.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation's tags set is identical to the specified set,
|
||||
False otherwise.
|
||||
"""
|
||||
sound_event_tags = set(sound_event_annotation.tags)
|
||||
return tags == sound_event_tags
|
||||
|
||||
|
||||
def build_filter_from_rule(
|
||||
rule: FilterRule,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> SoundEventFilter:
|
||||
"""Creates a callable filter function from a single FilterRule.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rule : FilterRule
|
||||
The filter rule configuration object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventFilter
|
||||
A function that takes a SoundEventAnnotation and returns True if it
|
||||
passes the rule, False otherwise.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the rule contains an invalid `match_type`.
|
||||
"""
|
||||
tag_set = {
|
||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||
for tag_info in rule.tags
|
||||
}
|
||||
|
||||
if rule.match_type == "any":
|
||||
return partial(has_any_tag, tags=tag_set)
|
||||
|
||||
if rule.match_type == "all":
|
||||
return partial(contains_tags, tags=tag_set)
|
||||
|
||||
if rule.match_type == "exclude":
|
||||
return partial(does_not_have_tags, tags=tag_set)
|
||||
|
||||
if rule.match_type == "equal":
|
||||
return partial(equal_tags, tags=tag_set)
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid match type {rule.match_type}. Valid types "
|
||||
"are: 'any', 'all', 'exclude' and 'equal'"
|
||||
)
|
||||
|
||||
|
||||
def _passes_all_filters(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
filters: List[SoundEventFilter],
|
||||
) -> bool:
|
||||
"""Check if the annotation passes all provided filters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to check.
|
||||
filters : List[SoundEventFilter]
|
||||
A list of filter functions to apply.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation passes all filters, False otherwise.
|
||||
"""
|
||||
for filter_fn in filters:
|
||||
if not filter_fn(sound_event_annotation):
|
||||
logging.debug(
|
||||
f"Sound event annotation {sound_event_annotation.uuid} "
|
||||
f"excluded due to rule {filter_fn}",
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class FilterConfig(BaseConfig):
|
||||
"""Configuration model for defining a list of filter rules.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
rules : List[FilterRule]
|
||||
A list of FilterRule objects. An annotation must pass all rules in
|
||||
this list to be considered valid by the filter built from this config.
|
||||
"""
|
||||
|
||||
rules: List[FilterRule] = Field(default_factory=list)
|
||||
|
||||
|
||||
def build_sound_event_filter(
|
||||
config: FilterConfig,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> SoundEventFilter:
|
||||
"""Builds a merged filter function from a FilterConfig object.
|
||||
|
||||
Creates individual filter functions for each rule in the configuration
|
||||
and merges them using AND logic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : FilterConfig
|
||||
The configuration object containing the list of filter rules.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventFilter
|
||||
A single callable filter function that applies all defined rules.
|
||||
"""
|
||||
filters = [
|
||||
build_filter_from_rule(rule, term_registry=term_registry)
|
||||
for rule in config.rules
|
||||
]
|
||||
return partial(_passes_all_filters, filters=filters)
|
||||
|
||||
|
||||
def load_filter_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> FilterConfig:
|
||||
"""Loads the filter configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (YAML).
|
||||
field : Optional[str], optional
|
||||
If the filter configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
FilterConfig
|
||||
The loaded and validated filter configuration object.
|
||||
"""
|
||||
return load_config(path, schema=FilterConfig, field=field)
|
||||
|
||||
|
||||
def load_filter_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> SoundEventFilter:
|
||||
"""Loads filter configuration from a file and builds the filter function.
|
||||
|
||||
This is a convenience function that combines loading the configuration
|
||||
and building the final callable filter function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (YAML).
|
||||
field : Optional[str], optional
|
||||
If the filter configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventFilter
|
||||
The final merged filter function ready to be used.
|
||||
"""
|
||||
config = load_filter_config(path=path, field=field)
|
||||
return build_sound_event_filter(config, term_registry=term_registry)
|
462
batdetect2/targets/rois.py
Normal file
462
batdetect2/targets/rois.py
Normal file
@ -0,0 +1,462 @@
|
||||
"""Handles mapping between geometric ROIs and target representations.
|
||||
|
||||
This module defines the interface and provides implementation for converting
|
||||
a sound event's Region of Interest (ROI), typically represented by a
|
||||
`soundevent.data.Geometry` object like a `BoundingBox`, into a format
|
||||
suitable for use as a machine learning target. This usually involves:
|
||||
|
||||
1. Extracting a single reference point (time, frequency) from the geometry.
|
||||
2. Calculating relevant size dimensions (e.g., duration/width,
|
||||
bandwidth/height) and applying scaling factors.
|
||||
|
||||
It also provides the inverse operation: recovering an approximate geometric ROI
|
||||
(like a `BoundingBox`) from a predicted reference point and predicted size
|
||||
dimensions.
|
||||
|
||||
This logic is encapsulated within components adhering to the `ROITargetMapper`
|
||||
protocol. Configuration for this mapping (e.g., which reference point to use,
|
||||
scaling factors) is managed by the `ROIConfig`. This module separates the
|
||||
*geometric* aspect of target definition from the *semantic* classification
|
||||
handled in `batdetect2.targets.classes`.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Protocol, Tuple
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data, geometry
|
||||
from soundevent.geometry.operations import Positions
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
|
||||
__all__ = [
|
||||
"ROITargetMapper",
|
||||
"ROIConfig",
|
||||
"BBoxEncoder",
|
||||
"build_roi_mapper",
|
||||
"load_roi_mapper",
|
||||
"DEFAULT_POSITION",
|
||||
"SIZE_WIDTH",
|
||||
"SIZE_HEIGHT",
|
||||
"SIZE_ORDER",
|
||||
"DEFAULT_TIME_SCALE",
|
||||
"DEFAULT_FREQUENCY_SCALE",
|
||||
]
|
||||
|
||||
SIZE_WIDTH = "width"
|
||||
"""Standard name for the width/time dimension component ('width')."""
|
||||
|
||||
SIZE_HEIGHT = "height"
|
||||
"""Standard name for the height/frequency dimension component ('height')."""
|
||||
|
||||
SIZE_ORDER = (SIZE_WIDTH, SIZE_HEIGHT)
|
||||
"""Standard order of dimensions for size arrays ([width, height])."""
|
||||
|
||||
DEFAULT_TIME_SCALE = 1000.0
|
||||
"""Default scaling factor for time duration."""
|
||||
|
||||
DEFAULT_FREQUENCY_SCALE = 1 / 859.375
|
||||
"""Default scaling factor for frequency bandwidth."""
|
||||
|
||||
|
||||
DEFAULT_POSITION = "bottom-left"
|
||||
"""Default reference position within the geometry ('bottom-left' corner)."""
|
||||
|
||||
|
||||
class ROITargetMapper(Protocol):
|
||||
"""Protocol defining the interface for ROI-to-target mapping.
|
||||
|
||||
Specifies the methods required for converting a geometric region of interest
|
||||
(`soundevent.data.Geometry`) into a target representation (reference point
|
||||
and scaled dimensions) and for recovering an approximate ROI from that
|
||||
representation.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
dimension_names : List[str]
|
||||
A list containing the names of the dimensions returned by
|
||||
`get_roi_size` and expected by `recover_roi`
|
||||
(e.g., ['width', 'height']).
|
||||
"""
|
||||
|
||||
dimension_names: List[str]
|
||||
|
||||
def get_roi_position(self, geom: data.Geometry) -> tuple[float, float]:
|
||||
"""Extract the reference position from a geometry.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
geom : soundevent.data.Geometry
|
||||
The input geometry (e.g., BoundingBox, Polygon).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
The calculated reference position as (time, frequency) coordinates,
|
||||
based on the implementing class's configuration (e.g., "center",
|
||||
"bottom-left").
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the position cannot be calculated for the given geometry type
|
||||
or configured reference point.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
|
||||
"""Calculate the scaled target dimensions from a geometry.
|
||||
|
||||
Computes the relevant size measures.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
geom : soundevent.data.Geometry
|
||||
The input geometry.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
A NumPy array containing the scaled dimensions corresponding to
|
||||
`dimension_names`. For bounding boxes, typically contains
|
||||
`[scaled_width, scaled_height]`.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError, ValueError
|
||||
If the size cannot be computed for the given geometry type.
|
||||
"""
|
||||
...
|
||||
|
||||
def recover_roi(
|
||||
self, pos: tuple[float, float], dims: np.ndarray
|
||||
) -> data.Geometry:
|
||||
"""Recover an approximate ROI from a position and target dimensions.
|
||||
|
||||
Performs the inverse mapping: takes a reference position and the
|
||||
predicted dimensions and reconstructs a geometric representation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
The reference position (time, frequency).
|
||||
dims : np.ndarray
|
||||
The NumPy array containing the dimensions, matching the order
|
||||
specified by `dimension_names`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Geometry
|
||||
The reconstructed geometry.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the number of provided dimensions `dims` does not match
|
||||
`dimension_names` or if reconstruction fails.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ROIConfig(BaseConfig):
|
||||
"""Configuration for mapping Regions of Interest (ROIs).
|
||||
|
||||
Defines parameters controlling how geometric ROIs are converted into
|
||||
target representations (reference points and scaled sizes).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
position : Positions, default="bottom-left"
|
||||
Specifies the reference point within the geometry (e.g., bounding box)
|
||||
to use as the target location (e.g., "center", "bottom-left").
|
||||
See `soundevent.geometry.operations.Positions`.
|
||||
time_scale : float, default=1000.0
|
||||
Scaling factor applied to the time duration (width) of the ROI
|
||||
when calculating the target size representation. Must match model
|
||||
expectations.
|
||||
frequency_scale : float, default=1/859.375
|
||||
Scaling factor applied to the frequency bandwidth (height) of the ROI
|
||||
when calculating the target size representation. Must match model
|
||||
expectations.
|
||||
"""
|
||||
|
||||
position: Positions = DEFAULT_POSITION
|
||||
time_scale: float = DEFAULT_TIME_SCALE
|
||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE
|
||||
|
||||
|
||||
class BBoxEncoder(ROITargetMapper):
|
||||
"""Concrete implementation of `ROITargetMapper` focused on Bounding Boxes.
|
||||
|
||||
This class implements the ROI mapping protocol primarily for
|
||||
`soundevent.data.BoundingBox` geometry. It extracts reference points,
|
||||
calculates scaled width/height, and recovers bounding boxes based on
|
||||
configured position and scaling factors.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
dimension_names : List[str]
|
||||
Specifies the output dimension names as ['width', 'height'].
|
||||
position : Positions
|
||||
The configured reference point type (e.g., "center", "bottom-left").
|
||||
time_scale : float
|
||||
The configured scaling factor for the time dimension (width).
|
||||
frequency_scale : float
|
||||
The configured scaling factor for the frequency dimension (height).
|
||||
"""
|
||||
|
||||
dimension_names = [SIZE_WIDTH, SIZE_HEIGHT]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
position: Positions = DEFAULT_POSITION,
|
||||
time_scale: float = DEFAULT_TIME_SCALE,
|
||||
frequency_scale: float = DEFAULT_FREQUENCY_SCALE,
|
||||
):
|
||||
"""Initialize the BBoxEncoder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
position : Positions, default="bottom-left"
|
||||
Reference point type within the bounding box.
|
||||
time_scale : float, default=1000.0
|
||||
Scaling factor for time duration (width).
|
||||
frequency_scale : float, default=1/859.375
|
||||
Scaling factor for frequency bandwidth (height).
|
||||
"""
|
||||
self.position: Positions = position
|
||||
self.time_scale = time_scale
|
||||
self.frequency_scale = frequency_scale
|
||||
|
||||
def get_roi_position(self, geom: data.Geometry) -> Tuple[float, float]:
|
||||
"""Extract the configured reference position from the geometry.
|
||||
|
||||
Uses `soundevent.geometry.get_geometry_point`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
geom : soundevent.data.Geometry
|
||||
Input geometry (e.g., BoundingBox).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
Reference position (time, frequency).
|
||||
"""
|
||||
return geometry.get_geometry_point(geom, position=self.position)
|
||||
|
||||
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
|
||||
"""Calculate the scaled [width, height] from the geometry's bounds.
|
||||
|
||||
Computes the bounding box, extracts duration and bandwidth, and applies
|
||||
the configured `time_scale` and `frequency_scale`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
geom : soundevent.data.Geometry
|
||||
Input geometry.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
A 1D NumPy array: `[scaled_width, scaled_height]`.
|
||||
"""
|
||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||
geom
|
||||
)
|
||||
return np.array(
|
||||
[
|
||||
(end_time - start_time) * self.time_scale,
|
||||
(high_freq - low_freq) * self.frequency_scale,
|
||||
]
|
||||
)
|
||||
|
||||
def recover_roi(
|
||||
self,
|
||||
pos: tuple[float, float],
|
||||
dims: np.ndarray,
|
||||
) -> data.Geometry:
|
||||
"""Recover a BoundingBox from a position and scaled dimensions.
|
||||
|
||||
Un-scales the input dimensions using the configured factors and
|
||||
reconstructs a `soundevent.data.BoundingBox` centered or anchored at
|
||||
the given reference `pos` according to the configured `position` type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
Reference position (time, frequency).
|
||||
dims : np.ndarray
|
||||
NumPy array containing the *scaled* dimensions, expected order is
|
||||
[scaled_width, scaled_height].
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.BoundingBox
|
||||
The reconstructed bounding box.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `dims` does not have the expected shape (length 2).
|
||||
"""
|
||||
if dims.ndim != 1 or dims.shape[0] != 2:
|
||||
raise ValueError(
|
||||
"Dimension array does not have the expected shape. "
|
||||
f"({dims.shape = }) != ([2])"
|
||||
)
|
||||
|
||||
width, height = dims
|
||||
return _build_bounding_box(
|
||||
pos,
|
||||
duration=width / self.time_scale,
|
||||
bandwidth=height / self.frequency_scale,
|
||||
position=self.position,
|
||||
)
|
||||
|
||||
|
||||
def build_roi_mapper(config: ROIConfig) -> ROITargetMapper:
|
||||
"""Factory function to create an ROITargetMapper from configuration.
|
||||
|
||||
Currently creates a `BBoxEncoder` instance based on the provided
|
||||
`ROIConfig`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : ROIConfig
|
||||
Configuration object specifying ROI mapping parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ROITargetMapper
|
||||
An initialized `BBoxEncoder` instance configured with the settings
|
||||
from `config`.
|
||||
"""
|
||||
return BBoxEncoder(
|
||||
position=config.position,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
)
|
||||
|
||||
|
||||
def load_roi_mapper(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> ROITargetMapper:
|
||||
"""Load ROI mapping configuration from a file and build the mapper.
|
||||
|
||||
Convenience function that loads an `ROIConfig` from the specified file
|
||||
(and optional field) and then uses `build_roi_mapper` to create the
|
||||
corresponding `ROITargetMapper` instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file (e.g., YAML).
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
ROI configuration. If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ROITargetMapper
|
||||
An initialized ROI mapper instance based on the configuration file.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
|
||||
TypeError
|
||||
If the configuration file cannot be found, parsed, validated, or if
|
||||
the specified `field` is invalid.
|
||||
"""
|
||||
config = load_config(path=path, schema=ROIConfig, field=field)
|
||||
return build_roi_mapper(config)
|
||||
|
||||
|
||||
VALID_POSITIONS = [
|
||||
"bottom-left",
|
||||
"bottom-right",
|
||||
"top-left",
|
||||
"top-right",
|
||||
"center-left",
|
||||
"center-right",
|
||||
"top-center",
|
||||
"bottom-center",
|
||||
"center",
|
||||
"centroid",
|
||||
"point_on_surface",
|
||||
]
|
||||
|
||||
|
||||
def _build_bounding_box(
|
||||
pos: tuple[float, float],
|
||||
duration: float,
|
||||
bandwidth: float,
|
||||
position: Positions = DEFAULT_POSITION,
|
||||
) -> data.BoundingBox:
|
||||
"""Construct a BoundingBox from a reference point, size, and position type.
|
||||
|
||||
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
|
||||
coordinates [start_time, low_freq, end_time, high_freq] based on where
|
||||
the input `pos` (time, freq) is located relative to the box (e.g.,
|
||||
center, corner).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
Reference position (time, frequency).
|
||||
duration : float
|
||||
The required *unscaled* duration (width) of the bounding box.
|
||||
bandwidth : float
|
||||
The required *unscaled* frequency bandwidth (height) of the bounding
|
||||
box.
|
||||
position : Positions, default="bottom-left"
|
||||
Specifies which part of the bounding box the input `pos` corresponds to.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.BoundingBox
|
||||
The constructed bounding box object.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `position` is not a recognized value or format.
|
||||
"""
|
||||
time, freq = pos
|
||||
if position in ["center", "centroid", "point_on_surface"]:
|
||||
return data.BoundingBox(
|
||||
coordinates=[
|
||||
time - duration / 2,
|
||||
freq - bandwidth / 2,
|
||||
time + duration / 2,
|
||||
freq + bandwidth / 2,
|
||||
]
|
||||
)
|
||||
|
||||
if position not in VALID_POSITIONS:
|
||||
raise ValueError(
|
||||
f"Invalid position: {position}. "
|
||||
f"Valid options are: {VALID_POSITIONS}"
|
||||
)
|
||||
|
||||
y, x = position.split("-")
|
||||
|
||||
start_time = {
|
||||
"left": time,
|
||||
"center": time - duration / 2,
|
||||
"right": time - duration,
|
||||
}[x]
|
||||
|
||||
low_freq = {
|
||||
"bottom": freq,
|
||||
"center": freq - bandwidth / 2,
|
||||
"top": freq - bandwidth,
|
||||
}[y]
|
||||
|
||||
return data.BoundingBox(
|
||||
coordinates=[
|
||||
start_time,
|
||||
low_freq,
|
||||
start_time + duration,
|
||||
low_freq + bandwidth,
|
||||
]
|
||||
)
|
495
batdetect2/targets/terms.py
Normal file
495
batdetect2/targets/terms.py
Normal file
@ -0,0 +1,495 @@
|
||||
"""Manages the vocabulary (Terms and Tags) for defining training targets.
|
||||
|
||||
This module provides the necessary tools to declare, register, and manage the
|
||||
set of `soundevent.data.Term` objects used throughout the `batdetect2.targets`
|
||||
sub-package. It establishes a consistent vocabulary for filtering,
|
||||
transforming, and classifying sound events based on their annotations (Tags).
|
||||
|
||||
The core component is the `TermRegistry`, which maps unique string keys
|
||||
(aliases) to specific `Term` definitions. This allows users to refer to complex
|
||||
terms using simple, consistent keys in configuration files and code.
|
||||
|
||||
Terms can be pre-defined, loaded from the `soundevent.terms` library, defined
|
||||
programmatically, or loaded from external configuration files (e.g., YAML).
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from inspect import getmembers
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.configs import load_config
|
||||
|
||||
__all__ = [
|
||||
"call_type",
|
||||
"individual",
|
||||
"data_source",
|
||||
"get_tag_from_info",
|
||||
"TermInfo",
|
||||
"TagInfo",
|
||||
]
|
||||
|
||||
# The default key used to reference the 'generic_class' term.
|
||||
# Often used implicitly when defining classification targets.
|
||||
GENERIC_CLASS_KEY = "class"
|
||||
|
||||
|
||||
data_source = data.Term(
|
||||
name="soundevent:data_source",
|
||||
label="Data Source",
|
||||
definition=(
|
||||
"A unique identifier for the source of the data, typically "
|
||||
"representing the project, site, or deployment context."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
call_type = data.Term(
|
||||
name="soundevent:call_type",
|
||||
label="Call Type",
|
||||
definition=(
|
||||
"A broad categorization of animal vocalizations based on their "
|
||||
"intended function or purpose (e.g., social, distress, mating, "
|
||||
"territorial, echolocation)."
|
||||
),
|
||||
)
|
||||
"""Term representing the broad functional category of a vocalization."""
|
||||
|
||||
individual = data.Term(
|
||||
name="soundevent:individual",
|
||||
label="Individual",
|
||||
definition=(
|
||||
"An id for an individual animal. In the context of bioacoustic "
|
||||
"annotation, this term is used to label vocalizations that are "
|
||||
"attributed to a specific individual."
|
||||
),
|
||||
)
|
||||
"""Term used for tags identifying a specific individual animal."""
|
||||
|
||||
generic_class = data.Term(
|
||||
name="soundevent:class",
|
||||
label="Class",
|
||||
definition=(
|
||||
"A generic term representing the name of a class within a "
|
||||
"classification model. Its specific meaning is determined by "
|
||||
"the model's application."
|
||||
),
|
||||
)
|
||||
"""Generic term representing a classification model's output class label."""
|
||||
|
||||
|
||||
class TermRegistry(Mapping[str, data.Term]):
|
||||
"""Manages a registry mapping unique keys to Term definitions.
|
||||
|
||||
This class acts as the central repository for the vocabulary of terms
|
||||
used within the target definition process. It allows registering terms
|
||||
with simple string keys and retrieving them consistently.
|
||||
"""
|
||||
|
||||
def __init__(self, terms: Optional[Dict[str, data.Term]] = None):
|
||||
"""Initializes the TermRegistry.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
terms : dict[str, soundevent.data.Term], optional
|
||||
An optional dictionary of initial key-to-Term mappings
|
||||
to populate the registry with. Defaults to an empty registry.
|
||||
"""
|
||||
self._terms: Dict[str, data.Term] = terms or {}
|
||||
|
||||
def __getitem__(self, key: str) -> data.Term:
|
||||
return self._terms[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._terms)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._terms)
|
||||
|
||||
def add_term(self, key: str, term: data.Term) -> None:
|
||||
"""Adds a Term object to the registry with the specified key.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The unique string key to associate with the term.
|
||||
term : soundevent.data.Term
|
||||
The soundevent.data.Term object to register.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a term with the provided key already exists in the
|
||||
registry.
|
||||
"""
|
||||
if key in self._terms:
|
||||
raise KeyError("A term with the provided key already exists.")
|
||||
|
||||
self._terms[key] = term
|
||||
|
||||
def get_term(self, key: str) -> data.Term:
|
||||
"""Retrieves a registered term by its unique key.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The unique string key of the term to retrieve.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Term
|
||||
The corresponding soundevent.data.Term object.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If no term with the specified key is found, with a
|
||||
helpful message suggesting listing available keys.
|
||||
"""
|
||||
try:
|
||||
return self._terms[key]
|
||||
except KeyError as err:
|
||||
raise KeyError(
|
||||
"No term found for key "
|
||||
f"'{key}'. Ensure it is registered or loaded. "
|
||||
f"Available keys: {', '.join(self.get_keys())}"
|
||||
) from err
|
||||
|
||||
def add_custom_term(
|
||||
self,
|
||||
key: str,
|
||||
name: Optional[str] = None,
|
||||
uri: Optional[str] = None,
|
||||
label: Optional[str] = None,
|
||||
definition: Optional[str] = None,
|
||||
) -> data.Term:
|
||||
"""Creates a new Term from attributes and adds it to the registry.
|
||||
|
||||
This is useful for defining terms directly in code or when loading
|
||||
from configuration files where only attributes are provided.
|
||||
|
||||
If optional fields (`name`, `label`, `definition`) are not provided,
|
||||
reasonable defaults are used (`key` for name/label, "Unknown" for
|
||||
definition).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The unique string key for the new term.
|
||||
name : str, optional
|
||||
The name for the new term (defaults to `key`).
|
||||
uri : str, optional
|
||||
The URI for the new term (optional).
|
||||
label : str, optional
|
||||
The display label for the new term (defaults to `key`).
|
||||
definition : str, optional
|
||||
The definition for the new term (defaults to "Unknown").
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Term
|
||||
The newly created and registered soundevent.data.Term object.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a term with the provided key already exists.
|
||||
"""
|
||||
term = data.Term(
|
||||
name=name or key,
|
||||
label=label or key,
|
||||
uri=uri,
|
||||
definition=definition or "Unknown",
|
||||
)
|
||||
self.add_term(key, term)
|
||||
return term
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
"""Returns a list of all keys currently registered.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
A list of strings representing the keys of all registered terms.
|
||||
"""
|
||||
return list(self._terms.keys())
|
||||
|
||||
def get_terms(self) -> List[data.Term]:
|
||||
"""Returns a list of all registered terms.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[soundevent.data.Term]
|
||||
A list containing all registered Term objects.
|
||||
"""
|
||||
return list(self._terms.values())
|
||||
|
||||
def remove_key(self, key: str) -> None:
|
||||
del self._terms[key]
|
||||
|
||||
|
||||
term_registry = TermRegistry(
|
||||
terms=dict(
|
||||
[
|
||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||
("event", call_type),
|
||||
("individual", individual),
|
||||
("data_source", data_source),
|
||||
(GENERIC_CLASS_KEY, generic_class),
|
||||
]
|
||||
)
|
||||
)
|
||||
"""The default, globally accessible TermRegistry instance.
|
||||
|
||||
It is pre-populated with standard terms from `soundevent.terms` and common
|
||||
terms defined in this module (`call_type`, `individual`, `generic_class`).
|
||||
Functions in this module use this registry by default unless another instance
|
||||
is explicitly passed.
|
||||
"""
|
||||
|
||||
|
||||
def get_term_from_key(
|
||||
key: str,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> data.Term:
|
||||
"""Convenience function to retrieve a term by key from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
instance is provided.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The unique key of the term to retrieve.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to search in. Defaults to the global
|
||||
`registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Term
|
||||
The corresponding soundevent.data.Term object.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the key is not found in the specified registry.
|
||||
"""
|
||||
return term_registry.get_term(key)
|
||||
|
||||
|
||||
def get_term_keys(term_registry: TermRegistry = term_registry) -> List[str]:
|
||||
"""Convenience function to get all registered keys from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
instance is provided.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to query. Defaults to the global `registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[str]
|
||||
A list of strings representing the keys of all registered terms.
|
||||
"""
|
||||
return term_registry.get_keys()
|
||||
|
||||
|
||||
def get_terms(term_registry: TermRegistry = term_registry) -> List[data.Term]:
|
||||
"""Convenience function to get all registered terms from a registry.
|
||||
|
||||
Uses the global default registry unless a specific `term_registry`
|
||||
instance is provided.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to query. Defaults to the global `registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[soundevent.data.Term]
|
||||
A list containing all registered Term objects.
|
||||
"""
|
||||
return term_registry.get_terms()
|
||||
|
||||
|
||||
class TagInfo(BaseModel):
|
||||
"""Represents information needed to define a specific Tag.
|
||||
|
||||
This model is typically used in configuration files (e.g., YAML) to
|
||||
specify tags used for filtering, target class definition, or associating
|
||||
tags with output classes. It links a tag value to a term definition
|
||||
via the term's registry key.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
value : str
|
||||
The value of the tag (e.g., "Myotis myotis", "Echolocation").
|
||||
key : str, default="class"
|
||||
The key (alias) of the term associated with this tag, as
|
||||
registered in the TermRegistry. Defaults to "class", implying
|
||||
it represents a classification target label by default.
|
||||
"""
|
||||
|
||||
value: str
|
||||
key: str = GENERIC_CLASS_KEY
|
||||
|
||||
|
||||
def get_tag_from_info(
|
||||
tag_info: TagInfo,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> data.Tag:
|
||||
"""Creates a soundevent.data.Tag object from TagInfo data.
|
||||
|
||||
Looks up the term using the key in the provided `tag_info` from the
|
||||
specified registry and constructs a Tag object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tag_info : TagInfo
|
||||
The TagInfo object containing the value and term key.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to use for term lookup. Defaults to the
|
||||
global `registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Tag
|
||||
A soundevent.data.Tag object corresponding to the input info.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the term key specified in `tag_info.key` is not found
|
||||
in the registry.
|
||||
"""
|
||||
term = get_term_from_key(tag_info.key, term_registry=term_registry)
|
||||
return data.Tag(term=term, value=tag_info.value)
|
||||
|
||||
|
||||
class TermInfo(BaseModel):
|
||||
"""Represents the definition of a Term within a configuration file.
|
||||
|
||||
This model allows users to define custom terms directly in configuration
|
||||
files (e.g., YAML) which can then be loaded into the TermRegistry.
|
||||
It mirrors the parameters of `TermRegistry.add_custom_term`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
key : str
|
||||
The unique key (alias) that will be used to register and
|
||||
reference this term.
|
||||
label : str, optional
|
||||
The optional display label for the term. Defaults to `key`
|
||||
if not provided during registration.
|
||||
name : str, optional
|
||||
The optional formal name for the term. Defaults to `key`
|
||||
if not provided during registration.
|
||||
uri : str, optional
|
||||
The optional URI identifying the term (e.g., from a standard
|
||||
vocabulary).
|
||||
definition : str, optional
|
||||
The optional textual definition of the term. Defaults to
|
||||
"Unknown" if not provided during registration.
|
||||
"""
|
||||
|
||||
key: str
|
||||
label: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
uri: Optional[str] = None
|
||||
definition: Optional[str] = None
|
||||
|
||||
|
||||
class TermConfig(BaseModel):
|
||||
"""Pydantic schema for loading a list of term definitions from config.
|
||||
|
||||
This model typically corresponds to a section in a configuration file
|
||||
(e.g., YAML) containing a list of term definitions to be registered.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
terms : list[TermInfo]
|
||||
A list of TermInfo objects, each defining a term to be
|
||||
registered. Defaults to an empty list.
|
||||
|
||||
Examples
|
||||
--------
|
||||
Example YAML structure:
|
||||
|
||||
```yaml
|
||||
terms:
|
||||
- key: species
|
||||
uri: dwc:scientificName
|
||||
label: Scientific Name
|
||||
- key: my_custom_term
|
||||
name: My Custom Term
|
||||
definition: Describes a specific project attribute.
|
||||
# ... more TermInfo definitions
|
||||
```
|
||||
"""
|
||||
|
||||
terms: List[TermInfo] = Field(default_factory=list)
|
||||
|
||||
|
||||
def load_terms_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
term_registry: TermRegistry = term_registry,
|
||||
) -> Dict[str, data.Term]:
|
||||
"""Loads term definitions from a configuration file and registers them.
|
||||
|
||||
Parses a configuration file (e.g., YAML) using the TermConfig schema,
|
||||
extracts the list of TermInfo definitions, and adds each one as a
|
||||
custom term to the specified TermRegistry instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
The path to the configuration file.
|
||||
field : str, optional
|
||||
Optional key indicating a specific section within the config
|
||||
file where the 'terms' list is located. If None, expects the
|
||||
list directly at the top level or within a structure matching
|
||||
TermConfig schema.
|
||||
term_registry : TermRegistry, optional
|
||||
The TermRegistry instance to add the loaded terms to. Defaults to
|
||||
the global `registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, soundevent.data.Term]
|
||||
A dictionary mapping the keys of the newly added terms to their
|
||||
corresponding Term objects.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the TermConfig schema.
|
||||
KeyError
|
||||
If a term key loaded from the config conflicts with a key
|
||||
already present in the registry.
|
||||
"""
|
||||
data = load_config(path, schema=TermConfig, field=field)
|
||||
return {
|
||||
info.key: term_registry.add_custom_term(
|
||||
info.key,
|
||||
name=info.name,
|
||||
uri=info.uri,
|
||||
label=info.label,
|
||||
definition=info.definition,
|
||||
)
|
||||
for info in data.terms
|
||||
}
|
||||
|
||||
|
||||
def register_term(
|
||||
key: str, term: data.Term, registry: TermRegistry = term_registry
|
||||
) -> None:
|
||||
registry.add_term(key, term)
|
699
batdetect2/targets/transform.py
Normal file
699
batdetect2/targets/transform.py
Normal file
@ -0,0 +1,699 @@
|
||||
import importlib
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.terms import (
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
get_tag_from_info,
|
||||
get_term_from_key,
|
||||
)
|
||||
from batdetect2.targets.terms import (
|
||||
term_registry as default_term_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DerivationRegistry",
|
||||
"DeriveTagRule",
|
||||
"MapValueRule",
|
||||
"ReplaceRule",
|
||||
"SoundEventTransformation",
|
||||
"TransformConfig",
|
||||
"build_transform_from_rule",
|
||||
"build_transformation_from_config",
|
||||
"derivation_registry",
|
||||
"get_derivation",
|
||||
"load_transformation_config",
|
||||
"load_transformation_from_config",
|
||||
"register_derivation",
|
||||
]
|
||||
|
||||
SoundEventTransformation = Callable[
|
||||
[data.SoundEventAnnotation], data.SoundEventAnnotation
|
||||
]
|
||||
"""Type alias for a sound event transformation function.
|
||||
|
||||
A function that accepts a sound event annotation object and returns a
|
||||
(potentially) modified sound event annotation object. Transformations
|
||||
should generally return a copy of the annotation rather than modifying
|
||||
it in place.
|
||||
"""
|
||||
|
||||
|
||||
Derivation = Callable[[str], str]
|
||||
"""Type alias for a derivation function.
|
||||
|
||||
A function that accepts a single string (typically a tag value) and returns
|
||||
a new string (the derived value).
|
||||
"""
|
||||
|
||||
|
||||
class MapValueRule(BaseConfig):
|
||||
"""Configuration for mapping specific values of a source term.
|
||||
|
||||
This rule replaces tags matching a specific term and one of the
|
||||
original values with a new tag (potentially having a different term)
|
||||
containing the corresponding replacement value. Useful for standardizing
|
||||
or grouping tag values.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
rule_type : Literal["map_value"]
|
||||
Discriminator field identifying this rule type.
|
||||
source_term_key : str
|
||||
The key (registered in `TermRegistry`) of the term whose tags' values
|
||||
should be checked against the `value_mapping`.
|
||||
value_mapping : Dict[str, str]
|
||||
A dictionary mapping original string values to replacement string
|
||||
values. Only tags whose value is a key in this dictionary will be
|
||||
affected.
|
||||
target_term_key : str, optional
|
||||
The key (registered in `TermRegistry`) for the term of the *output*
|
||||
tag. If None (default), the output tag uses the same term as the
|
||||
source (`source_term_key`). If provided, the term of the affected
|
||||
tag is changed to this target term upon replacement.
|
||||
"""
|
||||
|
||||
rule_type: Literal["map_value"] = "map_value"
|
||||
source_term_key: str
|
||||
value_mapping: Dict[str, str]
|
||||
target_term_key: Optional[str] = None
|
||||
|
||||
|
||||
def map_value_transform(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
source_term: data.Term,
|
||||
target_term: data.Term,
|
||||
mapping: Dict[str, str],
|
||||
) -> data.SoundEventAnnotation:
|
||||
"""Apply a value mapping transformation to an annotation's tags.
|
||||
|
||||
Iterates through the annotation's tags. If a tag matches the `source_term`
|
||||
and its value is found in the `mapping`, it is replaced by a new tag with
|
||||
the `target_term` and the mapped value. Other tags are kept unchanged.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to transform.
|
||||
source_term : data.Term
|
||||
The term of tags whose values should be mapped.
|
||||
target_term : data.Term
|
||||
The term to use for the newly created tags after mapping.
|
||||
mapping : Dict[str, str]
|
||||
The dictionary mapping original values to new values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.SoundEventAnnotation
|
||||
A new annotation object with the transformed tags.
|
||||
"""
|
||||
tags = []
|
||||
|
||||
for tag in sound_event_annotation.tags:
|
||||
if tag.term != source_term or tag.value not in mapping:
|
||||
tags.append(tag)
|
||||
continue
|
||||
|
||||
new_value = mapping[tag.value]
|
||||
tags.append(data.Tag(term=target_term, value=new_value))
|
||||
|
||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||
|
||||
|
||||
class DeriveTagRule(BaseConfig):
|
||||
"""Configuration for deriving a new tag from an existing tag's value.
|
||||
|
||||
This rule applies a specified function (`derivation_function`) to the
|
||||
value of tags matching the `source_term_key`. It then adds a new tag
|
||||
with the `target_term_key` and the derived value.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
rule_type : Literal["derive_tag"]
|
||||
Discriminator field identifying this rule type.
|
||||
source_term_key : str
|
||||
The key (registered in `TermRegistry`) of the term whose tag values
|
||||
will be used as input to the derivation function.
|
||||
derivation_function : str
|
||||
The name/key identifying the derivation function to use. This can be
|
||||
a key registered in the `DerivationRegistry` or, if
|
||||
`import_derivation` is True, a full Python path like
|
||||
`'my_module.my_submodule.my_function'`.
|
||||
target_term_key : str, optional
|
||||
The key (registered in `TermRegistry`) for the term of the new tag
|
||||
that will be created with the derived value. If None (default), the
|
||||
derived tag uses the same term as the source (`source_term_key`),
|
||||
effectively performing an in-place value transformation.
|
||||
import_derivation : bool, default=False
|
||||
If True, treat `derivation_function` as a Python import path and
|
||||
attempt to dynamically import it if not found in the registry.
|
||||
Requires the function to be accessible in the Python environment.
|
||||
keep_source : bool, default=True
|
||||
If True, the original source tag (whose value was used for derivation)
|
||||
is kept in the annotation's tag list alongside the newly derived tag.
|
||||
If False, the original source tag is removed.
|
||||
"""
|
||||
|
||||
rule_type: Literal["derive_tag"] = "derive_tag"
|
||||
source_term_key: str
|
||||
derivation_function: str
|
||||
target_term_key: Optional[str] = None
|
||||
import_derivation: bool = False
|
||||
keep_source: bool = True
|
||||
|
||||
|
||||
def derivation_tag_transform(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
source_term: data.Term,
|
||||
target_term: data.Term,
|
||||
derivation: Derivation,
|
||||
keep_source: bool = True,
|
||||
) -> data.SoundEventAnnotation:
|
||||
"""Apply a derivation transformation to an annotation's tags.
|
||||
|
||||
Iterates through the annotation's tags. For each tag matching the
|
||||
`source_term`, its value is passed to the `derivation` function.
|
||||
A new tag is created with the `target_term` and the derived value,
|
||||
and added to the output tag list. The original source tag is kept
|
||||
or discarded based on `keep_source`. Other tags are kept unchanged.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to transform.
|
||||
source_term : data.Term
|
||||
The term of tags whose values serve as input for the derivation.
|
||||
target_term : data.Term
|
||||
The term to use for the newly created derived tags.
|
||||
derivation : Derivation
|
||||
The function to apply to the source tag's value.
|
||||
keep_source : bool, default=True
|
||||
Whether to keep the original source tag in the output.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.SoundEventAnnotation
|
||||
A new annotation object with the transformed tags (including derived
|
||||
ones).
|
||||
"""
|
||||
tags = []
|
||||
|
||||
for tag in sound_event_annotation.tags:
|
||||
if tag.term != source_term:
|
||||
tags.append(tag)
|
||||
continue
|
||||
|
||||
if keep_source:
|
||||
tags.append(tag)
|
||||
|
||||
new_value = derivation(tag.value)
|
||||
tags.append(data.Tag(term=target_term, value=new_value))
|
||||
|
||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||
|
||||
|
||||
class ReplaceRule(BaseConfig):
|
||||
"""Configuration for exactly replacing one specific tag with another.
|
||||
|
||||
This rule looks for an exact match of the `original` tag (both term and
|
||||
value) and replaces it with the specified `replacement` tag.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
rule_type : Literal["replace"]
|
||||
Discriminator field identifying this rule type.
|
||||
original : TagInfo
|
||||
The exact tag to search for, defined using its value and term key.
|
||||
replacement : TagInfo
|
||||
The tag to substitute in place of the original tag, defined using
|
||||
its value and term key.
|
||||
"""
|
||||
|
||||
rule_type: Literal["replace"] = "replace"
|
||||
original: TagInfo
|
||||
replacement: TagInfo
|
||||
|
||||
|
||||
def replace_tag_transform(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
source: data.Tag,
|
||||
target: data.Tag,
|
||||
) -> data.SoundEventAnnotation:
|
||||
"""Apply an exact tag replacement transformation.
|
||||
|
||||
Iterates through the annotation's tags. If a tag exactly matches the
|
||||
`source` tag, it is replaced by the `target` tag. Other tags are kept
|
||||
unchanged.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event_annotation : data.SoundEventAnnotation
|
||||
The annotation to transform.
|
||||
source : data.Tag
|
||||
The exact tag to find and replace.
|
||||
target : data.Tag
|
||||
The tag to replace the source tag with.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.SoundEventAnnotation
|
||||
A new annotation object with the replaced tag (if found).
|
||||
"""
|
||||
tags = []
|
||||
|
||||
for tag in sound_event_annotation.tags:
|
||||
if tag == source:
|
||||
tags.append(target)
|
||||
else:
|
||||
tags.append(tag)
|
||||
|
||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||
|
||||
|
||||
class TransformConfig(BaseConfig):
|
||||
"""Configuration model for defining a sequence of transformation rules.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
rules : List[Union[ReplaceRule, MapValueRule, DeriveTagRule]]
|
||||
A list of transformation rules to apply. The rules are applied
|
||||
sequentially in the order they appear in the list. The output of
|
||||
one rule becomes the input for the next. The `rule_type` field
|
||||
discriminates between the different rule models.
|
||||
"""
|
||||
|
||||
rules: List[
|
||||
Annotated[
|
||||
Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
||||
Field(discriminator="rule_type"),
|
||||
]
|
||||
] = Field(
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
class DerivationRegistry(Mapping[str, Derivation]):
|
||||
"""A registry for managing named derivation functions.
|
||||
|
||||
Derivation functions are callables that take a string value and return
|
||||
a transformed string value, used by `DeriveTagRule`. This registry
|
||||
allows functions to be registered with a key and retrieved later.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize an empty DerivationRegistry."""
|
||||
self._derivations: Dict[str, Derivation] = {}
|
||||
|
||||
def __getitem__(self, key: str) -> Derivation:
|
||||
"""Retrieve a derivation function by key."""
|
||||
return self._derivations[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of registered derivation functions."""
|
||||
return len(self._derivations)
|
||||
|
||||
def __iter__(self):
|
||||
"""Return an iterator over the keys of registered functions."""
|
||||
return iter(self._derivations)
|
||||
|
||||
def register(self, key: str, derivation: Derivation) -> None:
|
||||
"""Register a derivation function with a unique key.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The unique key to associate with the derivation function.
|
||||
derivation : Derivation
|
||||
The callable derivation function (takes str, returns str).
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a derivation function with the same key is already registered.
|
||||
"""
|
||||
if key in self._derivations:
|
||||
raise KeyError(
|
||||
f"A derivation with the provided key {key} already exists"
|
||||
)
|
||||
|
||||
self._derivations[key] = derivation
|
||||
|
||||
def get_derivation(self, key: str) -> Derivation:
|
||||
"""Retrieve a derivation function by its registered key.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The key of the derivation function to retrieve.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Derivation
|
||||
The requested derivation function.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If no derivation function with the specified key is registered.
|
||||
"""
|
||||
try:
|
||||
return self._derivations[key]
|
||||
except KeyError as err:
|
||||
raise KeyError(
|
||||
f"No derivation with key {key} is registered."
|
||||
) from err
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
"""Get a list of all registered derivation function keys.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
The keys of all registered functions.
|
||||
"""
|
||||
return list(self._derivations.keys())
|
||||
|
||||
def get_derivations(self) -> List[Derivation]:
|
||||
"""Get a list of all registered derivation functions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Derivation]
|
||||
The registered derivation function objects.
|
||||
"""
|
||||
return list(self._derivations.values())
|
||||
|
||||
|
||||
derivation_registry = DerivationRegistry()
|
||||
"""Global instance of the DerivationRegistry.
|
||||
|
||||
Register custom derivation functions here to make them available by key
|
||||
in `DeriveTagRule` configuration.
|
||||
"""
|
||||
|
||||
|
||||
def get_derivation(
|
||||
key: str,
|
||||
import_derivation: bool = False,
|
||||
registry: DerivationRegistry = derivation_registry,
|
||||
):
|
||||
"""Retrieve a derivation function by key, optionally importing it.
|
||||
|
||||
First attempts to find the function in the provided `registry`.
|
||||
If not found and `import_derivation` is True, attempts to dynamically
|
||||
import the function using the `key` as a full Python path
|
||||
(e.g., 'my_module.submodule.my_func').
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The key or Python path of the derivation function.
|
||||
import_derivation : bool, default=False
|
||||
If True, attempt dynamic import if key is not in the registry.
|
||||
registry : DerivationRegistry, optional
|
||||
The registry instance to check first. Defaults to the global
|
||||
`derivation_registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Derivation
|
||||
The requested derivation function.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If the key is not found in the registry and either
|
||||
`import_derivation` is False or the dynamic import fails.
|
||||
ImportError
|
||||
If dynamic import fails specifically due to module not found.
|
||||
AttributeError
|
||||
If dynamic import fails because the function name isn't in the module.
|
||||
"""
|
||||
if not import_derivation or key in registry:
|
||||
return registry.get_derivation(key)
|
||||
|
||||
try:
|
||||
module_path, func_name = key.rsplit(".", 1)
|
||||
module = importlib.import_module(module_path)
|
||||
func = getattr(module, func_name)
|
||||
return func
|
||||
except ImportError as err:
|
||||
raise KeyError(
|
||||
f"Unable to load derivation '{key}'. Check the path and ensure "
|
||||
"it points to a valid callable function in an importable module."
|
||||
) from err
|
||||
|
||||
|
||||
def build_transform_from_rule(
|
||||
rule: Union[ReplaceRule, MapValueRule, DeriveTagRule],
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventTransformation:
|
||||
"""Build a specific SoundEventTransformation function from a rule config.
|
||||
|
||||
Selects the appropriate transformation logic based on the rule's
|
||||
`rule_type`, fetches necessary terms and derivation functions, and
|
||||
returns a partially applied function ready to transform an annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rule : Union[ReplaceRule, MapValueRule, DeriveTagRule]
|
||||
The configuration object for a single transformation rule.
|
||||
registry : DerivationRegistry, optional
|
||||
The derivation registry to use for `DeriveTagRule`. Defaults to the
|
||||
global `derivation_registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventTransformation
|
||||
A callable that applies the specified rule to a SoundEventAnnotation.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If required term keys or derivation keys are not found.
|
||||
ValueError
|
||||
If the rule has an unknown `rule_type`.
|
||||
ImportError, AttributeError, TypeError
|
||||
If dynamic import of a derivation function fails.
|
||||
"""
|
||||
if rule.rule_type == "replace":
|
||||
source = get_tag_from_info(
|
||||
rule.original,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
target = get_tag_from_info(
|
||||
rule.replacement,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
return partial(replace_tag_transform, source=source, target=target)
|
||||
|
||||
if rule.rule_type == "derive_tag":
|
||||
source_term = get_term_from_key(
|
||||
rule.source_term_key,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
target_term = (
|
||||
get_term_from_key(
|
||||
rule.target_term_key,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
if rule.target_term_key
|
||||
else source_term
|
||||
)
|
||||
derivation = get_derivation(
|
||||
key=rule.derivation_function,
|
||||
import_derivation=rule.import_derivation,
|
||||
registry=derivation_registry,
|
||||
)
|
||||
return partial(
|
||||
derivation_tag_transform,
|
||||
source_term=source_term,
|
||||
target_term=target_term,
|
||||
derivation=derivation,
|
||||
keep_source=rule.keep_source,
|
||||
)
|
||||
|
||||
if rule.rule_type == "map_value":
|
||||
source_term = get_term_from_key(
|
||||
rule.source_term_key,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
target_term = (
|
||||
get_term_from_key(
|
||||
rule.target_term_key,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
if rule.target_term_key
|
||||
else source_term
|
||||
)
|
||||
return partial(
|
||||
map_value_transform,
|
||||
source_term=source_term,
|
||||
target_term=target_term,
|
||||
mapping=rule.value_mapping,
|
||||
)
|
||||
|
||||
# Handle unknown rule type
|
||||
valid_options = ["replace", "derive_tag", "map_value"]
|
||||
# Should be caught by Pydantic validation, but good practice
|
||||
raise ValueError(
|
||||
f"Invalid transform rule type '{getattr(rule, 'rule_type', 'N/A')}'. "
|
||||
f"Valid options are: {valid_options}"
|
||||
)
|
||||
|
||||
|
||||
def build_transformation_from_config(
|
||||
config: TransformConfig,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventTransformation:
|
||||
"""Build a composite transformation function from a TransformConfig.
|
||||
|
||||
Creates a sequence of individual transformation functions based on the
|
||||
rules defined in the configuration. Returns a single function that
|
||||
applies these transformations sequentially to an annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : TransformConfig
|
||||
The configuration object containing the list of transformation rules.
|
||||
derivation_reg : DerivationRegistry, optional
|
||||
The derivation registry to use when building `DeriveTagRule`
|
||||
transformations. Defaults to the global `derivation_registry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventTransformation
|
||||
A single function that applies all configured transformations in order.
|
||||
"""
|
||||
transforms = [
|
||||
build_transform_from_rule(
|
||||
rule,
|
||||
derivation_registry=derivation_registry,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
for rule in config.rules
|
||||
]
|
||||
|
||||
def transformation(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> data.SoundEventAnnotation:
|
||||
for transform in transforms:
|
||||
sound_event_annotation = transform(sound_event_annotation)
|
||||
return sound_event_annotation
|
||||
|
||||
return transformation
|
||||
|
||||
|
||||
def load_transformation_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> TransformConfig:
|
||||
"""Load the transformation configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (YAML).
|
||||
field : str, optional
|
||||
If the transformation configuration is nested under a specific key
|
||||
in the file, specify the key here. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TransformConfig
|
||||
The loaded and validated transformation configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the TransformConfig schema.
|
||||
"""
|
||||
return load_config(path=path, schema=TransformConfig, field=field)
|
||||
|
||||
|
||||
def load_transformation_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
term_registry: TermRegistry = default_term_registry,
|
||||
) -> SoundEventTransformation:
|
||||
"""Load transformation config from a file and build the final function.
|
||||
|
||||
This is a convenience function that combines loading the configuration
|
||||
and building the final callable transformation function that applies
|
||||
all rules sequentially.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (YAML).
|
||||
field : str, optional
|
||||
If the transformation configuration is nested under a specific key
|
||||
in the file, specify the key here. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SoundEventTransformation
|
||||
The final composite transformation function ready to be used.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the TransformConfig schema.
|
||||
KeyError
|
||||
If required term keys or derivation keys specified in the config
|
||||
are not found during the build process.
|
||||
ImportError, AttributeError, TypeError
|
||||
If dynamic import of a derivation function specified in the config
|
||||
fails.
|
||||
"""
|
||||
config = load_transformation_config(path=path, field=field)
|
||||
return build_transformation_from_config(
|
||||
config,
|
||||
derivation_registry=derivation_registry,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
|
||||
|
||||
def register_derivation(
|
||||
key: str,
|
||||
derivation: Derivation,
|
||||
derivation_registry: DerivationRegistry = derivation_registry,
|
||||
) -> None:
|
||||
"""Register a new derivation function in the global registry.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
key : str
|
||||
The unique key to associate with the derivation function.
|
||||
derivation : Derivation
|
||||
The callable derivation function (takes str, returns str).
|
||||
derivation_registry : DerivationRegistry, optional
|
||||
The registry instance to register the derivation function with.
|
||||
Defaults to the global `derivation_registry`.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If a derivation function with the same key is already registered.
|
||||
"""
|
||||
derivation_registry.register(key, derivation)
|
233
batdetect2/targets/types.py
Normal file
233
batdetect2/targets/types.py
Normal file
@ -0,0 +1,233 @@
|
||||
"""Defines the core interface (Protocol) for the target definition pipeline.
|
||||
|
||||
This module specifies the standard structure, attributes, and methods expected
|
||||
from an object that encapsulates the complete configured logic for processing
|
||||
sound event annotations within the `batdetect2.targets` system.
|
||||
|
||||
The main component defined here is the `TargetProtocol`. This protocol acts as
|
||||
a contract for the entire target definition process, covering semantic aspects
|
||||
(filtering, tag transformation, class encoding/decoding) as well as geometric
|
||||
aspects (mapping regions of interest to target positions and sizes). It ensures
|
||||
that components responsible for these tasks can be interacted with consistently
|
||||
throughout BatDetect2.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Protocol
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
"TargetProtocol",
|
||||
]
|
||||
|
||||
|
||||
class TargetProtocol(Protocol):
|
||||
"""Protocol defining the interface for the target definition pipeline.
|
||||
|
||||
This protocol outlines the standard attributes and methods for an object
|
||||
that encapsulates the complete, configured process for handling sound event
|
||||
annotations (both tags and geometry). It defines how to:
|
||||
- Filter relevant annotations.
|
||||
- Transform annotation tags.
|
||||
- Encode an annotation into a specific target class name.
|
||||
- Decode a class name back into representative tags.
|
||||
- Extract a target reference position from an annotation's geometry (ROI).
|
||||
- Calculate target size dimensions from an annotation's geometry.
|
||||
- Recover an approximate geometry (ROI) from a position and size
|
||||
dimensions.
|
||||
|
||||
Implementations of this protocol bundle all configured logic for these
|
||||
steps.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
class_names : List[str]
|
||||
An ordered list of the unique names of the specific target classes
|
||||
defined by the configuration.
|
||||
generic_class_tags : List[data.Tag]
|
||||
A list of `soundevent.data.Tag` objects representing the configured
|
||||
generic class category (e.g., used when no specific class matches).
|
||||
dimension_names : List[str]
|
||||
A list containing the names of the size dimensions returned by
|
||||
`get_size` and expected by `recover_roi` (e.g., ['width', 'height']).
|
||||
"""
|
||||
|
||||
class_names: List[str]
|
||||
"""Ordered list of unique names for the specific target classes."""
|
||||
|
||||
generic_class_tags: List[data.Tag]
|
||||
"""List of tags representing the generic (unclassified) category."""
|
||||
|
||||
dimension_names: List[str]
|
||||
"""Names of the size dimensions (e.g., ['width', 'height'])."""
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||
"""Apply the filter to a sound event annotation.
|
||||
|
||||
Determines if the annotation should be included in further processing
|
||||
and training based on the configured filtering rules.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation to filter.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation should be kept (passes the filter),
|
||||
False otherwise. Implementations should return True if no
|
||||
filtering is configured.
|
||||
"""
|
||||
...
|
||||
|
||||
def transform(
|
||||
self,
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
) -> data.SoundEventAnnotation:
|
||||
"""Apply tag transformations to an annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation whose tags should be transformed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.SoundEventAnnotation
|
||||
A new annotation object with the transformed tags. Implementations
|
||||
should return the original annotation object if no transformations
|
||||
were configured.
|
||||
"""
|
||||
...
|
||||
|
||||
def encode(
|
||||
self,
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
) -> Optional[str]:
|
||||
"""Encode a sound event annotation to its target class name.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The (potentially filtered and transformed) annotation to encode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
The string name of the matched target class if the annotation
|
||||
matches a specific class definition. Returns None if the annotation
|
||||
does not match any specific class rule (indicating it may belong
|
||||
to a generic category or should be handled differently downstream).
|
||||
"""
|
||||
...
|
||||
|
||||
def decode(self, class_label: str) -> List[data.Tag]:
|
||||
"""Decode a predicted class name back into representative tags.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
class_label : str
|
||||
The class name string (e.g., predicted by a model) to decode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.Tag]
|
||||
The list of tags corresponding to the input class name according
|
||||
to the configuration. May return an empty list or raise an error
|
||||
for unmapped labels, depending on the implementation's configuration
|
||||
(e.g., `raise_on_unmapped` flag during building).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError, KeyError
|
||||
Implementations might raise an error if the `class_label` is not
|
||||
found in the configured mapping and error raising is enabled.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_position(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> tuple[float, float]:
|
||||
"""Extract the target reference position from the annotation's geometry.
|
||||
|
||||
Calculates the `(time, frequency)` coordinate representing the primary
|
||||
location of the sound event.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI) to process.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
The calculated reference position `(time, frequency)`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry or if the position cannot be
|
||||
calculated for the geometry type or configured reference point.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_size(self, sound_event: data.SoundEventAnnotation) -> np.ndarray:
|
||||
"""Calculate the target size dimensions from the annotation's geometry.
|
||||
|
||||
Computes the relevant physical size (e.g., duration/width,
|
||||
bandwidth/height from a bounding box) to produce
|
||||
the numerical target values expected by the model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI) to process.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
A NumPy array containing the size dimensions, matching the
|
||||
order specified by the `dimension_names` attribute (e.g.,
|
||||
`[width, height]`).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry or if the size cannot be computed.
|
||||
TypeError
|
||||
If geometry type is unsupported.
|
||||
"""
|
||||
...
|
||||
|
||||
def recover_roi(
|
||||
self, pos: tuple[float, float], dims: np.ndarray
|
||||
) -> data.Geometry:
|
||||
"""Recover the ROI geometry from a position and dimensions.
|
||||
|
||||
Performs the inverse mapping of `get_position` and `get_size`. It takes
|
||||
a reference position `(time, frequency)` and an array of size
|
||||
dimensions and reconstructs an approximate geometric representation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
dims : np.ndarray
|
||||
The NumPy array containing the dimensions (e.g., predicted
|
||||
by the model), corresponding to the order in `dimension_names`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
soundevent.data.Geometry
|
||||
The reconstructed geometry.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the number of provided `dims` does not match `dimension_names`,
|
||||
if dimensions are invalid (e.g., negative after unscaling), or
|
||||
if reconstruction fails based on the configured position type.
|
||||
"""
|
||||
...
|
@ -1,88 +0,0 @@
|
||||
from inspect import getmembers
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from soundevent import data, terms
|
||||
|
||||
__all__ = [
|
||||
"call_type",
|
||||
"individual",
|
||||
"get_term_from_info",
|
||||
"get_tag_from_info",
|
||||
"TermInfo",
|
||||
"TagInfo",
|
||||
]
|
||||
|
||||
|
||||
class TermInfo(BaseModel):
|
||||
label: Optional[str]
|
||||
name: Optional[str]
|
||||
uri: Optional[str]
|
||||
|
||||
|
||||
class TagInfo(BaseModel):
|
||||
value: str
|
||||
term: Optional[TermInfo] = None
|
||||
key: Optional[str] = None
|
||||
label: Optional[str] = None
|
||||
|
||||
|
||||
call_type = data.Term(
|
||||
name="soundevent:call_type",
|
||||
label="Call Type",
|
||||
definition="A broad categorization of animal vocalizations based on their intended function or purpose (e.g., social, distress, mating, territorial, echolocation).",
|
||||
)
|
||||
|
||||
individual = data.Term(
|
||||
name="soundevent:individual",
|
||||
label="Individual",
|
||||
definition="An id for an individual animal. In the context of bioacoustic annotation, this term is used to label vocalizations that are attributed to a specific individual.",
|
||||
)
|
||||
|
||||
|
||||
ALL_TERMS = [
|
||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||
call_type,
|
||||
individual,
|
||||
]
|
||||
|
||||
|
||||
def get_term_from_info(term_info: TermInfo) -> data.Term:
|
||||
for term in ALL_TERMS:
|
||||
if term_info.name and term_info.name == term.name:
|
||||
return term
|
||||
|
||||
if term_info.label and term_info.label == term.label:
|
||||
return term
|
||||
|
||||
if term_info.uri and term_info.uri == term.uri:
|
||||
return term
|
||||
|
||||
if term_info.name is None:
|
||||
if term_info.label is None:
|
||||
raise ValueError("At least one of name or label must be provided.")
|
||||
|
||||
term_info.name = (
|
||||
f"soundevent:{term_info.label.lower().replace(' ', '_')}"
|
||||
)
|
||||
|
||||
if term_info.label is None:
|
||||
term_info.label = term_info.name
|
||||
|
||||
return data.Term(
|
||||
name=term_info.name,
|
||||
label=term_info.label,
|
||||
uri=term_info.uri,
|
||||
definition="Unknown",
|
||||
)
|
||||
|
||||
|
||||
def get_tag_from_info(tag_info: TagInfo) -> data.Tag:
|
||||
if tag_info.term:
|
||||
term = get_term_from_info(tag_info.term)
|
||||
elif tag_info.key:
|
||||
term = data.term_from_key(tag_info.key)
|
||||
else:
|
||||
raise ValueError("Either term or key must be provided in tag info.")
|
||||
|
||||
return data.Tag(term=term, value=tag_info.value)
|
@ -1,53 +1,54 @@
|
||||
from batdetect2.train.augmentations import (
|
||||
AugmentationsConfig,
|
||||
EchoAugmentationConfig,
|
||||
FrequencyMaskAugmentationConfig,
|
||||
TimeMaskAugmentationConfig,
|
||||
VolumeAugmentationConfig,
|
||||
WarpAugmentationConfig,
|
||||
add_echo,
|
||||
augment_example,
|
||||
load_agumentation_config,
|
||||
build_augmentations,
|
||||
mask_frequency,
|
||||
mask_time,
|
||||
mix_examples,
|
||||
scale_volume,
|
||||
select_subclip,
|
||||
warp_spectrogram,
|
||||
)
|
||||
from batdetect2.train.clips import build_clipper, select_subclip
|
||||
from batdetect2.train.config import TrainingConfig, load_train_config
|
||||
from batdetect2.train.dataset import (
|
||||
LabeledDataset,
|
||||
SubclipConfig,
|
||||
RandomExampleSource,
|
||||
TrainExample,
|
||||
get_preprocessed_files,
|
||||
list_preprocessed_files,
|
||||
)
|
||||
from batdetect2.train.labels import LabelConfig, load_label_config
|
||||
from batdetect2.train.labels import load_label_config
|
||||
from batdetect2.train.losses import LossFunction, build_loss
|
||||
from batdetect2.train.preprocess import (
|
||||
generate_train_example,
|
||||
preprocess_annotations,
|
||||
)
|
||||
from batdetect2.train.targets import (
|
||||
TagInfo,
|
||||
TargetConfig,
|
||||
build_target_encoder,
|
||||
load_target_config,
|
||||
)
|
||||
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
|
||||
|
||||
__all__ = [
|
||||
"AugmentationsConfig",
|
||||
"LabelConfig",
|
||||
"EchoAugmentationConfig",
|
||||
"FrequencyMaskAugmentationConfig",
|
||||
"LabeledDataset",
|
||||
"SubclipConfig",
|
||||
"TagInfo",
|
||||
"TargetConfig",
|
||||
"LossFunction",
|
||||
"RandomExampleSource",
|
||||
"TimeMaskAugmentationConfig",
|
||||
"TrainExample",
|
||||
"TrainerConfig",
|
||||
"TrainingConfig",
|
||||
"VolumeAugmentationConfig",
|
||||
"WarpAugmentationConfig",
|
||||
"add_echo",
|
||||
"augment_example",
|
||||
"build_target_encoder",
|
||||
"build_augmentations",
|
||||
"build_clipper",
|
||||
"build_loss",
|
||||
"generate_train_example",
|
||||
"get_preprocessed_files",
|
||||
"load_agumentation_config",
|
||||
"list_preprocessed_files",
|
||||
"load_label_config",
|
||||
"load_target_config",
|
||||
"load_train_config",
|
||||
"load_trainer_config",
|
||||
"mask_frequency",
|
||||
|
File diff suppressed because it is too large
Load Diff
55
batdetect2/train/callbacks.py
Normal file
55
batdetect2/train/callbacks.py
Normal file
@ -0,0 +1,55 @@
|
||||
from lightning import LightningModule, Trainer
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.postprocess import PostprocessorProtocol
|
||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||
from batdetect2.types import ModelOutput
|
||||
|
||||
|
||||
class ValidationMetrics(Callback):
|
||||
def __init__(self, postprocessor: PostprocessorProtocol):
|
||||
super().__init__()
|
||||
self.postprocessor = postprocessor
|
||||
self.predictions = []
|
||||
|
||||
def on_validation_epoch_start(
|
||||
self,
|
||||
trainer: Trainer,
|
||||
pl_module: LightningModule,
|
||||
) -> None:
|
||||
self.predictions = []
|
||||
return super().on_validation_epoch_start(trainer, pl_module)
|
||||
|
||||
def on_validation_batch_end( # type: ignore
|
||||
self,
|
||||
trainer: Trainer,
|
||||
pl_module: LightningModule,
|
||||
outputs: ModelOutput,
|
||||
batch: TrainExample,
|
||||
batch_idx: int,
|
||||
dataloader_idx: int = 0,
|
||||
) -> None:
|
||||
dataloaders = trainer.val_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
dataset = dataloaders.dataset
|
||||
assert isinstance(dataset, LabeledDataset)
|
||||
clip_annotation = dataset.get_clip_annotation(batch_idx)
|
||||
|
||||
# clip_prediction = postprocess_model_outputs(
|
||||
# outputs,
|
||||
# clips=[clip_annotation.clip],
|
||||
# classes=self.class_names,
|
||||
# decoder=self.decoder,
|
||||
# config=self.config.postprocessing,
|
||||
# )[0]
|
||||
#
|
||||
# matches = match_predictions_and_annotations(
|
||||
# clip_annotation,
|
||||
# clip_prediction,
|
||||
# )
|
||||
#
|
||||
# self.validation_predictions.extend(matches)
|
||||
# return super().on_validation_batch_end(
|
||||
# trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
||||
# )
|
184
batdetect2/train/clips.py
Normal file
184
batdetect2/train/clips.py
Normal file
@ -0,0 +1,184 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import arrays
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.train.types import ClipperProtocol
|
||||
|
||||
DEFAULT_TRAIN_CLIP_DURATION = 0.513
|
||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||
|
||||
|
||||
class ClipingConfig(BaseConfig):
|
||||
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
||||
random: bool = True
|
||||
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
|
||||
|
||||
|
||||
class Clipper(ClipperProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
duration: float = 0.5,
|
||||
max_empty: float = 0.2,
|
||||
random: bool = True,
|
||||
):
|
||||
self.duration = duration
|
||||
self.random = random
|
||||
self.max_empty = max_empty
|
||||
|
||||
def extract_clip(
|
||||
self, example: xr.Dataset
|
||||
) -> Tuple[xr.Dataset, float, float]:
|
||||
step = arrays.get_dim_step(
|
||||
example.spectrogram,
|
||||
dim=arrays.Dimensions.time.value,
|
||||
)
|
||||
duration = (
|
||||
arrays.get_dim_width(
|
||||
example.spectrogram,
|
||||
dim=arrays.Dimensions.time.value,
|
||||
)
|
||||
+ step
|
||||
)
|
||||
|
||||
start_time = 0
|
||||
if self.random:
|
||||
start_time = np.random.uniform(
|
||||
-self.max_empty,
|
||||
duration - self.duration + self.max_empty,
|
||||
)
|
||||
|
||||
subclip = select_subclip(
|
||||
example,
|
||||
start=start_time,
|
||||
span=self.duration,
|
||||
dim="time",
|
||||
)
|
||||
|
||||
return (
|
||||
select_subclip(
|
||||
subclip,
|
||||
start=start_time,
|
||||
span=self.duration,
|
||||
dim="audio_time",
|
||||
),
|
||||
start_time,
|
||||
start_time + self.duration,
|
||||
)
|
||||
|
||||
|
||||
def build_clipper(config: Optional[ClipingConfig] = None) -> ClipperProtocol:
|
||||
config = config or ClipingConfig()
|
||||
return Clipper(
|
||||
duration=config.duration,
|
||||
max_empty=config.max_empty,
|
||||
random=config.random,
|
||||
)
|
||||
|
||||
|
||||
def select_subclip(
|
||||
dataset: xr.Dataset,
|
||||
span: float,
|
||||
start: float,
|
||||
fill_value: float = 0,
|
||||
dim: str = "time",
|
||||
) -> xr.Dataset:
|
||||
width = _compute_expected_width(
|
||||
dataset, # type: ignore
|
||||
span,
|
||||
dim=dim,
|
||||
)
|
||||
|
||||
coord = dataset.coords[dim]
|
||||
|
||||
if len(coord) == width:
|
||||
return dataset
|
||||
|
||||
new_coords, start_pad, end_pad, dim_slice = _extract_coordinate(
|
||||
coord, start, span
|
||||
)
|
||||
|
||||
data_vars = {}
|
||||
for name, data_array in dataset.data_vars.items():
|
||||
if dim not in data_array.dims:
|
||||
data_vars[name] = data_array
|
||||
continue
|
||||
|
||||
if width == data_array.sizes[dim]:
|
||||
data_vars[name] = data_array
|
||||
continue
|
||||
|
||||
sliced = data_array.isel({dim: dim_slice}).data
|
||||
|
||||
if start_pad > 0 or end_pad > 0:
|
||||
padding = [
|
||||
[0, 0] if other_dim != dim else [start_pad, end_pad]
|
||||
for other_dim in data_array.dims
|
||||
]
|
||||
sliced = np.pad(sliced, padding, constant_values=fill_value)
|
||||
|
||||
data_vars[name] = xr.DataArray(
|
||||
data=sliced,
|
||||
dims=data_array.dims,
|
||||
coords={**data_array.coords, dim: new_coords},
|
||||
attrs=data_array.attrs,
|
||||
)
|
||||
|
||||
return xr.Dataset(data_vars=data_vars, attrs=dataset.attrs)
|
||||
|
||||
|
||||
def _extract_coordinate(
|
||||
coord: xr.DataArray,
|
||||
start: float,
|
||||
span: float,
|
||||
) -> Tuple[xr.Variable, int, int, slice]:
|
||||
step = arrays.get_dim_step(coord, str(coord.name))
|
||||
|
||||
current_width = len(coord)
|
||||
expected_width = int(np.floor(span / step))
|
||||
|
||||
coord_start = float(coord[0])
|
||||
offset = start - coord_start
|
||||
|
||||
start_index = int(np.floor(offset / step))
|
||||
end_index = start_index + expected_width
|
||||
|
||||
if start_index > current_width:
|
||||
raise ValueError("Requested span does not overlap with current range")
|
||||
|
||||
if end_index < 0:
|
||||
raise ValueError("Requested span does not overlap with current range")
|
||||
|
||||
corrected_start = float(start_index * step)
|
||||
corrected_end = float(end_index * step)
|
||||
|
||||
start_index_offset = max(0, -start_index)
|
||||
end_index_offset = max(0, end_index - current_width)
|
||||
|
||||
sl = slice(
|
||||
start_index if start_index >= 0 else None,
|
||||
end_index if end_index < current_width else None,
|
||||
)
|
||||
|
||||
return (
|
||||
arrays.create_range_dim(
|
||||
str(coord.name),
|
||||
start=corrected_start,
|
||||
stop=corrected_end,
|
||||
step=step,
|
||||
),
|
||||
start_index_offset,
|
||||
end_index_offset,
|
||||
sl,
|
||||
)
|
||||
|
||||
|
||||
def _compute_expected_width(
|
||||
array: Union[xr.DataArray, xr.Dataset],
|
||||
duration: float,
|
||||
dim: str,
|
||||
) -> int:
|
||||
step = arrays.get_dim_step(array, dim) # type: ignore
|
||||
return int(np.floor(duration / step))
|
@ -4,6 +4,11 @@ from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.train.augmentations import (
|
||||
DEFAULT_AUGMENTATION_CONFIG,
|
||||
AugmentationsConfig,
|
||||
)
|
||||
from batdetect2.train.clips import ClipingConfig
|
||||
from batdetect2.train.losses import LossConfig
|
||||
|
||||
__all__ = [
|
||||
@ -20,9 +25,17 @@ class OptimizerConfig(BaseConfig):
|
||||
|
||||
class TrainingConfig(BaseConfig):
|
||||
batch_size: int = 32
|
||||
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
|
||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||
|
||||
augmentations: AugmentationsConfig = Field(
|
||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
||||
)
|
||||
|
||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||
|
||||
|
||||
def load_train_config(
|
||||
path: PathLike,
|
||||
|
@ -1,87 +1,40 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, Optional, Sequence, Union
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.tensors import adjust_width
|
||||
from batdetect2.train.augmentations import (
|
||||
AugmentationsConfig,
|
||||
augment_example,
|
||||
select_subclip,
|
||||
)
|
||||
from batdetect2.train.preprocess import PreprocessingConfig
|
||||
from batdetect2.train.augmentations import Augmentation
|
||||
from batdetect2.train.types import ClipperProtocol, TrainExample
|
||||
|
||||
__all__ = [
|
||||
"TrainExample",
|
||||
"LabeledDataset",
|
||||
]
|
||||
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
|
||||
class TrainExample(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
idx: torch.Tensor
|
||||
|
||||
|
||||
class SubclipConfig(BaseConfig):
|
||||
duration: Optional[float] = None
|
||||
width: int = 512
|
||||
random: bool = False
|
||||
|
||||
|
||||
class DatasetConfig(BaseConfig):
|
||||
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
|
||||
augmentation: AugmentationsConfig = Field(
|
||||
default_factory=AugmentationsConfig
|
||||
)
|
||||
|
||||
|
||||
class LabeledDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
filenames: Sequence[PathLike],
|
||||
subclip: Optional[SubclipConfig] = None,
|
||||
augmentation: Optional[AugmentationsConfig] = None,
|
||||
preprocessing: Optional[PreprocessingConfig] = None,
|
||||
filenames: Sequence[data.PathLike],
|
||||
clipper: ClipperProtocol,
|
||||
augmentation: Optional[Augmentation] = None,
|
||||
):
|
||||
self.filenames = filenames
|
||||
self.subclip = subclip
|
||||
self.clipper = clipper
|
||||
self.augmentation = augmentation
|
||||
self.preprocessing = preprocessing or PreprocessingConfig()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.filenames)
|
||||
|
||||
def __getitem__(self, idx) -> TrainExample:
|
||||
dataset = self.get_dataset(idx)
|
||||
|
||||
if self.subclip:
|
||||
dataset = select_subclip(
|
||||
dataset,
|
||||
duration=self.subclip.duration,
|
||||
width=self.subclip.width,
|
||||
random=self.subclip.random,
|
||||
)
|
||||
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
|
||||
|
||||
if self.augmentation:
|
||||
dataset = augment_example(
|
||||
dataset,
|
||||
self.augmentation,
|
||||
preprocessing_config=self.preprocessing,
|
||||
others=self.get_random_example,
|
||||
)
|
||||
dataset = self.augmentation(dataset)
|
||||
|
||||
return TrainExample(
|
||||
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
|
||||
@ -89,37 +42,31 @@ class LabeledDataset(Dataset):
|
||||
class_heatmap=self.to_tensor(dataset["class"]),
|
||||
size_heatmap=self.to_tensor(dataset["size"]),
|
||||
idx=torch.tensor(idx),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_directory(
|
||||
cls,
|
||||
directory: PathLike,
|
||||
directory: data.PathLike,
|
||||
clipper: ClipperProtocol,
|
||||
extension: str = ".nc",
|
||||
subclip: Optional[SubclipConfig] = None,
|
||||
augmentation: Optional[AugmentationsConfig] = None,
|
||||
preprocessing: Optional[PreprocessingConfig] = None,
|
||||
augmentation: Optional[Augmentation] = None,
|
||||
):
|
||||
return cls(
|
||||
get_preprocessed_files(directory, extension),
|
||||
subclip=subclip,
|
||||
filenames=list_preprocessed_files(directory, extension),
|
||||
clipper=clipper,
|
||||
augmentation=augmentation,
|
||||
preprocessing=preprocessing,
|
||||
)
|
||||
|
||||
def get_random_example(self) -> xr.Dataset:
|
||||
def get_random_example(self) -> Tuple[xr.Dataset, float, float]:
|
||||
idx = np.random.randint(0, len(self))
|
||||
dataset = self.get_dataset(idx)
|
||||
|
||||
if self.subclip:
|
||||
dataset = select_subclip(
|
||||
dataset,
|
||||
duration=self.subclip.duration,
|
||||
width=self.subclip.width,
|
||||
random=self.subclip.random,
|
||||
)
|
||||
dataset, start_time, end_time = self.clipper.extract_clip(dataset)
|
||||
|
||||
return dataset
|
||||
return dataset, start_time, end_time
|
||||
|
||||
def get_dataset(self, idx) -> xr.Dataset:
|
||||
return xr.open_dataset(self.filenames[idx])
|
||||
@ -134,16 +81,23 @@ class LabeledDataset(Dataset):
|
||||
array: xr.DataArray,
|
||||
dtype=np.float32,
|
||||
) -> torch.Tensor:
|
||||
tensor = torch.tensor(array.values.astype(dtype))
|
||||
|
||||
if not self.subclip:
|
||||
return tensor
|
||||
|
||||
width = self.subclip.width
|
||||
return adjust_width(tensor, width)
|
||||
return torch.tensor(array.values.astype(dtype))
|
||||
|
||||
|
||||
def get_preprocessed_files(
|
||||
directory: PathLike, extension: str = ".nc"
|
||||
def list_preprocessed_files(
|
||||
directory: data.PathLike, extension: str = ".nc"
|
||||
) -> Sequence[Path]:
|
||||
return list(Path(directory).glob(f"*{extension}"))
|
||||
|
||||
|
||||
class RandomExampleSource:
|
||||
def __init__(self, filenames: List[str], clipper: ClipperProtocol):
|
||||
self.filenames = filenames
|
||||
self.clipper = clipper
|
||||
|
||||
def __call__(self):
|
||||
index = int(np.random.randint(len(self.filenames)))
|
||||
filename = self.filenames[index]
|
||||
dataset = xr.open_dataset(filename)
|
||||
example, _, _ = self.clipper.extract_clip(dataset)
|
||||
return example
|
||||
|
@ -1,45 +1,216 @@
|
||||
"""Generate heatmap training targets for BatDetect2 models.
|
||||
|
||||
This module is responsible for creating the target labels used for training
|
||||
BatDetect2 models. It converts sound event annotations for an audio clip into
|
||||
the specific multi-channel heatmap formats required by the neural network.
|
||||
|
||||
It uses a pre-configured object adhering to the `TargetProtocol` (from
|
||||
`batdetect2.targets`) which encapsulates all the logic for filtering
|
||||
annotations, transforming tags, encoding class names, and mapping annotation
|
||||
geometry (ROIs) to target positions and sizes. This module then focuses on
|
||||
rendering this information onto the heatmap grids.
|
||||
|
||||
The pipeline generates three core outputs for a given spectrogram:
|
||||
1. **Detection Heatmap**: Indicates presence/location of relevant sound events.
|
||||
2. **Class Heatmap**: Indicates location and class identity for specifically
|
||||
classified events.
|
||||
3. **Size Heatmap**: Encodes the target dimensions (width, height) of events.
|
||||
|
||||
The primary function generated by this module is a `ClipLabeller` (defined in
|
||||
`.types`), which takes a `ClipAnnotation` object and its corresponding
|
||||
spectrogram and returns the calculated `Heatmaps` tuple. The main configurable
|
||||
parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
|
||||
defined in `LabelConfig`.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, List, Optional, Sequence, Tuple
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from scipy.ndimage import gaussian_filter
|
||||
from soundevent import arrays, data, geometry
|
||||
from soundevent.geometry.operations import Positions
|
||||
from soundevent import arrays, data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.types import (
|
||||
ClipLabeller,
|
||||
Heatmaps,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"HeatmapsConfig",
|
||||
"LabelConfig",
|
||||
"build_clip_labeler",
|
||||
"generate_clip_label",
|
||||
"generate_heatmaps",
|
||||
"load_label_config",
|
||||
]
|
||||
|
||||
|
||||
class HeatmapsConfig(BaseConfig):
|
||||
position: Positions = "bottom-left"
|
||||
sigma: float = 3.0
|
||||
time_scale: float = 1000.0
|
||||
frequency_scale: float = 1 / 859.375
|
||||
SIZE_DIMENSION = "dimension"
|
||||
"""Dimension name for the size heatmap."""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LabelConfig(BaseConfig):
|
||||
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
|
||||
"""Configuration parameters for heatmap generation.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
sigma : float, default=3.0
|
||||
The standard deviation (in pixels/bins) of the Gaussian kernel applied
|
||||
to smooth the detection and class heatmaps. Larger values create more
|
||||
diffuse targets.
|
||||
"""
|
||||
|
||||
sigma: float = 3.0
|
||||
|
||||
|
||||
def build_clip_labeler(
|
||||
targets: TargetProtocol,
|
||||
config: Optional[LabelConfig] = None,
|
||||
) -> ClipLabeller:
|
||||
"""Construct the final clip labelling function.
|
||||
|
||||
This factory function prepares the callable that will perform the
|
||||
end-to-end heatmap generation for a given clip and spectrogram during
|
||||
training data loading. It takes the fully configured `targets` object and
|
||||
the `LabelConfig` and binds them to the `generate_clip_label` function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
targets : TargetProtocol
|
||||
An initialized object conforming to the `TargetProtocol`, providing all
|
||||
necessary methods for filtering, transforming, encoding, and ROI
|
||||
mapping.
|
||||
config : LabelConfig
|
||||
Configuration object containing heatmap generation parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ClipLabeller
|
||||
A function that accepts a `data.ClipAnnotation` and `xr.DataArray`
|
||||
(spectrogram) and returns the generated `Heatmaps`.
|
||||
"""
|
||||
return partial(
|
||||
generate_clip_label,
|
||||
targets=targets,
|
||||
config=config or LabelConfig(),
|
||||
)
|
||||
|
||||
|
||||
def generate_clip_label(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
spec: xr.DataArray,
|
||||
targets: TargetProtocol,
|
||||
config: LabelConfig,
|
||||
) -> Heatmaps:
|
||||
"""Generate training heatmaps for a single annotated clip.
|
||||
|
||||
This function orchestrates the target generation process for one clip:
|
||||
1. Filters and transforms sound events using `targets.filter` and
|
||||
`targets.transform`.
|
||||
2. Passes the resulting processed annotations, along with the spectrogram,
|
||||
the `targets` object, and the Gaussian `sigma` from `config`, to the
|
||||
core `generate_heatmaps` function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip_annotation : data.ClipAnnotation
|
||||
The complete annotation data for the audio clip, including the list
|
||||
of `sound_events` to process.
|
||||
spec : xr.DataArray
|
||||
The spectrogram corresponding to the `clip_annotation`. Must have
|
||||
'time' and 'frequency' dimensions/coordinates.
|
||||
targets : TargetProtocol
|
||||
The fully configured target definition object, providing methods for
|
||||
filtering, transformation, encoding, and ROI mapping.
|
||||
config : LabelConfig
|
||||
Configuration object providing heatmap parameters (primarily `sigma`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Heatmaps
|
||||
A NamedTuple containing the generated 'detection', 'classes', and 'size'
|
||||
heatmaps for this clip.
|
||||
"""
|
||||
return generate_heatmaps(
|
||||
(
|
||||
targets.transform(sound_event_annotation)
|
||||
for sound_event_annotation in clip_annotation.sound_events
|
||||
if targets.filter(sound_event_annotation)
|
||||
),
|
||||
spec=spec,
|
||||
targets=targets,
|
||||
target_sigma=config.sigma,
|
||||
)
|
||||
|
||||
|
||||
def generate_heatmaps(
|
||||
sound_events: Sequence[data.SoundEventAnnotation],
|
||||
sound_events: Iterable[data.SoundEventAnnotation],
|
||||
spec: xr.DataArray,
|
||||
class_names: List[str],
|
||||
encoder: Callable[[Iterable[data.Tag]], Optional[str]],
|
||||
targets: TargetProtocol,
|
||||
target_sigma: float = 3.0,
|
||||
position: Positions = "bottom-left",
|
||||
time_scale: float = 1000.0,
|
||||
frequency_scale: float = 1 / 859.375,
|
||||
dtype=np.float32,
|
||||
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
|
||||
) -> Heatmaps:
|
||||
"""Generate detection, class, and size heatmaps from sound events.
|
||||
|
||||
Creates heatmap representations from an iterable of sound event
|
||||
annotations. This function relies on the provided `targets` object to get
|
||||
the reference position, scaled size, and class encoding for each
|
||||
annotation.
|
||||
|
||||
Process:
|
||||
1. Initializes empty heatmap arrays based on `spec` shape and `targets`
|
||||
metadata.
|
||||
2. Iterates through `sound_events`.
|
||||
3. For each event:
|
||||
a. Gets geometry. Skips if missing.
|
||||
b. Gets reference position using `targets.get_position()`. Skips if out
|
||||
of bounds.
|
||||
c. Places a peak (1.0) on the detection heatmap at the position.
|
||||
d. Gets scaled size using `targets.get_size()` and places it on the
|
||||
size heatmap.
|
||||
e. Encodes class using `targets.encode()` and places a peak (1.0) on
|
||||
the corresponding class heatmap layer if a specific class is
|
||||
returned.
|
||||
4. Applies Gaussian smoothing (using `target_sigma`) to detection and class
|
||||
heatmaps.
|
||||
5. Normalizes detection and class heatmaps to range [0, 1].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_events : Iterable[data.SoundEventAnnotation]
|
||||
An iterable of sound event annotations to render onto the heatmaps.
|
||||
spec : xr.DataArray
|
||||
The spectrogram array corresponding to the time/frequency range of
|
||||
the annotations. Used for shape and coordinate information. Must have
|
||||
'time' and 'frequency' dimensions/coordinates.
|
||||
targets : TargetProtocol
|
||||
The fully configured target definition object. Used to access class
|
||||
names, dimension names, and the methods `get_position`, `get_size`,
|
||||
`encode`.
|
||||
target_sigma : float, default=3.0
|
||||
Standard deviation (in pixels/bins) of the Gaussian kernel applied to
|
||||
smooth the detection and class heatmaps.
|
||||
dtype : type, default=np.float32
|
||||
The data type for the generated heatmap arrays (e.g., `np.float32`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Heatmaps
|
||||
A NamedTuple containing the 'detection', 'classes', and 'size'
|
||||
xarray DataArrays, ready for use in model training.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input spectrogram `spec` does not have both 'time' and
|
||||
'frequency' dimensions, or if `targets.class_names` is empty.
|
||||
"""
|
||||
shape = dict(zip(spec.dims, spec.shape))
|
||||
|
||||
if "time" not in shape or "frequency" not in shape:
|
||||
@ -50,18 +221,18 @@ def generate_heatmaps(
|
||||
# Initialize heatmaps
|
||||
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
||||
class_heatmap = xr.DataArray(
|
||||
data=np.zeros((len(class_names), *spec.shape), dtype=dtype),
|
||||
data=np.zeros((len(targets.class_names), *spec.shape), dtype=dtype),
|
||||
dims=["category", *spec.dims],
|
||||
coords={
|
||||
"category": [*class_names],
|
||||
"category": [*targets.class_names],
|
||||
**spec.coords,
|
||||
},
|
||||
)
|
||||
size_heatmap = xr.DataArray(
|
||||
data=np.zeros((2, *spec.shape), dtype=dtype),
|
||||
dims=["dimension", *spec.dims],
|
||||
dims=[SIZE_DIMENSION, *spec.dims],
|
||||
coords={
|
||||
"dimension": ["width", "height"],
|
||||
SIZE_DIMENSION: targets.dimension_names,
|
||||
**spec.coords,
|
||||
},
|
||||
)
|
||||
@ -69,10 +240,14 @@ def generate_heatmaps(
|
||||
for sound_event_annotation in sound_events:
|
||||
geom = sound_event_annotation.sound_event.geometry
|
||||
if geom is None:
|
||||
logger.debug(
|
||||
"Skipping annotation %s: missing geometry.",
|
||||
sound_event_annotation.uuid,
|
||||
)
|
||||
continue
|
||||
|
||||
# Get the position of the sound event
|
||||
time, frequency = geometry.get_geometry_point(geom, position=position)
|
||||
time, frequency = targets.get_position(sound_event_annotation)
|
||||
|
||||
# Set 1.0 at the position of the sound event in the detection heatmap
|
||||
try:
|
||||
@ -84,19 +259,15 @@ def generate_heatmaps(
|
||||
)
|
||||
except KeyError:
|
||||
# Skip the sound event if the position is outside the spectrogram
|
||||
logger.debug(
|
||||
"Skipping annotation %s: position outside spectrogram. "
|
||||
"Pos: %s",
|
||||
sound_event_annotation.uuid,
|
||||
(time, frequency),
|
||||
)
|
||||
continue
|
||||
|
||||
# Set the size of the sound event at the position in the size heatmap
|
||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||
geom
|
||||
)
|
||||
|
||||
size = np.array(
|
||||
[
|
||||
(end_time - start_time) * time_scale,
|
||||
(high_freq - low_freq) * frequency_scale,
|
||||
]
|
||||
)
|
||||
size = targets.get_size(sound_event_annotation)
|
||||
|
||||
size_heatmap = arrays.set_value_at_pos(
|
||||
size_heatmap,
|
||||
@ -106,19 +277,38 @@ def generate_heatmaps(
|
||||
)
|
||||
|
||||
# Get the class name of the sound event
|
||||
class_name = encoder(sound_event_annotation.tags)
|
||||
try:
|
||||
class_name = targets.encode(sound_event_annotation)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Skipping annotation %s: Unexpected error while encoding "
|
||||
"class name %s",
|
||||
sound_event_annotation.uuid,
|
||||
e,
|
||||
)
|
||||
continue
|
||||
|
||||
if class_name is None:
|
||||
# If the label is None skip the sound event
|
||||
continue
|
||||
|
||||
class_heatmap = arrays.set_value_at_pos(
|
||||
class_heatmap,
|
||||
1.0,
|
||||
time=time,
|
||||
frequency=frequency,
|
||||
category=class_name,
|
||||
)
|
||||
try:
|
||||
class_heatmap = arrays.set_value_at_pos(
|
||||
class_heatmap,
|
||||
1.0,
|
||||
time=time,
|
||||
frequency=frequency,
|
||||
category=class_name,
|
||||
)
|
||||
except KeyError:
|
||||
# Skip the sound event if the position is outside the spectrogram
|
||||
logger.debug(
|
||||
"Skipping annotation %s for class heatmap: "
|
||||
"position outside spectrogram. Pos: %s",
|
||||
sound_event_annotation.uuid,
|
||||
(class_name, time, frequency),
|
||||
)
|
||||
continue
|
||||
|
||||
# Apply gaussian filters
|
||||
detection_heatmap = xr.apply_ufunc(
|
||||
@ -141,10 +331,36 @@ def generate_heatmaps(
|
||||
class_heatmap / class_heatmap.max(dim=["time", "frequency"])
|
||||
).fillna(0.0)
|
||||
|
||||
return detection_heatmap, class_heatmap, size_heatmap
|
||||
return Heatmaps(
|
||||
detection=detection_heatmap,
|
||||
classes=class_heatmap,
|
||||
size=size_heatmap,
|
||||
)
|
||||
|
||||
|
||||
def load_label_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> LabelConfig:
|
||||
"""Load the heatmap label generation configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (e.g., YAML or JSON).
|
||||
field : str, optional
|
||||
If the label configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LabelConfig
|
||||
The loaded and validated label configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the LabelConfig schema.
|
||||
"""
|
||||
return load_config(path, schema=LabelConfig, field=field)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
from soundevent import data
|
||||
@ -6,7 +6,7 @@ from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.models.typing import DetectionModel
|
||||
from batdetect2.models.types import DetectionModel
|
||||
from batdetect2.train.dataset import LabeledDataset
|
||||
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
import sys
|
||||
import json
|
||||
import sys
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from typing import Dict, Generator, List, Optional, Tuple
|
||||
|
71
batdetect2/train/lightning.py
Normal file
71
batdetect2/train/lightning.py
Normal file
@ -0,0 +1,71 @@
|
||||
import lightning as L
|
||||
import torch
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.models import (
|
||||
DetectionModel,
|
||||
ModelOutput,
|
||||
)
|
||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train import TrainExample
|
||||
from batdetect2.train.types import LossProtocol
|
||||
|
||||
__all__ = [
|
||||
"TrainingModule",
|
||||
]
|
||||
|
||||
|
||||
class TrainingModule(L.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
detector: DetectionModel,
|
||||
loss: LossProtocol,
|
||||
targets: TargetProtocol,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
learning_rate: float = 0.001,
|
||||
t_max: int = 100,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.loss = loss
|
||||
self.detector = detector
|
||||
self.preprocessor = preprocessor
|
||||
self.targets = targets
|
||||
self.postprocessor = postprocessor
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.t_max = t_max
|
||||
|
||||
self.save_hyperparameters()
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
return self.detector(spec)
|
||||
|
||||
def training_step(self, batch: TrainExample):
|
||||
outputs = self.forward(batch.spec)
|
||||
losses = self.loss(outputs, batch)
|
||||
|
||||
self.log("train/loss/total", losses.total, prog_bar=True, logger=True)
|
||||
self.log("train/loss/detection", losses.total, logger=True)
|
||||
self.log("train/loss/size", losses.total, logger=True)
|
||||
self.log("train/loss/classification", losses.total, logger=True)
|
||||
|
||||
return losses.total
|
||||
|
||||
def validation_step(self, batch: TrainExample, batch_idx: int) -> None:
|
||||
outputs = self.forward(batch.spec)
|
||||
losses = self.loss(outputs, batch)
|
||||
|
||||
self.log("val/loss/total", losses.total, prog_bar=True, logger=True)
|
||||
self.log("val/loss/detection", losses.total, logger=True)
|
||||
self.log("val/loss/size", losses.total, logger=True)
|
||||
self.log("val/loss/classification", losses.total, logger=True)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
||||
return [optimizer], [scheduler]
|
@ -1,115 +1,323 @@
|
||||
from typing import NamedTuple, Optional
|
||||
"""Loss functions and configurations for training BatDetect2 models.
|
||||
|
||||
This module defines the loss functions used to train BatDetect2 models,
|
||||
including individual loss components for different prediction tasks (detection,
|
||||
classification, size regression) and a main coordinating loss function that
|
||||
combines them.
|
||||
|
||||
It utilizes common loss types like L1 loss (`BBoxLoss`) for regression and
|
||||
Focal Loss (`FocalLoss`) for handling class imbalance in dense detection and
|
||||
classification tasks. Configuration objects (`LossConfig`, etc.) allow for easy
|
||||
customization of loss parameters and weights via configuration files.
|
||||
|
||||
The primary entry points are:
|
||||
- `LossFunction`: An `nn.Module` that computes the weighted sum of individual
|
||||
loss components given model outputs and ground truth targets.
|
||||
- `build_loss`: A factory function that constructs the `LossFunction` based
|
||||
on a `LossConfig` object.
|
||||
- `LossConfig`: The Pydantic model for configuring loss weights and parameters.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.train.dataset import TrainExample
|
||||
from batdetect2.train.types import Losses, LossProtocol
|
||||
|
||||
__all__ = [
|
||||
"bbox_size_loss",
|
||||
"compute_loss",
|
||||
"focal_loss",
|
||||
"mse_loss",
|
||||
"BBoxLoss",
|
||||
"ClassificationLossConfig",
|
||||
"DetectionLossConfig",
|
||||
"FocalLoss",
|
||||
"FocalLossConfig",
|
||||
"LossConfig",
|
||||
"LossFunction",
|
||||
"MSELoss",
|
||||
"SizeLossConfig",
|
||||
"build_loss",
|
||||
]
|
||||
|
||||
|
||||
class SizeLossConfig(BaseConfig):
|
||||
"""Configuration for the bounding box size loss component.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
weight : float, default=0.1
|
||||
The weighting factor applied to the size loss when combining it with
|
||||
other losses (detection, classification) to form the total training
|
||||
loss.
|
||||
"""
|
||||
|
||||
weight: float = 0.1
|
||||
|
||||
|
||||
def bbox_size_loss(
|
||||
pred_size: torch.Tensor,
|
||||
gt_size: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
class BBoxLoss(nn.Module):
|
||||
"""Computes L1 loss for bounding box size regression.
|
||||
|
||||
Calculates the Mean Absolute Error (MAE or L1 loss) between the predicted
|
||||
size dimensions (`pred`) and the ground truth size dimensions (`gt`).
|
||||
Crucially, the loss is only computed at locations where the ground truth
|
||||
size heatmap (`gt`) contains non-zero values (i.e., at the reference points
|
||||
of actual annotated sound events). This prevents the model from being
|
||||
penalized for size predictions in background regions.
|
||||
|
||||
The loss is summed over all valid locations and normalized by the number
|
||||
of valid locations.
|
||||
"""
|
||||
Bounding box size loss. Only compute loss where there is a bounding box.
|
||||
"""
|
||||
gt_size_mask = (gt_size > 0).float()
|
||||
return F.l1_loss(pred_size * gt_size_mask, gt_size, reduction="sum") / (
|
||||
gt_size_mask.sum() + 1e-5
|
||||
)
|
||||
|
||||
def forward(self, pred: torch.Tensor, gt: torch.Tensor) -> torch.Tensor:
|
||||
"""Calculate masked L1 loss for size prediction.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred : torch.Tensor
|
||||
Predicted size tensor, typically shape `(B, 2, H, W)`, where
|
||||
channels represent scaled width and height.
|
||||
gt : torch.Tensor
|
||||
Ground truth size tensor, same shape as `pred`. Non-zero values
|
||||
indicate locations and target sizes of actual annotations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Scalar tensor representing the calculated masked L1 loss.
|
||||
"""
|
||||
gt_size_mask = (gt > 0).float()
|
||||
masked_pred = pred * gt_size_mask
|
||||
loss = F.l1_loss(masked_pred, gt, reduction="sum")
|
||||
num_pos = gt_size_mask.sum() + 1e-5
|
||||
return loss / num_pos
|
||||
|
||||
|
||||
class FocalLossConfig(BaseConfig):
|
||||
"""Configuration parameters for the Focal Loss function.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
beta : float, default=4
|
||||
Exponent controlling the down-weighting of easy negative examples.
|
||||
Higher values increase down-weighting (focus more on hard negatives).
|
||||
alpha : float, default=2
|
||||
Exponent controlling the down-weighting based on prediction confidence.
|
||||
Higher values focus more on misclassified examples (both positive and
|
||||
negative).
|
||||
"""
|
||||
|
||||
beta: float = 4
|
||||
alpha: float = 2
|
||||
|
||||
|
||||
def focal_loss(
|
||||
pred: torch.Tensor,
|
||||
gt: torch.Tensor,
|
||||
weights: Optional[torch.Tensor] = None,
|
||||
valid_mask: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-5,
|
||||
beta: float = 4,
|
||||
alpha: float = 2,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
|
||||
pred (batch x c x h x w)
|
||||
gt (batch x c x h x w)
|
||||
class FocalLoss(nn.Module):
|
||||
"""Focal Loss implementation, adapted from CornerNet.
|
||||
|
||||
Addresses class imbalance in dense object detection/classification tasks by
|
||||
down-weighting the loss contribution from easy examples (both positive and
|
||||
negative), allowing the model to focus more on hard-to-classify examples.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
eps : float, default=1e-5
|
||||
Small epsilon value added for numerical stability.
|
||||
beta : float, default=4
|
||||
Exponent focusing on hard negative examples (modulates `(1-gt)^beta`).
|
||||
alpha : float, default=2
|
||||
Exponent focusing on misclassified examples (modulates `(1-p)^alpha`
|
||||
for positives and `p^alpha` for negatives).
|
||||
class_weights : torch.Tensor, optional
|
||||
Optional tensor containing weights for each class (applied to positive
|
||||
loss). Shape should be broadcastable to the channel dimension of the
|
||||
input tensors.
|
||||
mask_zero : bool, default=False
|
||||
If True, ignores loss contributions from spatial locations where the
|
||||
ground truth `gt` tensor is zero across *all* channels. Useful for
|
||||
classification heatmaps where some areas might have no assigned class.
|
||||
|
||||
References
|
||||
----------
|
||||
- Lin, T. Y., et al. "Focal loss for dense object detection." ICCV 2017.
|
||||
- Law, H., & Deng, J. "CornerNet: Detecting Objects as Paired Keypoints."
|
||||
ECCV 2018.
|
||||
"""
|
||||
|
||||
pos_inds = gt.eq(1).float()
|
||||
neg_inds = gt.lt(1).float()
|
||||
def __init__(
|
||||
self,
|
||||
eps: float = 1e-5,
|
||||
beta: float = 4,
|
||||
alpha: float = 2,
|
||||
class_weights: Optional[torch.Tensor] = None,
|
||||
mask_zero: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.class_weights = class_weights
|
||||
self.eps = eps
|
||||
self.beta = beta
|
||||
self.alpha = alpha
|
||||
self.mask_zero = mask_zero
|
||||
|
||||
pos_loss = torch.log(pred + eps) * torch.pow(1 - pred, alpha) * pos_inds
|
||||
neg_loss = (
|
||||
torch.log(1 - pred + eps)
|
||||
* torch.pow(pred, alpha)
|
||||
* torch.pow(1 - gt, beta)
|
||||
* neg_inds
|
||||
)
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
gt: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute the Focal Loss.
|
||||
|
||||
if weights is not None:
|
||||
pos_loss = pos_loss * torch.tensor(weights)
|
||||
# neg_loss = neg_loss*weights
|
||||
Parameters
|
||||
----------
|
||||
pred : torch.Tensor
|
||||
Predicted probabilities or logits (typically sigmoid output for
|
||||
detection, or softmax/sigmoid for classification). Must be in the
|
||||
range [0, 1] after potential activation. Shape `(B, C, H, W)`.
|
||||
gt : torch.Tensor
|
||||
Ground truth heatmap tensor. Shape `(B, C, H, W)`. Values typically
|
||||
represent target probabilities (e.g., Gaussian peaks for detection,
|
||||
one-hot encoding or smoothed labels for classification). For the
|
||||
adapted CornerNet loss, `gt=1` indicates a positive location, and
|
||||
values `< 1` indicate negative locations (with potential Gaussian
|
||||
weighting `(1-gt)^beta` for negatives near positives).
|
||||
|
||||
if valid_mask is not None:
|
||||
pos_loss = pos_loss * valid_mask
|
||||
neg_loss = neg_loss * valid_mask
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Scalar tensor representing the computed focal loss, normalized by
|
||||
the number of positive locations.
|
||||
"""
|
||||
|
||||
pos_loss = pos_loss.sum()
|
||||
neg_loss = neg_loss.sum()
|
||||
pos_inds = gt.eq(1).float()
|
||||
neg_inds = gt.lt(1).float()
|
||||
|
||||
num_pos = pos_inds.float().sum()
|
||||
if num_pos == 0:
|
||||
loss = -neg_loss
|
||||
else:
|
||||
loss = -(pos_loss + neg_loss) / num_pos
|
||||
return loss
|
||||
pos_loss = (
|
||||
torch.log(pred + self.eps)
|
||||
* torch.pow(1 - pred, self.alpha)
|
||||
* pos_inds
|
||||
)
|
||||
neg_loss = (
|
||||
torch.log(1 - pred + self.eps)
|
||||
* torch.pow(pred, self.alpha)
|
||||
* torch.pow(1 - gt, self.beta)
|
||||
* neg_inds
|
||||
)
|
||||
|
||||
if self.class_weights is not None:
|
||||
pos_loss = pos_loss * torch.tensor(self.class_weights)
|
||||
|
||||
if self.mask_zero:
|
||||
valid_mask = gt.any(dim=1, keepdim=True).float()
|
||||
pos_loss = pos_loss * valid_mask
|
||||
neg_loss = neg_loss * valid_mask
|
||||
|
||||
pos_loss = pos_loss.sum()
|
||||
neg_loss = neg_loss.sum()
|
||||
|
||||
num_pos = pos_inds.float().sum()
|
||||
if num_pos == 0:
|
||||
loss = -neg_loss
|
||||
else:
|
||||
loss = -(pos_loss + neg_loss) / num_pos
|
||||
return loss
|
||||
|
||||
|
||||
def mse_loss(
|
||||
pred: torch.Tensor,
|
||||
gt: torch.Tensor,
|
||||
valid_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
class MSELoss(nn.Module):
|
||||
"""Mean Squared Error (MSE) Loss module.
|
||||
|
||||
Calculates the mean squared difference between predictions and ground
|
||||
truth. Optionally masks contributions where the ground truth is zero across
|
||||
channels.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mask_zero : bool, default=False
|
||||
If True, calculates the loss only over spatial locations (H, W) where
|
||||
at least one channel in the ground truth `gt` tensor is non-zero. The
|
||||
loss is then averaged over these valid locations. If False (default),
|
||||
the standard MSE over all elements is computed.
|
||||
"""
|
||||
Mean squared error loss.
|
||||
"""
|
||||
if valid_mask is None:
|
||||
op = ((gt - pred) ** 2).mean()
|
||||
else:
|
||||
op = (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
|
||||
return op
|
||||
|
||||
def __init__(self, mask_zero: bool = False):
|
||||
super().__init__()
|
||||
self.mask_zero = mask_zero
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
gt: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute the Mean Squared Error loss.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred : torch.Tensor
|
||||
Predicted tensor, shape `(B, C, H, W)`.
|
||||
gt : torch.Tensor
|
||||
Ground truth tensor, shape `(B, C, H, W)`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Scalar tensor representing the calculated MSE loss.
|
||||
"""
|
||||
if not self.mask_zero:
|
||||
return ((gt - pred) ** 2).mean()
|
||||
|
||||
valid_mask = gt.any(dim=1, keepdim=True).float()
|
||||
return (valid_mask * ((gt - pred) ** 2)).sum() / valid_mask.sum()
|
||||
|
||||
|
||||
class DetectionLossConfig(BaseConfig):
|
||||
"""Configuration for the detection loss component.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
weight : float, default=1.0
|
||||
Weighting factor for the detection loss in the combined total loss.
|
||||
focal : FocalLossConfig
|
||||
Configuration for the Focal Loss used for detection. Defaults to
|
||||
standard Focal Loss parameters (`alpha=2`, `beta=4`).
|
||||
"""
|
||||
|
||||
weight: float = 1.0
|
||||
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
|
||||
|
||||
|
||||
class ClassificationLossConfig(BaseConfig):
|
||||
"""Configuration for the classification loss component.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
weight : float, default=2.0
|
||||
Weighting factor for the classification loss in the combined total loss.
|
||||
focal : FocalLossConfig
|
||||
Configuration for the Focal Loss used for classification. Defaults to
|
||||
standard Focal Loss parameters (`alpha=2`, `beta=4`).
|
||||
"""
|
||||
|
||||
weight: float = 2.0
|
||||
focal: FocalLossConfig = Field(default_factory=FocalLossConfig)
|
||||
class_weights: Optional[list[float]] = None
|
||||
|
||||
|
||||
class LossConfig(BaseConfig):
|
||||
"""Aggregated configuration for all loss components.
|
||||
|
||||
Defines the configuration and weighting for detection, size regression,
|
||||
and classification losses used in the main `LossFunction`.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
detection : DetectionLossConfig
|
||||
Configuration for the detection loss (Focal Loss).
|
||||
size : SizeLossConfig
|
||||
Configuration for the size regression loss (L1 loss).
|
||||
classification : ClassificationLossConfig
|
||||
Configuration for the classification loss (Focal Loss).
|
||||
"""
|
||||
|
||||
detection: DetectionLossConfig = Field(default_factory=DetectionLossConfig)
|
||||
size: SizeLossConfig = Field(default_factory=SizeLossConfig)
|
||||
classification: ClassificationLossConfig = Field(
|
||||
@ -117,50 +325,157 @@ class LossConfig(BaseConfig):
|
||||
)
|
||||
|
||||
|
||||
class Losses(NamedTuple):
|
||||
detection: torch.Tensor
|
||||
size: torch.Tensor
|
||||
classification: torch.Tensor
|
||||
total: torch.Tensor
|
||||
class LossFunction(nn.Module, LossProtocol):
|
||||
"""Computes the combined training loss for the BatDetect2 model.
|
||||
|
||||
Aggregates individual loss functions for detection, size regression, and
|
||||
classification tasks. Calculates each component loss based on model outputs
|
||||
and ground truth targets, applies configured weights, and sums them to get
|
||||
the final total loss used for optimization. Also returns individual
|
||||
components for monitoring.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
size_loss : nn.Module
|
||||
Instantiated loss module for size regression (e.g., `BBoxLoss`).
|
||||
detection_loss : nn.Module
|
||||
Instantiated loss module for detection (e.g., `FocalLoss`).
|
||||
classification_loss : nn.Module
|
||||
Instantiated loss module for classification (e.g., `FocalLoss`).
|
||||
size_weight : float, default=0.1
|
||||
Weighting factor for the size loss component.
|
||||
detection_weight : float, default=1.0
|
||||
Weighting factor for the detection loss component.
|
||||
classification_weight : float, default=2.0
|
||||
Weighting factor for the classification loss component.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
size_loss_fn : nn.Module
|
||||
detection_loss_fn : nn.Module
|
||||
classification_loss_fn : nn.Module
|
||||
size_weight : float
|
||||
detection_weight : float
|
||||
classification_weight : float
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size_loss: nn.Module,
|
||||
detection_loss: nn.Module,
|
||||
classification_loss: nn.Module,
|
||||
size_weight: float = 0.1,
|
||||
detection_weight: float = 1.0,
|
||||
classification_weight: float = 2.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.size_loss_fn = size_loss
|
||||
self.detection_loss_fn = detection_loss
|
||||
self.classification_loss_fn = classification_loss
|
||||
|
||||
self.size_weight = size_weight
|
||||
self.detection_weight = detection_weight
|
||||
self.classification_weight = classification_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred: ModelOutput,
|
||||
gt: TrainExample,
|
||||
) -> Losses:
|
||||
"""Calculate the combined loss and individual components.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred: ModelOutput
|
||||
A NamedTuple containing the model's prediction tensors for the
|
||||
batch: `detection_probs`, `size_preds`, `class_probs`.
|
||||
gt: TrainExample
|
||||
A structure containing the ground truth targets for the batch,
|
||||
expected to have attributes like `detection_heatmap`,
|
||||
`size_heatmap`, and `class_heatmap` (as `torch.Tensor`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Losses
|
||||
A NamedTuple containing the scalar loss values for detection, size,
|
||||
classification, and the total weighted loss.
|
||||
"""
|
||||
size_loss = self.size_loss_fn(pred.size_preds, gt.size_heatmap)
|
||||
detection_loss = self.detection_loss_fn(
|
||||
pred.detection_probs,
|
||||
gt.detection_heatmap,
|
||||
)
|
||||
classification_loss = self.classification_loss_fn(
|
||||
pred.class_probs,
|
||||
gt.class_heatmap,
|
||||
)
|
||||
total_loss = (
|
||||
size_loss * self.size_weight
|
||||
+ classification_loss * self.classification_weight
|
||||
+ detection_loss * self.detection_weight
|
||||
)
|
||||
return Losses(
|
||||
detection=detection_loss,
|
||||
size=size_loss,
|
||||
classification=classification_loss,
|
||||
total=total_loss,
|
||||
)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
batch: TrainExample,
|
||||
outputs: ModelOutput,
|
||||
conf: LossConfig,
|
||||
class_weights: Optional[torch.Tensor] = None,
|
||||
) -> Losses:
|
||||
detection_loss = focal_loss(
|
||||
outputs.detection_probs,
|
||||
batch.detection_heatmap,
|
||||
beta=conf.detection.focal.beta,
|
||||
alpha=conf.detection.focal.alpha,
|
||||
def build_loss(
|
||||
config: Optional[LossConfig] = None,
|
||||
class_weights: Optional[np.ndarray] = None,
|
||||
) -> nn.Module:
|
||||
"""Factory function to build the main LossFunction from configuration.
|
||||
|
||||
Instantiates the necessary loss components (`BBoxLoss`, `FocalLoss`) based
|
||||
on the provided `LossConfig` (or defaults) and optional `class_weights`,
|
||||
then assembles them into the main `LossFunction` module with the specified
|
||||
component weights.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : LossConfig, optional
|
||||
Configuration object defining weights and parameters (e.g., alpha, beta
|
||||
for Focal Loss) for each loss component. If None, default settings
|
||||
from `LossConfig` and its nested configs are used.
|
||||
class_weights : np.ndarray, optional
|
||||
An array of weights for each specific class, used to adjust the
|
||||
classification loss (typically Focal Loss). If provided, this overrides
|
||||
any `class_weights` specified within `config.classification`. If None,
|
||||
weights from the config (or default of equal weights) are used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LossFunction
|
||||
An initialized `LossFunction` module ready for training.
|
||||
"""
|
||||
config = config or LossConfig()
|
||||
|
||||
class_weights_tensor = (
|
||||
torch.tensor(class_weights) if class_weights else None
|
||||
)
|
||||
|
||||
size_loss = bbox_size_loss(
|
||||
outputs.size_preds,
|
||||
batch.size_heatmap,
|
||||
detection_loss_fn = FocalLoss(
|
||||
beta=config.detection.focal.beta,
|
||||
alpha=config.detection.focal.alpha,
|
||||
mask_zero=False,
|
||||
)
|
||||
|
||||
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
||||
classification_loss = focal_loss(
|
||||
outputs.class_probs,
|
||||
batch.class_heatmap,
|
||||
weights=class_weights,
|
||||
valid_mask=valid_mask,
|
||||
beta=conf.classification.focal.beta,
|
||||
alpha=conf.classification.focal.alpha,
|
||||
classification_loss_fn = FocalLoss(
|
||||
beta=config.classification.focal.beta,
|
||||
alpha=config.classification.focal.alpha,
|
||||
class_weights=class_weights_tensor,
|
||||
mask_zero=True,
|
||||
)
|
||||
|
||||
total = (
|
||||
detection_loss * conf.detection.weight
|
||||
+ size_loss * conf.size.weight
|
||||
+ classification_loss * conf.classification.weight
|
||||
)
|
||||
size_loss_fn = BBoxLoss()
|
||||
|
||||
return Losses(
|
||||
detection=detection_loss,
|
||||
size=size_loss,
|
||||
classification=classification_loss,
|
||||
total=total,
|
||||
return LossFunction(
|
||||
size_loss=size_loss_fn,
|
||||
classification_loss=classification_loss_fn,
|
||||
detection_loss=detection_loss_fn,
|
||||
size_weight=config.size.weight,
|
||||
detection_weight=config.detection.weight,
|
||||
classification_weight=config.classification.weight,
|
||||
)
|
||||
|
@ -1,97 +1,102 @@
|
||||
"""Module for preprocessing data for training."""
|
||||
"""Preprocesses datasets for BatDetect2 model training.
|
||||
|
||||
This module provides functions to take a collection of annotated audio clips
|
||||
(`soundevent.data.ClipAnnotation`) and process them into the final format
|
||||
required for training a BatDetect2 model. This typically involves:
|
||||
|
||||
1. Loading the relevant audio segment for each annotation using a configured
|
||||
`PreprocessorProtocol`.
|
||||
2. Generating the corresponding input spectrogram using the
|
||||
`PreprocessorProtocol`.
|
||||
3. Generating the target heatmaps (detection, classification, size) using a
|
||||
configured `ClipLabeller` (which encapsulates the `TargetProtocol` logic).
|
||||
4. Packaging the input spectrogram, target heatmaps, and potentially the
|
||||
processed audio waveform into an `xarray.Dataset`.
|
||||
5. Saving each processed `xarray.Dataset` to a separate file (typically NetCDF)
|
||||
in an output directory.
|
||||
|
||||
This offline preprocessing is often preferred for large datasets as it avoids
|
||||
computationally intensive steps during the actual training loop. The module
|
||||
includes utilities for parallel processing using `multiprocessing`.
|
||||
"""
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
from multiprocessing import Pool
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Sequence, Union
|
||||
from typing import Callable, Optional, Sequence
|
||||
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessingConfig,
|
||||
compute_spectrogram,
|
||||
load_clip_audio,
|
||||
)
|
||||
from batdetect2.train.labels import LabelConfig, generate_heatmaps
|
||||
from batdetect2.train.targets import (
|
||||
TargetConfig,
|
||||
build_target_encoder,
|
||||
build_sound_event_filter,
|
||||
get_class_names,
|
||||
)
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.train.types import ClipLabeller
|
||||
|
||||
__all__ = [
|
||||
"preprocess_annotations",
|
||||
"preprocess_single_annotation",
|
||||
"generate_train_example",
|
||||
"TrainPreprocessingConfig",
|
||||
]
|
||||
|
||||
|
||||
class TrainPreprocessingConfig(BaseConfig):
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
target: TargetConfig = Field(default_factory=TargetConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||
"""Type alias for a function that generates an output filename."""
|
||||
|
||||
|
||||
def generate_train_example(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
target_config: Optional[TargetConfig] = None,
|
||||
label_config: Optional[LabelConfig] = None,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
) -> xr.Dataset:
|
||||
"""Generate a training example."""
|
||||
config = TrainPreprocessingConfig(
|
||||
preprocessing=preprocessing_config or PreprocessingConfig(),
|
||||
target=target_config or TargetConfig(),
|
||||
labels=label_config or LabelConfig(),
|
||||
)
|
||||
"""Generate a complete training example for one annotation.
|
||||
|
||||
wave = load_clip_audio(
|
||||
clip_annotation.clip,
|
||||
config=config.preprocessing.audio,
|
||||
)
|
||||
This function takes a single `ClipAnnotation`, applies the configured
|
||||
preprocessing (`PreprocessorProtocol`) to get the processed waveform and
|
||||
input spectrogram, applies the configured target generation
|
||||
(`ClipLabeller`) to get the target heatmaps, and packages them all into a
|
||||
single `xr.Dataset`.
|
||||
|
||||
spectrogram = compute_spectrogram(
|
||||
wave,
|
||||
config=config.preprocessing.spectrogram,
|
||||
)
|
||||
Parameters
|
||||
----------
|
||||
clip_annotation : data.ClipAnnotation
|
||||
The annotated clip to process. Contains the reference to the `Clip`
|
||||
(audio segment) and the associated `SoundEventAnnotation` objects.
|
||||
preprocessor : PreprocessorProtocol
|
||||
An initialized preprocessor object responsible for loading/processing
|
||||
audio and computing the input spectrogram.
|
||||
labeller : ClipLabeller
|
||||
An initialized clip labeller function responsible for generating the
|
||||
target heatmaps (detection, class, size) from the `clip_annotation`
|
||||
and the computed spectrogram.
|
||||
|
||||
filter_fn = build_sound_event_filter(
|
||||
include=config.target.include,
|
||||
exclude=config.target.exclude,
|
||||
)
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
An xarray Dataset containing the following data variables:
|
||||
- `audio`: The preprocessed audio waveform (dims: 'audio_time').
|
||||
- `spectrogram`: The computed input spectrogram
|
||||
(dims: 'time', 'frequency').
|
||||
- `detection`: The target detection heatmap
|
||||
(dims: 'time', 'frequency').
|
||||
- `class`: The target class heatmap
|
||||
(dims: 'category', 'time', 'frequency').
|
||||
- `size`: The target size heatmap
|
||||
(dims: 'dimension', 'time', 'frequency').
|
||||
The Dataset also includes metadata in its attributes.
|
||||
|
||||
selected_events = [
|
||||
event for event in clip_annotation.sound_events if filter_fn(event)
|
||||
]
|
||||
Notes
|
||||
-----
|
||||
- The 'time' dimension of the 'audio' DataArray is renamed to 'audio_time'
|
||||
within the output Dataset to avoid coordinate conflicts with the
|
||||
spectrogram's 'time' dimension when stored together.
|
||||
- The original `ClipAnnotation` metadata is stored as a JSON string in the
|
||||
Dataset's attributes for provenance.
|
||||
"""
|
||||
wave = preprocessor.load_clip_audio(clip_annotation.clip)
|
||||
|
||||
encoder = build_target_encoder(
|
||||
config.target.classes,
|
||||
replacement_rules=config.target.replace,
|
||||
)
|
||||
class_names = get_class_names(config.target.classes)
|
||||
spectrogram = preprocessor.compute_spectrogram(wave)
|
||||
|
||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||
selected_events,
|
||||
spectrogram,
|
||||
class_names,
|
||||
encoder,
|
||||
target_sigma=config.labels.heatmaps.sigma,
|
||||
position=config.labels.heatmaps.position,
|
||||
time_scale=config.labels.heatmaps.time_scale,
|
||||
frequency_scale=config.labels.heatmaps.frequency_scale,
|
||||
)
|
||||
heatmaps = labeller(clip_annotation, spectrogram)
|
||||
|
||||
dataset = xr.Dataset(
|
||||
{
|
||||
@ -101,15 +106,14 @@ def generate_train_example(
|
||||
# as the waveform.
|
||||
"audio": wave.rename({"time": "audio_time"}),
|
||||
"spectrogram": spectrogram,
|
||||
"detection": detection_heatmap,
|
||||
"class": class_heatmap,
|
||||
"size": size_heatmap,
|
||||
"detection": heatmaps.detection,
|
||||
"class": heatmaps.classes,
|
||||
"size": heatmaps.size,
|
||||
}
|
||||
)
|
||||
|
||||
return dataset.assign_attrs(
|
||||
title=f"Training example for {clip_annotation.uuid}",
|
||||
config=config.model_dump_json(),
|
||||
clip_annotation=clip_annotation.model_dump_json(
|
||||
exclude_none=True,
|
||||
exclude_defaults=True,
|
||||
@ -118,13 +122,25 @@ def generate_train_example(
|
||||
)
|
||||
|
||||
|
||||
def save_to_file(
|
||||
def _save_xr_dataset_to_file(
|
||||
dataset: xr.Dataset,
|
||||
path: PathLike,
|
||||
path: data.PathLike,
|
||||
) -> None:
|
||||
"""Save an xarray Dataset to a NetCDF file with compression.
|
||||
|
||||
Internal helper function used by `preprocess_single_annotation`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : xr.Dataset
|
||||
The training example dataset to save.
|
||||
path : PathLike
|
||||
The output file path (e.g., 'output/uuid.nc').
|
||||
"""
|
||||
dataset.to_netcdf(
|
||||
path,
|
||||
encoding={
|
||||
"audio": {"zlib": True},
|
||||
"spectrogram": {"zlib": True},
|
||||
"size": {"zlib": True},
|
||||
"class": {"zlib": True},
|
||||
@ -134,20 +150,60 @@ def save_to_file(
|
||||
|
||||
|
||||
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
||||
"""Generate a default output filename based on the annotation UUID."""
|
||||
return f"{clip_annotation.uuid}.nc"
|
||||
|
||||
|
||||
def preprocess_annotations(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
output_dir: PathLike,
|
||||
output_dir: data.PathLike,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
target_config: Optional[TargetConfig] = None,
|
||||
label_config: Optional[LabelConfig] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Preprocess annotations and save to disk."""
|
||||
"""Preprocess a sequence of ClipAnnotations and save results to disk.
|
||||
|
||||
Generates the full training example (spectrogram, heatmaps, etc.) for each
|
||||
`ClipAnnotation` in the input sequence using the provided `preprocessor`
|
||||
and `labeller`. Saves each example as a separate NetCDF file in the
|
||||
`output_dir`. Utilizes multiprocessing for potentially faster processing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip_annotations : Sequence[data.ClipAnnotation]
|
||||
A sequence (e.g., list) of the clip annotations to preprocess.
|
||||
output_dir : PathLike
|
||||
Path to the directory where the processed NetCDF files will be saved.
|
||||
Will be created if it doesn't exist.
|
||||
preprocessor : PreprocessorProtocol
|
||||
Initialized preprocessor object to generate spectrograms.
|
||||
labeller : ClipLabeller
|
||||
Initialized labeller function to generate target heatmaps.
|
||||
filename_fn : FilenameFn, optional
|
||||
Function to generate the output filename (without extension) for each
|
||||
`ClipAnnotation`. Defaults to using the annotation UUID via
|
||||
`_get_filename`.
|
||||
replace : bool, default=False
|
||||
If True, existing files in `output_dir` with the same generated name
|
||||
will be overwritten. If False (default), existing files are skipped.
|
||||
max_workers : int, optional
|
||||
Maximum number of worker processes to use for parallel processing.
|
||||
If None (default), uses the number of CPUs available (`os.cpu_count()`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
This function does not return anything; its side effect is creating
|
||||
files in the `output_dir`.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If processing fails for any individual annotation when using
|
||||
multiprocessing. The original exception will be attached as the cause.
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
if not output_dir.is_dir():
|
||||
@ -162,9 +218,8 @@ def preprocess_annotations(
|
||||
output_dir=output_dir,
|
||||
filename_fn=filename_fn,
|
||||
replace=replace,
|
||||
preprocessing_config=preprocessing_config,
|
||||
target_config=target_config,
|
||||
label_config=label_config,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
),
|
||||
clip_annotations,
|
||||
),
|
||||
@ -175,13 +230,34 @@ def preprocess_annotations(
|
||||
|
||||
def preprocess_single_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
output_dir: PathLike,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
target_config: Optional[TargetConfig] = None,
|
||||
label_config: Optional[LabelConfig] = None,
|
||||
output_dir: data.PathLike,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
) -> None:
|
||||
"""Process a single ClipAnnotation and save the result to a file.
|
||||
|
||||
Internal function designed to be called by `preprocess_annotations`, often
|
||||
in parallel worker processes. It generates the training example using
|
||||
`generate_train_example` and saves it using `save_to_file`. Handles
|
||||
file existence checks based on the `replace` flag.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip_annotation : data.ClipAnnotation
|
||||
The single annotation to process.
|
||||
output_dir : Path
|
||||
The directory where the output NetCDF file should be saved.
|
||||
preprocessor : PreprocessorProtocol
|
||||
Initialized preprocessor object.
|
||||
labeller : ClipLabeller
|
||||
Initialized labeller function.
|
||||
filename_fn : FilenameFn, default=_get_filename
|
||||
Function to determine the output filename.
|
||||
replace : bool, default=False
|
||||
Whether to overwrite existing output files.
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
filename = filename_fn(clip_annotation)
|
||||
@ -196,13 +272,12 @@ def preprocess_single_annotation(
|
||||
try:
|
||||
sample = generate_train_example(
|
||||
clip_annotation,
|
||||
preprocessing_config=preprocessing_config,
|
||||
target_config=target_config,
|
||||
label_config=label_config,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
)
|
||||
except Exception as error:
|
||||
raise RuntimeError(
|
||||
f"Failed to process annotation: {clip_annotation.uuid}"
|
||||
) from error
|
||||
|
||||
save_to_file(sample, path)
|
||||
_save_xr_dataset_to_file(sample, path)
|
||||
|
@ -1,181 +0,0 @@
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Set
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.terms import TagInfo, get_tag_from_info
|
||||
|
||||
__all__ = [
|
||||
"TargetConfig",
|
||||
"load_target_config",
|
||||
"build_target_encoder",
|
||||
"build_decoder",
|
||||
"filter_sound_event",
|
||||
]
|
||||
|
||||
|
||||
class ReplaceConfig(BaseConfig):
|
||||
"""Configuration for replacing tags."""
|
||||
|
||||
original: TagInfo
|
||||
replacement: TagInfo
|
||||
|
||||
|
||||
class TargetConfig(BaseConfig):
|
||||
"""Configuration for target generation."""
|
||||
|
||||
classes: List[TagInfo] = Field(
|
||||
default_factory=lambda: [
|
||||
TagInfo(key="class", value=value) for value in DEFAULT_SPECIES_LIST
|
||||
]
|
||||
)
|
||||
generic_class: Optional[TagInfo] = Field(
|
||||
default_factory=lambda: TagInfo(key="class", value="Bat")
|
||||
)
|
||||
|
||||
include: Optional[List[TagInfo]] = Field(
|
||||
default_factory=lambda: [TagInfo(key="event", value="Echolocation")]
|
||||
)
|
||||
|
||||
exclude: Optional[List[TagInfo]] = Field(
|
||||
default_factory=lambda: [
|
||||
TagInfo(key="class", value=""),
|
||||
TagInfo(key="class", value=" "),
|
||||
TagInfo(key="class", value="Unknown"),
|
||||
]
|
||||
)
|
||||
|
||||
replace: Optional[List[ReplaceConfig]] = None
|
||||
|
||||
|
||||
def build_sound_event_filter(
|
||||
include: Optional[List[TagInfo]] = None,
|
||||
exclude: Optional[List[TagInfo]] = None,
|
||||
) -> Callable[[data.SoundEventAnnotation], bool]:
|
||||
include_tags = (
|
||||
{get_tag_from_info(tag) for tag in include} if include else None
|
||||
)
|
||||
exclude_tags = (
|
||||
{get_tag_from_info(tag) for tag in exclude} if exclude else None
|
||||
)
|
||||
return partial(
|
||||
filter_sound_event,
|
||||
include=include_tags,
|
||||
exclude=exclude_tags,
|
||||
)
|
||||
|
||||
|
||||
def get_tag_label(tag_info: TagInfo) -> str:
|
||||
return tag_info.label if tag_info.label else tag_info.value
|
||||
|
||||
|
||||
def get_class_names(classes: List[TagInfo]) -> List[str]:
|
||||
return sorted({get_tag_label(tag) for tag in classes})
|
||||
|
||||
|
||||
def build_replacer(
|
||||
rules: List[ReplaceConfig],
|
||||
) -> Callable[[data.Tag], data.Tag]:
|
||||
mapping = {
|
||||
get_tag_from_info(rule.original): get_tag_from_info(rule.replacement)
|
||||
for rule in rules
|
||||
}
|
||||
|
||||
def replacer(tag: data.Tag) -> data.Tag:
|
||||
return mapping.get(tag, tag)
|
||||
|
||||
return replacer
|
||||
|
||||
|
||||
def build_target_encoder(
|
||||
classes: List[TagInfo],
|
||||
replacement_rules: Optional[List[ReplaceConfig]] = None,
|
||||
) -> Callable[[Iterable[data.Tag]], Optional[str]]:
|
||||
target_tags = set([get_tag_from_info(tag) for tag in classes])
|
||||
|
||||
tag_mapping = {
|
||||
tag: get_tag_label(tag_info)
|
||||
for tag, tag_info in zip(target_tags, classes)
|
||||
}
|
||||
|
||||
replacer = (
|
||||
build_replacer(replacement_rules) if replacement_rules else lambda x: x
|
||||
)
|
||||
|
||||
def encoder(
|
||||
tags: Iterable[data.Tag],
|
||||
) -> Optional[str]:
|
||||
sanitized_tags = {replacer(tag) for tag in tags}
|
||||
|
||||
intersection = sanitized_tags & target_tags
|
||||
|
||||
if not intersection:
|
||||
return None
|
||||
|
||||
first = intersection.pop()
|
||||
return tag_mapping[first]
|
||||
|
||||
return encoder
|
||||
|
||||
|
||||
def build_decoder(
|
||||
classes: List[TagInfo],
|
||||
) -> Callable[[str], List[data.Tag]]:
|
||||
target_tags = set([get_tag_from_info(tag) for tag in classes])
|
||||
tag_mapping = {
|
||||
get_tag_label(tag_info): tag
|
||||
for tag, tag_info in zip(target_tags, classes)
|
||||
}
|
||||
|
||||
def decoder(label: str) -> List[data.Tag]:
|
||||
tag = tag_mapping.get(label)
|
||||
return [tag] if tag else []
|
||||
|
||||
return decoder
|
||||
|
||||
|
||||
def filter_sound_event(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
include: Optional[Set[data.Tag]] = None,
|
||||
exclude: Optional[Set[data.Tag]] = None,
|
||||
) -> bool:
|
||||
tags = set(sound_event_annotation.tags)
|
||||
|
||||
if include is not None and not tags & include:
|
||||
return False
|
||||
|
||||
if exclude is not None and tags & exclude:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def load_target_config(
|
||||
path: Path, field: Optional[str] = None
|
||||
) -> TargetConfig:
|
||||
return load_config(path, schema=TargetConfig, field=field)
|
||||
|
||||
|
||||
DEFAULT_SPECIES_LIST = [
|
||||
"Barbastellus barbastellus",
|
||||
"Eptesicus serotinus",
|
||||
"Myotis alcathoe",
|
||||
"Myotis bechsteinii",
|
||||
"Myotis brandtii",
|
||||
"Myotis daubentonii",
|
||||
"Myotis mystacinus",
|
||||
"Myotis nattereri",
|
||||
"Nyctalus leisleri",
|
||||
"Nyctalus noctula",
|
||||
"Pipistrellus nathusii",
|
||||
"Pipistrellus pipistrellus",
|
||||
"Pipistrellus pygmaeus",
|
||||
"Plecotus auritus",
|
||||
"Plecotus austriacus",
|
||||
"Rhinolophus ferrumequinum",
|
||||
"Rhinolophus hipposideros",
|
||||
]
|
100
batdetect2/train/types.py
Normal file
100
batdetect2/train/types.py
Normal file
@ -0,0 +1,100 @@
|
||||
from typing import Callable, NamedTuple, Protocol, Tuple
|
||||
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models import ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"Heatmaps",
|
||||
"ClipLabeller",
|
||||
"Augmentation",
|
||||
"LossProtocol",
|
||||
"TrainExample",
|
||||
]
|
||||
|
||||
|
||||
class Heatmaps(NamedTuple):
|
||||
"""Structure holding the generated heatmap targets.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
detection : xr.DataArray
|
||||
Heatmap indicating the probability of sound event presence. Typically
|
||||
smoothed with a Gaussian kernel centered on event reference points.
|
||||
Shape matches the input spectrogram. Values normalized [0, 1].
|
||||
classes : xr.DataArray
|
||||
Heatmap indicating the probability of specific class presence. Has an
|
||||
additional 'category' dimension corresponding to the target class
|
||||
names. Each category slice is typically smoothed with a Gaussian
|
||||
kernel. Values normalized [0, 1] per category.
|
||||
size : xr.DataArray
|
||||
Heatmap encoding the size (width, height) of detected events. Has an
|
||||
additional 'dimension' coordinate ('width', 'height'). Values represent
|
||||
scaled dimensions placed at the event reference points.
|
||||
"""
|
||||
|
||||
detection: xr.DataArray
|
||||
classes: xr.DataArray
|
||||
size: xr.DataArray
|
||||
|
||||
|
||||
ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps]
|
||||
"""Type alias for the final clip labelling function.
|
||||
|
||||
This function takes the complete annotations for a clip and the corresponding
|
||||
spectrogram, applies all configured filtering, transformation, and encoding
|
||||
steps, and returns the final `Heatmaps` used for model training.
|
||||
"""
|
||||
|
||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||
|
||||
|
||||
class TrainExample(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
idx: torch.Tensor
|
||||
start_time: float
|
||||
end_time: float
|
||||
|
||||
|
||||
class Losses(NamedTuple):
|
||||
"""Structure to hold the computed loss values.
|
||||
|
||||
Allows returning individual loss components along with the total weighted
|
||||
loss for monitoring and analysis during training.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
detection : torch.Tensor
|
||||
Scalar tensor representing the calculated detection loss component
|
||||
(before weighting).
|
||||
size : torch.Tensor
|
||||
Scalar tensor representing the calculated size regression loss component
|
||||
(before weighting).
|
||||
classification : torch.Tensor
|
||||
Scalar tensor representing the calculated classification loss component
|
||||
(before weighting).
|
||||
total : torch.Tensor
|
||||
Scalar tensor representing the final combined loss, computed as the
|
||||
weighted sum of the detection, size, and classification components.
|
||||
This is the value typically used for backpropagation.
|
||||
"""
|
||||
|
||||
detection: torch.Tensor
|
||||
size: torch.Tensor
|
||||
classification: torch.Tensor
|
||||
total: torch.Tensor
|
||||
|
||||
|
||||
class LossProtocol(Protocol):
|
||||
def __call__(self, pred: ModelOutput, gt: TrainExample) -> Losses: ...
|
||||
|
||||
|
||||
class ClipperProtocol(Protocol):
|
||||
def extract_clip(
|
||||
self, example: xr.Dataset
|
||||
) -> Tuple[xr.Dataset, float, float]: ...
|
@ -1,14 +1,10 @@
|
||||
"""Types used in the code base."""
|
||||
|
||||
from typing import Any, List, NamedTuple, Optional
|
||||
|
||||
from typing import Any, List, NamedTuple, Optional, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
try:
|
||||
from typing import Protocol
|
||||
except ImportError:
|
||||
|
20
docs/Makefile
Normal file
20
docs/Makefile
Normal file
@ -0,0 +1,20 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= sphinx-build
|
||||
SOURCEDIR = source
|
||||
BUILDDIR = build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
35
docs/make.bat
Normal file
35
docs/make.bat
Normal file
@ -0,0 +1,35 @@
|
||||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=source
|
||||
set BUILDDIR=build
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.https://www.sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:end
|
||||
popd
|
61
docs/source/conf.py
Normal file
61
docs/source/conf.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# For the full list of built-in configuration values, see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
|
||||
project = "batdetect2"
|
||||
copyright = "2025, Oisin Mac Aodha, Santiago Martinez Balvanera"
|
||||
author = "Oisin Mac Aodha, Santiago Martinez Balvanera"
|
||||
release = "1.1.1"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
extensions = [
|
||||
"sphinx.ext.autodoc",
|
||||
"sphinx.ext.autosummary",
|
||||
"sphinx.ext.intersphinx",
|
||||
"sphinxcontrib.autodoc_pydantic",
|
||||
"numpydoc",
|
||||
"myst_parser",
|
||||
"sphinx_autodoc_typehints",
|
||||
]
|
||||
|
||||
templates_path = ["_templates"]
|
||||
exclude_patterns = []
|
||||
|
||||
source_suffix = {
|
||||
".rst": "restructuredtext",
|
||||
".txt": "markdown",
|
||||
".md": "markdown",
|
||||
}
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
html_theme = "sphinx_book_theme"
|
||||
html_static_path = ["_static"]
|
||||
|
||||
intersphinx_mapping = {
|
||||
"python": ("https://docs.python.org/3", None),
|
||||
"soundevent": ("https://mbsantiago.github.io/soundevent/", None),
|
||||
"pydantic": ("https://docs.pydantic.dev/latest/", None),
|
||||
"xarray": ("https://docs.xarray.dev/en/stable/", None),
|
||||
}
|
||||
|
||||
# -- Options for autodoc ------------------------------------------------------
|
||||
autosummary_generate = True
|
||||
autosummary_imported_members = True
|
||||
|
||||
autodoc_default_options = {
|
||||
"members": True,
|
||||
"undoc-members": False,
|
||||
"private-members": False,
|
||||
"special-members": False,
|
||||
"inherited-members": False,
|
||||
"show-inheritance": True,
|
||||
"module-first": True,
|
||||
}
|
106
docs/source/data/aoef.md
Normal file
106
docs/source/data/aoef.md
Normal file
@ -0,0 +1,106 @@
|
||||
# Using AOEF / Soundevent Data Sources
|
||||
|
||||
## Introduction
|
||||
|
||||
The **AOEF (Acoustic Open Event Format)**, stored as `.json` files, is the annotation format used by the underlying `soundevent` library and is compatible with annotation tools like **Whombat**.
|
||||
BatDetect2 can directly load annotation data stored in this format.
|
||||
|
||||
This format can represent two main types of annotation collections:
|
||||
|
||||
1. `AnnotationSet`: A straightforward collection of annotations for various audio clips.
|
||||
2. `AnnotationProject`: A more structured format often exported by annotation tools (like Whombat).
|
||||
It includes not only the annotations but also information about annotation _tasks_ (work assigned to annotators) and their status (e.g., in-progress, completed, verified, rejected).
|
||||
|
||||
This section explains how to configure a data source in your `DatasetConfig` to load data from either type of AOEF file.
|
||||
|
||||
## Configuration
|
||||
|
||||
To define a data source using the AOEF format, you add an entry to the `sources` list in your main `DatasetConfig` (usually within your primary YAML configuration file) and set the `format` field to `"aoef"`.
|
||||
|
||||
Here are the key fields you need to specify for an AOEF source:
|
||||
|
||||
- `format: "aoef"`: **(Required)** Tells BatDetect2 to use the AOEF loader for this source.
|
||||
- `name: your_source_name`: **(Required)** A unique name you choose for this data source (e.g., `"whombat_project_export"`, `"final_annotations"`).
|
||||
- `audio_dir: path/to/audio/files`: **(Required)** The path to the directory where the actual audio `.wav` files referenced in the annotations are located.
|
||||
- `annotations_path: path/to/your/annotations.aoef`: **(Required)** The path to the single `.aoef` or `.json` file containing the annotation data (either an `AnnotationSet` or an `AnnotationProject`).
|
||||
- `description: "Details about this source..."`: (Optional) A brief description of the data source.
|
||||
- `filter: ...`: **(Optional)** Specific settings used _only if_ the `annotations_path` file contains an `AnnotationProject`.
|
||||
See details below.
|
||||
|
||||
## Filtering Annotation Projects (Optional)
|
||||
|
||||
When working with annotation projects, especially collaborative ones or those still in progress (like exports from Whombat), you often want to train only on annotations that are considered complete and reliable.
|
||||
The optional `filter:` section allows you to specify criteria based on the status of the annotation _tasks_ within the project.
|
||||
|
||||
**If `annotations_path` points to a simple `AnnotationSet` file, the `filter:` section is ignored.**
|
||||
|
||||
If `annotations_path` points to an `AnnotationProject`, you can add a `filter:` block with the following options:
|
||||
|
||||
- `only_completed: <true_or_false>`:
|
||||
- `true` (Default): Only include annotations from tasks that have been marked as "completed".
|
||||
- `false`: Include annotations regardless of task completion status.
|
||||
- `only_verified: <true_or_false>`:
|
||||
- `false` (Default): Verification status is not considered.
|
||||
- `true`: Only include annotations from tasks that have _also_ been marked as "verified" (typically meaning they passed a review step).
|
||||
- `exclude_issues: <true_or_false>`:
|
||||
- `true` (Default): Exclude annotations from any task that has been marked as "rejected" or flagged with issues.
|
||||
- `false`: Include annotations even if their task was marked as having issues (use with caution).
|
||||
|
||||
**Default Filtering:** If you include the `filter:` block but omit some options, or if you _omit the entire `filter:` block_, the default settings are applied to `AnnotationProject` files: `only_completed: true`, `only_verified: false`, `exclude_issues: true`.
|
||||
This common default selects annotations from completed tasks that haven't been rejected, without requiring separate verification.
|
||||
|
||||
**Disabling Filtering:** If you want to load _all_ annotations from an `AnnotationProject` regardless of task status, you can explicitly disable filtering by setting `filter: null` in your YAML configuration.
|
||||
|
||||
## YAML Configuration Examples
|
||||
|
||||
**Example 1: Loading a standard AnnotationSet (or a Project with default filtering)**
|
||||
|
||||
```yaml
|
||||
# In your main DatasetConfig YAML file
|
||||
|
||||
sources:
|
||||
- name: "MyFinishedAnnotations"
|
||||
format: "aoef" # Specifies the loader
|
||||
audio_dir: "/path/to/my/audio/"
|
||||
annotations_path: "/path/to/my/dataset.soundevent.json" # Path to the AOEF file
|
||||
description: "Finalized annotations set."
|
||||
# No 'filter:' block means default filtering applied IF it's an AnnotationProject,
|
||||
# or no filtering applied if it's an AnnotationSet.
|
||||
```
|
||||
|
||||
**Example 2: Loading an AnnotationProject, requiring verification**
|
||||
|
||||
```yaml
|
||||
# In your main DatasetConfig YAML file
|
||||
|
||||
sources:
|
||||
- name: "WhombatVerifiedExport"
|
||||
format: "aoef"
|
||||
audio_dir: "relative/path/to/audio/" # Relative to where BatDetect2 runs or a base_dir
|
||||
annotations_path: "exports/whombat_project.aoef" # Path to the project file
|
||||
description: "Annotations from Whombat project, only using verified tasks."
|
||||
filter: # Customize the filter
|
||||
only_completed: true # Still require completion
|
||||
only_verified: true # *Also* require verification
|
||||
exclude_issues: true # Still exclude rejected tasks
|
||||
```
|
||||
|
||||
**Example 3: Loading an AnnotationProject, disabling all filtering**
|
||||
|
||||
```yaml
|
||||
# In your main DatasetConfig YAML file
|
||||
|
||||
sources:
|
||||
- name: "WhombatRawExport"
|
||||
format: "aoef"
|
||||
audio_dir: "data/audio_pool/"
|
||||
annotations_path: "exports/whombat_project_all.aoef"
|
||||
description: "All annotations from Whombat, regardless of task status."
|
||||
filter: null # Explicitly disable task filtering
|
||||
```
|
||||
|
||||
## Summary
|
||||
|
||||
To load standard `soundevent` annotations (including Whombat exports), set `format: "aoef"` for your data source in the `DatasetConfig`.
|
||||
Provide the `audio_dir` and the path to the single `annotations_path` file.
|
||||
If dealing with `AnnotationProject` files, you can optionally use the `filter:` block to select annotations based on task completion, verification, or issue status.
|
9
docs/source/data/index.md
Normal file
9
docs/source/data/index.md
Normal file
@ -0,0 +1,9 @@
|
||||
# Loading Data
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
:caption: Loading Data
|
||||
|
||||
aoef
|
||||
legacy
|
||||
```
|
122
docs/source/data/legacy.md
Normal file
122
docs/source/data/legacy.md
Normal file
@ -0,0 +1,122 @@
|
||||
# Using Legacy BatDetect2 Annotation Formats
|
||||
|
||||
## Introduction
|
||||
|
||||
If you have annotation data created using older BatDetect2 annotation tools, BatDetect2 provides tools to load these datasets.
|
||||
These older formats typically use JSON files to store annotation information, including bounding boxes and labels for sound events within recordings.
|
||||
|
||||
There are two main variations of this legacy format that BatDetect2 can load:
|
||||
|
||||
1. **Directory-Based (`format: "batdetect2"`):** Annotations for each audio recording are stored in a _separate_ JSON file within a dedicated directory.
|
||||
There's a naming convention linking the JSON file to its corresponding audio file (e.g., `my_recording.wav` annotations are stored in `my_recording.wav.json`).
|
||||
2. **Single Merged File (`format: "batdetect2_file"`):** Annotations for _multiple_ recordings are aggregated into a _single_ JSON file.
|
||||
This file contains a list, where each item represents the annotations for one recording, following the same internal structure as the directory-based format.
|
||||
|
||||
When you configure BatDetect2 to use these formats, it will read the legacy data and convert it internally into the standard `soundevent` data structures used by the rest of the pipeline.
|
||||
|
||||
## Configuration
|
||||
|
||||
You specify which legacy format to use within the `sources` list of your main `DatasetConfig` (usually in your primary YAML configuration file).
|
||||
|
||||
### Format 1: Directory-Based
|
||||
|
||||
Use this when you have a folder containing many individual JSON annotation files, one for each audio file.
|
||||
|
||||
**Configuration Fields:**
|
||||
|
||||
- `format: "batdetect2"`: **(Required)** Identifies this specific legacy format loader.
|
||||
- `name: your_source_name`: **(Required)** A unique name for this data source.
|
||||
- `audio_dir: path/to/audio/files`: **(Required)** Path to the directory containing the `.wav` audio files.
|
||||
- `annotations_dir: path/to/annotation/jsons`: **(Required)** Path to the directory containing the individual `.json` annotation files.
|
||||
- `description: "Details..."`: (Optional) Description of this source.
|
||||
- `filter: ...`: (Optional) Settings to filter which JSON files are processed based on flags within them (see "Filtering Legacy Annotations" below).
|
||||
|
||||
**YAML Example:**
|
||||
|
||||
```yaml
|
||||
# In your main DatasetConfig YAML file
|
||||
sources:
|
||||
- name: "OldProject_SiteA_Files"
|
||||
format: "batdetect2" # Use the directory-based loader
|
||||
audio_dir: "/data/SiteA/Audio/"
|
||||
annotations_dir: "/data/SiteA/Annotations_JSON/"
|
||||
description: "Legacy annotations stored as individual JSONs per recording."
|
||||
# filter: ... # Optional filter settings can be added here
|
||||
```
|
||||
|
||||
### Format 2: Single Merged File
|
||||
|
||||
Use this when you have a single JSON file that contains a list of annotations for multiple recordings.
|
||||
|
||||
**Configuration Fields:**
|
||||
|
||||
- `format: "batdetect2_file"`: **(Required)** Identifies this specific legacy format loader.
|
||||
- `name: your_source_name`: **(Required)** A unique name for this data source.
|
||||
- `audio_dir: path/to/audio/files`: **(Required)** Path to the directory containing the `.wav` audio files referenced _within_ the merged JSON file.
|
||||
- `annotations_path: path/to/your/merged_annotations.json`: **(Required)** Path to the single `.json` file containing the list of annotations.
|
||||
- `description: "Details..."`: (Optional) Description of this source.
|
||||
- `filter: ...`: (Optional) Settings to filter which records _within_ the merged file are processed (see "Filtering Legacy Annotations" below).
|
||||
|
||||
**YAML Example:**
|
||||
|
||||
```yaml
|
||||
# In your main DatasetConfig YAML file
|
||||
sources:
|
||||
- name: "OldProject_Merged"
|
||||
format: "batdetect2_file" # Use the merged file loader
|
||||
audio_dir: "/data/AllAudio/"
|
||||
annotations_path: "/data/CombinedAnnotations/old_project_merged.json"
|
||||
description: "Legacy annotations aggregated into a single JSON file."
|
||||
# filter: ... # Optional filter settings can be added here
|
||||
```
|
||||
|
||||
## Filtering Legacy Annotations
|
||||
|
||||
The legacy JSON annotation structure (for both formats) included boolean flags indicating the status of the annotation work for each recording:
|
||||
|
||||
- `annotated`: Typically `true` if a human had reviewed or created annotations for the file.
|
||||
- `issues`: Typically `true` if problems were noted during annotation or review.
|
||||
|
||||
You can optionally filter the data based on these flags using a `filter:` block within the source configuration.
|
||||
This applies whether you use `"batdetect2"` or `"batdetect2_file"`.
|
||||
|
||||
**Filter Options:**
|
||||
|
||||
- `only_annotated: <true_or_false>`:
|
||||
- `true` (**Default**): Only process entries where the `annotated` flag in the JSON is `true`.
|
||||
- `false`: Process entries regardless of the `annotated` flag.
|
||||
- `exclude_issues: <true_or_false>`:
|
||||
- `true` (**Default**): Skip processing entries where the `issues` flag in the JSON is `true`.
|
||||
- `false`: Process entries even if they are flagged with `issues`.
|
||||
|
||||
**Default Filtering:** If you **omit** the `filter:` block entirely, the default settings (`only_annotated: true`, `exclude_issues: true`) are applied automatically.
|
||||
This means only entries marked as annotated and not having issues will be loaded.
|
||||
|
||||
**Disabling Filtering:** To load _all_ entries from the legacy source regardless of the `annotated` or `issues` flags, explicitly disable the filter:
|
||||
|
||||
```yaml
|
||||
filter: null
|
||||
```
|
||||
|
||||
**YAML Example (Custom Filter):** Only load entries marked as annotated, but _include_ those with issues.
|
||||
|
||||
```yaml
|
||||
sources:
|
||||
- name: "LegacyData_WithIssues"
|
||||
format: "batdetect2" # Or "batdetect2_file"
|
||||
audio_dir: "path/to/audio"
|
||||
annotations_dir: "path/to/annotations" # Or annotations_path for merged
|
||||
filter:
|
||||
only_annotated: true
|
||||
exclude_issues: false # Include entries even if issues flag is true
|
||||
```
|
||||
|
||||
## Summary
|
||||
|
||||
BatDetect2 allows you to incorporate datasets stored in older "BatDetect2" JSON formats.
|
||||
|
||||
- Use `format: "batdetect2"` and provide `annotations_dir` if you have one JSON file per recording in a directory.
|
||||
- Use `format: "batdetect2_file"` and provide `annotations_path` if you have a single JSON file containing annotations for multiple recordings.
|
||||
- Optionally use the `filter:` block with `only_annotated` and `exclude_issues` to select data based on flags present in the legacy JSON structure.
|
||||
|
||||
The system will handle loading, filtering (if configured), and converting this legacy data into the standard `soundevent` format used internally.
|
14
docs/source/index.md
Normal file
14
docs/source/index.md
Normal file
@ -0,0 +1,14 @@
|
||||
# batdetect2 documentation
|
||||
|
||||
Hi!
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
:caption: Contents:
|
||||
|
||||
data/index
|
||||
preprocessing/index
|
||||
postprocessing
|
||||
targets/index
|
||||
reference/index
|
||||
```
|
126
docs/source/postprocessing.md
Normal file
126
docs/source/postprocessing.md
Normal file
@ -0,0 +1,126 @@
|
||||
# Postprocessing: From Model Output to Predictions
|
||||
|
||||
## What is Postprocessing?
|
||||
|
||||
After the BatDetect2 neural network analyzes a spectrogram, it doesn't directly output a neat list of bat calls.
|
||||
Instead, it produces raw numerical data, usually in the form of multi-dimensional arrays or "heatmaps".
|
||||
These arrays contain information like:
|
||||
|
||||
- The probability of a sound event being present at each time-frequency location.
|
||||
- The probability of each possible target class (e.g., species) at each location.
|
||||
- Predicted size characteristics (like duration and bandwidth) at each location.
|
||||
- Internal learned features at each location.
|
||||
|
||||
**Postprocessing** is the sequence of steps that takes these numerical model outputs and translates them into a structured list of detected sound events, complete with predicted tags, bounding boxes, and confidence scores.
|
||||
The {py:mod}`batdetect2.postprocess` mode handles this entire workflow.
|
||||
|
||||
## Why is Postprocessing Necessary?
|
||||
|
||||
1. **Interpretation:** Raw heatmap outputs need interpretation to identify distinct sound events (detections).
|
||||
A high probability score might spread across several adjacent time-frequency bins, all related to the same call.
|
||||
2. **Refinement:** Model outputs can be noisy or contain redundancies.
|
||||
Postprocessing steps like Non-Maximum Suppression (NMS) clean this up, ensuring (ideally) only one detection is reported for each actual sound event.
|
||||
3. **Contextualization:** Raw outputs lack real-world units.
|
||||
Postprocessing adds back time (seconds) and frequency (Hz) coordinates, converts predicted sizes to physical units using configured scales, and decodes predicted class indices back into meaningful tags based on your target definitions.
|
||||
4. **User Control:** Postprocessing includes tunable parameters, most importantly **thresholds**.
|
||||
By adjusting these, you can control the trade-off between finding more potential calls (sensitivity) versus reducing false positives (specificity) _without needing to retrain the model_.
|
||||
|
||||
## The Postprocessing Pipeline
|
||||
|
||||
BatDetect2 applies a series of steps to convert the raw model output into final predictions.
|
||||
Understanding these steps helps interpret the results and configure the process effectively:
|
||||
|
||||
1. **Non-Maximum Suppression (NMS):**
|
||||
|
||||
- **Goal:** Reduce redundant detections.
|
||||
If the model outputs high scores for several nearby points corresponding to the same call, NMS selects the single highest peak in a local neighbourhood and suppresses the others (sets their score to zero).
|
||||
- **Configurable:** The size of the neighbourhood (`nms_kernel_size`) can be adjusted.
|
||||
|
||||
2. **Coordinate Remapping:**
|
||||
|
||||
- **Goal:** Add coordinate (time/frequency) information.
|
||||
This step takes the grid-based model outputs (which just have row/column indices) and associates them with actual time (seconds) and frequency (Hz) coordinates based on the input spectrogram's properties.
|
||||
The result is coordinate-aware arrays (using {py:class}`xarray.DataArray`}).
|
||||
|
||||
3. **Detection Extraction:**
|
||||
|
||||
- **Goal:** Identify the specific points representing detected events.
|
||||
- **Process:** Looks for peaks in the NMS-processed detection heatmap that are above a certain confidence level (`detection_threshold`).
|
||||
It also often limits the maximum number of detections based on a rate (`top_k_per_sec`) to avoid excessive outputs in very busy files.
|
||||
- **Configurable:** `detection_threshold`, `top_k_per_sec`.
|
||||
|
||||
4. **Data Extraction:**
|
||||
|
||||
- **Goal:** Gather all relevant information for each detected point.
|
||||
- **Process:** For each time-frequency location identified in Step 3, this step looks up the corresponding values in the _other_ remapped model output arrays (class probabilities, predicted sizes, internal features).
|
||||
- **Intermediate Output 1:** The result of this stage (containing aligned scores, positions, sizes, class probabilities, and features for all detections in a clip) is often accessible programmatically as an {py:class}`xarray.Dataset`}.
|
||||
This can be useful for advanced users needing direct access to the numerical outputs.
|
||||
|
||||
5. **Decoding & Formatting:**
|
||||
|
||||
- **Goal:** Convert the extracted numerical data into interpretable, standard formats.
|
||||
- **Process:**
|
||||
- **ROI Recovery:** Uses the predicted position and size values, along with the ROI mapping configuration defined in the `targets` module, to reconstruct an estimated bounding box ({py:class}`soundevent.data.BoundingBox`}).
|
||||
- **Class Decoding:** Translates the numerical class probability vector into meaningful {py:class}`soundevent.data.PredictedTag` objects.
|
||||
This involves:
|
||||
- Applying the `classification_threshold` to ignore low-confidence class scores.
|
||||
- Using the class decoding rules (from the `targets` module) to map the name(s) of the high-scoring class(es) back to standard tags (like `species: Myotis daubentonii`).
|
||||
- Optionally selecting only the top-scoring class or multiple classes above the threshold.
|
||||
- Including the generic "Bat" tags if no specific class meets the threshold.
|
||||
- **Feature Conversion:** Converts raw feature vectors into {py:class}`soundevent.data.Feature` objects.
|
||||
- **Intermediate Output 2:** This step might internally create a list of simplified `RawPrediction` objects containing the bounding box, scores, and features.
|
||||
This intermediate list might also be accessible programmatically for users who prefer a simpler structure than the final {py:mod}`soundevent` objects.
|
||||
|
||||
6. **Final Output (`ClipPrediction`):**
|
||||
- **Goal:** Package everything into a standard format.
|
||||
- **Process:** Collects all the fully processed `SoundEventPrediction` objects (each containing a sound event with geometry, features, overall score, and predicted tags with scores) for a given audio clip into a final {py:class}`soundevent.data.ClipPrediction` object.
|
||||
This is the standard output format representing the model's findings for that clip.
|
||||
|
||||
## Configuring Postprocessing
|
||||
|
||||
You can control key aspects of this pipeline, especially the thresholds and NMS settings, via a `postprocess:` section in your main configuration YAML file.
|
||||
Adjusting these **allows you to fine-tune the detection results without retraining the model**.
|
||||
|
||||
**Key Configurable Parameters:**
|
||||
|
||||
- `detection_threshold`: (Number >= 0, e.g., `0.1`) Minimum score for a peak to be considered a detection.
|
||||
**Lowering this increases sensitivity (more detections, potentially more false positives); raising it increases specificity (fewer detections, potentially missing faint calls).**
|
||||
- `classification_threshold`: (Number >= 0, e.g., `0.3`) Minimum score for a _specific class_ prediction to be assigned as a tag.
|
||||
Affects how confidently the model must identify the class.
|
||||
- `top_k_per_sec`: (Integer > 0, e.g., `200`) Limits the maximum density of detections reported per second.
|
||||
Helps manage extremely dense recordings.
|
||||
- `nms_kernel_size`: (Integer > 0, e.g., `9`) Size of the NMS window in pixels/bins.
|
||||
Affects how close two distinct peaks can be before one suppresses the other.
|
||||
|
||||
**Example YAML Configuration:**
|
||||
|
||||
```yaml
|
||||
# Inside your main configuration file (e.g., config.yaml)
|
||||
|
||||
postprocess:
|
||||
nms_kernel_size: 9
|
||||
detection_threshold: 0.1 # Lower threshold -> more sensitive
|
||||
classification_threshold: 0.3 # Higher threshold -> more confident classifications
|
||||
top_k_per_sec: 200
|
||||
# ... other sections preprocessing, targets ...
|
||||
```
|
||||
|
||||
**Note:** These parameters can often also be adjusted via Command Line Interface (CLI) arguments when running predictions, or through function arguments if using the Python API, providing flexibility for experimentation.
|
||||
|
||||
## Accessing Intermediate Results
|
||||
|
||||
While the final `ClipPrediction` objects are the standard output, the `Postprocessor` object used internally provides methods to access results from intermediate stages (like the `xr.Dataset` after Step 4, or the list of `RawPrediction` objects after Step 5).
|
||||
|
||||
This can be valuable for:
|
||||
|
||||
- Debugging the pipeline.
|
||||
- Performing custom analyses on the numerical outputs before final decoding.
|
||||
- **Transfer Learning / Feature Extraction:** Directly accessing the extracted `features` (from Step 4 or 5a) associated with detected events can be highly useful for training other models or further analysis.
|
||||
|
||||
Consult the API documentation for details on how to access these intermediate results programmatically if needed.
|
||||
|
||||
## Summary
|
||||
|
||||
Postprocessing is the conversion between neural network outputs and meaningful, interpretable sound event detections.
|
||||
BatDetect2 provides a configurable pipeline including NMS, coordinate remapping, peak detection with thresholding, data extraction, and class/geometry decoding.
|
||||
Researchers can easily tune key parameters like thresholds via configuration files or arguments to adjust the final set of predictions without altering the trained model itself, and advanced users can access intermediate results for custom analyses or feature reuse.
|
92
docs/source/preprocessing/audio.md
Normal file
92
docs/source/preprocessing/audio.md
Normal file
@ -0,0 +1,92 @@
|
||||
# Audio Loading and Preprocessing
|
||||
|
||||
## Purpose
|
||||
|
||||
Before BatDetect2 can analyze the sounds in your recordings, the raw audio data needs to be loaded from the file and prepared.
|
||||
This initial preparation involves several standard waveform processing steps.
|
||||
This `audio` module handles this first stage of preprocessing.
|
||||
|
||||
It's crucial to understand that the _exact same_ preprocessing steps must be applied both when **training** a model and when **using** that trained model later to make predictions (inference).
|
||||
Consistent preprocessing ensures the model receives data in the format it expects.
|
||||
|
||||
BatDetect2 allows you to control these audio preprocessing steps through settings in your main configuration file.
|
||||
|
||||
## The Audio Processing Pipeline
|
||||
|
||||
When BatDetect2 needs to process an audio segment (either a full recording or a specific clip), it follows a defined sequence of steps:
|
||||
|
||||
1. **Load Audio Segment:** The system first reads the specified time segment from the audio file.
|
||||
- **Note:** BatDetect2 typically works with **mono** audio.
|
||||
By default, if your file has multiple channels (e.g., stereo), only the **first channel** is loaded and used for subsequent processing.
|
||||
2. **Adjust Duration (Optional):** If you've specified a target duration in your configuration, the loaded audio segment is either shortened (by cropping from the start) or lengthened (by adding silence, i.e., zeros, at the end) to match that exact duration.
|
||||
This is sometimes required by specific model architectures that expect fixed-size inputs.
|
||||
By default, this step is **off**, and the original clip duration is used.
|
||||
3. **Resample (Optional):** If configured (and usually **on** by default), the audio's sample rate is changed to a specific target value (e.g., 256,000 Hz).
|
||||
This is vital for standardizing the data, as different recording devices capture audio at different rates.
|
||||
The model needs to be trained and run on data with a consistent sample rate.
|
||||
4. **Center Waveform (Optional):** If configured (and typically **on** by default), the system removes any constant shift away from zero in the waveform (known as DC offset).
|
||||
This is a standard practice that can sometimes improve the quality of later signal processing steps.
|
||||
5. **Scale Amplitude (Optional):** If configured (typically **off** by default), the waveform's amplitude (loudness) is adjusted.
|
||||
The standard method used here is "peak normalization," which scales the entire clip so that the loudest point has an absolute value of 1.0.
|
||||
This can help standardize volume levels across different recordings, although it's not always necessary or desirable depending on your analysis goals.
|
||||
|
||||
## Configuring Audio Processing
|
||||
|
||||
You can control these steps via settings in your main configuration file (e.g., `config.yaml`), usually within a dedicated `audio:` section (which might itself be under a broader `preprocessing:` section).
|
||||
|
||||
Here are the key options you can set:
|
||||
|
||||
- **Resampling (`resample`)**:
|
||||
|
||||
- To enable resampling (recommended and usually default), include a `resample:` block.
|
||||
To disable it completely, you might set `resample: null` or omit the block.
|
||||
- `samplerate`: (Number) The target sample rate in Hertz (Hz) that all audio will be converted to.
|
||||
This **must** match the sample rate expected by the BatDetect2 model you are using or training (e.g., `samplerate: 256000`).
|
||||
- `mode`: (Text, `"poly"` or `"fourier"`) The underlying algorithm used for resampling.
|
||||
The default `"poly"` is generally a good choice.
|
||||
You typically don't need to change this unless you have specific reasons.
|
||||
|
||||
- **Duration (`duration`)**:
|
||||
|
||||
- (Number or `null`) Sets a fixed duration for all audio clips in **seconds**.
|
||||
If set (e.g., `duration: 4.0`), shorter clips are padded with silence, and longer clips are cropped.
|
||||
If `null` (default), the original clip duration is used.
|
||||
|
||||
- **Centering (`center`)**:
|
||||
|
||||
- (Boolean, `true` or `false`) Controls DC offset removal.
|
||||
Default is usually `true`.
|
||||
Set to `false` to disable.
|
||||
|
||||
- **Scaling (`scale`)**:
|
||||
- (Boolean, `true` or `false`) Controls peak amplitude normalization.
|
||||
Default is usually `false`.
|
||||
Set to `true` to enable scaling so the maximum absolute amplitude becomes 1.0.
|
||||
|
||||
**Example YAML Configuration:**
|
||||
|
||||
```yaml
|
||||
# Inside your main configuration file (e.g., training_config.yaml)
|
||||
|
||||
preprocessing: # Or this might be at the top level
|
||||
audio:
|
||||
# --- Resampling Settings ---
|
||||
resample: # Settings block to control resampling
|
||||
samplerate: 256000 # Target sample rate in Hz (Required if resampling)
|
||||
mode: poly # Algorithm ('poly' or 'fourier', optional, defaults to 'poly')
|
||||
# To disable resampling entirely, you might use:
|
||||
# resample: null
|
||||
|
||||
# --- Other Settings ---
|
||||
duration: null # Keep original clip duration (e.g., use 4.0 for 4 seconds)
|
||||
center: true # Remove DC offset (default is often true)
|
||||
scale: false # Do not normalize peak amplitude (default is often false)
|
||||
|
||||
# ... other configuration sections (like model, dataset, targets) ...
|
||||
```
|
||||
|
||||
## Outcome
|
||||
|
||||
After these steps, the output is a standardized audio waveform (represented as a numerical array with time information).
|
||||
This processed waveform is now ready for the next stage of preprocessing, which typically involves calculating the spectrogram (covered in the next module/section).
|
||||
Ensuring these audio preprocessing settings are consistent is fundamental for achieving reliable results in both training and inference.
|
46
docs/source/preprocessing/index.md
Normal file
46
docs/source/preprocessing/index.md
Normal file
@ -0,0 +1,46 @@
|
||||
# Preprocessing Audio for BatDetect2
|
||||
|
||||
## What is Preprocessing?
|
||||
|
||||
Preprocessing refers to the steps taken to transform your raw audio recordings into a standardized format suitable for analysis by the BatDetect2 deep learning model.
|
||||
This module (`batdetect2.preprocessing`) provides the tools to perform these transformations.
|
||||
|
||||
## Why is Preprocessing Important?
|
||||
|
||||
Applying a consistent preprocessing pipeline is important for several reasons:
|
||||
|
||||
1. **Standardization:** Audio recordings vary significantly depending on the equipment used, recording conditions, and settings (e.g., different sample rates, varying loudness levels, background noise).
|
||||
Preprocessing helps standardize these aspects, making the data more uniform and allowing the model to learn relevant patterns more effectively.
|
||||
2. **Model Requirements:** Deep learning models, particularly those like BatDetect2 that analyze 2D-patterns in spectrograms, are designed to work with specific input characteristics.
|
||||
They often expect spectrograms of a certain size (time x frequency bins), with values represented on a particular scale (e.g., logarithmic/dB), and within a defined frequency range.
|
||||
Preprocessing ensures the data meets these requirements.
|
||||
3. **Consistency is Key:** Perhaps most importantly, the **exact same preprocessing steps** must be applied both when _training_ the model and when _using the trained model to make predictions_ (inference) on new data.
|
||||
Any discrepancy between the preprocessing used during training and inference can significantly degrade the model's performance and lead to unreliable results.
|
||||
BatDetect2's configurable pipeline ensures this consistency.
|
||||
|
||||
## How Preprocessing is Done in BatDetect2
|
||||
|
||||
BatDetect2 handles preprocessing through a configurable, two-stage pipeline:
|
||||
|
||||
1. **Audio Loading & Preparation:** This first stage deals with the raw audio waveform.
|
||||
It involves loading the specified audio segment (from a file or clip), selecting a single channel (mono), optionally resampling it to a consistent sample rate, optionally adjusting its duration, and applying basic waveform conditioning like centering (DC offset removal) and amplitude scaling.
|
||||
(Details in the {doc}`audio` section).
|
||||
2. **Spectrogram Generation:** The prepared audio waveform is then converted into a spectrogram.
|
||||
This involves calculating the Short-Time Fourier Transform (STFT) and then applying a series of configurable steps like cropping the frequency range, applying amplitude representations (like dB scale or PCEN), optional denoising, optional resizing to the model's required dimensions, and optional final normalization.
|
||||
(Details in the {doc}`spectrogram` section).
|
||||
|
||||
The entire pipeline is controlled via settings in your main configuration file (typically a YAML file), usually grouped under a `preprocessing:` section which contains subsections like `audio:` and `spectrogram:`.
|
||||
This allows you to easily define, share, and reproduce the exact preprocessing used for a specific model or experiment.
|
||||
|
||||
## Next Steps
|
||||
|
||||
Explore the following sections for detailed explanations of how to configure each stage of the preprocessing pipeline and how to use the resulting preprocessor:
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
:caption: Preprocessing Steps:
|
||||
|
||||
audio
|
||||
spectrogram
|
||||
usage
|
||||
```
|
141
docs/source/preprocessing/spectrogram.md
Normal file
141
docs/source/preprocessing/spectrogram.md
Normal file
@ -0,0 +1,141 @@
|
||||
# Spectrogram Generation
|
||||
|
||||
## Purpose
|
||||
|
||||
After loading and performing initial processing on the audio waveform (as described in the Audio Loading section), the next crucial step in the `preprocessing` pipeline is to convert that waveform into a **spectrogram**.
|
||||
A spectrogram is a visual representation of sound, showing frequency content over time, and it's the primary input format for many deep learning models, including BatDetect2.
|
||||
|
||||
This module handles the calculation and subsequent processing of the spectrogram.
|
||||
Just like the audio processing, these steps need to be applied **consistently** during both model training and later use (inference) to ensure the model performs reliably.
|
||||
You control this entire process through the configuration file.
|
||||
|
||||
## The Spectrogram Generation Pipeline
|
||||
|
||||
Once BatDetect2 has a prepared audio waveform, it follows these steps to create the final spectrogram input for the model:
|
||||
|
||||
1. **Calculate STFT (Short-Time Fourier Transform):** This is the fundamental step that converts the 1D audio waveform into a 2D time-frequency representation.
|
||||
It calculates the frequency content within short, overlapping time windows.
|
||||
The output is typically a **magnitude spectrogram**, showing the intensity (amplitude) of different frequencies at different times.
|
||||
Key parameters here are the `window_duration` and `window_overlap`, which affect the trade-off between time and frequency resolution.
|
||||
2. **Crop Frequencies:** The STFT often produces frequency information over a very wide range (e.g., 0 Hz up to half the sample rate).
|
||||
This step crops the spectrogram to focus only on the frequency range relevant to your target sounds (e.g., 10 kHz to 120 kHz for typical bat echolocation).
|
||||
3. **Apply PCEN (Optional):** If configured, Per-Channel Energy Normalization is applied.
|
||||
PCEN is an adaptive technique that adjusts the gain (loudness) in each frequency channel based on its recent history.
|
||||
It can help suppress stationary background noise and enhance the prominence of transient sounds like echolocation pulses.
|
||||
This step is optional.
|
||||
4. **Set Amplitude Scale / Representation:** The values in the spectrogram (either raw magnitude or post-PCEN values) need to be represented on a suitable scale.
|
||||
You choose one of the following:
|
||||
- `"amplitude"`: Use the linear magnitude values directly.
|
||||
(Default)
|
||||
- `"power"`: Use the squared magnitude values (representing energy).
|
||||
- `"dB"`: Apply a logarithmic transformation (specifically `log(1 + C*Magnitude)`).
|
||||
This compresses the range of values, often making variations in quieter sounds more apparent, similar to how humans perceive loudness.
|
||||
5. **Denoise (Optional):** If configured (and usually **on** by default), a simple noise reduction technique is applied.
|
||||
This method subtracts the average value of each frequency bin (calculated across time) from that bin, assuming the average represents steady background noise.
|
||||
Negative values after subtraction are clipped to zero.
|
||||
6. **Resize (Optional):** If configured, the dimensions (height/frequency bins and width/time bins) of the spectrogram are adjusted using interpolation to match the exact input size expected by the neural network architecture.
|
||||
7. **Peak Normalize (Optional):** If configured (typically **off** by default), the entire final spectrogram is scaled so that its highest value is exactly 1.0.
|
||||
This ensures all spectrograms fed to the model have a consistent maximum value, which can sometimes aid training stability.
|
||||
|
||||
## Configuring Spectrogram Generation
|
||||
|
||||
You control all these steps via settings in your main configuration file (e.g., `config.yaml`), within the `spectrogram:` section (usually located under the main `preprocessing:` section).
|
||||
|
||||
Here are the key configuration options:
|
||||
|
||||
- **STFT Settings (`stft`)**:
|
||||
|
||||
- `window_duration`: (Number, seconds, e.g., `0.002`) Length of the analysis window.
|
||||
- `window_overlap`: (Number, 0.0 to <1.0, e.g., `0.75`) Fractional overlap between windows.
|
||||
- `window_fn`: (Text, e.g., `"hann"`) Name of the windowing function.
|
||||
|
||||
- **Frequency Cropping (`frequencies`)**:
|
||||
|
||||
- `min_freq`: (Integer, Hz, e.g., `10000`) Minimum frequency to keep.
|
||||
- `max_freq`: (Integer, Hz, e.g., `120000`) Maximum frequency to keep.
|
||||
|
||||
- **PCEN (`pcen`)**:
|
||||
|
||||
- This entire section is **optional**.
|
||||
Include it only if you want to apply PCEN.
|
||||
If omitted or set to `null`, PCEN is skipped.
|
||||
- `time_constant`: (Number, seconds, e.g., `0.4`) Controls adaptation speed.
|
||||
- `gain`: (Number, e.g., `0.98`) Gain factor.
|
||||
- `bias`: (Number, e.g., `2.0`) Bias factor.
|
||||
- `power`: (Number, e.g., `0.5`) Compression exponent.
|
||||
|
||||
- **Amplitude Scale (`scale`)**:
|
||||
|
||||
- (Text: `"dB"`, `"power"`, or `"amplitude"`) Selects the final representation of the spectrogram values.
|
||||
Default is `"amplitude"`.
|
||||
|
||||
- **Denoising (`spectral_mean_substraction`)**:
|
||||
|
||||
- (Boolean: `true` or `false`) Enables/disables the spectral mean subtraction denoising step.
|
||||
Default is usually `true`.
|
||||
|
||||
- **Resizing (`size`)**:
|
||||
|
||||
- This entire section is **optional**.
|
||||
Include it only if you need to resize the spectrogram to specific dimensions required by the model.
|
||||
If omitted or set to `null`, no resizing occurs after frequency cropping.
|
||||
- `height`: (Integer, e.g., `128`) Target number of frequency bins.
|
||||
- `resize_factor`: (Number or `null`, e.g., `0.5`) Factor to scale the time dimension by.
|
||||
`0.5` halves the width, `null` or `1.0` keeps the original width.
|
||||
|
||||
- **Peak Normalization (`peak_normalize`)**:
|
||||
- (Boolean: `true` or `false`) Enables/disables final scaling of the entire spectrogram so the maximum value is 1.0.
|
||||
Default is usually `false`.
|
||||
|
||||
**Example YAML Configuration:**
|
||||
|
||||
```yaml
|
||||
# Inside your main configuration file
|
||||
|
||||
preprocessing:
|
||||
audio:
|
||||
# ... (your audio configuration settings) ...
|
||||
resample:
|
||||
samplerate: 256000 # Ensure this matches model needs
|
||||
|
||||
spectrogram:
|
||||
# --- STFT Parameters ---
|
||||
stft:
|
||||
window_duration: 0.002 # 2ms window
|
||||
window_overlap: 0.75 # 75% overlap
|
||||
window_fn: hann
|
||||
|
||||
# --- Frequency Range ---
|
||||
frequencies:
|
||||
min_freq: 10000 # 10 kHz
|
||||
max_freq: 120000 # 120 kHz
|
||||
|
||||
# --- PCEN (Optional) ---
|
||||
# Include this block to enable PCEN, omit or set to null to disable.
|
||||
pcen:
|
||||
time_constant: 0.4
|
||||
gain: 0.98
|
||||
bias: 2.0
|
||||
power: 0.5
|
||||
|
||||
# --- Final Amplitude Representation ---
|
||||
scale: dB # Choose 'dB', 'power', or 'amplitude'
|
||||
|
||||
# --- Denoising ---
|
||||
spectral_mean_substraction: true # Enable spectral mean subtraction
|
||||
|
||||
# --- Resizing (Optional) ---
|
||||
# Include this block to resize, omit or set to null to disable.
|
||||
size:
|
||||
height: 128 # Target height in frequency bins
|
||||
resize_factor: 0.5 # Halve the number of time bins
|
||||
|
||||
# --- Final Normalization ---
|
||||
peak_normalize: false # Do not scale max value to 1.0
|
||||
```
|
||||
|
||||
## Outcome
|
||||
|
||||
The output of this module is the final, processed spectrogram (as a 2D numerical array with time and frequency information).
|
||||
This spectrogram is now in the precise format expected by the BatDetect2 neural network, ready to be used for training the model or for making predictions on new data.
|
||||
Remember, using the exact same `spectrogram` configuration settings during training and inference is essential for correct model performance.
|
175
docs/source/preprocessing/usage.md
Normal file
175
docs/source/preprocessing/usage.md
Normal file
@ -0,0 +1,175 @@
|
||||
# Using Preprocessors in BatDetect2
|
||||
|
||||
## Overview
|
||||
|
||||
In the previous sections ({doc}`audio`and {doc}`spectrogram`), we discussed the individual steps involved in converting raw audio into a processed spectrogram suitable for BatDetect2 models, and how to configure these steps using YAML files (specifically the `audio:` and `spectrogram:` sections within a main `preprocessing:` configuration block).
|
||||
|
||||
This page focuses on how this configured pipeline is represented and used within BatDetect2, primarily through the concept of a **`Preprocessor`** object.
|
||||
This object bundles together your chosen audio loading settings and spectrogram generation settings into a single component that can perform the end-to-end processing.
|
||||
|
||||
## Do I Need to Interact with Preprocessors Directly?
|
||||
|
||||
**Usually, no.** For standard model training or running inference with BatDetect2 using the provided scripts, the system will automatically:
|
||||
|
||||
1. Read your main configuration file (e.g., `config.yaml`).
|
||||
2. Find the `preprocessing:` section (containing `audio:` and `spectrogram:` settings).
|
||||
3. Build the appropriate `Preprocessor` object internally based on your settings.
|
||||
4. Use that internal `Preprocessor` object automatically whenever audio needs to be loaded and converted to a spectrogram.
|
||||
|
||||
**However**, understanding the `Preprocessor` object is useful if you want to:
|
||||
|
||||
- **Customize:** Go beyond the standard configuration options by interacting with parts of the pipeline programmatically.
|
||||
- **Integrate:** Use BatDetect2's preprocessing steps within your own custom Python analysis scripts.
|
||||
- **Inspect/Debug:** Manually apply preprocessing steps to specific files or clips to examine intermediate outputs (like the processed waveform) or the final spectrogram.
|
||||
|
||||
## Getting a Preprocessor Object
|
||||
|
||||
If you _do_ want to work with the preprocessor programmatically, you first need to get an instance of it.
|
||||
This is typically done based on a configuration:
|
||||
|
||||
1. **Define Configuration:** Create your `preprocessing:` configuration, usually in a YAML file (let's call it `preprocess_config.yaml`), detailing your desired `audio` and `spectrogram` settings.
|
||||
|
||||
```yaml
|
||||
# preprocess_config.yaml
|
||||
audio:
|
||||
resample:
|
||||
samplerate: 256000
|
||||
# ... other audio settings ...
|
||||
spectrogram:
|
||||
frequencies:
|
||||
min_freq: 15000
|
||||
max_freq: 120000
|
||||
scale: dB
|
||||
# ... other spectrogram settings ...
|
||||
```
|
||||
|
||||
2. **Load Configuration & Build Preprocessor (in Python):**
|
||||
|
||||
```python
|
||||
from batdetect2.preprocessing import load_preprocessing_config, build_preprocessor
|
||||
from batdetect2.preprocess.types import Preprocessor # Import the type
|
||||
|
||||
# Load the configuration from the file
|
||||
config_path = "path/to/your/preprocess_config.yaml"
|
||||
preprocessing_config = load_preprocessing_config(config_path)
|
||||
|
||||
# Build the actual preprocessor object using the loaded config
|
||||
preprocessor: Preprocessor = build_preprocessor(preprocessing_config)
|
||||
|
||||
# 'preprocessor' is now ready to use!
|
||||
```
|
||||
|
||||
3. **Using Defaults:** If you just want the standard BatDetect2 default preprocessing settings, you can build one without loading a config file:
|
||||
|
||||
```python
|
||||
from batdetect2.preprocessing import build_preprocessor
|
||||
from batdetect2.preprocess.types import Preprocessor
|
||||
|
||||
# Build with default settings
|
||||
default_preprocessor: Preprocessor = build_preprocessor()
|
||||
```
|
||||
|
||||
## Applying Preprocessing
|
||||
|
||||
Once you have a `preprocessor` object, you can use its methods to process audio data:
|
||||
|
||||
**1.
|
||||
End-to-End Processing (Common Use Case):**
|
||||
|
||||
These methods take an audio source identifier (file path, Recording object, or Clip object) and return the final, processed spectrogram.
|
||||
|
||||
- `preprocessor.preprocess_file(path)`: Processes an entire audio file.
|
||||
- `preprocessor.preprocess_recording(recording_obj)`: Processes the entire audio associated with a `soundevent.data.Recording` object.
|
||||
- `preprocessor.preprocess_clip(clip_obj)`: Processes only the specific time segment defined by a `soundevent.data.Clip` object.
|
||||
- **Efficiency Note:** Using `preprocess_clip` is **highly recommended** when you are only interested in analyzing a small portion of a potentially long recording.
|
||||
It avoids loading the entire audio file into memory, making it much more efficient.
|
||||
|
||||
```python
|
||||
from soundevent import data
|
||||
|
||||
# Assume 'preprocessor' is built as shown before
|
||||
# Assume 'my_clip' is a soundevent.data.Clip object defining a segment
|
||||
|
||||
# Process an entire file
|
||||
spectrogram_from_file = preprocessor.preprocess_file("my_recording.wav")
|
||||
|
||||
# Process only a specific clip (more efficient for segments)
|
||||
spectrogram_from_clip = preprocessor.preprocess_clip(my_clip)
|
||||
|
||||
# The results (spectrogram_from_file, spectrogram_from_clip) are xr.DataArrays
|
||||
print(type(spectrogram_from_clip))
|
||||
# Output: <class 'xarray.core.dataarray.DataArray'>
|
||||
```
|
||||
|
||||
**2.
|
||||
Intermediate Steps (Advanced Use Cases):**
|
||||
|
||||
The preprocessor also allows access to intermediate stages if needed:
|
||||
|
||||
- `preprocessor.load_clip_audio(clip_obj)` (and similar for file/recording): Loads the audio and applies _only_ the waveform processing steps (resampling, centering, etc.) defined in the `audio` config.
|
||||
Returns the processed waveform as an `xr.DataArray`.
|
||||
This is useful if you want to analyze or manipulate the waveform itself before spectrogram generation.
|
||||
- `preprocessor.compute_spectrogram(waveform)`: Takes an _already loaded_ waveform (either `np.ndarray` or `xr.DataArray`) and applies _only_ the spectrogram generation steps defined in the `spectrogram` config.
|
||||
- If you provide an `xr.DataArray` (e.g., from `load_clip_audio`), it uses the sample rate from the array's coordinates.
|
||||
- If you provide a raw `np.ndarray`, it **must assume a sample rate**.
|
||||
It uses the `default_samplerate` that was determined when the `preprocessor` was built (based on your `audio` config's resample settings or the global default).
|
||||
Be cautious when using NumPy arrays to ensure the sample rate assumption is correct for your data!
|
||||
|
||||
```python
|
||||
# Example: Get waveform first, then spectrogram
|
||||
waveform = preprocessor.load_clip_audio(my_clip)
|
||||
# waveform is an xr.DataArray
|
||||
|
||||
# ...potentially do other things with the waveform...
|
||||
|
||||
# Compute spectrogram from the loaded waveform
|
||||
spectrogram = preprocessor.compute_spectrogram(waveform)
|
||||
|
||||
# Example: Process external numpy array (use with caution re: sample rate)
|
||||
# import soundfile as sf # Requires installing soundfile
|
||||
# numpy_waveform, original_sr = sf.read("some_other_audio.wav")
|
||||
# # MUST ensure numpy_waveform's actual sample rate matches
|
||||
# # preprocessor.default_samplerate for correct results here!
|
||||
# spec_from_numpy = preprocessor.compute_spectrogram(numpy_waveform)
|
||||
```
|
||||
|
||||
## Understanding the Output: `xarray.DataArray`
|
||||
|
||||
All preprocessing methods return the final spectrogram (or the intermediate waveform) as an **`xarray.DataArray`**.
|
||||
|
||||
**What is it?** Think of it like a standard NumPy array (holding the numerical data of the spectrogram) but with added "superpowers":
|
||||
|
||||
- **Labeled Dimensions:** Instead of just having axis 0 and axis 1, the dimensions have names, typically `"frequency"` and `"time"`.
|
||||
- **Coordinates:** It stores the actual frequency values (e.g., in Hz) corresponding to each row and the actual time values (e.g., in seconds) corresponding to each column along the dimensions.
|
||||
|
||||
**Why is it used?**
|
||||
|
||||
- **Clarity:** The data is self-describing.
|
||||
You don't need to remember which axis is time and which is frequency, or what the units are – it's stored with the data.
|
||||
- **Convenience:** You can select, slice, or plot data using the real-world coordinate values (times, frequencies) instead of just numerical indices.
|
||||
This makes analysis code easier to write and less prone to errors.
|
||||
- **Metadata:** It can hold additional metadata about the processing steps in its `attrs` (attributes) dictionary.
|
||||
|
||||
**Using the Output:**
|
||||
|
||||
- **Input to Model:** For standard training or inference, you typically pass this `xr.DataArray` spectrogram directly to the BatDetect2 model functions.
|
||||
- **Inspection/Analysis:** If you're working programmatically, you can use xarray's powerful features.
|
||||
For example (these are just illustrations of xarray):
|
||||
|
||||
```python
|
||||
# Get the shape (frequency_bins, time_bins)
|
||||
# print(spectrogram.shape)
|
||||
|
||||
# Get the frequency coordinate values
|
||||
# print(spectrogram['frequency'].values)
|
||||
|
||||
# Select data near a specific time and frequency
|
||||
# value_at_point = spectrogram.sel(time=0.5, frequency=50000, method="nearest")
|
||||
# print(value_at_point)
|
||||
|
||||
# Select a time slice between 0.2 and 0.3 seconds
|
||||
# time_slice = spectrogram.sel(time=slice(0.2, 0.3))
|
||||
# print(time_slice.shape)
|
||||
```
|
||||
|
||||
In summary, while BatDetect2 often handles preprocessing automatically based on your configuration, the underlying `Preprocessor` object provides a flexible interface for applying these steps programmatically if needed, returning results in the convenient and informative `xarray.DataArray` format.
|
7
docs/source/reference/configs.md
Normal file
7
docs/source/reference/configs.md
Normal file
@ -0,0 +1,7 @@
|
||||
# Config Reference
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: batdetect2.configs
|
||||
:members:
|
||||
:inherited-members: pydantic.BaseModel
|
||||
```
|
10
docs/source/reference/index.md
Normal file
10
docs/source/reference/index.md
Normal file
@ -0,0 +1,10 @@
|
||||
# Reference documentation
|
||||
|
||||
```{eval-rst}
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: Contents:
|
||||
|
||||
configs
|
||||
targets
|
||||
```
|
6
docs/source/reference/targets.md
Normal file
6
docs/source/reference/targets.md
Normal file
@ -0,0 +1,6 @@
|
||||
# Targets Reference
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: batdetect2.targets
|
||||
:members:
|
||||
```
|
141
docs/source/targets/classes.md
Normal file
141
docs/source/targets/classes.md
Normal file
@ -0,0 +1,141 @@
|
||||
# Step 4: Defining Target Classes and Decoding Rules
|
||||
|
||||
## Purpose and Context
|
||||
|
||||
You've prepared your data by defining your annotation vocabulary (Step 1: Terms), removing irrelevant sounds (Step 2: Filtering), and potentially cleaning up or modifying tags (Step 3: Transforming Tags).
|
||||
Now, it's time for a crucial step with two related goals:
|
||||
|
||||
1. Telling `batdetect2` **exactly what categories (classes) your model should learn to identify** by defining rules that map annotation tags to class names (like `pippip`, `myodau`, or `noise`).
|
||||
This process is often called **encoding**.
|
||||
2. Defining how the model's predictions (those same class names) should be translated back into meaningful, structured **annotation tags** when you use the trained model.
|
||||
This is often called **decoding**.
|
||||
|
||||
These definitions are essential for both training the model correctly and interpreting its output later.
|
||||
|
||||
## How it Works: Defining Classes with Rules
|
||||
|
||||
You define your target classes and their corresponding decoding rules in your main configuration file (e.g., your `.yaml` training config), typically under a section named `classes`.
|
||||
This section contains:
|
||||
|
||||
1. A **list** of specific class definitions.
|
||||
2. A definition for the **generic class** tags.
|
||||
|
||||
Each item in the `classes` list defines one specific class your model should learn.
|
||||
|
||||
## Defining a Single Class
|
||||
|
||||
Each specific class definition rule requires the following information:
|
||||
|
||||
1. `name`: **(Required)** This is the unique, simple name for this class (e.g., `pipistrellus_pipistrellus`, `myotis_daubentonii`, `noise`).
|
||||
This label is used during training and is what the model predicts.
|
||||
Choose clear, distinct names.
|
||||
**Each class name must be unique.**
|
||||
2. `tags`: **(Required)** This list contains one or more specific tags (using `key` and `value`) used to identify if an _existing_ annotation belongs to this class during the _encoding_ phase (preparing training data).
|
||||
3. `match_type`: **(Optional, defaults to `"all"`)** Determines how the `tags` list is evaluated during _encoding_:
|
||||
- `"all"`: The annotation must have **ALL** listed tags to match.
|
||||
(Default).
|
||||
- `"any"`: The annotation needs **AT LEAST ONE** listed tag to match.
|
||||
4. `output_tags`: **(Optional)** This list specifies the tags that should be assigned to an annotation when the model _predicts_ this class `name`.
|
||||
This is used during the _decoding_ phase (interpreting model output).
|
||||
- **If you omit `output_tags` (or set it to `null`/~), the system will default to using the same tags listed in the `tags` field for decoding.** This is often what you want.
|
||||
- Providing `output_tags` allows you to specify a different, potentially more canonical or detailed, set of tags to represent the class upon prediction.
|
||||
For example, you could match based on simplified tags but output standardized tags.
|
||||
|
||||
**Example: Defining Species Classes (Encoding & Default Decoding)**
|
||||
|
||||
Here, the `tags` used for matching during encoding will also be used for decoding, as `output_tags` is omitted.
|
||||
|
||||
```yaml
|
||||
# In your main configuration file
|
||||
classes:
|
||||
# Definition for the first class
|
||||
- name: pippip # Simple name for Pipistrellus pipistrellus
|
||||
tags: # Used for BOTH encoding match and decoding output
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
# match_type defaults to "all"
|
||||
# output_tags is omitted, defaults to using 'tags' above
|
||||
|
||||
# Definition for the second class
|
||||
- name: myodau # Simple name for Myotis daubentonii
|
||||
tags: # Used for BOTH encoding match and decoding output
|
||||
- key: species
|
||||
value: Myotis daubentonii
|
||||
```
|
||||
|
||||
**Example: Defining a Class with Separate Encoding and Decoding Tags**
|
||||
|
||||
Here, we match based on _either_ of two tags (`match_type: any`), but when the model predicts `'pipistrelle'`, we decode it _only_ to the specific `Pipistrellus pipistrellus` tag plus a genus tag.
|
||||
|
||||
```yaml
|
||||
classes:
|
||||
- name: pipistrelle # Name for a Pipistrellus group
|
||||
match_type: any # Match if EITHER tag below is present during encoding
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
- key: species
|
||||
value: Pipistrellus pygmaeus # Match pygmaeus too
|
||||
output_tags: # BUT, when decoding 'pipistrelle', assign THESE tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus # Canonical species
|
||||
- key: genus # Assumes 'genus' key exists
|
||||
value: Pipistrellus # Add genus tag
|
||||
```
|
||||
|
||||
## Handling Overlap During Encoding: Priority Order Matters
|
||||
|
||||
As before, when preparing training data (encoding), if an annotation matches the `tags` and `match_type` rules for multiple class definitions, the **order of the class definitions in the configuration list determines the priority**.
|
||||
|
||||
- The system checks rules from the **top** of the `classes` list down.
|
||||
- The annotation gets assigned the `name` of the **first class rule it matches**.
|
||||
- **Place more specific rules before more general rules.**
|
||||
|
||||
_(The YAML example for prioritizing Species over Noise remains the same as the previous version)_
|
||||
|
||||
## Handling Non-Matches & Decoding the Generic Class
|
||||
|
||||
What happens if an annotation passes filtering/transformation but doesn't match any specific class rule during encoding?
|
||||
|
||||
- **Encoding:** As explained previously, these annotations are **not ignored**.
|
||||
They are typically assigned to a generic "relevant sound" category, often called the **"Bat"** class in BatDetect2, intended for all relevant bat calls not specifically classified.
|
||||
- **Decoding:** When the model predicts this generic "Bat" category (or when processing sounds that weren't assigned a specific class during encoding), we need a way to represent this generic status with tags.
|
||||
This is defined by the `generic_class` list directly within the main `classes` configuration section.
|
||||
|
||||
**Defining the Generic Class Tags:**
|
||||
|
||||
You specify the tags for the generic class like this:
|
||||
|
||||
```yaml
|
||||
# In your main configuration file
|
||||
classes: # Main configuration section for classes
|
||||
# --- List of specific class definitions ---
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
# ... other specific classes ...
|
||||
|
||||
# --- Definition of the generic class tags ---
|
||||
generic_class: # Define tags for the generic 'Bat' category
|
||||
- key: call_type
|
||||
value: Echolocation
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
# These tags will be assigned when decoding the generic category
|
||||
```
|
||||
|
||||
This `generic_class` list provides the standard tags assigned when a sound is identified as relevant (passed filtering) but doesn't belong to one of the specific target classes you defined.
|
||||
Like the specific classes, sensible defaults are often provided if you don't explicitly define `generic_class`.
|
||||
|
||||
**Crucially:** Remember, if sounds should be **completely excluded** from training (not even considered "generic"), use **Filtering rules (Step 2)**.
|
||||
|
||||
### Outcome
|
||||
|
||||
By defining this list of prioritized class rules (including their `name`, matching `tags`, `match_type`, and optional `output_tags`) and the `generic_class` tags, you provide `batdetect2` with:
|
||||
|
||||
1. A clear procedure to assign a target label (`name`) to each relevant annotation for training.
|
||||
2. A clear mapping to convert predicted class names (including the generic case) back into meaningful annotation tags.
|
||||
|
||||
This complete definition prepares your data for the final heatmap generation (Step 5) and enables interpretation of the model's results.
|
141
docs/source/targets/filtering.md
Normal file
141
docs/source/targets/filtering.md
Normal file
@ -0,0 +1,141 @@
|
||||
# Step 2: Filtering Sound Events
|
||||
|
||||
## Purpose
|
||||
|
||||
When preparing your annotated audio data for training a `batdetect2` model, you often want to select only specific sound events.
|
||||
For example, you might want to:
|
||||
|
||||
- Focus only on echolocation calls and ignore social calls or noise.
|
||||
- Exclude annotations that were marked as low quality.
|
||||
- Train only on specific species or groups of species.
|
||||
|
||||
This filtering module allows you to define rules based on the **tags** associated with each sound event annotation.
|
||||
Only the events that pass _all_ your defined rules will be kept for further processing and training.
|
||||
|
||||
## How it Works: Rules
|
||||
|
||||
Filtering is controlled by a list of **rules**.
|
||||
Each rule defines a condition based on the tags attached to a sound event.
|
||||
An event must satisfy **all** the rules you define in your configuration to be included.
|
||||
If an event fails even one rule, it is discarded.
|
||||
|
||||
## Defining Rules in Configuration
|
||||
|
||||
You define these rules within your main configuration file (usually a `.yaml` file) under a specific section (the exact name might depend on the main training config, but let's assume it's called `filtering`).
|
||||
|
||||
The configuration consists of a list named `rules`.
|
||||
Each item in this list is a single filter rule.
|
||||
|
||||
Each **rule** has two parts:
|
||||
|
||||
1. `match_type`: Specifies the _kind_ of check to perform.
|
||||
2. `tags`: A list of specific tags (each with a `key` and `value`) that the rule applies to.
|
||||
|
||||
```yaml
|
||||
# Example structure in your configuration file
|
||||
filtering:
|
||||
rules:
|
||||
- match_type: <TYPE_OF_CHECK_1>
|
||||
tags:
|
||||
- key: <tag_key_1a>
|
||||
value: <tag_value_1a>
|
||||
- key: <tag_key_1b>
|
||||
value: <tag_value_1b>
|
||||
- match_type: <TYPE_OF_CHECK_2>
|
||||
tags:
|
||||
- key: <tag_key_2a>
|
||||
value: <tag_value_2a>
|
||||
# ... add more rules as needed
|
||||
```
|
||||
|
||||
## Understanding `match_type`
|
||||
|
||||
This determines _how_ the list of `tags` in the rule is used to check a sound event.
|
||||
There are four types:
|
||||
|
||||
1. **`any`**: (Keep if _at least one_ tag matches)
|
||||
|
||||
- The sound event **passes** this rule if it has **at least one** of the tags listed in the `tags` section of the rule.
|
||||
- Think of it as an **OR** condition.
|
||||
- _Example Use Case:_ Keep events if they are tagged as `Species: Pip Pip` OR `Species: Pip Pyg`.
|
||||
|
||||
2. **`all`**: (Keep only if _all_ tags match)
|
||||
|
||||
- The sound event **passes** this rule only if it has **all** of the tags listed in the `tags` section.
|
||||
The event can have _other_ tags as well, but it must contain _all_ the ones specified here.
|
||||
- Think of it as an **AND** condition.
|
||||
- _Example Use Case:_ Keep events only if they are tagged with `Sound Type: Echolocation` AND `Quality: Good`.
|
||||
|
||||
3. **`exclude`**: (Discard if _any_ tag matches)
|
||||
|
||||
- The sound event **passes** this rule only if it does **not** have **any** of the tags listed in the `tags` section.
|
||||
If it matches even one tag in the list, the event is discarded.
|
||||
- _Example Use Case:_ Discard events if they are tagged `Quality: Poor` OR `Noise Source: Insect`.
|
||||
|
||||
4. **`equal`**: (Keep only if tags match _exactly_)
|
||||
- The sound event **passes** this rule only if its set of tags is _exactly identical_ to the list of `tags` provided in the rule (no more, no less).
|
||||
- _Note:_ This is very strict and usually less useful than `all` or `any`.
|
||||
|
||||
## Combining Rules
|
||||
|
||||
Remember: A sound event must **pass every single rule** defined in the `rules` list to be kept.
|
||||
The rules are checked one by one, and if an event fails any rule, it's immediately excluded from further consideration.
|
||||
|
||||
## Examples
|
||||
|
||||
**Example 1: Keep good quality echolocation calls**
|
||||
|
||||
```yaml
|
||||
filtering:
|
||||
rules:
|
||||
# Rule 1: Must have the 'Echolocation' tag
|
||||
- match_type: any # Could also use 'all' if 'Sound Type' is the only tag expected
|
||||
tags:
|
||||
- key: Sound Type
|
||||
value: Echolocation
|
||||
# Rule 2: Must NOT have the 'Poor' quality tag
|
||||
- match_type: exclude
|
||||
tags:
|
||||
- key: Quality
|
||||
value: Poor
|
||||
```
|
||||
|
||||
_Explanation:_ An event is kept only if it passes BOTH rules.
|
||||
It must have the `Sound Type: Echolocation` tag AND it must NOT have the `Quality: Poor` tag.
|
||||
|
||||
**Example 2: Keep calls from Pipistrellus species recorded in a specific project, excluding uncertain IDs**
|
||||
|
||||
```yaml
|
||||
filtering:
|
||||
rules:
|
||||
# Rule 1: Must be either Pip pip or Pip pyg
|
||||
- match_type: any
|
||||
tags:
|
||||
- key: Species
|
||||
value: Pipistrellus pipistrellus
|
||||
- key: Species
|
||||
value: Pipistrellus pygmaeus
|
||||
# Rule 2: Must belong to 'Project Alpha'
|
||||
- match_type: any # Using 'any' as it likely only has one project tag
|
||||
tags:
|
||||
- key: Project ID
|
||||
value: Project Alpha
|
||||
# Rule 3: Exclude if ID Certainty is 'Low' or 'Maybe'
|
||||
- match_type: exclude
|
||||
tags:
|
||||
- key: ID Certainty
|
||||
value: Low
|
||||
- key: ID Certainty
|
||||
value: Maybe
|
||||
```
|
||||
|
||||
_Explanation:_ An event is kept only if it passes ALL three rules:
|
||||
|
||||
1. It has a `Species` tag that is _either_ `Pipistrellus pipistrellus` OR `Pipistrellus pygmaeus`.
|
||||
2. It has the `Project ID: Project Alpha` tag.
|
||||
3. It does _not_ have an `ID Certainty: Low` tag AND it does _not_ have an `ID Certainty: Maybe` tag.
|
||||
|
||||
## Usage
|
||||
|
||||
You will typically specify the path to the configuration file containing these `filtering` rules when you set up your data processing or training pipeline in `batdetect2`.
|
||||
The tool will then automatically load these rules and apply them to your annotated sound events.
|
78
docs/source/targets/index.md
Normal file
78
docs/source/targets/index.md
Normal file
@ -0,0 +1,78 @@
|
||||
# Defining Training Targets
|
||||
|
||||
A crucial aspect of training any supervised machine learning model, including BatDetect2, is clearly defining the **training targets**.
|
||||
This process determines precisely what the model should learn to detect, localize, classify, and characterize from the input data (in this case, spectrograms).
|
||||
The choices made here directly influence the model's focus, its performance, and how its predictions should be interpreted.
|
||||
|
||||
For BatDetect2, defining targets involves specifying:
|
||||
|
||||
- Which sounds in your annotated dataset are relevant for training.
|
||||
- How these sounds should be categorized into distinct **classes** (e.g., different species).
|
||||
- How the geometric **Region of Interest (ROI)** (e.g., bounding box) of each sound maps to the specific **position** and **size** targets the model predicts.
|
||||
- How these classes and geometric properties relate back to the detailed information stored in your annotation **tags** (using a consistent **vocabulary/terms**).
|
||||
- How the model's output (predicted class names, positions, sizes) should be translated back into meaningful tags and geometries.
|
||||
|
||||
## Sound Event Annotations: The Starting Point
|
||||
|
||||
BatDetect2 assumes your training data consists of audio recordings where relevant sound events have been **annotated**.
|
||||
A typical annotation for a single sound event provides two key pieces of information:
|
||||
|
||||
1. **Location & Extent:** Information defining _where_ the sound occurs in time and frequency, usually represented as a **bounding box** (the ROI) drawn on a spectrogram.
|
||||
2. **Description (Tags):** Information _about_ the sound event, provided as a set of descriptive **tags** (key-value pairs).
|
||||
|
||||
For example, an annotation might have a bounding box and tags like:
|
||||
|
||||
- `species: Myotis daubentonii`
|
||||
- `quality: Good`
|
||||
- `call_type: Echolocation`
|
||||
|
||||
A single sound event can have **multiple tags**, allowing for rich descriptions.
|
||||
This richness requires a structured process to translate the annotation (both tags and geometry) into the precise targets needed for model training.
|
||||
The **target definition process** provides clear rules to:
|
||||
|
||||
- Interpret the meaning of different tag keys (**Terms**).
|
||||
- Select only the relevant annotations (**Filtering**).
|
||||
- Potentially standardize or modify the tags (**Transforming**).
|
||||
- Map the geometric ROI to specific position and size targets (**ROI Mapping**).
|
||||
- Map the final set of tags on each selected annotation to a single, definitive **target class** label (**Classes**).
|
||||
|
||||
## Configuration-Driven Workflow
|
||||
|
||||
BatDetect2 is designed so that researchers can configure this entire target definition process primarily through **configuration files** (typically written in YAML format), minimizing the need for direct programming for standard workflows.
|
||||
These settings are usually grouped under a main `targets:` key within your overall training configuration file.
|
||||
|
||||
## The Target Definition Steps
|
||||
|
||||
Defining the targets involves several sequential steps, each configurable and building upon the previous one:
|
||||
|
||||
1. **Defining Vocabulary (Terms & Tags):** Understand how annotations use tags (key-value pairs).
|
||||
This step involves defining the meaning (**Terms**) behind the tag keys (e.g., `species`, `call_type`).
|
||||
Often, default terms are sufficient, but understanding this is key to using tags in later steps.
|
||||
(See: {doc}`tags_and_terms`})
|
||||
2. **Filtering Sound Events:** Select only the relevant sound event annotations based on their tags (e.g., keeping only high-quality calls).
|
||||
(See: {doc}`filtering`})
|
||||
3. **Transforming Tags (Optional):** Modify tags on selected annotations for standardization, correction, grouping (e.g., species to genus), or deriving new tags.
|
||||
(See: {doc}`transform`})
|
||||
4. **Defining Classes & Decoding Rules:** Map the final tags to specific target **class names** (like `pippip` or `myodau`).
|
||||
Define priorities for overlap and specify how predicted names map back to tags (decoding).
|
||||
(See: {doc}`classes`})
|
||||
5. **Mapping ROIs (Position & Size):** Define how the geometric ROI (e.g., bounding box) of each sound event maps to the specific reference **point** (e.g., center, corner) and scaled **size** values (width, height) used as targets by the model.
|
||||
(See: {doc}`rois`})
|
||||
6. **The `Targets` Object:** Understand the outcome of configuring steps 1-5 – a functional object used internally by BatDetect2 that encapsulates all your defined rules for filtering, transforming, ROI mapping, encoding, and decoding.
|
||||
(See: {doc}`use`)
|
||||
|
||||
The result of this configuration process is a clear set of instructions that BatDetect2 uses during training data preparation to determine the correct "answer" (the ground truth label and geometry representation) for each relevant sound event.
|
||||
|
||||
Explore the detailed steps using the links below:
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 1
|
||||
:caption: Target Definition Steps:
|
||||
|
||||
tags_and_terms
|
||||
filtering
|
||||
transform
|
||||
classes
|
||||
rois
|
||||
use
|
||||
```
|
76
docs/source/targets/labels.md
Normal file
76
docs/source/targets/labels.md
Normal file
@ -0,0 +1,76 @@
|
||||
# Step 5: Generating Training Targets
|
||||
|
||||
## Purpose and Context
|
||||
|
||||
Following the previous steps of defining terms, filtering events, transforming tags, and defining specific class rules, this final stage focuses on **generating the ground truth data** used directly for training the BatDetect2 model.
|
||||
This involves converting the refined annotation information for each audio clip into specific **heatmap formats** required by the underlying neural network architecture.
|
||||
|
||||
This step essentially translates your structured annotations into the precise "answer key" the model learns to replicate during training.
|
||||
|
||||
## What are Heatmaps?
|
||||
|
||||
Heatmaps, in this context, are multi-dimensional arrays, often visualized as images aligned with the input spectrogram, where the values at different time-frequency coordinates represent specific information about the sound events.
|
||||
For BatDetect2 training, three primary heatmaps are generated:
|
||||
|
||||
1. **Detection Heatmap:**
|
||||
|
||||
- **Represents:** The presence or likelihood of relevant sound events across the spectrogram.
|
||||
- **Structure:** A 2D array matching the spectrogram's time-frequency dimensions.
|
||||
Peaks (typically smoothed) are generated at the reference locations of all sound events that passed the filtering stage (including both specifically classified events and those falling into the generic "Bat" category).
|
||||
|
||||
2. **Class Heatmap:**
|
||||
|
||||
- **Represents:** The location and class identity for sounds belonging to the _specific_ target classes you defined in Step 4.
|
||||
- **Structure:** A 3D array with dimensions for category, time, and frequency.
|
||||
It contains a separate 2D layer (channel) for each target class name (e.g., 'pippip', 'myodau').
|
||||
A smoothed peak appears on a specific class layer only if a sound event assigned to that class exists at that location.
|
||||
Events assigned only to the generic class do not produce peaks here.
|
||||
|
||||
3. **Size Heatmap:**
|
||||
- **Represents:** The target dimensions (duration/width and bandwidth/height) of detected sound events.
|
||||
- **Structure:** A 3D array with dimensions for size-dimension ('width', 'height'), time, and frequency.
|
||||
At the reference location of each detected sound event, this heatmap stores two numerical values corresponding to the scaled width and height derived from the event's bounding box.
|
||||
|
||||
## How Heatmaps are Created
|
||||
|
||||
The generation of these heatmaps is an automated process within `batdetect2`, driven by your configurations from all previous steps.
|
||||
For each audio clip and its corresponding spectrogram in the training dataset:
|
||||
|
||||
1. The system retrieves the associated sound event annotations.
|
||||
2. Configured **filtering rules** (Step 2) are applied to select relevant annotations.
|
||||
3. Configured **tag transformation rules** (Step 3) are applied to modify the tags of the selected annotations.
|
||||
4. Configured **class definition rules** (Step 4) are used to assign a specific class name or determine generic "Bat" status for each processed annotation.
|
||||
5. These final annotations are then mapped onto initialized heatmap arrays:
|
||||
- A signal (initially a single point) is placed on the **Detection Heatmap** at the reference location for each relevant annotation.
|
||||
- The scaled width and height values are placed on the **Size Heatmap** at the reference location.
|
||||
- If an annotation received a specific class name, a signal is placed on the corresponding layer of the **Class Heatmap** at the reference location.
|
||||
6. Finally, Gaussian smoothing (a blurring effect) is typically applied to the Detection and Class heatmaps to create spatially smoother targets, which often aids model training stability and performance.
|
||||
|
||||
## Configurable Settings for Heatmap Generation
|
||||
|
||||
While the content of the heatmaps is primarily determined by the previous configuration steps, a few parameters specific to the heatmap drawing process itself can be adjusted.
|
||||
These are usually set in your main configuration file under a section like `labelling`:
|
||||
|
||||
- `sigma`: (Number, e.g., `3.0`) Defines the standard deviation, in pixels or bins, of the Gaussian kernel used for smoothing the Detection and Class heatmaps.
|
||||
Larger values result in more diffused heatmap peaks.
|
||||
- `position`: (Text, e.g., `"bottom-left"`, `"center"`) Specifies the geometric reference point within each sound event's bounding box that anchors its representation on the heatmaps.
|
||||
- `time_scale` & `frequency_scale`: (Numbers) These crucial scaling factors convert the physical duration (in seconds) and frequency bandwidth (in Hz) of annotation bounding boxes into the numerical values stored in the 'width' and 'height' channels of the Size Heatmap.
|
||||
- **Important Note:** The appropriate values for these scales are dictated by the requirements of the specific BatDetect2 model architecture being trained.
|
||||
They ensure the size information is presented in the units or relative scale the model expects.
|
||||
**Consult the documentation or tutorials for your specific model to determine the correct `time_scale` and `frequency_scale` values.** Mismatched scales can hinder the model's ability to learn size regression accurately.
|
||||
|
||||
**Example YAML Configuration for Labelling Settings:**
|
||||
|
||||
```yaml
|
||||
# In your main configuration file
|
||||
labelling:
|
||||
sigma: 3.0 # Std. dev. for Gaussian smoothing (pixels/bins)
|
||||
position: "bottom-left" # Bounding box reference point
|
||||
time_scale: 1000.0 # Example: Scales seconds to milliseconds
|
||||
frequency_scale: 0.00116 # Example: Scales Hz relative to ~860 Hz (model specific!)
|
||||
```
|
||||
|
||||
## Outcome: Final Training Targets
|
||||
|
||||
Executing this step for all training data yields the complete set of target heatmaps (Detection, Class, Size) for each corresponding input spectrogram.
|
||||
These arrays constitute the ground truth data that the BatDetect2 model directly compares its predictions against during the training phase, guiding its learning process.
|
85
docs/source/targets/rois.md
Normal file
85
docs/source/targets/rois.md
Normal file
@ -0,0 +1,85 @@
|
||||
## Defining Target Geometry: Mapping Sound Event Regions
|
||||
|
||||
### Introduction
|
||||
|
||||
In the previous steps of defining targets, we focused on determining _which_ sound events are relevant (`filtering`), _what_ descriptive tags they should have (`transform`), and _which category_ they belong to (`classes`).
|
||||
However, for the model to learn effectively, it also needs to know **where** in the spectrogram each sound event is located and approximately **how large** it is.
|
||||
|
||||
Your annotations typically define the location and extent of a sound event using a **Region of Interest (ROI)**, most commonly a **bounding box** drawn around the call on the spectrogram.
|
||||
This ROI contains detailed spatial information (start/end time, low/high frequency).
|
||||
|
||||
This section explains how BatDetect2 converts the geometric ROI from your annotations into the specific positional and size information used as targets during model training.
|
||||
|
||||
### From ROI to Model Targets: Position & Size
|
||||
|
||||
BatDetect2 does not directly predict a full bounding box.
|
||||
Instead, it is trained to predict:
|
||||
|
||||
1. **A Reference Point:** A single point `(time, frequency)` that represents the primary location of the detected sound event within the spectrogram.
|
||||
2. **Size Dimensions:** Numerical values representing the event's size relative to that reference point, typically its `width` (duration in time) and `height` (bandwidth in frequency).
|
||||
|
||||
This step defines _how_ BatDetect2 calculates this specific reference point and these numerical size values from the original annotation's bounding box.
|
||||
It also handles the reverse process – converting predicted positions and sizes back into bounding boxes for visualization or analysis.
|
||||
|
||||
### Configuring the ROI Mapping
|
||||
|
||||
You can control how this conversion happens through settings in your configuration file (e.g., your main `.yaml` file).
|
||||
These settings are usually placed within the main `targets:` configuration block, under a specific `roi:` key.
|
||||
|
||||
Here are the key settings:
|
||||
|
||||
- **`position`**:
|
||||
|
||||
- **What it does:** Determines which specific point on the annotation's bounding box is used as the single **Reference Point** for training (e.g., `"center"`, `"bottom-left"`).
|
||||
- **Why configure it?** This affects where the peak signal appears in the target heatmaps used for training.
|
||||
Different choices might slightly influence model learning.
|
||||
The default (`"bottom-left"`) is often a good starting point.
|
||||
- **Example Value:** `position: "center"`
|
||||
|
||||
- **`time_scale`**:
|
||||
|
||||
- **What it does:** This is a numerical scaling factor that converts the _actual duration_ (width, measured in seconds) of the bounding box into the numerical 'width' value the model learns to predict (and which is stored in the Size Heatmap).
|
||||
- **Why configure it?** The model predicts raw numbers for size; this scale gives those numbers meaning.
|
||||
For example, setting `time_scale: 1000.0` means the model will be trained to predict the duration in **milliseconds** instead of seconds.
|
||||
- **Important Considerations:**
|
||||
- You can often set this value based on the units you prefer the model to work with internally.
|
||||
However, having target numerical values roughly centered around 1 (e.g., typically between 0.1 and 10) can sometimes improve numerical stability during model training.
|
||||
- The default value in BatDetect2 (e.g., `1000.0`) has been chosen to scale the duration relative to the spectrogram width under default STFT settings.
|
||||
Be aware that if you significantly change STFT parameters (window size or overlap), the relationship between the default scale and the spectrogram dimensions might change.
|
||||
- Crucially, whatever scale you use during training **must** be used when decoding the model's predictions back into real-world time units (seconds).
|
||||
BatDetect2 generally handles this consistency for you automatically when using the full pipeline.
|
||||
- **Example Value:** `time_scale: 1000.0`
|
||||
|
||||
- **`frequency_scale`**:
|
||||
- **What it does:** Similar to `time_scale`, this numerical scaling factor converts the _actual frequency bandwidth_ (height, typically measured in Hz or kHz) of the bounding box into the numerical 'height' value the model learns to predict.
|
||||
- **Why configure it?** It gives physical meaning to the model's raw numerical prediction for bandwidth and allows you to choose the internal units or scale.
|
||||
- **Important Considerations:**
|
||||
- Same as for `time_scale`.
|
||||
- **Example Value:** `frequency_scale: 0.00116`
|
||||
|
||||
**Example YAML Configuration:**
|
||||
|
||||
```yaml
|
||||
# Inside your main configuration file (e.g., training_config.yaml)
|
||||
|
||||
targets: # Top-level key for target definition
|
||||
# ... filtering settings ...
|
||||
# ... transforms settings ...
|
||||
# ... classes settings ...
|
||||
|
||||
# --- ROI Mapping Settings ---
|
||||
roi:
|
||||
position: "bottom-left" # Reference point (e.g., "center", "bottom-left")
|
||||
time_scale: 1000.0 # e.g., Model predicts width in ms
|
||||
frequency_scale: 0.00116 # e.g., Model predicts height relative to ~860Hz (or other model-specific scaling)
|
||||
```
|
||||
|
||||
### Decoding Size Predictions
|
||||
|
||||
These scaling factors (`time_scale`, `frequency_scale`) are also essential for interpreting the model's output correctly.
|
||||
When the model predicts numerical values for width and height, BatDetect2 uses these same scales (in reverse) to convert those numbers back into physically meaningful durations (seconds) and bandwidths (Hz/kHz) when reconstructing bounding boxes from predictions.
|
||||
|
||||
### Outcome
|
||||
|
||||
By configuring the `roi` settings, you ensure that BatDetect2 consistently translates the geometric information from your annotations into the specific reference points and scaled size values required for training the model.
|
||||
Using consistent scales that are appropriate for your data and potentially beneficial for training stability allows the model to effectively learn not just _what_ sound is present, but also _where_ it is located and _how large_ it is, and enables meaningful interpretation of the model's spatial and size predictions.
|
166
docs/source/targets/tags_and_terms.md
Normal file
166
docs/source/targets/tags_and_terms.md
Normal file
@ -0,0 +1,166 @@
|
||||
# Step 1: Managing Annotation Vocabulary
|
||||
|
||||
## Purpose
|
||||
|
||||
To train `batdetect2`, you will need sound events that have been carefully annotated. We annotate sound events using **tags**. A tag is simply a piece of information attached to an annotation, often describing what the sound is or its characteristics. Common examples include `Species: Myotis daubentonii` or `Quality: Good`.
|
||||
|
||||
Each tag fundamentally has two parts:
|
||||
|
||||
* **Value:** The specific information (e.g., "Myotis daubentonii", "Good").
|
||||
* **Term:** The *type* of information (e.g., "Species", "Quality"). This defines the context or meaning of the value.
|
||||
|
||||
We use this flexible **Term: Value** approach because it allows you to annotate your data with any kind of information relevant to your project, while still providing a structure that makes the meaning clear.
|
||||
|
||||
While simple terms like "Species" are easy to understand, sometimes the underlying definition needs to be more precise to ensure everyone interprets it the same way (e.g., using a standard scientific definition for "Species" or clarifying what "Call Type" specifically refers to).
|
||||
|
||||
This `terms` module is designed to help manage these definitions effectively:
|
||||
|
||||
1. It provides **standard definitions** for common terms used in bioacoustics, ensuring consistency.
|
||||
2. It lets you **define your own custom terms** if you need concepts specific to your project.
|
||||
3. Crucially, it allows you to use simple **"keys"** (like shortcuts) in your configuration files to refer to these potentially complex term definitions, making configuration much easier and less error-prone.
|
||||
|
||||
## The Problem: Why We Need Defined Terms
|
||||
|
||||
Imagine you have a tag that simply says `"Myomyo"`.
|
||||
If you created this tag, you might know it's a shortcut for the species _Myotis myotis_.
|
||||
But what about someone else using your data or model later? Does `"Myomyo"` refer to the species? Or maybe it's the name of an individual bat, or even the location where it was recorded? Simple tags like this can be ambiguous.
|
||||
|
||||
To make things clearer, it's good practice to provide context.
|
||||
We can do this by pairing the specific information (the **Value**) with the _type_ of information (the **Term**).
|
||||
For example, writing the tag as `species: Myomyo` is much less ambiguous.
|
||||
Here, `species` is the **Term**, explaining that `Myomyo` is a **Value** representing a species.
|
||||
|
||||
However, another challenge often comes up when sharing data or collaborating.
|
||||
You might use the term `species`, while a colleague uses `Species`, and someone else uses the more formal `Scientific Name`.
|
||||
Even though you all mean the same thing, these inconsistencies make it hard to combine data or reuse analysis pipelines automatically.
|
||||
|
||||
This is where standardized **Terms** become very helpful.
|
||||
Several groups work to create standard definitions for common concepts.
|
||||
For instance, the Darwin Core standard provides widely accepted terms for biological data, like `dwc:scientificName` for a species name.
|
||||
Using standard Terms whenever possible makes your data clearer, easier for others (and machines!) to understand correctly, and much more reusable across different projects.
|
||||
|
||||
**But here's the practical problem:** While using standard, well-defined Terms is important for clarity and reusability, writing out full definitions or long standard names (like `dwc:scientificName` or "Scientific Name according to Darwin Core standard") every single time you need to refer to a species tag in a configuration file would be extremely tedious and prone to typing errors.
|
||||
|
||||
## The Solution: Keys (Shortcuts) and the Registry
|
||||
|
||||
This module uses a central **Registry** that stores the full definitions of various Terms.
|
||||
Each Term in the registry is assigned a unique, short **key** (a simple string).
|
||||
|
||||
Think of the **key** as shortcut.
|
||||
|
||||
Instead of using the full Term definition in your configuration files, you just use its **key**.
|
||||
The system automatically looks up the full definition in the registry using the key when needed.
|
||||
|
||||
**Example:**
|
||||
|
||||
- **Full Term Definition:** Represents the scientific name of the organism.
|
||||
- **Key:** `species`
|
||||
- **In Config:** You just write `species`.
|
||||
|
||||
## Available Keys
|
||||
|
||||
The registry comes pre-loaded with keys for many standard terms used in bioacoustics, including those from the `soundevent` package and some specific to `batdetect2`. This means you can often use these common concepts without needing to define them yourself.
|
||||
|
||||
Common examples of pre-defined keys might include:
|
||||
|
||||
* `species`: For scientific species names (e.g., *Myotis daubentonii*).
|
||||
* `common_name`: For the common name of a species (e.g., "Daubenton's bat").
|
||||
* `genus`, `family`, `order`: For higher levels of biological taxonomy.
|
||||
* `call_type`: For functional call types (e.g., 'Echolocation', 'Social').
|
||||
* `individual`: For identifying specific individuals if tracked.
|
||||
* `class`: **(Special Key)** This key is often used **by default** in configurations when defining the target classes for your model (e.g., the different species you want the model to classify). If you are specifying a tag that represents a target class label, you often only need to provide the `value`, and the system assumes the `key` is `class`.
|
||||
|
||||
This is not an exhaustive list. To discover all the term keys currently available in the registry (including any standard ones loaded automatically and any custom ones you've added in your configuration), you can:
|
||||
|
||||
1. Use the function `batdetect2.terms.get_term_keys()` if you are working directly with Python code.
|
||||
2. Refer to the main `batdetect2` API documentation for a list of commonly included standard terms.
|
||||
|
||||
Okay, let's refine the "Defining Your Own Terms" section to incorporate the explanation about namespacing within the `name` field description, keeping the style clear and researcher-focused.
|
||||
|
||||
## Defining Your Own Terms
|
||||
|
||||
While many common terms have pre-defined keys, you might need a term specific to your project or data that isn't already available (e.g., "Recording Setup", "Weather Condition", "Project Phase", "Noise Source"). You can easily define these custom terms directly within a configuration file (usually your main `.yaml` file).
|
||||
|
||||
Typically, you define custom terms under a dedicated section (often named `terms`). Inside this section, you create a list, where each item in the list defines one new term using the following fields:
|
||||
|
||||
* `key`: **(Required)** This is the unique shortcut key or nickname you will use to refer to this term throughout your configuration (e.g., `weather`, `setup_id`, `noise_src`). Choose something short and memorable.
|
||||
* `label`: (Optional) A user-friendly label for the term, which might be used in reports or visualizations (e.g., "Weather Condition", "Setup ID"). If you don't provide one, it defaults to using the `key`.
|
||||
* `name`: (Optional) A more formal or technical name for the term.
|
||||
* It's good practice, especially if defining terms that might overlap with standard vocabularies, to use a **namespaced format** like `<namespace>:<term_name>`. The `namespace` part helps avoid clashes with terms defined elsewhere. For example, the standard Darwin Core term for scientific name is `dwc:scientificName`, where `dwc` is the namespace for Darwin Core. Using namespaces makes your custom terms more specific and reduces potential confusion.
|
||||
* If you don't provide a `name`, it defaults to using the `key`.
|
||||
* `definition`: (Optional) A brief text description explaining what this term represents (e.g., "The primary source of background noise identified", "General weather conditions during recording"). If omitted, it defaults to "Unknown".
|
||||
* `uri`: (Optional) If your term definition comes directly from a standard online vocabulary (like Darwin Core), you can include its unique web identifier (URI) here.
|
||||
|
||||
**Example YAML Configuration for Custom Terms:**
|
||||
|
||||
```yaml
|
||||
# In your main configuration file
|
||||
|
||||
# (Optional section to define custom terms)
|
||||
terms:
|
||||
- key: weather # Your chosen shortcut
|
||||
label: Weather Condition
|
||||
name: myproj:weather # Formal namespaced name
|
||||
definition: General weather conditions during recording (e.g., Clear, Rain, Fog).
|
||||
|
||||
- key: setup_id # Another shortcut
|
||||
label: Recording Setup ID
|
||||
name: myproj:setupID # Formal namespaced name
|
||||
definition: The unique identifier for the specific hardware setup used.
|
||||
|
||||
- key: species # Defining a term with a standard URI
|
||||
label: Scientific Name
|
||||
name: dwc:scientificName
|
||||
uri: http://rs.tdwg.org/dwc/terms/scientificName # Example URI
|
||||
definition: The full scientific name according to Darwin Core.
|
||||
|
||||
# ... other configuration sections ...
|
||||
```
|
||||
|
||||
When `batdetect2` loads your configuration, it reads this `terms` section and adds your custom definitions (linked to their unique keys) to the central registry. These keys (`weather`, `setup_id`, etc.) are then ready to be used in other parts of your configuration, like defining filters or target classes.
|
||||
|
||||
## Using Keys to Specify Tags (in Filters, Class Definitions, etc.)
|
||||
|
||||
Now that you have keys for all the terms you need (both pre-defined and custom), you can easily refer to specific **tags** in other parts of your configuration, such as:
|
||||
|
||||
- Filtering rules (as seen in the `filtering` module documentation).
|
||||
- Defining which tags represent your target classes.
|
||||
- Associating extra information with your classes.
|
||||
|
||||
When you need to specify a tag, you typically use a structure with two fields:
|
||||
|
||||
- `key`: The **key** (shortcut) for the _Term_ part of the tag (e.g., `species`, `quality`, `weather`).
|
||||
**It defaults to `class`** if you omit it, which is common when defining the main target classes.
|
||||
- `value`: The specific _value_ of the tag (e.g., `Myotis daubentonii`, `Good`, `Rain`).
|
||||
|
||||
**Example YAML Configuration using TagInfo (e.g., inside a filter rule):**
|
||||
|
||||
```yaml
|
||||
# ... inside a filtering configuration section ...
|
||||
rules:
|
||||
# Rule: Exclude events recorded in 'Rain'
|
||||
- match_type: exclude
|
||||
tags:
|
||||
- key: weather # Use the custom term key defined earlier
|
||||
value: Rain
|
||||
# Rule: Keep only 'Myotis daubentonii' (using the default 'class' key implicitly)
|
||||
- match_type: any # Or 'all' depending on logic
|
||||
tags:
|
||||
- value: Myotis daubentonii # 'key: class' is assumed by default here
|
||||
# key: class # Explicitly writing this is also fine
|
||||
# Rule: Keep only 'Good' quality events
|
||||
- match_type: any # Or 'all' depending on logic
|
||||
tags:
|
||||
- key: quality # Use a likely pre-defined key
|
||||
value: Good
|
||||
```
|
||||
|
||||
## Summary
|
||||
|
||||
- Annotations have **tags** (Term + Value).
|
||||
- This module uses short **keys** as shortcuts for Term definitions, stored in a **registry**.
|
||||
- Many **common keys are pre-defined**.
|
||||
- You can define **custom terms and keys** in your configuration file (using `key`, `label`, `definition`).
|
||||
- You use these **keys** along with specific **values** to refer to tags in other configuration sections (like filters or class definitions), often defaulting to the `class` key.
|
||||
|
||||
This system makes your configurations cleaner, more readable, and less prone to errors by avoiding repetition of complex term definitions.
|
118
docs/source/targets/transform.md
Normal file
118
docs/source/targets/transform.md
Normal file
@ -0,0 +1,118 @@
|
||||
# Step 3: Transforming Annotation Tags (Optional)
|
||||
|
||||
## Purpose and Context
|
||||
|
||||
After defining your vocabulary (Step 1: Terms) and filtering out irrelevant sound events (Step 2: Filtering), you have a dataset of annotations ready for the next stages.
|
||||
Before you select the final target classes for training (Step 4), you might want or need to **modify the tags** associated with your annotations.
|
||||
This optional step allows you to clean up, standardize, or derive new information from your existing tags.
|
||||
|
||||
**Why transform tags?**
|
||||
|
||||
- **Correcting Mistakes:** Fix typos or incorrect values in specific tags (e.g., changing an incorrect species label).
|
||||
- **Standardizing Labels:** Ensure consistency if the same information was tagged using slightly different values (e.g., mapping "echolocation", "Echoloc.", and "Echolocation Call" all to a single standard value: "Echolocation").
|
||||
- **Grouping Related Concepts:** Combine different specific tags into a broader category (e.g., mapping several different species tags like _Myotis daubentonii_ and _Myotis nattereri_ to a single `genus: Myotis` tag).
|
||||
- **Deriving New Information:** Automatically create new tags based on existing ones (e.g., automatically generating a `genus: Myotis` tag whenever a `species: Myotis daubentonii` tag is present).
|
||||
|
||||
This step uses the `batdetect2.targets.transform` module to apply these changes based on rules you define.
|
||||
|
||||
## How it Works: Transformation Rules
|
||||
|
||||
You control how tags are transformed by defining a list of **rules** in your configuration file (e.g., your main `.yaml` file, often under a section named `transform`).
|
||||
|
||||
Each rule specifies a particular type of transformation to perform.
|
||||
Importantly, the rules are applied **sequentially**, in the exact order they appear in your configuration list.
|
||||
The output annotation from one rule becomes the input for the next rule in the list.
|
||||
This means the order can matter!
|
||||
|
||||
## Types of Transformation Rules
|
||||
|
||||
Here are the main types of rules you can define:
|
||||
|
||||
1. **Replace an Exact Tag (`replace`)**
|
||||
|
||||
- **Use Case:** Fixing a specific, known incorrect tag.
|
||||
- **How it works:** You specify the _exact_ original tag (both its term key and value) and the _exact_ tag you want to replace it with.
|
||||
- **Example Config:** Replace the informal tag `species: Pip pip` with the correct scientific name tag.
|
||||
```yaml
|
||||
transform:
|
||||
rules:
|
||||
- rule_type: replace
|
||||
original:
|
||||
key: species # Term key of the tag to find
|
||||
value: "Pip pip" # Value of the tag to find
|
||||
replacement:
|
||||
key: species # Term key of the replacement tag
|
||||
value: "Pipistrellus pipistrellus" # Value of the replacement tag
|
||||
```
|
||||
|
||||
2. **Map Values (`map_value`)**
|
||||
|
||||
- **Use Case:** Standardizing different values used for the same concept, or grouping multiple specific values into one category.
|
||||
- **How it works:** You specify a `source_term_key` (the type of tag to look at, e.g., `call_type`).
|
||||
Then you provide a `value_mapping` dictionary listing original values and the new values they should be mapped to.
|
||||
Only tags matching the `source_term_key` and having a value listed in the mapping will be changed.
|
||||
You can optionally specify a `target_term_key` if you want to change the term type as well (e.g., mapping species to a genus).
|
||||
- **Example Config:** Standardize different ways "Echolocation" might have been written for the `call_type` term.
|
||||
```yaml
|
||||
transform:
|
||||
rules:
|
||||
- rule_type: map_value
|
||||
source_term_key: call_type # Look at 'call_type' tags
|
||||
# target_term_key is not specified, so the term stays 'call_type'
|
||||
value_mapping:
|
||||
echolocation: Echolocation
|
||||
Echolocation Call: Echolocation
|
||||
Echoloc.: Echolocation
|
||||
# Add mappings for other values like 'Social' if needed
|
||||
```
|
||||
- **Example Config (Grouping):** Map specific Pipistrellus species tags to a single `genus: Pipistrellus` tag.
|
||||
```yaml
|
||||
transform:
|
||||
rules:
|
||||
- rule_type: map_value
|
||||
source_term_key: species # Look at 'species' tags
|
||||
target_term_key: genus # Change the term to 'genus'
|
||||
value_mapping:
|
||||
"Pipistrellus pipistrellus": Pipistrellus
|
||||
"Pipistrellus pygmaeus": Pipistrellus
|
||||
"Pipistrellus nathusii": Pipistrellus
|
||||
```
|
||||
|
||||
3. **Derive a New Tag (`derive_tag`)**
|
||||
- **Use Case:** Automatically creating new information based on existing tags, like getting the genus from a species name.
|
||||
- **How it works:** You specify a `source_term_key` (e.g., `species`).
|
||||
You provide a `target_term_key` for the new tag to be created (e.g., `genus`).
|
||||
You also provide the name of a `derivation_function` (e.g., `"extract_genus"`) that knows how to perform the calculation (e.g., take "Myotis daubentonii" and return "Myotis").
|
||||
`batdetect2` has some built-in functions, or you can potentially define your own (see advanced documentation).
|
||||
You can also choose whether to keep the original source tag (`keep_source: true`).
|
||||
- **Example Config:** Create a `genus` tag from the existing `species` tag, keeping the species tag.
|
||||
```yaml
|
||||
transform:
|
||||
rules:
|
||||
- rule_type: derive_tag
|
||||
source_term_key: species # Use the value from the 'species' tag
|
||||
target_term_key: genus # Create a tag with the 'genus' term
|
||||
derivation_function: extract_genus # Use the built-in function for this
|
||||
keep_source: true # Keep the original 'species' tag
|
||||
```
|
||||
- **Another Example:** Convert species names to uppercase (modifying the value of the _same_ term).
|
||||
```yaml
|
||||
transform:
|
||||
rules:
|
||||
- rule_type: derive_tag
|
||||
source_term_key: species # Use the value from the 'species' tag
|
||||
# target_term_key is not specified, so the term stays 'species'
|
||||
derivation_function: to_upper_case # Assume this function exists
|
||||
keep_source: false # Replace the original species tag
|
||||
```
|
||||
|
||||
## Rule Order Matters
|
||||
|
||||
Remember that rules are applied one after another.
|
||||
If you have multiple rules, make sure they are ordered correctly to achieve the desired outcome.
|
||||
For instance, you might want to standardize species names _before_ deriving the genus from them.
|
||||
|
||||
## Outcome
|
||||
|
||||
After applying all the transformation rules you've defined, the annotations will proceed to the next step (Step 4: Select Target Tags & Define Classes) with their tags potentially cleaned, standardized, or augmented based on your configuration.
|
||||
If you don't define any rules, the tags simply pass through this step unchanged.
|
91
docs/source/targets/use.md
Normal file
91
docs/source/targets/use.md
Normal file
@ -0,0 +1,91 @@
|
||||
## Bringing It All Together: The `Targets` Object
|
||||
|
||||
### Recap: Defining Your Target Strategy
|
||||
|
||||
In the previous sections, we covered the sequential steps to precisely define what your BatDetect2 model should learn, specified within your configuration file:
|
||||
|
||||
1. **Terms:** Establishing the vocabulary for annotation tags.
|
||||
2. **Filtering:** Selecting relevant sound event annotations.
|
||||
3. **Transforming:** Optionally modifying tags.
|
||||
4. **Classes:** Defining target categories, setting priorities, and specifying tag decoding rules.
|
||||
5. **ROI Mapping:** Defining how annotation geometry maps to target position and size values.
|
||||
|
||||
You define all these aspects within your configuration file (e.g., YAML), which holds the complete specification for your target definition strategy, typically under a main `targets:` key.
|
||||
|
||||
### What is the `Targets` Object?
|
||||
|
||||
While the configuration file specifies _what_ you want to happen, BatDetect2 needs an active component to actually _perform_ these steps.
|
||||
This is the role of the `Targets` object.
|
||||
|
||||
The `Targets` is an organized container that holds all the specific functions and settings derived from your configuration file (`TargetConfig`).
|
||||
It's created directly from your configuration and provides methods to apply the **filtering**, **transformation**, **ROI mapping** (geometry to position/size and back), **class encoding**, and **class decoding** steps you defined.
|
||||
It effectively bundles together all the target definition logic determined by your settings into a single, usable object.
|
||||
|
||||
### How is it Created and Used?
|
||||
|
||||
For most standard training workflows, you typically won't need to create or interact with the `Targets` object directly in Python code.
|
||||
BatDetect2 usually handles its creation automatically when you provide your main configuration file during training setup.
|
||||
|
||||
Conceptually, here's what happens behind the scenes:
|
||||
|
||||
1. You provide the path to your configuration file (e.g., `my_training_config.yaml`).
|
||||
2. BatDetect2 reads this file and finds your `targets:` configuration section.
|
||||
3. It uses this configuration to build an instance of the `Targets` object using a dedicated function (like `load_targets`), loading it with the appropriate logic based on your settings.
|
||||
|
||||
```python
|
||||
# Conceptual Example: How BatDetect2 might use your configuration
|
||||
from batdetect2.targets import load_targets # The function to load/build the object
|
||||
from batdetect2.targets.types import TargetProtocol # The type/interface
|
||||
|
||||
# You provide this path, usually as part of the main training setup
|
||||
target_config_file = "path/to/your/target_config.yaml"
|
||||
|
||||
# --- BatDetect2 Internally Does Something Like This: ---
|
||||
# Loads your config and builds the Targets object using the loader function
|
||||
# The resulting object adheres to the TargetProtocol interface
|
||||
targets_processor: TargetProtocol = load_targets(target_config_file)
|
||||
# ---------------------------------------------------------
|
||||
|
||||
# Now, 'targets_processor' holds all your configured logic and is ready
|
||||
# to be used internally by the training pipeline or for prediction processing.
|
||||
```
|
||||
|
||||
### What Does the `Targets` Object Do? (Its Role)
|
||||
|
||||
Once created, the `targets_processor` object plays several vital roles within the BatDetect2 system:
|
||||
|
||||
1. **Preparing Training Data:** During the data loading and label generation phase of training, BatDetect2 uses this object to process each annotation from your dataset _before_ the final training format (e.g., heatmaps) is generated.
|
||||
For each annotation, it internally applies the logic:
|
||||
- `targets_processor.filter(...)`: To decide whether to keep the annotation.
|
||||
- `targets_processor.transform(...)`: To apply any tag modifications.
|
||||
- `targets_processor.encode(...)`: To get the final class name (e.g., `'pippip'`, `'myodau'`, or `None` for the generic class).
|
||||
- `targets_processor.get_position(...)`: To determine the reference `(time, frequency)` point from the annotation's geometry.
|
||||
- `targets_processor.get_size(...)`: To calculate the _scaled_ width and height target values from the annotation's geometry.
|
||||
2. **Interpreting Model Predictions:** When you use a trained model, its raw outputs (like predicted class names, positions, and sizes) need to be translated back into meaningful results.
|
||||
This object provides the necessary decoding logic:
|
||||
- `targets_processor.decode(...)`: Converts a predicted class name back into representative annotation tags.
|
||||
- `targets_processor.recover_roi(...)`: Converts a predicted position and _scaled_ size values back into an estimated geometric bounding box in real-world coordinates (seconds, Hz).
|
||||
- `targets_processor.generic_class_tags`: Provides the tags for sounds classified into the generic category.
|
||||
3. **Providing Metadata:** It conveniently holds useful information derived from your configuration:
|
||||
- `targets_processor.class_names`: The final list of specific target class names.
|
||||
- `targets_processor.generic_class_tags`: The tags representing the generic class.
|
||||
- `targets_processor.dimension_names`: The names used for the size dimensions (e.g., `['width', 'height']`).
|
||||
|
||||
### Why is Understanding This Important?
|
||||
|
||||
As a researcher using BatDetect2, your primary interaction is typically through the **configuration file**.
|
||||
The `Targets` object is the component that materializes your configurations.
|
||||
|
||||
Understanding its role can be important:
|
||||
|
||||
- It helps connect the settings in your configuration file (covering terms, filtering, transforms, classes, and ROIs) to the actual behavior observed during training or when interpreting model outputs.
|
||||
If the results aren't as expected (e.g., wrong classifications, incorrect bounding box predictions), reviewing the relevant sections of your `TargetConfig` is the first step in debugging.
|
||||
- Furthermore, understanding this structure is beneficial if you plan to create custom Python scripts.
|
||||
While standard training runs handle this object internally, the underlying functions for filtering, transforming, encoding, decoding, and ROI mapping are accessible or can be built individually.
|
||||
This modular design provides the **flexibility to use or customize specific parts of the target definition workflow programmatically** for advanced analyses, integration tasks, or specialized data processing pipelines, should you need to go beyond the standard configuration-driven approach.
|
||||
|
||||
### Summary
|
||||
|
||||
The `Targets` object encapsulates the entire configured target definition logic specified in your `TargetConfig` file.
|
||||
It acts as the central component within BatDetect2 for applying filtering, tag transformation, ROI mapping (geometry to/from position/size), class encoding (for training preparation), and class/ROI decoding (for interpreting predictions).
|
||||
It bridges the gap between your declarative configuration and the functional steps needed for training and using BatDetect2 models effectively, while also offering components for more advanced, scripted workflows.
|
10
example_conf.yaml
Normal file
10
example_conf.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
datasets:
|
||||
train:
|
||||
name: example dataset
|
||||
description: Only for demonstration purposes
|
||||
sources:
|
||||
- format: batdetect2
|
||||
name: Example Data
|
||||
description: Examples included for testing batdetect2
|
||||
annotations_dir: example_data/anns
|
||||
audio_dir: example_data/audio
|
177
example_data/anns/20170701_213954-MYOMYS-LR_0_0.5.wav.json
Normal file
177
example_data/anns/20170701_213954-MYOMYS-LR_0_0.5.wav.json
Normal file
@ -0,0 +1,177 @@
|
||||
{
|
||||
"annotated": true,
|
||||
"annotation": [
|
||||
{
|
||||
"class": "Myotis mystacinus",
|
||||
"class_prob": 0.55,
|
||||
"det_prob": 0.658,
|
||||
"end_time": 0.028,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 107492,
|
||||
"individual": "-1",
|
||||
"low_freq": 33203,
|
||||
"start_time": 0.0225
|
||||
},
|
||||
{
|
||||
"class": "Myotis mystacinus",
|
||||
"class_prob": 0.679,
|
||||
"det_prob": 0.742,
|
||||
"end_time": 0.0583,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 113192,
|
||||
"individual": "-1",
|
||||
"low_freq": 28046,
|
||||
"start_time": 0.0525
|
||||
},
|
||||
{
|
||||
"class": "Myotis mystacinus",
|
||||
"class_prob": 0.488,
|
||||
"det_prob": 0.585,
|
||||
"end_time": 0.1211,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 107008,
|
||||
"individual": "-1",
|
||||
"low_freq": 33203,
|
||||
"start_time": 0.1155
|
||||
},
|
||||
{
|
||||
"class": "Pipistrellus pipistrellus",
|
||||
"class_prob": 0.46,
|
||||
"det_prob": 0.503,
|
||||
"end_time": 0.145,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 59621,
|
||||
"individual": "-1",
|
||||
"low_freq": 48671,
|
||||
"start_time": 0.1385
|
||||
},
|
||||
{
|
||||
"class": "Myotis mystacinus",
|
||||
"class_prob": 0.656,
|
||||
"det_prob": 0.704,
|
||||
"end_time": 0.1513,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 113493,
|
||||
"individual": "-1",
|
||||
"low_freq": 27187,
|
||||
"start_time": 0.1445
|
||||
},
|
||||
{
|
||||
"class": "Myotis mystacinus",
|
||||
"class_prob": 0.549,
|
||||
"det_prob": 0.63,
|
||||
"end_time": 0.2076,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 108573,
|
||||
"individual": "-1",
|
||||
"low_freq": 34062,
|
||||
"start_time": 0.2025
|
||||
},
|
||||
{
|
||||
"class": "Pipistrellus pipistrellus",
|
||||
"class_prob": 0.503,
|
||||
"det_prob": 0.528,
|
||||
"end_time": 0.224,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 57361,
|
||||
"individual": "-1",
|
||||
"low_freq": 48671,
|
||||
"start_time": 0.2195
|
||||
},
|
||||
{
|
||||
"class": "Myotis mystacinus",
|
||||
"class_prob": 0.672,
|
||||
"det_prob": 0.737,
|
||||
"end_time": 0.2374,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 116415,
|
||||
"individual": "-1",
|
||||
"low_freq": 27187,
|
||||
"start_time": 0.2315
|
||||
},
|
||||
{
|
||||
"class": "Pipistrellus pipistrellus",
|
||||
"class_prob": 0.65,
|
||||
"det_prob": 0.736,
|
||||
"end_time": 0.3058,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 56624,
|
||||
"individual": "-1",
|
||||
"low_freq": 48671,
|
||||
"start_time": 0.2995
|
||||
},
|
||||
{
|
||||
"class": "Myotis mystacinus",
|
||||
"class_prob": 0.687,
|
||||
"det_prob": 0.724,
|
||||
"end_time": 0.3312,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 116522,
|
||||
"individual": "-1",
|
||||
"low_freq": 27187,
|
||||
"start_time": 0.3245
|
||||
},
|
||||
{
|
||||
"class": "Myotis mystacinus",
|
||||
"class_prob": 0.547,
|
||||
"det_prob": 0.599,
|
||||
"end_time": 0.3762,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 108530,
|
||||
"individual": "-1",
|
||||
"low_freq": 34062,
|
||||
"start_time": 0.3705
|
||||
},
|
||||
{
|
||||
"class": "Myotis mystacinus",
|
||||
"class_prob": 0.664,
|
||||
"det_prob": 0.711,
|
||||
"end_time": 0.4184,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 115775,
|
||||
"individual": "-1",
|
||||
"low_freq": 28906,
|
||||
"start_time": 0.4125
|
||||
},
|
||||
{
|
||||
"class": "Myotis mystacinus",
|
||||
"class_prob": 0.544,
|
||||
"det_prob": 0.598,
|
||||
"end_time": 0.4423,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 104197,
|
||||
"individual": "-1",
|
||||
"low_freq": 36640,
|
||||
"start_time": 0.4365
|
||||
},
|
||||
{
|
||||
"class": "Pipistrellus pipistrellus",
|
||||
"class_prob": 0.73,
|
||||
"det_prob": 0.78,
|
||||
"end_time": 0.4803,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 58290,
|
||||
"individual": "-1",
|
||||
"low_freq": 48671,
|
||||
"start_time": 0.4745
|
||||
},
|
||||
{
|
||||
"class": "Myotis mystacinus",
|
||||
"class_prob": 0.404,
|
||||
"det_prob": 0.449,
|
||||
"end_time": 0.4947,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 111336,
|
||||
"individual": "-1",
|
||||
"low_freq": 36640,
|
||||
"start_time": 0.4895
|
||||
}
|
||||
],
|
||||
"class_name": "Myotis mystacinus",
|
||||
"duration": 0.5,
|
||||
"id": "20170701_213954-MYOMYS-LR_0_0.5.wav",
|
||||
"issues": false,
|
||||
"notes": "Automatically generated. Example data do not assume correct!",
|
||||
"time_exp": 1
|
||||
}
|
||||
|
231
example_data/anns/20180530_213516-EPTSER-LR_0_0.5.wav.json
Normal file
231
example_data/anns/20180530_213516-EPTSER-LR_0_0.5.wav.json
Normal file
@ -0,0 +1,231 @@
|
||||
{
|
||||
"annotated": true,
|
||||
"annotation": [
|
||||
{
|
||||
"class": "Eptesicus serotinus",
|
||||
"class_prob": 0.744,
|
||||
"det_prob": 0.77,
|
||||
"end_time": 0.0162,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 65592,
|
||||
"individual": "-1",
|
||||
"low_freq": 27187,
|
||||
"start_time": 0.0085
|
||||
},
|
||||
{
|
||||
"class": "Pipistrellus pipistrellus",
|
||||
"class_prob": 0.453,
|
||||
"det_prob": 0.459,
|
||||
"end_time": 0.0255,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 59730,
|
||||
"individual": "-1",
|
||||
"low_freq": 46093,
|
||||
"start_time": 0.0205
|
||||
},
|
||||
{
|
||||
"class": "Pipistrellus pipistrellus",
|
||||
"class_prob": 0.668,
|
||||
"det_prob": 0.68,
|
||||
"end_time": 0.0499,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 57080,
|
||||
"individual": "-1",
|
||||
"low_freq": 46953,
|
||||
"start_time": 0.0445
|
||||
},
|
||||
{
|
||||
"class": "Pipistrellus pipistrellus",
|
||||
"class_prob": 0.729,
|
||||
"det_prob": 0.739,
|
||||
"end_time": 0.109,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 62808,
|
||||
"individual": "-1",
|
||||
"low_freq": 44375,
|
||||
"start_time": 0.1025
|
||||
},
|
||||
{
|
||||
"class": "Pipistrellus pipistrellus",
|
||||
"class_prob": 0.591,
|
||||
"det_prob": 0.602,
|
||||
"end_time": 0.1311,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 56848,
|
||||
"individual": "-1",
|
||||
"low_freq": 46953,
|
||||
"start_time": 0.1255
|
||||
},
|
||||
{
|
||||
"class": "Eptesicus serotinus",
|
||||
"class_prob": 0.696,
|
||||
"det_prob": 0.735,
|
||||
"end_time": 0.1694,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 67238,
|
||||
"individual": "-1",
|
||||
"low_freq": 28046,
|
||||
"start_time": 0.1625
|
||||
},
|
||||
{
|
||||
"class": "Pipistrellus pipistrellus",
|
||||
"class_prob": 0.617,
|
||||
"det_prob": 0.643,
|
||||
"end_time": 0.2031,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 57047,
|
||||
"individual": "-1",
|
||||
"low_freq": 46093,
|
||||
"start_time": 0.1975
|
||||
},
|
||||
{
|
||||
"class": "Pipistrellus pipistrellus",
|
||||
"class_prob": 0.507,
|
||||
"det_prob": 0.515,
|
||||
"end_time": 0.2222,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 58214,
|
||||
"individual": "-1",
|
||||
"low_freq": 47812,
|
||||
"start_time": 0.2175
|
||||
},
|
||||
{
|
||||
"class": "Eptesicus serotinus",
|
||||
"class_prob": 0.201,
|
||||
"det_prob": 0.372,
|
||||
"end_time": 0.2839,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 55667,
|
||||
"individual": "-1",
|
||||
"low_freq": 33203,
|
||||
"start_time": 0.2775
|
||||
},
|
||||
{
|
||||
"class": "Pipistrellus pipistrellus",
|
||||
"class_prob": 0.749,
|
||||
"det_prob": 0.78,
|
||||
"end_time": 0.2918,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 60611,
|
||||
"individual": "-1",
|
||||
"low_freq": 45234,
|
||||
"start_time": 0.2855
|
||||
},
|
||||
{
|
||||
"class": "Eptesicus serotinus",
|
||||
"class_prob": 0.239,
|
||||
"det_prob": 0.325,
|
||||
"end_time": 0.3148,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 54100,
|
||||
"individual": "-1",
|
||||
"low_freq": 30625,
|
||||
"start_time": 0.3085
|
||||
},
|
||||
{
|
||||
"class": "Eptesicus serotinus",
|
||||
"class_prob": 0.621,
|
||||
"det_prob": 0.652,
|
||||
"end_time": 0.3227,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 63504,
|
||||
"individual": "-1",
|
||||
"low_freq": 27187,
|
||||
"start_time": 0.3155
|
||||
},
|
||||
{
|
||||
"class": "Eptesicus serotinus",
|
||||
"class_prob": 0.32,
|
||||
"det_prob": 0.414,
|
||||
"end_time": 0.3546,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 37589,
|
||||
"individual": "-1",
|
||||
"low_freq": 27187,
|
||||
"start_time": 0.3455
|
||||
},
|
||||
{
|
||||
"class": "Pipistrellus pipistrellus",
|
||||
"class_prob": 0.69,
|
||||
"det_prob": 0.697,
|
||||
"end_time": 0.3776,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 57262,
|
||||
"individual": "-1",
|
||||
"low_freq": 46093,
|
||||
"start_time": 0.3735
|
||||
},
|
||||
{
|
||||
"class": "Eptesicus serotinus",
|
||||
"class_prob": 0.34,
|
||||
"det_prob": 0.415,
|
||||
"end_time": 0.4069,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 52025,
|
||||
"individual": "-1",
|
||||
"low_freq": 31484,
|
||||
"start_time": 0.4005
|
||||
},
|
||||
{
|
||||
"class": "Eptesicus serotinus",
|
||||
"class_prob": 0.386,
|
||||
"det_prob": 0.445,
|
||||
"end_time": 0.4178,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 53951,
|
||||
"individual": "-1",
|
||||
"low_freq": 27187,
|
||||
"start_time": 0.4115
|
||||
},
|
||||
{
|
||||
"class": "Eptesicus serotinus",
|
||||
"class_prob": 0.393,
|
||||
"det_prob": 0.517,
|
||||
"end_time": 0.4359,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 51724,
|
||||
"individual": "-1",
|
||||
"low_freq": 30625,
|
||||
"start_time": 0.4305
|
||||
},
|
||||
{
|
||||
"class": "Eptesicus serotinus",
|
||||
"class_prob": 0.332,
|
||||
"det_prob": 0.396,
|
||||
"end_time": 0.4502,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 58310,
|
||||
"individual": "-1",
|
||||
"low_freq": 27187,
|
||||
"start_time": 0.4435
|
||||
},
|
||||
{
|
||||
"class": "Pipistrellus pipistrellus",
|
||||
"class_prob": 0.45,
|
||||
"det_prob": 0.456,
|
||||
"end_time": 0.4638,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 55714,
|
||||
"individual": "-1",
|
||||
"low_freq": 46093,
|
||||
"start_time": 0.4575
|
||||
},
|
||||
{
|
||||
"class": "Eptesicus serotinus",
|
||||
"class_prob": 0.719,
|
||||
"det_prob": 0.766,
|
||||
"end_time": 0.4824,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 66101,
|
||||
"individual": "-1",
|
||||
"low_freq": 28046,
|
||||
"start_time": 0.4755
|
||||
}
|
||||
],
|
||||
"class_name": "Pipistrellus pipistrellus",
|
||||
"duration": 0.5,
|
||||
"id": "20180530_213516-EPTSER-LR_0_0.5.wav",
|
||||
"issues": false,
|
||||
"notes": "Automatically generated. Example data do not assume correct!",
|
||||
"time_exp": 1
|
||||
}
|
111
example_data/anns/20180627_215323-RHIFER-LR_0_0.5.wav.json
Normal file
111
example_data/anns/20180627_215323-RHIFER-LR_0_0.5.wav.json
Normal file
@ -0,0 +1,111 @@
|
||||
{
|
||||
"annotated": true,
|
||||
"annotation": [
|
||||
{
|
||||
"class": "Rhinolophus ferrumequinum",
|
||||
"class_prob": 0.407,
|
||||
"det_prob": 0.407,
|
||||
"end_time": 0.066,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 84254,
|
||||
"individual": "-1",
|
||||
"low_freq": 68437,
|
||||
"start_time": 0.0245
|
||||
},
|
||||
{
|
||||
"class": "Rhinolophus ferrumequinum",
|
||||
"class_prob": 0.759,
|
||||
"det_prob": 0.76,
|
||||
"end_time": 0.1576,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 84048,
|
||||
"individual": "-1",
|
||||
"low_freq": 68437,
|
||||
"start_time": 0.0955
|
||||
},
|
||||
{
|
||||
"class": "Rhinolophus ferrumequinum",
|
||||
"class_prob": 0.754,
|
||||
"det_prob": 0.755,
|
||||
"end_time": 0.269,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 83768,
|
||||
"individual": "-1",
|
||||
"low_freq": 68437,
|
||||
"start_time": 0.2095
|
||||
},
|
||||
{
|
||||
"class": "Rhinolophus ferrumequinum",
|
||||
"class_prob": 0.495,
|
||||
"det_prob": 0.495,
|
||||
"end_time": 0.2869,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 84055,
|
||||
"individual": "-1",
|
||||
"low_freq": 68437,
|
||||
"start_time": 0.2425
|
||||
},
|
||||
{
|
||||
"class": "Rhinolophus ferrumequinum",
|
||||
"class_prob": 0.73,
|
||||
"det_prob": 0.73,
|
||||
"end_time": 0.3631,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 84280,
|
||||
"individual": "-1",
|
||||
"low_freq": 68437,
|
||||
"start_time": 0.3055
|
||||
},
|
||||
{
|
||||
"class": "Rhinolophus ferrumequinum",
|
||||
"class_prob": 0.648,
|
||||
"det_prob": 0.649,
|
||||
"end_time": 0.3798,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 83030,
|
||||
"individual": "-1",
|
||||
"low_freq": 68437,
|
||||
"start_time": 0.3215
|
||||
},
|
||||
{
|
||||
"class": "Rhinolophus ferrumequinum",
|
||||
"class_prob": 0.678,
|
||||
"det_prob": 0.678,
|
||||
"end_time": 0.4611,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 84020,
|
||||
"individual": "-1",
|
||||
"low_freq": 68437,
|
||||
"start_time": 0.4065
|
||||
},
|
||||
{
|
||||
"class": "Rhinolophus ferrumequinum",
|
||||
"class_prob": 0.717,
|
||||
"det_prob": 0.718,
|
||||
"end_time": 0.4987,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 83603,
|
||||
"individual": "-1",
|
||||
"low_freq": 68437,
|
||||
"start_time": 0.4365
|
||||
},
|
||||
{
|
||||
"class": "Rhinolophus ferrumequinum",
|
||||
"class_prob": 0.662,
|
||||
"det_prob": 0.662,
|
||||
"end_time": 0.5503,
|
||||
"event": "Echolocation",
|
||||
"high_freq": 83710,
|
||||
"individual": "-1",
|
||||
"low_freq": 68437,
|
||||
"start_time": 0.4975
|
||||
}
|
||||
],
|
||||
"class_name": "Rhinolophus ferrumequinum",
|
||||
"duration": 0.5,
|
||||
"id": "20180627_215323-RHIFER-LR_0_0.5.wav",
|
||||
"issues": false,
|
||||
"notes": "Automatically generated. Example data do not assume correct!",
|
||||
"time_exp": 1
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ dependencies = [
|
||||
"torch>=1.13.1,<2.5.0",
|
||||
"torchaudio>=1.13.1,<2.5.0",
|
||||
"torchvision>=0.14.0",
|
||||
"soundevent[audio,geometry,plot]>=2.3",
|
||||
"soundevent[audio,geometry,plot]>=2.5.0",
|
||||
"click>=8.1.7",
|
||||
"netcdf4>=1.6.5",
|
||||
"tqdm>=4.66.2",
|
||||
@ -30,6 +30,7 @@ dependencies = [
|
||||
"pyyaml>=6.0.2",
|
||||
"hydra-core>=1.3.2",
|
||||
"numba>=0.60",
|
||||
"loguru>=0.7.3",
|
||||
]
|
||||
requires-python = ">=3.9,<3.13"
|
||||
readme = "README.md"
|
||||
@ -72,7 +73,14 @@ dev-dependencies = [
|
||||
"ruff>=0.7.3",
|
||||
"ipykernel>=6.29.4",
|
||||
"setuptools>=69.5.1",
|
||||
"basedpyright>=1.28.4",
|
||||
"pyright>=1.1.399",
|
||||
"myst-parser>=3.0.1",
|
||||
"sphinx-autobuild>=2024.10.3",
|
||||
"numpydoc>=1.8.0",
|
||||
"sphinx-autodoc-typehints>=2.3.0",
|
||||
"sphinx-book-theme>=1.1.4",
|
||||
"autodoc-pydantic>=2.2.0",
|
||||
"pytest-cov>=6.1.1",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
@ -91,7 +99,24 @@ convention = "numpy"
|
||||
|
||||
[tool.pyright]
|
||||
include = ["batdetect2", "tests"]
|
||||
venvPath = "."
|
||||
venv = ".venv"
|
||||
pythonVersion = "3.9"
|
||||
pythonPlatform = "All"
|
||||
exclude = [
|
||||
"batdetect2/detector/",
|
||||
"batdetect2/finetune",
|
||||
"batdetect2/utils",
|
||||
"batdetect2/plotting",
|
||||
"batdetect2/plot",
|
||||
"batdetect2/api",
|
||||
"batdetect2/evaluate/legacy",
|
||||
"batdetect2/train/legacy",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
jupyter = [
|
||||
"ipywidgets>=8.1.5",
|
||||
"jupyter>=1.1.1",
|
||||
]
|
||||
marimo = [
|
||||
"marimo>=0.12.2",
|
||||
]
|
||||
|
@ -5,7 +5,22 @@ from typing import Callable, List, Optional
|
||||
import numpy as np
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
from soundevent import data
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets import (
|
||||
TargetConfig,
|
||||
TermRegistry,
|
||||
build_targets,
|
||||
call_type,
|
||||
)
|
||||
from batdetect2.targets.classes import ClassesConfig, TargetClass
|
||||
from batdetect2.targets.filtering import FilterConfig, FilterRule
|
||||
from batdetect2.targets.terms import TagInfo
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.types import ClipLabeller
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -84,8 +99,8 @@ def wav_factory(tmp_path: Path):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recording_factory(wav_factory: Callable[..., Path]):
|
||||
def _recording_factory(
|
||||
def create_recording(wav_factory: Callable[..., Path]):
|
||||
def factory(
|
||||
tags: Optional[list[data.Tag]] = None,
|
||||
path: Optional[Path] = None,
|
||||
recording_id: Optional[uuid.UUID] = None,
|
||||
@ -94,7 +109,8 @@ def recording_factory(wav_factory: Callable[..., Path]):
|
||||
samplerate: int = 256_000,
|
||||
time_expansion: float = 1,
|
||||
) -> data.Recording:
|
||||
path = path or wav_factory(
|
||||
path = wav_factory(
|
||||
path=path,
|
||||
duration=duration,
|
||||
channels=channels,
|
||||
samplerate=samplerate,
|
||||
@ -106,4 +122,264 @@ def recording_factory(wav_factory: Callable[..., Path]):
|
||||
tags=tags or [],
|
||||
)
|
||||
|
||||
return _recording_factory
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recording(
|
||||
create_recording: Callable[..., data.Recording],
|
||||
) -> data.Recording:
|
||||
return create_recording()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_clip():
|
||||
def factory(
|
||||
recording: data.Recording,
|
||||
start_time: float = 0,
|
||||
end_time: float = 0.5,
|
||||
) -> data.Clip:
|
||||
return data.Clip(
|
||||
recording=recording,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clip(recording: data.Recording) -> data.Clip:
|
||||
return data.Clip(recording=recording, start_time=0, end_time=0.5)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_sound_event():
|
||||
def factory(
|
||||
recording: data.Recording,
|
||||
coords: Optional[List[float]] = None,
|
||||
) -> data.SoundEvent:
|
||||
coords = coords or [0.2, 60_000, 0.3, 70_000]
|
||||
|
||||
return data.SoundEvent(
|
||||
geometry=data.BoundingBox(coordinates=coords),
|
||||
recording=recording,
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sound_event(recording: data.Recording) -> data.SoundEvent:
|
||||
return data.SoundEvent(
|
||||
geometry=data.BoundingBox(coordinates=[0.1, 67_000, 0.11, 73_000]),
|
||||
recording=recording,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_sound_event_annotation():
|
||||
def factory(
|
||||
sound_event: data.SoundEvent,
|
||||
tags: Optional[List[data.Tag]] = None,
|
||||
) -> data.SoundEventAnnotation:
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=tags or [],
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def echolocation_call(recording: data.Recording) -> data.SoundEventAnnotation:
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.BoundingBox(coordinates=[0.1, 67_000, 0.11, 73_000]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[
|
||||
data.Tag(term=terms.scientific_name, value="Myotis myotis"),
|
||||
data.Tag(term=call_type, value="Echolocation"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def generic_call(recording: data.Recording) -> data.SoundEventAnnotation:
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[0.34, 35_000, 0.348, 62_000]
|
||||
),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[
|
||||
data.Tag(term=terms.order, value="Chiroptera"),
|
||||
data.Tag(term=call_type, value="Echolocation"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def non_relevant_sound_event(
|
||||
recording: data.Recording,
|
||||
) -> data.SoundEventAnnotation:
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[0.22, 50_000, 0.24, 58_000]
|
||||
),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=terms.scientific_name,
|
||||
value="Muscardinus avellanarius",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_clip_annotation():
|
||||
def factory(
|
||||
clip: data.Clip,
|
||||
clip_tags: Optional[List[data.Tag]] = None,
|
||||
sound_events: Optional[List[data.SoundEventAnnotation]] = None,
|
||||
) -> data.ClipAnnotation:
|
||||
return data.ClipAnnotation(
|
||||
clip=clip,
|
||||
tags=clip_tags or [],
|
||||
sound_events=sound_events or [],
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clip_annotation(
|
||||
clip: data.Clip,
|
||||
echolocation_call: data.SoundEventAnnotation,
|
||||
generic_call: data.SoundEventAnnotation,
|
||||
non_relevant_sound_event: data.SoundEventAnnotation,
|
||||
) -> data.ClipAnnotation:
|
||||
return data.ClipAnnotation(
|
||||
clip=clip,
|
||||
sound_events=[
|
||||
echolocation_call,
|
||||
generic_call,
|
||||
non_relevant_sound_event,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_annotation_set():
|
||||
def factory(
|
||||
name: str = "test",
|
||||
description: str = "Test annotation set",
|
||||
annotations: Optional[List[data.ClipAnnotation]] = None,
|
||||
) -> data.AnnotationSet:
|
||||
return data.AnnotationSet(
|
||||
name=name,
|
||||
description=description,
|
||||
clip_annotations=annotations or [],
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_annotation_project():
|
||||
def factory(
|
||||
name: str = "test_project",
|
||||
description: str = "Test Annotation Project",
|
||||
tasks: Optional[List[data.AnnotationTask]] = None,
|
||||
annotations: Optional[List[data.ClipAnnotation]] = None,
|
||||
) -> data.AnnotationProject:
|
||||
return data.AnnotationProject(
|
||||
name=name,
|
||||
description=description,
|
||||
tasks=tasks or [],
|
||||
clip_annotations=annotations or [],
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_term_registry() -> TermRegistry:
|
||||
"""Fixture for a sample TermRegistry."""
|
||||
registry = TermRegistry()
|
||||
registry.add_custom_term("class")
|
||||
registry.add_custom_term("order")
|
||||
registry.add_custom_term("species")
|
||||
registry.add_custom_term("call_type")
|
||||
registry.add_custom_term("quality")
|
||||
return registry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_preprocessor() -> PreprocessorProtocol:
|
||||
return build_preprocessor()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bat_tag() -> TagInfo:
|
||||
return TagInfo(key="class", value="bat")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noise_tag() -> TagInfo:
|
||||
return TagInfo(key="class", value="noise")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def myomyo_tag() -> TagInfo:
|
||||
return TagInfo(key="species", value="Myotis myotis")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pippip_tag() -> TagInfo:
|
||||
return TagInfo(key="species", value="Pipistrellus pipistrellus")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_target_config(
|
||||
sample_term_registry: TermRegistry,
|
||||
bat_tag: TagInfo,
|
||||
noise_tag: TagInfo,
|
||||
myomyo_tag: TagInfo,
|
||||
pippip_tag: TagInfo,
|
||||
) -> TargetConfig:
|
||||
return TargetConfig(
|
||||
filtering=FilterConfig(
|
||||
rules=[FilterRule(match_type="exclude", tags=[noise_tag])]
|
||||
),
|
||||
classes=ClassesConfig(
|
||||
classes=[
|
||||
TargetClass(name="pippip", tags=[pippip_tag]),
|
||||
TargetClass(name="myomyo", tags=[myomyo_tag]),
|
||||
],
|
||||
generic_class=[bat_tag],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_targets(
|
||||
sample_target_config: TargetConfig,
|
||||
sample_term_registry: TermRegistry,
|
||||
) -> TargetProtocol:
|
||||
return build_targets(
|
||||
sample_target_config,
|
||||
term_registry=sample_term_registry,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_labeller(
|
||||
sample_targets: TargetProtocol,
|
||||
) -> ClipLabeller:
|
||||
return build_clip_labeler(sample_targets)
|
||||
|
@ -8,7 +8,7 @@ from batdetect2.detector import parameters
|
||||
from batdetect2.utils import audio_utils, detector_utils
|
||||
|
||||
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
@given(duration=st.floats(min_value=0.1, max_value=1))
|
||||
def test_can_compute_correct_spectrogram_width(duration: float):
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user