Compare commits

...

81 Commits

Author SHA1 Message Date
mbsantiago
2d796394f6 Store anns and preds instead of evals in EvaluatorModule 2025-09-30 19:10:07 +01:00
mbsantiago
49ec1916ce Fix API after module name change 2025-09-30 18:24:06 +01:00
mbsantiago
8727edf466 Save checkpoint on max class mAP 2025-09-30 18:20:31 +01:00
mbsantiago
2f48c58de1 api_v2 2025-09-30 13:56:25 +01:00
mbsantiago
981e37c346 Writing batch inference code 2025-09-30 13:22:03 +01:00
mbsantiago
30159d64a9 Update example config 2025-09-28 16:22:21 +01:00
mbsantiago
c9f0c5c431 Added bbox iou affinity function 2025-09-28 16:08:21 +01:00
mbsantiago
10865ee600 Re-org gallery example plots 2025-09-28 15:45:48 +01:00
mbsantiago
87ed44c8f7 Plotting reorganised 2025-09-27 23:58:06 +01:00
mbsantiago
df2abff654 Task/Metrics restructure 2025-09-26 15:23:38 +01:00
mbsantiago
d6ddc4514c Better evaluation organisation 2025-09-25 17:48:29 +01:00
mbsantiago
4cd983a2c2 Better train cli arg names 2025-09-18 09:44:27 +01:00
mbsantiago
e65df81db2 Evaluate using Lightning too to handle device changes 2025-09-18 09:28:21 +01:00
mbsantiago
6c25787123 Logging is not just for training 2025-09-18 09:27:40 +01:00
mbsantiago
8c80402f08 Move clips and audio to dedicated module 2025-09-18 09:27:24 +01:00
mbsantiago
b81a882b58 Add metrics and plots 2025-09-17 10:30:24 +01:00
mbsantiago
6e217380f2 Moved example target config to independent file 2025-09-17 10:29:30 +01:00
mbsantiago
957c0735d2 Starting new API 2025-09-16 19:39:30 +01:00
mbsantiago
bbb96b33a2 Config restructuring 2025-09-16 18:57:56 +01:00
mbsantiago
7d6cba5465 Restructuring 2025-09-16 13:38:38 +01:00
mbsantiago
60e922d565 Use pascal voc map computation by default 2025-09-16 10:56:37 +01:00
mbsantiago
704b28292b Cleaning train module 2025-09-15 16:50:08 +01:00
mbsantiago
e752e96b93 Restructure eval metrics and plotting 2025-09-15 16:01:15 +01:00
mbsantiago
ec1c0ff020 Better matching module, remove generic from classification evaluations 2025-09-14 18:16:59 +01:00
mbsantiago
8628133fd7 Compute mAP 2025-09-14 10:08:51 +01:00
mbsantiago
d80377981e Fix plotting 2025-09-14 09:38:45 +01:00
mbsantiago
ad5293e0d0 Ad FileNotFoundError to plotting 2025-09-13 20:05:42 +01:00
mbsantiago
01e7a5df25 Add ignore at ends when evaluating 2025-09-13 19:03:40 +01:00
mbsantiago
6d70140bc9 Default to normal anchor 2025-09-13 13:56:47 +01:00
mbsantiago
4fd2e84773 Fix clip missalignment in validation dataset 2025-09-11 11:09:59 +01:00
mbsantiago
74c419f674 Update default config 2025-09-10 21:49:57 +01:00
mbsantiago
e65d5a6846 Added more clipping options for validation 2025-09-10 21:09:51 +01:00
mbsantiago
615c811bb4 Add detection_class_name to targets protocol 2025-09-09 20:30:20 +01:00
mbsantiago
41b18c3f0a Better order for checkpoints 2025-09-09 15:56:46 +01:00
mbsantiago
16a0fa7b75 Add targets to train cli 2025-09-09 15:45:00 +01:00
mbsantiago
115084fd2b Updat lightning version 2025-09-09 15:31:40 +01:00
mbsantiago
951dc59718 Add seed option to train 2025-09-09 13:23:56 +01:00
mbsantiago
3376be06a4 Add experiment name 2025-09-09 09:02:25 +01:00
mbsantiago
cd4955d4f3 Eval 2025-09-08 22:04:30 +01:00
mbsantiago
c73984b213 Small fixes 2025-09-08 18:35:02 +01:00
mbsantiago
d8d2e5a2c2 Remove preprocessing modules 2025-09-08 18:11:58 +01:00
mbsantiago
b056d7d28d Make sure training is still working 2025-09-08 18:03:56 +01:00
mbsantiago
95a884ea16 Update tests 2025-09-08 18:00:17 +01:00
mbsantiago
b7ae526071 Big changes in data module 2025-09-08 17:50:25 +01:00
mbsantiago
cf6d0d1ccc Remove stale tests 2025-09-07 11:03:46 +01:00
mbsantiago
709b6355c2 torch.multiprocessing didn't work, returning to serial processing 2025-09-01 11:27:23 +01:00
mbsantiago
db2ad11743 Make matching in parallel for speedup 2025-09-01 11:19:02 +01:00
mbsantiago
e0ecc3c3d1 Clear evaluation callback after epoch ends 2025-09-01 08:56:38 +01:00
mbsantiago
71c2301c21 Independent preprocessor for generating validation plots 2025-08-31 23:04:16 +01:00
mbsantiago
d3d2a28130 Move detections array to cpu 2025-08-31 22:59:06 +01:00
mbsantiago
5b9a5a968f Refactor eval code 2025-08-31 22:57:02 +01:00
mbsantiago
356be57f62 Update to latests soundevent 2025-08-31 20:17:48 +01:00
mbsantiago
8d093c3ca2 Use custom AugmentationSequence instead of nn.Sequential 2025-08-31 19:27:15 +01:00
mbsantiago
2f4edeffff Create separate preprocessor for the train/val datasets 2025-08-31 19:22:45 +01:00
mbsantiago
cca1d82d63 Remove blocking .to(device) 2025-08-31 19:10:06 +01:00
mbsantiago
55f473c9ca Fix logging issue 2025-08-31 19:06:09 +01:00
mbsantiago
40f6b64611 Remove train preprocessing 2025-08-31 18:28:52 +01:00
mbsantiago
1cec332dd5 Change default train duration to 0.256 instead of 0.512 2025-08-30 14:08:00 +01:00
mbsantiago
93e89ecc46 LR Scheduler takes num of total batches 2025-08-28 08:52:11 +01:00
mbsantiago
34ef9e92a1 Make sure preprocessing is batchable 2025-08-27 23:58:38 +01:00
mbsantiago
0b5ac96fe8 Update model config 2025-08-27 23:58:07 +01:00
mbsantiago
dba6d2d918 Updating configs 2025-08-27 23:44:49 +01:00
mbsantiago
ff754a1269 Tweaks of augmentation config 2025-08-27 18:23:38 +01:00
mbsantiago
ed76ec24b6 Plot anchor points 2025-08-27 18:13:40 +01:00
mbsantiago
d25efdad10 Fix plotting after update 2025-08-26 11:48:06 +01:00
mbsantiago
3043230d4f Device fixing #5 2025-08-25 23:20:30 +01:00
mbsantiago
67e37227f5 Device fixing #4 2025-08-25 23:11:33 +01:00
mbsantiago
9d4a9fc35c Device fixing #3 2025-08-25 23:08:49 +01:00
mbsantiago
d0bab60bf3 Device fixing #2 2025-08-25 23:07:09 +01:00
mbsantiago
a267db290c Device fixing 2025-08-25 23:04:13 +01:00
mbsantiago
441ccb3382 Remove user warning from plotting function 2025-08-25 22:49:48 +01:00
mbsantiago
281c4dcb8a Remove xr from postprocess 2025-08-25 22:46:21 +01:00
mbsantiago
cc9e47b022 Fix plotting after changes 2025-08-25 19:07:12 +01:00
mbsantiago
1f26103f42 Cleanup train preprocessing 2025-08-25 18:37:46 +01:00
mbsantiago
c80078feee Removing stale tests 2025-08-25 17:23:27 +01:00
mbsantiago
0bb0caddea Update augmentations 2025-08-25 17:06:17 +01:00
mbsantiago
76dda0a0e9 Fix train preprocessing 2025-08-25 14:01:31 +01:00
mbsantiago
c36ef3ecb5 Labels to torch 2025-08-25 12:43:34 +01:00
mbsantiago
667b18a54d Preprocessing in pytorch 2025-08-25 11:41:55 +01:00
mbsantiago
61115d562c Moved types to dedicated module 2025-08-24 10:55:48 +01:00
mbsantiago
02adc19070 Better structure for training module 2025-08-23 18:23:45 +01:00
155 changed files with 12499 additions and 13946 deletions

View File

@ -133,7 +133,7 @@ When you need to specify a tag, you typically use a structure with two fields:
**It defaults to `class`** if you omit it, which is common when defining the main target classes.
- `value`: The specific _value_ of the tag (e.g., `Myotis daubentonii`, `Good`, `Rain`).
**Example YAML Configuration using TagInfo (e.g., inside a filter rule):**
**Example YAML Configuration (e.g., inside a filter rule):**
```yaml
# ... inside a filtering configuration section ...

View File

@ -1,119 +1,125 @@
datasets:
train:
name: example dataset
description: Only for demonstration purposes
sources:
- format: batdetect2
name: Example Data
description: Examples included for testing batdetect2
annotations_dir: example_data/anns
audio_dir: example_data/audio
targets:
classes:
classes:
- name: myomys
tags:
- value: Myotis mystacinus
- name: pippip
tags:
- value: Pipistrellus pipistrellus
- name: eptser
tags:
- value: Eptesicus serotinus
- name: rhifer
tags:
- value: Rhinolophus ferrumequinum
generic_class:
- key: class
value: Bat
filtering:
rules:
- match_type: all
tags:
- key: event
value: Echolocation
- match_type: exclude
tags:
- key: class
value: Unknown
audio:
samplerate: 256000
resample:
enabled: True
method: "poly"
preprocess:
audio:
resample:
samplerate: 256000
method: "poly"
scale: false
center: true
duration: null
spectrogram:
stft:
window_duration: 0.002
window_overlap: 0.75
window_fn: hann
frequencies:
max_freq: 120000
min_freq: 10000
pcen:
stft:
window_duration: 0.002
window_overlap: 0.75
window_fn: hann
frequencies:
max_freq: 120000
min_freq: 10000
size:
height: 128
resize_factor: 0.5
spectrogram_transforms:
- name: pcen
time_constant: 0.1
gain: 0.98
bias: 2
power: 0.5
scale: "amplitude"
size:
height: 128
resize_factor: 0.5
spectral_mean_substraction: true
peak_normalize: false
- name: spectral_mean_substraction
postprocess:
nms_kernel_size: 9
detection_threshold: 0.01
min_freq: 10000
max_freq: 120000
top_k_per_sec: 200
labels:
sigma: 3
model:
input_height: 128
in_channels: 1
out_channels: 32
encoder:
layers:
- block_type: FreqCoordConvDown
- name: FreqCoordConvDown
out_channels: 32
- block_type: FreqCoordConvDown
- name: FreqCoordConvDown
out_channels: 64
- block_type: LayerGroup
- name: LayerGroup
layers:
- block_type: FreqCoordConvDown
- name: FreqCoordConvDown
out_channels: 128
- block_type: ConvBlock
- name: ConvBlock
out_channels: 256
bottleneck:
channels: 256
self_attention: true
layers:
- name: SelfAttention
attention_channels: 256
decoder:
layers:
- block_type: FreqCoordConvUp
- name: FreqCoordConvUp
out_channels: 64
- block_type: FreqCoordConvUp
- name: FreqCoordConvUp
out_channels: 32
- block_type: LayerGroup
- name: LayerGroup
layers:
- block_type: FreqCoordConvUp
- name: FreqCoordConvUp
out_channels: 32
- block_type: ConvBlock
- name: ConvBlock
out_channels: 32
train:
batch_size: 8
learning_rate: 0.001
t_max: 100
optimizer:
learning_rate: 0.001
t_max: 100
labels:
sigma: 3
trainer:
max_epochs: 10
check_val_every_n_epoch: 5
train_loader:
batch_size: 8
num_workers: 2
shuffle: True
clipping_strategy:
name: random_subclip
duration: 0.256
augmentations:
enabled: true
audio:
- name: mix_audio
probability: 0.2
min_weight: 0.3
max_weight: 0.7
- name: add_echo
probability: 0.2
max_delay: 0.005
min_weight: 0.0
max_weight: 1.0
spectrogram:
- name: scale_volume
probability: 0.2
min_scaling: 0.0
max_scaling: 2.0
- name: warp
probability: 0.2
delta: 0.04
- name: mask_time
probability: 0.2
max_perc: 0.05
max_masks: 3
- name: mask_freq
probability: 0.2
max_perc: 0.10
max_masks: 3
val_loader:
num_workers: 2
clipping_strategy:
name: whole_audio_padded
chunk_size: 0.256
loss:
detection:
weight: 1.0
@ -127,37 +133,54 @@ train:
alpha: 2
size:
weight: 0.1
logger:
logger_type: mlflow
experiment_name: batdetect2
tracking_uri: http://localhost:5000
log_model: true
save_dir: outputs/log/
artifact_location: outputs/artifacts/
checkpoint_path_prefix: outputs/checkpoints/
augmentations:
steps:
- augmentation_type: mix_audio
probability: 0.2
min_weight: 0.3
max_weight: 0.7
- augmentation_type: add_echo
probability: 0.2
max_delay: 0.005
min_weight: 0.0
max_weight: 1.0
- augmentation_type: scale_volume
probability: 0.2
min_scaling: 0.0
max_scaling: 2.0
- augmentation_type: warp
probability: 0.2
delta: 0.04
- augmentation_type: mask_time
probability: 0.2
max_perc: 0.05
max_masks: 3
- augmentation_type: mask_freq
probability: 0.2
max_perc: 0.10
max_masks: 3
name: csv
validation:
tasks:
- name: sound_event_detection
metrics:
- name: average_precision
- name: sound_event_classification
metrics:
- name: average_precision
evaluation:
tasks:
- name: sound_event_detection
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: score_distribution
- name: example_detection
- name: sound_event_classification
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: top_class_detection
metrics:
- name: average_precision
plots:
- name: pr_curve
- name: confusion_matrix
- name: example_classification
- name: clip_detection
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: roc_curve
- name: score_distribution
- name: clip_classification
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: roc_curve

View File

@ -0,0 +1,8 @@
name: example dataset
description: Only for demonstration purposes
sources:
- format: batdetect2
name: Example Data
description: Examples included for testing batdetect2
annotations_dir: example_data/anns
audio_dir: example_data/audio

36
example_data/targets.yaml Normal file
View File

@ -0,0 +1,36 @@
detection_target:
name: bat
match_if:
name: all_of
conditions:
- name: has_tag
tag: { key: event, value: Echolocation }
- name: not
condition:
name: has_tag
tag: { key: class, value: Unknown }
assign_tags:
- key: class
value: Bat
classification_targets:
- name: myomys
tags:
- key: class
value: Myotis mystacinus
- name: pippip
tags:
- key: class
value: Pipistrellus pipistrellus
- name: eptser
tags:
- key: class
value: Eptesicus serotinus
- name: rhifer
tags:
- key: class
value: Rhinolophus ferrumequinum
roi:
name: anchor_bbox
anchor: top-left

View File

@ -92,19 +92,11 @@ clean-build:
clean: clean-build clean-pyc clean-test clean-docs
# Examples
# Preprocess example data.
example-preprocess OPTIONS="":
batdetect2 preprocess \
--base-dir . \
--dataset-field datasets.train \
--config example_data/config.yaml \
{{OPTIONS}} \
example_data/datasets.yaml example_data/preprocessed
# Train on example data.
example-train OPTIONS="":
batdetect2 train \
--val-dir example_data/preprocessed \
--val-dataset example_data/dataset.yaml \
--config example_data/config.yaml \
{{OPTIONS}} \
example_data/preprocessed
example_data/dataset.yaml

View File

@ -17,13 +17,13 @@ dependencies = [
"torch>=1.13.1,<2.5.0",
"torchaudio>=1.13.1,<2.5.0",
"torchvision>=0.14.0",
"soundevent[audio,geometry,plot]>=2.7.0",
"soundevent[audio,geometry,plot]>=2.9.1",
"click>=8.1.7",
"netcdf4>=1.6.5",
"tqdm>=4.66.2",
"cf-xarray>=0.9.0",
"onnx>=1.16.0",
"lightning[extra]>=2.2.2",
"lightning[extra]==2.5.0",
"tensorboard>=2.16.2",
"omegaconf>=2.3.0",
"pyyaml>=6.0.2",

272
src/batdetect2/api_v2.py Normal file
View File

@ -0,0 +1,272 @@
from pathlib import Path
from typing import List, Optional, Sequence
import numpy as np
import torch
from soundevent import data
from soundevent.audio.files import get_audio_files
from batdetect2.audio import build_audio_loader
from batdetect2.config import BatDetect2Config
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR
from batdetect2.models import Model, build_model
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train import (
DEFAULT_CHECKPOINT_DIR,
load_model_from_checkpoint,
train,
)
from batdetect2.typing import (
AudioLoader,
BatDetect2Prediction,
EvaluatorProtocol,
PostprocessorProtocol,
PreprocessorProtocol,
RawPrediction,
TargetProtocol,
)
class BatDetect2API:
def __init__(
self,
config: BatDetect2Config,
targets: TargetProtocol,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
evaluator: EvaluatorProtocol,
model: Model,
):
self.config = config
self.targets = targets
self.audio_loader = audio_loader
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.evaluator = evaluator
self.model = model
self.model.eval()
def train(
self,
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
train_workers: Optional[int] = None,
val_workers: Optional[int] = None,
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
seed: Optional[int] = None,
):
train(
train_annotations=train_annotations,
val_annotations=val_annotations,
targets=self.targets,
config=self.config,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
train_workers=train_workers,
val_workers=val_workers,
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
seed=seed,
)
return self
def evaluate(
self,
test_annotations: Sequence[data.ClipAnnotation],
num_workers: Optional[int] = None,
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
):
return evaluate(
self.model,
test_annotations,
targets=self.targets,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
config=self.config,
num_workers=num_workers,
output_dir=output_dir,
experiment_name=experiment_name,
run_name=run_name,
)
def load_audio(self, path: data.PathLike) -> np.ndarray:
return self.audio_loader.load_file(path)
def load_clip(self, clip: data.Clip) -> np.ndarray:
return self.audio_loader.load_clip(clip)
def generate_spectrogram(
self,
audio: np.ndarray,
) -> torch.Tensor:
tensor = torch.tensor(audio).unsqueeze(0)
return self.preprocessor(tensor)
def process_file(self, audio_file: str) -> BatDetect2Prediction:
recording = data.Recording.from_file(audio_file, compute_hash=False)
wav = self.audio_loader.load_recording(recording)
detections = self.process_audio(wav)
return BatDetect2Prediction(
clip=data.Clip(
uuid=recording.uuid,
recording=recording,
start_time=0,
end_time=recording.duration,
),
predictions=detections,
)
def process_audio(
self,
audio: np.ndarray,
) -> List[RawPrediction]:
spec = self.generate_spectrogram(audio)
return self.process_spectrogram(spec)
def process_spectrogram(
self,
spec: torch.Tensor,
start_time: float = 0,
) -> List[RawPrediction]:
if spec.ndim == 4 and spec.shape[0] > 1:
raise ValueError("Batched spectrograms not supported.")
if spec.ndim == 3:
spec = spec.unsqueeze(0)
outputs = self.model.detector(spec)
detections = self.model.postprocessor(
outputs,
start_times=[start_time],
)[0]
return to_raw_predictions(detections.numpy(), targets=self.targets)
def process_directory(
self,
audio_dir: data.PathLike,
) -> List[BatDetect2Prediction]:
files = list(get_audio_files(audio_dir))
return self.process_files(files)
def process_files(
self,
audio_files: Sequence[data.PathLike],
num_workers: Optional[int] = None,
) -> List[BatDetect2Prediction]:
return process_file_list(
self.model,
audio_files,
config=self.config,
targets=self.targets,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
num_workers=num_workers,
)
def process_clips(
self,
clips: Sequence[data.Clip],
) -> List[BatDetect2Prediction]:
return run_batch_inference(
self.model,
clips,
targets=self.targets,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
config=self.config,
)
@classmethod
def from_config(cls, config: BatDetect2Config):
targets = build_targets(config=config.targets)
audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
)
postprocessor = build_postprocessor(
preprocessor,
config=config.postprocess,
)
evaluator = build_evaluator(config=config.evaluation, targets=targets)
# NOTE: Better to have a separate instance of
# preprocessor and postprocessor as these may be moved
# to another device.
model = build_model(
config=config.model,
targets=targets,
preprocessor=build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
),
postprocessor=build_postprocessor(
preprocessor,
config=config.postprocess,
),
)
return cls(
config=config,
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,
postprocessor=postprocessor,
evaluator=evaluator,
model=model,
)
@classmethod
def from_checkpoint(
cls,
path: data.PathLike,
config: Optional[BatDetect2Config] = None,
):
model, stored_config = load_model_from_checkpoint(path)
config = config or stored_config
targets = build_targets(config=config.targets)
audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
)
postprocessor = build_postprocessor(
preprocessor,
config=config.postprocess,
)
evaluator = build_evaluator(config=config.evaluation, targets=targets)
return cls(
config=config,
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,
postprocessor=postprocessor,
evaluator=evaluator,
model=model,
)

View File

@ -0,0 +1,16 @@
from batdetect2.audio.clips import ClipConfig, build_clipper
from batdetect2.audio.loader import (
TARGET_SAMPLERATE_HZ,
AudioConfig,
SoundEventAudioLoader,
build_audio_loader,
)
__all__ = [
"TARGET_SAMPLERATE_HZ",
"AudioConfig",
"SoundEventAudioLoader",
"build_audio_loader",
"ClipConfig",
"build_clipper",
]

View File

@ -0,0 +1,264 @@
from typing import Annotated, List, Literal, Optional, Union
import numpy as np
from loguru import logger
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds, intervals_overlap
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipperProtocol
DEFAULT_TRAIN_CLIP_DURATION = 0.256
DEFAULT_MAX_EMPTY_CLIP = 0.1
__all__ = [
"build_clipper",
"ClipConfig",
]
clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")
class RandomClipConfig(BaseConfig):
name: Literal["random_subclip"] = "random_subclip"
duration: float = DEFAULT_TRAIN_CLIP_DURATION
random: bool = True
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
min_sound_event_overlap: float = 0
class RandomClip:
def __init__(
self,
duration: float = 0.5,
max_empty: float = 0.2,
random: bool = True,
min_sound_event_overlap: float = 0,
):
super().__init__()
self.duration = duration
self.random = random
self.max_empty = max_empty
self.min_sound_event_overlap = min_sound_event_overlap
def __call__(
self,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation:
subclip = self.get_subclip(clip_annotation.clip)
sound_events = select_sound_event_annotations(
clip_annotation,
subclip,
min_overlap=self.min_sound_event_overlap,
)
return clip_annotation.model_copy(
update=dict(
clip=subclip,
sound_events=sound_events,
)
)
def get_subclip(self, clip: data.Clip) -> data.Clip:
return select_random_subclip(
clip,
random=self.random,
duration=self.duration,
max_empty=self.max_empty,
)
@clipper_registry.register(RandomClipConfig)
@staticmethod
def from_config(config: RandomClipConfig):
return RandomClip(
duration=config.duration,
max_empty=config.max_empty,
min_sound_event_overlap=config.min_sound_event_overlap,
)
def get_subclip_annotation(
clip_annotation: data.ClipAnnotation,
random: bool = True,
duration: float = 0.5,
max_empty: float = 0.2,
min_sound_event_overlap: float = 0,
) -> data.ClipAnnotation:
clip = clip_annotation.clip
subclip = select_random_subclip(
clip,
random=random,
duration=duration,
max_empty=max_empty,
)
sound_events = select_sound_event_annotations(
clip_annotation,
subclip,
min_overlap=min_sound_event_overlap,
)
return clip_annotation.model_copy(
update=dict(
clip=subclip,
sound_events=sound_events,
)
)
def select_random_subclip(
clip: data.Clip,
random: bool = True,
duration: float = 0.5,
max_empty: float = 0.2,
) -> data.Clip:
start_time = clip.start_time
end_time = clip.end_time
if duration > clip.duration + max_empty or not random:
return clip.model_copy(
update=dict(
start_time=start_time,
end_time=start_time + duration,
)
)
random_start_time = np.random.uniform(
low=start_time,
high=end_time + max_empty - duration,
)
return clip.model_copy(
update=dict(
start_time=random_start_time,
end_time=random_start_time + duration,
)
)
def select_sound_event_annotations(
clip_annotation: data.ClipAnnotation,
subclip: data.Clip,
min_overlap: float = 0,
) -> List[data.SoundEventAnnotation]:
selected = []
start_time = subclip.start_time
end_time = subclip.end_time
for sound_event_annotation in clip_annotation.sound_events:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
continue
geom_start_time, _, geom_end_time, _ = compute_bounds(geometry)
if not intervals_overlap(
(start_time, end_time),
(geom_start_time, geom_end_time),
min_absolute_overlap=min_overlap,
):
continue
selected.append(sound_event_annotation)
return selected
class PaddedClipConfig(BaseConfig):
name: Literal["whole_audio_padded"] = "whole_audio_padded"
chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION
class PaddedClip:
def __init__(self, chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION):
self.chunk_size = chunk_size
def __call__(
self,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation:
clip = clip_annotation.clip
clip = self.get_subclip(clip)
return clip_annotation.model_copy(update=dict(clip=clip))
def get_subclip(self, clip: data.Clip) -> data.Clip:
duration = clip.duration
target_duration = float(
self.chunk_size * np.ceil(duration / self.chunk_size)
)
clip = clip.model_copy(
update=dict(
end_time=clip.start_time + target_duration,
)
)
return clip
@clipper_registry.register(PaddedClipConfig)
@staticmethod
def from_config(config: PaddedClipConfig):
return PaddedClip(chunk_size=config.chunk_size)
class FixedDurationClipConfig(BaseConfig):
name: Literal["fixed_duration"] = "fixed_duration"
duration: float = DEFAULT_TRAIN_CLIP_DURATION
class FixedDurationClip:
def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION):
self.duration = duration
def __call__(
self,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation:
clip = self.get_subclip(clip_annotation.clip)
sound_events = select_sound_event_annotations(
clip_annotation,
clip,
min_overlap=0,
)
return clip_annotation.model_copy(
update=dict(
clip=clip,
sound_events=sound_events,
)
)
def get_subclip(self, clip: data.Clip) -> data.Clip:
return clip.model_copy(
update=dict(
end_time=clip.start_time + self.duration,
)
)
@clipper_registry.register(FixedDurationClipConfig)
@staticmethod
def from_config(config: FixedDurationClipConfig):
return FixedDurationClip(duration=config.duration)
ClipConfig = Annotated[
Union[
RandomClipConfig,
PaddedClipConfig,
FixedDurationClipConfig,
],
Field(discriminator="name"),
]
def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
config = config or RandomClipConfig()
logger.opt(lazy=True).debug(
"Building clipper with config: \n{}",
lambda: config.to_yaml_string(),
)
return clipper_registry.build(config)

View File

@ -0,0 +1,295 @@
from typing import Optional
import numpy as np
from numpy.typing import DTypeLike
from pydantic import Field
from scipy.signal import resample, resample_poly
from soundevent import audio, data
from soundfile import LibsndfileError
from batdetect2.core import BaseConfig
from batdetect2.typing import AudioLoader
__all__ = [
"SoundEventAudioLoader",
"build_audio_loader",
"load_file_audio",
"load_recording_audio",
"load_clip_audio",
"resample_audio",
]
TARGET_SAMPLERATE_HZ = 256_000
"""Default target sample rate in Hz used if resampling is enabled."""
class ResampleConfig(BaseConfig):
"""Configuration for audio resampling.
Attributes
----------
samplerate : int, default=256000
The target sample rate in Hz to resample the audio to. Must be > 0.
method : str, default="poly"
The resampling algorithm to use. Options:
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
Generally fast.
- "fourier": Resampling via Fourier method using
`scipy.signal.resample`. May handle non-integer
resampling factors differently.
"""
enabled: bool = True
method: str = "poly"
class AudioConfig(BaseConfig):
"""Configuration for loading and initial audio preprocessing."""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
resample: ResampleConfig = Field(default_factory=ResampleConfig)
def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader:
"""Factory function to create an AudioLoader based on configuration."""
config = config or AudioConfig()
return SoundEventAudioLoader(
samplerate=config.samplerate,
config=config.resample,
)
class SoundEventAudioLoader(AudioLoader):
"""Concrete implementation of the `AudioLoader`."""
def __init__(
self,
samplerate: int = TARGET_SAMPLERATE_HZ,
config: Optional[ResampleConfig] = None,
):
self.samplerate = samplerate
self.config = config or ResampleConfig()
def load_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> np.ndarray:
"""Load and preprocess audio directly from a file path."""
return load_file_audio(
path,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> np.ndarray:
"""Load and preprocess the entire audio for a Recording object."""
return load_recording_audio(
recording,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> np.ndarray:
"""Load and preprocess the audio segment defined by a Clip object."""
return load_clip_audio(
clip,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_file_audio(
path: data.PathLike,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess audio from a file path using specified config."""
try:
recording = data.Recording.from_file(path)
except LibsndfileError as e:
raise FileNotFoundError(
f"Could not load the recording at path: {path}. Error: {e}"
) from e
return load_recording_audio(
recording,
samplerate=samplerate,
config=config,
dtype=dtype,
audio_dir=audio_dir,
)
def load_recording_audio(
recording: data.Recording,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess the entire audio content of a recording using config."""
clip = data.Clip(
recording=recording,
start_time=0,
end_time=recording.duration,
)
return load_clip_audio(
clip,
samplerate=samplerate,
config=config,
dtype=dtype,
audio_dir=audio_dir,
)
def load_clip_audio(
clip: data.Clip,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess a specific audio clip segment based on config."""
try:
wav = (
audio.load_clip(clip, audio_dir=audio_dir)
.sel(channel=0)
.astype(dtype)
)
except LibsndfileError as e:
raise FileNotFoundError(
f"Could not load the recording at path: {clip.recording.path}. "
f"Error: {e}"
) from e
if not config or not config.enabled or samplerate is None:
return wav.data.astype(dtype)
sr = int(1 / wav.time.attrs["step"])
return resample_audio(
wav.data,
sr=sr,
samplerate=samplerate,
method=config.method,
)
def resample_audio(
wav: np.ndarray,
sr: int,
samplerate: int = TARGET_SAMPLERATE_HZ,
method: str = "poly",
) -> np.ndarray:
"""Resample an audio waveform DataArray to a target sample rate."""
if sr == samplerate:
return wav
if method == "poly":
return resample_audio_poly(
wav,
sr_orig=sr,
sr_new=samplerate,
)
elif method == "fourier":
return resample_audio_fourier(
wav,
sr_orig=sr,
sr_new=samplerate,
)
else:
raise NotImplementedError(
f"Resampling method '{method}' not implemented"
)
def resample_audio_poly(
array: np.ndarray,
sr_orig: int,
sr_new: int,
axis: int = -1,
) -> np.ndarray:
"""Resample a numpy array using `scipy.signal.resample_poly`.
This method is often preferred for signals when the ratio of new
to old sample rates can be expressed as a rational number. It uses
polyphase filtering.
Parameters
----------
array : np.ndarray
The input array to resample.
sr_orig : int
The original sample rate in Hz.
sr_new : int
The target sample rate in Hz.
axis : int, default=-1
The axis of `array` along which to resample.
Returns
-------
np.ndarray
The array resampled to the target sample rate.
Raises
------
ValueError
If sample rates are not positive.
"""
gcd = np.gcd(sr_orig, sr_new)
return resample_poly(
array,
sr_new // gcd,
sr_orig // gcd,
axis=axis,
)
def resample_audio_fourier(
array: np.ndarray,
sr_orig: int,
sr_new: int,
axis: int = -1,
) -> np.ndarray:
"""Resample a numpy array using `scipy.signal.resample`.
This method uses FFTs to resample the signal.
Parameters
----------
array : np.ndarray
The input array to resample.
num : int
The desired number of samples in the output array along `axis`.
axis : int, default=-1
The axis of `array` along which to resample.
Returns
-------
np.ndarray
The array resampled to have `num` samples along `axis`.
Raises
------
ValueError
If `num` is negative.
"""
ratio = sr_new / sr_orig
return resample( # type: ignore
array,
int(array.shape[axis] * ratio),
axis=axis,
)

View File

@ -1,7 +1,7 @@
from batdetect2.cli.base import cli
from batdetect2.cli.compat import detect
from batdetect2.cli.data import data
from batdetect2.cli.preprocess import preprocess
from batdetect2.cli.evaluate import evaluate_command
from batdetect2.cli.train import train_command
__all__ = [
@ -9,7 +9,7 @@ __all__ = [
"detect",
"data",
"train_command",
"preprocess",
"evaluate_command",
]

View File

@ -1,10 +1,15 @@
import os
import click
from batdetect2 import api
from batdetect2.cli.base import cli
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
from batdetect2.types import ProcessingConfiguration
from batdetect2.utils.detector_utils import save_results_to_file
DEFAULT_MODEL_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"models",
"checkpoints",
"Net2DFast_UK_same.pth.tar",
)
@cli.command()
@ -74,6 +79,9 @@ def detect(
Input files should be short in duration e.g. < 30 seconds.
"""
from batdetect2 import api
from batdetect2.utils.detector_utils import save_results_to_file
click.echo(f"Loading model: {args['model_path']}")
model, params = api.load_model(args["model_path"])
@ -123,7 +131,7 @@ def detect(
click.echo(f" {err}")
def print_config(config: ProcessingConfiguration):
def print_config(config):
"""Print the processing configuration."""
click.echo("\nProcessing Configuration:")
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")

View File

@ -4,7 +4,6 @@ from typing import Optional
import click
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
__all__ = ["data"]
@ -33,6 +32,8 @@ def summary(
field: Optional[str] = None,
base_dir: Optional[Path] = None,
):
from batdetect2.data import load_dataset_from_config
base_dir = base_dir or Path.cwd()
dataset = load_dataset_from_config(
dataset_config,

View File

@ -0,0 +1,78 @@
import sys
from pathlib import Path
from typing import Optional
import click
from loguru import logger
from batdetect2.cli.base import cli
__all__ = ["evaluate_command"]
DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@cli.command(name="evaluate")
@click.argument("model-path", type=click.Path(exists=True))
@click.argument("test_dataset", type=click.Path(exists=True))
@click.option("--config", "config_path", type=click.Path())
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
@click.option("--experiment-name", type=str)
@click.option("--run-name", type=str)
@click.option("--workers", "num_workers", type=int)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def evaluate_command(
model_path: Path,
test_dataset: Path,
base_dir: Path,
config_path: Optional[Path],
output_dir: Path = DEFAULT_OUTPUT_DIR,
num_workers: Optional[int] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
verbose: int = 0,
):
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import load_full_config
from batdetect2.data import load_dataset_from_config
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Initiating evaluation process...")
test_annotations = load_dataset_from_config(
test_dataset,
base_dir=base_dir,
)
logger.debug(
"Loaded {num_annotations} test examples",
num_annotations=len(test_annotations),
)
config = None
if config_path is not None:
config = load_full_config(config_path)
api = BatDetect2API.from_checkpoint(model_path, config=config)
api.evaluate(
test_annotations,
num_workers=num_workers,
output_dir=output_dir,
experiment_name=experiment_name,
run_name=run_name,
)

View File

@ -1,154 +0,0 @@
import sys
from pathlib import Path
from typing import Optional
import click
import yaml
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.train.preprocess import (
TrainPreprocessConfig,
load_train_preprocessing_config,
preprocess_dataset,
)
__all__ = ["preprocess"]
@cli.command()
@click.argument(
"dataset_config",
type=click.Path(exists=True),
)
@click.argument(
"output",
type=click.Path(),
)
@click.option(
"--dataset-field",
type=str,
help=(
"Specifies the key to access the dataset information within the "
"dataset configuration file, if the information is nested inside a "
"dictionary. If the dataset information is at the top level of the "
"config file, you don't need to specify this."
),
)
@click.option(
"--base-dir",
type=click.Path(exists=True),
help=(
"The main directory where your audio recordings and annotation "
"files are stored. This helps the program find your data, "
"especially if the paths in your dataset configuration file "
"are relative."
),
)
@click.option(
"--config",
type=click.Path(exists=True),
help=(
"Path to the configuration file. This file tells "
"the program how to prepare your audio data before training, such "
"as resampling or applying filters."
),
)
@click.option(
"--config-field",
type=str,
help=(
"If the preprocessing settings are inside a nested dictionary "
"within the preprocessing configuration file, specify the key "
"here to access them. If the preprocessing settings are at the "
"top level, you don't need to specify this."
),
)
@click.option(
"--force",
is_flag=True,
help=(
"If a preprocessed file already exists, this option tells the "
"program to overwrite it with the new preprocessed data. Use "
"this if you want to re-do the preprocessing even if the files "
"already exist."
),
)
@click.option(
"--num-workers",
type=int,
help=(
"The maximum number of computer cores to use when processing "
"your audio data. Using more cores can speed up the preprocessing, "
"but don't use more than your computer has available. By default, "
"the program will use all available cores."
),
)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def preprocess(
dataset_config: Path,
output: Path,
base_dir: Optional[Path] = None,
config: Optional[Path] = None,
config_field: Optional[str] = None,
force: bool = False,
num_workers: Optional[int] = None,
dataset_field: Optional[str] = None,
verbose: int = 0,
):
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Starting preprocessing.")
output = Path(output)
logger.info("Will save outputs to {output}", output=output)
base_dir = base_dir or Path.cwd()
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
if config:
logger.info(
"Loading preprocessing config from: {config}", config=config
)
conf = (
load_train_preprocessing_config(config, field=config_field)
if config is not None
else TrainPreprocessConfig()
)
logger.debug(
"Preprocessing config:\n{conf}",
conf=yaml.dump(conf.model_dump()),
)
dataset = load_dataset_from_config(
dataset_config,
field=dataset_field,
base_dir=base_dir,
)
logger.info(
"Loaded {num_examples} annotated clips from the configured dataset",
num_examples=len(dataset),
)
preprocess_dataset(
dataset,
conf,
output=output,
force=force,
max_workers=num_workers,
)

View File

@ -6,24 +6,24 @@ import click
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.train import (
FullTrainingConfig,
load_full_training_config,
train,
)
from batdetect2.train.dataset import list_preprocessed_files
__all__ = ["train_command"]
@cli.command(name="train")
@click.argument("train_dir", type=click.Path(exists=True))
@click.option("--val-dir", type=click.Path(exists=True))
@click.option("--model-path", type=click.Path(exists=True))
@click.argument("train_dataset", type=click.Path(exists=True))
@click.option("--val-dataset", type=click.Path(exists=True))
@click.option("--model", "model_path", type=click.Path(exists=True))
@click.option("--targets", "targets_config", type=click.Path(exists=True))
@click.option("--ckpt-dir", type=click.Path(exists=True))
@click.option("--log-dir", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True))
@click.option("--config-field", type=str)
@click.option("--train-workers", type=int)
@click.option("--val-workers", type=int)
@click.option("--experiment-name", type=str)
@click.option("--run-name", type=str)
@click.option("--seed", type=int)
@click.option(
"-v",
"--verbose",
@ -31,15 +31,29 @@ __all__ = ["train_command"]
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def train_command(
train_dir: Path,
val_dir: Optional[Path] = None,
train_dataset: Path,
val_dataset: Optional[Path] = None,
model_path: Optional[Path] = None,
ckpt_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
config: Optional[Path] = None,
targets_config: Optional[Path] = None,
config_field: Optional[str] = None,
seed: Optional[int] = None,
train_workers: int = 0,
val_workers: int = 0,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
verbose: int = 0,
):
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import (
BatDetect2Config,
load_full_config,
)
from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config
logger.remove()
if verbose == 0:
log_level = "WARNING"
@ -48,41 +62,53 @@ def train_command(
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Initiating training process...")
logger.info("Loading training configuration...")
logger.info("Loading configuration...")
conf = (
load_full_training_config(config, field=config_field)
load_full_config(config, field=config_field)
if config is not None
else FullTrainingConfig()
else BatDetect2Config()
)
logger.info("Scanning for training and validation data...")
train_examples = list_preprocessed_files(train_dir)
if targets_config is not None:
logger.info("Loading targets configuration...")
conf = conf.model_copy(
update=dict(targets=load_target_config(targets_config))
)
logger.info("Loading training dataset...")
train_annotations = load_dataset_from_config(train_dataset)
logger.debug(
"Found {num_files} training examples in {path}",
num_files=len(train_examples),
path=train_dir,
"Loaded {num_annotations} training examples",
num_annotations=len(train_annotations),
)
val_examples = None
if val_dir is not None:
val_examples = list_preprocessed_files(val_dir)
val_annotations = None
if val_dataset is not None:
val_annotations = load_dataset_from_config(val_dataset)
logger.debug(
"Found {num_files} validation examples in {path}",
num_files=len(val_examples),
path=val_dir,
"Loaded {num_annotations} validation examples",
num_annotations=len(val_annotations),
)
else:
logger.debug("No validation directory provided.")
logger.info("Configuration and data loaded. Starting training...")
train(
train_examples=train_examples,
val_examples=val_examples,
config=conf,
model_path=model_path,
if model_path is None:
api = BatDetect2API.from_config(conf)
else:
api = BatDetect2API.from_checkpoint(model_path)
return api.train(
train_annotations=train_annotations,
val_annotations=val_annotations,
train_workers=train_workers,
val_workers=val_workers,
checkpoint_dir=ckpt_dir,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
seed=seed,
)

View File

@ -11,7 +11,6 @@ from soundevent import data
from soundevent.geometry import compute_bounds
from soundevent.types import ClassMapper
from batdetect2.targets.terms import get_term_from_key
from batdetect2.types import (
Annotation,
AudioLoaderAnnotationGroup,
@ -173,18 +172,9 @@ def annotation_to_sound_event_annotation(
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
sound_event=sound_event,
tags=[
data.Tag(
term=get_term_from_key(label_key),
value=annotation["class"],
),
data.Tag(
term=get_term_from_key(event_key),
value=annotation["event"],
),
data.Tag(
term=get_term_from_key(individual_key),
value=str(annotation["individual"]),
),
data.Tag(key=label_key, value=annotation["class"]),
data.Tag(key=event_key, value=annotation["event"]),
data.Tag(key=individual_key, value=str(annotation["individual"])),
],
)
@ -219,17 +209,11 @@ def annotation_to_sound_event_prediction(
tags=[
data.PredictedTag(
score=annotation["class_prob"],
tag=data.Tag(
term=get_term_from_key(label_key),
value=annotation["class"],
),
tag=data.Tag(key=label_key, value=annotation["class"]),
),
data.PredictedTag(
score=annotation["det_prob"],
tag=data.Tag(
term=get_term_from_key(event_key),
value=annotation["event"],
),
tag=data.Tag(key=event_key, value=annotation["event"]),
),
],
)

42
src/batdetect2/config.py Normal file
View File

@ -0,0 +1,42 @@
from typing import Literal, Optional
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.audio import AudioConfig
from batdetect2.core import BaseConfig
from batdetect2.core.configs import load_config
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.inference.config import InferenceConfig
from batdetect2.models.config import BackboneConfig
from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.targets.config import TargetConfig
from batdetect2.train.config import TrainingConfig
__all__ = [
"BatDetect2Config",
"load_full_config",
]
class BatDetect2Config(BaseConfig):
config_version: Literal["v1"] = "v1"
train: TrainingConfig = Field(default_factory=TrainingConfig)
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
audio: AudioConfig = Field(default_factory=AudioConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
inference: InferenceConfig = Field(default_factory=InferenceConfig)
def load_full_config(
path: PathLike,
field: Optional[str] = None,
) -> BatDetect2Config:
return load_config(path, schema=BatDetect2Config, field=field)

View File

@ -0,0 +1,8 @@
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.registries import Registry
__all__ = [
"BaseConfig",
"load_config",
"Registry",
]

View File

@ -0,0 +1,95 @@
from typing import Optional
import numpy as np
import torch
import xarray as xr
def spec_to_xarray(
spec: np.ndarray,
start_time: float,
end_time: float,
min_freq: float,
max_freq: float,
) -> xr.DataArray:
if spec.ndim != 2:
raise ValueError(
"Input numpy spectrogram array should be 2-dimensional"
)
height, width = spec.shape
return xr.DataArray(
data=spec,
dims=["frequency", "time"],
coords={
"frequency": np.linspace(
min_freq,
max_freq,
height,
endpoint=False,
),
"time": np.linspace(
start_time,
end_time,
width,
endpoint=False,
),
},
)
def extend_width(
tensor: torch.Tensor,
extra: int,
axis: int = -1,
value: float = 0,
) -> torch.Tensor:
dims = len(tensor.shape)
axis = dims - axis % dims - 1
pad = [0 for _ in range(2 * dims)]
pad[2 * axis + 1] = extra
return torch.nn.functional.pad(
tensor,
pad,
mode="constant",
value=value,
)
def adjust_width(
tensor: torch.Tensor,
width: int,
axis: int = -1,
value: float = 0,
) -> torch.Tensor:
dims = len(tensor.shape)
axis = axis % dims
current_width = tensor.shape[axis]
if current_width == width:
return tensor
if current_width < width:
return extend_width(
tensor,
extra=width - current_width,
axis=axis,
value=value,
)
slices = [
slice(None, None) if index != axis else slice(None, width)
for index in range(dims)
]
return tensor[tuple(slices)]
def slice_tensor(
tensor: torch.Tensor,
start: Optional[int] = None,
end: Optional[int] = None,
dim: int = -1,
) -> torch.Tensor:
slices = [slice(None)] * tensor.ndim
slices[dim] = slice(start, end)
return tensor[tuple(slices)]

View File

@ -27,7 +27,7 @@ class BaseConfig(BaseModel):
and serialization capabilities.
"""
model_config = ConfigDict(extra="ignore")
model_config = ConfigDict(extra="forbid")
def to_yaml_string(
self,
@ -53,6 +53,7 @@ class BaseConfig(BaseModel):
"""
return yaml.dump(
self.model_dump(
mode="json",
exclude_none=exclude_none,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,

View File

@ -0,0 +1,98 @@
import sys
from typing import Callable, Dict, Generic, Tuple, Type, TypeVar
from pydantic import BaseModel
if sys.version_info >= (3, 10):
from typing import Concatenate, ParamSpec
else:
from typing_extensions import Concatenate, ParamSpec
__all__ = [
"Registry",
"SimpleRegistry",
]
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
T_Type = TypeVar("T_Type", covariant=True)
P_Type = ParamSpec("P_Type")
T = TypeVar("T")
class SimpleRegistry(Generic[T]):
def __init__(self, name: str):
self._name = name
self._registry = {}
def register(self, name: str):
def decorator(obj: T) -> T:
self._registry[name] = obj
return obj
return decorator
def get(self, name: str) -> T:
return self._registry[name]
def has(self, name: str) -> bool:
return name in self._registry
class Registry(Generic[T_Type, P_Type]):
"""A generic class to create and manage a registry of items."""
def __init__(self, name: str):
self._name = name
self._registry: Dict[
str, Callable[Concatenate[..., P_Type], T_Type]
] = {}
self._config_types: Dict[str, Type[BaseModel]] = {}
def register(
self,
config_cls: Type[T_Config],
):
fields = config_cls.model_fields
if "name" not in fields:
raise ValueError("Configuration object must have a 'name' field.")
name = fields["name"].default
self._config_types[name] = config_cls
if not isinstance(name, str):
raise ValueError("'name' field must be a string literal.")
def decorator(
func: Callable[Concatenate[T_Config, P_Type], T_Type],
):
self._registry[name] = func
return func
return decorator
def get_config_types(self) -> Tuple[Type[BaseModel], ...]:
return tuple(self._config_types.values())
def build(
self,
config: BaseModel,
*args: P_Type.args,
**kwargs: P_Type.kwargs,
) -> T_Type:
"""Builds a logic instance from a config object."""
name = getattr(config, "name") # noqa: B009
if name is None:
raise ValueError("Config does not have a name field")
if name not in self._registry:
raise NotImplementedError(
f"No {self._name} with name '{name}' is registered."
)
return self._registry[name](config, *args, **kwargs)

View File

@ -14,8 +14,9 @@ format-specific loading function to retrieve the annotations as a standard
"""
from pathlib import Path
from typing import Optional, Union
from typing import Annotated, Optional, Union
from pydantic import Field
from soundevent import data
from batdetect2.data.annotations.aoef import (
@ -42,10 +43,13 @@ __all__ = [
]
AnnotationFormats = Union[
BatDetect2MergedAnnotations,
BatDetect2FilesAnnotations,
AOEFAnnotations,
AnnotationFormats = Annotated[
Union[
BatDetect2MergedAnnotations,
BatDetect2FilesAnnotations,
AOEFAnnotations,
],
Field(discriminator="format"),
]
"""Type Alias representing all supported data source configurations.

View File

@ -18,7 +18,7 @@ from uuid import uuid5
from pydantic import Field
from soundevent import data, io
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.data.annotations.types import AnnotatedDataset
__all__ = [

View File

@ -33,7 +33,7 @@ from loguru import logger
from pydantic import Field, ValidationError
from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.data.annotations.legacy import (
FileAnnotation,
file_annotation_to_clip,
@ -301,7 +301,8 @@ def load_batdetect2_merged_annotated_dataset(
for ann in content:
try:
ann = FileAnnotation.model_validate(ann)
except ValueError:
except ValueError as err:
logger.warning(f"Invalid annotation file: {err}")
continue
if (
@ -309,14 +310,17 @@ def load_batdetect2_merged_annotated_dataset(
and dataset.filter.only_annotated
and not ann.annotated
):
logger.debug(f"Skipping incomplete annotation {ann.id}")
continue
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
logger.debug(f"Skipping annotation with issues {ann.id}")
continue
try:
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
except FileNotFoundError:
except FileNotFoundError as err:
logger.warning(f"Error loading annotations: {err}")
continue
annotations.append(file_annotation_to_clip_annotation(ann, clip))

View File

@ -8,8 +8,6 @@ from typing import Callable, List, Optional, Union
from pydantic import BaseModel, Field
from soundevent import data
from batdetect2.targets import get_term_from_key
PathLike = Union[Path, str, os.PathLike]
__all__ = []
@ -91,18 +89,9 @@ def annotation_to_sound_event(
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
sound_event=sound_event,
tags=[
data.Tag(
term=get_term_from_key(label_key),
value=annotation.label,
),
data.Tag(
term=get_term_from_key(event_key),
value=annotation.event,
),
data.Tag(
term=get_term_from_key(individual_key),
value=str(annotation.individual),
),
data.Tag(key=label_key, value=annotation.label),
data.Tag(key=event_key, value=annotation.event),
data.Tag(key=individual_key, value=str(annotation.individual)),
],
)
@ -123,12 +112,7 @@ def file_annotation_to_clip(
recording = data.Recording.from_file(
full_path,
time_expansion=file_annotation.time_exp,
tags=[
data.Tag(
term=get_term_from_key(label_key),
value=file_annotation.label,
)
],
tags=[data.Tag(key=label_key, value=file_annotation.label)],
)
return data.Clip(
@ -155,11 +139,7 @@ def file_annotation_to_clip_annotation(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
clip=clip,
notes=notes,
tags=[
data.Tag(
term=get_term_from_key(label_key), value=file_annotation.label
)
],
tags=[data.Tag(key=label_key, value=file_annotation.label)],
sound_events=[
annotation_to_sound_event(
annotation,

View File

@ -1,6 +1,6 @@
from pathlib import Path
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
__all__ = [
"AnnotatedDataset",

View File

@ -0,0 +1,287 @@
from collections.abc import Callable
from typing import Annotated, List, Literal, Sequence, Union
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
conditions: Registry[SoundEventCondition, []] = Registry("condition")
class HasTagConfig(BaseConfig):
name: Literal["has_tag"] = "has_tag"
tag: data.Tag
class HasTag:
def __init__(self, tag: data.Tag):
self.tag = tag
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return self.tag in sound_event_annotation.tags
@conditions.register(HasTagConfig)
@staticmethod
def from_config(config: HasTagConfig):
return HasTag(tag=config.tag)
class HasAllTagsConfig(BaseConfig):
name: Literal["has_all_tags"] = "has_all_tags"
tags: List[data.Tag]
class HasAllTags:
def __init__(self, tags: List[data.Tag]):
if not tags:
raise ValueError("Need to specify at least one tag")
self.tags = set(tags)
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return self.tags.issubset(sound_event_annotation.tags)
@conditions.register(HasAllTagsConfig)
@staticmethod
def from_config(config: HasAllTagsConfig):
return HasAllTags(tags=config.tags)
class HasAnyTagConfig(BaseConfig):
name: Literal["has_any_tag"] = "has_any_tag"
tags: List[data.Tag]
class HasAnyTag:
def __init__(self, tags: List[data.Tag]):
if not tags:
raise ValueError("Need to specify at least one tag")
self.tags = set(tags)
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return bool(self.tags.intersection(sound_event_annotation.tags))
@conditions.register(HasAnyTagConfig)
@staticmethod
def from_config(config: HasAnyTagConfig):
return HasAnyTag(tags=config.tags)
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
class DurationConfig(BaseConfig):
name: Literal["duration"] = "duration"
operator: Operator
seconds: float
def _build_comparator(
operator: Operator, value: float
) -> Callable[[float], bool]:
if operator == "gt":
return lambda x: x > value
if operator == "gte":
return lambda x: x >= value
if operator == "lt":
return lambda x: x < value
if operator == "lte":
return lambda x: x <= value
if operator == "eq":
return lambda x: x == value
raise ValueError(f"Invalid operator {operator}")
class Duration:
def __init__(self, operator: Operator, seconds: float):
self.operator = operator
self.seconds = seconds
self._comparator = _build_comparator(self.operator, self.seconds)
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> bool:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
start_time, _, end_time, _ = compute_bounds(geometry)
duration = end_time - start_time
return self._comparator(duration)
@conditions.register(DurationConfig)
@staticmethod
def from_config(config: DurationConfig):
return Duration(operator=config.operator, seconds=config.seconds)
class FrequencyConfig(BaseConfig):
name: Literal["frequency"] = "frequency"
boundary: Literal["low", "high"]
operator: Operator
hertz: float
class Frequency:
def __init__(
self,
operator: Operator,
boundary: Literal["low", "high"],
hertz: float,
):
self.operator = operator
self.hertz = hertz
self.boundary = boundary
self._comparator = _build_comparator(self.operator, self.hertz)
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> bool:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
# Automatically false if geometry does not have a frequency range
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
return False
_, low_freq, _, high_freq = compute_bounds(geometry)
if self.boundary == "low":
return self._comparator(low_freq)
return self._comparator(high_freq)
@conditions.register(FrequencyConfig)
@staticmethod
def from_config(config: FrequencyConfig):
return Frequency(
operator=config.operator,
boundary=config.boundary,
hertz=config.hertz,
)
class AllOfConfig(BaseConfig):
name: Literal["all_of"] = "all_of"
conditions: Sequence["SoundEventConditionConfig"]
class AllOf:
def __init__(self, conditions: List[SoundEventCondition]):
self.conditions = conditions
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return all(c(sound_event_annotation) for c in self.conditions)
@conditions.register(AllOfConfig)
@staticmethod
def from_config(config: AllOfConfig):
conditions = [
build_sound_event_condition(cond) for cond in config.conditions
]
return AllOf(conditions)
class AnyOfConfig(BaseConfig):
name: Literal["any_of"] = "any_of"
conditions: List["SoundEventConditionConfig"]
class AnyOf:
def __init__(self, conditions: List[SoundEventCondition]):
self.conditions = conditions
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return any(c(sound_event_annotation) for c in self.conditions)
@conditions.register(AnyOfConfig)
@staticmethod
def from_config(config: AnyOfConfig):
conditions = [
build_sound_event_condition(cond) for cond in config.conditions
]
return AnyOf(conditions)
class NotConfig(BaseConfig):
name: Literal["not"] = "not"
condition: "SoundEventConditionConfig"
class Not:
def __init__(self, condition: SoundEventCondition):
self.condition = condition
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return not self.condition(sound_event_annotation)
@conditions.register(NotConfig)
@staticmethod
def from_config(config: NotConfig):
condition = build_sound_event_condition(config.condition)
return Not(condition)
SoundEventConditionConfig = Annotated[
Union[
HasTagConfig,
HasAllTagsConfig,
HasAnyTagConfig,
DurationConfig,
FrequencyConfig,
AllOfConfig,
AnyOfConfig,
NotConfig,
],
Field(discriminator="name"),
]
def build_sound_event_condition(
config: SoundEventConditionConfig,
) -> SoundEventCondition:
return conditions.build(config)
def filter_clip_annotation(
clip_annotation: data.ClipAnnotation,
condition: SoundEventCondition,
) -> data.ClipAnnotation:
return clip_annotation.model_copy(
update=dict(
sound_events=[
sound_event
for sound_event in clip_annotation.sound_events
if condition(sound_event)
]
)
)

View File

@ -19,18 +19,29 @@ The core components are:
"""
from pathlib import Path
from typing import Annotated, List, Optional
from typing import List, Optional
from loguru import logger
from pydantic import Field
from soundevent import data, io
from batdetect2.configs import BaseConfig, load_config
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.data.annotations import (
AnnotatedDataset,
AnnotationFormats,
load_annotated_dataset,
)
from batdetect2.data.conditions import (
SoundEventConditionConfig,
build_sound_event_condition,
filter_clip_annotation,
)
from batdetect2.data.transforms import (
ApplyAll,
SoundEventTransformConfig,
build_sound_event_transform,
transform_clip_annotation,
)
from batdetect2.targets.terms import data_source
__all__ = [
@ -52,79 +63,68 @@ sources.
class DatasetConfig(BaseConfig):
"""Configuration model defining the structure of a BatDetect2 dataset.
This class is typically loaded from a YAML file and describes the components
of the dataset, including metadata and a list of data sources.
Attributes
----------
name : str
A descriptive name for the dataset (e.g., "UK_Bats_Project_2024").
description : str
A longer description of the dataset's contents, origin, purpose, etc.
sources : List[AnnotationFormats]
A list defining the different data sources contributing to this
dataset. Each item in the list must conform to one of the Pydantic
models defined in the `AnnotationFormats` type union. The specific
model used for each source is determined by the mandatory `format`
field within the source's configuration, allowing BatDetect2 to use the
correct parser for different annotation styles.
"""
"""Configuration model defining the structure of a BatDetect2 dataset."""
name: str
description: str
sources: List[
Annotated[AnnotationFormats, Field(..., discriminator="format")]
]
sources: List[AnnotationFormats]
sound_event_filter: Optional[SoundEventConditionConfig] = None
sound_event_transforms: List[SoundEventTransformConfig] = Field(
default_factory=list
)
def load_dataset(
dataset: DatasetConfig,
config: DatasetConfig,
base_dir: Optional[Path] = None,
) -> Dataset:
"""Load all clip annotations from the sources defined in a DatasetConfig.
Iterates through each data source specified in the `dataset_config`,
delegates the loading and parsing of that source's annotations to
`batdetect2.data.annotations.load_annotated_dataset` (which handles
different data formats), and aggregates all resulting `ClipAnnotation`
objects into a single flat list.
Parameters
----------
dataset_config : DatasetConfig
The configuration object describing the dataset and its sources.
base_dir : Path, optional
An optional base directory path. If provided, relative paths for
metadata files or data directories within the `dataset_config`'s
sources might be resolved relative to this directory. Defaults to None.
Returns
-------
Dataset (List[data.ClipAnnotation])
A flat list containing all loaded `ClipAnnotation` metadata objects
from all specified sources.
Raises
------
Exception
Can raise various exceptions during the delegated loading process
(`load_annotated_dataset`) if files are not found, cannot be parsed
according to the specified format, or other I/O errors occur.
"""
"""Load all clip annotations from the sources defined in a DatasetConfig."""
clip_annotations = []
for source in dataset.sources:
condition = (
build_sound_event_condition(config.sound_event_filter)
if config.sound_event_filter is not None
else None
)
transform = (
ApplyAll(
[
build_sound_event_transform(step)
for step in config.sound_event_transforms
]
)
if config.sound_event_transforms
else None
)
for source in config.sources:
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
logger.debug(
"Loaded {num_examples} from dataset source '{source_name}'",
num_examples=len(annotated_source.clip_annotations),
source_name=source.name,
)
clip_annotations.extend(
insert_source_tag(clip_annotation, source)
for clip_annotation in annotated_source.clip_annotations
)
for clip_annotation in annotated_source.clip_annotations:
clip_annotation = insert_source_tag(clip_annotation, source)
if condition is not None:
clip_annotation = filter_clip_annotation(
clip_annotation,
condition,
)
if transform is not None:
clip_annotation = transform_clip_annotation(
clip_annotation,
transform,
)
clip_annotations.append(clip_annotation)
return clip_annotations
@ -161,7 +161,6 @@ def insert_source_tag(
)
# TODO: add documentation
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
return load_config(path=path, schema=DatasetConfig, field=field)

View File

@ -4,22 +4,14 @@ from typing import Optional, Tuple
from soundevent import data
from batdetect2.data.datasets import Dataset
from batdetect2.targets.types import TargetProtocol
from batdetect2.typing.targets import TargetProtocol
def iterate_over_sound_events(
dataset: Dataset,
targets: TargetProtocol,
apply_filter: bool = True,
apply_transform: bool = True,
exclude_generic: bool = True,
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
"""Iterate over sound events in a dataset, applying filtering and
transformations.
This generator function processes sound event annotations from a given
dataset, allowing for optional filtering, transformation, and exclusion of
unclassifiable (generic) events based on the provided target definitions.
"""Iterate over sound events in a dataset.
Parameters
----------
@ -29,18 +21,6 @@ def iterate_over_sound_events(
targets : TargetProtocol
An object implementing the `TargetProtocol`, which provides methods
for filtering, transforming, and encoding sound events.
apply_filter : bool, optional
If True, sound events will be filtered using `targets.filter()`.
Only events for which `targets.filter()` returns True will be yielded.
Defaults to True.
apply_transform : bool, optional
If True, sound events will be transformed using `targets.transform()`
before being yielded. Defaults to True.
exclude_generic : bool, optional
If True, sound events that result in a `None` class name after
`targets.encode()` will be excluded. This is typically used to
filter out events that cannot be mapped to a specific target class.
Defaults to True.
Yields
------
@ -63,17 +43,9 @@ def iterate_over_sound_events(
"""
for clip_annotation in dataset:
for sound_event_annotation in clip_annotation.sound_events:
if apply_filter:
if not targets.filter(sound_event_annotation):
continue
if apply_transform:
sound_event_annotation = targets.transform(
sound_event_annotation
)
class_name = targets.encode_class(sound_event_annotation)
if class_name is None and exclude_generic:
if not targets.filter(sound_event_annotation):
continue
class_name = targets.encode_class(sound_event_annotation)
yield class_name, sound_event_annotation

View File

@ -7,7 +7,7 @@ from batdetect2.data.summary import (
extract_recordings_df,
extract_sound_events_df,
)
from batdetect2.targets.types import TargetProtocol
from batdetect2.typing.targets import TargetProtocol
def split_dataset_by_recordings(

View File

@ -2,7 +2,7 @@ import pandas as pd
from soundevent.geometry import compute_bounds
from batdetect2.data.datasets import Dataset
from batdetect2.targets.types import TargetProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"extract_recordings_df",
@ -100,8 +100,11 @@ def extract_sound_events_df(
class_name = targets.encode_class(sound_event)
if class_name is None and exclude_generic:
continue
if class_name is None:
if exclude_generic:
continue
else:
class_name = targets.detection_class_name
start_time, low_freq, end_time, high_freq = compute_bounds(
sound_event.sound_event.geometry
@ -153,7 +156,7 @@ def compute_class_summary(
sound_events = extract_sound_events_df(
dataset,
targets,
exclude_generic=True,
exclude_generic=False,
exclude_non_target=True,
)
recordings = extract_recordings_df(dataset)

View File

@ -0,0 +1,252 @@
from collections.abc import Callable
from typing import Annotated, Dict, List, Literal, Optional, Union
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.data.conditions import (
SoundEventCondition,
SoundEventConditionConfig,
build_sound_event_condition,
)
SoundEventTransform = Callable[
[data.SoundEventAnnotation],
data.SoundEventAnnotation,
]
transforms: Registry[SoundEventTransform, []] = Registry("transform")
class SetFrequencyBoundConfig(BaseConfig):
name: Literal["set_frequency"] = "set_frequency"
boundary: Literal["low", "high"] = "low"
hertz: float
class SetFrequencyBound:
def __init__(self, hertz: float, boundary: Literal["low", "high"] = "low"):
self.hertz = hertz
self.boundary = boundary
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
sound_event = sound_event_annotation.sound_event
geometry = sound_event.geometry
if geometry is None:
return sound_event_annotation
if not isinstance(geometry, data.BoundingBox):
return sound_event_annotation
start_time, low_freq, end_time, high_freq = geometry.coordinates
if self.boundary == "low":
low_freq = self.hertz
high_freq = max(high_freq, low_freq)
elif self.boundary == "high":
high_freq = self.hertz
low_freq = min(high_freq, low_freq)
geometry = data.BoundingBox(
coordinates=[start_time, low_freq, end_time, high_freq],
)
sound_event = sound_event.model_copy(update=dict(geometry=geometry))
return sound_event_annotation.model_copy(
update=dict(sound_event=sound_event)
)
@transforms.register(SetFrequencyBoundConfig)
@staticmethod
def from_config(config: SetFrequencyBoundConfig):
return SetFrequencyBound(hertz=config.hertz, boundary=config.boundary)
class ApplyIfConfig(BaseConfig):
name: Literal["apply_if"] = "apply_if"
transform: "SoundEventTransformConfig"
condition: SoundEventConditionConfig
class ApplyIf:
def __init__(
self,
condition: SoundEventCondition,
transform: SoundEventTransform,
):
self.condition = condition
self.transform = transform
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
if not self.condition(sound_event_annotation):
return sound_event_annotation
return self.transform(sound_event_annotation)
@transforms.register(ApplyIfConfig)
@staticmethod
def from_config(config: ApplyIfConfig):
transform = build_sound_event_transform(config.transform)
condition = build_sound_event_condition(config.condition)
return ApplyIf(condition=condition, transform=transform)
class ReplaceTagConfig(BaseConfig):
name: Literal["replace_tag"] = "replace_tag"
original: data.Tag
replacement: data.Tag
class ReplaceTag:
def __init__(
self,
original: data.Tag,
replacement: data.Tag,
):
self.original = original
self.replacement = replacement
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
tags = []
for tag in sound_event_annotation.tags:
if tag == self.original:
tags.append(self.replacement)
else:
tags.append(tag)
return sound_event_annotation.model_copy(update=dict(tags=tags))
@transforms.register(ReplaceTagConfig)
@staticmethod
def from_config(config: ReplaceTagConfig):
return ReplaceTag(
original=config.original, replacement=config.replacement
)
class MapTagValueConfig(BaseConfig):
name: Literal["map_tag_value"] = "map_tag_value"
tag_key: str
value_mapping: Dict[str, str]
target_key: Optional[str] = None
class MapTagValue:
def __init__(
self,
tag_key: str,
value_mapping: Dict[str, str],
target_key: Optional[str] = None,
):
self.tag_key = tag_key
self.value_mapping = value_mapping
self.target_key = target_key
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
tags = []
for tag in sound_event_annotation.tags:
if tag.key != self.tag_key:
tags.append(tag)
continue
value = self.value_mapping.get(tag.value)
if value is None:
tags.append(tag)
continue
if self.target_key is None:
tags.append(tag.model_copy(update=dict(value=value)))
else:
tags.append(
data.Tag(
key=self.target_key, # type: ignore
value=value,
)
)
return sound_event_annotation.model_copy(update=dict(tags=tags))
@transforms.register(MapTagValueConfig)
@staticmethod
def from_config(config: MapTagValueConfig):
return MapTagValue(
tag_key=config.tag_key,
value_mapping=config.value_mapping,
target_key=config.target_key,
)
class ApplyAllConfig(BaseConfig):
name: Literal["apply_all"] = "apply_all"
steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
class ApplyAll:
def __init__(self, steps: List[SoundEventTransform]):
self.steps = steps
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> data.SoundEventAnnotation:
for step in self.steps:
sound_event_annotation = step(sound_event_annotation)
return sound_event_annotation
@transforms.register(ApplyAllConfig)
@staticmethod
def from_config(config: ApplyAllConfig):
steps = [build_sound_event_transform(step) for step in config.steps]
return ApplyAll(steps)
SoundEventTransformConfig = Annotated[
Union[
SetFrequencyBoundConfig,
ReplaceTagConfig,
MapTagValueConfig,
ApplyIfConfig,
ApplyAllConfig,
],
Field(discriminator="name"),
]
def build_sound_event_transform(
config: SoundEventTransformConfig,
) -> SoundEventTransform:
return transforms.build(config)
def transform_clip_annotation(
clip_annotation: data.ClipAnnotation,
transform: SoundEventTransform,
) -> data.ClipAnnotation:
return clip_annotation.model_copy(
update=dict(
sound_events=[
transform(sound_event)
for sound_event in clip_annotation.sound_events
]
)
)

View File

@ -1,15 +1,15 @@
from batdetect2.evaluate.config import (
EvaluationConfig,
load_evaluation_config,
)
from batdetect2.evaluate.match import (
match_predictions_and_annotations,
match_sound_events_and_raw_predictions,
)
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
from batdetect2.evaluate.tasks import TaskConfig, build_task
__all__ = [
"EvaluationConfig",
"Evaluator",
"TaskConfig",
"build_evaluator",
"build_task",
"evaluate",
"load_evaluation_config",
"match_predictions_and_annotations",
"match_sound_events_and_raw_predictions",
"DEFAULT_EVAL_DIR",
]

View File

@ -0,0 +1,230 @@
from typing import Annotated, Literal, Optional, Union
from pydantic import Field
from soundevent import data
from soundevent.evaluation import compute_affinity
from soundevent.geometry import compute_interval_overlap
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.typing.evaluate import AffinityFunction
affinity_functions: Registry[AffinityFunction, []] = Registry(
"matching_strategy"
)
class TimeAffinityConfig(BaseConfig):
name: Literal["time_affinity"] = "time_affinity"
time_buffer: float = 0.01
class TimeAffinity(AffinityFunction):
def __init__(self, time_buffer: float):
self.time_buffer = time_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
return compute_timestamp_affinity(
geometry1, geometry2, time_buffer=self.time_buffer
)
@affinity_functions.register(TimeAffinityConfig)
@staticmethod
def from_config(config: TimeAffinityConfig):
return TimeAffinity(time_buffer=config.time_buffer)
def compute_timestamp_affinity(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
) -> float:
assert isinstance(geometry1, data.TimeStamp)
assert isinstance(geometry2, data.TimeStamp)
start_time1 = geometry1.coordinates
start_time2 = geometry2.coordinates
a = min(start_time1, start_time2)
b = max(start_time1, start_time2)
if b - a >= 2 * time_buffer:
return 0
intersection = a - b + 2 * time_buffer
union = b - a + 2 * time_buffer
return intersection / union
class IntervalIOUConfig(BaseConfig):
name: Literal["interval_iou"] = "interval_iou"
time_buffer: float = 0.01
class IntervalIOU(AffinityFunction):
def __init__(self, time_buffer: float):
self.time_buffer = time_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
return compute_interval_iou(
geometry1,
geometry2,
time_buffer=self.time_buffer,
)
@affinity_functions.register(IntervalIOUConfig)
@staticmethod
def from_config(config: IntervalIOUConfig):
return IntervalIOU(time_buffer=config.time_buffer)
def compute_interval_iou(
geometry1: data.Geometry,
geometry2: data.Geometry,
time_buffer: float = 0.01,
) -> float:
assert isinstance(geometry1, data.TimeInterval)
assert isinstance(geometry2, data.TimeInterval)
start_time1, end_time1 = geometry1.coordinates
start_time2, end_time2 = geometry1.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
intersection = compute_interval_overlap(
(start_time1, end_time1),
(start_time2, end_time2),
)
union = (
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
)
if union == 0:
return 0
return intersection / union
class BBoxIOUConfig(BaseConfig):
name: Literal["bbox_iou"] = "bbox_iou"
time_buffer: float = 0.01
freq_buffer: float = 1000
class BBoxIOU(AffinityFunction):
def __init__(self, time_buffer: float, freq_buffer: float):
self.time_buffer = time_buffer
self.freq_buffer = freq_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
if not isinstance(geometry1, data.BoundingBox):
raise TypeError(
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}"
)
if not isinstance(geometry2, data.BoundingBox):
raise TypeError(
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
)
return bbox_iou(
geometry1,
geometry2,
time_buffer=self.time_buffer,
freq_buffer=self.freq_buffer,
)
@affinity_functions.register(BBoxIOUConfig)
@staticmethod
def from_config(config: BBoxIOUConfig):
return BBoxIOU(
time_buffer=config.time_buffer,
freq_buffer=config.freq_buffer,
)
def bbox_iou(
geometry1: data.BoundingBox,
geometry2: data.BoundingBox,
time_buffer: float = 0.01,
freq_buffer: float = 1000,
) -> float:
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
start_time1 -= time_buffer
start_time2 -= time_buffer
end_time1 += time_buffer
end_time2 += time_buffer
low_freq1 -= freq_buffer
low_freq2 -= freq_buffer
high_freq1 += freq_buffer
high_freq2 += freq_buffer
time_intersection = compute_interval_overlap(
(start_time1, end_time1),
(start_time2, end_time2),
)
freq_intersection = max(
0,
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
)
intersection = time_intersection * freq_intersection
if intersection == 0:
return 0
union = (
(end_time1 - start_time1) * (high_freq1 - low_freq1)
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
- intersection
)
return intersection / union
class GeometricIOUConfig(BaseConfig):
name: Literal["geometric_iou"] = "geometric_iou"
time_buffer: float = 0.01
freq_buffer: float = 1000
class GeometricIOU(AffinityFunction):
def __init__(self, time_buffer: float):
self.time_buffer = time_buffer
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
return compute_affinity(
geometry1,
geometry2,
time_buffer=self.time_buffer,
)
@affinity_functions.register(GeometricIOUConfig)
@staticmethod
def from_config(config: GeometricIOUConfig):
return GeometricIOU(time_buffer=config.time_buffer)
AffinityConfig = Annotated[
Union[
TimeAffinityConfig,
IntervalIOUConfig,
BBoxIOUConfig,
GeometricIOUConfig,
],
Field(discriminator="name"),
]
def build_affinity_function(
config: Optional[AffinityConfig] = None,
) -> AffinityFunction:
config = config or GeometricIOUConfig()
return affinity_functions.build(config)

View File

@ -1,10 +1,15 @@
from typing import Optional
from typing import List, Optional
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.match import MatchConfig
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.tasks import (
TaskConfig,
)
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
__all__ = [
"EvaluationConfig",
@ -13,7 +18,13 @@ __all__ = [
class EvaluationConfig(BaseConfig):
match: MatchConfig = Field(default_factory=MatchConfig)
tasks: List[TaskConfig] = Field(
default_factory=lambda: [
DetectionTaskConfig(),
ClassificationTaskConfig(),
]
)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
def load_evaluation_config(

View File

@ -0,0 +1,144 @@
from typing import List, NamedTuple, Optional, Sequence
import torch
from loguru import logger
from pydantic import Field
from soundevent import data
from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
from batdetect2.audio.clips import PaddedClipConfig
from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import (
AudioLoader,
ClipperProtocol,
PreprocessorProtocol,
)
__all__ = [
"TestDataset",
"build_test_dataset",
"build_test_loader",
]
class TestExample(NamedTuple):
spec: torch.Tensor
idx: torch.Tensor
start_time: torch.Tensor
end_time: torch.Tensor
class TestDataset(Dataset[TestExample]):
clip_annotations: List[data.ClipAnnotation]
def __init__(
self,
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
clipper: Optional[ClipperProtocol] = None,
audio_dir: Optional[data.PathLike] = None,
):
self.clip_annotations = list(clip_annotations)
self.clipper = clipper
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.audio_dir = audio_dir
def __len__(self):
return len(self.clip_annotations)
def __getitem__(self, idx: int) -> TestExample:
clip_annotation = self.clip_annotations[idx]
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
clip = clip_annotation.clip
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
wav_tensor = torch.tensor(wav).unsqueeze(0)
spectrogram = self.preprocessor(wav_tensor)
return TestExample(
spec=spectrogram,
idx=torch.tensor(idx),
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)
class TestLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
def build_test_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TestLoaderConfig] = None,
num_workers: Optional[int] = None,
) -> DataLoader[TestExample]:
logger.info("Building test data loader...")
config = config or TestLoaderConfig()
logger.opt(lazy=True).debug(
"Test data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
test_dataset = build_test_dataset(
clip_annotations,
audio_loader=audio_loader,
preprocessor=preprocessor,
config=config,
)
num_workers = num_workers or config.num_workers
return DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
collate_fn=_collate_fn,
)
def build_test_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TestLoaderConfig] = None,
) -> TestDataset:
logger.info("Building training dataset...")
config = config or TestLoaderConfig()
clipper = build_clipper(config=config.clipping_strategy)
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
return TestDataset(
clip_annotations,
audio_loader=audio_loader,
clipper=clipper,
preprocessor=preprocessor,
)
def _collate_fn(batch: List[TestExample]) -> TestExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TestExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)

View File

@ -0,0 +1,69 @@
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Sequence
from lightning import Trainer
from soundevent import data
from batdetect2.audio import build_audio_loader
from batdetect2.evaluate.dataset import build_test_loader
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import build_logger
from batdetect2.models import Model
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
from batdetect2.typing import (
AudioLoader,
PreprocessorProtocol,
TargetProtocol,
)
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
def evaluate(
model: Model,
test_annotations: Sequence[data.ClipAnnotation],
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
num_workers: Optional[int] = None,
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
):
from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config()
audio_loader = audio_loader or build_audio_loader(config=config.audio)
preprocessor = preprocessor or build_preprocessor(
config=config.preprocess,
input_samplerate=audio_loader.samplerate,
)
targets = targets or build_targets(config=config.targets)
loader = build_test_loader(
test_annotations,
audio_loader=audio_loader,
preprocessor=preprocessor,
num_workers=num_workers,
)
evaluator = build_evaluator(config=config.evaluation, targets=targets)
logger = build_logger(
config.evaluation.logger,
log_dir=Path(output_dir),
experiment_name=experiment_name,
run_name=run_name,
)
module = EvaluationModule(model, evaluator)
trainer = Trainer(logger=logger, enable_checkpointing=False)
return trainer.test(module, loader)

View File

@ -0,0 +1,67 @@
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from matplotlib.figure import Figure
from soundevent import data
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.tasks import build_task
from batdetect2.targets import build_targets
from batdetect2.typing import EvaluatorProtocol, RawPrediction, TargetProtocol
__all__ = [
"Evaluator",
"build_evaluator",
]
class Evaluator:
def __init__(
self,
targets: TargetProtocol,
tasks: Sequence[EvaluatorProtocol],
):
self.targets = targets
self.tasks = tasks
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> List[Any]:
return [
task.evaluate(clip_annotations, predictions) for task in self.tasks
]
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
results = {}
for task, outputs in zip(self.tasks, eval_outputs):
results.update(task.compute_metrics(outputs))
return results
def generate_plots(
self,
eval_outputs: List[Any],
) -> Iterable[Tuple[str, Figure]]:
for task, outputs in zip(self.tasks, eval_outputs):
for name, fig in task.generate_plots(outputs):
yield name, fig
def build_evaluator(
config: Optional[Union[EvaluationConfig, dict]] = None,
targets: Optional[TargetProtocol] = None,
) -> EvaluatorProtocol:
targets = targets or build_targets()
if config is None:
config = EvaluationConfig()
if not isinstance(config, EvaluationConfig):
config = EvaluationConfig.model_validate(config)
return Evaluator(
targets=targets,
tasks=[build_task(task, targets=targets) for task in config.tasks],
)

View File

@ -0,0 +1,82 @@
from typing import Any, List
from lightning import LightningModule
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.logging import get_image_logger
from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import EvaluatorProtocol
from batdetect2.typing.postprocess import RawPrediction
class EvaluationModule(LightningModule):
def __init__(
self,
model: Model,
evaluator: EvaluatorProtocol,
):
super().__init__()
self.model = model
self.evaluator = evaluator
self.clip_annotations: List[data.ClipAnnotation] = []
self.predictions: List[List[RawPrediction]] = []
def test_step(self, batch: TestExample, batch_idx: int):
dataset = self.get_dataset()
clip_annotations = [
dataset.clip_annotations[int(example_idx)]
for example_idx in batch.idx
]
outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor(
outputs,
start_times=[ca.clip.start_time for ca in clip_annotations],
)
predictions = [
to_raw_predictions(
clip_dets.numpy(),
targets=self.evaluator.targets,
)
for clip_dets in clip_detections
]
self.clip_annotations.extend(clip_annotations)
self.predictions.extend(predictions)
def on_test_epoch_start(self):
self.clip_annotations = []
self.predictions = []
def on_test_epoch_end(self):
clip_evals = self.evaluator.evaluate(
self.clip_annotations,
self.predictions,
)
self.log_metrics(clip_evals)
self.generate_plots(clip_evals)
def generate_plots(self, evaluated_clips: Any):
plotter = get_image_logger(self.logger) # type: ignore
if plotter is None:
return
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
plotter(figure_name, fig, self.global_step)
def log_metrics(self, evaluated_clips: Any):
metrics = self.evaluator.compute_metrics(evaluated_clips)
self.log_dict(metrics)
def get_dataset(self) -> TestDataset:
dataloaders = self.trainer.test_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, TestDataset)
return dataset

View File

@ -1,50 +1,211 @@
from collections.abc import Callable, Iterable, Mapping
from typing import List, Literal, Optional, Tuple
from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
import numpy as np
from pydantic import Field
from soundevent import data
from soundevent.evaluation import compute_affinity
from soundevent.evaluation import (
match_geometries as optimal_match,
)
from soundevent.evaluation import match_geometries as optimal_match
from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol
MatchingStrategy = Literal["greedy", "optimal"]
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.affinity import (
AffinityConfig,
GeometricIOUConfig,
build_affinity_function,
)
from batdetect2.targets import build_targets
from batdetect2.typing import (
AffinityFunction,
MatcherProtocol,
MatchEvaluation,
RawPrediction,
TargetProtocol,
)
from batdetect2.typing.evaluate import ClipMatches
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
"""The geometry representation to use for matching."""
matching_strategies = Registry("matching_strategy")
class MatchConfig(BaseConfig):
"""Configuration for matching geometries.
Attributes
----------
strategy : MatchingStrategy, default="greedy"
The matching algorithm to use. 'greedy' prioritizes high-confidence
predictions, while 'optimal' finds the globally best set of matches.
geometry : MatchingGeometry, default="timestamp"
The geometric representation to use when computing affinity.
affinity_threshold : float, default=0.0
The minimum affinity score (e.g., IoU) required for a valid match.
time_buffer : float, default=0.005
Time tolerance in seconds used in affinity calculations.
frequency_buffer : float, default=1000
Frequency tolerance in Hertz used in affinity calculations.
"""
def match(
sound_event_annotations: Sequence[data.SoundEventAnnotation],
raw_predictions: Sequence[RawPrediction],
clip: data.Clip,
scores: Optional[Sequence[float]] = None,
targets: Optional[TargetProtocol] = None,
matcher: Optional[MatcherProtocol] = None,
) -> ClipMatches:
if matcher is None:
matcher = build_matcher()
strategy: MatchingStrategy = "greedy"
geometry: MatchingGeometry = "timestamp"
affinity_threshold: float = 0.0
time_buffer: float = 0.005
frequency_buffer: float = 1_000
if targets is None:
targets = build_targets()
target_geometries: List[data.Geometry] = [ # type: ignore
sound_event_annotation.sound_event.geometry
for sound_event_annotation in sound_event_annotations
]
predicted_geometries = [
raw_prediction.geometry for raw_prediction in raw_predictions
]
if scores is None:
scores = [
raw_prediction.detection_score
for raw_prediction in raw_predictions
]
matches = []
for source_idx, target_idx, affinity in matcher(
ground_truth=target_geometries,
predictions=predicted_geometries,
scores=scores,
):
target = (
sound_event_annotations[target_idx]
if target_idx is not None
else None
)
prediction = (
raw_predictions[source_idx] if source_idx is not None else None
)
gt_det = target_idx is not None
gt_class = targets.encode_class(target) if target is not None else None
gt_geometry = (
target_geometries[target_idx] if target_idx is not None else None
)
pred_score = float(prediction.detection_score) if prediction else 0
pred_geometry = (
predicted_geometries[source_idx]
if source_idx is not None
else None
)
class_scores = (
{
class_name: score
for class_name, score in zip(
targets.class_names,
prediction.class_scores,
)
}
if prediction is not None
else {}
)
matches.append(
MatchEvaluation(
clip=clip,
sound_event_annotation=target,
gt_det=gt_det,
gt_class=gt_class,
gt_geometry=gt_geometry,
pred_score=pred_score,
pred_class_scores=class_scores,
pred_geometry=pred_geometry,
affinity=affinity,
)
)
return ClipMatches(clip=clip, matches=matches)
class StartTimeMatchConfig(BaseConfig):
name: Literal["start_time_match"] = "start_time_match"
distance_threshold: float = 0.01
class StartTimeMatcher(MatcherProtocol):
def __init__(self, distance_threshold: float):
self.distance_threshold = distance_threshold
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
):
return match_start_times(
ground_truth,
predictions,
scores,
distance_threshold=self.distance_threshold,
)
@matching_strategies.register(StartTimeMatchConfig)
@staticmethod
def from_config(config: StartTimeMatchConfig):
return StartTimeMatcher(distance_threshold=config.distance_threshold)
def match_start_times(
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
distance_threshold: float = 0.01,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
if not ground_truth:
for index in range(len(predictions)):
yield index, None, 0
return
if not predictions:
for index in range(len(ground_truth)):
yield None, index, 0
return
gt_times = np.array([compute_bounds(geom)[0] for geom in ground_truth])
pred_times = np.array([compute_bounds(geom)[0] for geom in predictions])
scores = np.array(scores)
sort_args = np.argsort(scores)[::-1]
distances = np.abs(gt_times[None, :] - pred_times[:, None])
closests = np.argmin(distances, axis=-1)
unmatched_gt = set(range(len(gt_times)))
for pred_index in sort_args:
# Get the closest ground truth
gt_closest_index = closests[pred_index]
if gt_closest_index not in unmatched_gt:
# Does not match if closest has been assigned
yield pred_index, None, 0
continue
# Get the actual distance
distance = distances[pred_index, gt_closest_index]
if distance > distance_threshold:
# Does not match if too far from closest
yield pred_index, None, 0
continue
# Return affinity value: linear interpolation between 0 to 1, where a
# distance at the threshold maps to 0 affinity and a zero distance maps
# to 1.
affinity = np.interp(
distance,
[0, distance_threshold],
[1, 0],
left=1,
right=0,
)
unmatched_gt.remove(gt_closest_index)
yield pred_index, gt_closest_index, affinity
for missing_index in unmatched_gt:
yield None, missing_index, 0
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
@ -73,45 +234,58 @@ _geometry_cast_functions: Mapping[
}
def match_geometries(
source: List[data.Geometry],
target: List[data.Geometry],
config: MatchConfig,
scores: Optional[List[float]] = None,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
geometry_cast = _geometry_cast_functions[config.geometry]
if config.strategy == "optimal":
return optimal_match(
source=[geometry_cast(geom) for geom in source],
target=[geometry_cast(geom) for geom in target],
time_buffer=config.time_buffer,
freq_buffer=config.frequency_buffer,
affinity_threshold=config.affinity_threshold,
)
if config.strategy == "greedy":
return greedy_match(
source=[geometry_cast(geom) for geom in source],
target=[geometry_cast(geom) for geom in target],
time_buffer=config.time_buffer,
freq_buffer=config.frequency_buffer,
affinity_threshold=config.affinity_threshold,
scores=scores,
)
raise NotImplementedError(
f"Matching strategy not implemented {config.strategy}"
class GreedyMatchConfig(BaseConfig):
name: Literal["greedy_match"] = "greedy_match"
geometry: MatchingGeometry = "timestamp"
affinity_threshold: float = 0.5
affinity_function: AffinityConfig = Field(
default_factory=GeometricIOUConfig
)
class GreedyMatcher(MatcherProtocol):
def __init__(
self,
geometry: MatchingGeometry,
affinity_threshold: float,
affinity_function: AffinityFunction,
):
self.geometry = geometry
self.affinity_function = affinity_function
self.affinity_threshold = affinity_threshold
self.cast_geometry = _geometry_cast_functions[self.geometry]
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
):
return greedy_match(
ground_truth=[self.cast_geometry(geom) for geom in ground_truth],
predictions=[self.cast_geometry(geom) for geom in predictions],
scores=scores,
affinity_function=self.affinity_function,
affinity_threshold=self.affinity_threshold,
)
@matching_strategies.register(GreedyMatchConfig)
@staticmethod
def from_config(config: GreedyMatchConfig):
affinity_function = build_affinity_function(config.affinity_function)
return GreedyMatcher(
geometry=config.geometry,
affinity_threshold=config.affinity_threshold,
affinity_function=affinity_function,
)
def greedy_match(
source: List[data.Geometry],
target: List[data.Geometry],
scores: Optional[List[float]] = None,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
affinity_threshold: float = 0.5,
time_buffer: float = 0.001,
freq_buffer: float = 1000,
affinity_function: AffinityFunction = compute_affinity,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
"""Performs a greedy, one-to-one matching of source to target geometries.
@ -129,10 +303,6 @@ def greedy_match(
Confidence scores for each source geometry for prioritization.
affinity_threshold
The minimum affinity score required for a valid match.
time_buffer
Time tolerance in seconds for affinity calculation.
freq_buffer
Frequency tolerance in Hertz for affinity calculation.
Yields
------
@ -143,37 +313,29 @@ def greedy_match(
- Unmatched Source (False Positive): `(source_idx, None, 0)`
- Unmatched Target (False Negative): `(None, target_idx, 0)`
"""
assigned = set()
unassigned_gt = set(range(len(ground_truth)))
if not source:
for target_idx in range(len(target)):
yield None, target_idx, 0
if not predictions:
for gt_idx in range(len(ground_truth)):
yield None, gt_idx, 0
return
if not target:
for source_idx in range(len(source)):
yield source_idx, None, 0
if not ground_truth:
for pred_idx in range(len(predictions)):
yield pred_idx, None, 0
return
if scores is None:
indices = np.arange(len(source))
else:
indices = np.argsort(scores)[::-1]
indices = np.argsort(scores)[::-1]
for source_idx in indices:
source_geometry = source[source_idx]
for pred_idx in indices:
source_geometry = predictions[pred_idx]
affinities = np.array(
[
compute_affinity(
source_geometry,
target_geometry,
time_buffer=time_buffer,
freq_buffer=freq_buffer,
)
for target_geometry in target
affinity_function(source_geometry, target_geometry)
for target_geometry in ground_truth
]
)
@ -181,162 +343,72 @@ def greedy_match(
affinity = affinities[closest_target]
if affinities[closest_target] <= affinity_threshold:
yield source_idx, None, 0
yield pred_idx, None, 0
continue
if closest_target in assigned:
yield source_idx, None, 0
if closest_target not in unassigned_gt:
yield pred_idx, None, 0
continue
assigned.add(closest_target)
yield source_idx, closest_target, affinity
unassigned_gt.remove(closest_target)
yield pred_idx, closest_target, affinity
missed_ground_truth = set(range(len(target))) - assigned
for target_idx in missed_ground_truth:
yield None, target_idx, 0
for gt_idx in unassigned_gt:
yield None, gt_idx, 0
def match_sound_events_and_raw_predictions(
clip_annotation: data.ClipAnnotation,
raw_predictions: List[BatDetect2Prediction],
targets: TargetProtocol,
config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
config = config or MatchConfig()
class OptimalMatchConfig(BaseConfig):
name: Literal["optimal_match"] = "optimal_match"
affinity_threshold: float = 0.5
time_buffer: float = 0.005
frequency_buffer: float = 1_000
target_sound_events = [
targets.transform(sound_event_annotation)
for sound_event_annotation in clip_annotation.sound_events
if targets.filter(sound_event_annotation)
and sound_event_annotation.sound_event.geometry is not None
]
target_geometries: List[data.Geometry] = [ # type: ignore
sound_event_annotation.sound_event.geometry
for sound_event_annotation in target_sound_events
if sound_event_annotation.sound_event.geometry is not None
]
predicted_geometries = [
raw_prediction.raw.geometry for raw_prediction in raw_predictions
]
scores = [
raw_prediction.raw.detection_score
for raw_prediction in raw_predictions
]
matches = []
for source_idx, target_idx, affinity in match_geometries(
source=predicted_geometries,
target=target_geometries,
config=config,
scores=scores,
class OptimalMatcher(MatcherProtocol):
def __init__(
self,
affinity_threshold: float,
time_buffer: float,
frequency_buffer: float,
):
target = (
target_sound_events[target_idx] if target_idx is not None else None
)
prediction = (
raw_predictions[source_idx] if source_idx is not None else None
)
self.affinity_threshold = affinity_threshold
self.time_buffer = time_buffer
self.frequency_buffer = frequency_buffer
gt_det = target is not None
gt_class = targets.encode_class(target) if target is not None else None
pred_score = float(prediction.raw.detection_score) if prediction else 0
class_scores = (
{
str(class_name): float(score)
for class_name, score in zip(
targets.class_names,
prediction.raw.class_scores,
)
}
if prediction is not None
else {}
)
matches.append(
MatchEvaluation(
match=data.Match(
source=None
if prediction is None
else prediction.sound_event_prediction,
target=target,
affinity=affinity,
),
gt_det=gt_det,
gt_class=gt_class,
pred_score=pred_score,
pred_class_scores=class_scores,
)
)
return matches
def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation,
clip_prediction: data.ClipPrediction,
config: Optional[MatchConfig] = None,
) -> List[data.Match]:
config = config or MatchConfig()
annotated_sound_events = [
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
if sound_event_annotation.sound_event.geometry is not None
]
predicted_sound_events = [
sound_event_prediction
for sound_event_prediction in clip_prediction.sound_events
if sound_event_prediction.sound_event.geometry is not None
]
annotated_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in annotated_sound_events
if sound_event.sound_event.geometry is not None
]
predicted_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
scores = [
sound_event.score
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
matches = []
for source_idx, target_idx, affinity in match_geometries(
source=predicted_geometries,
target=annotated_geometries,
config=config,
scores=scores,
def __call__(
self,
ground_truth: Sequence[data.Geometry],
predictions: Sequence[data.Geometry],
scores: Sequence[float],
):
target = (
annotated_sound_events[target_idx]
if target_idx is not None
else None
)
source = (
predicted_sound_events[source_idx]
if source_idx is not None
else None
)
matches.append(
data.Match(
source=source,
target=target,
affinity=affinity,
)
return optimal_match(
source=predictions,
target=ground_truth,
time_buffer=self.time_buffer,
freq_buffer=self.frequency_buffer,
affinity_threshold=self.affinity_threshold,
)
return matches
@matching_strategies.register(OptimalMatchConfig)
@staticmethod
def from_config(config: OptimalMatchConfig):
return OptimalMatcher(
affinity_threshold=config.affinity_threshold,
time_buffer=config.time_buffer,
frequency_buffer=config.frequency_buffer,
)
MatchConfig = Annotated[
Union[
GreedyMatchConfig,
StartTimeMatchConfig,
OptimalMatchConfig,
],
Field(discriminator="name"),
]
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
config = config or StartTimeMatchConfig()
return matching_strategies.build(config)

View File

@ -1,97 +0,0 @@
from typing import Dict, List
import pandas as pd
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
__all__ = ["DetectionAveragePrecision"]
class DetectionAveragePrecision(MetricsProtocol):
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true, y_score = zip(
*[(match.gt_det, match.pred_score) for match in matches]
)
score = float(metrics.average_precision_score(y_true, y_score))
return {"detection_AP": score}
class ClassificationMeanAveragePrecision(MetricsProtocol):
def __init__(self, class_names: List[str], per_class: bool = True):
self.class_names = class_names
self.per_class = per_class
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true = label_binarize(
[
match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches
],
classes=self.class_names,
)
y_pred = pd.DataFrame(
[
{
name: match.pred_class_scores.get(name, 0)
for name in self.class_names
}
for match in matches
]
).fillna(0)
mAP = metrics.average_precision_score(y_true, y_pred[self.class_names])
ret = {
"classification_mAP": float(mAP),
}
if not self.per_class:
return ret
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[class_name]
class_ap = metrics.average_precision_score(
y_true_class,
y_pred_class,
)
ret[f"classification_AP/{class_name}"] = float(class_ap)
return ret
class ClassificationAccuracy(MetricsProtocol):
def __init__(self, class_names: List[str]):
self.class_names = class_names
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true = [
match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches
]
y_pred = pd.DataFrame(
[
{
name: match.pred_class_scores.get(name, 0)
for name in self.class_names
}
for match in matches
]
).fillna(0)
y_pred = y_pred.apply(
lambda row: row.idxmax()
if row.max() >= (1 - row.sum())
else "__NONE__",
axis=1,
)
accuracy = metrics.balanced_accuracy_score(
y_true,
y_pred,
)
return {
"classification_acc": float(accuracy),
}

View File

@ -0,0 +1,267 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import (
Annotated,
Callable,
Dict,
List,
Literal,
Mapping,
Optional,
Sequence,
Union,
)
import numpy as np
from pydantic import Field
from sklearn import metrics
from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction, TargetProtocol
__all__ = [
"ClassificationMetric",
"ClassificationMetricConfig",
"build_classification_metric",
]
@dataclass
class MatchEval:
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
is_prediction: bool
is_ground_truth: bool
is_generic: bool
true_class: Optional[str]
score: float
@dataclass
class ClipEval:
clip: data.Clip
matches: Mapping[str, List[MatchEval]]
ClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
Registry("classification_metric")
)
class BaseClassificationConfig(BaseConfig):
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class BaseClassificationMetric:
def __init__(
self,
targets: TargetProtocol,
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.targets = targets
self.include = include
self.exclude = exclude
def include_class(self, class_name: str) -> bool:
if self.include is not None:
return class_name in self.include
if self.exclude is not None:
return class_name not in self.exclude
return True
class ClassificationAveragePrecisionConfig(BaseClassificationConfig):
name: Literal["average_precision"] = "average_precision"
ignore_non_predictions: bool = True
ignore_generic: bool = True
label: str = "average_precision"
class ClassificationAveragePrecision(BaseClassificationMetric):
def __init__(
self,
targets: TargetProtocol,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
label: str = "average_precision",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
super().__init__(include=include, exclude=exclude, targets=targets)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.label = label
def __call__(
self, clip_evaluations: Sequence[ClipEval]
) -> Dict[str, float]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
class_scores = {
class_name: average_precision(
y_true[class_name],
y_score[class_name],
num_positives=num_positives[class_name],
)
for class_name in self.targets.class_names
}
mean_score = float(
np.mean([v for v in class_scores.values() if not np.isnan(v)])
)
return {
f"mean_{self.label}": mean_score,
**{
f"{self.label}/{class_name}": score
for class_name, score in class_scores.items()
if self.include_class(class_name)
},
}
@classification_metrics.register(ClassificationAveragePrecisionConfig)
@staticmethod
def from_config(
config: ClassificationAveragePrecisionConfig,
targets: TargetProtocol,
):
return ClassificationAveragePrecision(
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
label=config.label,
include=config.include,
exclude=config.exclude,
)
class ClassificationROCAUCConfig(BaseClassificationConfig):
name: Literal["roc_auc"] = "roc_auc"
label: str = "roc_auc"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ClassificationROCAUC(BaseClassificationMetric):
def __init__(
self,
targets: TargetProtocol,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
label: str = "roc_auc",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.targets = targets
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.label = label
self.include = include
self.exclude = exclude
def __call__(
self, clip_evaluations: Sequence[ClipEval]
) -> Dict[str, float]:
y_true, y_score, _ = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
class_scores = {
class_name: float(
metrics.roc_auc_score(
y_true[class_name],
y_score[class_name],
)
)
for class_name in self.targets.class_names
}
mean_score = float(
np.mean([v for v in class_scores.values() if v != np.nan])
)
return {
f"mean_{self.label}": mean_score,
**{
f"{self.label}/{class_name}": score
for class_name, score in class_scores.items()
if self.include_class(class_name)
},
}
@classification_metrics.register(ClassificationROCAUCConfig)
@staticmethod
def from_config(
config: ClassificationROCAUCConfig, targets: TargetProtocol
):
return ClassificationROCAUC(
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
label=config.label,
)
ClassificationMetricConfig = Annotated[
Union[
ClassificationAveragePrecisionConfig,
ClassificationROCAUCConfig,
],
Field(discriminator="name"),
]
def build_classification_metric(
config: ClassificationMetricConfig,
targets: TargetProtocol,
) -> ClassificationMetric:
return classification_metrics.build(config, targets)
def _extract_per_class_metric_data(
clip_evaluations: Sequence[ClipEval],
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
):
y_true = defaultdict(list)
y_score = defaultdict(list)
num_positives = defaultdict(lambda: 0)
for clip_eval in clip_evaluations:
for class_name, matches in clip_eval.matches.items():
for m in matches:
# Exclude matches with ground truth sounds where the class
# is unknown
if m.is_generic and ignore_generic:
continue
is_class = m.true_class == class_name
if is_class:
num_positives[class_name] += 1
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and ignore_non_predictions:
continue
y_true[class_name].append(is_class)
y_score[class_name].append(m.score)
return y_true, y_score, num_positives

View File

@ -0,0 +1,135 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union
import numpy as np
from pydantic import Field
from sklearn import metrics
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.evaluate.metrics.common import average_precision
@dataclass
class ClipEval:
true_classes: Set[str]
class_scores: Dict[str, float]
ClipClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
clip_classification_metrics: Registry[ClipClassificationMetric, []] = Registry(
"clip_classification_metric"
)
class ClipClassificationAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision"
label: str = "average_precision"
class ClipClassificationAveragePrecision:
def __init__(self, label: str = "average_precision"):
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = defaultdict(list)
y_score = defaultdict(list)
for clip_eval in clip_evaluations:
for class_name, score in clip_eval.class_scores.items():
y_true[class_name].append(class_name in clip_eval.true_classes)
y_score[class_name].append(score)
class_scores = {
class_name: float(
average_precision(
y_true=y_true[class_name],
y_score=y_score[class_name],
)
)
for class_name in y_true
}
mean = np.mean([v for v in class_scores.values() if not np.isnan(v)])
return {
f"mean_{self.label}": float(mean),
**{
f"{self.label}/{class_name}": score
for class_name, score in class_scores.items()
if not np.isnan(score)
},
}
@clip_classification_metrics.register(
ClipClassificationAveragePrecisionConfig
)
@staticmethod
def from_config(config: ClipClassificationAveragePrecisionConfig):
return ClipClassificationAveragePrecision(label=config.label)
class ClipClassificationROCAUCConfig(BaseConfig):
name: Literal["roc_auc"] = "roc_auc"
label: str = "roc_auc"
class ClipClassificationROCAUC:
def __init__(self, label: str = "roc_auc"):
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = defaultdict(list)
y_score = defaultdict(list)
for clip_eval in clip_evaluations:
for class_name, score in clip_eval.class_scores.items():
y_true[class_name].append(class_name in clip_eval.true_classes)
y_score[class_name].append(score)
class_scores = {
class_name: float(
metrics.roc_auc_score(
y_true=y_true[class_name],
y_score=y_score[class_name],
)
)
for class_name in y_true
}
mean = np.mean([v for v in class_scores.values() if not np.isnan(v)])
return {
f"mean_{self.label}": float(mean),
**{
f"{self.label}/{class_name}": score
for class_name, score in class_scores.items()
if not np.isnan(score)
},
}
@clip_classification_metrics.register(ClipClassificationROCAUCConfig)
@staticmethod
def from_config(config: ClipClassificationROCAUCConfig):
return ClipClassificationROCAUC(label=config.label)
ClipClassificationMetricConfig = Annotated[
Union[
ClipClassificationAveragePrecisionConfig,
ClipClassificationROCAUCConfig,
],
Field(discriminator="name"),
]
def build_clip_metric(config: ClipClassificationMetricConfig):
return clip_classification_metrics.build(config)

View File

@ -0,0 +1,173 @@
from dataclasses import dataclass
from typing import Annotated, Callable, Dict, Literal, Sequence, Union
import numpy as np
from pydantic import Field
from sklearn import metrics
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.evaluate.metrics.common import average_precision
@dataclass
class ClipEval:
gt_det: bool
score: float
ClipDetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
clip_detection_metrics: Registry[ClipDetectionMetric, []] = Registry(
"clip_detection_metric"
)
class ClipDetectionAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision"
label: str = "average_precision"
class ClipDetectionAveragePrecision:
def __init__(self, label: str = "average_precision"):
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
y_true.append(clip_eval.gt_det)
y_score.append(clip_eval.score)
score = average_precision(y_true=y_true, y_score=y_score)
return {self.label: score}
@clip_detection_metrics.register(ClipDetectionAveragePrecisionConfig)
@staticmethod
def from_config(config: ClipDetectionAveragePrecisionConfig):
return ClipDetectionAveragePrecision(label=config.label)
class ClipDetectionROCAUCConfig(BaseConfig):
name: Literal["roc_auc"] = "roc_auc"
label: str = "roc_auc"
class ClipDetectionROCAUC:
def __init__(self, label: str = "roc_auc"):
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
y_true.append(clip_eval.gt_det)
y_score.append(clip_eval.score)
score = float(metrics.roc_auc_score(y_true=y_true, y_score=y_score))
return {self.label: score}
@clip_detection_metrics.register(ClipDetectionROCAUCConfig)
@staticmethod
def from_config(config: ClipDetectionROCAUCConfig):
return ClipDetectionROCAUC(label=config.label)
class ClipDetectionRecallConfig(BaseConfig):
name: Literal["recall"] = "recall"
threshold: float = 0.5
label: str = "recall"
class ClipDetectionRecall:
def __init__(self, threshold: float, label: str = "recall"):
self.threshold = threshold
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_positives = 0
true_positives = 0
for clip_eval in clip_evaluations:
if clip_eval.gt_det:
num_positives += 1
if clip_eval.score >= self.threshold and clip_eval.gt_det:
true_positives += 1
if num_positives == 0:
return {self.label: np.nan}
score = true_positives / num_positives
return {self.label: score}
@clip_detection_metrics.register(ClipDetectionRecallConfig)
@staticmethod
def from_config(config: ClipDetectionRecallConfig):
return ClipDetectionRecall(
threshold=config.threshold, label=config.label
)
class ClipDetectionPrecisionConfig(BaseConfig):
name: Literal["precision"] = "precision"
threshold: float = 0.5
label: str = "precision"
class ClipDetectionPrecision:
def __init__(self, threshold: float, label: str = "precision"):
self.threshold = threshold
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_detections = 0
true_positives = 0
for clip_eval in clip_evaluations:
if clip_eval.score >= self.threshold:
num_detections += 1
if clip_eval.score >= self.threshold and clip_eval.gt_det:
true_positives += 1
if num_detections == 0:
return {self.label: np.nan}
score = true_positives / num_detections
return {self.label: score}
@clip_detection_metrics.register(ClipDetectionPrecisionConfig)
@staticmethod
def from_config(config: ClipDetectionPrecisionConfig):
return ClipDetectionPrecision(
threshold=config.threshold, label=config.label
)
ClipDetectionMetricConfig = Annotated[
Union[
ClipDetectionAveragePrecisionConfig,
ClipDetectionROCAUCConfig,
ClipDetectionRecallConfig,
ClipDetectionPrecisionConfig,
],
Field(discriminator="name"),
]
def build_clip_metric(config: ClipDetectionMetricConfig):
return clip_detection_metrics.build(config)

View File

@ -0,0 +1,63 @@
from typing import Optional, Tuple
import numpy as np
__all__ = [
"compute_precision_recall",
"average_precision",
]
def compute_precision_recall(
y_true,
y_score,
num_positives: Optional[int] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
y_true = np.array(y_true)
y_score = np.array(y_score)
if num_positives is None:
num_positives = y_true.sum()
# Sort by score
sort_ind = np.argsort(y_score)[::-1]
y_true_sorted = y_true[sort_ind]
y_score_sorted = y_score[sort_ind]
false_pos_c = np.cumsum(1 - y_true_sorted)
true_pos_c = np.cumsum(y_true_sorted)
recall = true_pos_c / num_positives
precision = true_pos_c / np.maximum(
true_pos_c + false_pos_c,
np.finfo(np.float64).eps,
)
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
return precision, recall, y_score_sorted
def average_precision(
y_true,
y_score,
num_positives: Optional[int] = None,
) -> float:
if num_positives == 0:
return np.nan
precision, recall, _ = compute_precision_recall(
y_true,
y_score,
num_positives=num_positives,
)
# pascal 12 way
mprec = np.hstack((0, precision, 0))
mrec = np.hstack((0, recall, 1))
for ii in range(mprec.shape[0] - 2, -1, -1):
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
return ave_prec

View File

@ -0,0 +1,226 @@
from dataclasses import dataclass
from typing import (
Annotated,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np
from pydantic import Field
from sklearn import metrics
from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction
__all__ = [
"DetectionMetricConfig",
"DetectionMetric",
"build_detection_metric",
]
@dataclass
class MatchEval:
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
is_prediction: bool
is_ground_truth: bool
score: float
@dataclass
class ClipEval:
clip: data.Clip
matches: List[MatchEval]
DetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric")
class DetectionAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision"
label: str = "average_precision"
ignore_non_predictions: bool = True
class DetectionAveragePrecision:
def __init__(self, label: str, ignore_non_predictions: bool = True):
self.ignore_non_predictions = ignore_non_predictions
self.label = label
def __call__(
self,
clip_evals: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evals:
for m in clip_eval.matches:
num_positives += int(m.is_ground_truth)
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and self.ignore_non_predictions:
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
ap = average_precision(y_true, y_score, num_positives=num_positives)
return {self.label: ap}
@detection_metrics.register(DetectionAveragePrecisionConfig)
@staticmethod
def from_config(config: DetectionAveragePrecisionConfig):
return DetectionAveragePrecision(
label=config.label,
ignore_non_predictions=config.ignore_non_predictions,
)
class DetectionROCAUCConfig(BaseConfig):
name: Literal["roc_auc"] = "roc_auc"
label: str = "roc_auc"
ignore_non_predictions: bool = True
class DetectionROCAUC:
def __init__(
self,
label: str = "roc_auc",
ignore_non_predictions: bool = True,
):
self.label = label
self.ignore_non_predictions = ignore_non_predictions
def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]:
y_true: List[bool] = []
y_score: List[float] = []
for clip_eval in clip_evals:
for m in clip_eval.matches:
if not m.is_prediction and self.ignore_non_predictions:
# Ignore matches that don't correspond to a prediction
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
score = float(metrics.roc_auc_score(y_true, y_score))
return {self.label: score}
@detection_metrics.register(DetectionROCAUCConfig)
@staticmethod
def from_config(config: DetectionROCAUCConfig):
return DetectionROCAUC(
label=config.label,
ignore_non_predictions=config.ignore_non_predictions,
)
class DetectionRecallConfig(BaseConfig):
name: Literal["recall"] = "recall"
label: str = "recall"
threshold: float = 0.5
class DetectionRecall:
def __init__(self, threshold: float, label: str = "recall"):
self.label = label
self.threshold = threshold
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_positives = 0
true_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_ground_truth:
num_positives += 1
if m.score >= self.threshold and m.is_ground_truth:
true_positives += 1
if num_positives == 0:
return {self.label: np.nan}
score = true_positives / num_positives
return {self.label: score}
@detection_metrics.register(DetectionRecallConfig)
@staticmethod
def from_config(config: DetectionRecallConfig):
return DetectionRecall(threshold=config.threshold, label=config.label)
class DetectionPrecisionConfig(BaseConfig):
name: Literal["precision"] = "precision"
label: str = "precision"
threshold: float = 0.5
class DetectionPrecision:
def __init__(self, threshold: float, label: str = "precision"):
self.threshold = threshold
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_detections = 0
true_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
is_detection = m.score >= self.threshold
if is_detection:
num_detections += 1
if is_detection and m.is_ground_truth:
true_positives += 1
if num_detections == 0:
return {self.label: np.nan}
score = true_positives / num_detections
return {self.label: score}
@detection_metrics.register(DetectionPrecisionConfig)
@staticmethod
def from_config(config: DetectionPrecisionConfig):
return DetectionPrecision(
threshold=config.threshold,
label=config.label,
)
DetectionMetricConfig = Annotated[
Union[
DetectionAveragePrecisionConfig,
DetectionROCAUCConfig,
DetectionRecallConfig,
DetectionPrecisionConfig,
],
Field(discriminator="name"),
]
def build_detection_metric(config: DetectionMetricConfig):
return detection_metrics.build(config)

View File

@ -0,0 +1,314 @@
from dataclasses import dataclass
from typing import (
Annotated,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np
from pydantic import Field
from sklearn import metrics, preprocessing
from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction
__all__ = [
"TopClassMetricConfig",
"TopClassMetric",
"build_top_class_metric",
]
@dataclass
class MatchEval:
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
is_ground_truth: bool
is_generic: bool
is_prediction: bool
pred_class: Optional[str]
true_class: Optional[str]
score: float
@dataclass
class ClipEval:
clip: data.Clip
matches: List[MatchEval]
TopClassMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric")
class TopClassAveragePrecisionConfig(BaseConfig):
name: Literal["average_precision"] = "average_precision"
label: str = "average_precision"
ignore_generic: bool = True
ignore_non_predictions: bool = True
class TopClassAveragePrecision:
def __init__(
self,
ignore_generic: bool = True,
ignore_non_predictions: bool = True,
label: str = "average_precision",
):
self.ignore_generic = ignore_generic
self.ignore_non_predictions = ignore_non_predictions
self.label = label
def __call__(
self,
clip_evals: Sequence[ClipEval],
) -> Dict[str, float]:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evals:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore gt sounds with unknown class
continue
num_positives += int(m.is_ground_truth)
if not m.is_prediction and self.ignore_non_predictions:
# Ignore non predictions
continue
y_true.append(m.pred_class == m.true_class)
y_score.append(m.score)
score = average_precision(y_true, y_score, num_positives=num_positives)
return {self.label: score}
@top_class_metrics.register(TopClassAveragePrecisionConfig)
@staticmethod
def from_config(config: TopClassAveragePrecisionConfig):
return TopClassAveragePrecision(
ignore_generic=config.ignore_generic,
label=config.label,
)
class TopClassROCAUCConfig(BaseConfig):
name: Literal["roc_auc"] = "roc_auc"
ignore_generic: bool = True
ignore_non_predictions: bool = True
label: str = "roc_auc"
class TopClassROCAUC:
def __init__(
self,
ignore_generic: bool = True,
ignore_non_predictions: bool = True,
label: str = "roc_auc",
):
self.ignore_generic = ignore_generic
self.ignore_non_predictions = ignore_non_predictions
self.label = label
def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]:
y_true: List[bool] = []
y_score: List[float] = []
for clip_eval in clip_evals:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore gt sounds with unknown class
continue
if not m.is_prediction and self.ignore_non_predictions:
# Ignore non predictions
continue
y_true.append(m.pred_class == m.true_class)
y_score.append(m.score)
score = float(metrics.roc_auc_score(y_true, y_score))
return {self.label: score}
@top_class_metrics.register(TopClassROCAUCConfig)
@staticmethod
def from_config(config: TopClassROCAUCConfig):
return TopClassROCAUC(
ignore_generic=config.ignore_generic,
label=config.label,
)
class TopClassRecallConfig(BaseConfig):
name: Literal["recall"] = "recall"
threshold: float = 0.5
label: str = "recall"
class TopClassRecall:
def __init__(self, threshold: float, label: str = "recall"):
self.threshold = threshold
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_positives = 0
true_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_ground_truth:
num_positives += 1
if m.score >= self.threshold and m.pred_class == m.true_class:
true_positives += 1
if num_positives == 0:
return {self.label: np.nan}
score = true_positives / num_positives
return {self.label: score}
@top_class_metrics.register(TopClassRecallConfig)
@staticmethod
def from_config(config: TopClassRecallConfig):
return TopClassRecall(
threshold=config.threshold,
label=config.label,
)
class TopClassPrecisionConfig(BaseConfig):
name: Literal["precision"] = "precision"
threshold: float = 0.5
label: str = "precision"
class TopClassPrecision:
def __init__(self, threshold: float, label: str = "precision"):
self.threshold = threshold
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
num_detections = 0
true_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
is_detection = m.score >= self.threshold
if is_detection:
num_detections += 1
if is_detection and m.pred_class == m.true_class:
true_positives += 1
if num_detections == 0:
return {self.label: np.nan}
score = true_positives / num_detections
return {self.label: score}
@top_class_metrics.register(TopClassPrecisionConfig)
@staticmethod
def from_config(config: TopClassPrecisionConfig):
return TopClassPrecision(
threshold=config.threshold,
label=config.label,
)
class BalancedAccuracyConfig(BaseConfig):
name: Literal["balanced_accuracy"] = "balanced_accuracy"
label: str = "balanced_accuracy"
exclude_noise: bool = False
noise_class: str = "noise"
class BalancedAccuracy:
def __init__(
self,
exclude_noise: bool = True,
noise_class: str = "noise",
label: str = "balanced_accuracy",
):
self.exclude_noise = exclude_noise
self.noise_class = noise_class
self.label = label
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Dict[str, float]:
y_true: List[str] = []
y_pred: List[str] = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_generic:
# Ignore matches that correspond to a sound event
# with unknown class
continue
if not m.is_ground_truth and self.exclude_noise:
# Ignore predictions that were not matched to a
# ground truth
continue
if m.pred_class is None and self.exclude_noise:
# Ignore non-predictions
continue
y_true.append(m.true_class or self.noise_class)
y_pred.append(m.pred_class or self.noise_class)
encoder = preprocessing.LabelEncoder()
encoder.fit(list(set(y_true) | set(y_pred)))
y_true = encoder.transform(y_true)
y_pred = encoder.transform(y_pred)
score = metrics.balanced_accuracy_score(y_true, y_pred)
return {self.label: score}
@top_class_metrics.register(BalancedAccuracyConfig)
@staticmethod
def from_config(config: BalancedAccuracyConfig):
return BalancedAccuracy(
exclude_noise=config.exclude_noise,
noise_class=config.noise_class,
label=config.label,
)
TopClassMetricConfig = Annotated[
Union[
TopClassAveragePrecisionConfig,
TopClassROCAUCConfig,
TopClassRecallConfig,
TopClassPrecisionConfig,
BalancedAccuracyConfig,
],
Field(discriminator="name"),
]
def build_top_class_metric(config: TopClassMetricConfig):
return top_class_metrics.build(config)

View File

@ -0,0 +1,54 @@
from typing import Optional
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from batdetect2.core import BaseConfig
from batdetect2.typing import TargetProtocol
class BasePlotConfig(BaseConfig):
label: str = "plot"
theme: str = "default"
title: Optional[str] = None
figsize: tuple[int, int] = (10, 10)
dpi: int = 100
class BasePlot:
def __init__(
self,
targets: TargetProtocol,
label: str = "plot",
figsize: tuple[int, int] = (10, 10),
title: Optional[str] = None,
dpi: int = 100,
theme: str = "default",
):
self.targets = targets
self.label = label
self.figsize = figsize
self.dpi = dpi
self.theme = theme
self.title = title
def create_figure(self) -> Figure:
plt.style.use(self.theme)
fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
if self.title is not None:
fig.suptitle(self.title)
return fig
@classmethod
def build(cls, config: BasePlotConfig, targets: TargetProtocol, **kwargs):
return cls(
targets=targets,
figsize=config.figsize,
dpi=config.dpi,
theme=config.theme,
label=config.label,
title=config.title,
**kwargs,
)

View File

@ -0,0 +1,370 @@
from typing import (
Annotated,
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.classification import (
ClipEval,
_extract_per_class_metric_data,
)
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import (
plot_pr_curve,
plot_pr_curves,
plot_roc_curve,
plot_roc_curves,
plot_threshold_precision_curve,
plot_threshold_precision_curves,
plot_threshold_recall_curve,
plot_threshold_recall_curves,
)
from batdetect2.typing import TargetProtocol
ClassificationPlotter = Callable[
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
]
classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
Registry("classification_plot")
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Classification Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
class PRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives=num_positives[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (precision, recall, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
separate_figures=config.separate_figures,
)
class ThresholdPrecisionCurveConfig(BasePlotConfig):
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
label: str = "threshold_precision_curve"
title: Optional[str] = "Classification Threshold-Precision Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
class ThresholdPrecisionCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_threshold_precision_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (precision, _, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_threshold_precision_curve(
thresholds,
precision,
ax=ax,
)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(ThresholdPrecisionCurveConfig)
@staticmethod
def from_config(
config: ThresholdPrecisionCurveConfig, targets: TargetProtocol
):
return ThresholdPrecisionCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
separate_figures=config.separate_figures,
)
class ThresholdRecallCurveConfig(BasePlotConfig):
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
label: str = "threshold_recall_curve"
title: Optional[str] = "Classification Threshold-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
class ThresholdRecallCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_threshold_recall_curves(data, ax=ax, add_legend=True)
yield self.label, fig
return
for class_name, (_, recall, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_threshold_recall_curve(
thresholds,
recall,
ax=ax,
)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(ThresholdRecallCurveConfig)
@staticmethod
def from_config(
config: ThresholdRecallCurveConfig, targets: TargetProtocol
):
return ThresholdRecallCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
separate_figures=config.separate_figures,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Classification ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
class ROCCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, _ = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: metrics.roc_curve(
y_true[class_name],
y_score[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (fpr, tpr, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
separate_figures=config.separate_figures,
)
ClassificationPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ThresholdPrecisionCurveConfig,
ThresholdRecallCurveConfig,
],
Field(discriminator="name"),
]
def build_classification_plotter(
config: ClassificationPlotConfig,
targets: TargetProtocol,
) -> ClassificationPlotter:
return classification_plots.build(config, targets)

View File

@ -0,0 +1,189 @@
from typing import (
Annotated,
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.clip_classification import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import (
plot_pr_curve,
plot_pr_curves,
plot_roc_curve,
plot_roc_curves,
)
from batdetect2.typing import TargetProtocol
__all__ = [
"ClipClassificationPlotConfig",
"ClipClassificationPlotter",
"build_clip_classification_plotter",
]
ClipClassificationPlotter = Callable[
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
]
clip_classification_plots: Registry[
ClipClassificationPlotter, [TargetProtocol]
] = Registry("clip_classification_plot")
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Clip Classification Precision-Recall Curve"
separate_figures: bool = False
class PRCurve(BasePlot):
def __init__(
self,
*args,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
data = {}
for class_name in self.targets.class_names:
y_true = [class_name in c.true_classes for c in clip_evaluations]
y_score = [
c.class_scores.get(class_name, 0) for c in clip_evaluations
]
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
)
data[class_name] = (precision, recall, thresholds)
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (precision, recall, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@clip_classification_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
separate_figures=config.separate_figures,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Clip Classification ROC Curve"
separate_figures: bool = False
class ROCCurve(BasePlot):
def __init__(
self,
*args,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
data = {}
for class_name in self.targets.class_names:
y_true = [class_name in c.true_classes for c in clip_evaluations]
y_score = [
c.class_scores.get(class_name, 0) for c in clip_evaluations
]
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
data[class_name] = (fpr, tpr, thresholds)
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (fpr, tpr, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@clip_classification_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
separate_figures=config.separate_figures,
)
ClipClassificationPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
],
Field(discriminator="name"),
]
def build_clip_classification_plotter(
config: ClipClassificationPlotConfig,
targets: TargetProtocol,
) -> ClipClassificationPlotter:
return clip_classification_plots.build(config, targets)

View File

@ -0,0 +1,163 @@
from typing import (
Annotated,
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import pandas as pd
import seaborn as sns
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.clip_detection import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.typing import TargetProtocol
__all__ = [
"ClipDetectionPlotConfig",
"ClipDetectionPlotter",
"build_clip_detection_plotter",
]
ClipDetectionPlotter = Callable[
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
]
clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
Registry("clip_detection_plot")
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Clip Detection Precision-Recall Curve"
class PRCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
)
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
yield self.label, fig
@clip_detection_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Clip Detection ROC Curve"
class ROCCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
yield self.label, fig
@clip_detection_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
)
class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution"
title: Optional[str] = "Clip Detection Score Distribution"
class ScoreDistributionPlot(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
fig = self.create_figure()
ax = fig.subplots()
df = pd.DataFrame({"is_true": y_true, "score": y_score})
sns.histplot(
data=df,
x="score",
binwidth=0.025,
binrange=(0, 1),
hue="is_true",
ax=ax,
stat="probability",
common_norm=False,
)
yield self.label, fig
@clip_detection_plots.register(ScoreDistributionPlotConfig)
@staticmethod
def from_config(
config: ScoreDistributionPlotConfig, targets: TargetProtocol
):
return ScoreDistributionPlot.build(
config=config,
targets=targets,
)
ClipDetectionPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
],
Field(discriminator="name"),
]
def build_clip_detection_plotter(
config: ClipDetectionPlotConfig,
targets: TargetProtocol,
) -> ClipDetectionPlotter:
return clip_detection_plots.build(config, targets)

View File

@ -0,0 +1,309 @@
import random
from typing import (
Annotated,
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.detection import ClipEval
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.detections import plot_clip_detections
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
name="detection_plot"
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Detection Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class PRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evals: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evals:
for m in clip_eval.matches:
num_positives += int(m.is_ground_truth)
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and self.ignore_non_predictions:
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
num_positives=num_positives,
)
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
yield self.label, fig
@detection_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Detection ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ROCCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if not m.is_prediction and self.ignore_non_predictions:
# Ignore matches that don't correspond to a prediction
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
yield self.label, fig
@detection_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution"
title: Optional[str] = "Detection Score Distribution"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ScoreDistributionPlot(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if not m.is_prediction and self.ignore_non_predictions:
# Ignore matches that don't correspond to a prediction
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
df = pd.DataFrame({"is_true": y_true, "score": y_score})
fig = self.create_figure()
ax = fig.subplots()
sns.histplot(
data=df,
x="score",
binwidth=0.025,
binrange=(0, 1),
hue="is_true",
ax=ax,
stat="probability",
common_norm=False,
)
yield self.label, fig
@detection_plots.register(ScoreDistributionPlotConfig)
@staticmethod
def from_config(
config: ScoreDistributionPlotConfig, targets: TargetProtocol
):
return ScoreDistributionPlot.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ExampleDetectionPlotConfig(BasePlotConfig):
name: Literal["example_detection"] = "example_detection"
label: str = "example_detection"
title: Optional[str] = "Example Detection"
figsize: tuple[int, int] = (10, 4)
num_examples: int = 5
threshold: float = 0.2
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class ExampleDetectionPlot(BasePlot):
def __init__(
self,
*args,
num_examples: int = 5,
threshold: float = 0.2,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_examples = num_examples
self.audio_loader = audio_loader
self.threshold = threshold
self.preprocessor = preprocessor
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
sample = clip_evaluations
if self.num_examples < len(sample):
sample = random.sample(sample, self.num_examples)
for num_example, clip_eval in enumerate(sample):
fig = self.create_figure()
ax = fig.subplots()
plot_clip_detections(
clip_eval,
ax=ax,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
)
yield f"{self.label}/example_{num_example}", fig
plt.close(fig)
@detection_plots.register(ExampleDetectionPlotConfig)
@staticmethod
def from_config(
config: ExampleDetectionPlotConfig,
targets: TargetProtocol,
):
return ExampleDetectionPlot.build(
config=config,
targets=targets,
num_examples=config.num_examples,
audio_loader=build_audio_loader(config.audio),
preprocessor=build_preprocessor(config.preprocessing),
)
DetectionPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
ExampleDetectionPlotConfig,
],
Field(discriminator="name"),
]
def build_detection_plotter(
config: DetectionPlotConfig,
targets: TargetProtocol,
) -> DetectionPlotter:
return detection_plots.build(config, targets)

View File

@ -0,0 +1,444 @@
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import (
Annotated,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.top_class import ClipEval, MatchEval
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
name="top_class_plot"
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Top Class Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class PRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore gt sounds with unknown class
continue
num_positives += int(m.is_ground_truth)
if not m.is_prediction and self.ignore_non_predictions:
# Ignore non predictions
continue
y_true.append(m.pred_class == m.true_class)
y_score.append(m.score)
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
num_positives=num_positives,
)
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
yield self.label, fig
@top_class_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Top Class ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ROCCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore gt sounds with unknown class
continue
if not m.is_prediction and self.ignore_non_predictions:
# Ignore non predictions
continue
y_true.append(m.pred_class == m.true_class)
y_score.append(m.score)
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
yield self.label, fig
@top_class_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ConfusionMatrixConfig(BasePlotConfig):
name: Literal["confusion_matrix"] = "confusion_matrix"
title: Optional[str] = "Top Class Confusion Matrix"
figsize: tuple[int, int] = (10, 10)
label: str = "confusion_matrix"
exclude_generic: bool = True
exclude_noise: bool = False
noise_class: str = "noise"
normalize: Literal["true", "pred", "all", "none"] = "true"
threshold: float = 0.2
add_colorbar: bool = True
cmap: str = "Blues"
class ConfusionMatrix(BasePlot):
def __init__(
self,
*args,
exclude_generic: bool = True,
exclude_noise: bool = False,
noise_class: str = "noise",
add_colorbar: bool = True,
normalize: Literal["true", "pred", "all", "none"] = "true",
cmap: str = "Blues",
threshold: float = 0.2,
**kwargs,
):
super().__init__(*args, **kwargs)
self.exclude_generic = exclude_generic
self.exclude_noise = exclude_noise
self.noise_class = noise_class
self.normalize = normalize
self.add_colorbar = add_colorbar
self.threshold = threshold
self.cmap = cmap
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true: List[str] = []
y_pred: List[str] = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
true_class = m.true_class
pred_class = m.pred_class
if not m.is_prediction and self.exclude_noise:
# Ignore matches that don't correspond to a prediction
continue
if not m.is_ground_truth and self.exclude_noise:
# Ignore matches that don't correspond to a ground truth
continue
if m.score < self.threshold:
if self.exclude_noise:
continue
pred_class = self.noise_class
if m.is_generic:
if self.exclude_generic:
# Ignore gt sounds with unknown class
continue
true_class = self.targets.detection_class_name
y_true.append(true_class or self.noise_class)
y_pred.append(pred_class or self.noise_class)
fig = self.create_figure()
ax = fig.subplots()
class_names = [*self.targets.class_names]
if not self.exclude_generic:
class_names.append(self.targets.detection_class_name)
if not self.exclude_noise:
class_names.append(self.noise_class)
metrics.ConfusionMatrixDisplay.from_predictions(
y_true,
y_pred,
labels=class_names,
ax=ax,
xticks_rotation="vertical",
cmap=self.cmap,
colorbar=self.add_colorbar,
normalize=self.normalize if self.normalize != "none" else None,
values_format=".2f",
)
yield self.label, fig
@top_class_plots.register(ConfusionMatrixConfig)
@staticmethod
def from_config(config: ConfusionMatrixConfig, targets: TargetProtocol):
return ConfusionMatrix.build(
config=config,
targets=targets,
exclude_generic=config.exclude_generic,
exclude_noise=config.exclude_noise,
noise_class=config.noise_class,
add_colorbar=config.add_colorbar,
normalize=config.normalize,
cmap=config.cmap,
)
class ExampleClassificationPlotConfig(BasePlotConfig):
name: Literal["example_classification"] = "example_classification"
label: str = "example_classification"
title: Optional[str] = "Example Classification"
num_examples: int = 4
threshold: float = 0.2
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class ExampleClassificationPlot(BasePlot):
def __init__(
self,
*args,
num_examples: int = 4,
threshold: float = 0.2,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_examples = num_examples
self.audio_loader = audio_loader
self.threshold = threshold
self.preprocessor = preprocessor
self.num_examples = num_examples
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
grouped = group_matches(clip_evaluations, threshold=self.threshold)
for class_name, matches in grouped.items():
true_positives: List[MatchEval] = get_binned_sample(
matches.true_positives,
n_examples=self.num_examples,
)
false_positives: List[MatchEval] = get_binned_sample(
matches.false_positives,
n_examples=self.num_examples,
)
false_negatives: List[MatchEval] = random.sample(
matches.false_negatives,
k=min(self.num_examples, len(matches.false_negatives)),
)
cross_triggers: List[MatchEval] = get_binned_sample(
matches.cross_triggers, n_examples=self.num_examples
)
fig = self.create_figure()
fig = plot_match_gallery(
true_positives,
false_positives,
false_negatives,
cross_triggers,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
n_examples=self.num_examples,
fig=fig,
)
if self.title is not None:
fig.suptitle(f"{self.title}: {class_name}")
else:
fig.suptitle(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@top_class_plots.register(ExampleClassificationPlotConfig)
@staticmethod
def from_config(
config: ExampleClassificationPlotConfig,
targets: TargetProtocol,
):
return ExampleClassificationPlot.build(
config=config,
targets=targets,
num_examples=config.num_examples,
threshold=config.threshold,
audio_loader=build_audio_loader(config.audio),
preprocessor=build_preprocessor(config.preprocessing),
)
TopClassPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ConfusionMatrixConfig,
ExampleClassificationPlotConfig,
],
Field(discriminator="name"),
]
def build_top_class_plotter(
config: TopClassPlotConfig,
targets: TargetProtocol,
) -> TopClassPlotter:
return top_class_plots.build(config, targets)
@dataclass
class ClassMatches:
false_positives: List[MatchEval] = field(default_factory=list)
false_negatives: List[MatchEval] = field(default_factory=list)
true_positives: List[MatchEval] = field(default_factory=list)
cross_triggers: List[MatchEval] = field(default_factory=list)
def group_matches(
clip_evals: Sequence[ClipEval],
threshold: float = 0.2,
) -> Dict[str, ClassMatches]:
class_examples = defaultdict(ClassMatches)
for clip_eval in clip_evals:
for match in clip_eval.matches:
gt_class = match.true_class
pred_class = match.pred_class
is_pred = match.score >= threshold
if not is_pred and gt_class is not None:
class_examples[gt_class].false_negatives.append(match)
continue
if not is_pred:
continue
if gt_class is None:
class_examples[pred_class].false_positives.append(match)
continue
if gt_class != pred_class:
class_examples[pred_class].cross_triggers.append(match)
continue
class_examples[gt_class].true_positives.append(match)
return class_examples
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
if len(matches) < n_examples:
return matches
indices, pred_scores = zip(
*[(index, match.score) for index, match in enumerate(matches)]
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
df = pd.DataFrame({"indices": indices, "bins": bins})
sample = df.groupby("bins").sample(1)
return [matches[ind] for ind in sample["indices"]]

View File

@ -0,0 +1,106 @@
from typing import Annotated, Callable, Literal, Sequence, Union
import pandas as pd
from pydantic import Field
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipMatches
EvaluationTableGenerator = Callable[[Sequence[ClipMatches]], pd.DataFrame]
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
"evaluation_table"
)
class FullEvaluationTableConfig(BaseConfig):
name: Literal["full_evaluation"] = "full_evaluation"
class FullEvaluationTable:
def __call__(
self, clip_evaluations: Sequence[ClipMatches]
) -> pd.DataFrame:
return extract_matches_dataframe(clip_evaluations)
@tables_registry.register(FullEvaluationTableConfig)
@staticmethod
def from_config(config: FullEvaluationTableConfig):
return FullEvaluationTable()
def extract_matches_dataframe(
clip_evaluations: Sequence[ClipMatches],
) -> pd.DataFrame:
data = []
for clip_evaluation in clip_evaluations:
for match in clip_evaluation.matches:
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
pred_start_time = pred_low_freq = pred_end_time = (
pred_high_freq
) = None
sound_event_annotation = match.sound_event_annotation
if sound_event_annotation is not None:
geometry = sound_event_annotation.sound_event.geometry
assert geometry is not None
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
compute_bounds(geometry)
)
if match.pred_geometry is not None:
(
pred_start_time,
pred_low_freq,
pred_end_time,
pred_high_freq,
) = compute_bounds(match.pred_geometry)
data.append(
{
("recording", "uuid"): match.clip.recording.uuid,
("clip", "uuid"): match.clip.uuid,
("clip", "start_time"): match.clip.start_time,
("clip", "end_time"): match.clip.end_time,
("gt", "uuid"): match.sound_event_annotation.uuid
if match.sound_event_annotation is not None
else None,
("gt", "class"): match.gt_class,
("gt", "det"): match.gt_det,
("gt", "start_time"): gt_start_time,
("gt", "end_time"): gt_end_time,
("gt", "low_freq"): gt_low_freq,
("gt", "high_freq"): gt_high_freq,
("pred", "score"): match.pred_score,
("pred", "class"): match.top_class,
("pred", "class_score"): match.top_class_score,
("pred", "start_time"): pred_start_time,
("pred", "end_time"): pred_end_time,
("pred", "low_freq"): pred_low_freq,
("pred", "high_freq"): pred_high_freq,
("match", "affinity"): match.affinity,
**{
("pred_class_score", key): value
for key, value in match.pred_class_scores.items()
},
}
)
df = pd.DataFrame(data)
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
return df
EvaluationTableConfig = Annotated[
Union[FullEvaluationTableConfig,], Field(discriminator="name")
]
def build_table_generator(
config: EvaluationTableConfig,
) -> EvaluationTableGenerator:
return tables_registry.build(config)

View File

@ -0,0 +1,39 @@
from typing import Annotated, Optional, Union
from pydantic import Field
from batdetect2.evaluate.tasks.base import tasks_registry
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
from batdetect2.evaluate.tasks.clip_classification import (
ClipClassificationTaskConfig,
)
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
from batdetect2.targets import build_targets
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
__all__ = [
"TaskConfig",
"build_task",
]
TaskConfig = Annotated[
Union[
ClassificationTaskConfig,
DetectionTaskConfig,
ClipDetectionTaskConfig,
ClipClassificationTaskConfig,
TopClassDetectionTaskConfig,
],
Field(discriminator="name"),
]
def build_task(
config: TaskConfig,
targets: Optional[TargetProtocol] = None,
) -> EvaluatorProtocol:
targets = targets or build_targets()
return tasks_registry.build(config, targets)

View File

@ -0,0 +1,175 @@
from typing import (
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)
from matplotlib.figure import Figure
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.evaluate.match import (
MatchConfig,
StartTimeMatchConfig,
build_matcher,
)
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"BaseTaskConfig",
"BaseTask",
]
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
"tasks"
)
T_Output = TypeVar("T_Output")
class BaseTaskConfig(BaseConfig):
prefix: str
ignore_start_end: float = 0.01
matching_strategy: MatchConfig = Field(
default_factory=StartTimeMatchConfig
)
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
targets: TargetProtocol
matcher: MatcherProtocol
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
ignore_start_end: float
prefix: str
def __init__(
self,
matcher: MatcherProtocol,
targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
prefix: str,
ignore_start_end: float = 0.01,
plots: Optional[
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
] = None,
):
self.matcher = matcher
self.metrics = metrics
self.plots = plots or []
self.targets = targets
self.prefix = prefix
self.ignore_start_end = ignore_start_end
def compute_metrics(
self,
eval_outputs: List[T_Output],
) -> Dict[str, float]:
scores = [metric(eval_outputs) for metric in self.metrics]
return {
f"{self.prefix}/{name}": score
for metric_output in scores
for name, score in metric_output.items()
}
def generate_plots(
self, eval_outputs: List[T_Output]
) -> Iterable[Tuple[str, Figure]]:
for plot in self.plots:
for name, fig in plot(eval_outputs):
yield f"{self.prefix}/{name}", fig
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> List[T_Output]:
return [
self.evaluate_clip(clip_annotation, preds)
for clip_annotation, preds in zip(clip_annotations, predictions)
]
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> T_Output: ...
def include_sound_event_annotation(
self,
sound_event_annotation: data.SoundEventAnnotation,
clip: data.Clip,
) -> bool:
if not self.targets.filter(sound_event_annotation):
return False
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
return is_in_bounds(
geometry,
clip,
self.ignore_start_end,
)
def include_prediction(
self,
prediction: RawPrediction,
clip: data.Clip,
) -> bool:
return is_in_bounds(
prediction.geometry,
clip,
self.ignore_start_end,
)
@classmethod
def build(
cls,
config: BaseTaskConfig,
targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
plots: Optional[
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
] = None,
**kwargs,
):
matcher = build_matcher(config.matching_strategy)
return cls(
matcher=matcher,
targets=targets,
metrics=metrics,
plots=plots,
prefix=config.prefix,
ignore_start_end=config.ignore_start_end,
**kwargs,
)
def is_in_bounds(
geometry: data.Geometry,
clip: data.Clip,
buffer: float,
) -> bool:
start_time = compute_bounds(geometry)[0]
return (start_time >= clip.start_time + buffer) and (
start_time <= clip.end_time - buffer
)

View File

@ -0,0 +1,149 @@
from typing import (
List,
Literal,
Sequence,
)
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.metrics.classification import (
ClassificationAveragePrecisionConfig,
ClassificationMetricConfig,
ClipEval,
MatchEval,
build_classification_metric,
)
from batdetect2.evaluate.plots.classification import (
ClassificationPlotConfig,
build_classification_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
class ClassificationTaskConfig(BaseTaskConfig):
name: Literal["sound_event_classification"] = "sound_event_classification"
prefix: str = "classification"
metrics: List[ClassificationMetricConfig] = Field(
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
)
plots: List[ClassificationPlotConfig] = Field(default_factory=list)
include_generics: bool = True
class ClassificationTask(BaseTask[ClipEval]):
def __init__(
self,
*args,
include_generics: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.include_generics = include_generics
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEval:
clip = clip_annotation.clip
preds = [
pred for pred in predictions if self.include_prediction(pred, clip)
]
all_gts = [
sound_event
for sound_event in clip_annotation.sound_events
if self.include_sound_event_annotation(sound_event, clip)
]
per_class_matches = {}
for class_name in self.targets.class_names:
class_idx = self.targets.class_names.index(class_name)
# Only match to targets of the given class
gts = [
sound_event
for sound_event in all_gts
if self.is_class(sound_event, class_name)
]
scores = [float(pred.class_scores[class_idx]) for pred in preds]
matches = []
for pred_idx, gt_idx, _ in self.matcher(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
predictions=[pred.geometry for pred in preds],
scores=scores,
):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
true_class = (
self.targets.encode_class(gt) if gt is not None else None
)
score = (
float(pred.class_scores[class_idx])
if pred is not None
else 0
)
matches.append(
MatchEval(
clip=clip,
gt=gt,
pred=pred,
is_prediction=pred is not None,
is_ground_truth=gt is not None,
is_generic=gt is not None and true_class is None,
true_class=true_class,
score=score,
)
)
per_class_matches[class_name] = matches
return ClipEval(clip=clip, matches=per_class_matches)
def is_class(
self,
sound_event: data.SoundEventAnnotation,
class_name: str,
) -> bool:
sound_event_class = self.targets.encode_class(sound_event)
if sound_event_class is None and self.include_generics:
# Sound events that are generic could be of the given
# class
return True
return sound_event_class == class_name
@tasks_registry.register(ClassificationTaskConfig)
@staticmethod
def from_config(
config: ClassificationTaskConfig,
targets: TargetProtocol,
):
metrics = [
build_classification_metric(metric, targets)
for metric in config.metrics
]
plots = [
build_classification_plotter(plot, targets)
for plot in config.plots
]
return ClassificationTask.build(
config=config,
plots=plots,
targets=targets,
metrics=metrics,
)

View File

@ -0,0 +1,85 @@
from collections import defaultdict
from typing import List, Literal, Sequence
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.metrics.clip_classification import (
ClipClassificationAveragePrecisionConfig,
ClipClassificationMetricConfig,
ClipEval,
build_clip_metric,
)
from batdetect2.evaluate.plots.clip_classification import (
ClipClassificationPlotConfig,
build_clip_classification_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
class ClipClassificationTaskConfig(BaseTaskConfig):
name: Literal["clip_classification"] = "clip_classification"
prefix: str = "clip_classification"
metrics: List[ClipClassificationMetricConfig] = Field(
default_factory=lambda: [
ClipClassificationAveragePrecisionConfig(),
]
)
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list)
class ClipClassificationTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEval:
clip = clip_annotation.clip
gt_classes = set()
for sound_event in clip_annotation.sound_events:
if not self.include_sound_event_annotation(sound_event, clip):
continue
class_name = self.targets.encode_class(sound_event)
if class_name is None:
continue
gt_classes.add(class_name)
pred_scores = defaultdict(float)
for pred in predictions:
if not self.include_prediction(pred, clip):
continue
for class_idx, class_name in enumerate(self.targets.class_names):
pred_scores[class_name] = max(
float(pred.class_scores[class_idx]),
pred_scores[class_name],
)
return ClipEval(true_classes=gt_classes, class_scores=pred_scores)
@tasks_registry.register(ClipClassificationTaskConfig)
@staticmethod
def from_config(
config: ClipClassificationTaskConfig,
targets: TargetProtocol,
):
metrics = [build_clip_metric(metric) for metric in config.metrics]
plots = [
build_clip_classification_plotter(plot, targets)
for plot in config.plots
]
return ClipClassificationTask.build(
config=config,
plots=plots,
metrics=metrics,
targets=targets,
)

View File

@ -0,0 +1,76 @@
from typing import List, Literal, Sequence
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.metrics.clip_detection import (
ClipDetectionAveragePrecisionConfig,
ClipDetectionMetricConfig,
ClipEval,
build_clip_metric,
)
from batdetect2.evaluate.plots.clip_detection import (
ClipDetectionPlotConfig,
build_clip_detection_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
class ClipDetectionTaskConfig(BaseTaskConfig):
name: Literal["clip_detection"] = "clip_detection"
prefix: str = "clip_detection"
metrics: List[ClipDetectionMetricConfig] = Field(
default_factory=lambda: [
ClipDetectionAveragePrecisionConfig(),
]
)
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list)
class ClipDetectionTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEval:
clip = clip_annotation.clip
gt_det = any(
self.include_sound_event_annotation(sound_event, clip)
for sound_event in clip_annotation.sound_events
)
pred_score = 0
for pred in predictions:
if not self.include_prediction(pred, clip):
continue
pred_score = max(pred_score, pred.detection_score)
return ClipEval(
gt_det=gt_det,
score=pred_score,
)
@tasks_registry.register(ClipDetectionTaskConfig)
@staticmethod
def from_config(
config: ClipDetectionTaskConfig,
targets: TargetProtocol,
):
metrics = [build_clip_metric(metric) for metric in config.metrics]
plots = [
build_clip_detection_plotter(plot, targets)
for plot in config.plots
]
return ClipDetectionTask.build(
config=config,
metrics=metrics,
targets=targets,
plots=plots,
)

View File

@ -0,0 +1,88 @@
from typing import List, Literal, Sequence
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.metrics.detection import (
ClipEval,
DetectionAveragePrecisionConfig,
DetectionMetricConfig,
MatchEval,
build_detection_metric,
)
from batdetect2.evaluate.plots.detection import (
DetectionPlotConfig,
build_detection_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
class DetectionTaskConfig(BaseTaskConfig):
name: Literal["sound_event_detection"] = "sound_event_detection"
prefix: str = "detection"
metrics: List[DetectionMetricConfig] = Field(
default_factory=lambda: [DetectionAveragePrecisionConfig()]
)
plots: List[DetectionPlotConfig] = Field(default_factory=list)
class DetectionTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEval:
clip = clip_annotation.clip
gts = [
sound_event
for sound_event in clip_annotation.sound_events
if self.include_sound_event_annotation(sound_event, clip)
]
preds = [
pred for pred in predictions if self.include_prediction(pred, clip)
]
scores = [pred.detection_score for pred in preds]
matches = []
for pred_idx, gt_idx, _ in self.matcher(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
predictions=[pred.geometry for pred in preds],
scores=scores,
):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
matches.append(
MatchEval(
gt=gt,
pred=pred,
is_prediction=pred is not None,
is_ground_truth=gt is not None,
score=pred.detection_score if pred is not None else 0,
)
)
return ClipEval(clip=clip, matches=matches)
@tasks_registry.register(DetectionTaskConfig)
@staticmethod
def from_config(
config: DetectionTaskConfig,
targets: TargetProtocol,
):
metrics = [build_detection_metric(metric) for metric in config.metrics]
plots = [
build_detection_plotter(plot, targets) for plot in config.plots
]
return DetectionTask.build(
config=config,
metrics=metrics,
targets=targets,
plots=plots,
)

View File

@ -0,0 +1,111 @@
from typing import List, Literal, Sequence
from pydantic import Field
from soundevent import data
from batdetect2.evaluate.metrics.top_class import (
ClipEval,
MatchEval,
TopClassAveragePrecisionConfig,
TopClassMetricConfig,
build_top_class_metric,
)
from batdetect2.evaluate.plots.top_class import (
TopClassPlotConfig,
build_top_class_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
class TopClassDetectionTaskConfig(BaseTaskConfig):
name: Literal["top_class_detection"] = "top_class_detection"
prefix: str = "top_class"
metrics: List[TopClassMetricConfig] = Field(
default_factory=lambda: [TopClassAveragePrecisionConfig()]
)
plots: List[TopClassPlotConfig] = Field(default_factory=list)
class TopClassDetectionTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
) -> ClipEval:
clip = clip_annotation.clip
gts = [
sound_event
for sound_event in clip_annotation.sound_events
if self.include_sound_event_annotation(sound_event, clip)
]
preds = [
pred for pred in predictions if self.include_prediction(pred, clip)
]
# Take the highest score for each prediction
scores = [pred.class_scores.max() for pred in preds]
matches = []
for pred_idx, gt_idx, _ in self.matcher(
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
predictions=[pred.geometry for pred in preds],
scores=scores,
):
gt = gts[gt_idx] if gt_idx is not None else None
pred = preds[pred_idx] if pred_idx is not None else None
true_class = (
self.targets.encode_class(gt) if gt is not None else None
)
class_idx = (
pred.class_scores.argmax() if pred is not None else None
)
score = (
float(pred.class_scores[class_idx]) if pred is not None else 0
)
pred_class = (
self.targets.class_names[class_idx]
if class_idx is not None
else None
)
matches.append(
MatchEval(
clip=clip,
gt=gt,
pred=pred,
is_ground_truth=gt is not None,
is_prediction=pred is not None,
true_class=true_class,
is_generic=gt is not None and true_class is None,
pred_class=pred_class,
score=score,
)
)
return ClipEval(clip=clip, matches=matches)
@tasks_registry.register(TopClassDetectionTaskConfig)
@staticmethod
def from_config(
config: TopClassDetectionTaskConfig,
targets: TargetProtocol,
):
metrics = [build_top_class_metric(metric) for metric in config.metrics]
plots = [
build_top_class_plotter(plot, targets) for plot in config.plots
]
return TopClassDetectionTask.build(
config=config,
plots=plots,
metrics=metrics,
targets=targets,
)

View File

@ -1,40 +0,0 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Protocol
from soundevent import data
__all__ = [
"MetricsProtocol",
"MatchEvaluation",
]
@dataclass
class MatchEvaluation:
match: data.Match
gt_det: bool
gt_class: Optional[str]
pred_score: float
pred_class_scores: Dict[str, float]
@property
def pred_class(self) -> Optional[str]:
if not self.pred_class_scores:
return None
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
@property
def pred_class_score(self) -> float:
pred_class = self.pred_class
if pred_class is None:
return 0
return self.pred_class_scores[pred_class]
class MetricsProtocol(Protocol):
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...

View File

@ -0,0 +1,10 @@
from batdetect2.inference.batch import process_file_list, run_batch_inference
from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.config import InferenceConfig
__all__ = [
"process_file_list",
"run_batch_inference",
"InferenceConfig",
"get_clips_from_files",
]

View File

@ -0,0 +1,88 @@
from typing import TYPE_CHECKING, List, Optional, Sequence
from lightning import Trainer
from soundevent import data
from batdetect2.audio.loader import build_audio_loader
from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.dataset import build_inference_loader
from batdetect2.inference.lightning import InferenceModule
from batdetect2.models import Model
from batdetect2.preprocess.preprocessor import build_preprocessor
from batdetect2.targets.targets import build_targets
from batdetect2.typing.postprocess import BatDetect2Prediction
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
from batdetect2.typing import (
AudioLoader,
PreprocessorProtocol,
TargetProtocol,
)
def run_batch_inference(
model,
clips: Sequence[data.Clip],
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
num_workers: Optional[int] = None,
) -> List[BatDetect2Prediction]:
from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config()
audio_loader = audio_loader or build_audio_loader()
preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate,
)
targets = targets or build_targets()
loader = build_inference_loader(
clips,
audio_loader=audio_loader,
preprocessor=preprocessor,
config=config.inference.loader,
num_workers=num_workers,
)
module = InferenceModule(model)
trainer = Trainer(enable_checkpointing=False, logger=False)
outputs = trainer.predict(module, loader)
return [
clip_prediction
for clip_predictions in outputs # type: ignore
for clip_prediction in clip_predictions
]
def process_file_list(
model: Model,
paths: Sequence[data.PathLike],
config: "BatDetect2Config",
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
num_workers: Optional[int] = None,
) -> List[BatDetect2Prediction]:
clip_config = config.inference.clipping
clips = get_clips_from_files(
paths,
duration=clip_config.duration,
overlap=clip_config.overlap,
max_empty=clip_config.max_empty,
discard_empty=clip_config.discard_empty,
)
return run_batch_inference(
model,
clips,
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,
config=config,
num_workers=num_workers,
)

View File

@ -0,0 +1,75 @@
from typing import List, Sequence
from uuid import uuid5
import numpy as np
from soundevent import data
def get_clips_from_files(
paths: Sequence[data.PathLike],
duration: float,
overlap: float = 0.0,
max_empty: float = 0.0,
discard_empty: bool = True,
compute_hash: bool = False,
) -> List[data.Clip]:
clips: List[data.Clip] = []
for path in paths:
recording = data.Recording.from_file(path, compute_hash=compute_hash)
clips.extend(
get_recording_clips(
recording,
duration,
overlap=overlap,
max_empty=max_empty,
discard_empty=discard_empty,
)
)
return clips
def get_recording_clips(
recording: data.Recording,
duration: float,
overlap: float = 0.0,
max_empty: float = 0.0,
discard_empty: bool = True,
) -> Sequence[data.Clip]:
start_time = 0
duration = recording.duration
hop = duration * (1 - overlap)
num_clips = int(np.ceil(duration / hop))
if num_clips == 0:
# This should only happen if the clip's duration is zero,
# which should never happen in practice, but just in case...
return []
clips = []
for i in range(num_clips):
start = start_time + i * hop
end = start + duration
if end > duration:
empty_duration = end - duration
if empty_duration > max_empty and discard_empty:
# Discard clips that contain too much empty space
continue
clips.append(
data.Clip(
uuid=uuid5(recording.uuid, f"{start}_{end}"),
recording=recording,
start_time=start,
end_time=end,
)
)
if discard_empty:
clips = [clip for clip in clips if clip.duration > max_empty]
return clips

View File

@ -0,0 +1,21 @@
from pydantic import Field
from batdetect2.core.configs import BaseConfig
from batdetect2.inference.dataset import InferenceLoaderConfig
__all__ = ["InferenceConfig"]
class ClipingConfig(BaseConfig):
enabled: bool = True
duration: float = 0.5
overlap: float = 0.0
max_empty: float = 0.0
discard_empty: bool = True
class InferenceConfig(BaseConfig):
loader: InferenceLoaderConfig = Field(
default_factory=InferenceLoaderConfig
)
clipping: ClipingConfig = Field(default_factory=ClipingConfig)

View File

@ -0,0 +1,120 @@
from typing import List, NamedTuple, Optional, Sequence
import torch
from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import build_audio_loader
from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol
__all__ = [
"InferenceDataset",
"build_inference_dataset",
"build_inference_loader",
]
DEFAULT_INFERENCE_CLIP_DURATION = 0.512
class DatasetItem(NamedTuple):
spec: torch.Tensor
idx: torch.Tensor
start_time: torch.Tensor
end_time: torch.Tensor
class InferenceDataset(Dataset[DatasetItem]):
clips: List[data.Clip]
def __init__(
self,
clips: Sequence[data.Clip],
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
audio_dir: Optional[data.PathLike] = None,
):
self.clips = list(clips)
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.audio_dir = audio_dir
def __len__(self):
return len(self.clips)
def __getitem__(self, idx: int) -> DatasetItem:
clip = self.clips[idx]
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
wav_tensor = torch.tensor(wav).unsqueeze(0)
spectrogram = self.preprocessor(wav_tensor)
return DatasetItem(
spec=spectrogram,
idx=torch.tensor(idx),
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)
class InferenceLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8
def build_inference_loader(
clips: Sequence[data.Clip],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[InferenceLoaderConfig] = None,
num_workers: Optional[int] = None,
) -> DataLoader[DatasetItem]:
logger.info("Building inference data loader...")
config = config or InferenceLoaderConfig()
inference_dataset = build_inference_dataset(
clips,
audio_loader=audio_loader,
preprocessor=preprocessor,
)
num_workers = num_workers or config.num_workers
return DataLoader(
inference_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
collate_fn=_collate_fn,
)
def build_inference_dataset(
clips: Sequence[data.Clip],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
) -> InferenceDataset:
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
return InferenceDataset(
clips,
audio_loader=audio_loader,
preprocessor=preprocessor,
)
def _collate_fn(batch: List[DatasetItem]) -> DatasetItem:
max_width = max(item.spec.shape[-1] for item in batch)
return DatasetItem(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)

View File

@ -0,0 +1,52 @@
from typing import Sequence
from lightning import LightningModule
from torch.utils.data import DataLoader
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing.postprocess import BatDetect2Prediction
class InferenceModule(LightningModule):
def __init__(self, model: Model):
super().__init__()
self.model = model
def predict_step(
self,
batch: DatasetItem,
batch_idx: int,
dataloader_idx: int = 0,
) -> Sequence[BatDetect2Prediction]:
dataset = self.get_dataset()
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor(
outputs,
start_times=[clip.start_time for clip in clips],
)
predictions = [
BatDetect2Prediction(
clip=clip,
predictions=to_raw_predictions(
clip_dets.numpy(),
targets=self.model.targets,
),
)
for clip, clip_dets in zip(clips, clip_detections)
]
return predictions
def get_dataset(self) -> InferenceDataset:
dataloaders = self.trainer.predict_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, InferenceDataset)
return dataset

314
src/batdetect2/logging.py Normal file
View File

@ -0,0 +1,314 @@
import io
from collections.abc import Callable
from functools import partial
from pathlib import Path
from typing import (
Annotated,
Any,
Dict,
Generic,
Literal,
Optional,
Protocol,
TypeVar,
Union,
)
import numpy as np
import pandas as pd
from lightning.pytorch.loggers import (
CSVLogger,
Logger,
MLFlowLogger,
TensorBoardLogger,
)
from loguru import logger
from matplotlib.figure import Figure
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
class BaseLoggerConfig(BaseConfig):
log_dir: Path = DEFAULT_LOGS_DIR
experiment_name: Optional[str] = None
run_name: Optional[str] = None
class DVCLiveConfig(BaseLoggerConfig):
name: Literal["dvclive"] = "dvclive"
prefix: str = ""
log_model: Union[bool, Literal["all"]] = False
monitor_system: bool = False
class CSVLoggerConfig(BaseLoggerConfig):
name: Literal["csv"] = "csv"
flush_logs_every_n_steps: int = 100
class TensorBoardLoggerConfig(BaseLoggerConfig):
name: Literal["tensorboard"] = "tensorboard"
log_graph: bool = False
class MLFlowLoggerConfig(BaseLoggerConfig):
name: Literal["mlflow"] = "mlflow"
tracking_uri: Optional[str] = "http://localhost:5000"
tags: Optional[dict[str, Any]] = None
log_model: bool = False
LoggerConfig = Annotated[
Union[
DVCLiveConfig,
CSVLoggerConfig,
TensorBoardLoggerConfig,
MLFlowLoggerConfig,
],
Field(discriminator="name"),
]
T = TypeVar("T", bound=LoggerConfig, contravariant=True)
class LoggerBuilder(Protocol, Generic[T]):
def __call__(
self,
config: T,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger: ...
def create_dvclive_logger(
config: DVCLiveConfig,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger:
try:
from dvclive.lightning import DVCLiveLogger # type: ignore
except ImportError as error:
raise ValueError(
"DVCLive is not installed and cannot be used for logging"
"Make sure you have it installed by running `pip install dvclive`"
"or `uv add dvclive`"
) from error
return DVCLiveLogger(
dir=log_dir if log_dir is not None else config.log_dir,
run_name=run_name if run_name is not None else config.run_name,
experiment=experiment_name
if experiment_name is not None
else config.experiment_name,
prefix=config.prefix,
log_model=config.log_model,
monitor_system=config.monitor_system,
)
def create_csv_logger(
config: CSVLoggerConfig,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger:
from lightning.pytorch.loggers import CSVLogger
if log_dir is None:
log_dir = Path(config.log_dir)
if run_name is None:
run_name = config.run_name
if experiment_name is None:
experiment_name = config.experiment_name
name = run_name
if run_name is not None and experiment_name is not None:
name = str(Path(experiment_name) / run_name)
return CSVLogger(
save_dir=str(log_dir),
name=name,
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
)
def create_tensorboard_logger(
config: TensorBoardLoggerConfig,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger:
from lightning.pytorch.loggers import TensorBoardLogger
if log_dir is None:
log_dir = Path(config.log_dir)
if run_name is None:
run_name = config.run_name
if experiment_name is None:
experiment_name = config.experiment_name
name = run_name
if name is None:
name = experiment_name
if run_name is not None and experiment_name is not None:
name = str(Path(experiment_name) / run_name)
return TensorBoardLogger(
save_dir=str(log_dir),
name=name,
log_graph=config.log_graph,
)
def create_mlflow_logger(
config: MLFlowLoggerConfig,
log_dir: Optional[data.PathLike] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger:
try:
from lightning.pytorch.loggers import MLFlowLogger
except ImportError as error:
raise ValueError(
"MLFlow is not installed and cannot be used for logging. "
"Make sure you have it installed by running `pip install mlflow` "
"or `uv add mlflow`"
) from error
if experiment_name is None:
experiment_name = config.experiment_name or "Default"
if log_dir is None:
log_dir = config.log_dir
return MLFlowLogger(
experiment_name=experiment_name
if experiment_name is not None
else config.experiment_name,
run_name=run_name if run_name is not None else config.run_name,
save_dir=str(log_dir),
tracking_uri=config.tracking_uri,
tags=config.tags,
log_model=config.log_model,
)
LOGGER_FACTORY: Dict[str, LoggerBuilder] = {
"dvclive": create_dvclive_logger,
"csv": create_csv_logger,
"tensorboard": create_tensorboard_logger,
"mlflow": create_mlflow_logger,
}
def build_logger(
config: LoggerConfig,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger:
logger.opt(lazy=True).debug(
"Building logger with config: \n{}",
lambda: config.to_yaml_string(),
)
logger_type = config.name
if logger_type not in LOGGER_FACTORY:
raise ValueError(f"Unknown logger type: {logger_type}")
creation_func = LOGGER_FACTORY[logger_type]
return creation_func(
config,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
)
PlotLogger = Callable[[str, Figure, int], None]
def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
if isinstance(logger, TensorBoardLogger):
return logger.experiment.add_figure
if isinstance(logger, MLFlowLogger):
def plot_figure(name, figure, step):
image = _convert_figure_to_array(figure)
name = name.replace("/", "_")
return logger.experiment.log_image(
logger.run_id,
image,
key=name,
step=step,
)
return plot_figure
if isinstance(logger, CSVLogger):
return partial(save_figure, dir=Path(logger.log_dir))
TableLogger = Callable[[str, pd.DataFrame, int], None]
def get_table_logger(logger: Logger) -> Optional[TableLogger]:
if isinstance(logger, TensorBoardLogger):
return partial(save_table, dir=Path(logger.log_dir))
if isinstance(logger, MLFlowLogger):
def plot_figure(name: str, df: pd.DataFrame, step: int):
return logger.experiment.log_table(
logger.run_id,
data=df,
artifact_file=f"{name}_step_{step}.json",
)
return plot_figure
if isinstance(logger, CSVLogger):
return partial(save_table, dir=Path(logger.log_dir))
def save_table(name: str, df: pd.DataFrame, step: int, dir: Path) -> None:
path = dir / "tables" / f"{name}_step_{step}.csv"
if not path.parent.exists():
path.parent.mkdir(parents=True)
df.to_csv(path, index=False)
def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
path = dir / "plots" / f"{name}_step_{step}.png"
if not path.parent.exists():
path.parent.mkdir(parents=True)
fig.savefig(path, transparent=True, bbox_inches="tight")
def _convert_figure_to_array(figure: Figure) -> np.ndarray:
with io.BytesIO() as buff:
figure.savefig(buff, format="raw")
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = figure.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
return im

View File

@ -26,15 +26,13 @@ for creating a standard BatDetect2 model instance is the `build_model` function
provided here.
"""
from typing import Optional
from typing import List, Optional
from loguru import logger
import torch
from batdetect2.models.backbones import (
Backbone,
BackboneConfig,
build_backbone,
load_backbone_config,
)
from batdetect2.models.blocks import (
ConvConfig,
@ -48,29 +46,37 @@ from batdetect2.models.bottleneck import (
BottleneckConfig,
build_bottleneck,
)
from batdetect2.models.config import (
BackboneConfig,
load_backbone_config,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
build_decoder,
)
from batdetect2.models.detectors import (
Detector,
build_detector,
)
from batdetect2.models.detectors import Detector, build_detector
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
EncoderConfig,
build_encoder,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.typing.models import DetectionModel
from batdetect2.typing.postprocess import (
ClipDetectionsTensor,
PostprocessorProtocol,
)
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"BBoxHead",
"Backbone",
"BackboneConfig",
"BackboneModel",
"BackboneModel",
"Bottleneck",
"BottleneckConfig",
"ClassifierHead",
@ -78,65 +84,68 @@ __all__ = [
"DEFAULT_DECODER_CONFIG",
"DEFAULT_ENCODER_CONFIG",
"DecoderConfig",
"DetectionModel",
"Detector",
"DetectorHead",
"EncoderConfig",
"FreqCoordConvDownConfig",
"FreqCoordConvUpConfig",
"ModelOutput",
"StandardConvDownConfig",
"StandardConvUpConfig",
"build_backbone",
"build_bottleneck",
"build_decoder",
"build_detector",
"build_encoder",
"build_model",
"build_detector",
"load_backbone_config",
"Model",
"build_model",
]
class Model(torch.nn.Module):
detector: DetectionModel
preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol
targets: TargetProtocol
def __init__(
self,
detector: DetectionModel,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
targets: TargetProtocol,
):
super().__init__()
self.detector = detector
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.targets = targets
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
spec = self.preprocessor(wav)
outputs = self.detector(spec)
return self.postprocessor(outputs)
def build_model(
num_classes: int,
config: Optional[BackboneConfig] = None,
) -> DetectionModel:
"""Build the complete BatDetect2 detection model.
This high-level factory function constructs the standard BatDetect2 model
architecture. It first builds the feature extraction backbone (typically an
encoder-bottleneck-decoder structure) based on the provided
`BackboneConfig` (or defaults if None), and then attaches the standard
prediction heads (`DetectorHead`, `ClassifierHead`, `BBoxHead`) using the
`build_detector` function.
Parameters
----------
num_classes : int
The number of specific target classes the model should predict
(required for the `ClassifierHead`). Must be positive.
config : BackboneConfig, optional
Configuration object specifying the architecture of the backbone
(encoder, bottleneck, decoder). If None, default configurations defined
within the respective builder functions (`build_encoder`, etc.) will be
used to construct a default backbone architecture.
Returns
-------
DetectionModel
An initialized `Detector` model instance.
Raises
------
ValueError
If `num_classes` is not positive, or if errors occur during the
construction of the backbone or detector components (e.g., incompatible
configurations, invalid parameters).
"""
targets: Optional[TargetProtocol] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
postprocessor: Optional[PostprocessorProtocol] = None,
):
config = config or BackboneConfig()
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(),
targets = targets or build_targets()
preprocessor = preprocessor or build_preprocessor()
postprocessor = postprocessor or build_postprocessor(
preprocessor=preprocessor,
)
detector = build_detector(
num_classes=len(targets.class_names),
config=config,
)
return Model(
detector=detector,
postprocessor=postprocessor,
preprocessor=preprocessor,
targets=targets,
)
backbone = build_backbone(config)
return build_detector(num_classes, backbone)

View File

@ -18,37 +18,20 @@ automatic padding to handle input sizes not perfectly divisible by the
network's total downsampling factor.
"""
from typing import Optional, Tuple
from typing import Tuple
import torch
import torch.nn.functional as F
from soundevent import data
from torch import nn
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig,
build_bottleneck,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
Decoder,
DecoderConfig,
build_decoder,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
Encoder,
EncoderConfig,
build_encoder,
)
from batdetect2.models.types import BackboneModel
from batdetect2.models.bottleneck import build_bottleneck
from batdetect2.models.config import BackboneConfig
from batdetect2.models.decoder import Decoder, build_decoder
from batdetect2.models.encoder import Encoder, build_encoder
from batdetect2.typing.models import BackboneModel
__all__ = [
"Backbone",
"BackboneConfig",
"load_backbone_config",
"build_backbone",
]
@ -161,82 +144,6 @@ class Backbone(BackboneModel):
return x
class BackboneConfig(BaseConfig):
"""Configuration for the Encoder-Decoder Backbone network.
Aggregates configurations for the encoder, bottleneck, and decoder
components, along with defining the input and final output dimensions
for the complete backbone.
Attributes
----------
input_height : int, default=128
Expected height (frequency bins) of the input spectrograms to the
backbone. Must be positive.
in_channels : int, default=1
Expected number of channels in the input spectrograms (e.g., 1 for
mono). Must be positive.
encoder : EncoderConfig, optional
Configuration for the encoder. If None or omitted,
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
encoder module) will be used.
bottleneck : BottleneckConfig, optional
Configuration for the bottleneck layer connecting encoder and decoder.
If None or omitted, the default bottleneck configuration will be used.
decoder : DecoderConfig, optional
Configuration for the decoder. If None or omitted,
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
decoder module) will be used.
out_channels : int, default=32
Desired number of channels in the final feature map output by the
backbone. Must be positive.
"""
input_height: int = 128
in_channels: int = 1
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
out_channels: int = 32
def load_backbone_config(
path: data.PathLike,
field: Optional[str] = None,
) -> BackboneConfig:
"""Load the backbone configuration from a file.
Reads a configuration file (YAML) and validates it against the
`BackboneConfig` schema, potentially extracting data from a nested field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
backbone configuration (e.g., "model.backbone"). If None, the entire
file content is used.
Returns
-------
BackboneConfig
The loaded and validated backbone configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded config data does not conform to `BackboneConfig`.
KeyError, TypeError
If `field` specifies an invalid path.
"""
return load_config(path, schema=BackboneConfig, field=field)
def build_backbone(config: BackboneConfig) -> BackboneModel:
"""Factory function to build a Backbone from configuration.

View File

@ -34,7 +34,7 @@ import torch.nn.functional as F
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
__all__ = [
"ConvBlock",
@ -55,6 +55,12 @@ __all__ = [
]
class SelfAttentionConfig(BaseConfig):
name: Literal["SelfAttention"] = "SelfAttention"
attention_channels: int
temperature: float = 1
class SelfAttention(nn.Module):
"""Self-Attention mechanism operating along the time dimension.
@ -115,6 +121,7 @@ class SelfAttention(nn.Module):
# Note, does not encode position information (absolute or relative)
self.temperature = temperature
self.att_dim = attention_channels
self.key_fun = nn.Linear(in_channels, attention_channels)
self.value_fun = nn.Linear(in_channels, attention_channels)
self.query_fun = nn.Linear(in_channels, attention_channels)
@ -171,7 +178,7 @@ class SelfAttention(nn.Module):
class ConvConfig(BaseConfig):
"""Configuration for a basic ConvBlock."""
block_type: Literal["ConvBlock"] = "ConvBlock"
name: Literal["ConvBlock"] = "ConvBlock"
"""Discriminator field indicating the block type."""
out_channels: int
@ -218,7 +225,7 @@ class ConvBlock(nn.Module):
kernel_size=kernel_size,
padding=pad_size,
)
self.conv_bn = nn.BatchNorm2d(out_channels)
self.batch_norm = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Conv -> BN -> ReLU.
@ -233,7 +240,7 @@ class ConvBlock(nn.Module):
torch.Tensor
Output tensor, shape `(B, C_out, H, W)`.
"""
return F.relu_(self.conv_bn(self.conv(x)))
return F.relu_(self.batch_norm(self.conv(x)))
class VerticalConv(nn.Module):
@ -293,7 +300,7 @@ class VerticalConv(nn.Module):
class FreqCoordConvDownConfig(BaseConfig):
"""Configuration for a FreqCoordConvDownBlock."""
block_type: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
name: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
"""Discriminator field indicating the block type."""
out_channels: int
@ -357,7 +364,7 @@ class FreqCoordConvDownBlock(nn.Module):
padding=pad_size,
stride=1,
)
self.conv_bn = nn.BatchNorm2d(out_channels)
self.batch_norm = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply CoordF -> Conv -> MaxPool -> BN -> ReLU.
@ -376,14 +383,14 @@ class FreqCoordConvDownBlock(nn.Module):
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
x = torch.cat((x, freq_info), 1)
x = F.max_pool2d(self.conv(x), 2, 2)
x = F.relu(self.conv_bn(x), inplace=True)
x = F.relu(self.batch_norm(x), inplace=True)
return x
class StandardConvDownConfig(BaseConfig):
"""Configuration for a StandardConvDownBlock."""
block_type: Literal["StandardConvDown"] = "StandardConvDown"
name: Literal["StandardConvDown"] = "StandardConvDown"
"""Discriminator field indicating the block type."""
out_channels: int
@ -431,7 +438,7 @@ class StandardConvDownBlock(nn.Module):
padding=pad_size,
stride=1,
)
self.conv_bn = nn.BatchNorm2d(out_channels)
self.batch_norm = nn.BatchNorm2d(out_channels)
def forward(self, x):
"""Apply Conv -> MaxPool -> BN -> ReLU.
@ -447,13 +454,13 @@ class StandardConvDownBlock(nn.Module):
Output tensor, shape `(B, C_out, H/2, W/2)`.
"""
x = F.max_pool2d(self.conv(x), 2, 2)
return F.relu(self.conv_bn(x), inplace=True)
return F.relu(self.batch_norm(x), inplace=True)
class FreqCoordConvUpConfig(BaseConfig):
"""Configuration for a FreqCoordConvUpBlock."""
block_type: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
name: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
"""Discriminator field indicating the block type."""
out_channels: int
@ -527,7 +534,7 @@ class FreqCoordConvUpBlock(nn.Module):
kernel_size=kernel_size,
padding=pad_size,
)
self.conv_bn = nn.BatchNorm2d(out_channels)
self.batch_norm = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Interpolate -> Concat Coords -> Conv -> BN -> ReLU.
@ -555,14 +562,14 @@ class FreqCoordConvUpBlock(nn.Module):
freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3])
op = torch.cat((op, freq_info), 1)
op = self.conv(op)
op = F.relu(self.conv_bn(op), inplace=True)
op = F.relu(self.batch_norm(op), inplace=True)
return op
class StandardConvUpConfig(BaseConfig):
"""Configuration for a StandardConvUpBlock."""
block_type: Literal["StandardConvUp"] = "StandardConvUp"
name: Literal["StandardConvUp"] = "StandardConvUp"
"""Discriminator field indicating the block type."""
out_channels: int
@ -618,7 +625,7 @@ class StandardConvUpBlock(nn.Module):
kernel_size=kernel_size,
padding=pad_size,
)
self.conv_bn = nn.BatchNorm2d(out_channels)
self.batch_norm = nn.BatchNorm2d(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply Interpolate -> Conv -> BN -> ReLU.
@ -643,7 +650,7 @@ class StandardConvUpBlock(nn.Module):
align_corners=False,
)
op = self.conv(op)
op = F.relu(self.conv_bn(op), inplace=True)
op = F.relu(self.batch_norm(op), inplace=True)
return op
@ -654,15 +661,16 @@ LayerConfig = Annotated[
StandardConvDownConfig,
FreqCoordConvUpConfig,
StandardConvUpConfig,
SelfAttentionConfig,
"LayerGroupConfig",
],
Field(discriminator="block_type"),
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configuration models."""
class LayerGroupConfig(BaseConfig):
block_type: Literal["LayerGroup"] = "LayerGroup"
name: Literal["LayerGroup"] = "LayerGroup"
layers: List[LayerConfig]
@ -678,7 +686,7 @@ def build_layer_from_config(
parameters derived from the config and the current pipeline state
(`input_height`, `in_channels`).
It uses the `block_type` field within the `config` object to determine
It uses the `name` field within the `config` object to determine
which block class to instantiate.
Parameters
@ -690,7 +698,7 @@ def build_layer_from_config(
config : LayerConfig
A Pydantic configuration object for the desired block (e.g., an
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
by its `block_type` field.
by its `name` field.
Returns
-------
@ -703,11 +711,11 @@ def build_layer_from_config(
Raises
------
NotImplementedError
If the `config.block_type` does not correspond to a known block type.
If the `config.name` does not correspond to a known block type.
ValueError
If parameters derived from the config are invalid for the block.
"""
if config.block_type == "ConvBlock":
if config.name == "ConvBlock":
return (
ConvBlock(
in_channels=in_channels,
@ -719,7 +727,7 @@ def build_layer_from_config(
input_height,
)
if config.block_type == "FreqCoordConvDown":
if config.name == "FreqCoordConvDown":
return (
FreqCoordConvDownBlock(
in_channels=in_channels,
@ -732,7 +740,7 @@ def build_layer_from_config(
input_height // 2,
)
if config.block_type == "StandardConvDown":
if config.name == "StandardConvDown":
return (
StandardConvDownBlock(
in_channels=in_channels,
@ -744,7 +752,7 @@ def build_layer_from_config(
input_height // 2,
)
if config.block_type == "FreqCoordConvUp":
if config.name == "FreqCoordConvUp":
return (
FreqCoordConvUpBlock(
in_channels=in_channels,
@ -757,7 +765,7 @@ def build_layer_from_config(
input_height * 2,
)
if config.block_type == "StandardConvUp":
if config.name == "StandardConvUp":
return (
StandardConvUpBlock(
in_channels=in_channels,
@ -769,7 +777,18 @@ def build_layer_from_config(
input_height * 2,
)
if config.block_type == "LayerGroup":
if config.name == "SelfAttention":
return (
SelfAttention(
in_channels=in_channels,
attention_channels=config.attention_channels,
temperature=config.temperature,
),
config.attention_channels,
input_height,
)
if config.name == "LayerGroup":
current_channels = in_channels
current_height = input_height
@ -785,4 +804,4 @@ def build_layer_from_config(
return nn.Sequential(*blocks), current_channels, current_height
raise NotImplementedError(f"Unknown block type {config.block_type}")
raise NotImplementedError(f"Unknown block type {config.name}")

View File

@ -14,47 +14,26 @@ A factory function `build_bottleneck` constructs the appropriate bottleneck
module based on the provided configuration.
"""
from typing import Optional
from typing import Annotated, List, Optional, Union
import torch
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import SelfAttention, VerticalConv
from batdetect2.core.configs import BaseConfig
from batdetect2.models.blocks import (
SelfAttentionConfig,
VerticalConv,
build_layer_from_config,
)
__all__ = [
"BottleneckConfig",
"Bottleneck",
"BottleneckAttn",
"build_bottleneck",
]
class BottleneckConfig(BaseConfig):
"""Configuration for the bottleneck layer(s).
Defines the number of channels within the bottleneck and whether to include
a self-attention mechanism.
Attributes
----------
channels : int
The number of output channels produced by the main convolutional layer
within the bottleneck. This often matches the number of channels coming
from the last encoder stage, but can be different. Must be positive.
This also defines the channel dimensions used within the optional
`SelfAttention` layer.
self_attention : bool
If True, includes a `SelfAttention` layer operating on the time
dimension after an initial `VerticalConv` layer within the bottleneck.
If False, only the initial `VerticalConv` (and height repetition) is
performed.
"""
channels: int
self_attention: bool
class Bottleneck(nn.Module):
"""Base Bottleneck module for Encoder-Decoder architectures.
@ -99,16 +78,24 @@ class Bottleneck(nn.Module):
input_height: int,
in_channels: int,
out_channels: int,
bottleneck_channels: Optional[int] = None,
layers: Optional[List[torch.nn.Module]] = None,
) -> None:
"""Initialize the base Bottleneck layer."""
super().__init__()
self.in_channels = in_channels
self.input_height = input_height
self.out_channels = out_channels
self.bottleneck_channels = (
bottleneck_channels
if bottleneck_channels is not None
else out_channels
)
self.layers = nn.ModuleList(layers or [])
self.conv_vert = VerticalConv(
in_channels=in_channels,
out_channels=out_channels,
out_channels=self.bottleneck_channels,
input_height=input_height,
)
@ -132,73 +119,52 @@ class Bottleneck(nn.Module):
convolution.
"""
x = self.conv_vert(x)
for layer in self.layers:
x = layer(x)
return x.repeat([1, 1, self.input_height, 1])
class BottleneckAttn(Bottleneck):
"""Bottleneck module including a Self-Attention layer.
BottleneckLayerConfig = Annotated[
Union[SelfAttentionConfig,],
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""
Extends the base `Bottleneck` by inserting a `SelfAttention` layer after
the initial `VerticalConv`. This allows the bottleneck to capture global
temporal dependencies in the summarized frequency features before passing
them to the decoder.
Sequence: VerticalConv -> SelfAttention -> Repeat Height.
class BottleneckConfig(BaseConfig):
"""Configuration for the bottleneck layer(s).
Parameters
Defines the number of channels within the bottleneck and whether to include
a self-attention mechanism.
Attributes
----------
input_height : int
Height (frequency bins) of the input tensor from the encoder.
in_channels : int
Number of channels in the input tensor from the encoder.
out_channels : int
Number of output channels produced by the `VerticalConv` and
subsequently processed and output by this bottleneck. Also determines
the input/output channels of the internal `SelfAttention` layer.
attention : nn.Module
An initialized `SelfAttention` module instance.
Raises
------
ValueError
If `input_height`, `in_channels`, or `out_channels` are not positive.
channels : int
The number of output channels produced by the main convolutional layer
within the bottleneck. This often matches the number of channels coming
from the last encoder stage, but can be different. Must be positive.
This also defines the channel dimensions used within the optional
`SelfAttention` layer.
self_attention : bool
If True, includes a `SelfAttention` layer operating on the time
dimension after an initial `VerticalConv` layer within the bottleneck.
If False, only the initial `VerticalConv` (and height repetition) is
performed.
"""
def __init__(
self,
input_height: int,
in_channels: int,
out_channels: int,
attention: nn.Module,
) -> None:
"""Initialize the Bottleneck with Self-Attention."""
super().__init__(input_height, in_channels, out_channels)
self.attention = attention
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Process input tensor.
Parameters
----------
x : torch.Tensor
Input tensor from the encoder bottleneck, shape
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
`H_in` must match `self.input_height`.
Returns
-------
torch.Tensor
Output tensor, shape `(B, C_out, H_in, W)`, after applying attention
and repeating the height dimension.
"""
x = self.conv_vert(x)
x = self.attention(x)
return x.repeat([1, 1, self.input_height, 1])
channels: int
layers: List[BottleneckLayerConfig] = Field(
default_factory=list,
)
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
channels=256,
self_attention=True,
layers=[
SelfAttentionConfig(attention_channels=256),
],
)
@ -234,21 +200,25 @@ def build_bottleneck(
"""
config = config or DEFAULT_BOTTLENECK_CONFIG
if config.self_attention:
attention = SelfAttention(
in_channels=config.channels,
attention_channels=config.channels,
)
current_channels = in_channels
current_height = input_height
return BottleneckAttn(
input_height=input_height,
in_channels=in_channels,
out_channels=config.channels,
attention=attention,
layers = []
for layer_config in config.layers:
layer, current_channels, current_height = build_layer_from_config(
input_height=current_height,
in_channels=current_channels,
config=layer_config,
)
assert current_height == input_height, (
"Bottleneck layers should not change the spectrogram height"
)
layers.append(layer)
return Bottleneck(
input_height=input_height,
in_channels=in_channels,
out_channels=config.channels,
layers=layers,
)

View File

@ -0,0 +1,98 @@
from typing import Optional
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
EncoderConfig,
)
__all__ = [
"BackboneConfig",
"load_backbone_config",
]
class BackboneConfig(BaseConfig):
"""Configuration for the Encoder-Decoder Backbone network.
Aggregates configurations for the encoder, bottleneck, and decoder
components, along with defining the input and final output dimensions
for the complete backbone.
Attributes
----------
input_height : int, default=128
Expected height (frequency bins) of the input spectrograms to the
backbone. Must be positive.
in_channels : int, default=1
Expected number of channels in the input spectrograms (e.g., 1 for
mono). Must be positive.
encoder : EncoderConfig, optional
Configuration for the encoder. If None or omitted,
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
encoder module) will be used.
bottleneck : BottleneckConfig, optional
Configuration for the bottleneck layer connecting encoder and decoder.
If None or omitted, the default bottleneck configuration will be used.
decoder : DecoderConfig, optional
Configuration for the decoder. If None or omitted,
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
decoder module) will be used.
out_channels : int, default=32
Desired number of channels in the final feature map output by the
backbone. Must be positive.
"""
input_height: int = 128
in_channels: int = 1
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
out_channels: int = 32
def load_backbone_config(
path: data.PathLike,
field: Optional[str] = None,
) -> BackboneConfig:
"""Load the backbone configuration from a file.
Reads a configuration file (YAML) and validates it against the
`BackboneConfig` schema, potentially extracting data from a nested field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
backbone configuration (e.g., "model.backbone"). If None, the entire
file content is used.
Returns
-------
BackboneConfig
The loaded and validated backbone configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded config data does not conform to `BackboneConfig`.
KeyError, TypeError
If `field` specifies an invalid path.
"""
return load_config(path, schema=BackboneConfig, field=field)

View File

@ -24,7 +24,7 @@ import torch
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.models.blocks import (
ConvConfig,
FreqCoordConvUpConfig,
@ -47,7 +47,7 @@ DecoderLayerConfig = Annotated[
StandardConvUpConfig,
LayerGroupConfig,
],
Field(discriminator="block_type"),
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Decoder."""
@ -63,7 +63,7 @@ class DecoderConfig(BaseConfig):
layers : List[DecoderLayerConfig]
An ordered list of configuration objects, each defining one layer or
block in the decoder sequence. Each item must be a valid block
config including a `block_type` field and necessary parameters like
config including a `name` field and necessary parameters like
`out_channels`. Input channels for each layer are inferred sequentially.
The list must contain at least one layer.
"""
@ -249,9 +249,9 @@ def build_decoder(
------
ValueError
If `in_channels` or `input_height` are not positive, or if the layer
configuration is invalid (e.g., empty list, unknown `block_type`).
configuration is invalid (e.g., empty list, unknown `name`).
NotImplementedError
If `build_layer_from_config` encounters an unknown `block_type`.
If `build_layer_from_config` encounters an unknown `name`.
"""
config = config or DEFAULT_DECODER_CONFIG

View File

@ -8,18 +8,25 @@ classifying them.
The primary components are:
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
- `build_detector`: A factory function to conveniently construct a standard
`Detector` instance given a backbone and the number of target classes.
This module focuses purely on the neural network architecture definition. The
logic for preprocessing inputs and postprocessing/decoding outputs resides in
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
"""
import torch
from typing import Optional
import torch
from loguru import logger
from batdetect2.models.backbones import BackboneConfig, build_backbone
from batdetect2.models.heads import BBoxHead, ClassifierHead
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
__all__ = [
"Detector",
"build_detector",
]
class Detector(DetectionModel):
@ -119,36 +126,41 @@ class Detector(DetectionModel):
)
def build_detector(num_classes: int, backbone: BackboneModel) -> Detector:
"""Factory function to build a standard Detector model instance.
Creates the standard prediction heads (`ClassifierHead`, `DetectorHead`,
`BBoxHead`) configured appropriately based on the output channels of the
provided `backbone` and the specified `num_classes`. It then assembles
these components into a `Detector` model.
def build_detector(
num_classes: int, config: Optional[BackboneConfig] = None
) -> DetectionModel:
"""Build the complete BatDetect2 detection model.
Parameters
----------
num_classes : int
The number of specific target classes for the classification head
(excluding any implicit background class). Must be positive.
backbone : BackboneModel
An initialized feature extraction backbone module instance. The number
of output channels from this backbone (`backbone.out_channels`) is used
to configure the input channels for the prediction heads.
The number of specific target classes the model should predict
(required for the `ClassifierHead`). Must be positive.
config : BackboneConfig, optional
Configuration object specifying the architecture of the backbone
(encoder, bottleneck, decoder). If None, default configurations defined
within the respective builder functions (`build_encoder`, etc.) will be
used to construct a default backbone architecture.
Returns
-------
Detector
DetectionModel
An initialized `Detector` model instance.
Raises
------
ValueError
If `num_classes` is not positive.
AttributeError
If `backbone` does not have the required `out_channels` attribute.
If `num_classes` is not positive, or if errors occur during the
construction of the backbone or detector components (e.g., incompatible
configurations, invalid parameters).
"""
config = config or BackboneConfig()
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(),
)
backbone = build_backbone(config=config)
classifier_head = ClassifierHead(
num_classes=num_classes,
in_channels=backbone.out_channels,

View File

@ -26,7 +26,7 @@ import torch
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.models.blocks import (
ConvConfig,
FreqCoordConvDownConfig,
@ -49,7 +49,7 @@ EncoderLayerConfig = Annotated[
StandardConvDownConfig,
LayerGroupConfig,
],
Field(discriminator="block_type"),
Field(discriminator="name"),
]
"""Type alias for the discriminated union of block configs usable in Encoder."""
@ -66,7 +66,7 @@ class EncoderConfig(BaseConfig):
An ordered list of configuration objects, each defining one layer or
block in the encoder sequence. Each item must be a valid block config
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
`StandardConvDownConfig`) including a `block_type` field and necessary
`StandardConvDownConfig`) including a `name` field and necessary
parameters like `out_channels`. Input channels for each layer are
inferred sequentially. The list must contain at least one layer.
"""
@ -287,9 +287,9 @@ def build_encoder(
------
ValueError
If `in_channels` or `input_height` are not positive, or if the layer
configuration is invalid (e.g., empty list, unknown `block_type`).
configuration is invalid (e.g., empty list, unknown `name`).
NotImplementedError
If `build_layer_from_config` encounters an unknown `block_type`.
If `build_layer_from_config` encounters an unknown `name`.
"""
if in_channels <= 0 or input_height <= 0:
raise ValueError("in_channels and input_height must be positive.")

View File

@ -1,11 +1,16 @@
from batdetect2.plotting.clip_annotations import plot_clip_annotation
from batdetect2.plotting.clip_predictions import plot_clip_prediction
from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.common import plot_spectrogram
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.heatmaps import (
plot_classification_heatmap,
plot_detection_heatmap,
)
from batdetect2.plotting.matches import (
plot_cross_trigger_match,
plot_false_negative_match,
plot_false_positive_match,
plot_matches,
plot_true_positive_match,
)
@ -13,9 +18,12 @@ __all__ = [
"plot_clip",
"plot_clip_annotation",
"plot_clip_prediction",
"plot_matches",
"plot_false_positive_match",
"plot_true_positive_match",
"plot_false_negative_match",
"plot_cross_trigger_match",
"plot_false_negative_match",
"plot_false_positive_match",
"plot_spectrogram",
"plot_true_positive_match",
"plot_detection_heatmap",
"plot_classification_heatmap",
"plot_match_gallery",
]

View File

@ -4,7 +4,9 @@ from matplotlib.axes import Axes
from soundevent import data, plot
from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import PreprocessorProtocol
from batdetect2.plotting.common import create_ax
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"plot_clip_annotation",
@ -17,8 +19,6 @@ def plot_clip_annotation(
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
cmap: str = "gray",
alpha: float = 1,
@ -31,8 +31,6 @@ def plot_clip_annotation(
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=cmap,
)
@ -47,3 +45,29 @@ def plot_clip_annotation(
facecolor="none" if not fill else None,
)
return ax
def plot_anchor_points(
clip_annotation: data.ClipAnnotation,
targets: TargetProtocol,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
size: int = 1,
color: str = "red",
marker: str = "x",
alpha: float = 1,
) -> Axes:
ax = create_ax(ax=ax, figsize=figsize)
positions = []
for sound_event in clip_annotation.sound_events:
if not targets.filter(sound_event):
continue
position, _ = targets.encode_roi(sound_event)
positions.append(position)
X, Y = zip(*positions)
ax.scatter(X, Y, s=size, c=color, marker=marker, alpha=alpha)
return ax

View File

@ -8,7 +8,7 @@ from soundevent.plot.geometries import plot_geometry
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import PreprocessorProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol
__all__ = [
"plot_clip_prediction",
@ -21,8 +21,6 @@ def plot_clip_prediction(
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
add_colorbar: bool = False,
add_labels: bool = False,
add_legend: bool = False,
spec_cmap: str = "gray",
linewidth: float = 1,
@ -34,8 +32,6 @@ def plot_clip_prediction(
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)

View File

@ -1,13 +1,14 @@
from typing import Optional, Tuple
import matplotlib.pyplot as plt
import torch
from matplotlib.axes import Axes
from soundevent import data
from batdetect2.preprocess import (
PreprocessorProtocol,
get_default_preprocessor,
)
from batdetect2.audio import build_audio_loader
from batdetect2.plotting.common import plot_spectrogram
from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol
__all__ = [
"plot_clip",
@ -16,29 +17,33 @@ __all__ = [
def plot_clip(
clip: data.Clip,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
add_colorbar: bool = False,
add_labels: bool = False,
spec_cmap: str = "gray",
) -> Axes:
if ax is None:
_, ax = plt.subplots(figsize=figsize)
if preprocessor is None:
preprocessor = get_default_preprocessor()
preprocessor = build_preprocessor()
spec = preprocessor.preprocess_clip(clip, audio_dir=audio_dir)
if audio_loader is None:
audio_loader = build_audio_loader()
spec.plot( # type: ignore
wav = torch.tensor(audio_loader.load_clip(clip, audio_dir=audio_dir))
spec = preprocessor(wav)
plot_spectrogram(
spec,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
ax=ax,
add_colorbar=add_colorbar,
cmap=spec_cmap,
add_labels=add_labels,
vmin=spec.min().item(),
vmax=spec.max().item(),
)
return ax

View File

@ -1,8 +1,10 @@
"""General plotting utilities."""
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib import axes
__all__ = [
@ -12,11 +14,62 @@ __all__ = [
def create_ax(
ax: Optional[axes.Axes] = None,
figsize: Tuple[int, int] = (10, 10),
figsize: Optional[Tuple[int, int]] = None,
**kwargs,
) -> axes.Axes:
"""Create a new axis if none is provided"""
if ax is None:
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
_, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) # type: ignore
return ax # type: ignore
def plot_spectrogram(
spec: Union[torch.Tensor, np.ndarray],
start_time: Optional[float] = None,
end_time: Optional[float] = None,
min_freq: Optional[float] = None,
max_freq: Optional[float] = None,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_colorbar: bool = False,
colorbar_kwargs: Optional[dict] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
cmap="gray",
) -> axes.Axes:
if isinstance(spec, torch.Tensor):
spec = spec.numpy()
spec = spec.squeeze()
ax = create_ax(ax=ax, figsize=figsize)
if start_time is None:
start_time = 0
if end_time is None:
end_time = spec.shape[-1]
if min_freq is None:
min_freq = 0
if max_freq is None:
max_freq = spec.shape[-2]
mappable = ax.pcolormesh(
np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True),
np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True),
spec,
cmap=cmap,
vmin=vmin,
vmax=vmax,
)
ax.set_xlim(start_time, end_time)
ax.set_ylim(min_freq, max_freq)
if add_colorbar:
plt.colorbar(mappable, ax=ax, **(colorbar_kwargs or {}))
return ax

View File

@ -0,0 +1,113 @@
from typing import Optional
from matplotlib import axes, patches
from soundevent.plot import plot_geometry
from batdetect2.evaluate.metrics.detection import ClipEval
from batdetect2.plotting.clips import (
AudioLoader,
PreprocessorProtocol,
plot_clip,
)
from batdetect2.plotting.common import create_ax
__all__ = [
"plot_clip_detections",
]
def plot_clip_detections(
clip_eval: ClipEval,
figsize: tuple[int, int] = (10, 10),
ax: Optional[axes.Axes] = None,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
threshold: float = 0.2,
add_legend: bool = True,
add_title: bool = True,
fill: bool = False,
linewidth: float = 1.0,
gt_color: str = "green",
gt_linestyle: str = "-",
true_pred_color: str = "yellow",
true_pred_linestyle: str = "--",
false_pred_color: str = "blue",
false_pred_linestyle: str = "-",
missed_gt_color: str = "red",
missed_gt_linestyle: str = "-",
) -> axes.Axes:
ax = create_ax(figsize=figsize, ax=ax)
plot_clip(
clip_eval.clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
ax=ax,
)
for m in clip_eval.matches:
is_match = (
m.pred is not None and m.gt is not None and m.score >= threshold
)
if m.pred is not None:
color = true_pred_color if is_match else false_pred_color
plot_geometry(
m.pred.geometry,
ax=ax,
add_points=False,
facecolor="none" if not fill else color,
alpha=m.pred.detection_score,
linewidth=linewidth,
linestyle=true_pred_linestyle
if is_match
else missed_gt_linestyle,
color=color,
)
if m.gt is not None:
color = gt_color if is_match else missed_gt_color
plot_geometry(
m.gt.sound_event.geometry, # type: ignore
ax=ax,
add_points=False,
linewidth=linewidth,
facecolor="none" if not fill else color,
linestyle=gt_linestyle if is_match else false_pred_linestyle,
color=color,
)
if add_title:
ax.set_title(clip_eval.clip.recording.path.name)
if add_legend:
ax.legend(
handles=[
patches.Patch(
label="found GT",
edgecolor=gt_color,
facecolor="none" if not fill else gt_color,
linestyle=gt_linestyle,
),
patches.Patch(
label="missed GT",
edgecolor=missed_gt_color,
facecolor="none" if not fill else missed_gt_color,
linestyle=missed_gt_linestyle,
),
patches.Patch(
label="true Det",
edgecolor=true_pred_color,
facecolor="none" if not fill else true_pred_color,
linestyle=true_pred_linestyle,
),
patches.Patch(
label="false Det",
edgecolor=false_pred_color,
facecolor="none" if not fill else false_pred_color,
linestyle=false_pred_linestyle,
),
]
)
return ax

View File

@ -1,160 +0,0 @@
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import List
import matplotlib.pyplot as plt
import pandas as pd
from batdetect2 import plotting
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.preprocess.types import PreprocessorProtocol
@dataclass
class ClassExamples:
false_positives: List[MatchEvaluation] = field(default_factory=list)
false_negatives: List[MatchEvaluation] = field(default_factory=list)
true_positives: List[MatchEvaluation] = field(default_factory=list)
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def plot_example_gallery(
matches: List[MatchEvaluation],
preprocessor: PreprocessorProtocol,
n_examples: int = 5,
):
class_examples = defaultdict(ClassExamples)
for match in matches:
gt_class = match.gt_class
pred_class = match.pred_class
if pred_class is None:
class_examples[gt_class].false_negatives.append(match)
continue
if gt_class is None:
class_examples[pred_class].false_positives.append(match)
continue
if gt_class != pred_class:
class_examples[gt_class].cross_triggers.append(match)
class_examples[pred_class].cross_triggers.append(match)
continue
class_examples[gt_class].true_positives.append(match)
for class_name, examples in class_examples.items():
true_positives = get_binned_sample(
examples.true_positives,
n_examples=n_examples,
)
false_positives = get_binned_sample(
examples.false_positives,
n_examples=n_examples,
)
false_negatives = random.sample(
examples.false_negatives,
k=min(n_examples, len(examples.false_negatives)),
)
cross_triggers = get_binned_sample(
examples.cross_triggers,
n_examples=n_examples,
)
fig = plot_class_examples(
true_positives,
false_positives,
false_negatives,
cross_triggers,
preprocessor=preprocessor,
n_examples=n_examples,
)
yield class_name, fig
plt.close(fig)
def plot_class_examples(
true_positives: List[MatchEvaluation],
false_positives: List[MatchEvaluation],
false_negatives: List[MatchEvaluation],
cross_triggers: List[MatchEvaluation],
preprocessor: PreprocessorProtocol,
n_examples: int = 5,
duration: float = 0.1,
):
fig = plt.figure(figsize=(20, 20))
for index, match in enumerate(true_positives[:n_examples]):
ax = plt.subplot(4, n_examples, index + 1)
try:
plotting.plot_true_positive_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError):
continue
for index, match in enumerate(false_positives[:n_examples]):
ax = plt.subplot(4, n_examples, n_examples + index + 1)
try:
plotting.plot_false_positive_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError):
continue
for index, match in enumerate(false_negatives[:n_examples]):
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
try:
plotting.plot_false_negative_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError):
continue
for index, match in enumerate(cross_triggers[:n_examples]):
ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1)
try:
plotting.plot_cross_trigger_match(
match,
ax=ax,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError):
continue
return fig
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
if len(matches) < n_examples:
return matches
indices, pred_scores = zip(
*[
(index, match.pred_class_scores[pred_class])
for index, match in enumerate(matches)
if (pred_class := match.pred_class) is not None
]
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
df = pd.DataFrame({"indices": indices, "bins": bins})
sample = df.groupby("bins").apply(lambda x: x.sample(1))
return [matches[ind] for ind in sample["indices"]]

View File

@ -0,0 +1,109 @@
from typing import Optional, Sequence
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from batdetect2.plotting.matches import (
MatchProtocol,
plot_cross_trigger_match,
plot_false_negative_match,
plot_false_positive_match,
plot_true_positive_match,
)
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
__all__ = ["plot_match_gallery"]
def plot_match_gallery(
true_positives: Sequence[MatchProtocol],
false_positives: Sequence[MatchProtocol],
false_negatives: Sequence[MatchProtocol],
cross_triggers: Sequence[MatchProtocol],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
n_examples: int = 5,
duration: float = 0.1,
fig: Optional[Figure] = None,
):
if fig is None:
fig = plt.figure(figsize=(20, 20))
axes = fig.subplots(
nrows=4,
ncols=n_examples,
sharex="none",
sharey="row",
)
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples]):
try:
plot_true_positive_match(
tp_match,
ax=tp_ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (
ValueError,
AssertionError,
RuntimeError,
FileNotFoundError,
):
continue
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples]):
try:
plot_false_positive_match(
fp_match,
ax=fp_ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (
ValueError,
AssertionError,
RuntimeError,
FileNotFoundError,
):
continue
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples]):
try:
plot_false_negative_match(
fn_match,
ax=fn_ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (
ValueError,
AssertionError,
RuntimeError,
FileNotFoundError,
):
continue
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples]):
try:
plot_cross_trigger_match(
ct_match,
ax=ct_ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (
ValueError,
AssertionError,
RuntimeError,
FileNotFoundError,
):
continue
fig.tight_layout()
return fig

View File

@ -1,26 +1,117 @@
"""Plot heatmaps"""
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union
import xarray as xr
from matplotlib import axes
import numpy as np
import torch
from matplotlib import axes, patches
from matplotlib.cm import get_cmap
from matplotlib.colors import Colormap, LinearSegmentedColormap, to_rgba
from batdetect2.plotting.common import create_ax
def plot_heatmap(
heatmap: xr.DataArray,
def plot_detection_heatmap(
heatmap: Union[torch.Tensor, np.ndarray],
ax: Optional[axes.Axes] = None,
figsize: Tuple[int, int] = (10, 10),
threshold: Optional[float] = None,
alpha: float = 1,
cmap: Union[str, Colormap] = "jet",
color: Optional[str] = None,
) -> axes.Axes:
ax = create_ax(ax, figsize=figsize)
if isinstance(heatmap, torch.Tensor):
heatmap = heatmap.numpy()
heatmap = heatmap.squeeze()
if threshold is not None:
heatmap = np.ma.masked_where(
heatmap < threshold,
heatmap,
)
if color is not None:
cmap = create_colormap(color)
ax.pcolormesh(
heatmap.time,
heatmap.frequency,
heatmap,
vmax=1,
vmin=0,
cmap=cmap,
alpha=alpha,
)
return ax
def plot_classification_heatmap(
heatmap: Union[torch.Tensor, np.ndarray],
ax: Optional[axes.Axes] = None,
figsize: Tuple[int, int] = (10, 10),
class_names: Optional[List[str]] = None,
threshold: Optional[float] = 0.1,
alpha: float = 1,
cmap: Union[str, Colormap] = "tab20",
):
ax = create_ax(ax, figsize=figsize)
if isinstance(heatmap, torch.Tensor):
heatmap = heatmap.numpy()
if heatmap.ndim == 4:
heatmap = heatmap[0]
if heatmap.ndim != 3:
raise ValueError("Expecting a 3-dimensional array")
num_classes = heatmap.shape[0]
if class_names is None:
class_names = [f"class_{i}" for i in range(num_classes)]
if len(class_names) != num_classes:
raise ValueError("Inconsistent number of class names")
if not isinstance(cmap, Colormap):
cmap = get_cmap(cmap)
handles = []
for index, class_heatmap in enumerate(heatmap):
class_name = class_names[index]
color = cmap(index / num_classes)
max = class_heatmap.max()
if max == 0:
continue
if threshold is not None:
class_heatmap = np.ma.masked_where(
class_heatmap < threshold,
class_heatmap,
)
ax.pcolormesh(
class_heatmap,
vmax=1,
vmin=0,
cmap=create_colormap(color), # type: ignore
alpha=alpha,
)
handles.append(patches.Patch(color=color, label=class_name))
ax.legend(handles=handles)
return ax
def create_colormap(color: str) -> Colormap:
(r, g, b, a) = to_rgba(color)
return LinearSegmentedColormap.from_list(
"cmap", colors=[(0, 0, 0, 0), (r, g, b, a)]
)

View File

@ -1,21 +1,17 @@
from typing import List, Optional, Tuple, Union
from typing import Optional, Protocol, Tuple, Union
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from soundevent import data, plot
from soundevent.geometry import compute_bounds
from soundevent.plot.tags import TagColorMapper
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.plotting.clip_predictions import plot_prediction
from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import (
from batdetect2.typing import (
AudioLoader,
PreprocessorProtocol,
get_default_preprocessor,
RawPrediction,
)
__all__ = [
"plot_matches",
"plot_false_positive_match",
"plot_true_positive_match",
"plot_false_negative_match",
@ -23,6 +19,14 @@ __all__ = [
]
class MatchProtocol(Protocol):
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
score: float
true_class: Optional[str]
DEFAULT_DURATION = 0.05
DEFAULT_FALSE_POSITIVE_COLOR = "orange"
DEFAULT_FALSE_NEGATIVE_COLOR = "red"
@ -32,278 +36,191 @@ DEFAULT_ANNOTATION_LINE_STYLE = "-"
DEFAULT_PREDICTION_LINE_STYLE = "--"
def plot_matches(
matches: List[data.Match],
clip: data.Clip,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
color_mapper: Optional[TagColorMapper] = None,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:
if preprocessor is None:
preprocessor = get_default_preprocessor()
ax = plot_clip(
clip,
ax=ax,
figsize=figsize,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
if color_mapper is None:
color_mapper = TagColorMapper()
for match in matches:
if match.source is None and match.target is not None:
plot.plot_annotation(
annotation=match.target,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=false_negative_color,
color_mapper=color_mapper,
linestyle=annotation_linestyle,
)
elif match.target is None and match.source is not None:
plot_prediction(
prediction=match.source,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=false_positive_color,
color_mapper=color_mapper,
linestyle=prediction_linestyle,
)
elif match.target is not None and match.source is not None:
plot.plot_annotation(
annotation=match.target,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=true_positive_color,
color_mapper=color_mapper,
linestyle=annotation_linestyle,
)
plot_prediction(
prediction=match.source,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=true_positive_color,
color_mapper=color_mapper,
linestyle=prediction_linestyle,
)
else:
continue
return ax
def plot_false_positive_match(
match: MatchEvaluation,
match: MatchProtocol,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
use_score: bool = True,
add_spectrogram: bool = True,
add_text: bool = True,
add_points: bool = False,
add_title: bool = True,
fill: bool = False,
spec_cmap: str = "gray",
time_offset: float = 0,
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small",
) -> Axes:
assert match.match.source is not None
assert match.match.target is None
sound_event = match.match.source.sound_event
geometry = sound_event.geometry
assert geometry is not None
assert match.pred is not None
start_time, _, _, high_freq = compute_bounds(geometry)
start_time, _, _, high_freq = compute_bounds(match.pred.geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
start_time=max(
start_time - duration / 2,
0,
),
end_time=min(
start_time + duration / 2,
sound_event.recording.duration,
match.clip.recording.duration,
),
recording=sound_event.recording,
recording=match.clip.recording,
)
ax = plot_clip(
clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
if add_spectrogram:
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
plot_prediction(
match.match.source,
ax = plot.plot_geometry(
match.pred.geometry,
ax=ax,
time_offset=time_offset,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
alpha=match.score if use_score else 1,
color=color,
)
plt.text(
start_time,
high_freq,
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_text:
ax.text(
start_time,
high_freq,
f"score={match.score:.2f}",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_title:
ax.set_title("False Positive")
return ax
def plot_false_negative_match(
match: MatchEvaluation,
match: MatchProtocol,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_spectrogram: bool = True,
add_points: bool = False,
add_title: bool = True,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
fontsize: Union[float, str] = "small",
) -> Axes:
assert match.match.source is None
assert match.match.target is not None
sound_event = match.match.target.sound_event
geometry = sound_event.geometry
assert match.gt is not None
geometry = match.gt.sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
start_time = compute_bounds(geometry)[0]
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2, sound_event.recording.duration
start_time=max(
start_time - duration / 2,
0,
),
recording=sound_event.recording,
end_time=min(
start_time + duration / 2,
match.clip.recording.duration,
),
recording=match.clip.recording,
)
ax = plot_clip(
clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
if add_spectrogram:
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
match.match.target,
ax = plot.plot_geometry(
geometry,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
)
plt.text(
start_time,
high_freq,
f"False Negative \nClass: {match.gt_class} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_title:
ax.set_title("False Negative")
return ax
def plot_true_positive_match(
match: MatchEvaluation,
match: MatchProtocol,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
use_score: bool = True,
add_spectrogram: bool = True,
add_points: bool = False,
add_text: bool = True,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small",
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
add_title: bool = True,
) -> Axes:
assert match.match.source is not None
assert match.match.target is not None
sound_event = match.match.target.sound_event
geometry = sound_event.geometry
assert match.gt is not None
assert match.pred is not None
geometry = match.gt.sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2, sound_event.recording.duration
start_time=max(
start_time - duration / 2,
0,
),
recording=sound_event.recording,
end_time=min(
start_time + duration / 2,
match.clip.recording.duration,
),
recording=match.clip.recording,
)
ax = plot_clip(
clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
if add_spectrogram:
ax = plot_clip(
clip,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
match.match.target,
ax = plot.plot_geometry(
geometry,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
@ -311,41 +228,46 @@ def plot_true_positive_match(
linestyle=annotation_linestyle,
)
plot_prediction(
match.match.source,
plot.plot_geometry(
match.pred.geometry,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
alpha=match.score if use_score else 1,
color=color,
linestyle=prediction_linestyle,
)
plt.text(
start_time,
high_freq,
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_text:
ax.text(
start_time,
high_freq,
f"score={match.score:.2f}",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_title:
ax.set_title("True Positive")
return ax
def plot_cross_trigger_match(
match: MatchEvaluation,
match: MatchProtocol,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
use_score: bool = True,
add_spectrogram: bool = True,
add_points: bool = False,
add_text: bool = True,
add_title: bool = True,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
@ -353,38 +275,40 @@ def plot_cross_trigger_match(
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:
assert match.match.source is not None
assert match.match.target is not None
sound_event = match.match.source.sound_event
geometry = sound_event.geometry
assert match.gt is not None
assert match.pred is not None
geometry = match.gt.sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2, sound_event.recording.duration
start_time=max(
start_time - duration / 2,
0,
),
recording=sound_event.recording,
end_time=min(
start_time + duration / 2,
match.clip.recording.duration,
),
recording=match.clip.recording,
)
ax = plot_clip(
clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
if add_spectrogram:
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
match.match.target,
ax = plot.plot_geometry(
geometry,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
@ -392,26 +316,29 @@ def plot_cross_trigger_match(
linestyle=annotation_linestyle,
)
plot_prediction(
match.match.source,
ax = plot.plot_geometry(
match.pred.geometry,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
alpha=match.score if use_score else 1,
color=color,
linestyle=prediction_linestyle,
)
plt.text(
start_time,
high_freq,
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_text:
ax.text(
start_time,
high_freq,
f"score={match.score:.2f}\nclass={match.true_class}",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_title:
ax.set_title("Cross Trigger")
return ax

View File

@ -0,0 +1,286 @@
from typing import Dict, Optional, Tuple
import numpy as np
import seaborn as sns
from cycler import cycler
from matplotlib import axes
from batdetect2.plotting.common import create_ax
def set_default_styler(ax: axes.Axes) -> axes.Axes:
color_cycler = cycler(color=sns.color_palette("muted"))
style_cycler = cycler(linestyle=["-", "--", ":"]) * cycler(
marker=["o", "s", "^"]
)
custom_cycler = color_cycler * len(style_cycler) + style_cycler * len(
color_cycler
)
ax.set_prop_cycle(custom_cycler)
return ax
def set_default_style(ax: axes.Axes) -> axes.Axes:
ax = set_default_styler(ax)
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)
return ax
def plot_pr_curve(
precision: np.ndarray,
recall: np.ndarray,
thresholds: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
add_legend: bool = False,
label: str = "PR Curve",
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(
recall,
precision,
label=label,
marker="o",
markevery=_get_marker_positions(thresholds),
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_legend:
ax.legend()
if add_labels:
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
return ax
def plot_pr_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (precision, recall, thresholds) in data.items():
ax.plot(
recall,
precision,
label=name,
markevery=_get_marker_positions(thresholds),
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
return ax
def plot_threshold_precision_curve(
threshold: np.ndarray,
precision: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(threshold, precision, markevery=_get_marker_positions(threshold))
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Precision")
return ax
def plot_threshold_precision_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (precision, _, thresholds) in data.items():
ax.plot(
thresholds,
precision,
label=name,
markevery=_get_marker_positions(thresholds),
)
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Precision")
return ax
def plot_threshold_recall_curve(
threshold: np.ndarray,
recall: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(threshold, recall, markevery=_get_marker_positions(threshold))
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Recall")
return ax
def plot_threshold_recall_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (_, recall, thresholds) in data.items():
ax.plot(
thresholds,
recall,
label=name,
markevery=_get_marker_positions(thresholds),
)
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Recall")
return ax
def plot_roc_curve(
fpr: np.ndarray,
tpr: np.ndarray,
thresholds: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(
fpr,
tpr,
markevery=_get_marker_positions(thresholds),
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
return ax
def plot_roc_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (fpr, tpr, thresholds) in data.items():
ax.plot(
fpr,
tpr,
label=name,
markevery=_get_marker_positions(thresholds),
)
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
return ax
def _get_marker_positions(
thresholds: np.ndarray,
n_points: int = 11,
) -> np.ndarray:
size = len(thresholds)
cut_points = np.linspace(0, 1, n_points)
indices = np.searchsorted(thresholds[::-1], cut_points)
return np.clip(size - indices, 0, size - 1) # type: ignore

View File

@ -1,598 +1,25 @@
"""Main entry point for the BatDetect2 Postprocessing pipeline.
"""Main entry point for the BatDetect2 Postprocessing pipeline."""
This package (`batdetect2.postprocess`) takes the raw outputs from a trained
BatDetect2 neural network model and transforms them into meaningful, structured
predictions, typically in the form of `soundevent.data.ClipPrediction` objects
containing detected sound events with associated class tags and geometry.
The pipeline involves several configurable steps, implemented in submodules:
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
2. Coordinate Remapping (`.remapping`): Adds time/frequency coordinates to raw
model output arrays.
3. Detection Extraction (`.detection`): Identifies candidate detection points
(location and score) based on thresholds and score ranking (top-k).
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
class probabilities, features) at the detected locations.
5. Decoding & Formatting (`.decoding`): Converts extracted numerical data and
class predictions into interpretable `soundevent` objects, including
recovering geometry (ROIs) and decoding class names back to standard tags.
This module provides the primary interface:
- `PostprocessConfig`: A configuration object for postprocessing parameters
(thresholds, NMS kernel size, etc.).
- `load_postprocess_config`: Function to load the configuration from a file.
- `Postprocessor`: The main class (implementing `PostprocessorProtocol`) that
holds the configured pipeline logic.
- `build_postprocessor`: A factory function to create a `Postprocessor`
instance, linking it to the necessary target definitions (`TargetProtocol`).
It also re-exports key components from submodules for convenience.
"""
from typing import List, Optional
import xarray as xr
from loguru import logger
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.types import ModelOutput
from batdetect2.postprocess.config import (
PostprocessConfig,
load_postprocess_config,
)
from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction,
convert_xr_dataset_to_raw_prediction,
to_raw_predictions,
)
from batdetect2.postprocess.detection import (
DEFAULT_DETECTION_THRESHOLD,
TOP_K_PER_SEC,
extract_detections_from_array,
get_max_detections,
from batdetect2.postprocess.nms import non_max_suppression
from batdetect2.postprocess.postprocessor import (
Postprocessor,
build_postprocessor,
)
from batdetect2.postprocess.extraction import (
extract_detection_xr_dataset,
)
from batdetect2.postprocess.nms import (
NMS_KERNEL_SIZE,
non_max_suppression,
)
from batdetect2.postprocess.remapping import (
classification_to_xarray,
detection_to_xarray,
features_to_xarray,
sizes_to_xarray,
)
from batdetect2.postprocess.types import (
BatDetect2Prediction,
PostprocessorProtocol,
RawPrediction,
)
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets.types import TargetProtocol
__all__ = [
"DEFAULT_CLASSIFICATION_THRESHOLD",
"DEFAULT_DETECTION_THRESHOLD",
"MAX_FREQ",
"MIN_FREQ",
"ModelOutput",
"NMS_KERNEL_SIZE",
"PostprocessConfig",
"Postprocessor",
"PostprocessorProtocol",
"RawPrediction",
"TOP_K_PER_SEC",
"build_postprocessor",
"classification_to_xarray",
"convert_raw_predictions_to_clip_prediction",
"convert_xr_dataset_to_raw_prediction",
"detection_to_xarray",
"extract_detection_xr_dataset",
"extract_detections_from_array",
"features_to_xarray",
"get_max_detections",
"to_raw_predictions",
"load_postprocess_config",
"non_max_suppression",
"sizes_to_xarray",
]
class PostprocessConfig(BaseConfig):
"""Configuration settings for the postprocessing pipeline.
Defines tunable parameters that control how raw model outputs are
converted into final detections.
Attributes
----------
nms_kernel_size : int, default=NMS_KERNEL_SIZE
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
Used to suppress weaker detections near stronger peaks. Must be
positive.
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
Minimum confidence score from the detection heatmap required to
consider a point as a potential detection. Must be >= 0.
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
Minimum confidence score for a specific class prediction to be included
in the decoded tags for a detection. Must be >= 0.
top_k_per_sec : int, default=TOP_K_PER_SEC
Desired maximum number of detections per second of audio. Used by
`get_max_detections` to calculate an absolute limit based on clip
duration before applying `extract_detections_from_array`. Must be
positive.
"""
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
detection_threshold: float = Field(
default=DEFAULT_DETECTION_THRESHOLD,
ge=0,
)
classification_threshold: float = Field(
default=DEFAULT_CLASSIFICATION_THRESHOLD,
ge=0,
)
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
def load_postprocess_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PostprocessConfig:
"""Load the postprocessing configuration from a file.
Reads a configuration file (YAML) and validates it against the
`PostprocessConfig` schema, potentially extracting data from a nested
field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
postprocessing configuration (e.g., "inference.postprocessing").
If None, the entire file content is used.
Returns
-------
PostprocessConfig
The loaded and validated postprocessing configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded configuration data does not conform to the
`PostprocessConfig` schema.
KeyError, TypeError
If `field` specifies an invalid path within the loaded data.
"""
return load_config(path, schema=PostprocessConfig, field=field)
def build_postprocessor(
targets: TargetProtocol,
config: Optional[PostprocessConfig] = None,
max_freq: float = MAX_FREQ,
min_freq: float = MIN_FREQ,
) -> PostprocessorProtocol:
"""Factory function to build the standard postprocessor.
Creates and initializes the `Postprocessor` instance, providing it with the
necessary `targets` object and the `PostprocessConfig`.
Parameters
----------
targets : TargetProtocol
An initialized object conforming to the `TargetProtocol`, providing
methods like `.decode()` and `.recover_roi()`, and attributes like
`.class_names` and `.generic_class_tags`. This links postprocessing
to the defined target semantics and geometry mappings.
config : PostprocessConfig, optional
Configuration object specifying postprocessing parameters (thresholds,
NMS kernel size, etc.). If None, default settings defined in
`PostprocessConfig` will be used.
min_freq : int, default=MIN_FREQ
The minimum frequency (Hz) corresponding to the frequency axis of the
model outputs. Required for coordinate remapping. Consider setting via
`PostprocessConfig` instead for better encapsulation.
max_freq : int, default=MAX_FREQ
The maximum frequency (Hz) corresponding to the frequency axis of the
model outputs. Required for coordinate remapping. Consider setting via
`PostprocessConfig`.
Returns
-------
PostprocessorProtocol
An initialized `Postprocessor` instance ready to process model outputs.
"""
config = config or PostprocessConfig()
logger.opt(lazy=True).debug(
"Building postprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
return Postprocessor(
targets=targets,
config=config,
min_freq=min_freq,
max_freq=max_freq,
)
class Postprocessor(PostprocessorProtocol):
"""Standard implementation of the postprocessing pipeline.
This class orchestrates the steps required to convert raw model outputs
into interpretable `soundevent` predictions. It uses configured parameters
and leverages functions from the `batdetect2.postprocess` submodules for
each stage (NMS, remapping, detection, extraction, decoding).
It requires a `TargetProtocol` object during initialization to access
necessary decoding information (class name to tag mapping,
ROI recovery logic) ensuring consistency with the target definitions used
during training or specified for inference.
Instances are typically created using the `build_postprocessor` factory
function.
Attributes
----------
targets : TargetProtocol
The configured target definition object providing decoding and ROI
recovery.
config : PostprocessConfig
Configuration object holding parameters for NMS, thresholds, etc.
min_freq : float
Minimum frequency (Hz) assumed for the model output's frequency axis.
max_freq : float
Maximum frequency (Hz) assumed for the model output's frequency axis.
"""
targets: TargetProtocol
def __init__(
self,
targets: TargetProtocol,
config: PostprocessConfig,
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
):
"""Initialize the Postprocessor.
Parameters
----------
targets : TargetProtocol
Initialized target definition object.
config : PostprocessConfig
Configuration for postprocessing parameters.
min_freq : int, default=MIN_FREQ
Minimum frequency (Hz) for coordinate remapping.
max_freq : int, default=MAX_FREQ
Maximum frequency (Hz) for coordinate remapping.
"""
self.targets = targets
self.config = config
self.min_freq = min_freq
self.max_freq = max_freq
def get_feature_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Extract and remap raw feature tensors for a batch.
Parameters
----------
output : ModelOutput
Raw model output containing `output.features` tensor for the batch.
clips : List[data.Clip]
List of Clip objects corresponding to the batch items.
Returns
-------
List[xr.DataArray]
List of coordinate-aware feature DataArrays, one per clip.
Raises
------
ValueError
If batch sizes of `output.features` and `clips` do not match.
"""
if len(clips) != len(output.features):
raise ValueError(
"Number of clips and batch size of feature array"
"do not match. "
f"(clips: {len(clips)}, features: {len(output.features)})"
)
return [
features_to_xarray(
feats,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for feats, clip in zip(output.features, clips)
]
def get_detection_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Apply NMS and remap detection heatmaps for a batch.
Parameters
----------
output : ModelOutput
Raw model output containing `output.detection_probs` tensor for the
batch.
clips : List[data.Clip]
List of Clip objects corresponding to the batch items.
Returns
-------
List[xr.DataArray]
List of NMS-applied, coordinate-aware detection heatmaps, one per
clip.
Raises
------
ValueError
If batch sizes of `output.detection_probs` and `clips` do not match.
"""
detections = output.detection_probs
if len(clips) != len(output.detection_probs):
raise ValueError(
"Number of clips and batch size of detection array "
"do not match. "
f"(clips: {len(clips)}, detection: {len(detections)})"
)
detections = non_max_suppression(
detections,
kernel_size=self.config.nms_kernel_size,
)
return [
detection_to_xarray(
dets,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for dets, clip in zip(detections, clips)
]
def get_classification_arrays(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[xr.DataArray]:
"""Extract and remap raw classification tensors for a batch.
Parameters
----------
output : ModelOutput
Raw model output containing `output.class_probs` tensor for the
batch.
clips : List[data.Clip]
List of Clip objects corresponding to the batch items.
Returns
-------
List[xr.DataArray]
List of coordinate-aware class probability maps, one per clip.
Raises
------
ValueError
If batch sizes of `output.class_probs` and `clips` do not match, or
if number of classes mismatches `self.targets.class_names`.
"""
classifications = output.class_probs
if len(clips) != len(classifications):
raise ValueError(
"Number of clips and batch size of classification array "
"do not match. "
f"(clips: {len(clips)}, classification: {len(classifications)})"
)
return [
classification_to_xarray(
class_probs,
start_time=clip.start_time,
end_time=clip.end_time,
class_names=self.targets.class_names,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for class_probs, clip in zip(classifications, clips)
]
def get_sizes_arrays(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[xr.DataArray]:
"""Extract and remap raw size prediction tensors for a batch.
Parameters
----------
output : ModelOutput
Raw model output containing `output.size_preds` tensor for the
batch.
clips : List[data.Clip]
List of Clip objects corresponding to the batch items.
Returns
-------
List[xr.DataArray]
List of coordinate-aware size prediction maps, one per clip.
Raises
------
ValueError
If batch sizes of `output.size_preds` and `clips` do not match.
"""
sizes = output.size_preds
if len(clips) != len(sizes):
raise ValueError(
"Number of clips and batch size of sizes array do not match. "
f"(clips: {len(clips)}, sizes: {len(sizes)})"
)
return [
sizes_to_xarray(
size_preds,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for size_preds, clip in zip(sizes, clips)
]
def get_detection_datasets(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[xr.Dataset]:
"""Perform NMS, remapping, detection, and data extraction for a batch.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[xr.Dataset]
List of xarray Datasets (one per clip). Each Dataset contains
aligned scores, dimensions, class probabilities, and features for
detections found in that clip.
"""
detection_arrays = self.get_detection_arrays(output, clips)
classification_arrays = self.get_classification_arrays(output, clips)
size_arrays = self.get_sizes_arrays(output, clips)
features_arrays = self.get_feature_arrays(output, clips)
datasets = []
for det_array, class_array, sizes_array, feats_array in zip(
detection_arrays,
classification_arrays,
size_arrays,
features_arrays,
):
max_detections = get_max_detections(
det_array,
top_k_per_sec=self.config.top_k_per_sec,
)
positions = extract_detections_from_array(
det_array,
max_detections=max_detections,
threshold=self.config.detection_threshold,
)
datasets.append(
extract_detection_xr_dataset(
positions,
sizes_array,
class_array,
feats_array,
)
)
return datasets
def get_raw_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Processes raw model output through remapping, NMS, detection, data
extraction, and geometry recovery via the configured
`targets.recover_roi`.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[List[RawPrediction]]
List of lists (one inner list per input clip). Each inner list
contains `RawPrediction` objects for detections in that clip.
"""
detection_datasets = self.get_detection_datasets(output, clips)
return [
convert_xr_dataset_to_raw_prediction(
dataset,
self.targets.decode_roi,
)
for dataset in detection_datasets
]
def get_sound_event_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[BatDetect2Prediction]]:
raw_predictions = self.get_raw_predictions(output, clips)
return [
[
BatDetect2Prediction(
raw=raw,
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
targets=self.targets,
classification_threshold=self.config.classification_threshold,
),
)
for raw in predictions
]
for predictions, clip in zip(raw_predictions, clips)
]
def get_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output and corresponding clips, applies the entire
configured chain (NMS, remapping, extraction, geometry recovery, class
decoding), producing final `soundevent.data.ClipPrediction` objects.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[data.ClipPrediction]
List containing one `ClipPrediction` object for each input clip,
populated with `SoundEventPrediction` objects.
"""
raw_predictions = self.get_raw_predictions(output, clips)
return [
convert_raw_predictions_to_clip_prediction(
prediction,
clip,
targets=self.targets,
classification_threshold=self.config.classification_threshold,
)
for prediction, clip in zip(raw_predictions, clips)
]

View File

@ -0,0 +1,94 @@
from typing import Optional
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.postprocess.decoding import DEFAULT_CLASSIFICATION_THRESHOLD
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE
__all__ = [
"PostprocessConfig",
"load_postprocess_config",
]
DEFAULT_DETECTION_THRESHOLD = 0.01
TOP_K_PER_SEC = 100
class PostprocessConfig(BaseConfig):
"""Configuration settings for the postprocessing pipeline.
Defines tunable parameters that control how raw model outputs are
converted into final detections.
Attributes
----------
nms_kernel_size : int, default=NMS_KERNEL_SIZE
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
Used to suppress weaker detections near stronger peaks. Must be
positive.
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
Minimum confidence score from the detection heatmap required to
consider a point as a potential detection. Must be >= 0.
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
Minimum confidence score for a specific class prediction to be included
in the decoded tags for a detection. Must be >= 0.
top_k_per_sec : int, default=TOP_K_PER_SEC
Desired maximum number of detections per second of audio. Used by
`get_max_detections` to calculate an absolute limit based on clip
duration before applying `extract_detections_from_array`. Must be
positive.
"""
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
detection_threshold: float = Field(
default=DEFAULT_DETECTION_THRESHOLD,
ge=0,
)
classification_threshold: float = Field(
default=DEFAULT_CLASSIFICATION_THRESHOLD,
ge=0,
)
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
def load_postprocess_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PostprocessConfig:
"""Load the postprocessing configuration from a file.
Reads a configuration file (YAML) and validates it against the
`PostprocessConfig` schema, potentially extracting data from a nested
field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
postprocessing configuration (e.g., "inference.postprocessing").
If None, the entire file content is used.
Returns
-------
PostprocessConfig
The loaded and validated postprocessing configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded configuration data does not conform to the
`PostprocessConfig` schema.
KeyError, TypeError
If `field` specifies an invalid path within the loaded data.
"""
return load_config(path, schema=PostprocessConfig, field=field)

View File

@ -1,42 +1,18 @@
"""Decodes extracted detection data into standard soundevent predictions.
This module handles the final stages of the BatDetect2 postprocessing pipeline.
It takes the structured detection data extracted by the `extraction` module
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
class probabilities, and features for each detection point) and converts it
into standardized prediction objects based on the `soundevent` data model.
The process involves:
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
objects, using a configured geometry builder to recover bounding boxes from
predicted positions and sizes (`convert_xr_dataset_to_raw_prediction`).
2. Converting each `RawPrediction` into a
`soundevent.data.SoundEventPrediction`, which involves:
- Creating the `soundevent.data.SoundEvent` with geometry and features.
- Decoding the predicted class probabilities into representative tags using
a configured class decoder (`SoundEventDecoder`).
- Applying a classification threshold.
- Optionally selecting only the single highest-scoring class (top-1) or
including tags for all classes above the threshold (multi-label).
- Adding generic class tags as a baseline.
- Associating scores with the final prediction and tags.
(`convert_raw_prediction_to_sound_event_prediction`)
3. Grouping the `SoundEventPrediction` objects for a given audio segment into
a `soundevent.data.ClipPrediction`
(`convert_raw_predictions_to_clip_prediction`).
"""
"""Decodes extracted detection data into standard soundevent predictions."""
from typing import List, Optional
import numpy as np
import xarray as xr
from soundevent import data
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
from batdetect2.targets.types import TargetProtocol
from batdetect2.typing.postprocess import (
ClipDetectionsArray,
RawPrediction,
)
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"convert_xr_dataset_to_raw_prediction",
"to_raw_predictions",
"convert_raw_predictions_to_clip_prediction",
"convert_raw_prediction_to_sound_event_prediction",
"DEFAULT_CLASSIFICATION_THRESHOLD",
@ -51,65 +27,29 @@ decoding.
"""
def convert_xr_dataset_to_raw_prediction(
detection_dataset: xr.Dataset,
geometry_decoder: GeometryDecoder,
def to_raw_predictions(
detections: ClipDetectionsArray,
targets: TargetProtocol,
) -> List[RawPrediction]:
"""Convert an xarray.Dataset of detections to RawPrediction objects.
Takes the output of the extraction step (`extract_detection_xr_dataset`)
and transforms each detection entry into an intermediate `RawPrediction`
object. This involves recovering the geometry (e.g., bounding box) from
the predicted position and scaled size dimensions using the provided
`geometry_builder` function.
Parameters
----------
detection_dataset : xr.Dataset
An xarray Dataset containing aligned detection information, typically
output by `extract_detection_xr_dataset`. Expected variables include
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
Must have a 'detection' dimension.
geometry_decoder : GeometryDecoder
A function that takes a position tuple `(time, freq)` and a NumPy array
of dimensions, and returns the corresponding reconstructed
`soundevent.data.Geometry`.
Returns
-------
List[RawPrediction]
A list of `RawPrediction` objects, each containing the detection score,
recovered bounding box coordinates (start/end time, low/high freq),
the vector of class scores, and the feature vector for one detection.
Raises
------
AttributeError, KeyError, ValueError
If `detection_dataset` is missing expected variables ('scores',
'dimensions', 'classes', 'features') or coordinates ('time', 'freq'
associated with 'scores'), or if `geometry_builder` fails.
"""
detections = []
categories = detection_dataset.category.values
predictions = []
for score, class_scores, time, freq, dims, feats in zip(
detection_dataset["scores"].values,
detection_dataset["classes"].values,
detection_dataset["time"].values,
detection_dataset["frequency"].values,
detection_dataset["dimensions"].values,
detection_dataset["features"].values,
detections.scores,
detections.class_scores,
detections.times,
detections.frequencies,
detections.sizes,
detections.features,
):
highest_scoring_class = categories[class_scores.argmax()]
highest_scoring_class = targets.class_names[class_scores.argmax()]
geom = geometry_decoder(
geom = targets.decode_roi(
(time, freq),
dims,
class_name=highest_scoring_class,
)
detections.append(
predictions.append(
RawPrediction(
detection_score=score,
geometry=geom,
@ -118,7 +58,7 @@ def convert_xr_dataset_to_raw_prediction(
)
)
return detections
return predictions
def convert_raw_predictions_to_clip_prediction(
@ -128,35 +68,7 @@ def convert_raw_predictions_to_clip_prediction(
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False,
) -> data.ClipPrediction:
"""Convert a list of RawPredictions into a soundevent ClipPrediction.
Iterates through `raw_predictions` (assumed to belong to a single clip),
converts each one into a `soundevent.data.SoundEventPrediction` using
`convert_raw_prediction_to_sound_event_prediction`, and packages them
into a `soundevent.data.ClipPrediction` associated with the original `clip`.
Parameters
----------
raw_predictions : List[RawPrediction]
List of raw prediction objects for a single clip.
clip : data.Clip
The original `soundevent.data.Clip` object these predictions belong to.
sound_event_decoder : SoundEventDecoder
Function to decode class names into representative tags.
generic_class_tags : List[data.Tag]
List of tags representing the generic class category.
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
Threshold applied to class scores during decoding.
top_class_only : bool, default=False
If True, only decode tags for the single highest-scoring class above
the threshold. If False, decode tags for all classes above threshold.
Returns
-------
data.ClipPrediction
A `ClipPrediction` object containing a list of `SoundEventPrediction`
objects corresponding to the input `raw_predictions`.
"""
"""Convert a list of RawPredictions into a soundevent ClipPrediction."""
return data.ClipPrediction(
clip=clip,
sound_events=[
@ -181,68 +93,7 @@ def convert_raw_prediction_to_sound_event_prediction(
] = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False,
):
"""Convert a single RawPrediction into a soundevent SoundEventPrediction.
This function performs the core decoding steps for a single detected event:
1. Creates a `soundevent.data.SoundEvent` containing the geometry
(BoundingBox derived from `raw_prediction` bounds) and any associated
feature vectors.
2. Initializes a list of predicted tags using the provided
`generic_class_tags`, assigning the overall `detection_score` from the
`raw_prediction` to these generic tags.
3. Processes the `class_scores` from the `raw_prediction`:
a. Optionally filters out scores below `classification_threshold`
(if it's not None).
b. Sorts the remaining scores in descending order.
c. Iterates through the sorted, thresholded class scores.
d. For each class, uses the `sound_event_decoder` to get the
representative base tags for that class name.
e. Wraps these base tags in `soundevent.data.PredictedTag`, associating
the specific `score` of that class prediction.
f. Appends these specific predicted tags to the list.
g. If `top_class_only` is True, stops after processing the first
(highest-scoring) class that passed the threshold.
4. Creates and returns the final `soundevent.data.SoundEventPrediction`,
associating the `SoundEvent`, the overall `detection_score`, and the
compiled list of `PredictedTag` objects.
Parameters
----------
raw_prediction : RawPrediction
The raw prediction object containing score, bounds, class scores,
features. Assumes `class_scores` is an `xr.DataArray` with a 'category'
coordinate. Assumes `features` is an `xr.DataArray` with a 'feature'
coordinate.
recording : data.Recording
The recording the sound event belongs to.
sound_event_decoder : SoundEventDecoder
Configured function mapping class names (str) to lists of base
`data.Tag` objects.
generic_class_tags : List[data.Tag]
List of base tags representing the generic category.
classification_threshold : float, optional
The minimum score a class prediction must have to be considered
significant enough to have its tags decoded and added. If None, no
thresholding is applied based on class score (all predicted classes,
or the top one if `top_class_only` is True, will be processed).
Defaults to `DEFAULT_CLASSIFICATION_THRESHOLD`.
top_class_only : bool, default=False
If True, only includes tags for the single highest-scoring class that
exceeds the threshold. If False (default), includes tags for all classes
exceeding the threshold.
Returns
-------
data.SoundEventPrediction
The fully formed sound event prediction object.
Raises
------
ValueError
If `raw_prediction.features` has unexpected structure or if
`data.term_from_key` (if used internally) fails.
If `sound_event_decoder` fails for a class name and errors are raised.
"""
"""Convert a single RawPrediction into a soundevent SoundEventPrediction."""
sound_event = data.SoundEvent(
recording=recording,
geometry=raw_prediction.geometry,
@ -252,7 +103,7 @@ def convert_raw_prediction_to_sound_event_prediction(
tags = [
*get_generic_tags(
raw_prediction.detection_score,
generic_class_tags=targets.generic_class_tags,
generic_class_tags=targets.detection_class_tags,
),
*get_class_tags(
raw_prediction.class_scores,
@ -273,25 +124,7 @@ def get_generic_tags(
detection_score: float,
generic_class_tags: List[data.Tag],
) -> List[data.PredictedTag]:
"""Create PredictedTag objects for the generic category.
Takes the base list of generic tags and assigns the overall detection
score to each one, wrapping them in `PredictedTag` objects.
Parameters
----------
detection_score : float
The overall confidence score of the detection event.
generic_class_tags : List[data.Tag]
The list of base `soundevent.data.Tag` objects that define the
generic category (e.g., ['call_type:Echolocation', 'order:Chiroptera']).
Returns
-------
List[data.PredictedTag]
A list of `PredictedTag` objects for the generic category, each
assigned the `detection_score`.
"""
"""Create PredictedTag objects for the generic category."""
return [
data.PredictedTag(tag=tag, score=detection_score)
for tag in generic_class_tags
@ -299,25 +132,7 @@ def get_generic_tags(
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
"""Convert an extracted feature vector DataArray into soundevent Features.
Parameters
----------
features : xr.DataArray
A 1D xarray DataArray containing feature values, indexed by a coordinate
named 'feature' which holds the feature names (e.g., output of selecting
features for one detection from `extract_detection_xr_dataset`).
Returns
-------
List[data.Feature]
A list of `soundevent.data.Feature` objects.
Notes
-----
- This function creates basic `Term` objects using the feature coordinate
names with a "batdetect2:" prefix.
"""
"""Convert an extracted feature vector DataArray into soundevent Features."""
return [
data.Feature(
term=data.Term(

View File

@ -1,162 +0,0 @@
"""Extracts candidate detection points from a model output heatmap.
This module implements Step 3 within the BatDetect2 postprocessing
pipeline. Its primary function is to identify potential sound event locations
by finding peaks (local maxima or high-scoring points) in the detection heatmap
produced by the neural network (usually after Non-Maximum Suppression and
coordinate remapping have been applied).
It provides functionality to:
- Identify the locations (time, frequency) of the highest-scoring points.
- Filter these points based on a minimum confidence score threshold.
- Limit the maximum number of detection points returned (top-k).
The main output is an `xarray.DataArray` containing the scores and
corresponding time/frequency coordinates for the extracted detection points.
This output serves as the input for subsequent postprocessing steps, such as
extracting predicted class probabilities and bounding box sizes at these
specific locations.
"""
from typing import Optional
import numpy as np
import xarray as xr
from soundevent.arrays import Dimensions, get_dim_width
__all__ = [
"extract_detections_from_array",
"get_max_detections",
"DEFAULT_DETECTION_THRESHOLD",
"TOP_K_PER_SEC",
]
DEFAULT_DETECTION_THRESHOLD = 0.01
"""Default confidence score threshold used for filtering detections."""
TOP_K_PER_SEC = 200
"""Default desired maximum number of detections per second of audio."""
def extract_detections_from_array(
detection_array: xr.DataArray,
max_detections: Optional[int] = None,
threshold: Optional[float] = DEFAULT_DETECTION_THRESHOLD,
) -> xr.DataArray:
"""Extract detection locations (time, freq) and scores from a heatmap.
Identifies the pixels with the highest scores in the input detection
heatmap, filters them based on an optional score `threshold`, limits the
number to an optional `max_detections`, and returns their scores along with
their corresponding time and frequency coordinates.
Parameters
----------
detection_array : xr.DataArray
A 2D xarray DataArray representing the detection heatmap. Must have
dimensions and coordinates named 'time' and 'frequency'. Higher values
are assumed to indicate higher detection confidence.
max_detections : int, optional
The absolute maximum number of detections to return. If specified, only
the top `max_detections` highest-scoring detections (passing the
threshold) are returned. If None (default), all detections passing
the threshold are returned, sorted by score.
threshold : float, optional
The minimum confidence score required for a detection peak to be
kept. Detections with scores below this value are discarded.
Defaults to `DEFAULT_DETECTION_THRESHOLD`. If set to None, no
thresholding is applied.
Returns
-------
xr.DataArray
A 1D xarray DataArray named 'score' with a 'detection' dimension.
- The data values are the scores of the extracted detections, sorted
in descending order.
- It has coordinates 'time' and 'frequency' (also indexed by the
'detection' dimension) indicating the location of each detection
peak in the original coordinate system.
- Returns an empty DataArray if no detections pass the criteria.
Raises
------
ValueError
If `max_detections` is not None and not a positive integer, or if
`detection_array` lacks required dimensions/coordinates.
"""
if max_detections is not None:
if max_detections <= 0:
raise ValueError("Max detections must be positive")
values = detection_array.values.flatten()
if max_detections is not None:
top_indices = np.argpartition(-values, max_detections)[:max_detections]
top_sorted_indices = top_indices[np.argsort(-values[top_indices])]
else:
top_sorted_indices = np.argsort(-values)
top_values = values[top_sorted_indices]
if threshold is not None:
mask = top_values > threshold
top_values = top_values[mask]
top_sorted_indices = top_sorted_indices[mask]
freq_indices, time_indices = np.unravel_index(
top_sorted_indices,
detection_array.shape,
)
times = detection_array.coords[Dimensions.time.value].values[time_indices]
freqs = detection_array.coords[Dimensions.frequency.value].values[
freq_indices
]
return xr.DataArray(
data=top_values,
coords={
Dimensions.frequency.value: ("detection", freqs),
Dimensions.time.value: ("detection", times),
},
dims="detection",
name="score",
)
def get_max_detections(
detection_array: xr.DataArray,
top_k_per_sec: int = TOP_K_PER_SEC,
) -> int:
"""Calculate max detections allowed based on duration and rate.
Determines the total maximum number of detections to extract from a
heatmap based on its time duration and a desired rate of detections
per second.
Parameters
----------
detection_array : xr.DataArray
The detection heatmap, requiring 'time' coordinates from which the
total duration can be calculated using
`soundevent.arrays.get_dim_width`.
top_k_per_sec : int, default=TOP_K_PER_SEC
The desired maximum number of detections to allow per second of audio.
Returns
-------
int
The calculated total maximum number of detections allowed for the
entire duration of the `detection_array`.
Raises
------
ValueError
If the duration cannot be calculated from the `detection_array` (e.g.,
missing or invalid 'time' coordinates/dimension).
"""
if top_k_per_sec < 0:
raise ValueError("top_k_per_sec cannot be negative.")
duration = get_dim_width(detection_array, Dimensions.time.value)
return int(duration * top_k_per_sec)

View File

@ -15,108 +15,75 @@ precise time-frequency location of each detection. The final output aggregates
all extracted information into a structured `xarray.Dataset`.
"""
import xarray as xr
from soundevent.arrays import Dimensions
from typing import List, Optional
import torch
from batdetect2.typing.postprocess import ClipDetectionsTensor
__all__ = [
"extract_values_at_positions",
"extract_detection_xr_dataset",
"extract_detection_peaks",
]
def extract_values_at_positions(
array: xr.DataArray,
positions: xr.DataArray,
) -> xr.DataArray:
"""Extract values from an array at specified time-frequency positions.
def extract_detection_peaks(
detection_heatmap: torch.Tensor,
size_heatmap: torch.Tensor,
feature_heatmap: torch.Tensor,
classification_heatmap: torch.Tensor,
max_detections: int = 200,
threshold: Optional[float] = None,
) -> List[ClipDetectionsTensor]:
height = detection_heatmap.shape[-2]
width = detection_heatmap.shape[-1]
Uses coordinate-based indexing to retrieve values from a source `array`
(e.g., class probabilities, size predictions, features) at the time and
frequency coordinates defined in the `positions` array.
Parameters
----------
array : xr.DataArray
The source DataArray from which to extract values. Must have 'time'
and 'frequency' dimensions and coordinates matching the space of
`positions`.
positions : xr.DataArray
A 1D DataArray whose 'time' and 'frequency' coordinates specify the
locations from which to extract values.
Returns
-------
xr.DataArray
A DataArray containing the values extracted from `array` at the given
positions.
Raises
------
ValueError, IndexError, KeyError
If dimensions or coordinates are missing or incompatible between
`array` and `positions`, or if selection fails.
"""
return array.sel(
**{
Dimensions.frequency.value: positions.coords[
Dimensions.frequency.value
],
Dimensions.time.value: positions.coords[Dimensions.time.value],
}
).T
def extract_detection_xr_dataset(
positions: xr.DataArray,
sizes: xr.DataArray,
classes: xr.DataArray,
features: xr.DataArray,
) -> xr.Dataset:
"""Combine extracted detection information into a structured xr.Dataset.
Takes the detection positions/scores and the full model output heatmaps
(sizes, classes, optional features), extracts the relevant data at the
detection positions, and packages everything into a single `xarray.Dataset`
where all variables are indexed by a common 'detection' dimension.
Parameters
----------
positions : xr.DataArray
Output from `extract_detections_from_array`, containing detection
scores as data and 'time', 'frequency' coordinates along the
'detection' dimension.
sizes : xr.DataArray
The full size prediction heatmap from the model, with dimensions like
('dimension', 'time', 'frequency').
classes : xr.DataArray
The full class probability heatmap from the model, with dimensions like
('category', 'time', 'frequency').
features : xr.DataArray
The full feature map from the model, with
dimensions like ('feature', 'time', 'frequency').
Returns
-------
xr.Dataset
An xarray Dataset containing aligned information for each detection:
- 'scores': DataArray from `positions` (score data, time/freq coords).
- 'dimensions': DataArray with extracted size values
(dims: 'detection', 'dimension').
- 'classes': DataArray with extracted class probabilities
(dims: 'detection', 'category').
- 'features': DataArray with extracted feature vectors
(dims: 'detection', 'feature'), if `features` was provided. All
DataArrays share the 'detection' dimension and associated
time/frequency coordinates.
"""
sizes = extract_values_at_positions(sizes, positions)
classes = extract_values_at_positions(classes, positions)
features = extract_values_at_positions(features, positions)
return xr.Dataset(
{
"scores": positions,
"dimensions": sizes,
"classes": classes,
"features": features,
}
freqs, times = torch.meshgrid(
torch.arange(height, dtype=torch.int32),
torch.arange(width, dtype=torch.int32),
indexing="ij",
)
freqs = freqs.flatten().to(detection_heatmap.device)
times = times.flatten().to(detection_heatmap.device)
output_size_preds = size_heatmap.detach()
output_features = feature_heatmap.detach()
output_class_probs = classification_heatmap.detach()
predictions = []
for idx, item in enumerate(detection_heatmap):
item = item.squeeze().flatten() # Remove channel dim
indices = torch.argsort(item, descending=True)[:max_detections]
detection_scores = item.take(indices)
detection_freqs = freqs.take(indices)
detection_times = times.take(indices)
if threshold is not None:
mask = detection_scores >= threshold
detection_scores = detection_scores[mask]
detection_times = detection_times[mask]
detection_freqs = detection_freqs[mask]
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
features = output_features[idx, :, detection_freqs, detection_times].T
class_scores = output_class_probs[
idx,
:,
detection_freqs,
detection_times,
].T
predictions.append(
ClipDetectionsTensor(
scores=detection_scores,
sizes=sizes,
features=features,
class_scores=class_scores,
times=detection_times.to(torch.float32) / width,
frequencies=(detection_freqs.to(torch.float32) / height),
)
)
return predictions

View File

@ -0,0 +1,100 @@
from typing import List, Optional, Tuple, Union
import torch
from loguru import logger
from batdetect2.postprocess.config import (
PostprocessConfig,
)
from batdetect2.postprocess.extraction import extract_detection_peaks
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
from batdetect2.postprocess.remapping import map_detection_to_clip
from batdetect2.typing import ModelOutput
from batdetect2.typing.postprocess import (
ClipDetectionsTensor,
PostprocessorProtocol,
)
from batdetect2.typing.preprocess import PreprocessorProtocol
__all__ = [
"build_postprocessor",
"Postprocessor",
]
def build_postprocessor(
preprocessor: PreprocessorProtocol,
config: Optional[PostprocessConfig] = None,
) -> PostprocessorProtocol:
"""Factory function to build the standard postprocessor."""
config = config or PostprocessConfig()
logger.opt(lazy=True).debug(
"Building postprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
return Postprocessor(
samplerate=preprocessor.output_samplerate,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
top_k_per_sec=config.top_k_per_sec,
detection_threshold=config.detection_threshold,
)
class Postprocessor(torch.nn.Module, PostprocessorProtocol):
"""Standard implementation of the postprocessing pipeline."""
def __init__(
self,
samplerate: float,
min_freq: float,
max_freq: float,
top_k_per_sec: int = 200,
detection_threshold: float = 0.01,
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
):
"""Initialize the Postprocessor."""
super().__init__()
self.output_samplerate = samplerate
self.min_freq = min_freq
self.max_freq = max_freq
self.top_k_per_sec = top_k_per_sec
self.detection_threshold = detection_threshold
self.nms_kernel_size = nms_kernel_size
def forward(
self,
output: ModelOutput,
start_times: Optional[List[float]] = None,
) -> List[ClipDetectionsTensor]:
detection_heatmap = non_max_suppression(
output.detection_probs.detach(),
kernel_size=self.nms_kernel_size,
)
width = output.detection_probs.shape[-1]
duration = width / self.output_samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_detection_peaks(
detection_heatmap,
size_heatmap=output.size_preds,
feature_heatmap=output.features,
classification_heatmap=output.class_probs,
max_detections=max_detections,
threshold=self.detection_threshold,
)
if start_times is None:
start_times = [0 for _ in range(len(detections))]
return [
map_detection_to_clip(
detection,
start_time=0,
end_time=duration,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection in detections
]

View File

@ -20,6 +20,7 @@ import xarray as xr
from soundevent.arrays import Dimensions
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing.postprocess import ClipDetectionsTensor
__all__ = [
"features_to_xarray",
@ -29,6 +30,25 @@ __all__ = [
]
def map_detection_to_clip(
detections: ClipDetectionsTensor,
start_time: float,
end_time: float,
min_freq: float,
max_freq: float,
) -> ClipDetectionsTensor:
duration = end_time - start_time
bandwidth = max_freq - min_freq
return ClipDetectionsTensor(
scores=detections.scores,
sizes=detections.sizes,
features=detections.features,
class_scores=detections.class_scores,
times=(detections.times * duration + start_time),
frequencies=(detections.frequencies * bandwidth + min_freq),
)
def features_to_xarray(
features: torch.Tensor,
start_time: float,

View File

@ -1,295 +0,0 @@
"""Defines shared interfaces and data structures for postprocessing.
This module centralizes the Protocol definitions and common data structures
used throughout the `batdetect2.postprocess` module.
The main component is the `PostprocessorProtocol`, which outlines the standard
interface for an object responsible for executing the entire postprocessing
pipeline. This pipeline transforms raw neural network outputs into interpretable
detections represented as `soundevent` objects. Using protocols ensures
modularity and consistent interaction between different parts of the BatDetect2
system that deal with model predictions.
"""
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Protocol
import numpy as np
import xarray as xr
from soundevent import data
from batdetect2.models.types import ModelOutput
from batdetect2.targets.types import Position, Size
__all__ = [
"RawPrediction",
"PostprocessorProtocol",
"GeometryDecoder",
]
# TODO: update the docstring
class GeometryDecoder(Protocol):
"""Type alias for a function that recovers geometry from position and size.
This callable takes:
1. A position tuple `(time, frequency)`.
2. A NumPy array of size dimensions (e.g., `[width, height]`).
3. Optionally a class name of the highest scoring class. This is to accomodate
different ways of decoding geometry that depend on the predicted class.
It should return the reconstructed `soundevent.data.Geometry` (typically a
`BoundingBox`).
"""
def __call__(
self, position: Position, size: Size, class_name: Optional[str] = None
) -> data.Geometry: ...
class RawPrediction(NamedTuple):
"""Intermediate representation of a single detected sound event.
Holds extracted information about a detection after initial processing
(like peak finding, coordinate remapping, geometry recovery) but before
final class decoding and conversion into a `SoundEventPrediction`. This
can be useful for evaluation or simpler data handling formats.
Attributes
----------
geometry: data.Geometry
The recovered estimated geometry of the detected sound event.
Usually a bounding box.
detection_score : float
The confidence score associated with this detection, typically from
the detection heatmap peak.
class_scores : xr.DataArray
An xarray DataArray containing the predicted probabilities or scores
for each target class at the detection location. Indexed by a
'category' coordinate containing class names.
features : xr.DataArray
An xarray DataArray containing extracted feature vectors at the
detection location. Indexed by a 'feature' coordinate.
"""
geometry: data.Geometry
detection_score: float
class_scores: np.ndarray
features: np.ndarray
@dataclass
class BatDetect2Prediction:
raw: RawPrediction
sound_event_prediction: data.SoundEventPrediction
class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline.
This protocol outlines the standard methods for an object that takes raw
output from a BatDetect2 model and the corresponding input clip metadata,
and processes it through various stages (e.g., coordinate remapping, NMS,
detection extraction, data extraction, decoding) to produce interpretable
results at different levels of completion.
Implementations manage the configured logic for all postprocessing steps.
"""
def get_feature_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Remap feature tensors to coordinate-aware DataArrays.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch, expected
to contain the necessary feature tensors.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects, one for each item in the
processed batch. This list provides the timing, recording, and
other metadata context needed to calculate real-world coordinates
(seconds, Hz) for the output arrays. The length of this list must
correspond to the batch size of the `output`.
Returns
-------
List[xr.DataArray]
A list of xarray DataArrays, one for each input clip in the batch,
in the same order. Each DataArray contains the feature vectors
with dimensions like ('feature', 'time', 'frequency') and
corresponding real-world coordinates.
"""
...
def get_detection_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Remap detection tensors to coordinate-aware DataArrays.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch,
containing detection heatmaps.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing coordinate context. Must match the batch size of
`output`.
Returns
-------
List[xr.DataArray]
A list of 2D xarray DataArrays (one per input clip, in order),
representing the detection heatmap with 'time' and 'frequency'
coordinates. Values typically indicate detection confidence.
"""
...
def get_classification_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Remap classification tensors to coordinate-aware DataArrays.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch,
containing class probability tensors.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing coordinate context. Must match the batch size of
`output`.
Returns
-------
List[xr.DataArray]
A list of 3D xarray DataArrays (one per input clip, in order),
representing class probabilities with 'category', 'time', and
'frequency' dimensions and coordinates.
"""
...
def get_sizes_arrays(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.DataArray]:
"""Remap size prediction tensors to coordinate-aware DataArrays.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch,
containing predicted size tensors (e.g., width and height).
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing coordinate context. Must match the batch size of
`output`.
Returns
-------
List[xr.DataArray]
A list of 3D xarray DataArrays (one per input clip, in order),
representing predicted sizes with 'dimension'
(e.g., ['width', 'height']), 'time', and 'frequency' dimensions and
coordinates. Values represent estimated detection sizes.
"""
...
def get_detection_datasets(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[xr.Dataset]:
"""Perform remapping, NMS, detection, and data extraction for a batch.
Processes the raw model output for a batch to identify detection peaks
and extract all associated information (score, position, size, class
probs, features) at those peak locations, returning a structured
dataset for each input clip in the batch.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[xr.Dataset]
A list of xarray Datasets (one per input clip, in order). Each
Dataset contains multiple DataArrays ('scores', 'dimensions',
'classes', 'features') sharing a common 'detection' dimension,
providing aligned data for each detected event in that clip.
"""
...
def get_raw_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Processes the raw model output for a batch through remapping, NMS,
detection, data extraction, and geometry recovery to produce a list of
`RawPrediction` objects for each corresponding input clip. This provides
a simplified, intermediate representation before final tag decoding.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[List[RawPrediction]]
A list of lists (one inner list per input clip, in order). Each
inner list contains the `RawPrediction` objects extracted for the
corresponding input clip.
"""
...
def get_sound_event_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[BatDetect2Prediction]]: ...
def get_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output for a batch and corresponding clips, applies the
entire postprocessing chain, and returns the final, interpretable
predictions as a list of `soundevent.data.ClipPrediction` objects.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[data.ClipPrediction]
A list containing one `ClipPrediction` object for each input clip
(in the same order), populated with `SoundEventPrediction` objects
representing the final detections with decoded tags and geometry.
"""
...

View File

@ -1,458 +1,19 @@
"""Main entry point for the BatDetect2 Preprocessing subsystem.
"""Main entry point for the BatDetect2 preprocessing subsystem."""
This package (`batdetect2.preprocessing`) defines and orchestrates the pipeline
for converting raw audio input (from files or data objects) into processed
spectrograms suitable for input to BatDetect2 models. This ensures consistent
data handling between model training and inference.
The preprocessing pipeline consists of two main stages, configured via nested
data structures:
1. **Audio Processing (`.audio`)**: Loads audio waveforms and applies initial
processing like resampling, duration adjustment, centering, and scaling.
Configured via `AudioConfig`.
2. **Spectrogram Generation (`.spectrogram`)**: Computes the spectrogram from
the processed waveform using STFT, followed by frequency cropping, optional
PCEN, amplitude scaling (dB, power, linear), optional denoising, optional
resizing, and optional peak normalization. Configured via
`SpectrogramConfig`.
This module provides the primary interface:
- `PreprocessingConfig`: A unified configuration object holding `AudioConfig`
and `SpectrogramConfig`.
- `load_preprocessing_config`: Function to load the unified configuration.
- `Preprocessor`: A protocol defining the interface for the end-to-end pipeline.
- `StandardPreprocessor`: The default implementation of the `Preprocessor`.
- `build_preprocessor`: A factory function to create a `StandardPreprocessor`
instance from a `PreprocessingConfig`.
"""
from typing import Optional, Union
import numpy as np
import xarray as xr
from loguru import logger
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess.audio import (
DEFAULT_DURATION,
SCALE_RAW_AUDIO,
TARGET_SAMPLERATE_HZ,
AudioConfig,
ResampleConfig,
build_audio_loader,
)
from batdetect2.preprocess.spectrogram import (
MAX_FREQ,
MIN_FREQ,
ConfigurableSpectrogramBuilder,
FrequencyConfig,
PcenConfig,
SpecSizeConfig,
SpectrogramConfig,
STFTConfig,
build_spectrogram_builder,
get_spectrogram_resolution,
)
from batdetect2.preprocess.types import (
AudioLoader,
PreprocessorProtocol,
SpectrogramBuilder,
from batdetect2.audio import TARGET_SAMPLERATE_HZ
from batdetect2.preprocess.config import (
PreprocessingConfig,
load_preprocessing_config,
)
from batdetect2.preprocess.preprocessor import Preprocessor, build_preprocessor
from batdetect2.preprocess.spectrogram import MAX_FREQ, MIN_FREQ
__all__ = [
"AudioConfig",
"AudioLoader",
"ConfigurableSpectrogramBuilder",
"DEFAULT_DURATION",
"FrequencyConfig",
"MAX_FREQ",
"MIN_FREQ",
"PcenConfig",
"PreprocessingConfig",
"ResampleConfig",
"SCALE_RAW_AUDIO",
"STFTConfig",
"SpecSizeConfig",
"SpectrogramBuilder",
"SpectrogramConfig",
"StandardPreprocessor",
"Preprocessor",
"TARGET_SAMPLERATE_HZ",
"build_audio_loader",
"build_preprocessor",
"build_spectrogram_builder",
"get_spectrogram_resolution",
"load_preprocessing_config",
"get_default_preprocessor",
]
class PreprocessingConfig(BaseConfig):
"""Unified configuration for the audio preprocessing pipeline.
Aggregates the configuration for both the initial audio processing stage
and the subsequent spectrogram generation stage.
Attributes
----------
audio : AudioConfig
Configuration settings for the audio loading and initial waveform
processing steps (e.g., resampling, duration adjustment, scaling).
Defaults to default `AudioConfig` settings if omitted.
spectrogram : SpectrogramConfig
Configuration settings for the spectrogram generation process
(e.g., STFT parameters, frequency cropping, scaling, denoising,
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
"""
audio: AudioConfig = Field(default_factory=AudioConfig)
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
class StandardPreprocessor(PreprocessorProtocol):
"""Standard implementation of the `Preprocessor` protocol.
Orchestrates the audio loading and spectrogram generation pipeline using
an `AudioLoader` and a `SpectrogramBuilder` internally, which are
configured according to a `PreprocessingConfig`.
This class is typically instantiated using the `build_preprocessor`
factory function.
Attributes
----------
audio_loader : AudioLoader
The configured audio loader instance used for waveform loading and
initial processing.
spectrogram_builder : SpectrogramBuilder
The configured spectrogram builder instance used for generating
spectrograms from waveforms.
default_samplerate : int
The sample rate (in Hz) assumed for input waveforms when they are
provided as raw NumPy arrays without coordinate information (e.g.,
when calling `compute_spectrogram` directly with `np.ndarray`).
This value is derived from the `AudioConfig` (target resample rate
or default if resampling is off) and also serves as documentation
for the pipeline's intended operating sample rate. Note that when
processing `xr.DataArray` inputs that have coordinate information
(the standard internal workflow), the sample rate embedded in the
coordinates takes precedence over this default value during
spectrogram calculation.
"""
audio_loader: AudioLoader
spectrogram_builder: SpectrogramBuilder
default_samplerate: int
max_freq: float
min_freq: float
def __init__(
self,
audio_loader: AudioLoader,
spectrogram_builder: SpectrogramBuilder,
default_samplerate: int,
max_freq: float,
min_freq: float,
) -> None:
"""Initialize the StandardPreprocessor.
Parameters
----------
audio_loader : AudioLoader
An initialized audio loader conforming to the AudioLoader protocol.
spectrogram_builder : SpectrogramBuilder
An initialized spectrogram builder conforming to the
SpectrogramBuilder protocol.
default_samplerate : int
The sample rate to assume for NumPy array inputs and potentially
reflecting the target rate of the audio config.
"""
self.audio_loader = audio_loader
self.spectrogram_builder = spectrogram_builder
self.default_samplerate = default_samplerate
self.max_freq = max_freq
self.min_freq = min_freq
def load_file_audio(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform from a file path.
Delegates to the internal `audio_loader`.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix if `path` is relative.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform (typically first
channel).
"""
return self.audio_loader.load_file(
path,
audio_dir=audio_dir,
)
def load_recording_audio(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform for a Recording.
Delegates to the internal `audio_loader`.
Parameters
----------
recording : data.Recording
The Recording object.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform (typically first
channel).
"""
return self.audio_loader.load_recording(
recording,
audio_dir=audio_dir,
)
def load_clip_audio(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load and preprocess *only* the audio waveform for a Clip.
Delegates to the internal `audio_loader`.
Parameters
----------
clip : data.Clip
The Clip object defining the segment.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The loaded and preprocessed audio waveform segment (typically first
channel).
"""
return self.audio_loader.load_clip(
clip,
audio_dir=audio_dir,
)
def preprocess_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio from a file and compute the final processed spectrogram.
Performs the full pipeline:
Load -> Preprocess Audio -> Compute Spectrogram.
Parameters
----------
path : PathLike
Path to the audio file.
audio_dir : PathLike, optional
A directory prefix if `path` is relative.
Returns
-------
xr.DataArray
The final processed spectrogram.
"""
wav = self.load_file_audio(path, audio_dir=audio_dir)
return self.spectrogram_builder(
wav,
samplerate=self.default_samplerate,
)
def preprocess_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio for a Recording and compute the processed spectrogram.
Performs the full pipeline for the entire duration of the recording.
Parameters
----------
recording : data.Recording
The Recording object.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The final processed spectrogram.
"""
wav = self.load_recording_audio(recording, audio_dir=audio_dir)
return self.spectrogram_builder(
wav,
samplerate=self.default_samplerate,
)
def preprocess_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray:
"""Load audio for a Clip and compute the final processed spectrogram.
Performs the full pipeline for the specified clip segment.
Parameters
----------
clip : data.Clip
The Clip object defining the audio segment.
audio_dir : PathLike, optional
Directory containing the audio file.
Returns
-------
xr.DataArray
The final processed spectrogram.
"""
wav = self.load_clip_audio(clip, audio_dir=audio_dir)
return self.spectrogram_builder(
wav,
samplerate=self.default_samplerate,
)
def compute_spectrogram(
self, wav: Union[xr.DataArray, np.ndarray]
) -> xr.DataArray:
"""Compute the spectrogram from a pre-loaded audio waveform.
Applies the configured spectrogram generation steps
(STFT, scaling, etc.) using the internal `spectrogram_builder`.
If `wav` is a NumPy array, the `default_samplerate` stored in this
preprocessor instance will be used. If `wav` is an xarray DataArray
with time coordinates, the sample rate derived from those coordinates
will take precedence over `default_samplerate`.
Parameters
----------
wav : Union[xr.DataArray, np.ndarray]
The input audio waveform. If numpy array, `default_samplerate`
stored in this object will be assumed.
Returns
-------
xr.DataArray
The computed spectrogram.
"""
return self.spectrogram_builder(
wav,
samplerate=self.default_samplerate,
)
def load_preprocessing_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PreprocessingConfig:
"""Load the unified preprocessing configuration from a file.
Reads a configuration file (YAML) and validates it against the
`PreprocessingConfig` schema, potentially extracting data from a nested
field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
preprocessing configuration (e.g., "train.preprocessing"). If None, the
entire file content is validated as the PreprocessingConfig.
Returns
-------
PreprocessingConfig
Loaded and validated preprocessing configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded config data does not conform to PreprocessingConfig.
KeyError, TypeError
If `field` specifies an invalid path.
"""
return load_config(path, schema=PreprocessingConfig, field=field)
def build_preprocessor(
config: Optional[PreprocessingConfig] = None,
) -> PreprocessorProtocol:
"""Factory function to build the standard preprocessor from configuration.
Creates instances of the required `AudioLoader` and `SpectrogramBuilder`
based on the provided `PreprocessingConfig` (or defaults if config is None),
determines the effective default sample rate, and initializes the
`StandardPreprocessor`.
Parameters
----------
config : PreprocessingConfig, optional
The unified preprocessing configuration object. If None, default
configurations for audio and spectrogram processing will be used.
Returns
-------
Preprocessor
An initialized `StandardPreprocessor` instance ready to process audio
according to the configuration.
"""
config = config or PreprocessingConfig()
logger.opt(lazy=True).debug(
"Building preprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
default_samplerate = (
config.audio.resample.samplerate
if config.audio.resample
else TARGET_SAMPLERATE_HZ
)
min_freq = config.spectrogram.frequencies.min_freq
max_freq = config.spectrogram.frequencies.max_freq
return StandardPreprocessor(
audio_loader=build_audio_loader(config.audio),
spectrogram_builder=build_spectrogram_builder(config.spectrogram),
default_samplerate=default_samplerate,
min_freq=min_freq,
max_freq=max_freq,
)
def get_default_preprocessor():
return build_preprocessor()

Some files were not shown because too many files have changed in this diff Show More