mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Compare commits
81 Commits
d4f249366e
...
2d796394f6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2d796394f6 | ||
|
|
49ec1916ce | ||
|
|
8727edf466 | ||
|
|
2f48c58de1 | ||
|
|
981e37c346 | ||
|
|
30159d64a9 | ||
|
|
c9f0c5c431 | ||
|
|
10865ee600 | ||
|
|
87ed44c8f7 | ||
|
|
df2abff654 | ||
|
|
d6ddc4514c | ||
|
|
4cd983a2c2 | ||
|
|
e65df81db2 | ||
|
|
6c25787123 | ||
|
|
8c80402f08 | ||
|
|
b81a882b58 | ||
|
|
6e217380f2 | ||
|
|
957c0735d2 | ||
|
|
bbb96b33a2 | ||
|
|
7d6cba5465 | ||
|
|
60e922d565 | ||
|
|
704b28292b | ||
|
|
e752e96b93 | ||
|
|
ec1c0ff020 | ||
|
|
8628133fd7 | ||
|
|
d80377981e | ||
|
|
ad5293e0d0 | ||
|
|
01e7a5df25 | ||
|
|
6d70140bc9 | ||
|
|
4fd2e84773 | ||
|
|
74c419f674 | ||
|
|
e65d5a6846 | ||
|
|
615c811bb4 | ||
|
|
41b18c3f0a | ||
|
|
16a0fa7b75 | ||
|
|
115084fd2b | ||
|
|
951dc59718 | ||
|
|
3376be06a4 | ||
|
|
cd4955d4f3 | ||
|
|
c73984b213 | ||
|
|
d8d2e5a2c2 | ||
|
|
b056d7d28d | ||
|
|
95a884ea16 | ||
|
|
b7ae526071 | ||
|
|
cf6d0d1ccc | ||
|
|
709b6355c2 | ||
|
|
db2ad11743 | ||
|
|
e0ecc3c3d1 | ||
|
|
71c2301c21 | ||
|
|
d3d2a28130 | ||
|
|
5b9a5a968f | ||
|
|
356be57f62 | ||
|
|
8d093c3ca2 | ||
|
|
2f4edeffff | ||
|
|
cca1d82d63 | ||
|
|
55f473c9ca | ||
|
|
40f6b64611 | ||
|
|
1cec332dd5 | ||
|
|
93e89ecc46 | ||
|
|
34ef9e92a1 | ||
|
|
0b5ac96fe8 | ||
|
|
dba6d2d918 | ||
|
|
ff754a1269 | ||
|
|
ed76ec24b6 | ||
|
|
d25efdad10 | ||
|
|
3043230d4f | ||
|
|
67e37227f5 | ||
|
|
9d4a9fc35c | ||
|
|
d0bab60bf3 | ||
|
|
a267db290c | ||
|
|
441ccb3382 | ||
|
|
281c4dcb8a | ||
|
|
cc9e47b022 | ||
|
|
1f26103f42 | ||
|
|
c80078feee | ||
|
|
0bb0caddea | ||
|
|
76dda0a0e9 | ||
|
|
c36ef3ecb5 | ||
|
|
667b18a54d | ||
|
|
61115d562c | ||
|
|
02adc19070 |
@ -133,7 +133,7 @@ When you need to specify a tag, you typically use a structure with two fields:
|
||||
**It defaults to `class`** if you omit it, which is common when defining the main target classes.
|
||||
- `value`: The specific _value_ of the tag (e.g., `Myotis daubentonii`, `Good`, `Rain`).
|
||||
|
||||
**Example YAML Configuration using TagInfo (e.g., inside a filter rule):**
|
||||
**Example YAML Configuration (e.g., inside a filter rule):**
|
||||
|
||||
```yaml
|
||||
# ... inside a filtering configuration section ...
|
||||
|
||||
@ -1,119 +1,125 @@
|
||||
datasets:
|
||||
train:
|
||||
name: example dataset
|
||||
description: Only for demonstration purposes
|
||||
sources:
|
||||
- format: batdetect2
|
||||
name: Example Data
|
||||
description: Examples included for testing batdetect2
|
||||
annotations_dir: example_data/anns
|
||||
audio_dir: example_data/audio
|
||||
|
||||
targets:
|
||||
classes:
|
||||
classes:
|
||||
- name: myomys
|
||||
tags:
|
||||
- value: Myotis mystacinus
|
||||
- name: pippip
|
||||
tags:
|
||||
- value: Pipistrellus pipistrellus
|
||||
- name: eptser
|
||||
tags:
|
||||
- value: Eptesicus serotinus
|
||||
- name: rhifer
|
||||
tags:
|
||||
- value: Rhinolophus ferrumequinum
|
||||
generic_class:
|
||||
- key: class
|
||||
value: Bat
|
||||
|
||||
filtering:
|
||||
rules:
|
||||
- match_type: all
|
||||
tags:
|
||||
- key: event
|
||||
value: Echolocation
|
||||
- match_type: exclude
|
||||
tags:
|
||||
- key: class
|
||||
value: Unknown
|
||||
audio:
|
||||
samplerate: 256000
|
||||
resample:
|
||||
enabled: True
|
||||
method: "poly"
|
||||
|
||||
preprocess:
|
||||
audio:
|
||||
resample:
|
||||
samplerate: 256000
|
||||
method: "poly"
|
||||
scale: false
|
||||
center: true
|
||||
duration: null
|
||||
|
||||
spectrogram:
|
||||
stft:
|
||||
window_duration: 0.002
|
||||
window_overlap: 0.75
|
||||
window_fn: hann
|
||||
frequencies:
|
||||
max_freq: 120000
|
||||
min_freq: 10000
|
||||
pcen:
|
||||
stft:
|
||||
window_duration: 0.002
|
||||
window_overlap: 0.75
|
||||
window_fn: hann
|
||||
frequencies:
|
||||
max_freq: 120000
|
||||
min_freq: 10000
|
||||
size:
|
||||
height: 128
|
||||
resize_factor: 0.5
|
||||
spectrogram_transforms:
|
||||
- name: pcen
|
||||
time_constant: 0.1
|
||||
gain: 0.98
|
||||
bias: 2
|
||||
power: 0.5
|
||||
scale: "amplitude"
|
||||
size:
|
||||
height: 128
|
||||
resize_factor: 0.5
|
||||
spectral_mean_substraction: true
|
||||
peak_normalize: false
|
||||
- name: spectral_mean_substraction
|
||||
|
||||
postprocess:
|
||||
nms_kernel_size: 9
|
||||
detection_threshold: 0.01
|
||||
min_freq: 10000
|
||||
max_freq: 120000
|
||||
top_k_per_sec: 200
|
||||
|
||||
labels:
|
||||
sigma: 3
|
||||
|
||||
model:
|
||||
input_height: 128
|
||||
in_channels: 1
|
||||
out_channels: 32
|
||||
encoder:
|
||||
layers:
|
||||
- block_type: FreqCoordConvDown
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 32
|
||||
- block_type: FreqCoordConvDown
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 64
|
||||
- block_type: LayerGroup
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- block_type: FreqCoordConvDown
|
||||
- name: FreqCoordConvDown
|
||||
out_channels: 128
|
||||
- block_type: ConvBlock
|
||||
- name: ConvBlock
|
||||
out_channels: 256
|
||||
bottleneck:
|
||||
channels: 256
|
||||
self_attention: true
|
||||
layers:
|
||||
- name: SelfAttention
|
||||
attention_channels: 256
|
||||
decoder:
|
||||
layers:
|
||||
- block_type: FreqCoordConvUp
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 64
|
||||
- block_type: FreqCoordConvUp
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- block_type: LayerGroup
|
||||
- name: LayerGroup
|
||||
layers:
|
||||
- block_type: FreqCoordConvUp
|
||||
- name: FreqCoordConvUp
|
||||
out_channels: 32
|
||||
- block_type: ConvBlock
|
||||
- name: ConvBlock
|
||||
out_channels: 32
|
||||
|
||||
train:
|
||||
batch_size: 8
|
||||
learning_rate: 0.001
|
||||
t_max: 100
|
||||
optimizer:
|
||||
learning_rate: 0.001
|
||||
t_max: 100
|
||||
|
||||
labels:
|
||||
sigma: 3
|
||||
|
||||
trainer:
|
||||
max_epochs: 10
|
||||
check_val_every_n_epoch: 5
|
||||
|
||||
train_loader:
|
||||
batch_size: 8
|
||||
|
||||
num_workers: 2
|
||||
|
||||
shuffle: True
|
||||
|
||||
clipping_strategy:
|
||||
name: random_subclip
|
||||
duration: 0.256
|
||||
|
||||
augmentations:
|
||||
enabled: true
|
||||
audio:
|
||||
- name: mix_audio
|
||||
probability: 0.2
|
||||
min_weight: 0.3
|
||||
max_weight: 0.7
|
||||
- name: add_echo
|
||||
probability: 0.2
|
||||
max_delay: 0.005
|
||||
min_weight: 0.0
|
||||
max_weight: 1.0
|
||||
spectrogram:
|
||||
- name: scale_volume
|
||||
probability: 0.2
|
||||
min_scaling: 0.0
|
||||
max_scaling: 2.0
|
||||
- name: warp
|
||||
probability: 0.2
|
||||
delta: 0.04
|
||||
- name: mask_time
|
||||
probability: 0.2
|
||||
max_perc: 0.05
|
||||
max_masks: 3
|
||||
- name: mask_freq
|
||||
probability: 0.2
|
||||
max_perc: 0.10
|
||||
max_masks: 3
|
||||
|
||||
val_loader:
|
||||
num_workers: 2
|
||||
clipping_strategy:
|
||||
name: whole_audio_padded
|
||||
chunk_size: 0.256
|
||||
|
||||
loss:
|
||||
detection:
|
||||
weight: 1.0
|
||||
@ -127,37 +133,54 @@ train:
|
||||
alpha: 2
|
||||
size:
|
||||
weight: 0.1
|
||||
|
||||
logger:
|
||||
logger_type: mlflow
|
||||
experiment_name: batdetect2
|
||||
tracking_uri: http://localhost:5000
|
||||
log_model: true
|
||||
save_dir: outputs/log/
|
||||
artifact_location: outputs/artifacts/
|
||||
checkpoint_path_prefix: outputs/checkpoints/
|
||||
augmentations:
|
||||
steps:
|
||||
- augmentation_type: mix_audio
|
||||
probability: 0.2
|
||||
min_weight: 0.3
|
||||
max_weight: 0.7
|
||||
- augmentation_type: add_echo
|
||||
probability: 0.2
|
||||
max_delay: 0.005
|
||||
min_weight: 0.0
|
||||
max_weight: 1.0
|
||||
- augmentation_type: scale_volume
|
||||
probability: 0.2
|
||||
min_scaling: 0.0
|
||||
max_scaling: 2.0
|
||||
- augmentation_type: warp
|
||||
probability: 0.2
|
||||
delta: 0.04
|
||||
- augmentation_type: mask_time
|
||||
probability: 0.2
|
||||
max_perc: 0.05
|
||||
max_masks: 3
|
||||
- augmentation_type: mask_freq
|
||||
probability: 0.2
|
||||
max_perc: 0.10
|
||||
max_masks: 3
|
||||
name: csv
|
||||
|
||||
validation:
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: sound_event_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
|
||||
evaluation:
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: score_distribution
|
||||
- name: example_detection
|
||||
- name: sound_event_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: top_class_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: confusion_matrix
|
||||
- name: example_classification
|
||||
- name: clip_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: roc_curve
|
||||
- name: score_distribution
|
||||
- name: clip_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: roc_curve
|
||||
|
||||
8
example_data/dataset.yaml
Normal file
8
example_data/dataset.yaml
Normal file
@ -0,0 +1,8 @@
|
||||
name: example dataset
|
||||
description: Only for demonstration purposes
|
||||
sources:
|
||||
- format: batdetect2
|
||||
name: Example Data
|
||||
description: Examples included for testing batdetect2
|
||||
annotations_dir: example_data/anns
|
||||
audio_dir: example_data/audio
|
||||
36
example_data/targets.yaml
Normal file
36
example_data/targets.yaml
Normal file
@ -0,0 +1,36 @@
|
||||
detection_target:
|
||||
name: bat
|
||||
match_if:
|
||||
name: all_of
|
||||
conditions:
|
||||
- name: has_tag
|
||||
tag: { key: event, value: Echolocation }
|
||||
- name: not
|
||||
condition:
|
||||
name: has_tag
|
||||
tag: { key: class, value: Unknown }
|
||||
assign_tags:
|
||||
- key: class
|
||||
value: Bat
|
||||
|
||||
classification_targets:
|
||||
- name: myomys
|
||||
tags:
|
||||
- key: class
|
||||
value: Myotis mystacinus
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: class
|
||||
value: Pipistrellus pipistrellus
|
||||
- name: eptser
|
||||
tags:
|
||||
- key: class
|
||||
value: Eptesicus serotinus
|
||||
- name: rhifer
|
||||
tags:
|
||||
- key: class
|
||||
value: Rhinolophus ferrumequinum
|
||||
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
12
justfile
12
justfile
@ -92,19 +92,11 @@ clean-build:
|
||||
clean: clean-build clean-pyc clean-test clean-docs
|
||||
|
||||
# Examples
|
||||
# Preprocess example data.
|
||||
example-preprocess OPTIONS="":
|
||||
batdetect2 preprocess \
|
||||
--base-dir . \
|
||||
--dataset-field datasets.train \
|
||||
--config example_data/config.yaml \
|
||||
{{OPTIONS}} \
|
||||
example_data/datasets.yaml example_data/preprocessed
|
||||
|
||||
# Train on example data.
|
||||
example-train OPTIONS="":
|
||||
batdetect2 train \
|
||||
--val-dir example_data/preprocessed \
|
||||
--val-dataset example_data/dataset.yaml \
|
||||
--config example_data/config.yaml \
|
||||
{{OPTIONS}} \
|
||||
example_data/preprocessed
|
||||
example_data/dataset.yaml
|
||||
|
||||
@ -17,13 +17,13 @@ dependencies = [
|
||||
"torch>=1.13.1,<2.5.0",
|
||||
"torchaudio>=1.13.1,<2.5.0",
|
||||
"torchvision>=0.14.0",
|
||||
"soundevent[audio,geometry,plot]>=2.7.0",
|
||||
"soundevent[audio,geometry,plot]>=2.9.1",
|
||||
"click>=8.1.7",
|
||||
"netcdf4>=1.6.5",
|
||||
"tqdm>=4.66.2",
|
||||
"cf-xarray>=0.9.0",
|
||||
"onnx>=1.16.0",
|
||||
"lightning[extra]>=2.2.2",
|
||||
"lightning[extra]==2.5.0",
|
||||
"tensorboard>=2.16.2",
|
||||
"omegaconf>=2.3.0",
|
||||
"pyyaml>=6.0.2",
|
||||
|
||||
272
src/batdetect2/api_v2.py
Normal file
272
src/batdetect2/api_v2.py
Normal file
@ -0,0 +1,272 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from soundevent import data
|
||||
from soundevent.audio.files import get_audio_files
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
||||
from batdetect2.inference import process_file_list, run_batch_inference
|
||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
load_model_from_checkpoint,
|
||||
train,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
BatDetect2Prediction,
|
||||
EvaluatorProtocol,
|
||||
PostprocessorProtocol,
|
||||
PreprocessorProtocol,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
|
||||
class BatDetect2API:
|
||||
def __init__(
|
||||
self,
|
||||
config: BatDetect2Config,
|
||||
targets: TargetProtocol,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
evaluator: EvaluatorProtocol,
|
||||
model: Model,
|
||||
):
|
||||
self.config = config
|
||||
self.targets = targets
|
||||
self.audio_loader = audio_loader
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
self.evaluator = evaluator
|
||||
self.model = model
|
||||
|
||||
self.model.eval()
|
||||
|
||||
def train(
|
||||
self,
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
||||
train_workers: Optional[int] = None,
|
||||
val_workers: Optional[int] = None,
|
||||
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
|
||||
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
targets=self.targets,
|
||||
config=self.config,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
)
|
||||
return self
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
num_workers: Optional[int] = None,
|
||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
):
|
||||
return evaluate(
|
||||
self.model,
|
||||
test_annotations,
|
||||
targets=self.targets,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
config=self.config,
|
||||
num_workers=num_workers,
|
||||
output_dir=output_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
|
||||
def load_audio(self, path: data.PathLike) -> np.ndarray:
|
||||
return self.audio_loader.load_file(path)
|
||||
|
||||
def load_clip(self, clip: data.Clip) -> np.ndarray:
|
||||
return self.audio_loader.load_clip(clip)
|
||||
|
||||
def generate_spectrogram(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
) -> torch.Tensor:
|
||||
tensor = torch.tensor(audio).unsqueeze(0)
|
||||
return self.preprocessor(tensor)
|
||||
|
||||
def process_file(self, audio_file: str) -> BatDetect2Prediction:
|
||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||
wav = self.audio_loader.load_recording(recording)
|
||||
detections = self.process_audio(wav)
|
||||
return BatDetect2Prediction(
|
||||
clip=data.Clip(
|
||||
uuid=recording.uuid,
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
),
|
||||
predictions=detections,
|
||||
)
|
||||
|
||||
def process_audio(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
) -> List[RawPrediction]:
|
||||
spec = self.generate_spectrogram(audio)
|
||||
return self.process_spectrogram(spec)
|
||||
|
||||
def process_spectrogram(
|
||||
self,
|
||||
spec: torch.Tensor,
|
||||
start_time: float = 0,
|
||||
) -> List[RawPrediction]:
|
||||
if spec.ndim == 4 and spec.shape[0] > 1:
|
||||
raise ValueError("Batched spectrograms not supported.")
|
||||
|
||||
if spec.ndim == 3:
|
||||
spec = spec.unsqueeze(0)
|
||||
|
||||
outputs = self.model.detector(spec)
|
||||
|
||||
detections = self.model.postprocessor(
|
||||
outputs,
|
||||
start_times=[start_time],
|
||||
)[0]
|
||||
|
||||
return to_raw_predictions(detections.numpy(), targets=self.targets)
|
||||
|
||||
def process_directory(
|
||||
self,
|
||||
audio_dir: data.PathLike,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
files = list(get_audio_files(audio_dir))
|
||||
return self.process_files(files)
|
||||
|
||||
def process_files(
|
||||
self,
|
||||
audio_files: Sequence[data.PathLike],
|
||||
num_workers: Optional[int] = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
return process_file_list(
|
||||
self.model,
|
||||
audio_files,
|
||||
config=self.config,
|
||||
targets=self.targets,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
def process_clips(
|
||||
self,
|
||||
clips: Sequence[data.Clip],
|
||||
) -> List[BatDetect2Prediction]:
|
||||
return run_batch_inference(
|
||||
self.model,
|
||||
clips,
|
||||
targets=self.targets,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: BatDetect2Config):
|
||||
targets = build_targets(config=config.targets)
|
||||
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
|
||||
preprocessor = build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.postprocess,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
|
||||
# NOTE: Better to have a separate instance of
|
||||
# preprocessor and postprocessor as these may be moved
|
||||
# to another device.
|
||||
model = build_model(
|
||||
config=config.model,
|
||||
targets=targets,
|
||||
preprocessor=build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
),
|
||||
postprocessor=build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.postprocess,
|
||||
),
|
||||
)
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
evaluator=evaluator,
|
||||
model=model,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
path: data.PathLike,
|
||||
config: Optional[BatDetect2Config] = None,
|
||||
):
|
||||
model, stored_config = load_model_from_checkpoint(path)
|
||||
|
||||
config = config or stored_config
|
||||
|
||||
targets = build_targets(config=config.targets)
|
||||
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
|
||||
preprocessor = build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.postprocess,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
evaluator=evaluator,
|
||||
model=model,
|
||||
)
|
||||
16
src/batdetect2/audio/__init__.py
Normal file
16
src/batdetect2/audio/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
from batdetect2.audio.clips import ClipConfig, build_clipper
|
||||
from batdetect2.audio.loader import (
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
AudioConfig,
|
||||
SoundEventAudioLoader,
|
||||
build_audio_loader,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"AudioConfig",
|
||||
"SoundEventAudioLoader",
|
||||
"build_audio_loader",
|
||||
"ClipConfig",
|
||||
"build_clipper",
|
||||
]
|
||||
264
src/batdetect2/audio/clips.py
Normal file
264
src/batdetect2/audio/clips.py
Normal file
@ -0,0 +1,264 @@
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds, intervals_overlap
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.typing import ClipperProtocol
|
||||
|
||||
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||
|
||||
|
||||
__all__ = [
|
||||
"build_clipper",
|
||||
"ClipConfig",
|
||||
]
|
||||
|
||||
|
||||
clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")
|
||||
|
||||
|
||||
class RandomClipConfig(BaseConfig):
|
||||
name: Literal["random_subclip"] = "random_subclip"
|
||||
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
||||
random: bool = True
|
||||
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
|
||||
min_sound_event_overlap: float = 0
|
||||
|
||||
|
||||
class RandomClip:
|
||||
def __init__(
|
||||
self,
|
||||
duration: float = 0.5,
|
||||
max_empty: float = 0.2,
|
||||
random: bool = True,
|
||||
min_sound_event_overlap: float = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.duration = duration
|
||||
self.random = random
|
||||
self.max_empty = max_empty
|
||||
self.min_sound_event_overlap = min_sound_event_overlap
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
) -> data.ClipAnnotation:
|
||||
subclip = self.get_subclip(clip_annotation.clip)
|
||||
sound_events = select_sound_event_annotations(
|
||||
clip_annotation,
|
||||
subclip,
|
||||
min_overlap=self.min_sound_event_overlap,
|
||||
)
|
||||
return clip_annotation.model_copy(
|
||||
update=dict(
|
||||
clip=subclip,
|
||||
sound_events=sound_events,
|
||||
)
|
||||
)
|
||||
|
||||
def get_subclip(self, clip: data.Clip) -> data.Clip:
|
||||
return select_random_subclip(
|
||||
clip,
|
||||
random=self.random,
|
||||
duration=self.duration,
|
||||
max_empty=self.max_empty,
|
||||
)
|
||||
|
||||
@clipper_registry.register(RandomClipConfig)
|
||||
@staticmethod
|
||||
def from_config(config: RandomClipConfig):
|
||||
return RandomClip(
|
||||
duration=config.duration,
|
||||
max_empty=config.max_empty,
|
||||
min_sound_event_overlap=config.min_sound_event_overlap,
|
||||
)
|
||||
|
||||
|
||||
def get_subclip_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
random: bool = True,
|
||||
duration: float = 0.5,
|
||||
max_empty: float = 0.2,
|
||||
min_sound_event_overlap: float = 0,
|
||||
) -> data.ClipAnnotation:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
subclip = select_random_subclip(
|
||||
clip,
|
||||
random=random,
|
||||
duration=duration,
|
||||
max_empty=max_empty,
|
||||
)
|
||||
|
||||
sound_events = select_sound_event_annotations(
|
||||
clip_annotation,
|
||||
subclip,
|
||||
min_overlap=min_sound_event_overlap,
|
||||
)
|
||||
|
||||
return clip_annotation.model_copy(
|
||||
update=dict(
|
||||
clip=subclip,
|
||||
sound_events=sound_events,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def select_random_subclip(
|
||||
clip: data.Clip,
|
||||
random: bool = True,
|
||||
duration: float = 0.5,
|
||||
max_empty: float = 0.2,
|
||||
) -> data.Clip:
|
||||
start_time = clip.start_time
|
||||
end_time = clip.end_time
|
||||
|
||||
if duration > clip.duration + max_empty or not random:
|
||||
return clip.model_copy(
|
||||
update=dict(
|
||||
start_time=start_time,
|
||||
end_time=start_time + duration,
|
||||
)
|
||||
)
|
||||
|
||||
random_start_time = np.random.uniform(
|
||||
low=start_time,
|
||||
high=end_time + max_empty - duration,
|
||||
)
|
||||
|
||||
return clip.model_copy(
|
||||
update=dict(
|
||||
start_time=random_start_time,
|
||||
end_time=random_start_time + duration,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def select_sound_event_annotations(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
subclip: data.Clip,
|
||||
min_overlap: float = 0,
|
||||
) -> List[data.SoundEventAnnotation]:
|
||||
selected = []
|
||||
|
||||
start_time = subclip.start_time
|
||||
end_time = subclip.end_time
|
||||
|
||||
for sound_event_annotation in clip_annotation.sound_events:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
continue
|
||||
|
||||
geom_start_time, _, geom_end_time, _ = compute_bounds(geometry)
|
||||
|
||||
if not intervals_overlap(
|
||||
(start_time, end_time),
|
||||
(geom_start_time, geom_end_time),
|
||||
min_absolute_overlap=min_overlap,
|
||||
):
|
||||
continue
|
||||
|
||||
selected.append(sound_event_annotation)
|
||||
|
||||
return selected
|
||||
|
||||
|
||||
class PaddedClipConfig(BaseConfig):
|
||||
name: Literal["whole_audio_padded"] = "whole_audio_padded"
|
||||
chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION
|
||||
|
||||
|
||||
class PaddedClip:
|
||||
def __init__(self, chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION):
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
) -> data.ClipAnnotation:
|
||||
clip = clip_annotation.clip
|
||||
clip = self.get_subclip(clip)
|
||||
return clip_annotation.model_copy(update=dict(clip=clip))
|
||||
|
||||
def get_subclip(self, clip: data.Clip) -> data.Clip:
|
||||
duration = clip.duration
|
||||
|
||||
target_duration = float(
|
||||
self.chunk_size * np.ceil(duration / self.chunk_size)
|
||||
)
|
||||
clip = clip.model_copy(
|
||||
update=dict(
|
||||
end_time=clip.start_time + target_duration,
|
||||
)
|
||||
)
|
||||
return clip
|
||||
|
||||
@clipper_registry.register(PaddedClipConfig)
|
||||
@staticmethod
|
||||
def from_config(config: PaddedClipConfig):
|
||||
return PaddedClip(chunk_size=config.chunk_size)
|
||||
|
||||
|
||||
class FixedDurationClipConfig(BaseConfig):
|
||||
name: Literal["fixed_duration"] = "fixed_duration"
|
||||
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
||||
|
||||
|
||||
class FixedDurationClip:
|
||||
def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION):
|
||||
self.duration = duration
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
) -> data.ClipAnnotation:
|
||||
clip = self.get_subclip(clip_annotation.clip)
|
||||
sound_events = select_sound_event_annotations(
|
||||
clip_annotation,
|
||||
clip,
|
||||
min_overlap=0,
|
||||
)
|
||||
return clip_annotation.model_copy(
|
||||
update=dict(
|
||||
clip=clip,
|
||||
sound_events=sound_events,
|
||||
)
|
||||
)
|
||||
|
||||
def get_subclip(self, clip: data.Clip) -> data.Clip:
|
||||
return clip.model_copy(
|
||||
update=dict(
|
||||
end_time=clip.start_time + self.duration,
|
||||
)
|
||||
)
|
||||
|
||||
@clipper_registry.register(FixedDurationClipConfig)
|
||||
@staticmethod
|
||||
def from_config(config: FixedDurationClipConfig):
|
||||
return FixedDurationClip(duration=config.duration)
|
||||
|
||||
|
||||
ClipConfig = Annotated[
|
||||
Union[
|
||||
RandomClipConfig,
|
||||
PaddedClipConfig,
|
||||
FixedDurationClipConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
|
||||
config = config or RandomClipConfig()
|
||||
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building clipper with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
return clipper_registry.build(config)
|
||||
295
src/batdetect2/audio/loader.py
Normal file
295
src/batdetect2/audio/loader.py
Normal file
@ -0,0 +1,295 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import Field
|
||||
from scipy.signal import resample, resample_poly
|
||||
from soundevent import audio, data
|
||||
from soundfile import LibsndfileError
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.typing import AudioLoader
|
||||
|
||||
__all__ = [
|
||||
"SoundEventAudioLoader",
|
||||
"build_audio_loader",
|
||||
"load_file_audio",
|
||||
"load_recording_audio",
|
||||
"load_clip_audio",
|
||||
"resample_audio",
|
||||
]
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256_000
|
||||
"""Default target sample rate in Hz used if resampling is enabled."""
|
||||
|
||||
|
||||
class ResampleConfig(BaseConfig):
|
||||
"""Configuration for audio resampling.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
samplerate : int, default=256000
|
||||
The target sample rate in Hz to resample the audio to. Must be > 0.
|
||||
method : str, default="poly"
|
||||
The resampling algorithm to use. Options:
|
||||
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
|
||||
Generally fast.
|
||||
- "fourier": Resampling via Fourier method using
|
||||
`scipy.signal.resample`. May handle non-integer
|
||||
resampling factors differently.
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
method: str = "poly"
|
||||
|
||||
|
||||
class AudioConfig(BaseConfig):
|
||||
"""Configuration for loading and initial audio preprocessing."""
|
||||
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
resample: ResampleConfig = Field(default_factory=ResampleConfig)
|
||||
|
||||
|
||||
def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader:
|
||||
"""Factory function to create an AudioLoader based on configuration."""
|
||||
config = config or AudioConfig()
|
||||
return SoundEventAudioLoader(
|
||||
samplerate=config.samplerate,
|
||||
config=config.resample,
|
||||
)
|
||||
|
||||
|
||||
class SoundEventAudioLoader(AudioLoader):
|
||||
"""Concrete implementation of the `AudioLoader`."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
):
|
||||
self.samplerate = samplerate
|
||||
self.config = config or ResampleConfig()
|
||||
|
||||
def load_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio directly from a file path."""
|
||||
return load_file_audio(
|
||||
path,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio for a Recording object."""
|
||||
return load_recording_audio(
|
||||
recording,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the audio segment defined by a Clip object."""
|
||||
return load_clip_audio(
|
||||
clip,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_file_audio(
|
||||
path: data.PathLike,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio from a file path using specified config."""
|
||||
try:
|
||||
recording = data.Recording.from_file(path)
|
||||
except LibsndfileError as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not load the recording at path: {path}. Error: {e}"
|
||||
) from e
|
||||
|
||||
return load_recording_audio(
|
||||
recording,
|
||||
samplerate=samplerate,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_recording_audio(
|
||||
recording: data.Recording,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio content of a recording using config."""
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
)
|
||||
return load_clip_audio(
|
||||
clip,
|
||||
samplerate=samplerate,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess a specific audio clip segment based on config."""
|
||||
try:
|
||||
wav = (
|
||||
audio.load_clip(clip, audio_dir=audio_dir)
|
||||
.sel(channel=0)
|
||||
.astype(dtype)
|
||||
)
|
||||
except LibsndfileError as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not load the recording at path: {clip.recording.path}. "
|
||||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
if not config or not config.enabled or samplerate is None:
|
||||
return wav.data.astype(dtype)
|
||||
|
||||
sr = int(1 / wav.time.attrs["step"])
|
||||
return resample_audio(
|
||||
wav.data,
|
||||
sr=sr,
|
||||
samplerate=samplerate,
|
||||
method=config.method,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio(
|
||||
wav: np.ndarray,
|
||||
sr: int,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
method: str = "poly",
|
||||
) -> np.ndarray:
|
||||
"""Resample an audio waveform DataArray to a target sample rate."""
|
||||
if sr == samplerate:
|
||||
return wav
|
||||
|
||||
if method == "poly":
|
||||
return resample_audio_poly(
|
||||
wav,
|
||||
sr_orig=sr,
|
||||
sr_new=samplerate,
|
||||
)
|
||||
elif method == "fourier":
|
||||
return resample_audio_fourier(
|
||||
wav,
|
||||
sr_orig=sr,
|
||||
sr_new=samplerate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Resampling method '{method}' not implemented"
|
||||
)
|
||||
|
||||
|
||||
def resample_audio_poly(
|
||||
array: np.ndarray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
"""Resample a numpy array using `scipy.signal.resample_poly`.
|
||||
|
||||
This method is often preferred for signals when the ratio of new
|
||||
to old sample rates can be expressed as a rational number. It uses
|
||||
polyphase filtering.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : np.ndarray
|
||||
The input array to resample.
|
||||
sr_orig : int
|
||||
The original sample rate in Hz.
|
||||
sr_new : int
|
||||
The target sample rate in Hz.
|
||||
axis : int, default=-1
|
||||
The axis of `array` along which to resample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The array resampled to the target sample rate.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If sample rates are not positive.
|
||||
"""
|
||||
gcd = np.gcd(sr_orig, sr_new)
|
||||
return resample_poly(
|
||||
array,
|
||||
sr_new // gcd,
|
||||
sr_orig // gcd,
|
||||
axis=axis,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio_fourier(
|
||||
array: np.ndarray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
"""Resample a numpy array using `scipy.signal.resample`.
|
||||
|
||||
This method uses FFTs to resample the signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : np.ndarray
|
||||
The input array to resample.
|
||||
num : int
|
||||
The desired number of samples in the output array along `axis`.
|
||||
axis : int, default=-1
|
||||
The axis of `array` along which to resample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The array resampled to have `num` samples along `axis`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num` is negative.
|
||||
"""
|
||||
ratio = sr_new / sr_orig
|
||||
return resample( # type: ignore
|
||||
array,
|
||||
int(array.shape[axis] * ratio),
|
||||
axis=axis,
|
||||
)
|
||||
@ -1,7 +1,7 @@
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.cli.compat import detect
|
||||
from batdetect2.cli.data import data
|
||||
from batdetect2.cli.preprocess import preprocess
|
||||
from batdetect2.cli.evaluate import evaluate_command
|
||||
from batdetect2.cli.train import train_command
|
||||
|
||||
__all__ = [
|
||||
@ -9,7 +9,7 @@ __all__ = [
|
||||
"detect",
|
||||
"data",
|
||||
"train_command",
|
||||
"preprocess",
|
||||
"evaluate_command",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1,10 +1,15 @@
|
||||
import os
|
||||
|
||||
import click
|
||||
|
||||
from batdetect2 import api
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
||||
from batdetect2.types import ProcessingConfiguration
|
||||
from batdetect2.utils.detector_utils import save_results_to_file
|
||||
|
||||
DEFAULT_MODEL_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"models",
|
||||
"checkpoints",
|
||||
"Net2DFast_UK_same.pth.tar",
|
||||
)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -74,6 +79,9 @@ def detect(
|
||||
|
||||
Input files should be short in duration e.g. < 30 seconds.
|
||||
"""
|
||||
from batdetect2 import api
|
||||
from batdetect2.utils.detector_utils import save_results_to_file
|
||||
|
||||
click.echo(f"Loading model: {args['model_path']}")
|
||||
model, params = api.load_model(args["model_path"])
|
||||
|
||||
@ -123,7 +131,7 @@ def detect(
|
||||
click.echo(f" {err}")
|
||||
|
||||
|
||||
def print_config(config: ProcessingConfiguration):
|
||||
def print_config(config):
|
||||
"""Print the processing configuration."""
|
||||
click.echo("\nProcessing Configuration:")
|
||||
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Optional
|
||||
import click
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
__all__ = ["data"]
|
||||
|
||||
@ -33,6 +32,8 @@ def summary(
|
||||
field: Optional[str] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
):
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
base_dir = base_dir or Path.cwd()
|
||||
dataset = load_dataset_from_config(
|
||||
dataset_config,
|
||||
|
||||
78
src/batdetect2/cli/evaluate.py
Normal file
78
src/batdetect2/cli/evaluate.py
Normal file
@ -0,0 +1,78 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
|
||||
__all__ = ["evaluate_command"]
|
||||
|
||||
|
||||
DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||
|
||||
|
||||
@cli.command(name="evaluate")
|
||||
@click.argument("model-path", type=click.Path(exists=True))
|
||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||
@click.option("--config", "config_path", type=click.Path())
|
||||
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
|
||||
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
|
||||
@click.option("--experiment-name", type=str)
|
||||
@click.option("--run-name", type=str)
|
||||
@click.option("--workers", "num_workers", type=int)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
count=True,
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def evaluate_command(
|
||||
model_path: Path,
|
||||
test_dataset: Path,
|
||||
base_dir: Path,
|
||||
config_path: Optional[Path],
|
||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||
num_workers: Optional[int] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import load_full_config
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
elif verbose == 1:
|
||||
log_level = "INFO"
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
logger.info("Initiating evaluation process...")
|
||||
|
||||
test_annotations = load_dataset_from_config(
|
||||
test_dataset,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
logger.debug(
|
||||
"Loaded {num_annotations} test examples",
|
||||
num_annotations=len(test_annotations),
|
||||
)
|
||||
|
||||
config = None
|
||||
if config_path is not None:
|
||||
config = load_full_config(config_path)
|
||||
|
||||
api = BatDetect2API.from_checkpoint(model_path, config=config)
|
||||
|
||||
api.evaluate(
|
||||
test_annotations,
|
||||
num_workers=num_workers,
|
||||
output_dir=output_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
@ -1,154 +0,0 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.train.preprocess import (
|
||||
TrainPreprocessConfig,
|
||||
load_train_preprocessing_config,
|
||||
preprocess_dataset,
|
||||
)
|
||||
|
||||
__all__ = ["preprocess"]
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument(
|
||||
"dataset_config",
|
||||
type=click.Path(exists=True),
|
||||
)
|
||||
@click.argument(
|
||||
"output",
|
||||
type=click.Path(),
|
||||
)
|
||||
@click.option(
|
||||
"--dataset-field",
|
||||
type=str,
|
||||
help=(
|
||||
"Specifies the key to access the dataset information within the "
|
||||
"dataset configuration file, if the information is nested inside a "
|
||||
"dictionary. If the dataset information is at the top level of the "
|
||||
"config file, you don't need to specify this."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--base-dir",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"The main directory where your audio recordings and annotation "
|
||||
"files are stored. This helps the program find your data, "
|
||||
"especially if the paths in your dataset configuration file "
|
||||
"are relative."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to the configuration file. This file tells "
|
||||
"the program how to prepare your audio data before training, such "
|
||||
"as resampling or applying filters."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--config-field",
|
||||
type=str,
|
||||
help=(
|
||||
"If the preprocessing settings are inside a nested dictionary "
|
||||
"within the preprocessing configuration file, specify the key "
|
||||
"here to access them. If the preprocessing settings are at the "
|
||||
"top level, you don't need to specify this."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--force",
|
||||
is_flag=True,
|
||||
help=(
|
||||
"If a preprocessed file already exists, this option tells the "
|
||||
"program to overwrite it with the new preprocessed data. Use "
|
||||
"this if you want to re-do the preprocessing even if the files "
|
||||
"already exist."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
help=(
|
||||
"The maximum number of computer cores to use when processing "
|
||||
"your audio data. Using more cores can speed up the preprocessing, "
|
||||
"but don't use more than your computer has available. By default, "
|
||||
"the program will use all available cores."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
count=True,
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def preprocess(
|
||||
dataset_config: Path,
|
||||
output: Path,
|
||||
base_dir: Optional[Path] = None,
|
||||
config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
force: bool = False,
|
||||
num_workers: Optional[int] = None,
|
||||
dataset_field: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
elif verbose == 1:
|
||||
log_level = "INFO"
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
logger.info("Starting preprocessing.")
|
||||
|
||||
output = Path(output)
|
||||
logger.info("Will save outputs to {output}", output=output)
|
||||
|
||||
base_dir = base_dir or Path.cwd()
|
||||
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
|
||||
|
||||
if config:
|
||||
logger.info(
|
||||
"Loading preprocessing config from: {config}", config=config
|
||||
)
|
||||
|
||||
conf = (
|
||||
load_train_preprocessing_config(config, field=config_field)
|
||||
if config is not None
|
||||
else TrainPreprocessConfig()
|
||||
)
|
||||
logger.debug(
|
||||
"Preprocessing config:\n{conf}",
|
||||
conf=yaml.dump(conf.model_dump()),
|
||||
)
|
||||
|
||||
dataset = load_dataset_from_config(
|
||||
dataset_config,
|
||||
field=dataset_field,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Loaded {num_examples} annotated clips from the configured dataset",
|
||||
num_examples=len(dataset),
|
||||
)
|
||||
|
||||
preprocess_dataset(
|
||||
dataset,
|
||||
conf,
|
||||
output=output,
|
||||
force=force,
|
||||
max_workers=num_workers,
|
||||
)
|
||||
@ -6,24 +6,24 @@ import click
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.train import (
|
||||
FullTrainingConfig,
|
||||
load_full_training_config,
|
||||
train,
|
||||
)
|
||||
from batdetect2.train.dataset import list_preprocessed_files
|
||||
|
||||
__all__ = ["train_command"]
|
||||
|
||||
|
||||
@cli.command(name="train")
|
||||
@click.argument("train_dir", type=click.Path(exists=True))
|
||||
@click.option("--val-dir", type=click.Path(exists=True))
|
||||
@click.option("--model-path", type=click.Path(exists=True))
|
||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||
@click.option("--model", "model_path", type=click.Path(exists=True))
|
||||
@click.option("--targets", "targets_config", type=click.Path(exists=True))
|
||||
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
||||
@click.option("--log-dir", type=click.Path(exists=True))
|
||||
@click.option("--config", type=click.Path(exists=True))
|
||||
@click.option("--config-field", type=str)
|
||||
@click.option("--train-workers", type=int)
|
||||
@click.option("--val-workers", type=int)
|
||||
@click.option("--experiment-name", type=str)
|
||||
@click.option("--run-name", type=str)
|
||||
@click.option("--seed", type=int)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
@ -31,15 +31,29 @@ __all__ = ["train_command"]
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def train_command(
|
||||
train_dir: Path,
|
||||
val_dir: Optional[Path] = None,
|
||||
train_dataset: Path,
|
||||
val_dataset: Optional[Path] = None,
|
||||
model_path: Optional[Path] = None,
|
||||
ckpt_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
config: Optional[Path] = None,
|
||||
targets_config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import (
|
||||
BatDetect2Config,
|
||||
load_full_config,
|
||||
)
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.targets import load_target_config
|
||||
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
@ -48,41 +62,53 @@ def train_command(
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
logger.info("Initiating training process...")
|
||||
|
||||
logger.info("Loading training configuration...")
|
||||
logger.info("Loading configuration...")
|
||||
conf = (
|
||||
load_full_training_config(config, field=config_field)
|
||||
load_full_config(config, field=config_field)
|
||||
if config is not None
|
||||
else FullTrainingConfig()
|
||||
else BatDetect2Config()
|
||||
)
|
||||
|
||||
logger.info("Scanning for training and validation data...")
|
||||
train_examples = list_preprocessed_files(train_dir)
|
||||
if targets_config is not None:
|
||||
logger.info("Loading targets configuration...")
|
||||
conf = conf.model_copy(
|
||||
update=dict(targets=load_target_config(targets_config))
|
||||
)
|
||||
|
||||
logger.info("Loading training dataset...")
|
||||
train_annotations = load_dataset_from_config(train_dataset)
|
||||
logger.debug(
|
||||
"Found {num_files} training examples in {path}",
|
||||
num_files=len(train_examples),
|
||||
path=train_dir,
|
||||
"Loaded {num_annotations} training examples",
|
||||
num_annotations=len(train_annotations),
|
||||
)
|
||||
|
||||
val_examples = None
|
||||
if val_dir is not None:
|
||||
val_examples = list_preprocessed_files(val_dir)
|
||||
val_annotations = None
|
||||
if val_dataset is not None:
|
||||
val_annotations = load_dataset_from_config(val_dataset)
|
||||
logger.debug(
|
||||
"Found {num_files} validation examples in {path}",
|
||||
num_files=len(val_examples),
|
||||
path=val_dir,
|
||||
"Loaded {num_annotations} validation examples",
|
||||
num_annotations=len(val_annotations),
|
||||
)
|
||||
else:
|
||||
logger.debug("No validation directory provided.")
|
||||
|
||||
logger.info("Configuration and data loaded. Starting training...")
|
||||
train(
|
||||
train_examples=train_examples,
|
||||
val_examples=val_examples,
|
||||
config=conf,
|
||||
model_path=model_path,
|
||||
|
||||
if model_path is None:
|
||||
api = BatDetect2API.from_config(conf)
|
||||
else:
|
||||
api = BatDetect2API.from_checkpoint(model_path)
|
||||
|
||||
return api.train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
checkpoint_dir=ckpt_dir,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
@ -11,7 +11,6 @@ from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
from soundevent.types import ClassMapper
|
||||
|
||||
from batdetect2.targets.terms import get_term_from_key
|
||||
from batdetect2.types import (
|
||||
Annotation,
|
||||
AudioLoaderAnnotationGroup,
|
||||
@ -173,18 +172,9 @@ def annotation_to_sound_event_annotation(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=get_term_from_key(label_key),
|
||||
value=annotation["class"],
|
||||
),
|
||||
data.Tag(
|
||||
term=get_term_from_key(event_key),
|
||||
value=annotation["event"],
|
||||
),
|
||||
data.Tag(
|
||||
term=get_term_from_key(individual_key),
|
||||
value=str(annotation["individual"]),
|
||||
),
|
||||
data.Tag(key=label_key, value=annotation["class"]),
|
||||
data.Tag(key=event_key, value=annotation["event"]),
|
||||
data.Tag(key=individual_key, value=str(annotation["individual"])),
|
||||
],
|
||||
)
|
||||
|
||||
@ -219,17 +209,11 @@ def annotation_to_sound_event_prediction(
|
||||
tags=[
|
||||
data.PredictedTag(
|
||||
score=annotation["class_prob"],
|
||||
tag=data.Tag(
|
||||
term=get_term_from_key(label_key),
|
||||
value=annotation["class"],
|
||||
),
|
||||
tag=data.Tag(key=label_key, value=annotation["class"]),
|
||||
),
|
||||
data.PredictedTag(
|
||||
score=annotation["det_prob"],
|
||||
tag=data.Tag(
|
||||
term=get_term_from_key(event_key),
|
||||
value=annotation["event"],
|
||||
),
|
||||
tag=data.Tag(key=event_key, value=annotation["event"]),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
42
src/batdetect2/config.py
Normal file
42
src/batdetect2/config.py
Normal file
@ -0,0 +1,42 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.configs import load_config
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
from batdetect2.models.config import BackboneConfig
|
||||
from batdetect2.postprocess.config import PostprocessConfig
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.targets.config import TargetConfig
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
|
||||
__all__ = [
|
||||
"BatDetect2Config",
|
||||
"load_full_config",
|
||||
]
|
||||
|
||||
|
||||
class BatDetect2Config(BaseConfig):
|
||||
config_version: Literal["v1"] = "v1"
|
||||
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||
|
||||
|
||||
def load_full_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> BatDetect2Config:
|
||||
return load_config(path, schema=BatDetect2Config, field=field)
|
||||
8
src/batdetect2/core/__init__.py
Normal file
8
src/batdetect2/core/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.core.registries import Registry
|
||||
|
||||
__all__ = [
|
||||
"BaseConfig",
|
||||
"load_config",
|
||||
"Registry",
|
||||
]
|
||||
95
src/batdetect2/core/arrays.py
Normal file
95
src/batdetect2/core/arrays.py
Normal file
@ -0,0 +1,95 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
|
||||
|
||||
def spec_to_xarray(
|
||||
spec: np.ndarray,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
) -> xr.DataArray:
|
||||
if spec.ndim != 2:
|
||||
raise ValueError(
|
||||
"Input numpy spectrogram array should be 2-dimensional"
|
||||
)
|
||||
|
||||
height, width = spec.shape
|
||||
return xr.DataArray(
|
||||
data=spec,
|
||||
dims=["frequency", "time"],
|
||||
coords={
|
||||
"frequency": np.linspace(
|
||||
min_freq,
|
||||
max_freq,
|
||||
height,
|
||||
endpoint=False,
|
||||
),
|
||||
"time": np.linspace(
|
||||
start_time,
|
||||
end_time,
|
||||
width,
|
||||
endpoint=False,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def extend_width(
|
||||
tensor: torch.Tensor,
|
||||
extra: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> torch.Tensor:
|
||||
dims = len(tensor.shape)
|
||||
axis = dims - axis % dims - 1
|
||||
pad = [0 for _ in range(2 * dims)]
|
||||
pad[2 * axis + 1] = extra
|
||||
return torch.nn.functional.pad(
|
||||
tensor,
|
||||
pad,
|
||||
mode="constant",
|
||||
value=value,
|
||||
)
|
||||
|
||||
|
||||
def adjust_width(
|
||||
tensor: torch.Tensor,
|
||||
width: int,
|
||||
axis: int = -1,
|
||||
value: float = 0,
|
||||
) -> torch.Tensor:
|
||||
dims = len(tensor.shape)
|
||||
axis = axis % dims
|
||||
current_width = tensor.shape[axis]
|
||||
|
||||
if current_width == width:
|
||||
return tensor
|
||||
|
||||
if current_width < width:
|
||||
return extend_width(
|
||||
tensor,
|
||||
extra=width - current_width,
|
||||
axis=axis,
|
||||
value=value,
|
||||
)
|
||||
|
||||
slices = [
|
||||
slice(None, None) if index != axis else slice(None, width)
|
||||
for index in range(dims)
|
||||
]
|
||||
return tensor[tuple(slices)]
|
||||
|
||||
|
||||
def slice_tensor(
|
||||
tensor: torch.Tensor,
|
||||
start: Optional[int] = None,
|
||||
end: Optional[int] = None,
|
||||
dim: int = -1,
|
||||
) -> torch.Tensor:
|
||||
slices = [slice(None)] * tensor.ndim
|
||||
slices[dim] = slice(start, end)
|
||||
return tensor[tuple(slices)]
|
||||
@ -27,7 +27,7 @@ class BaseConfig(BaseModel):
|
||||
and serialization capabilities.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def to_yaml_string(
|
||||
self,
|
||||
@ -53,6 +53,7 @@ class BaseConfig(BaseModel):
|
||||
"""
|
||||
return yaml.dump(
|
||||
self.model_dump(
|
||||
mode="json",
|
||||
exclude_none=exclude_none,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
98
src/batdetect2/core/registries.py
Normal file
98
src/batdetect2/core/registries.py
Normal file
@ -0,0 +1,98 @@
|
||||
import sys
|
||||
from typing import Callable, Dict, Generic, Tuple, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import Concatenate, ParamSpec
|
||||
else:
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
__all__ = [
|
||||
"Registry",
|
||||
"SimpleRegistry",
|
||||
]
|
||||
|
||||
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
|
||||
T_Type = TypeVar("T_Type", covariant=True)
|
||||
P_Type = ParamSpec("P_Type")
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class SimpleRegistry(Generic[T]):
|
||||
def __init__(self, name: str):
|
||||
self._name = name
|
||||
self._registry = {}
|
||||
|
||||
def register(self, name: str):
|
||||
def decorator(obj: T) -> T:
|
||||
self._registry[name] = obj
|
||||
return obj
|
||||
|
||||
return decorator
|
||||
|
||||
def get(self, name: str) -> T:
|
||||
return self._registry[name]
|
||||
|
||||
def has(self, name: str) -> bool:
|
||||
return name in self._registry
|
||||
|
||||
|
||||
class Registry(Generic[T_Type, P_Type]):
|
||||
"""A generic class to create and manage a registry of items."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
self._name = name
|
||||
self._registry: Dict[
|
||||
str, Callable[Concatenate[..., P_Type], T_Type]
|
||||
] = {}
|
||||
self._config_types: Dict[str, Type[BaseModel]] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
config_cls: Type[T_Config],
|
||||
):
|
||||
fields = config_cls.model_fields
|
||||
|
||||
if "name" not in fields:
|
||||
raise ValueError("Configuration object must have a 'name' field.")
|
||||
|
||||
name = fields["name"].default
|
||||
|
||||
self._config_types[name] = config_cls
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise ValueError("'name' field must be a string literal.")
|
||||
|
||||
def decorator(
|
||||
func: Callable[Concatenate[T_Config, P_Type], T_Type],
|
||||
):
|
||||
self._registry[name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
def get_config_types(self) -> Tuple[Type[BaseModel], ...]:
|
||||
return tuple(self._config_types.values())
|
||||
|
||||
def build(
|
||||
self,
|
||||
config: BaseModel,
|
||||
*args: P_Type.args,
|
||||
**kwargs: P_Type.kwargs,
|
||||
) -> T_Type:
|
||||
"""Builds a logic instance from a config object."""
|
||||
|
||||
name = getattr(config, "name") # noqa: B009
|
||||
|
||||
if name is None:
|
||||
raise ValueError("Config does not have a name field")
|
||||
|
||||
if name not in self._registry:
|
||||
raise NotImplementedError(
|
||||
f"No {self._name} with name '{name}' is registered."
|
||||
)
|
||||
|
||||
return self._registry[name](config, *args, **kwargs)
|
||||
@ -14,8 +14,9 @@ format-specific loading function to retrieve the annotations as a standard
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.annotations.aoef import (
|
||||
@ -42,10 +43,13 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
AnnotationFormats = Union[
|
||||
BatDetect2MergedAnnotations,
|
||||
BatDetect2FilesAnnotations,
|
||||
AOEFAnnotations,
|
||||
AnnotationFormats = Annotated[
|
||||
Union[
|
||||
BatDetect2MergedAnnotations,
|
||||
BatDetect2FilesAnnotations,
|
||||
AOEFAnnotations,
|
||||
],
|
||||
Field(discriminator="format"),
|
||||
]
|
||||
"""Type Alias representing all supported data source configurations.
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ from uuid import uuid5
|
||||
from pydantic import Field
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -33,7 +33,7 @@ from loguru import logger
|
||||
from pydantic import Field, ValidationError
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.data.annotations.legacy import (
|
||||
FileAnnotation,
|
||||
file_annotation_to_clip,
|
||||
@ -301,7 +301,8 @@ def load_batdetect2_merged_annotated_dataset(
|
||||
for ann in content:
|
||||
try:
|
||||
ann = FileAnnotation.model_validate(ann)
|
||||
except ValueError:
|
||||
except ValueError as err:
|
||||
logger.warning(f"Invalid annotation file: {err}")
|
||||
continue
|
||||
|
||||
if (
|
||||
@ -309,14 +310,17 @@ def load_batdetect2_merged_annotated_dataset(
|
||||
and dataset.filter.only_annotated
|
||||
and not ann.annotated
|
||||
):
|
||||
logger.debug(f"Skipping incomplete annotation {ann.id}")
|
||||
continue
|
||||
|
||||
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
|
||||
logger.debug(f"Skipping annotation with issues {ann.id}")
|
||||
continue
|
||||
|
||||
try:
|
||||
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
||||
except FileNotFoundError:
|
||||
except FileNotFoundError as err:
|
||||
logger.warning(f"Error loading annotations: {err}")
|
||||
continue
|
||||
|
||||
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
||||
|
||||
@ -8,8 +8,6 @@ from typing import Callable, List, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import get_term_from_key
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
__all__ = []
|
||||
@ -91,18 +89,9 @@ def annotation_to_sound_event(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=get_term_from_key(label_key),
|
||||
value=annotation.label,
|
||||
),
|
||||
data.Tag(
|
||||
term=get_term_from_key(event_key),
|
||||
value=annotation.event,
|
||||
),
|
||||
data.Tag(
|
||||
term=get_term_from_key(individual_key),
|
||||
value=str(annotation.individual),
|
||||
),
|
||||
data.Tag(key=label_key, value=annotation.label),
|
||||
data.Tag(key=event_key, value=annotation.event),
|
||||
data.Tag(key=individual_key, value=str(annotation.individual)),
|
||||
],
|
||||
)
|
||||
|
||||
@ -123,12 +112,7 @@ def file_annotation_to_clip(
|
||||
recording = data.Recording.from_file(
|
||||
full_path,
|
||||
time_expansion=file_annotation.time_exp,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=get_term_from_key(label_key),
|
||||
value=file_annotation.label,
|
||||
)
|
||||
],
|
||||
tags=[data.Tag(key=label_key, value=file_annotation.label)],
|
||||
)
|
||||
|
||||
return data.Clip(
|
||||
@ -155,11 +139,7 @@ def file_annotation_to_clip_annotation(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
||||
clip=clip,
|
||||
notes=notes,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=get_term_from_key(label_key), value=file_annotation.label
|
||||
)
|
||||
],
|
||||
tags=[data.Tag(key=label_key, value=file_annotation.label)],
|
||||
sound_events=[
|
||||
annotation_to_sound_event(
|
||||
annotation,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from pathlib import Path
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"AnnotatedDataset",
|
||||
|
||||
287
src/batdetect2/data/conditions.py
Normal file
287
src/batdetect2/data/conditions.py
Normal file
@ -0,0 +1,287 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Annotated, List, Literal, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
|
||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||
|
||||
conditions: Registry[SoundEventCondition, []] = Registry("condition")
|
||||
|
||||
|
||||
class HasTagConfig(BaseConfig):
|
||||
name: Literal["has_tag"] = "has_tag"
|
||||
tag: data.Tag
|
||||
|
||||
|
||||
class HasTag:
|
||||
def __init__(self, tag: data.Tag):
|
||||
self.tag = tag
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return self.tag in sound_event_annotation.tags
|
||||
|
||||
@conditions.register(HasTagConfig)
|
||||
@staticmethod
|
||||
def from_config(config: HasTagConfig):
|
||||
return HasTag(tag=config.tag)
|
||||
|
||||
|
||||
class HasAllTagsConfig(BaseConfig):
|
||||
name: Literal["has_all_tags"] = "has_all_tags"
|
||||
tags: List[data.Tag]
|
||||
|
||||
|
||||
class HasAllTags:
|
||||
def __init__(self, tags: List[data.Tag]):
|
||||
if not tags:
|
||||
raise ValueError("Need to specify at least one tag")
|
||||
|
||||
self.tags = set(tags)
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return self.tags.issubset(sound_event_annotation.tags)
|
||||
|
||||
@conditions.register(HasAllTagsConfig)
|
||||
@staticmethod
|
||||
def from_config(config: HasAllTagsConfig):
|
||||
return HasAllTags(tags=config.tags)
|
||||
|
||||
|
||||
class HasAnyTagConfig(BaseConfig):
|
||||
name: Literal["has_any_tag"] = "has_any_tag"
|
||||
tags: List[data.Tag]
|
||||
|
||||
|
||||
class HasAnyTag:
|
||||
def __init__(self, tags: List[data.Tag]):
|
||||
if not tags:
|
||||
raise ValueError("Need to specify at least one tag")
|
||||
|
||||
self.tags = set(tags)
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return bool(self.tags.intersection(sound_event_annotation.tags))
|
||||
|
||||
@conditions.register(HasAnyTagConfig)
|
||||
@staticmethod
|
||||
def from_config(config: HasAnyTagConfig):
|
||||
return HasAnyTag(tags=config.tags)
|
||||
|
||||
|
||||
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
||||
|
||||
|
||||
class DurationConfig(BaseConfig):
|
||||
name: Literal["duration"] = "duration"
|
||||
operator: Operator
|
||||
seconds: float
|
||||
|
||||
|
||||
def _build_comparator(
|
||||
operator: Operator, value: float
|
||||
) -> Callable[[float], bool]:
|
||||
if operator == "gt":
|
||||
return lambda x: x > value
|
||||
|
||||
if operator == "gte":
|
||||
return lambda x: x >= value
|
||||
|
||||
if operator == "lt":
|
||||
return lambda x: x < value
|
||||
|
||||
if operator == "lte":
|
||||
return lambda x: x <= value
|
||||
|
||||
if operator == "eq":
|
||||
return lambda x: x == value
|
||||
|
||||
raise ValueError(f"Invalid operator {operator}")
|
||||
|
||||
|
||||
class Duration:
|
||||
def __init__(self, operator: Operator, seconds: float):
|
||||
self.operator = operator
|
||||
self.seconds = seconds
|
||||
self._comparator = _build_comparator(self.operator, self.seconds)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> bool:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
start_time, _, end_time, _ = compute_bounds(geometry)
|
||||
duration = end_time - start_time
|
||||
|
||||
return self._comparator(duration)
|
||||
|
||||
@conditions.register(DurationConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DurationConfig):
|
||||
return Duration(operator=config.operator, seconds=config.seconds)
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
name: Literal["frequency"] = "frequency"
|
||||
boundary: Literal["low", "high"]
|
||||
operator: Operator
|
||||
hertz: float
|
||||
|
||||
|
||||
class Frequency:
|
||||
def __init__(
|
||||
self,
|
||||
operator: Operator,
|
||||
boundary: Literal["low", "high"],
|
||||
hertz: float,
|
||||
):
|
||||
self.operator = operator
|
||||
self.hertz = hertz
|
||||
self.boundary = boundary
|
||||
self._comparator = _build_comparator(self.operator, self.hertz)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> bool:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
# Automatically false if geometry does not have a frequency range
|
||||
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
|
||||
return False
|
||||
|
||||
_, low_freq, _, high_freq = compute_bounds(geometry)
|
||||
|
||||
if self.boundary == "low":
|
||||
return self._comparator(low_freq)
|
||||
|
||||
return self._comparator(high_freq)
|
||||
|
||||
@conditions.register(FrequencyConfig)
|
||||
@staticmethod
|
||||
def from_config(config: FrequencyConfig):
|
||||
return Frequency(
|
||||
operator=config.operator,
|
||||
boundary=config.boundary,
|
||||
hertz=config.hertz,
|
||||
)
|
||||
|
||||
|
||||
class AllOfConfig(BaseConfig):
|
||||
name: Literal["all_of"] = "all_of"
|
||||
conditions: Sequence["SoundEventConditionConfig"]
|
||||
|
||||
|
||||
class AllOf:
|
||||
def __init__(self, conditions: List[SoundEventCondition]):
|
||||
self.conditions = conditions
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return all(c(sound_event_annotation) for c in self.conditions)
|
||||
|
||||
@conditions.register(AllOfConfig)
|
||||
@staticmethod
|
||||
def from_config(config: AllOfConfig):
|
||||
conditions = [
|
||||
build_sound_event_condition(cond) for cond in config.conditions
|
||||
]
|
||||
return AllOf(conditions)
|
||||
|
||||
|
||||
class AnyOfConfig(BaseConfig):
|
||||
name: Literal["any_of"] = "any_of"
|
||||
conditions: List["SoundEventConditionConfig"]
|
||||
|
||||
|
||||
class AnyOf:
|
||||
def __init__(self, conditions: List[SoundEventCondition]):
|
||||
self.conditions = conditions
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return any(c(sound_event_annotation) for c in self.conditions)
|
||||
|
||||
@conditions.register(AnyOfConfig)
|
||||
@staticmethod
|
||||
def from_config(config: AnyOfConfig):
|
||||
conditions = [
|
||||
build_sound_event_condition(cond) for cond in config.conditions
|
||||
]
|
||||
return AnyOf(conditions)
|
||||
|
||||
|
||||
class NotConfig(BaseConfig):
|
||||
name: Literal["not"] = "not"
|
||||
condition: "SoundEventConditionConfig"
|
||||
|
||||
|
||||
class Not:
|
||||
def __init__(self, condition: SoundEventCondition):
|
||||
self.condition = condition
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return not self.condition(sound_event_annotation)
|
||||
|
||||
@conditions.register(NotConfig)
|
||||
@staticmethod
|
||||
def from_config(config: NotConfig):
|
||||
condition = build_sound_event_condition(config.condition)
|
||||
return Not(condition)
|
||||
|
||||
|
||||
SoundEventConditionConfig = Annotated[
|
||||
Union[
|
||||
HasTagConfig,
|
||||
HasAllTagsConfig,
|
||||
HasAnyTagConfig,
|
||||
DurationConfig,
|
||||
FrequencyConfig,
|
||||
AllOfConfig,
|
||||
AnyOfConfig,
|
||||
NotConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_sound_event_condition(
|
||||
config: SoundEventConditionConfig,
|
||||
) -> SoundEventCondition:
|
||||
return conditions.build(config)
|
||||
|
||||
|
||||
def filter_clip_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
condition: SoundEventCondition,
|
||||
) -> data.ClipAnnotation:
|
||||
return clip_annotation.model_copy(
|
||||
update=dict(
|
||||
sound_events=[
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if condition(sound_event)
|
||||
]
|
||||
)
|
||||
)
|
||||
@ -19,18 +19,29 @@ The core components are:
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.data.annotations import (
|
||||
AnnotatedDataset,
|
||||
AnnotationFormats,
|
||||
load_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.conditions import (
|
||||
SoundEventConditionConfig,
|
||||
build_sound_event_condition,
|
||||
filter_clip_annotation,
|
||||
)
|
||||
from batdetect2.data.transforms import (
|
||||
ApplyAll,
|
||||
SoundEventTransformConfig,
|
||||
build_sound_event_transform,
|
||||
transform_clip_annotation,
|
||||
)
|
||||
from batdetect2.targets.terms import data_source
|
||||
|
||||
__all__ = [
|
||||
@ -52,79 +63,68 @@ sources.
|
||||
|
||||
|
||||
class DatasetConfig(BaseConfig):
|
||||
"""Configuration model defining the structure of a BatDetect2 dataset.
|
||||
|
||||
This class is typically loaded from a YAML file and describes the components
|
||||
of the dataset, including metadata and a list of data sources.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
name : str
|
||||
A descriptive name for the dataset (e.g., "UK_Bats_Project_2024").
|
||||
description : str
|
||||
A longer description of the dataset's contents, origin, purpose, etc.
|
||||
sources : List[AnnotationFormats]
|
||||
A list defining the different data sources contributing to this
|
||||
dataset. Each item in the list must conform to one of the Pydantic
|
||||
models defined in the `AnnotationFormats` type union. The specific
|
||||
model used for each source is determined by the mandatory `format`
|
||||
field within the source's configuration, allowing BatDetect2 to use the
|
||||
correct parser for different annotation styles.
|
||||
"""
|
||||
"""Configuration model defining the structure of a BatDetect2 dataset."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
sources: List[
|
||||
Annotated[AnnotationFormats, Field(..., discriminator="format")]
|
||||
]
|
||||
sources: List[AnnotationFormats]
|
||||
|
||||
sound_event_filter: Optional[SoundEventConditionConfig] = None
|
||||
sound_event_transforms: List[SoundEventTransformConfig] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
|
||||
def load_dataset(
|
||||
dataset: DatasetConfig,
|
||||
config: DatasetConfig,
|
||||
base_dir: Optional[Path] = None,
|
||||
) -> Dataset:
|
||||
"""Load all clip annotations from the sources defined in a DatasetConfig.
|
||||
|
||||
Iterates through each data source specified in the `dataset_config`,
|
||||
delegates the loading and parsing of that source's annotations to
|
||||
`batdetect2.data.annotations.load_annotated_dataset` (which handles
|
||||
different data formats), and aggregates all resulting `ClipAnnotation`
|
||||
objects into a single flat list.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset_config : DatasetConfig
|
||||
The configuration object describing the dataset and its sources.
|
||||
base_dir : Path, optional
|
||||
An optional base directory path. If provided, relative paths for
|
||||
metadata files or data directories within the `dataset_config`'s
|
||||
sources might be resolved relative to this directory. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dataset (List[data.ClipAnnotation])
|
||||
A flat list containing all loaded `ClipAnnotation` metadata objects
|
||||
from all specified sources.
|
||||
|
||||
Raises
|
||||
------
|
||||
Exception
|
||||
Can raise various exceptions during the delegated loading process
|
||||
(`load_annotated_dataset`) if files are not found, cannot be parsed
|
||||
according to the specified format, or other I/O errors occur.
|
||||
"""
|
||||
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
||||
clip_annotations = []
|
||||
for source in dataset.sources:
|
||||
|
||||
condition = (
|
||||
build_sound_event_condition(config.sound_event_filter)
|
||||
if config.sound_event_filter is not None
|
||||
else None
|
||||
)
|
||||
|
||||
transform = (
|
||||
ApplyAll(
|
||||
[
|
||||
build_sound_event_transform(step)
|
||||
for step in config.sound_event_transforms
|
||||
]
|
||||
)
|
||||
if config.sound_event_transforms
|
||||
else None
|
||||
)
|
||||
|
||||
for source in config.sources:
|
||||
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
|
||||
|
||||
logger.debug(
|
||||
"Loaded {num_examples} from dataset source '{source_name}'",
|
||||
num_examples=len(annotated_source.clip_annotations),
|
||||
source_name=source.name,
|
||||
)
|
||||
clip_annotations.extend(
|
||||
insert_source_tag(clip_annotation, source)
|
||||
for clip_annotation in annotated_source.clip_annotations
|
||||
)
|
||||
|
||||
for clip_annotation in annotated_source.clip_annotations:
|
||||
clip_annotation = insert_source_tag(clip_annotation, source)
|
||||
|
||||
if condition is not None:
|
||||
clip_annotation = filter_clip_annotation(
|
||||
clip_annotation,
|
||||
condition,
|
||||
)
|
||||
|
||||
if transform is not None:
|
||||
clip_annotation = transform_clip_annotation(
|
||||
clip_annotation,
|
||||
transform,
|
||||
)
|
||||
|
||||
clip_annotations.append(clip_annotation)
|
||||
|
||||
return clip_annotations
|
||||
|
||||
|
||||
@ -161,7 +161,6 @@ def insert_source_tag(
|
||||
)
|
||||
|
||||
|
||||
# TODO: add documentation
|
||||
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
||||
return load_config(path=path, schema=DatasetConfig, field=field)
|
||||
|
||||
|
||||
@ -4,22 +4,14 @@ from typing import Optional, Tuple
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
|
||||
def iterate_over_sound_events(
|
||||
dataset: Dataset,
|
||||
targets: TargetProtocol,
|
||||
apply_filter: bool = True,
|
||||
apply_transform: bool = True,
|
||||
exclude_generic: bool = True,
|
||||
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
|
||||
"""Iterate over sound events in a dataset, applying filtering and
|
||||
transformations.
|
||||
|
||||
This generator function processes sound event annotations from a given
|
||||
dataset, allowing for optional filtering, transformation, and exclusion of
|
||||
unclassifiable (generic) events based on the provided target definitions.
|
||||
"""Iterate over sound events in a dataset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -29,18 +21,6 @@ def iterate_over_sound_events(
|
||||
targets : TargetProtocol
|
||||
An object implementing the `TargetProtocol`, which provides methods
|
||||
for filtering, transforming, and encoding sound events.
|
||||
apply_filter : bool, optional
|
||||
If True, sound events will be filtered using `targets.filter()`.
|
||||
Only events for which `targets.filter()` returns True will be yielded.
|
||||
Defaults to True.
|
||||
apply_transform : bool, optional
|
||||
If True, sound events will be transformed using `targets.transform()`
|
||||
before being yielded. Defaults to True.
|
||||
exclude_generic : bool, optional
|
||||
If True, sound events that result in a `None` class name after
|
||||
`targets.encode()` will be excluded. This is typically used to
|
||||
filter out events that cannot be mapped to a specific target class.
|
||||
Defaults to True.
|
||||
|
||||
Yields
|
||||
------
|
||||
@ -63,17 +43,9 @@ def iterate_over_sound_events(
|
||||
"""
|
||||
for clip_annotation in dataset:
|
||||
for sound_event_annotation in clip_annotation.sound_events:
|
||||
if apply_filter:
|
||||
if not targets.filter(sound_event_annotation):
|
||||
continue
|
||||
|
||||
if apply_transform:
|
||||
sound_event_annotation = targets.transform(
|
||||
sound_event_annotation
|
||||
)
|
||||
|
||||
class_name = targets.encode_class(sound_event_annotation)
|
||||
if class_name is None and exclude_generic:
|
||||
if not targets.filter(sound_event_annotation):
|
||||
continue
|
||||
|
||||
class_name = targets.encode_class(sound_event_annotation)
|
||||
|
||||
yield class_name, sound_event_annotation
|
||||
|
||||
@ -7,7 +7,7 @@ from batdetect2.data.summary import (
|
||||
extract_recordings_df,
|
||||
extract_sound_events_df,
|
||||
)
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
|
||||
def split_dataset_by_recordings(
|
||||
|
||||
@ -2,7 +2,7 @@ import pandas as pd
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"extract_recordings_df",
|
||||
@ -100,8 +100,11 @@ def extract_sound_events_df(
|
||||
|
||||
class_name = targets.encode_class(sound_event)
|
||||
|
||||
if class_name is None and exclude_generic:
|
||||
continue
|
||||
if class_name is None:
|
||||
if exclude_generic:
|
||||
continue
|
||||
else:
|
||||
class_name = targets.detection_class_name
|
||||
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||
sound_event.sound_event.geometry
|
||||
@ -153,7 +156,7 @@ def compute_class_summary(
|
||||
sound_events = extract_sound_events_df(
|
||||
dataset,
|
||||
targets,
|
||||
exclude_generic=True,
|
||||
exclude_generic=False,
|
||||
exclude_non_target=True,
|
||||
)
|
||||
recordings = extract_recordings_df(dataset)
|
||||
|
||||
252
src/batdetect2/data/transforms.py
Normal file
252
src/batdetect2/data/transforms.py
Normal file
@ -0,0 +1,252 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Annotated, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.data.conditions import (
|
||||
SoundEventCondition,
|
||||
SoundEventConditionConfig,
|
||||
build_sound_event_condition,
|
||||
)
|
||||
|
||||
SoundEventTransform = Callable[
|
||||
[data.SoundEventAnnotation],
|
||||
data.SoundEventAnnotation,
|
||||
]
|
||||
|
||||
transforms: Registry[SoundEventTransform, []] = Registry("transform")
|
||||
|
||||
|
||||
class SetFrequencyBoundConfig(BaseConfig):
|
||||
name: Literal["set_frequency"] = "set_frequency"
|
||||
boundary: Literal["low", "high"] = "low"
|
||||
hertz: float
|
||||
|
||||
|
||||
class SetFrequencyBound:
|
||||
def __init__(self, hertz: float, boundary: Literal["low", "high"] = "low"):
|
||||
self.hertz = hertz
|
||||
self.boundary = boundary
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> data.SoundEventAnnotation:
|
||||
sound_event = sound_event_annotation.sound_event
|
||||
geometry = sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
return sound_event_annotation
|
||||
|
||||
if not isinstance(geometry, data.BoundingBox):
|
||||
return sound_event_annotation
|
||||
|
||||
start_time, low_freq, end_time, high_freq = geometry.coordinates
|
||||
|
||||
if self.boundary == "low":
|
||||
low_freq = self.hertz
|
||||
high_freq = max(high_freq, low_freq)
|
||||
|
||||
elif self.boundary == "high":
|
||||
high_freq = self.hertz
|
||||
low_freq = min(high_freq, low_freq)
|
||||
|
||||
geometry = data.BoundingBox(
|
||||
coordinates=[start_time, low_freq, end_time, high_freq],
|
||||
)
|
||||
|
||||
sound_event = sound_event.model_copy(update=dict(geometry=geometry))
|
||||
return sound_event_annotation.model_copy(
|
||||
update=dict(sound_event=sound_event)
|
||||
)
|
||||
|
||||
@transforms.register(SetFrequencyBoundConfig)
|
||||
@staticmethod
|
||||
def from_config(config: SetFrequencyBoundConfig):
|
||||
return SetFrequencyBound(hertz=config.hertz, boundary=config.boundary)
|
||||
|
||||
|
||||
class ApplyIfConfig(BaseConfig):
|
||||
name: Literal["apply_if"] = "apply_if"
|
||||
transform: "SoundEventTransformConfig"
|
||||
condition: SoundEventConditionConfig
|
||||
|
||||
|
||||
class ApplyIf:
|
||||
def __init__(
|
||||
self,
|
||||
condition: SoundEventCondition,
|
||||
transform: SoundEventTransform,
|
||||
):
|
||||
self.condition = condition
|
||||
self.transform = transform
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> data.SoundEventAnnotation:
|
||||
if not self.condition(sound_event_annotation):
|
||||
return sound_event_annotation
|
||||
|
||||
return self.transform(sound_event_annotation)
|
||||
|
||||
@transforms.register(ApplyIfConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ApplyIfConfig):
|
||||
transform = build_sound_event_transform(config.transform)
|
||||
condition = build_sound_event_condition(config.condition)
|
||||
return ApplyIf(condition=condition, transform=transform)
|
||||
|
||||
|
||||
class ReplaceTagConfig(BaseConfig):
|
||||
name: Literal["replace_tag"] = "replace_tag"
|
||||
original: data.Tag
|
||||
replacement: data.Tag
|
||||
|
||||
|
||||
class ReplaceTag:
|
||||
def __init__(
|
||||
self,
|
||||
original: data.Tag,
|
||||
replacement: data.Tag,
|
||||
):
|
||||
self.original = original
|
||||
self.replacement = replacement
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> data.SoundEventAnnotation:
|
||||
tags = []
|
||||
|
||||
for tag in sound_event_annotation.tags:
|
||||
if tag == self.original:
|
||||
tags.append(self.replacement)
|
||||
else:
|
||||
tags.append(tag)
|
||||
|
||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||
|
||||
@transforms.register(ReplaceTagConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ReplaceTagConfig):
|
||||
return ReplaceTag(
|
||||
original=config.original, replacement=config.replacement
|
||||
)
|
||||
|
||||
|
||||
class MapTagValueConfig(BaseConfig):
|
||||
name: Literal["map_tag_value"] = "map_tag_value"
|
||||
tag_key: str
|
||||
value_mapping: Dict[str, str]
|
||||
target_key: Optional[str] = None
|
||||
|
||||
|
||||
class MapTagValue:
|
||||
def __init__(
|
||||
self,
|
||||
tag_key: str,
|
||||
value_mapping: Dict[str, str],
|
||||
target_key: Optional[str] = None,
|
||||
):
|
||||
self.tag_key = tag_key
|
||||
self.value_mapping = value_mapping
|
||||
self.target_key = target_key
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> data.SoundEventAnnotation:
|
||||
tags = []
|
||||
|
||||
for tag in sound_event_annotation.tags:
|
||||
if tag.key != self.tag_key:
|
||||
tags.append(tag)
|
||||
continue
|
||||
|
||||
value = self.value_mapping.get(tag.value)
|
||||
|
||||
if value is None:
|
||||
tags.append(tag)
|
||||
continue
|
||||
|
||||
if self.target_key is None:
|
||||
tags.append(tag.model_copy(update=dict(value=value)))
|
||||
else:
|
||||
tags.append(
|
||||
data.Tag(
|
||||
key=self.target_key, # type: ignore
|
||||
value=value,
|
||||
)
|
||||
)
|
||||
|
||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
||||
|
||||
@transforms.register(MapTagValueConfig)
|
||||
@staticmethod
|
||||
def from_config(config: MapTagValueConfig):
|
||||
return MapTagValue(
|
||||
tag_key=config.tag_key,
|
||||
value_mapping=config.value_mapping,
|
||||
target_key=config.target_key,
|
||||
)
|
||||
|
||||
|
||||
class ApplyAllConfig(BaseConfig):
|
||||
name: Literal["apply_all"] = "apply_all"
|
||||
steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ApplyAll:
|
||||
def __init__(self, steps: List[SoundEventTransform]):
|
||||
self.steps = steps
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> data.SoundEventAnnotation:
|
||||
for step in self.steps:
|
||||
sound_event_annotation = step(sound_event_annotation)
|
||||
|
||||
return sound_event_annotation
|
||||
|
||||
@transforms.register(ApplyAllConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ApplyAllConfig):
|
||||
steps = [build_sound_event_transform(step) for step in config.steps]
|
||||
return ApplyAll(steps)
|
||||
|
||||
|
||||
SoundEventTransformConfig = Annotated[
|
||||
Union[
|
||||
SetFrequencyBoundConfig,
|
||||
ReplaceTagConfig,
|
||||
MapTagValueConfig,
|
||||
ApplyIfConfig,
|
||||
ApplyAllConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_sound_event_transform(
|
||||
config: SoundEventTransformConfig,
|
||||
) -> SoundEventTransform:
|
||||
return transforms.build(config)
|
||||
|
||||
|
||||
def transform_clip_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
transform: SoundEventTransform,
|
||||
) -> data.ClipAnnotation:
|
||||
return clip_annotation.model_copy(
|
||||
update=dict(
|
||||
sound_events=[
|
||||
transform(sound_event)
|
||||
for sound_event in clip_annotation.sound_events
|
||||
]
|
||||
)
|
||||
)
|
||||
@ -1,15 +1,15 @@
|
||||
from batdetect2.evaluate.config import (
|
||||
EvaluationConfig,
|
||||
load_evaluation_config,
|
||||
)
|
||||
from batdetect2.evaluate.match import (
|
||||
match_predictions_and_annotations,
|
||||
match_sound_events_and_raw_predictions,
|
||||
)
|
||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate
|
||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
||||
|
||||
__all__ = [
|
||||
"EvaluationConfig",
|
||||
"Evaluator",
|
||||
"TaskConfig",
|
||||
"build_evaluator",
|
||||
"build_task",
|
||||
"evaluate",
|
||||
"load_evaluation_config",
|
||||
"match_predictions_and_annotations",
|
||||
"match_sound_events_and_raw_predictions",
|
||||
"DEFAULT_EVAL_DIR",
|
||||
]
|
||||
|
||||
230
src/batdetect2/evaluate/affinity.py
Normal file
230
src/batdetect2/evaluate/affinity.py
Normal file
@ -0,0 +1,230 @@
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import compute_affinity
|
||||
from soundevent.geometry import compute_interval_overlap
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.typing.evaluate import AffinityFunction
|
||||
|
||||
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
||||
"matching_strategy"
|
||||
)
|
||||
|
||||
|
||||
class TimeAffinityConfig(BaseConfig):
|
||||
name: Literal["time_affinity"] = "time_affinity"
|
||||
time_buffer: float = 0.01
|
||||
|
||||
|
||||
class TimeAffinity(AffinityFunction):
|
||||
def __init__(self, time_buffer: float):
|
||||
self.time_buffer = time_buffer
|
||||
|
||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||
return compute_timestamp_affinity(
|
||||
geometry1, geometry2, time_buffer=self.time_buffer
|
||||
)
|
||||
|
||||
@affinity_functions.register(TimeAffinityConfig)
|
||||
@staticmethod
|
||||
def from_config(config: TimeAffinityConfig):
|
||||
return TimeAffinity(time_buffer=config.time_buffer)
|
||||
|
||||
|
||||
def compute_timestamp_affinity(
|
||||
geometry1: data.Geometry,
|
||||
geometry2: data.Geometry,
|
||||
time_buffer: float = 0.01,
|
||||
) -> float:
|
||||
assert isinstance(geometry1, data.TimeStamp)
|
||||
assert isinstance(geometry2, data.TimeStamp)
|
||||
|
||||
start_time1 = geometry1.coordinates
|
||||
start_time2 = geometry2.coordinates
|
||||
|
||||
a = min(start_time1, start_time2)
|
||||
b = max(start_time1, start_time2)
|
||||
|
||||
if b - a >= 2 * time_buffer:
|
||||
return 0
|
||||
|
||||
intersection = a - b + 2 * time_buffer
|
||||
union = b - a + 2 * time_buffer
|
||||
return intersection / union
|
||||
|
||||
|
||||
class IntervalIOUConfig(BaseConfig):
|
||||
name: Literal["interval_iou"] = "interval_iou"
|
||||
time_buffer: float = 0.01
|
||||
|
||||
|
||||
class IntervalIOU(AffinityFunction):
|
||||
def __init__(self, time_buffer: float):
|
||||
self.time_buffer = time_buffer
|
||||
|
||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||
return compute_interval_iou(
|
||||
geometry1,
|
||||
geometry2,
|
||||
time_buffer=self.time_buffer,
|
||||
)
|
||||
|
||||
@affinity_functions.register(IntervalIOUConfig)
|
||||
@staticmethod
|
||||
def from_config(config: IntervalIOUConfig):
|
||||
return IntervalIOU(time_buffer=config.time_buffer)
|
||||
|
||||
|
||||
def compute_interval_iou(
|
||||
geometry1: data.Geometry,
|
||||
geometry2: data.Geometry,
|
||||
time_buffer: float = 0.01,
|
||||
) -> float:
|
||||
assert isinstance(geometry1, data.TimeInterval)
|
||||
assert isinstance(geometry2, data.TimeInterval)
|
||||
|
||||
start_time1, end_time1 = geometry1.coordinates
|
||||
start_time2, end_time2 = geometry1.coordinates
|
||||
|
||||
start_time1 -= time_buffer
|
||||
start_time2 -= time_buffer
|
||||
end_time1 += time_buffer
|
||||
end_time2 += time_buffer
|
||||
|
||||
intersection = compute_interval_overlap(
|
||||
(start_time1, end_time1),
|
||||
(start_time2, end_time2),
|
||||
)
|
||||
|
||||
union = (
|
||||
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
|
||||
)
|
||||
|
||||
if union == 0:
|
||||
return 0
|
||||
|
||||
return intersection / union
|
||||
|
||||
|
||||
class BBoxIOUConfig(BaseConfig):
|
||||
name: Literal["bbox_iou"] = "bbox_iou"
|
||||
time_buffer: float = 0.01
|
||||
freq_buffer: float = 1000
|
||||
|
||||
|
||||
class BBoxIOU(AffinityFunction):
|
||||
def __init__(self, time_buffer: float, freq_buffer: float):
|
||||
self.time_buffer = time_buffer
|
||||
self.freq_buffer = freq_buffer
|
||||
|
||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||
if not isinstance(geometry1, data.BoundingBox):
|
||||
raise TypeError(
|
||||
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}"
|
||||
)
|
||||
|
||||
if not isinstance(geometry2, data.BoundingBox):
|
||||
raise TypeError(
|
||||
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
|
||||
)
|
||||
return bbox_iou(
|
||||
geometry1,
|
||||
geometry2,
|
||||
time_buffer=self.time_buffer,
|
||||
freq_buffer=self.freq_buffer,
|
||||
)
|
||||
|
||||
@affinity_functions.register(BBoxIOUConfig)
|
||||
@staticmethod
|
||||
def from_config(config: BBoxIOUConfig):
|
||||
return BBoxIOU(
|
||||
time_buffer=config.time_buffer,
|
||||
freq_buffer=config.freq_buffer,
|
||||
)
|
||||
|
||||
|
||||
def bbox_iou(
|
||||
geometry1: data.BoundingBox,
|
||||
geometry2: data.BoundingBox,
|
||||
time_buffer: float = 0.01,
|
||||
freq_buffer: float = 1000,
|
||||
) -> float:
|
||||
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
|
||||
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
|
||||
|
||||
start_time1 -= time_buffer
|
||||
start_time2 -= time_buffer
|
||||
end_time1 += time_buffer
|
||||
end_time2 += time_buffer
|
||||
|
||||
low_freq1 -= freq_buffer
|
||||
low_freq2 -= freq_buffer
|
||||
high_freq1 += freq_buffer
|
||||
high_freq2 += freq_buffer
|
||||
|
||||
time_intersection = compute_interval_overlap(
|
||||
(start_time1, end_time1),
|
||||
(start_time2, end_time2),
|
||||
)
|
||||
|
||||
freq_intersection = max(
|
||||
0,
|
||||
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
|
||||
)
|
||||
|
||||
intersection = time_intersection * freq_intersection
|
||||
|
||||
if intersection == 0:
|
||||
return 0
|
||||
|
||||
union = (
|
||||
(end_time1 - start_time1) * (high_freq1 - low_freq1)
|
||||
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
|
||||
- intersection
|
||||
)
|
||||
|
||||
return intersection / union
|
||||
|
||||
|
||||
class GeometricIOUConfig(BaseConfig):
|
||||
name: Literal["geometric_iou"] = "geometric_iou"
|
||||
time_buffer: float = 0.01
|
||||
freq_buffer: float = 1000
|
||||
|
||||
|
||||
class GeometricIOU(AffinityFunction):
|
||||
def __init__(self, time_buffer: float):
|
||||
self.time_buffer = time_buffer
|
||||
|
||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
||||
return compute_affinity(
|
||||
geometry1,
|
||||
geometry2,
|
||||
time_buffer=self.time_buffer,
|
||||
)
|
||||
|
||||
@affinity_functions.register(GeometricIOUConfig)
|
||||
@staticmethod
|
||||
def from_config(config: GeometricIOUConfig):
|
||||
return GeometricIOU(time_buffer=config.time_buffer)
|
||||
|
||||
|
||||
AffinityConfig = Annotated[
|
||||
Union[
|
||||
TimeAffinityConfig,
|
||||
IntervalIOUConfig,
|
||||
BBoxIOUConfig,
|
||||
GeometricIOUConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_affinity_function(
|
||||
config: Optional[AffinityConfig] = None,
|
||||
) -> AffinityFunction:
|
||||
config = config or GeometricIOUConfig()
|
||||
return affinity_functions.build(config)
|
||||
@ -1,10 +1,15 @@
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.match import MatchConfig
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.tasks import (
|
||||
TaskConfig,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
|
||||
|
||||
__all__ = [
|
||||
"EvaluationConfig",
|
||||
@ -13,7 +18,13 @@ __all__ = [
|
||||
|
||||
|
||||
class EvaluationConfig(BaseConfig):
|
||||
match: MatchConfig = Field(default_factory=MatchConfig)
|
||||
tasks: List[TaskConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
DetectionTaskConfig(),
|
||||
ClassificationTaskConfig(),
|
||||
]
|
||||
)
|
||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
|
||||
|
||||
def load_evaluation_config(
|
||||
|
||||
144
src/batdetect2/evaluate/dataset.py
Normal file
144
src/batdetect2/evaluate/dataset.py
Normal file
@ -0,0 +1,144 @@
|
||||
from typing import List, NamedTuple, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
||||
from batdetect2.audio.clips import PaddedClipConfig
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipperProtocol,
|
||||
PreprocessorProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TestDataset",
|
||||
"build_test_dataset",
|
||||
"build_test_loader",
|
||||
]
|
||||
|
||||
|
||||
class TestExample(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
idx: torch.Tensor
|
||||
start_time: torch.Tensor
|
||||
end_time: torch.Tensor
|
||||
|
||||
|
||||
class TestDataset(Dataset[TestExample]):
|
||||
clip_annotations: List[data.ClipAnnotation]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
clipper: Optional[ClipperProtocol] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
):
|
||||
self.clip_annotations = list(clip_annotations)
|
||||
self.clipper = clipper
|
||||
self.preprocessor = preprocessor
|
||||
self.audio_loader = audio_loader
|
||||
self.audio_dir = audio_dir
|
||||
|
||||
def __len__(self):
|
||||
return len(self.clip_annotations)
|
||||
|
||||
def __getitem__(self, idx: int) -> TestExample:
|
||||
clip_annotation = self.clip_annotations[idx]
|
||||
|
||||
if self.clipper is not None:
|
||||
clip_annotation = self.clipper(clip_annotation)
|
||||
|
||||
clip = clip_annotation.clip
|
||||
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
||||
spectrogram = self.preprocessor(wav_tensor)
|
||||
return TestExample(
|
||||
spec=spectrogram,
|
||||
idx=torch.tensor(idx),
|
||||
start_time=torch.tensor(clip.start_time),
|
||||
end_time=torch.tensor(clip.end_time),
|
||||
)
|
||||
|
||||
|
||||
class TestLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: PaddedClipConfig()
|
||||
)
|
||||
|
||||
|
||||
def build_test_loader(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[TestLoaderConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> DataLoader[TestExample]:
|
||||
logger.info("Building test data loader...")
|
||||
config = config or TestLoaderConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Test data loader config: \n{config}",
|
||||
config=lambda: config.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
|
||||
test_dataset = build_test_dataset(
|
||||
clip_annotations,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
config=config,
|
||||
)
|
||||
|
||||
num_workers = num_workers or config.num_workers
|
||||
return DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=_collate_fn,
|
||||
)
|
||||
|
||||
|
||||
def build_test_dataset(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[TestLoaderConfig] = None,
|
||||
) -> TestDataset:
|
||||
logger.info("Building training dataset...")
|
||||
config = config or TestLoaderConfig()
|
||||
|
||||
clipper = build_clipper(config=config.clipping_strategy)
|
||||
|
||||
if audio_loader is None:
|
||||
audio_loader = build_audio_loader()
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
return TestDataset(
|
||||
clip_annotations,
|
||||
audio_loader=audio_loader,
|
||||
clipper=clipper,
|
||||
preprocessor=preprocessor,
|
||||
)
|
||||
|
||||
|
||||
def _collate_fn(batch: List[TestExample]) -> TestExample:
|
||||
max_width = max(item.spec.shape[-1] for item in batch)
|
||||
return TestExample(
|
||||
spec=torch.stack(
|
||||
[adjust_width(item.spec, max_width) for item in batch]
|
||||
),
|
||||
idx=torch.stack([item.idx for item in batch]),
|
||||
start_time=torch.stack([item.start_time for item in batch]),
|
||||
end_time=torch.stack([item.end_time for item in batch]),
|
||||
)
|
||||
69
src/batdetect2/evaluate/evaluate.py
Normal file
69
src/batdetect2/evaluate/evaluate.py
Normal file
@ -0,0 +1,69 @@
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, Sequence
|
||||
|
||||
from lightning import Trainer
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.evaluate.dataset import build_test_loader
|
||||
from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.evaluate.lightning import EvaluationModule
|
||||
from batdetect2.logging import build_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||
|
||||
|
||||
def evaluate(
|
||||
model: Model,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
):
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
||||
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
config=config.preprocess,
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
|
||||
targets = targets or build_targets(config=config.targets)
|
||||
|
||||
loader = build_test_loader(
|
||||
test_annotations,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
|
||||
logger = build_logger(
|
||||
config.evaluation.logger,
|
||||
log_dir=Path(output_dir),
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
module = EvaluationModule(model, evaluator)
|
||||
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
||||
return trainer.test(module, loader)
|
||||
67
src/batdetect2/evaluate/evaluator.py
Normal file
67
src/batdetect2/evaluate/evaluator.py
Normal file
@ -0,0 +1,67 @@
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.evaluate.tasks import build_task
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import EvaluatorProtocol, RawPrediction, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"Evaluator",
|
||||
"build_evaluator",
|
||||
]
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
tasks: Sequence[EvaluatorProtocol],
|
||||
):
|
||||
self.targets = targets
|
||||
self.tasks = tasks
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> List[Any]:
|
||||
return [
|
||||
task.evaluate(clip_annotations, predictions) for task in self.tasks
|
||||
]
|
||||
|
||||
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
|
||||
results = {}
|
||||
|
||||
for task, outputs in zip(self.tasks, eval_outputs):
|
||||
results.update(task.compute_metrics(outputs))
|
||||
|
||||
return results
|
||||
|
||||
def generate_plots(
|
||||
self,
|
||||
eval_outputs: List[Any],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
for task, outputs in zip(self.tasks, eval_outputs):
|
||||
for name, fig in task.generate_plots(outputs):
|
||||
yield name, fig
|
||||
|
||||
|
||||
def build_evaluator(
|
||||
config: Optional[Union[EvaluationConfig, dict]] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
) -> EvaluatorProtocol:
|
||||
targets = targets or build_targets()
|
||||
|
||||
if config is None:
|
||||
config = EvaluationConfig()
|
||||
|
||||
if not isinstance(config, EvaluationConfig):
|
||||
config = EvaluationConfig.model_validate(config)
|
||||
|
||||
return Evaluator(
|
||||
targets=targets,
|
||||
tasks=[build_task(task, targets=targets) for task in config.tasks],
|
||||
)
|
||||
82
src/batdetect2/evaluate/lightning.py
Normal file
82
src/batdetect2/evaluate/lightning.py
Normal file
@ -0,0 +1,82 @@
|
||||
from typing import Any, List
|
||||
|
||||
from lightning import LightningModule
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||
from batdetect2.logging import get_image_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing import EvaluatorProtocol
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
|
||||
|
||||
class EvaluationModule(LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
evaluator: EvaluatorProtocol,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
self.evaluator = evaluator
|
||||
|
||||
self.clip_annotations: List[data.ClipAnnotation] = []
|
||||
self.predictions: List[List[RawPrediction]] = []
|
||||
|
||||
def test_step(self, batch: TestExample, batch_idx: int):
|
||||
dataset = self.get_dataset()
|
||||
clip_annotations = [
|
||||
dataset.clip_annotations[int(example_idx)]
|
||||
for example_idx in batch.idx
|
||||
]
|
||||
|
||||
outputs = self.model.detector(batch.spec)
|
||||
clip_detections = self.model.postprocessor(
|
||||
outputs,
|
||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||
)
|
||||
predictions = [
|
||||
to_raw_predictions(
|
||||
clip_dets.numpy(),
|
||||
targets=self.evaluator.targets,
|
||||
)
|
||||
for clip_dets in clip_detections
|
||||
]
|
||||
|
||||
self.clip_annotations.extend(clip_annotations)
|
||||
self.predictions.extend(predictions)
|
||||
|
||||
def on_test_epoch_start(self):
|
||||
self.clip_annotations = []
|
||||
self.predictions = []
|
||||
|
||||
def on_test_epoch_end(self):
|
||||
clip_evals = self.evaluator.evaluate(
|
||||
self.clip_annotations,
|
||||
self.predictions,
|
||||
)
|
||||
self.log_metrics(clip_evals)
|
||||
self.generate_plots(clip_evals)
|
||||
|
||||
def generate_plots(self, evaluated_clips: Any):
|
||||
plotter = get_image_logger(self.logger) # type: ignore
|
||||
|
||||
if plotter is None:
|
||||
return
|
||||
|
||||
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
|
||||
plotter(figure_name, fig, self.global_step)
|
||||
|
||||
def log_metrics(self, evaluated_clips: Any):
|
||||
metrics = self.evaluator.compute_metrics(evaluated_clips)
|
||||
self.log_dict(metrics)
|
||||
|
||||
def get_dataset(self) -> TestDataset:
|
||||
dataloaders = self.trainer.test_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
dataset = dataloaders.dataset
|
||||
assert isinstance(dataset, TestDataset)
|
||||
return dataset
|
||||
@ -1,50 +1,211 @@
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from typing import List, Literal, Optional, Tuple
|
||||
from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import compute_affinity
|
||||
from soundevent.evaluation import (
|
||||
match_geometries as optimal_match,
|
||||
)
|
||||
from soundevent.evaluation import match_geometries as optimal_match
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.evaluate.types import MatchEvaluation
|
||||
from batdetect2.postprocess.types import BatDetect2Prediction
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
MatchingStrategy = Literal["greedy", "optimal"]
|
||||
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.affinity import (
|
||||
AffinityConfig,
|
||||
GeometricIOUConfig,
|
||||
build_affinity_function,
|
||||
)
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import (
|
||||
AffinityFunction,
|
||||
MatcherProtocol,
|
||||
MatchEvaluation,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.typing.evaluate import ClipMatches
|
||||
|
||||
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
||||
"""The geometry representation to use for matching."""
|
||||
|
||||
matching_strategies = Registry("matching_strategy")
|
||||
|
||||
class MatchConfig(BaseConfig):
|
||||
"""Configuration for matching geometries.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
strategy : MatchingStrategy, default="greedy"
|
||||
The matching algorithm to use. 'greedy' prioritizes high-confidence
|
||||
predictions, while 'optimal' finds the globally best set of matches.
|
||||
geometry : MatchingGeometry, default="timestamp"
|
||||
The geometric representation to use when computing affinity.
|
||||
affinity_threshold : float, default=0.0
|
||||
The minimum affinity score (e.g., IoU) required for a valid match.
|
||||
time_buffer : float, default=0.005
|
||||
Time tolerance in seconds used in affinity calculations.
|
||||
frequency_buffer : float, default=1000
|
||||
Frequency tolerance in Hertz used in affinity calculations.
|
||||
"""
|
||||
def match(
|
||||
sound_event_annotations: Sequence[data.SoundEventAnnotation],
|
||||
raw_predictions: Sequence[RawPrediction],
|
||||
clip: data.Clip,
|
||||
scores: Optional[Sequence[float]] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
matcher: Optional[MatcherProtocol] = None,
|
||||
) -> ClipMatches:
|
||||
if matcher is None:
|
||||
matcher = build_matcher()
|
||||
|
||||
strategy: MatchingStrategy = "greedy"
|
||||
geometry: MatchingGeometry = "timestamp"
|
||||
affinity_threshold: float = 0.0
|
||||
time_buffer: float = 0.005
|
||||
frequency_buffer: float = 1_000
|
||||
if targets is None:
|
||||
targets = build_targets()
|
||||
|
||||
target_geometries: List[data.Geometry] = [ # type: ignore
|
||||
sound_event_annotation.sound_event.geometry
|
||||
for sound_event_annotation in sound_event_annotations
|
||||
]
|
||||
|
||||
predicted_geometries = [
|
||||
raw_prediction.geometry for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
if scores is None:
|
||||
scores = [
|
||||
raw_prediction.detection_score
|
||||
for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
matches = []
|
||||
|
||||
for source_idx, target_idx, affinity in matcher(
|
||||
ground_truth=target_geometries,
|
||||
predictions=predicted_geometries,
|
||||
scores=scores,
|
||||
):
|
||||
target = (
|
||||
sound_event_annotations[target_idx]
|
||||
if target_idx is not None
|
||||
else None
|
||||
)
|
||||
prediction = (
|
||||
raw_predictions[source_idx] if source_idx is not None else None
|
||||
)
|
||||
|
||||
gt_det = target_idx is not None
|
||||
gt_class = targets.encode_class(target) if target is not None else None
|
||||
gt_geometry = (
|
||||
target_geometries[target_idx] if target_idx is not None else None
|
||||
)
|
||||
|
||||
pred_score = float(prediction.detection_score) if prediction else 0
|
||||
pred_geometry = (
|
||||
predicted_geometries[source_idx]
|
||||
if source_idx is not None
|
||||
else None
|
||||
)
|
||||
|
||||
class_scores = (
|
||||
{
|
||||
class_name: score
|
||||
for class_name, score in zip(
|
||||
targets.class_names,
|
||||
prediction.class_scores,
|
||||
)
|
||||
}
|
||||
if prediction is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
matches.append(
|
||||
MatchEvaluation(
|
||||
clip=clip,
|
||||
sound_event_annotation=target,
|
||||
gt_det=gt_det,
|
||||
gt_class=gt_class,
|
||||
gt_geometry=gt_geometry,
|
||||
pred_score=pred_score,
|
||||
pred_class_scores=class_scores,
|
||||
pred_geometry=pred_geometry,
|
||||
affinity=affinity,
|
||||
)
|
||||
)
|
||||
|
||||
return ClipMatches(clip=clip, matches=matches)
|
||||
|
||||
|
||||
class StartTimeMatchConfig(BaseConfig):
|
||||
name: Literal["start_time_match"] = "start_time_match"
|
||||
distance_threshold: float = 0.01
|
||||
|
||||
|
||||
class StartTimeMatcher(MatcherProtocol):
|
||||
def __init__(self, distance_threshold: float):
|
||||
self.distance_threshold = distance_threshold
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
):
|
||||
return match_start_times(
|
||||
ground_truth,
|
||||
predictions,
|
||||
scores,
|
||||
distance_threshold=self.distance_threshold,
|
||||
)
|
||||
|
||||
@matching_strategies.register(StartTimeMatchConfig)
|
||||
@staticmethod
|
||||
def from_config(config: StartTimeMatchConfig):
|
||||
return StartTimeMatcher(distance_threshold=config.distance_threshold)
|
||||
|
||||
|
||||
def match_start_times(
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
distance_threshold: float = 0.01,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
if not ground_truth:
|
||||
for index in range(len(predictions)):
|
||||
yield index, None, 0
|
||||
|
||||
return
|
||||
|
||||
if not predictions:
|
||||
for index in range(len(ground_truth)):
|
||||
yield None, index, 0
|
||||
|
||||
return
|
||||
|
||||
gt_times = np.array([compute_bounds(geom)[0] for geom in ground_truth])
|
||||
pred_times = np.array([compute_bounds(geom)[0] for geom in predictions])
|
||||
|
||||
scores = np.array(scores)
|
||||
sort_args = np.argsort(scores)[::-1]
|
||||
|
||||
distances = np.abs(gt_times[None, :] - pred_times[:, None])
|
||||
closests = np.argmin(distances, axis=-1)
|
||||
|
||||
unmatched_gt = set(range(len(gt_times)))
|
||||
|
||||
for pred_index in sort_args:
|
||||
# Get the closest ground truth
|
||||
gt_closest_index = closests[pred_index]
|
||||
|
||||
if gt_closest_index not in unmatched_gt:
|
||||
# Does not match if closest has been assigned
|
||||
yield pred_index, None, 0
|
||||
continue
|
||||
|
||||
# Get the actual distance
|
||||
distance = distances[pred_index, gt_closest_index]
|
||||
|
||||
if distance > distance_threshold:
|
||||
# Does not match if too far from closest
|
||||
yield pred_index, None, 0
|
||||
continue
|
||||
|
||||
# Return affinity value: linear interpolation between 0 to 1, where a
|
||||
# distance at the threshold maps to 0 affinity and a zero distance maps
|
||||
# to 1.
|
||||
affinity = np.interp(
|
||||
distance,
|
||||
[0, distance_threshold],
|
||||
[1, 0],
|
||||
left=1,
|
||||
right=0,
|
||||
)
|
||||
unmatched_gt.remove(gt_closest_index)
|
||||
yield pred_index, gt_closest_index, affinity
|
||||
|
||||
for missing_index in unmatched_gt:
|
||||
yield None, missing_index, 0
|
||||
|
||||
|
||||
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
|
||||
@ -73,45 +234,58 @@ _geometry_cast_functions: Mapping[
|
||||
}
|
||||
|
||||
|
||||
def match_geometries(
|
||||
source: List[data.Geometry],
|
||||
target: List[data.Geometry],
|
||||
config: MatchConfig,
|
||||
scores: Optional[List[float]] = None,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
geometry_cast = _geometry_cast_functions[config.geometry]
|
||||
|
||||
if config.strategy == "optimal":
|
||||
return optimal_match(
|
||||
source=[geometry_cast(geom) for geom in source],
|
||||
target=[geometry_cast(geom) for geom in target],
|
||||
time_buffer=config.time_buffer,
|
||||
freq_buffer=config.frequency_buffer,
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
)
|
||||
|
||||
if config.strategy == "greedy":
|
||||
return greedy_match(
|
||||
source=[geometry_cast(geom) for geom in source],
|
||||
target=[geometry_cast(geom) for geom in target],
|
||||
time_buffer=config.time_buffer,
|
||||
freq_buffer=config.frequency_buffer,
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
scores=scores,
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Matching strategy not implemented {config.strategy}"
|
||||
class GreedyMatchConfig(BaseConfig):
|
||||
name: Literal["greedy_match"] = "greedy_match"
|
||||
geometry: MatchingGeometry = "timestamp"
|
||||
affinity_threshold: float = 0.5
|
||||
affinity_function: AffinityConfig = Field(
|
||||
default_factory=GeometricIOUConfig
|
||||
)
|
||||
|
||||
|
||||
class GreedyMatcher(MatcherProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
geometry: MatchingGeometry,
|
||||
affinity_threshold: float,
|
||||
affinity_function: AffinityFunction,
|
||||
):
|
||||
self.geometry = geometry
|
||||
self.affinity_function = affinity_function
|
||||
self.affinity_threshold = affinity_threshold
|
||||
self.cast_geometry = _geometry_cast_functions[self.geometry]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
):
|
||||
return greedy_match(
|
||||
ground_truth=[self.cast_geometry(geom) for geom in ground_truth],
|
||||
predictions=[self.cast_geometry(geom) for geom in predictions],
|
||||
scores=scores,
|
||||
affinity_function=self.affinity_function,
|
||||
affinity_threshold=self.affinity_threshold,
|
||||
)
|
||||
|
||||
@matching_strategies.register(GreedyMatchConfig)
|
||||
@staticmethod
|
||||
def from_config(config: GreedyMatchConfig):
|
||||
affinity_function = build_affinity_function(config.affinity_function)
|
||||
return GreedyMatcher(
|
||||
geometry=config.geometry,
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
affinity_function=affinity_function,
|
||||
)
|
||||
|
||||
|
||||
def greedy_match(
|
||||
source: List[data.Geometry],
|
||||
target: List[data.Geometry],
|
||||
scores: Optional[List[float]] = None,
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
affinity_threshold: float = 0.5,
|
||||
time_buffer: float = 0.001,
|
||||
freq_buffer: float = 1000,
|
||||
affinity_function: AffinityFunction = compute_affinity,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
"""Performs a greedy, one-to-one matching of source to target geometries.
|
||||
|
||||
@ -129,10 +303,6 @@ def greedy_match(
|
||||
Confidence scores for each source geometry for prioritization.
|
||||
affinity_threshold
|
||||
The minimum affinity score required for a valid match.
|
||||
time_buffer
|
||||
Time tolerance in seconds for affinity calculation.
|
||||
freq_buffer
|
||||
Frequency tolerance in Hertz for affinity calculation.
|
||||
|
||||
Yields
|
||||
------
|
||||
@ -143,37 +313,29 @@ def greedy_match(
|
||||
- Unmatched Source (False Positive): `(source_idx, None, 0)`
|
||||
- Unmatched Target (False Negative): `(None, target_idx, 0)`
|
||||
"""
|
||||
assigned = set()
|
||||
unassigned_gt = set(range(len(ground_truth)))
|
||||
|
||||
if not source:
|
||||
for target_idx in range(len(target)):
|
||||
yield None, target_idx, 0
|
||||
if not predictions:
|
||||
for gt_idx in range(len(ground_truth)):
|
||||
yield None, gt_idx, 0
|
||||
|
||||
return
|
||||
|
||||
if not target:
|
||||
for source_idx in range(len(source)):
|
||||
yield source_idx, None, 0
|
||||
if not ground_truth:
|
||||
for pred_idx in range(len(predictions)):
|
||||
yield pred_idx, None, 0
|
||||
|
||||
return
|
||||
|
||||
if scores is None:
|
||||
indices = np.arange(len(source))
|
||||
else:
|
||||
indices = np.argsort(scores)[::-1]
|
||||
indices = np.argsort(scores)[::-1]
|
||||
|
||||
for source_idx in indices:
|
||||
source_geometry = source[source_idx]
|
||||
for pred_idx in indices:
|
||||
source_geometry = predictions[pred_idx]
|
||||
|
||||
affinities = np.array(
|
||||
[
|
||||
compute_affinity(
|
||||
source_geometry,
|
||||
target_geometry,
|
||||
time_buffer=time_buffer,
|
||||
freq_buffer=freq_buffer,
|
||||
)
|
||||
for target_geometry in target
|
||||
affinity_function(source_geometry, target_geometry)
|
||||
for target_geometry in ground_truth
|
||||
]
|
||||
)
|
||||
|
||||
@ -181,162 +343,72 @@ def greedy_match(
|
||||
affinity = affinities[closest_target]
|
||||
|
||||
if affinities[closest_target] <= affinity_threshold:
|
||||
yield source_idx, None, 0
|
||||
yield pred_idx, None, 0
|
||||
continue
|
||||
|
||||
if closest_target in assigned:
|
||||
yield source_idx, None, 0
|
||||
if closest_target not in unassigned_gt:
|
||||
yield pred_idx, None, 0
|
||||
continue
|
||||
|
||||
assigned.add(closest_target)
|
||||
yield source_idx, closest_target, affinity
|
||||
unassigned_gt.remove(closest_target)
|
||||
yield pred_idx, closest_target, affinity
|
||||
|
||||
missed_ground_truth = set(range(len(target))) - assigned
|
||||
for target_idx in missed_ground_truth:
|
||||
yield None, target_idx, 0
|
||||
for gt_idx in unassigned_gt:
|
||||
yield None, gt_idx, 0
|
||||
|
||||
|
||||
def match_sound_events_and_raw_predictions(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
raw_predictions: List[BatDetect2Prediction],
|
||||
targets: TargetProtocol,
|
||||
config: Optional[MatchConfig] = None,
|
||||
) -> List[MatchEvaluation]:
|
||||
config = config or MatchConfig()
|
||||
class OptimalMatchConfig(BaseConfig):
|
||||
name: Literal["optimal_match"] = "optimal_match"
|
||||
affinity_threshold: float = 0.5
|
||||
time_buffer: float = 0.005
|
||||
frequency_buffer: float = 1_000
|
||||
|
||||
target_sound_events = [
|
||||
targets.transform(sound_event_annotation)
|
||||
for sound_event_annotation in clip_annotation.sound_events
|
||||
if targets.filter(sound_event_annotation)
|
||||
and sound_event_annotation.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
target_geometries: List[data.Geometry] = [ # type: ignore
|
||||
sound_event_annotation.sound_event.geometry
|
||||
for sound_event_annotation in target_sound_events
|
||||
if sound_event_annotation.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
predicted_geometries = [
|
||||
raw_prediction.raw.geometry for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
scores = [
|
||||
raw_prediction.raw.detection_score
|
||||
for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
matches = []
|
||||
|
||||
for source_idx, target_idx, affinity in match_geometries(
|
||||
source=predicted_geometries,
|
||||
target=target_geometries,
|
||||
config=config,
|
||||
scores=scores,
|
||||
class OptimalMatcher(MatcherProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
affinity_threshold: float,
|
||||
time_buffer: float,
|
||||
frequency_buffer: float,
|
||||
):
|
||||
target = (
|
||||
target_sound_events[target_idx] if target_idx is not None else None
|
||||
)
|
||||
prediction = (
|
||||
raw_predictions[source_idx] if source_idx is not None else None
|
||||
)
|
||||
self.affinity_threshold = affinity_threshold
|
||||
self.time_buffer = time_buffer
|
||||
self.frequency_buffer = frequency_buffer
|
||||
|
||||
gt_det = target is not None
|
||||
gt_class = targets.encode_class(target) if target is not None else None
|
||||
|
||||
pred_score = float(prediction.raw.detection_score) if prediction else 0
|
||||
|
||||
class_scores = (
|
||||
{
|
||||
str(class_name): float(score)
|
||||
for class_name, score in zip(
|
||||
targets.class_names,
|
||||
prediction.raw.class_scores,
|
||||
)
|
||||
}
|
||||
if prediction is not None
|
||||
else {}
|
||||
)
|
||||
|
||||
matches.append(
|
||||
MatchEvaluation(
|
||||
match=data.Match(
|
||||
source=None
|
||||
if prediction is None
|
||||
else prediction.sound_event_prediction,
|
||||
target=target,
|
||||
affinity=affinity,
|
||||
),
|
||||
gt_det=gt_det,
|
||||
gt_class=gt_class,
|
||||
pred_score=pred_score,
|
||||
pred_class_scores=class_scores,
|
||||
)
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def match_predictions_and_annotations(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
clip_prediction: data.ClipPrediction,
|
||||
config: Optional[MatchConfig] = None,
|
||||
) -> List[data.Match]:
|
||||
config = config or MatchConfig()
|
||||
|
||||
annotated_sound_events = [
|
||||
sound_event_annotation
|
||||
for sound_event_annotation in clip_annotation.sound_events
|
||||
if sound_event_annotation.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
predicted_sound_events = [
|
||||
sound_event_prediction
|
||||
for sound_event_prediction in clip_prediction.sound_events
|
||||
if sound_event_prediction.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
annotated_geometries: List[data.Geometry] = [
|
||||
sound_event.sound_event.geometry
|
||||
for sound_event in annotated_sound_events
|
||||
if sound_event.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
predicted_geometries: List[data.Geometry] = [
|
||||
sound_event.sound_event.geometry
|
||||
for sound_event in predicted_sound_events
|
||||
if sound_event.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
scores = [
|
||||
sound_event.score
|
||||
for sound_event in predicted_sound_events
|
||||
if sound_event.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
matches = []
|
||||
for source_idx, target_idx, affinity in match_geometries(
|
||||
source=predicted_geometries,
|
||||
target=annotated_geometries,
|
||||
config=config,
|
||||
scores=scores,
|
||||
def __call__(
|
||||
self,
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
):
|
||||
target = (
|
||||
annotated_sound_events[target_idx]
|
||||
if target_idx is not None
|
||||
else None
|
||||
)
|
||||
source = (
|
||||
predicted_sound_events[source_idx]
|
||||
if source_idx is not None
|
||||
else None
|
||||
)
|
||||
matches.append(
|
||||
data.Match(
|
||||
source=source,
|
||||
target=target,
|
||||
affinity=affinity,
|
||||
)
|
||||
return optimal_match(
|
||||
source=predictions,
|
||||
target=ground_truth,
|
||||
time_buffer=self.time_buffer,
|
||||
freq_buffer=self.frequency_buffer,
|
||||
affinity_threshold=self.affinity_threshold,
|
||||
)
|
||||
|
||||
return matches
|
||||
@matching_strategies.register(OptimalMatchConfig)
|
||||
@staticmethod
|
||||
def from_config(config: OptimalMatchConfig):
|
||||
return OptimalMatcher(
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
time_buffer=config.time_buffer,
|
||||
frequency_buffer=config.frequency_buffer,
|
||||
)
|
||||
|
||||
|
||||
MatchConfig = Annotated[
|
||||
Union[
|
||||
GreedyMatchConfig,
|
||||
StartTimeMatchConfig,
|
||||
OptimalMatchConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
|
||||
config = config or StartTimeMatchConfig()
|
||||
return matching_strategies.build(config)
|
||||
|
||||
@ -1,97 +0,0 @@
|
||||
from typing import Dict, List
|
||||
|
||||
import pandas as pd
|
||||
from sklearn import metrics
|
||||
from sklearn.preprocessing import label_binarize
|
||||
|
||||
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||
|
||||
__all__ = ["DetectionAveragePrecision"]
|
||||
|
||||
|
||||
class DetectionAveragePrecision(MetricsProtocol):
|
||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||
y_true, y_score = zip(
|
||||
*[(match.gt_det, match.pred_score) for match in matches]
|
||||
)
|
||||
score = float(metrics.average_precision_score(y_true, y_score))
|
||||
return {"detection_AP": score}
|
||||
|
||||
|
||||
class ClassificationMeanAveragePrecision(MetricsProtocol):
|
||||
def __init__(self, class_names: List[str], per_class: bool = True):
|
||||
self.class_names = class_names
|
||||
self.per_class = per_class
|
||||
|
||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||
y_true = label_binarize(
|
||||
[
|
||||
match.gt_class if match.gt_class is not None else "__NONE__"
|
||||
for match in matches
|
||||
],
|
||||
classes=self.class_names,
|
||||
)
|
||||
y_pred = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
name: match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
}
|
||||
for match in matches
|
||||
]
|
||||
).fillna(0)
|
||||
mAP = metrics.average_precision_score(y_true, y_pred[self.class_names])
|
||||
|
||||
ret = {
|
||||
"classification_mAP": float(mAP),
|
||||
}
|
||||
|
||||
if not self.per_class:
|
||||
return ret
|
||||
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_pred_class = y_pred[class_name]
|
||||
class_ap = metrics.average_precision_score(
|
||||
y_true_class,
|
||||
y_pred_class,
|
||||
)
|
||||
ret[f"classification_AP/{class_name}"] = float(class_ap)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
class ClassificationAccuracy(MetricsProtocol):
|
||||
def __init__(self, class_names: List[str]):
|
||||
self.class_names = class_names
|
||||
|
||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||
y_true = [
|
||||
match.gt_class if match.gt_class is not None else "__NONE__"
|
||||
for match in matches
|
||||
]
|
||||
|
||||
y_pred = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
name: match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
}
|
||||
for match in matches
|
||||
]
|
||||
).fillna(0)
|
||||
y_pred = y_pred.apply(
|
||||
lambda row: row.idxmax()
|
||||
if row.max() >= (1 - row.sum())
|
||||
else "__NONE__",
|
||||
axis=1,
|
||||
)
|
||||
|
||||
accuracy = metrics.balanced_accuracy_score(
|
||||
y_true,
|
||||
y_pred,
|
||||
)
|
||||
|
||||
return {
|
||||
"classification_acc": float(accuracy),
|
||||
}
|
||||
267
src/batdetect2/evaluate/metrics/classification.py
Normal file
267
src/batdetect2/evaluate/metrics/classification.py
Normal file
@ -0,0 +1,267 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClassificationMetric",
|
||||
"ClassificationMetricConfig",
|
||||
"build_classification_metric",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchEval:
|
||||
clip: data.Clip
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
|
||||
is_prediction: bool
|
||||
is_ground_truth: bool
|
||||
is_generic: bool
|
||||
true_class: Optional[str]
|
||||
score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipEval:
|
||||
clip: data.Clip
|
||||
matches: Mapping[str, List[MatchEval]]
|
||||
|
||||
|
||||
ClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||
|
||||
|
||||
classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
|
||||
Registry("classification_metric")
|
||||
)
|
||||
|
||||
|
||||
class BaseClassificationConfig(BaseConfig):
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class BaseClassificationMetric:
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.targets = targets
|
||||
self.include = include
|
||||
self.exclude = exclude
|
||||
|
||||
def include_class(self, class_name: str) -> bool:
|
||||
if self.include is not None:
|
||||
return class_name in self.include
|
||||
|
||||
if self.exclude is not None:
|
||||
return class_name not in self.exclude
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class ClassificationAveragePrecisionConfig(BaseClassificationConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
label: str = "average_precision"
|
||||
|
||||
|
||||
class ClassificationAveragePrecision(BaseClassificationMetric):
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
label: str = "average_precision",
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(include=include, exclude=exclude, targets=targets)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEval]
|
||||
) -> Dict[str, float]:
|
||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
class_scores = {
|
||||
class_name: average_precision(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
num_positives=num_positives[class_name],
|
||||
)
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
mean_score = float(
|
||||
np.mean([v for v in class_scores.values() if not np.isnan(v)])
|
||||
)
|
||||
|
||||
return {
|
||||
f"mean_{self.label}": mean_score,
|
||||
**{
|
||||
f"{self.label}/{class_name}": score
|
||||
for class_name, score in class_scores.items()
|
||||
if self.include_class(class_name)
|
||||
},
|
||||
}
|
||||
|
||||
@classification_metrics.register(ClassificationAveragePrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ClassificationAveragePrecisionConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
return ClassificationAveragePrecision(
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
label=config.label,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
|
||||
|
||||
class ClassificationROCAUCConfig(BaseClassificationConfig):
|
||||
name: Literal["roc_auc"] = "roc_auc"
|
||||
label: str = "roc_auc"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class ClassificationROCAUC(BaseClassificationMetric):
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
label: str = "roc_auc",
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.targets = targets
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
self.label = label
|
||||
self.include = include
|
||||
self.exclude = exclude
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEval]
|
||||
) -> Dict[str, float]:
|
||||
y_true, y_score, _ = _extract_per_class_metric_data(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
class_scores = {
|
||||
class_name: float(
|
||||
metrics.roc_auc_score(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
)
|
||||
)
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
mean_score = float(
|
||||
np.mean([v for v in class_scores.values() if v != np.nan])
|
||||
)
|
||||
|
||||
return {
|
||||
f"mean_{self.label}": mean_score,
|
||||
**{
|
||||
f"{self.label}/{class_name}": score
|
||||
for class_name, score in class_scores.items()
|
||||
if self.include_class(class_name)
|
||||
},
|
||||
}
|
||||
|
||||
@classification_metrics.register(ClassificationROCAUCConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ClassificationROCAUCConfig, targets: TargetProtocol
|
||||
):
|
||||
return ClassificationROCAUC(
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
label=config.label,
|
||||
)
|
||||
|
||||
|
||||
ClassificationMetricConfig = Annotated[
|
||||
Union[
|
||||
ClassificationAveragePrecisionConfig,
|
||||
ClassificationROCAUCConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_classification_metric(
|
||||
config: ClassificationMetricConfig,
|
||||
targets: TargetProtocol,
|
||||
) -> ClassificationMetric:
|
||||
return classification_metrics.build(config, targets)
|
||||
|
||||
|
||||
def _extract_per_class_metric_data(
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
):
|
||||
y_true = defaultdict(list)
|
||||
y_score = defaultdict(list)
|
||||
num_positives = defaultdict(lambda: 0)
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for class_name, matches in clip_eval.matches.items():
|
||||
for m in matches:
|
||||
# Exclude matches with ground truth sounds where the class
|
||||
# is unknown
|
||||
if m.is_generic and ignore_generic:
|
||||
continue
|
||||
|
||||
is_class = m.true_class == class_name
|
||||
|
||||
if is_class:
|
||||
num_positives[class_name] += 1
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if not m.is_prediction and ignore_non_predictions:
|
||||
continue
|
||||
|
||||
y_true[class_name].append(is_class)
|
||||
y_score[class_name].append(m.score)
|
||||
|
||||
return y_true, y_score, num_positives
|
||||
135
src/batdetect2/evaluate/metrics/clip_classification.py
Normal file
135
src/batdetect2/evaluate/metrics/clip_classification.py
Normal file
@ -0,0 +1,135 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipEval:
|
||||
true_classes: Set[str]
|
||||
class_scores: Dict[str, float]
|
||||
|
||||
|
||||
ClipClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||
|
||||
clip_classification_metrics: Registry[ClipClassificationMetric, []] = Registry(
|
||||
"clip_classification_metric"
|
||||
)
|
||||
|
||||
|
||||
class ClipClassificationAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
label: str = "average_precision"
|
||||
|
||||
|
||||
class ClipClassificationAveragePrecision:
|
||||
def __init__(self, label: str = "average_precision"):
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
y_true = defaultdict(list)
|
||||
y_score = defaultdict(list)
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for class_name, score in clip_eval.class_scores.items():
|
||||
y_true[class_name].append(class_name in clip_eval.true_classes)
|
||||
y_score[class_name].append(score)
|
||||
|
||||
class_scores = {
|
||||
class_name: float(
|
||||
average_precision(
|
||||
y_true=y_true[class_name],
|
||||
y_score=y_score[class_name],
|
||||
)
|
||||
)
|
||||
for class_name in y_true
|
||||
}
|
||||
|
||||
mean = np.mean([v for v in class_scores.values() if not np.isnan(v)])
|
||||
|
||||
return {
|
||||
f"mean_{self.label}": float(mean),
|
||||
**{
|
||||
f"{self.label}/{class_name}": score
|
||||
for class_name, score in class_scores.items()
|
||||
if not np.isnan(score)
|
||||
},
|
||||
}
|
||||
|
||||
@clip_classification_metrics.register(
|
||||
ClipClassificationAveragePrecisionConfig
|
||||
)
|
||||
@staticmethod
|
||||
def from_config(config: ClipClassificationAveragePrecisionConfig):
|
||||
return ClipClassificationAveragePrecision(label=config.label)
|
||||
|
||||
|
||||
class ClipClassificationROCAUCConfig(BaseConfig):
|
||||
name: Literal["roc_auc"] = "roc_auc"
|
||||
label: str = "roc_auc"
|
||||
|
||||
|
||||
class ClipClassificationROCAUC:
|
||||
def __init__(self, label: str = "roc_auc"):
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
y_true = defaultdict(list)
|
||||
y_score = defaultdict(list)
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for class_name, score in clip_eval.class_scores.items():
|
||||
y_true[class_name].append(class_name in clip_eval.true_classes)
|
||||
y_score[class_name].append(score)
|
||||
|
||||
class_scores = {
|
||||
class_name: float(
|
||||
metrics.roc_auc_score(
|
||||
y_true=y_true[class_name],
|
||||
y_score=y_score[class_name],
|
||||
)
|
||||
)
|
||||
for class_name in y_true
|
||||
}
|
||||
|
||||
mean = np.mean([v for v in class_scores.values() if not np.isnan(v)])
|
||||
|
||||
return {
|
||||
f"mean_{self.label}": float(mean),
|
||||
**{
|
||||
f"{self.label}/{class_name}": score
|
||||
for class_name, score in class_scores.items()
|
||||
if not np.isnan(score)
|
||||
},
|
||||
}
|
||||
|
||||
@clip_classification_metrics.register(ClipClassificationROCAUCConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClipClassificationROCAUCConfig):
|
||||
return ClipClassificationROCAUC(label=config.label)
|
||||
|
||||
|
||||
ClipClassificationMetricConfig = Annotated[
|
||||
Union[
|
||||
ClipClassificationAveragePrecisionConfig,
|
||||
ClipClassificationROCAUCConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_clip_metric(config: ClipClassificationMetricConfig):
|
||||
return clip_classification_metrics.build(config)
|
||||
173
src/batdetect2/evaluate/metrics/clip_detection.py
Normal file
173
src/batdetect2/evaluate/metrics/clip_detection.py
Normal file
@ -0,0 +1,173 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Callable, Dict, Literal, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipEval:
|
||||
gt_det: bool
|
||||
score: float
|
||||
|
||||
|
||||
ClipDetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||
|
||||
clip_detection_metrics: Registry[ClipDetectionMetric, []] = Registry(
|
||||
"clip_detection_metric"
|
||||
)
|
||||
|
||||
|
||||
class ClipDetectionAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
label: str = "average_precision"
|
||||
|
||||
|
||||
class ClipDetectionAveragePrecision:
|
||||
def __init__(self, label: str = "average_precision"):
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
y_true.append(clip_eval.gt_det)
|
||||
y_score.append(clip_eval.score)
|
||||
|
||||
score = average_precision(y_true=y_true, y_score=y_score)
|
||||
return {self.label: score}
|
||||
|
||||
@clip_detection_metrics.register(ClipDetectionAveragePrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClipDetectionAveragePrecisionConfig):
|
||||
return ClipDetectionAveragePrecision(label=config.label)
|
||||
|
||||
|
||||
class ClipDetectionROCAUCConfig(BaseConfig):
|
||||
name: Literal["roc_auc"] = "roc_auc"
|
||||
label: str = "roc_auc"
|
||||
|
||||
|
||||
class ClipDetectionROCAUC:
|
||||
def __init__(self, label: str = "roc_auc"):
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
y_true.append(clip_eval.gt_det)
|
||||
y_score.append(clip_eval.score)
|
||||
|
||||
score = float(metrics.roc_auc_score(y_true=y_true, y_score=y_score))
|
||||
return {self.label: score}
|
||||
|
||||
@clip_detection_metrics.register(ClipDetectionROCAUCConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClipDetectionROCAUCConfig):
|
||||
return ClipDetectionROCAUC(label=config.label)
|
||||
|
||||
|
||||
class ClipDetectionRecallConfig(BaseConfig):
|
||||
name: Literal["recall"] = "recall"
|
||||
threshold: float = 0.5
|
||||
label: str = "recall"
|
||||
|
||||
|
||||
class ClipDetectionRecall:
|
||||
def __init__(self, threshold: float, label: str = "recall"):
|
||||
self.threshold = threshold
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
num_positives = 0
|
||||
true_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
if clip_eval.gt_det:
|
||||
num_positives += 1
|
||||
|
||||
if clip_eval.score >= self.threshold and clip_eval.gt_det:
|
||||
true_positives += 1
|
||||
|
||||
if num_positives == 0:
|
||||
return {self.label: np.nan}
|
||||
|
||||
score = true_positives / num_positives
|
||||
return {self.label: score}
|
||||
|
||||
@clip_detection_metrics.register(ClipDetectionRecallConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClipDetectionRecallConfig):
|
||||
return ClipDetectionRecall(
|
||||
threshold=config.threshold, label=config.label
|
||||
)
|
||||
|
||||
|
||||
class ClipDetectionPrecisionConfig(BaseConfig):
|
||||
name: Literal["precision"] = "precision"
|
||||
threshold: float = 0.5
|
||||
label: str = "precision"
|
||||
|
||||
|
||||
class ClipDetectionPrecision:
|
||||
def __init__(self, threshold: float, label: str = "precision"):
|
||||
self.threshold = threshold
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
num_detections = 0
|
||||
true_positives = 0
|
||||
for clip_eval in clip_evaluations:
|
||||
if clip_eval.score >= self.threshold:
|
||||
num_detections += 1
|
||||
|
||||
if clip_eval.score >= self.threshold and clip_eval.gt_det:
|
||||
true_positives += 1
|
||||
|
||||
if num_detections == 0:
|
||||
return {self.label: np.nan}
|
||||
|
||||
score = true_positives / num_detections
|
||||
return {self.label: score}
|
||||
|
||||
@clip_detection_metrics.register(ClipDetectionPrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClipDetectionPrecisionConfig):
|
||||
return ClipDetectionPrecision(
|
||||
threshold=config.threshold, label=config.label
|
||||
)
|
||||
|
||||
|
||||
ClipDetectionMetricConfig = Annotated[
|
||||
Union[
|
||||
ClipDetectionAveragePrecisionConfig,
|
||||
ClipDetectionROCAUCConfig,
|
||||
ClipDetectionRecallConfig,
|
||||
ClipDetectionPrecisionConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_clip_metric(config: ClipDetectionMetricConfig):
|
||||
return clip_detection_metrics.build(config)
|
||||
63
src/batdetect2/evaluate/metrics/common.py
Normal file
63
src/batdetect2/evaluate/metrics/common.py
Normal file
@ -0,0 +1,63 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
__all__ = [
|
||||
"compute_precision_recall",
|
||||
"average_precision",
|
||||
]
|
||||
|
||||
|
||||
def compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives: Optional[int] = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
y_true = np.array(y_true)
|
||||
y_score = np.array(y_score)
|
||||
|
||||
if num_positives is None:
|
||||
num_positives = y_true.sum()
|
||||
|
||||
# Sort by score
|
||||
sort_ind = np.argsort(y_score)[::-1]
|
||||
y_true_sorted = y_true[sort_ind]
|
||||
y_score_sorted = y_score[sort_ind]
|
||||
|
||||
false_pos_c = np.cumsum(1 - y_true_sorted)
|
||||
true_pos_c = np.cumsum(y_true_sorted)
|
||||
|
||||
recall = true_pos_c / num_positives
|
||||
precision = true_pos_c / np.maximum(
|
||||
true_pos_c + false_pos_c,
|
||||
np.finfo(np.float64).eps,
|
||||
)
|
||||
|
||||
precision[np.isnan(precision)] = 0
|
||||
recall[np.isnan(recall)] = 0
|
||||
return precision, recall, y_score_sorted
|
||||
|
||||
|
||||
def average_precision(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives: Optional[int] = None,
|
||||
) -> float:
|
||||
if num_positives == 0:
|
||||
return np.nan
|
||||
|
||||
precision, recall, _ = compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives=num_positives,
|
||||
)
|
||||
|
||||
# pascal 12 way
|
||||
mprec = np.hstack((0, precision, 0))
|
||||
mrec = np.hstack((0, recall, 1))
|
||||
for ii in range(mprec.shape[0] - 2, -1, -1):
|
||||
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
||||
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
||||
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
||||
|
||||
return ave_prec
|
||||
226
src/batdetect2/evaluate/metrics/detection.py
Normal file
226
src/batdetect2/evaluate/metrics/detection.py
Normal file
@ -0,0 +1,226 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import RawPrediction
|
||||
|
||||
__all__ = [
|
||||
"DetectionMetricConfig",
|
||||
"DetectionMetric",
|
||||
"build_detection_metric",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchEval:
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
|
||||
is_prediction: bool
|
||||
is_ground_truth: bool
|
||||
score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipEval:
|
||||
clip: data.Clip
|
||||
matches: List[MatchEval]
|
||||
|
||||
|
||||
DetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||
|
||||
|
||||
detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric")
|
||||
|
||||
|
||||
class DetectionAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
label: str = "average_precision"
|
||||
ignore_non_predictions: bool = True
|
||||
|
||||
|
||||
class DetectionAveragePrecision:
|
||||
def __init__(self, label: str, ignore_non_predictions: bool = True):
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evals: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
|
||||
for clip_eval in clip_evals:
|
||||
for m in clip_eval.matches:
|
||||
num_positives += int(m.is_ground_truth)
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
continue
|
||||
|
||||
y_true.append(m.is_ground_truth)
|
||||
y_score.append(m.score)
|
||||
|
||||
ap = average_precision(y_true, y_score, num_positives=num_positives)
|
||||
return {self.label: ap}
|
||||
|
||||
@detection_metrics.register(DetectionAveragePrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DetectionAveragePrecisionConfig):
|
||||
return DetectionAveragePrecision(
|
||||
label=config.label,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
)
|
||||
|
||||
|
||||
class DetectionROCAUCConfig(BaseConfig):
|
||||
name: Literal["roc_auc"] = "roc_auc"
|
||||
label: str = "roc_auc"
|
||||
ignore_non_predictions: bool = True
|
||||
|
||||
|
||||
class DetectionROCAUC:
|
||||
def __init__(
|
||||
self,
|
||||
label: str = "roc_auc",
|
||||
ignore_non_predictions: bool = True,
|
||||
):
|
||||
self.label = label
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
|
||||
def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]:
|
||||
y_true: List[bool] = []
|
||||
y_score: List[float] = []
|
||||
|
||||
for clip_eval in clip_evals:
|
||||
for m in clip_eval.matches:
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
continue
|
||||
|
||||
y_true.append(m.is_ground_truth)
|
||||
y_score.append(m.score)
|
||||
|
||||
score = float(metrics.roc_auc_score(y_true, y_score))
|
||||
return {self.label: score}
|
||||
|
||||
@detection_metrics.register(DetectionROCAUCConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DetectionROCAUCConfig):
|
||||
return DetectionROCAUC(
|
||||
label=config.label,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
)
|
||||
|
||||
|
||||
class DetectionRecallConfig(BaseConfig):
|
||||
name: Literal["recall"] = "recall"
|
||||
label: str = "recall"
|
||||
threshold: float = 0.5
|
||||
|
||||
|
||||
class DetectionRecall:
|
||||
def __init__(self, threshold: float, label: str = "recall"):
|
||||
self.label = label
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
num_positives = 0
|
||||
true_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if m.is_ground_truth:
|
||||
num_positives += 1
|
||||
|
||||
if m.score >= self.threshold and m.is_ground_truth:
|
||||
true_positives += 1
|
||||
|
||||
if num_positives == 0:
|
||||
return {self.label: np.nan}
|
||||
|
||||
score = true_positives / num_positives
|
||||
return {self.label: score}
|
||||
|
||||
@detection_metrics.register(DetectionRecallConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DetectionRecallConfig):
|
||||
return DetectionRecall(threshold=config.threshold, label=config.label)
|
||||
|
||||
|
||||
class DetectionPrecisionConfig(BaseConfig):
|
||||
name: Literal["precision"] = "precision"
|
||||
label: str = "precision"
|
||||
threshold: float = 0.5
|
||||
|
||||
|
||||
class DetectionPrecision:
|
||||
def __init__(self, threshold: float, label: str = "precision"):
|
||||
self.threshold = threshold
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
num_detections = 0
|
||||
true_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
is_detection = m.score >= self.threshold
|
||||
|
||||
if is_detection:
|
||||
num_detections += 1
|
||||
|
||||
if is_detection and m.is_ground_truth:
|
||||
true_positives += 1
|
||||
|
||||
if num_detections == 0:
|
||||
return {self.label: np.nan}
|
||||
|
||||
score = true_positives / num_detections
|
||||
return {self.label: score}
|
||||
|
||||
@detection_metrics.register(DetectionPrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DetectionPrecisionConfig):
|
||||
return DetectionPrecision(
|
||||
threshold=config.threshold,
|
||||
label=config.label,
|
||||
)
|
||||
|
||||
|
||||
DetectionMetricConfig = Annotated[
|
||||
Union[
|
||||
DetectionAveragePrecisionConfig,
|
||||
DetectionROCAUCConfig,
|
||||
DetectionRecallConfig,
|
||||
DetectionPrecisionConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_detection_metric(config: DetectionMetricConfig):
|
||||
return detection_metrics.build(config)
|
||||
314
src/batdetect2/evaluate/metrics/top_class.py
Normal file
314
src/batdetect2/evaluate/metrics/top_class.py
Normal file
@ -0,0 +1,314 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics, preprocessing
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import RawPrediction
|
||||
|
||||
__all__ = [
|
||||
"TopClassMetricConfig",
|
||||
"TopClassMetric",
|
||||
"build_top_class_metric",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchEval:
|
||||
clip: data.Clip
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
|
||||
is_ground_truth: bool
|
||||
is_generic: bool
|
||||
is_prediction: bool
|
||||
pred_class: Optional[str]
|
||||
true_class: Optional[str]
|
||||
score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipEval:
|
||||
clip: data.Clip
|
||||
matches: List[MatchEval]
|
||||
|
||||
|
||||
TopClassMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||
|
||||
|
||||
top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric")
|
||||
|
||||
|
||||
class TopClassAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
label: str = "average_precision"
|
||||
ignore_generic: bool = True
|
||||
ignore_non_predictions: bool = True
|
||||
|
||||
|
||||
class TopClassAveragePrecision:
|
||||
def __init__(
|
||||
self,
|
||||
ignore_generic: bool = True,
|
||||
ignore_non_predictions: bool = True,
|
||||
label: str = "average_precision",
|
||||
):
|
||||
self.ignore_generic = ignore_generic
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evals: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
|
||||
for clip_eval in clip_evals:
|
||||
for m in clip_eval.matches:
|
||||
if m.is_generic and self.ignore_generic:
|
||||
# Ignore gt sounds with unknown class
|
||||
continue
|
||||
|
||||
num_positives += int(m.is_ground_truth)
|
||||
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore non predictions
|
||||
continue
|
||||
|
||||
y_true.append(m.pred_class == m.true_class)
|
||||
y_score.append(m.score)
|
||||
|
||||
score = average_precision(y_true, y_score, num_positives=num_positives)
|
||||
return {self.label: score}
|
||||
|
||||
@top_class_metrics.register(TopClassAveragePrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: TopClassAveragePrecisionConfig):
|
||||
return TopClassAveragePrecision(
|
||||
ignore_generic=config.ignore_generic,
|
||||
label=config.label,
|
||||
)
|
||||
|
||||
|
||||
class TopClassROCAUCConfig(BaseConfig):
|
||||
name: Literal["roc_auc"] = "roc_auc"
|
||||
ignore_generic: bool = True
|
||||
ignore_non_predictions: bool = True
|
||||
label: str = "roc_auc"
|
||||
|
||||
|
||||
class TopClassROCAUC:
|
||||
def __init__(
|
||||
self,
|
||||
ignore_generic: bool = True,
|
||||
ignore_non_predictions: bool = True,
|
||||
label: str = "roc_auc",
|
||||
):
|
||||
self.ignore_generic = ignore_generic
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.label = label
|
||||
|
||||
def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]:
|
||||
y_true: List[bool] = []
|
||||
y_score: List[float] = []
|
||||
|
||||
for clip_eval in clip_evals:
|
||||
for m in clip_eval.matches:
|
||||
if m.is_generic and self.ignore_generic:
|
||||
# Ignore gt sounds with unknown class
|
||||
continue
|
||||
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore non predictions
|
||||
continue
|
||||
|
||||
y_true.append(m.pred_class == m.true_class)
|
||||
y_score.append(m.score)
|
||||
|
||||
score = float(metrics.roc_auc_score(y_true, y_score))
|
||||
return {self.label: score}
|
||||
|
||||
@top_class_metrics.register(TopClassROCAUCConfig)
|
||||
@staticmethod
|
||||
def from_config(config: TopClassROCAUCConfig):
|
||||
return TopClassROCAUC(
|
||||
ignore_generic=config.ignore_generic,
|
||||
label=config.label,
|
||||
)
|
||||
|
||||
|
||||
class TopClassRecallConfig(BaseConfig):
|
||||
name: Literal["recall"] = "recall"
|
||||
threshold: float = 0.5
|
||||
label: str = "recall"
|
||||
|
||||
|
||||
class TopClassRecall:
|
||||
def __init__(self, threshold: float, label: str = "recall"):
|
||||
self.threshold = threshold
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
num_positives = 0
|
||||
true_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if m.is_ground_truth:
|
||||
num_positives += 1
|
||||
|
||||
if m.score >= self.threshold and m.pred_class == m.true_class:
|
||||
true_positives += 1
|
||||
|
||||
if num_positives == 0:
|
||||
return {self.label: np.nan}
|
||||
|
||||
score = true_positives / num_positives
|
||||
return {self.label: score}
|
||||
|
||||
@top_class_metrics.register(TopClassRecallConfig)
|
||||
@staticmethod
|
||||
def from_config(config: TopClassRecallConfig):
|
||||
return TopClassRecall(
|
||||
threshold=config.threshold,
|
||||
label=config.label,
|
||||
)
|
||||
|
||||
|
||||
class TopClassPrecisionConfig(BaseConfig):
|
||||
name: Literal["precision"] = "precision"
|
||||
threshold: float = 0.5
|
||||
label: str = "precision"
|
||||
|
||||
|
||||
class TopClassPrecision:
|
||||
def __init__(self, threshold: float, label: str = "precision"):
|
||||
self.threshold = threshold
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
num_detections = 0
|
||||
true_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
is_detection = m.score >= self.threshold
|
||||
|
||||
if is_detection:
|
||||
num_detections += 1
|
||||
|
||||
if is_detection and m.pred_class == m.true_class:
|
||||
true_positives += 1
|
||||
|
||||
if num_detections == 0:
|
||||
return {self.label: np.nan}
|
||||
|
||||
score = true_positives / num_detections
|
||||
return {self.label: score}
|
||||
|
||||
@top_class_metrics.register(TopClassPrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: TopClassPrecisionConfig):
|
||||
return TopClassPrecision(
|
||||
threshold=config.threshold,
|
||||
label=config.label,
|
||||
)
|
||||
|
||||
|
||||
class BalancedAccuracyConfig(BaseConfig):
|
||||
name: Literal["balanced_accuracy"] = "balanced_accuracy"
|
||||
label: str = "balanced_accuracy"
|
||||
exclude_noise: bool = False
|
||||
noise_class: str = "noise"
|
||||
|
||||
|
||||
class BalancedAccuracy:
|
||||
def __init__(
|
||||
self,
|
||||
exclude_noise: bool = True,
|
||||
noise_class: str = "noise",
|
||||
label: str = "balanced_accuracy",
|
||||
):
|
||||
self.exclude_noise = exclude_noise
|
||||
self.noise_class = noise_class
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
y_true: List[str] = []
|
||||
y_pred: List[str] = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if m.is_generic:
|
||||
# Ignore matches that correspond to a sound event
|
||||
# with unknown class
|
||||
continue
|
||||
|
||||
if not m.is_ground_truth and self.exclude_noise:
|
||||
# Ignore predictions that were not matched to a
|
||||
# ground truth
|
||||
continue
|
||||
|
||||
if m.pred_class is None and self.exclude_noise:
|
||||
# Ignore non-predictions
|
||||
continue
|
||||
|
||||
y_true.append(m.true_class or self.noise_class)
|
||||
y_pred.append(m.pred_class or self.noise_class)
|
||||
|
||||
encoder = preprocessing.LabelEncoder()
|
||||
encoder.fit(list(set(y_true) | set(y_pred)))
|
||||
|
||||
y_true = encoder.transform(y_true)
|
||||
y_pred = encoder.transform(y_pred)
|
||||
score = metrics.balanced_accuracy_score(y_true, y_pred)
|
||||
return {self.label: score}
|
||||
|
||||
@top_class_metrics.register(BalancedAccuracyConfig)
|
||||
@staticmethod
|
||||
def from_config(config: BalancedAccuracyConfig):
|
||||
return BalancedAccuracy(
|
||||
exclude_noise=config.exclude_noise,
|
||||
noise_class=config.noise_class,
|
||||
label=config.label,
|
||||
)
|
||||
|
||||
|
||||
TopClassMetricConfig = Annotated[
|
||||
Union[
|
||||
TopClassAveragePrecisionConfig,
|
||||
TopClassROCAUCConfig,
|
||||
TopClassRecallConfig,
|
||||
TopClassPrecisionConfig,
|
||||
BalancedAccuracyConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_top_class_metric(config: TopClassMetricConfig):
|
||||
return top_class_metrics.build(config)
|
||||
0
src/batdetect2/evaluate/plots/__init__.py
Normal file
0
src/batdetect2/evaluate/plots/__init__.py
Normal file
54
src/batdetect2/evaluate/plots/base.py
Normal file
54
src/batdetect2/evaluate/plots/base.py
Normal file
@ -0,0 +1,54 @@
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.typing import TargetProtocol
|
||||
|
||||
|
||||
class BasePlotConfig(BaseConfig):
|
||||
label: str = "plot"
|
||||
theme: str = "default"
|
||||
title: Optional[str] = None
|
||||
figsize: tuple[int, int] = (10, 10)
|
||||
dpi: int = 100
|
||||
|
||||
|
||||
class BasePlot:
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
label: str = "plot",
|
||||
figsize: tuple[int, int] = (10, 10),
|
||||
title: Optional[str] = None,
|
||||
dpi: int = 100,
|
||||
theme: str = "default",
|
||||
):
|
||||
self.targets = targets
|
||||
self.label = label
|
||||
self.figsize = figsize
|
||||
self.dpi = dpi
|
||||
self.theme = theme
|
||||
self.title = title
|
||||
|
||||
def create_figure(self) -> Figure:
|
||||
plt.style.use(self.theme)
|
||||
fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
|
||||
|
||||
if self.title is not None:
|
||||
fig.suptitle(self.title)
|
||||
|
||||
return fig
|
||||
|
||||
@classmethod
|
||||
def build(cls, config: BasePlotConfig, targets: TargetProtocol, **kwargs):
|
||||
return cls(
|
||||
targets=targets,
|
||||
figsize=config.figsize,
|
||||
dpi=config.dpi,
|
||||
theme=config.theme,
|
||||
label=config.label,
|
||||
title=config.title,
|
||||
**kwargs,
|
||||
)
|
||||
370
src/batdetect2/evaluate/plots/classification.py
Normal file
370
src/batdetect2/evaluate/plots/classification.py
Normal file
@ -0,0 +1,370 @@
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.evaluate.metrics.classification import (
|
||||
ClipEval,
|
||||
_extract_per_class_metric_data,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.metrics import (
|
||||
plot_pr_curve,
|
||||
plot_pr_curves,
|
||||
plot_roc_curve,
|
||||
plot_roc_curves,
|
||||
plot_threshold_precision_curve,
|
||||
plot_threshold_precision_curves,
|
||||
plot_threshold_recall_curve,
|
||||
plot_threshold_recall_curves,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
|
||||
ClassificationPlotter = Callable[
|
||||
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
||||
]
|
||||
|
||||
classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
|
||||
Registry("classification_plot")
|
||||
)
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Classification Precision-Recall Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
separate_figures: bool = False
|
||||
|
||||
|
||||
class PRCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
separate_figures: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
self.separate_figures = separate_figures
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
data = {
|
||||
class_name: compute_precision_recall(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
num_positives=num_positives[class_name],
|
||||
)
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
if not self.separate_figures:
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
plot_pr_curves(data, ax=ax)
|
||||
yield self.label, fig
|
||||
return
|
||||
|
||||
for class_name, (precision, recall, thresholds) in data.items():
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||
ax.set_title(class_name)
|
||||
|
||||
yield f"{self.label}/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@classification_plots.register(PRCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||
return PRCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
separate_figures=config.separate_figures,
|
||||
)
|
||||
|
||||
|
||||
class ThresholdPrecisionCurveConfig(BasePlotConfig):
|
||||
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
|
||||
label: str = "threshold_precision_curve"
|
||||
title: Optional[str] = "Classification Threshold-Precision Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
separate_figures: bool = False
|
||||
|
||||
|
||||
class ThresholdPrecisionCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
separate_figures: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
self.separate_figures = separate_figures
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
data = {
|
||||
class_name: compute_precision_recall(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
num_positives[class_name],
|
||||
)
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
if not self.separate_figures:
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_threshold_precision_curves(data, ax=ax)
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
return
|
||||
|
||||
for class_name, (precision, _, thresholds) in data.items():
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
ax = plot_threshold_precision_curve(
|
||||
thresholds,
|
||||
precision,
|
||||
ax=ax,
|
||||
)
|
||||
|
||||
ax.set_title(class_name)
|
||||
|
||||
yield f"{self.label}/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@classification_plots.register(ThresholdPrecisionCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ThresholdPrecisionCurveConfig, targets: TargetProtocol
|
||||
):
|
||||
return ThresholdPrecisionCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
separate_figures=config.separate_figures,
|
||||
)
|
||||
|
||||
|
||||
class ThresholdRecallCurveConfig(BasePlotConfig):
|
||||
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
|
||||
label: str = "threshold_recall_curve"
|
||||
title: Optional[str] = "Classification Threshold-Recall Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
separate_figures: bool = False
|
||||
|
||||
|
||||
class ThresholdRecallCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
separate_figures: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
self.separate_figures = separate_figures
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
data = {
|
||||
class_name: compute_precision_recall(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
num_positives[class_name],
|
||||
)
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
if not self.separate_figures:
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_threshold_recall_curves(data, ax=ax, add_legend=True)
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
return
|
||||
|
||||
for class_name, (_, recall, thresholds) in data.items():
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
ax = plot_threshold_recall_curve(
|
||||
thresholds,
|
||||
recall,
|
||||
ax=ax,
|
||||
)
|
||||
|
||||
ax.set_title(class_name)
|
||||
|
||||
yield f"{self.label}/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@classification_plots.register(ThresholdRecallCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ThresholdRecallCurveConfig, targets: TargetProtocol
|
||||
):
|
||||
return ThresholdRecallCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
separate_figures=config.separate_figures,
|
||||
)
|
||||
|
||||
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Classification ROC Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
separate_figures: bool = False
|
||||
|
||||
|
||||
class ROCCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
separate_figures: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
self.separate_figures = separate_figures
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true, y_score, _ = _extract_per_class_metric_data(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
data = {
|
||||
class_name: metrics.roc_curve(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
)
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
if not self.separate_figures:
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_roc_curves(data, ax=ax)
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
return
|
||||
|
||||
for class_name, (fpr, tpr, thresholds) in data.items():
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||
ax.set_title(class_name)
|
||||
|
||||
yield f"{self.label}/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@classification_plots.register(ROCCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||
return ROCCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
separate_figures=config.separate_figures,
|
||||
)
|
||||
|
||||
|
||||
ClassificationPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ThresholdPrecisionCurveConfig,
|
||||
ThresholdRecallCurveConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_classification_plotter(
|
||||
config: ClassificationPlotConfig,
|
||||
targets: TargetProtocol,
|
||||
) -> ClassificationPlotter:
|
||||
return classification_plots.build(config, targets)
|
||||
189
src/batdetect2/evaluate/plots/clip_classification.py
Normal file
189
src/batdetect2/evaluate/plots/clip_classification.py
Normal file
@ -0,0 +1,189 @@
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.evaluate.metrics.clip_classification import ClipEval
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.metrics import (
|
||||
plot_pr_curve,
|
||||
plot_pr_curves,
|
||||
plot_roc_curve,
|
||||
plot_roc_curves,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClipClassificationPlotConfig",
|
||||
"ClipClassificationPlotter",
|
||||
"build_clip_classification_plotter",
|
||||
]
|
||||
|
||||
ClipClassificationPlotter = Callable[
|
||||
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
||||
]
|
||||
|
||||
clip_classification_plots: Registry[
|
||||
ClipClassificationPlotter, [TargetProtocol]
|
||||
] = Registry("clip_classification_plot")
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Clip Classification Precision-Recall Curve"
|
||||
separate_figures: bool = False
|
||||
|
||||
|
||||
class PRCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
separate_figures: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.separate_figures = separate_figures
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
data = {}
|
||||
|
||||
for class_name in self.targets.class_names:
|
||||
y_true = [class_name in c.true_classes for c in clip_evaluations]
|
||||
y_score = [
|
||||
c.class_scores.get(class_name, 0) for c in clip_evaluations
|
||||
]
|
||||
|
||||
precision, recall, thresholds = compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
)
|
||||
|
||||
data[class_name] = (precision, recall, thresholds)
|
||||
|
||||
if not self.separate_figures:
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_pr_curves(data, ax=ax)
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
return
|
||||
|
||||
for class_name, (precision, recall, thresholds) in data.items():
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||
ax.set_title(class_name)
|
||||
|
||||
yield f"{self.label}/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@clip_classification_plots.register(PRCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||
return PRCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
separate_figures=config.separate_figures,
|
||||
)
|
||||
|
||||
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Clip Classification ROC Curve"
|
||||
separate_figures: bool = False
|
||||
|
||||
|
||||
class ROCCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
separate_figures: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.separate_figures = separate_figures
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
data = {}
|
||||
|
||||
for class_name in self.targets.class_names:
|
||||
y_true = [class_name in c.true_classes for c in clip_evaluations]
|
||||
y_score = [
|
||||
c.class_scores.get(class_name, 0) for c in clip_evaluations
|
||||
]
|
||||
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
y_true,
|
||||
y_score,
|
||||
)
|
||||
|
||||
data[class_name] = (fpr, tpr, thresholds)
|
||||
|
||||
if not self.separate_figures:
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
plot_roc_curves(data, ax=ax)
|
||||
yield self.label, fig
|
||||
|
||||
return
|
||||
|
||||
for class_name, (fpr, tpr, thresholds) in data.items():
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||
ax.set_title(class_name)
|
||||
|
||||
yield f"{self.label}/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@clip_classification_plots.register(ROCCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||
return ROCCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
separate_figures=config.separate_figures,
|
||||
)
|
||||
|
||||
|
||||
ClipClassificationPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_clip_classification_plotter(
|
||||
config: ClipClassificationPlotConfig,
|
||||
targets: TargetProtocol,
|
||||
) -> ClipClassificationPlotter:
|
||||
return clip_classification_plots.build(config, targets)
|
||||
163
src/batdetect2/evaluate/plots/clip_detection.py
Normal file
163
src/batdetect2/evaluate/plots/clip_detection.py
Normal file
@ -0,0 +1,163 @@
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.evaluate.metrics.clip_detection import ClipEval
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||
from batdetect2.typing import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClipDetectionPlotConfig",
|
||||
"ClipDetectionPlotter",
|
||||
"build_clip_detection_plotter",
|
||||
]
|
||||
|
||||
ClipDetectionPlotter = Callable[
|
||||
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
||||
]
|
||||
|
||||
|
||||
clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
|
||||
Registry("clip_detection_plot")
|
||||
)
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Clip Detection Precision-Recall Curve"
|
||||
|
||||
|
||||
class PRCurve(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true = [c.gt_det for c in clip_evaluations]
|
||||
y_score = [c.score for c in clip_evaluations]
|
||||
|
||||
precision, recall, thresholds = compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
)
|
||||
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||
yield self.label, fig
|
||||
|
||||
@clip_detection_plots.register(PRCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||
return PRCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Clip Detection ROC Curve"
|
||||
|
||||
|
||||
class ROCCurve(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true = [c.gt_det for c in clip_evaluations]
|
||||
y_score = [c.score for c in clip_evaluations]
|
||||
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
y_true,
|
||||
y_score,
|
||||
)
|
||||
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||
yield self.label, fig
|
||||
|
||||
@clip_detection_plots.register(ROCCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||
return ROCCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||
name: Literal["score_distribution"] = "score_distribution"
|
||||
label: str = "score_distribution"
|
||||
title: Optional[str] = "Clip Detection Score Distribution"
|
||||
|
||||
|
||||
class ScoreDistributionPlot(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true = [c.gt_det for c in clip_evaluations]
|
||||
y_score = [c.score for c in clip_evaluations]
|
||||
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
df = pd.DataFrame({"is_true": y_true, "score": y_score})
|
||||
sns.histplot(
|
||||
data=df,
|
||||
x="score",
|
||||
binwidth=0.025,
|
||||
binrange=(0, 1),
|
||||
hue="is_true",
|
||||
ax=ax,
|
||||
stat="probability",
|
||||
common_norm=False,
|
||||
)
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
@clip_detection_plots.register(ScoreDistributionPlotConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ScoreDistributionPlotConfig, targets: TargetProtocol
|
||||
):
|
||||
return ScoreDistributionPlot.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
ClipDetectionPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ScoreDistributionPlotConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_clip_detection_plotter(
|
||||
config: ClipDetectionPlotConfig,
|
||||
targets: TargetProtocol,
|
||||
) -> ClipDetectionPlotter:
|
||||
return clip_detection_plots.build(config, targets)
|
||||
309
src/batdetect2/evaluate/plots/detection.py
Normal file
309
src/batdetect2/evaluate/plots/detection.py
Normal file
@ -0,0 +1,309 @@
|
||||
import random
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.metrics.detection import ClipEval
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.detections import plot_clip_detections
|
||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||
|
||||
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
||||
|
||||
detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
|
||||
name="detection_plot"
|
||||
)
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Detection Precision-Recall Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class PRCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evals: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
|
||||
for clip_eval in clip_evals:
|
||||
for m in clip_eval.matches:
|
||||
num_positives += int(m.is_ground_truth)
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
continue
|
||||
|
||||
y_true.append(m.is_ground_truth)
|
||||
y_score.append(m.score)
|
||||
|
||||
precision, recall, thresholds = compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives=num_positives,
|
||||
)
|
||||
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
@detection_plots.register(PRCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||
return PRCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Detection ROC Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class ROCCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
continue
|
||||
|
||||
y_true.append(m.is_ground_truth)
|
||||
y_score.append(m.score)
|
||||
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
y_true,
|
||||
y_score,
|
||||
)
|
||||
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
@detection_plots.register(ROCCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||
return ROCCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||
name: Literal["score_distribution"] = "score_distribution"
|
||||
label: str = "score_distribution"
|
||||
title: Optional[str] = "Detection Score Distribution"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class ScoreDistributionPlot(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
continue
|
||||
|
||||
y_true.append(m.is_ground_truth)
|
||||
y_score.append(m.score)
|
||||
|
||||
df = pd.DataFrame({"is_true": y_true, "score": y_score})
|
||||
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
sns.histplot(
|
||||
data=df,
|
||||
x="score",
|
||||
binwidth=0.025,
|
||||
binrange=(0, 1),
|
||||
hue="is_true",
|
||||
ax=ax,
|
||||
stat="probability",
|
||||
common_norm=False,
|
||||
)
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
@detection_plots.register(ScoreDistributionPlotConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ScoreDistributionPlotConfig, targets: TargetProtocol
|
||||
):
|
||||
return ScoreDistributionPlot.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
class ExampleDetectionPlotConfig(BasePlotConfig):
|
||||
name: Literal["example_detection"] = "example_detection"
|
||||
label: str = "example_detection"
|
||||
title: Optional[str] = "Example Detection"
|
||||
figsize: tuple[int, int] = (10, 4)
|
||||
num_examples: int = 5
|
||||
threshold: float = 0.2
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
|
||||
|
||||
class ExampleDetectionPlot(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
num_examples: int = 5,
|
||||
threshold: float = 0.2,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_examples = num_examples
|
||||
self.audio_loader = audio_loader
|
||||
self.threshold = threshold
|
||||
self.preprocessor = preprocessor
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
sample = clip_evaluations
|
||||
|
||||
if self.num_examples < len(sample):
|
||||
sample = random.sample(sample, self.num_examples)
|
||||
|
||||
for num_example, clip_eval in enumerate(sample):
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_clip_detections(
|
||||
clip_eval,
|
||||
ax=ax,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
)
|
||||
|
||||
yield f"{self.label}/example_{num_example}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@detection_plots.register(ExampleDetectionPlotConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ExampleDetectionPlotConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
return ExampleDetectionPlot.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
num_examples=config.num_examples,
|
||||
audio_loader=build_audio_loader(config.audio),
|
||||
preprocessor=build_preprocessor(config.preprocessing),
|
||||
)
|
||||
|
||||
|
||||
DetectionPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ScoreDistributionPlotConfig,
|
||||
ExampleDetectionPlotConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_detection_plotter(
|
||||
config: DetectionPlotConfig,
|
||||
targets: TargetProtocol,
|
||||
) -> DetectionPlotter:
|
||||
return detection_plots.build(config, targets)
|
||||
444
src/batdetect2/evaluate/plots/top_class.py
Normal file
444
src/batdetect2/evaluate/plots/top_class.py
Normal file
@ -0,0 +1,444 @@
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.metrics.top_class import ClipEval, MatchEval
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.gallery import plot_match_gallery
|
||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||
|
||||
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
||||
|
||||
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
||||
name="top_class_plot"
|
||||
)
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Top Class Precision-Recall Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class PRCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if m.is_generic and self.ignore_generic:
|
||||
# Ignore gt sounds with unknown class
|
||||
continue
|
||||
|
||||
num_positives += int(m.is_ground_truth)
|
||||
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore non predictions
|
||||
continue
|
||||
|
||||
y_true.append(m.pred_class == m.true_class)
|
||||
y_score.append(m.score)
|
||||
|
||||
precision, recall, thresholds = compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives=num_positives,
|
||||
)
|
||||
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
@top_class_plots.register(PRCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||
return PRCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Top Class ROC Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class ROCCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if m.is_generic and self.ignore_generic:
|
||||
# Ignore gt sounds with unknown class
|
||||
continue
|
||||
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore non predictions
|
||||
continue
|
||||
|
||||
y_true.append(m.pred_class == m.true_class)
|
||||
y_score.append(m.score)
|
||||
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
y_true,
|
||||
y_score,
|
||||
)
|
||||
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
@top_class_plots.register(ROCCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||
return ROCCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
class ConfusionMatrixConfig(BasePlotConfig):
|
||||
name: Literal["confusion_matrix"] = "confusion_matrix"
|
||||
title: Optional[str] = "Top Class Confusion Matrix"
|
||||
figsize: tuple[int, int] = (10, 10)
|
||||
label: str = "confusion_matrix"
|
||||
exclude_generic: bool = True
|
||||
exclude_noise: bool = False
|
||||
noise_class: str = "noise"
|
||||
normalize: Literal["true", "pred", "all", "none"] = "true"
|
||||
threshold: float = 0.2
|
||||
add_colorbar: bool = True
|
||||
cmap: str = "Blues"
|
||||
|
||||
|
||||
class ConfusionMatrix(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
exclude_generic: bool = True,
|
||||
exclude_noise: bool = False,
|
||||
noise_class: str = "noise",
|
||||
add_colorbar: bool = True,
|
||||
normalize: Literal["true", "pred", "all", "none"] = "true",
|
||||
cmap: str = "Blues",
|
||||
threshold: float = 0.2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.exclude_generic = exclude_generic
|
||||
self.exclude_noise = exclude_noise
|
||||
self.noise_class = noise_class
|
||||
self.normalize = normalize
|
||||
self.add_colorbar = add_colorbar
|
||||
self.threshold = threshold
|
||||
self.cmap = cmap
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true: List[str] = []
|
||||
y_pred: List[str] = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
true_class = m.true_class
|
||||
pred_class = m.pred_class
|
||||
|
||||
if not m.is_prediction and self.exclude_noise:
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
continue
|
||||
|
||||
if not m.is_ground_truth and self.exclude_noise:
|
||||
# Ignore matches that don't correspond to a ground truth
|
||||
continue
|
||||
|
||||
if m.score < self.threshold:
|
||||
if self.exclude_noise:
|
||||
continue
|
||||
|
||||
pred_class = self.noise_class
|
||||
|
||||
if m.is_generic:
|
||||
if self.exclude_generic:
|
||||
# Ignore gt sounds with unknown class
|
||||
continue
|
||||
|
||||
true_class = self.targets.detection_class_name
|
||||
|
||||
y_true.append(true_class or self.noise_class)
|
||||
y_pred.append(pred_class or self.noise_class)
|
||||
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
class_names = [*self.targets.class_names]
|
||||
|
||||
if not self.exclude_generic:
|
||||
class_names.append(self.targets.detection_class_name)
|
||||
|
||||
if not self.exclude_noise:
|
||||
class_names.append(self.noise_class)
|
||||
|
||||
metrics.ConfusionMatrixDisplay.from_predictions(
|
||||
y_true,
|
||||
y_pred,
|
||||
labels=class_names,
|
||||
ax=ax,
|
||||
xticks_rotation="vertical",
|
||||
cmap=self.cmap,
|
||||
colorbar=self.add_colorbar,
|
||||
normalize=self.normalize if self.normalize != "none" else None,
|
||||
values_format=".2f",
|
||||
)
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
@top_class_plots.register(ConfusionMatrixConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ConfusionMatrixConfig, targets: TargetProtocol):
|
||||
return ConfusionMatrix.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
exclude_generic=config.exclude_generic,
|
||||
exclude_noise=config.exclude_noise,
|
||||
noise_class=config.noise_class,
|
||||
add_colorbar=config.add_colorbar,
|
||||
normalize=config.normalize,
|
||||
cmap=config.cmap,
|
||||
)
|
||||
|
||||
|
||||
class ExampleClassificationPlotConfig(BasePlotConfig):
|
||||
name: Literal["example_classification"] = "example_classification"
|
||||
label: str = "example_classification"
|
||||
title: Optional[str] = "Example Classification"
|
||||
num_examples: int = 4
|
||||
threshold: float = 0.2
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
|
||||
|
||||
class ExampleClassificationPlot(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
num_examples: int = 4,
|
||||
threshold: float = 0.2,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_examples = num_examples
|
||||
self.audio_loader = audio_loader
|
||||
self.threshold = threshold
|
||||
self.preprocessor = preprocessor
|
||||
self.num_examples = num_examples
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
grouped = group_matches(clip_evaluations, threshold=self.threshold)
|
||||
|
||||
for class_name, matches in grouped.items():
|
||||
true_positives: List[MatchEval] = get_binned_sample(
|
||||
matches.true_positives,
|
||||
n_examples=self.num_examples,
|
||||
)
|
||||
|
||||
false_positives: List[MatchEval] = get_binned_sample(
|
||||
matches.false_positives,
|
||||
n_examples=self.num_examples,
|
||||
)
|
||||
|
||||
false_negatives: List[MatchEval] = random.sample(
|
||||
matches.false_negatives,
|
||||
k=min(self.num_examples, len(matches.false_negatives)),
|
||||
)
|
||||
|
||||
cross_triggers: List[MatchEval] = get_binned_sample(
|
||||
matches.cross_triggers, n_examples=self.num_examples
|
||||
)
|
||||
|
||||
fig = self.create_figure()
|
||||
|
||||
fig = plot_match_gallery(
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
cross_triggers,
|
||||
preprocessor=self.preprocessor,
|
||||
audio_loader=self.audio_loader,
|
||||
n_examples=self.num_examples,
|
||||
fig=fig,
|
||||
)
|
||||
|
||||
if self.title is not None:
|
||||
fig.suptitle(f"{self.title}: {class_name}")
|
||||
else:
|
||||
fig.suptitle(class_name)
|
||||
|
||||
yield f"{self.label}/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@top_class_plots.register(ExampleClassificationPlotConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ExampleClassificationPlotConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
return ExampleClassificationPlot.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
num_examples=config.num_examples,
|
||||
threshold=config.threshold,
|
||||
audio_loader=build_audio_loader(config.audio),
|
||||
preprocessor=build_preprocessor(config.preprocessing),
|
||||
)
|
||||
|
||||
|
||||
TopClassPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ConfusionMatrixConfig,
|
||||
ExampleClassificationPlotConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_top_class_plotter(
|
||||
config: TopClassPlotConfig,
|
||||
targets: TargetProtocol,
|
||||
) -> TopClassPlotter:
|
||||
return top_class_plots.build(config, targets)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassMatches:
|
||||
false_positives: List[MatchEval] = field(default_factory=list)
|
||||
false_negatives: List[MatchEval] = field(default_factory=list)
|
||||
true_positives: List[MatchEval] = field(default_factory=list)
|
||||
cross_triggers: List[MatchEval] = field(default_factory=list)
|
||||
|
||||
|
||||
def group_matches(
|
||||
clip_evals: Sequence[ClipEval],
|
||||
threshold: float = 0.2,
|
||||
) -> Dict[str, ClassMatches]:
|
||||
class_examples = defaultdict(ClassMatches)
|
||||
|
||||
for clip_eval in clip_evals:
|
||||
for match in clip_eval.matches:
|
||||
gt_class = match.true_class
|
||||
pred_class = match.pred_class
|
||||
is_pred = match.score >= threshold
|
||||
|
||||
if not is_pred and gt_class is not None:
|
||||
class_examples[gt_class].false_negatives.append(match)
|
||||
continue
|
||||
|
||||
if not is_pred:
|
||||
continue
|
||||
|
||||
if gt_class is None:
|
||||
class_examples[pred_class].false_positives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class != pred_class:
|
||||
class_examples[pred_class].cross_triggers.append(match)
|
||||
continue
|
||||
|
||||
class_examples[gt_class].true_positives.append(match)
|
||||
|
||||
return class_examples
|
||||
|
||||
|
||||
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
|
||||
if len(matches) < n_examples:
|
||||
return matches
|
||||
|
||||
indices, pred_scores = zip(
|
||||
*[(index, match.score) for index, match in enumerate(matches)]
|
||||
)
|
||||
|
||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
||||
df = pd.DataFrame({"indices": indices, "bins": bins})
|
||||
sample = df.groupby("bins").sample(1)
|
||||
return [matches[ind] for ind in sample["indices"]]
|
||||
106
src/batdetect2/evaluate/tables.py
Normal file
106
src/batdetect2/evaluate/tables.py
Normal file
@ -0,0 +1,106 @@
|
||||
from typing import Annotated, Callable, Literal, Sequence, Union
|
||||
|
||||
import pandas as pd
|
||||
from pydantic import Field
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.typing import ClipMatches
|
||||
|
||||
EvaluationTableGenerator = Callable[[Sequence[ClipMatches]], pd.DataFrame]
|
||||
|
||||
|
||||
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
|
||||
"evaluation_table"
|
||||
)
|
||||
|
||||
|
||||
class FullEvaluationTableConfig(BaseConfig):
|
||||
name: Literal["full_evaluation"] = "full_evaluation"
|
||||
|
||||
|
||||
class FullEvaluationTable:
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipMatches]
|
||||
) -> pd.DataFrame:
|
||||
return extract_matches_dataframe(clip_evaluations)
|
||||
|
||||
@tables_registry.register(FullEvaluationTableConfig)
|
||||
@staticmethod
|
||||
def from_config(config: FullEvaluationTableConfig):
|
||||
return FullEvaluationTable()
|
||||
|
||||
|
||||
def extract_matches_dataframe(
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
) -> pd.DataFrame:
|
||||
data = []
|
||||
|
||||
for clip_evaluation in clip_evaluations:
|
||||
for match in clip_evaluation.matches:
|
||||
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
||||
pred_start_time = pred_low_freq = pred_end_time = (
|
||||
pred_high_freq
|
||||
) = None
|
||||
|
||||
sound_event_annotation = match.sound_event_annotation
|
||||
|
||||
if sound_event_annotation is not None:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
assert geometry is not None
|
||||
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
|
||||
compute_bounds(geometry)
|
||||
)
|
||||
|
||||
if match.pred_geometry is not None:
|
||||
(
|
||||
pred_start_time,
|
||||
pred_low_freq,
|
||||
pred_end_time,
|
||||
pred_high_freq,
|
||||
) = compute_bounds(match.pred_geometry)
|
||||
|
||||
data.append(
|
||||
{
|
||||
("recording", "uuid"): match.clip.recording.uuid,
|
||||
("clip", "uuid"): match.clip.uuid,
|
||||
("clip", "start_time"): match.clip.start_time,
|
||||
("clip", "end_time"): match.clip.end_time,
|
||||
("gt", "uuid"): match.sound_event_annotation.uuid
|
||||
if match.sound_event_annotation is not None
|
||||
else None,
|
||||
("gt", "class"): match.gt_class,
|
||||
("gt", "det"): match.gt_det,
|
||||
("gt", "start_time"): gt_start_time,
|
||||
("gt", "end_time"): gt_end_time,
|
||||
("gt", "low_freq"): gt_low_freq,
|
||||
("gt", "high_freq"): gt_high_freq,
|
||||
("pred", "score"): match.pred_score,
|
||||
("pred", "class"): match.top_class,
|
||||
("pred", "class_score"): match.top_class_score,
|
||||
("pred", "start_time"): pred_start_time,
|
||||
("pred", "end_time"): pred_end_time,
|
||||
("pred", "low_freq"): pred_low_freq,
|
||||
("pred", "high_freq"): pred_high_freq,
|
||||
("match", "affinity"): match.affinity,
|
||||
**{
|
||||
("pred_class_score", key): value
|
||||
for key, value in match.pred_class_scores.items()
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
|
||||
return df
|
||||
|
||||
|
||||
EvaluationTableConfig = Annotated[
|
||||
Union[FullEvaluationTableConfig,], Field(discriminator="name")
|
||||
]
|
||||
|
||||
|
||||
def build_table_generator(
|
||||
config: EvaluationTableConfig,
|
||||
) -> EvaluationTableGenerator:
|
||||
return tables_registry.build(config)
|
||||
39
src/batdetect2/evaluate/tasks/__init__.py
Normal file
39
src/batdetect2/evaluate/tasks/__init__.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.evaluate.tasks.base import tasks_registry
|
||||
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||
from batdetect2.evaluate.tasks.clip_classification import (
|
||||
ClipClassificationTaskConfig,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
|
||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"TaskConfig",
|
||||
"build_task",
|
||||
]
|
||||
|
||||
|
||||
TaskConfig = Annotated[
|
||||
Union[
|
||||
ClassificationTaskConfig,
|
||||
DetectionTaskConfig,
|
||||
ClipDetectionTaskConfig,
|
||||
ClipClassificationTaskConfig,
|
||||
TopClassDetectionTaskConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_task(
|
||||
config: TaskConfig,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
) -> EvaluatorProtocol:
|
||||
targets = targets or build_targets()
|
||||
return tasks_registry.build(config, targets)
|
||||
175
src/batdetect2/evaluate/tasks/base.py
Normal file
175
src/batdetect2/evaluate/tasks/base.py
Normal file
@ -0,0 +1,175 @@
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.evaluate.match import (
|
||||
MatchConfig,
|
||||
StartTimeMatchConfig,
|
||||
build_matcher,
|
||||
)
|
||||
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"BaseTaskConfig",
|
||||
"BaseTask",
|
||||
]
|
||||
|
||||
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
|
||||
"tasks"
|
||||
)
|
||||
|
||||
|
||||
T_Output = TypeVar("T_Output")
|
||||
|
||||
|
||||
class BaseTaskConfig(BaseConfig):
|
||||
prefix: str
|
||||
ignore_start_end: float = 0.01
|
||||
matching_strategy: MatchConfig = Field(
|
||||
default_factory=StartTimeMatchConfig
|
||||
)
|
||||
|
||||
|
||||
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
targets: TargetProtocol
|
||||
|
||||
matcher: MatcherProtocol
|
||||
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
||||
|
||||
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||
|
||||
ignore_start_end: float
|
||||
|
||||
prefix: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
matcher: MatcherProtocol,
|
||||
targets: TargetProtocol,
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||
prefix: str,
|
||||
ignore_start_end: float = 0.01,
|
||||
plots: Optional[
|
||||
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||
] = None,
|
||||
):
|
||||
self.matcher = matcher
|
||||
self.metrics = metrics
|
||||
self.plots = plots or []
|
||||
self.targets = targets
|
||||
self.prefix = prefix
|
||||
self.ignore_start_end = ignore_start_end
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
eval_outputs: List[T_Output],
|
||||
) -> Dict[str, float]:
|
||||
scores = [metric(eval_outputs) for metric in self.metrics]
|
||||
return {
|
||||
f"{self.prefix}/{name}": score
|
||||
for metric_output in scores
|
||||
for name, score in metric_output.items()
|
||||
}
|
||||
|
||||
def generate_plots(
|
||||
self, eval_outputs: List[T_Output]
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
for plot in self.plots:
|
||||
for name, fig in plot(eval_outputs):
|
||||
yield f"{self.prefix}/{name}", fig
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> List[T_Output]:
|
||||
return [
|
||||
self.evaluate_clip(clip_annotation, preds)
|
||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
||||
]
|
||||
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> T_Output: ...
|
||||
|
||||
def include_sound_event_annotation(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
clip: data.Clip,
|
||||
) -> bool:
|
||||
if not self.targets.filter(sound_event_annotation):
|
||||
return False
|
||||
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
return is_in_bounds(
|
||||
geometry,
|
||||
clip,
|
||||
self.ignore_start_end,
|
||||
)
|
||||
|
||||
def include_prediction(
|
||||
self,
|
||||
prediction: RawPrediction,
|
||||
clip: data.Clip,
|
||||
) -> bool:
|
||||
return is_in_bounds(
|
||||
prediction.geometry,
|
||||
clip,
|
||||
self.ignore_start_end,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
config: BaseTaskConfig,
|
||||
targets: TargetProtocol,
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||
plots: Optional[
|
||||
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||
] = None,
|
||||
**kwargs,
|
||||
):
|
||||
matcher = build_matcher(config.matching_strategy)
|
||||
return cls(
|
||||
matcher=matcher,
|
||||
targets=targets,
|
||||
metrics=metrics,
|
||||
plots=plots,
|
||||
prefix=config.prefix,
|
||||
ignore_start_end=config.ignore_start_end,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def is_in_bounds(
|
||||
geometry: data.Geometry,
|
||||
clip: data.Clip,
|
||||
buffer: float,
|
||||
) -> bool:
|
||||
start_time = compute_bounds(geometry)[0]
|
||||
return (start_time >= clip.start_time + buffer) and (
|
||||
start_time <= clip.end_time - buffer
|
||||
)
|
||||
149
src/batdetect2/evaluate/tasks/classification.py
Normal file
149
src/batdetect2/evaluate/tasks/classification.py
Normal file
@ -0,0 +1,149 @@
|
||||
from typing import (
|
||||
List,
|
||||
Literal,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.metrics.classification import (
|
||||
ClassificationAveragePrecisionConfig,
|
||||
ClassificationMetricConfig,
|
||||
ClipEval,
|
||||
MatchEval,
|
||||
build_classification_metric,
|
||||
)
|
||||
from batdetect2.evaluate.plots.classification import (
|
||||
ClassificationPlotConfig,
|
||||
build_classification_plotter,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
|
||||
|
||||
class ClassificationTaskConfig(BaseTaskConfig):
|
||||
name: Literal["sound_event_classification"] = "sound_event_classification"
|
||||
prefix: str = "classification"
|
||||
metrics: List[ClassificationMetricConfig] = Field(
|
||||
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
||||
)
|
||||
plots: List[ClassificationPlotConfig] = Field(default_factory=list)
|
||||
include_generics: bool = True
|
||||
|
||||
|
||||
class ClassificationTask(BaseTask[ClipEval]):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
include_generics: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.include_generics = include_generics
|
||||
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
preds = [
|
||||
pred for pred in predictions if self.include_prediction(pred, clip)
|
||||
]
|
||||
|
||||
all_gts = [
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if self.include_sound_event_annotation(sound_event, clip)
|
||||
]
|
||||
|
||||
per_class_matches = {}
|
||||
|
||||
for class_name in self.targets.class_names:
|
||||
class_idx = self.targets.class_names.index(class_name)
|
||||
|
||||
# Only match to targets of the given class
|
||||
gts = [
|
||||
sound_event
|
||||
for sound_event in all_gts
|
||||
if self.is_class(sound_event, class_name)
|
||||
]
|
||||
scores = [float(pred.class_scores[class_idx]) for pred in preds]
|
||||
|
||||
matches = []
|
||||
|
||||
for pred_idx, gt_idx, _ in self.matcher(
|
||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||
predictions=[pred.geometry for pred in preds],
|
||||
scores=scores,
|
||||
):
|
||||
gt = gts[gt_idx] if gt_idx is not None else None
|
||||
pred = preds[pred_idx] if pred_idx is not None else None
|
||||
|
||||
true_class = (
|
||||
self.targets.encode_class(gt) if gt is not None else None
|
||||
)
|
||||
|
||||
score = (
|
||||
float(pred.class_scores[class_idx])
|
||||
if pred is not None
|
||||
else 0
|
||||
)
|
||||
|
||||
matches.append(
|
||||
MatchEval(
|
||||
clip=clip,
|
||||
gt=gt,
|
||||
pred=pred,
|
||||
is_prediction=pred is not None,
|
||||
is_ground_truth=gt is not None,
|
||||
is_generic=gt is not None and true_class is None,
|
||||
true_class=true_class,
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
|
||||
per_class_matches[class_name] = matches
|
||||
|
||||
return ClipEval(clip=clip, matches=per_class_matches)
|
||||
|
||||
def is_class(
|
||||
self,
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
class_name: str,
|
||||
) -> bool:
|
||||
sound_event_class = self.targets.encode_class(sound_event)
|
||||
|
||||
if sound_event_class is None and self.include_generics:
|
||||
# Sound events that are generic could be of the given
|
||||
# class
|
||||
return True
|
||||
|
||||
return sound_event_class == class_name
|
||||
|
||||
@tasks_registry.register(ClassificationTaskConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ClassificationTaskConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = [
|
||||
build_classification_metric(metric, targets)
|
||||
for metric in config.metrics
|
||||
]
|
||||
plots = [
|
||||
build_classification_plotter(plot, targets)
|
||||
for plot in config.plots
|
||||
]
|
||||
return ClassificationTask.build(
|
||||
config=config,
|
||||
plots=plots,
|
||||
targets=targets,
|
||||
metrics=metrics,
|
||||
)
|
||||
85
src/batdetect2/evaluate/tasks/clip_classification.py
Normal file
85
src/batdetect2/evaluate/tasks/clip_classification.py
Normal file
@ -0,0 +1,85 @@
|
||||
from collections import defaultdict
|
||||
from typing import List, Literal, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.metrics.clip_classification import (
|
||||
ClipClassificationAveragePrecisionConfig,
|
||||
ClipClassificationMetricConfig,
|
||||
ClipEval,
|
||||
build_clip_metric,
|
||||
)
|
||||
from batdetect2.evaluate.plots.clip_classification import (
|
||||
ClipClassificationPlotConfig,
|
||||
build_clip_classification_plotter,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
|
||||
|
||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||
name: Literal["clip_classification"] = "clip_classification"
|
||||
prefix: str = "clip_classification"
|
||||
metrics: List[ClipClassificationMetricConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ClipClassificationAveragePrecisionConfig(),
|
||||
]
|
||||
)
|
||||
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ClipClassificationTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
gt_classes = set()
|
||||
for sound_event in clip_annotation.sound_events:
|
||||
if not self.include_sound_event_annotation(sound_event, clip):
|
||||
continue
|
||||
|
||||
class_name = self.targets.encode_class(sound_event)
|
||||
|
||||
if class_name is None:
|
||||
continue
|
||||
|
||||
gt_classes.add(class_name)
|
||||
|
||||
pred_scores = defaultdict(float)
|
||||
for pred in predictions:
|
||||
if not self.include_prediction(pred, clip):
|
||||
continue
|
||||
|
||||
for class_idx, class_name in enumerate(self.targets.class_names):
|
||||
pred_scores[class_name] = max(
|
||||
float(pred.class_scores[class_idx]),
|
||||
pred_scores[class_name],
|
||||
)
|
||||
|
||||
return ClipEval(true_classes=gt_classes, class_scores=pred_scores)
|
||||
|
||||
@tasks_registry.register(ClipClassificationTaskConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ClipClassificationTaskConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = [build_clip_metric(metric) for metric in config.metrics]
|
||||
plots = [
|
||||
build_clip_classification_plotter(plot, targets)
|
||||
for plot in config.plots
|
||||
]
|
||||
return ClipClassificationTask.build(
|
||||
config=config,
|
||||
plots=plots,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
)
|
||||
76
src/batdetect2/evaluate/tasks/clip_detection.py
Normal file
76
src/batdetect2/evaluate/tasks/clip_detection.py
Normal file
@ -0,0 +1,76 @@
|
||||
from typing import List, Literal, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.metrics.clip_detection import (
|
||||
ClipDetectionAveragePrecisionConfig,
|
||||
ClipDetectionMetricConfig,
|
||||
ClipEval,
|
||||
build_clip_metric,
|
||||
)
|
||||
from batdetect2.evaluate.plots.clip_detection import (
|
||||
ClipDetectionPlotConfig,
|
||||
build_clip_detection_plotter,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
|
||||
|
||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||
name: Literal["clip_detection"] = "clip_detection"
|
||||
prefix: str = "clip_detection"
|
||||
metrics: List[ClipDetectionMetricConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ClipDetectionAveragePrecisionConfig(),
|
||||
]
|
||||
)
|
||||
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ClipDetectionTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
gt_det = any(
|
||||
self.include_sound_event_annotation(sound_event, clip)
|
||||
for sound_event in clip_annotation.sound_events
|
||||
)
|
||||
|
||||
pred_score = 0
|
||||
for pred in predictions:
|
||||
if not self.include_prediction(pred, clip):
|
||||
continue
|
||||
|
||||
pred_score = max(pred_score, pred.detection_score)
|
||||
|
||||
return ClipEval(
|
||||
gt_det=gt_det,
|
||||
score=pred_score,
|
||||
)
|
||||
|
||||
@tasks_registry.register(ClipDetectionTaskConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ClipDetectionTaskConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = [build_clip_metric(metric) for metric in config.metrics]
|
||||
plots = [
|
||||
build_clip_detection_plotter(plot, targets)
|
||||
for plot in config.plots
|
||||
]
|
||||
return ClipDetectionTask.build(
|
||||
config=config,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
plots=plots,
|
||||
)
|
||||
88
src/batdetect2/evaluate/tasks/detection.py
Normal file
88
src/batdetect2/evaluate/tasks/detection.py
Normal file
@ -0,0 +1,88 @@
|
||||
from typing import List, Literal, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.metrics.detection import (
|
||||
ClipEval,
|
||||
DetectionAveragePrecisionConfig,
|
||||
DetectionMetricConfig,
|
||||
MatchEval,
|
||||
build_detection_metric,
|
||||
)
|
||||
from batdetect2.evaluate.plots.detection import (
|
||||
DetectionPlotConfig,
|
||||
build_detection_plotter,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
|
||||
|
||||
class DetectionTaskConfig(BaseTaskConfig):
|
||||
name: Literal["sound_event_detection"] = "sound_event_detection"
|
||||
prefix: str = "detection"
|
||||
metrics: List[DetectionMetricConfig] = Field(
|
||||
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
||||
)
|
||||
plots: List[DetectionPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DetectionTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
gts = [
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if self.include_sound_event_annotation(sound_event, clip)
|
||||
]
|
||||
preds = [
|
||||
pred for pred in predictions if self.include_prediction(pred, clip)
|
||||
]
|
||||
scores = [pred.detection_score for pred in preds]
|
||||
|
||||
matches = []
|
||||
for pred_idx, gt_idx, _ in self.matcher(
|
||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||
predictions=[pred.geometry for pred in preds],
|
||||
scores=scores,
|
||||
):
|
||||
gt = gts[gt_idx] if gt_idx is not None else None
|
||||
pred = preds[pred_idx] if pred_idx is not None else None
|
||||
|
||||
matches.append(
|
||||
MatchEval(
|
||||
gt=gt,
|
||||
pred=pred,
|
||||
is_prediction=pred is not None,
|
||||
is_ground_truth=gt is not None,
|
||||
score=pred.detection_score if pred is not None else 0,
|
||||
)
|
||||
)
|
||||
|
||||
return ClipEval(clip=clip, matches=matches)
|
||||
|
||||
@tasks_registry.register(DetectionTaskConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: DetectionTaskConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = [build_detection_metric(metric) for metric in config.metrics]
|
||||
plots = [
|
||||
build_detection_plotter(plot, targets) for plot in config.plots
|
||||
]
|
||||
return DetectionTask.build(
|
||||
config=config,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
plots=plots,
|
||||
)
|
||||
111
src/batdetect2/evaluate/tasks/top_class.py
Normal file
111
src/batdetect2/evaluate/tasks/top_class.py
Normal file
@ -0,0 +1,111 @@
|
||||
from typing import List, Literal, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.metrics.top_class import (
|
||||
ClipEval,
|
||||
MatchEval,
|
||||
TopClassAveragePrecisionConfig,
|
||||
TopClassMetricConfig,
|
||||
build_top_class_metric,
|
||||
)
|
||||
from batdetect2.evaluate.plots.top_class import (
|
||||
TopClassPlotConfig,
|
||||
build_top_class_plotter,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
|
||||
|
||||
class TopClassDetectionTaskConfig(BaseTaskConfig):
|
||||
name: Literal["top_class_detection"] = "top_class_detection"
|
||||
prefix: str = "top_class"
|
||||
metrics: List[TopClassMetricConfig] = Field(
|
||||
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
||||
)
|
||||
plots: List[TopClassPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
gts = [
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if self.include_sound_event_annotation(sound_event, clip)
|
||||
]
|
||||
preds = [
|
||||
pred for pred in predictions if self.include_prediction(pred, clip)
|
||||
]
|
||||
# Take the highest score for each prediction
|
||||
scores = [pred.class_scores.max() for pred in preds]
|
||||
|
||||
matches = []
|
||||
for pred_idx, gt_idx, _ in self.matcher(
|
||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||
predictions=[pred.geometry for pred in preds],
|
||||
scores=scores,
|
||||
):
|
||||
gt = gts[gt_idx] if gt_idx is not None else None
|
||||
pred = preds[pred_idx] if pred_idx is not None else None
|
||||
|
||||
true_class = (
|
||||
self.targets.encode_class(gt) if gt is not None else None
|
||||
)
|
||||
|
||||
class_idx = (
|
||||
pred.class_scores.argmax() if pred is not None else None
|
||||
)
|
||||
|
||||
score = (
|
||||
float(pred.class_scores[class_idx]) if pred is not None else 0
|
||||
)
|
||||
|
||||
pred_class = (
|
||||
self.targets.class_names[class_idx]
|
||||
if class_idx is not None
|
||||
else None
|
||||
)
|
||||
|
||||
matches.append(
|
||||
MatchEval(
|
||||
clip=clip,
|
||||
gt=gt,
|
||||
pred=pred,
|
||||
is_ground_truth=gt is not None,
|
||||
is_prediction=pred is not None,
|
||||
true_class=true_class,
|
||||
is_generic=gt is not None and true_class is None,
|
||||
pred_class=pred_class,
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
|
||||
return ClipEval(clip=clip, matches=matches)
|
||||
|
||||
@tasks_registry.register(TopClassDetectionTaskConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: TopClassDetectionTaskConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = [build_top_class_metric(metric) for metric in config.metrics]
|
||||
plots = [
|
||||
build_top_class_plotter(plot, targets) for plot in config.plots
|
||||
]
|
||||
return TopClassDetectionTask.build(
|
||||
config=config,
|
||||
plots=plots,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
)
|
||||
@ -1,40 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Protocol
|
||||
|
||||
from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
"MetricsProtocol",
|
||||
"MatchEvaluation",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchEvaluation:
|
||||
match: data.Match
|
||||
|
||||
gt_det: bool
|
||||
gt_class: Optional[str]
|
||||
|
||||
pred_score: float
|
||||
pred_class_scores: Dict[str, float]
|
||||
|
||||
@property
|
||||
def pred_class(self) -> Optional[str]:
|
||||
if not self.pred_class_scores:
|
||||
return None
|
||||
|
||||
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
|
||||
|
||||
@property
|
||||
def pred_class_score(self) -> float:
|
||||
pred_class = self.pred_class
|
||||
|
||||
if pred_class is None:
|
||||
return 0
|
||||
|
||||
return self.pred_class_scores[pred_class]
|
||||
|
||||
|
||||
class MetricsProtocol(Protocol):
|
||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...
|
||||
10
src/batdetect2/inference/__init__.py
Normal file
10
src/batdetect2/inference/__init__.py
Normal file
@ -0,0 +1,10 @@
|
||||
from batdetect2.inference.batch import process_file_list, run_batch_inference
|
||||
from batdetect2.inference.clips import get_clips_from_files
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
|
||||
__all__ = [
|
||||
"process_file_list",
|
||||
"run_batch_inference",
|
||||
"InferenceConfig",
|
||||
"get_clips_from_files",
|
||||
]
|
||||
88
src/batdetect2/inference/batch.py
Normal file
88
src/batdetect2/inference/batch.py
Normal file
@ -0,0 +1,88 @@
|
||||
from typing import TYPE_CHECKING, List, Optional, Sequence
|
||||
|
||||
from lightning import Trainer
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio.loader import build_audio_loader
|
||||
from batdetect2.inference.clips import get_clips_from_files
|
||||
from batdetect2.inference.dataset import build_inference_loader
|
||||
from batdetect2.inference.lightning import InferenceModule
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.preprocess.preprocessor import build_preprocessor
|
||||
from batdetect2.targets.targets import build_targets
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
|
||||
def run_batch_inference(
|
||||
model,
|
||||
clips: Sequence[data.Clip],
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader()
|
||||
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
|
||||
targets = targets or build_targets()
|
||||
|
||||
loader = build_inference_loader(
|
||||
clips,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
config=config.inference.loader,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
module = InferenceModule(model)
|
||||
trainer = Trainer(enable_checkpointing=False, logger=False)
|
||||
outputs = trainer.predict(module, loader)
|
||||
return [
|
||||
clip_prediction
|
||||
for clip_predictions in outputs # type: ignore
|
||||
for clip_prediction in clip_predictions
|
||||
]
|
||||
|
||||
|
||||
def process_file_list(
|
||||
model: Model,
|
||||
paths: Sequence[data.PathLike],
|
||||
config: "BatDetect2Config",
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
clip_config = config.inference.clipping
|
||||
clips = get_clips_from_files(
|
||||
paths,
|
||||
duration=clip_config.duration,
|
||||
overlap=clip_config.overlap,
|
||||
max_empty=clip_config.max_empty,
|
||||
discard_empty=clip_config.discard_empty,
|
||||
)
|
||||
return run_batch_inference(
|
||||
model,
|
||||
clips,
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
config=config,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
75
src/batdetect2/inference/clips.py
Normal file
75
src/batdetect2/inference/clips.py
Normal file
@ -0,0 +1,75 @@
|
||||
from typing import List, Sequence
|
||||
from uuid import uuid5
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
|
||||
def get_clips_from_files(
|
||||
paths: Sequence[data.PathLike],
|
||||
duration: float,
|
||||
overlap: float = 0.0,
|
||||
max_empty: float = 0.0,
|
||||
discard_empty: bool = True,
|
||||
compute_hash: bool = False,
|
||||
) -> List[data.Clip]:
|
||||
clips: List[data.Clip] = []
|
||||
|
||||
for path in paths:
|
||||
recording = data.Recording.from_file(path, compute_hash=compute_hash)
|
||||
clips.extend(
|
||||
get_recording_clips(
|
||||
recording,
|
||||
duration,
|
||||
overlap=overlap,
|
||||
max_empty=max_empty,
|
||||
discard_empty=discard_empty,
|
||||
)
|
||||
)
|
||||
|
||||
return clips
|
||||
|
||||
|
||||
def get_recording_clips(
|
||||
recording: data.Recording,
|
||||
duration: float,
|
||||
overlap: float = 0.0,
|
||||
max_empty: float = 0.0,
|
||||
discard_empty: bool = True,
|
||||
) -> Sequence[data.Clip]:
|
||||
start_time = 0
|
||||
duration = recording.duration
|
||||
hop = duration * (1 - overlap)
|
||||
|
||||
num_clips = int(np.ceil(duration / hop))
|
||||
|
||||
if num_clips == 0:
|
||||
# This should only happen if the clip's duration is zero,
|
||||
# which should never happen in practice, but just in case...
|
||||
return []
|
||||
|
||||
clips = []
|
||||
for i in range(num_clips):
|
||||
start = start_time + i * hop
|
||||
end = start + duration
|
||||
|
||||
if end > duration:
|
||||
empty_duration = end - duration
|
||||
|
||||
if empty_duration > max_empty and discard_empty:
|
||||
# Discard clips that contain too much empty space
|
||||
continue
|
||||
|
||||
clips.append(
|
||||
data.Clip(
|
||||
uuid=uuid5(recording.uuid, f"{start}_{end}"),
|
||||
recording=recording,
|
||||
start_time=start,
|
||||
end_time=end,
|
||||
)
|
||||
)
|
||||
|
||||
if discard_empty:
|
||||
clips = [clip for clip in clips if clip.duration > max_empty]
|
||||
|
||||
return clips
|
||||
21
src/batdetect2/inference/config.py
Normal file
21
src/batdetect2/inference/config.py
Normal file
@ -0,0 +1,21 @@
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.inference.dataset import InferenceLoaderConfig
|
||||
|
||||
__all__ = ["InferenceConfig"]
|
||||
|
||||
|
||||
class ClipingConfig(BaseConfig):
|
||||
enabled: bool = True
|
||||
duration: float = 0.5
|
||||
overlap: float = 0.0
|
||||
max_empty: float = 0.0
|
||||
discard_empty: bool = True
|
||||
|
||||
|
||||
class InferenceConfig(BaseConfig):
|
||||
loader: InferenceLoaderConfig = Field(
|
||||
default_factory=InferenceLoaderConfig
|
||||
)
|
||||
clipping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||
120
src/batdetect2/inference/dataset.py
Normal file
120
src/batdetect2/inference/dataset.py
Normal file
@ -0,0 +1,120 @@
|
||||
from typing import List, NamedTuple, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"InferenceDataset",
|
||||
"build_inference_dataset",
|
||||
"build_inference_loader",
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_INFERENCE_CLIP_DURATION = 0.512
|
||||
|
||||
|
||||
class DatasetItem(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
idx: torch.Tensor
|
||||
start_time: torch.Tensor
|
||||
end_time: torch.Tensor
|
||||
|
||||
|
||||
class InferenceDataset(Dataset[DatasetItem]):
|
||||
clips: List[data.Clip]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clips: Sequence[data.Clip],
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
):
|
||||
self.clips = list(clips)
|
||||
self.preprocessor = preprocessor
|
||||
self.audio_loader = audio_loader
|
||||
self.audio_dir = audio_dir
|
||||
|
||||
def __len__(self):
|
||||
return len(self.clips)
|
||||
|
||||
def __getitem__(self, idx: int) -> DatasetItem:
|
||||
clip = self.clips[idx]
|
||||
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
||||
spectrogram = self.preprocessor(wav_tensor)
|
||||
return DatasetItem(
|
||||
spec=spectrogram,
|
||||
idx=torch.tensor(idx),
|
||||
start_time=torch.tensor(clip.start_time),
|
||||
end_time=torch.tensor(clip.end_time),
|
||||
)
|
||||
|
||||
|
||||
class InferenceLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
batch_size: int = 8
|
||||
|
||||
|
||||
def build_inference_loader(
|
||||
clips: Sequence[data.Clip],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[InferenceLoaderConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> DataLoader[DatasetItem]:
|
||||
logger.info("Building inference data loader...")
|
||||
config = config or InferenceLoaderConfig()
|
||||
|
||||
inference_dataset = build_inference_dataset(
|
||||
clips,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
)
|
||||
|
||||
num_workers = num_workers or config.num_workers
|
||||
return DataLoader(
|
||||
inference_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=config.num_workers,
|
||||
collate_fn=_collate_fn,
|
||||
)
|
||||
|
||||
|
||||
def build_inference_dataset(
|
||||
clips: Sequence[data.Clip],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
) -> InferenceDataset:
|
||||
if audio_loader is None:
|
||||
audio_loader = build_audio_loader()
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
return InferenceDataset(
|
||||
clips,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
)
|
||||
|
||||
|
||||
def _collate_fn(batch: List[DatasetItem]) -> DatasetItem:
|
||||
max_width = max(item.spec.shape[-1] for item in batch)
|
||||
return DatasetItem(
|
||||
spec=torch.stack(
|
||||
[adjust_width(item.spec, max_width) for item in batch]
|
||||
),
|
||||
idx=torch.stack([item.idx for item in batch]),
|
||||
start_time=torch.stack([item.start_time for item in batch]),
|
||||
end_time=torch.stack([item.end_time for item in batch]),
|
||||
)
|
||||
52
src/batdetect2/inference/lightning.py
Normal file
52
src/batdetect2/inference/lightning.py
Normal file
@ -0,0 +1,52 @@
|
||||
from typing import Sequence
|
||||
|
||||
from lightning import LightningModule
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
|
||||
|
||||
class InferenceModule(LightningModule):
|
||||
def __init__(self, model: Model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def predict_step(
|
||||
self,
|
||||
batch: DatasetItem,
|
||||
batch_idx: int,
|
||||
dataloader_idx: int = 0,
|
||||
) -> Sequence[BatDetect2Prediction]:
|
||||
dataset = self.get_dataset()
|
||||
|
||||
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
|
||||
|
||||
outputs = self.model.detector(batch.spec)
|
||||
|
||||
clip_detections = self.model.postprocessor(
|
||||
outputs,
|
||||
start_times=[clip.start_time for clip in clips],
|
||||
)
|
||||
|
||||
predictions = [
|
||||
BatDetect2Prediction(
|
||||
clip=clip,
|
||||
predictions=to_raw_predictions(
|
||||
clip_dets.numpy(),
|
||||
targets=self.model.targets,
|
||||
),
|
||||
)
|
||||
for clip, clip_dets in zip(clips, clip_detections)
|
||||
]
|
||||
|
||||
return predictions
|
||||
|
||||
def get_dataset(self) -> InferenceDataset:
|
||||
dataloaders = self.trainer.predict_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
dataset = dataloaders.dataset
|
||||
assert isinstance(dataset, InferenceDataset)
|
||||
return dataset
|
||||
314
src/batdetect2/logging.py
Normal file
314
src/batdetect2/logging.py
Normal file
@ -0,0 +1,314 @@
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from lightning.pytorch.loggers import (
|
||||
CSVLogger,
|
||||
Logger,
|
||||
MLFlowLogger,
|
||||
TensorBoardLogger,
|
||||
)
|
||||
from loguru import logger
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
|
||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||
|
||||
|
||||
class BaseLoggerConfig(BaseConfig):
|
||||
log_dir: Path = DEFAULT_LOGS_DIR
|
||||
experiment_name: Optional[str] = None
|
||||
run_name: Optional[str] = None
|
||||
|
||||
|
||||
class DVCLiveConfig(BaseLoggerConfig):
|
||||
name: Literal["dvclive"] = "dvclive"
|
||||
prefix: str = ""
|
||||
log_model: Union[bool, Literal["all"]] = False
|
||||
monitor_system: bool = False
|
||||
|
||||
|
||||
class CSVLoggerConfig(BaseLoggerConfig):
|
||||
name: Literal["csv"] = "csv"
|
||||
flush_logs_every_n_steps: int = 100
|
||||
|
||||
|
||||
class TensorBoardLoggerConfig(BaseLoggerConfig):
|
||||
name: Literal["tensorboard"] = "tensorboard"
|
||||
log_graph: bool = False
|
||||
|
||||
|
||||
class MLFlowLoggerConfig(BaseLoggerConfig):
|
||||
name: Literal["mlflow"] = "mlflow"
|
||||
tracking_uri: Optional[str] = "http://localhost:5000"
|
||||
tags: Optional[dict[str, Any]] = None
|
||||
log_model: bool = False
|
||||
|
||||
|
||||
LoggerConfig = Annotated[
|
||||
Union[
|
||||
DVCLiveConfig,
|
||||
CSVLoggerConfig,
|
||||
TensorBoardLoggerConfig,
|
||||
MLFlowLoggerConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
T = TypeVar("T", bound=LoggerConfig, contravariant=True)
|
||||
|
||||
|
||||
class LoggerBuilder(Protocol, Generic[T]):
|
||||
def __call__(
|
||||
self,
|
||||
config: T,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
) -> Logger: ...
|
||||
|
||||
|
||||
def create_dvclive_logger(
|
||||
config: DVCLiveConfig,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
) -> Logger:
|
||||
try:
|
||||
from dvclive.lightning import DVCLiveLogger # type: ignore
|
||||
except ImportError as error:
|
||||
raise ValueError(
|
||||
"DVCLive is not installed and cannot be used for logging"
|
||||
"Make sure you have it installed by running `pip install dvclive`"
|
||||
"or `uv add dvclive`"
|
||||
) from error
|
||||
|
||||
return DVCLiveLogger(
|
||||
dir=log_dir if log_dir is not None else config.log_dir,
|
||||
run_name=run_name if run_name is not None else config.run_name,
|
||||
experiment=experiment_name
|
||||
if experiment_name is not None
|
||||
else config.experiment_name,
|
||||
prefix=config.prefix,
|
||||
log_model=config.log_model,
|
||||
monitor_system=config.monitor_system,
|
||||
)
|
||||
|
||||
|
||||
def create_csv_logger(
|
||||
config: CSVLoggerConfig,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
) -> Logger:
|
||||
from lightning.pytorch.loggers import CSVLogger
|
||||
|
||||
if log_dir is None:
|
||||
log_dir = Path(config.log_dir)
|
||||
|
||||
if run_name is None:
|
||||
run_name = config.run_name
|
||||
|
||||
if experiment_name is None:
|
||||
experiment_name = config.experiment_name
|
||||
|
||||
name = run_name
|
||||
|
||||
if run_name is not None and experiment_name is not None:
|
||||
name = str(Path(experiment_name) / run_name)
|
||||
|
||||
return CSVLogger(
|
||||
save_dir=str(log_dir),
|
||||
name=name,
|
||||
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
||||
)
|
||||
|
||||
|
||||
def create_tensorboard_logger(
|
||||
config: TensorBoardLoggerConfig,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
) -> Logger:
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
if log_dir is None:
|
||||
log_dir = Path(config.log_dir)
|
||||
|
||||
if run_name is None:
|
||||
run_name = config.run_name
|
||||
|
||||
if experiment_name is None:
|
||||
experiment_name = config.experiment_name
|
||||
|
||||
name = run_name
|
||||
|
||||
if name is None:
|
||||
name = experiment_name
|
||||
|
||||
if run_name is not None and experiment_name is not None:
|
||||
name = str(Path(experiment_name) / run_name)
|
||||
|
||||
return TensorBoardLogger(
|
||||
save_dir=str(log_dir),
|
||||
name=name,
|
||||
log_graph=config.log_graph,
|
||||
)
|
||||
|
||||
|
||||
def create_mlflow_logger(
|
||||
config: MLFlowLoggerConfig,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
) -> Logger:
|
||||
try:
|
||||
from lightning.pytorch.loggers import MLFlowLogger
|
||||
except ImportError as error:
|
||||
raise ValueError(
|
||||
"MLFlow is not installed and cannot be used for logging. "
|
||||
"Make sure you have it installed by running `pip install mlflow` "
|
||||
"or `uv add mlflow`"
|
||||
) from error
|
||||
|
||||
if experiment_name is None:
|
||||
experiment_name = config.experiment_name or "Default"
|
||||
|
||||
if log_dir is None:
|
||||
log_dir = config.log_dir
|
||||
|
||||
return MLFlowLogger(
|
||||
experiment_name=experiment_name
|
||||
if experiment_name is not None
|
||||
else config.experiment_name,
|
||||
run_name=run_name if run_name is not None else config.run_name,
|
||||
save_dir=str(log_dir),
|
||||
tracking_uri=config.tracking_uri,
|
||||
tags=config.tags,
|
||||
log_model=config.log_model,
|
||||
)
|
||||
|
||||
|
||||
LOGGER_FACTORY: Dict[str, LoggerBuilder] = {
|
||||
"dvclive": create_dvclive_logger,
|
||||
"csv": create_csv_logger,
|
||||
"tensorboard": create_tensorboard_logger,
|
||||
"mlflow": create_mlflow_logger,
|
||||
}
|
||||
|
||||
|
||||
def build_logger(
|
||||
config: LoggerConfig,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
) -> Logger:
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building logger with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
logger_type = config.name
|
||||
if logger_type not in LOGGER_FACTORY:
|
||||
raise ValueError(f"Unknown logger type: {logger_type}")
|
||||
|
||||
creation_func = LOGGER_FACTORY[logger_type]
|
||||
return creation_func(
|
||||
config,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
|
||||
|
||||
PlotLogger = Callable[[str, Figure, int], None]
|
||||
|
||||
|
||||
def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
|
||||
if isinstance(logger, TensorBoardLogger):
|
||||
return logger.experiment.add_figure
|
||||
|
||||
if isinstance(logger, MLFlowLogger):
|
||||
|
||||
def plot_figure(name, figure, step):
|
||||
image = _convert_figure_to_array(figure)
|
||||
name = name.replace("/", "_")
|
||||
return logger.experiment.log_image(
|
||||
logger.run_id,
|
||||
image,
|
||||
key=name,
|
||||
step=step,
|
||||
)
|
||||
|
||||
return plot_figure
|
||||
|
||||
if isinstance(logger, CSVLogger):
|
||||
return partial(save_figure, dir=Path(logger.log_dir))
|
||||
|
||||
|
||||
TableLogger = Callable[[str, pd.DataFrame, int], None]
|
||||
|
||||
|
||||
def get_table_logger(logger: Logger) -> Optional[TableLogger]:
|
||||
if isinstance(logger, TensorBoardLogger):
|
||||
return partial(save_table, dir=Path(logger.log_dir))
|
||||
|
||||
if isinstance(logger, MLFlowLogger):
|
||||
|
||||
def plot_figure(name: str, df: pd.DataFrame, step: int):
|
||||
return logger.experiment.log_table(
|
||||
logger.run_id,
|
||||
data=df,
|
||||
artifact_file=f"{name}_step_{step}.json",
|
||||
)
|
||||
|
||||
return plot_figure
|
||||
|
||||
if isinstance(logger, CSVLogger):
|
||||
return partial(save_table, dir=Path(logger.log_dir))
|
||||
|
||||
|
||||
def save_table(name: str, df: pd.DataFrame, step: int, dir: Path) -> None:
|
||||
path = dir / "tables" / f"{name}_step_{step}.csv"
|
||||
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir(parents=True)
|
||||
|
||||
df.to_csv(path, index=False)
|
||||
|
||||
|
||||
def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
|
||||
path = dir / "plots" / f"{name}_step_{step}.png"
|
||||
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir(parents=True)
|
||||
|
||||
fig.savefig(path, transparent=True, bbox_inches="tight")
|
||||
|
||||
|
||||
def _convert_figure_to_array(figure: Figure) -> np.ndarray:
|
||||
with io.BytesIO() as buff:
|
||||
figure.savefig(buff, format="raw")
|
||||
buff.seek(0)
|
||||
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
|
||||
w, h = figure.canvas.get_width_height()
|
||||
im = data.reshape((int(h), int(w), -1))
|
||||
return im
|
||||
@ -26,15 +26,13 @@ for creating a standard BatDetect2 model instance is the `build_model` function
|
||||
provided here.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
import torch
|
||||
|
||||
from batdetect2.models.backbones import (
|
||||
Backbone,
|
||||
BackboneConfig,
|
||||
build_backbone,
|
||||
load_backbone_config,
|
||||
)
|
||||
from batdetect2.models.blocks import (
|
||||
ConvConfig,
|
||||
@ -48,29 +46,37 @@ from batdetect2.models.bottleneck import (
|
||||
BottleneckConfig,
|
||||
build_bottleneck,
|
||||
)
|
||||
from batdetect2.models.config import (
|
||||
BackboneConfig,
|
||||
load_backbone_config,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
DecoderConfig,
|
||||
build_decoder,
|
||||
)
|
||||
from batdetect2.models.detectors import (
|
||||
Detector,
|
||||
build_detector,
|
||||
)
|
||||
from batdetect2.models.detectors import Detector, build_detector
|
||||
from batdetect2.models.encoder import (
|
||||
DEFAULT_ENCODER_CONFIG,
|
||||
EncoderConfig,
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing.models import DetectionModel
|
||||
from batdetect2.typing.postprocess import (
|
||||
ClipDetectionsTensor,
|
||||
PostprocessorProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"BBoxHead",
|
||||
"Backbone",
|
||||
"BackboneConfig",
|
||||
"BackboneModel",
|
||||
"BackboneModel",
|
||||
"Bottleneck",
|
||||
"BottleneckConfig",
|
||||
"ClassifierHead",
|
||||
@ -78,65 +84,68 @@ __all__ = [
|
||||
"DEFAULT_DECODER_CONFIG",
|
||||
"DEFAULT_ENCODER_CONFIG",
|
||||
"DecoderConfig",
|
||||
"DetectionModel",
|
||||
"Detector",
|
||||
"DetectorHead",
|
||||
"EncoderConfig",
|
||||
"FreqCoordConvDownConfig",
|
||||
"FreqCoordConvUpConfig",
|
||||
"ModelOutput",
|
||||
"StandardConvDownConfig",
|
||||
"StandardConvUpConfig",
|
||||
"build_backbone",
|
||||
"build_bottleneck",
|
||||
"build_decoder",
|
||||
"build_detector",
|
||||
"build_encoder",
|
||||
"build_model",
|
||||
"build_detector",
|
||||
"load_backbone_config",
|
||||
"Model",
|
||||
"build_model",
|
||||
]
|
||||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
detector: DetectionModel
|
||||
preprocessor: PreprocessorProtocol
|
||||
postprocessor: PostprocessorProtocol
|
||||
targets: TargetProtocol
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
detector: DetectionModel,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
super().__init__()
|
||||
self.detector = detector
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
self.targets = targets
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
|
||||
spec = self.preprocessor(wav)
|
||||
outputs = self.detector(spec)
|
||||
return self.postprocessor(outputs)
|
||||
|
||||
|
||||
def build_model(
|
||||
num_classes: int,
|
||||
config: Optional[BackboneConfig] = None,
|
||||
) -> DetectionModel:
|
||||
"""Build the complete BatDetect2 detection model.
|
||||
|
||||
This high-level factory function constructs the standard BatDetect2 model
|
||||
architecture. It first builds the feature extraction backbone (typically an
|
||||
encoder-bottleneck-decoder structure) based on the provided
|
||||
`BackboneConfig` (or defaults if None), and then attaches the standard
|
||||
prediction heads (`DetectorHead`, `ClassifierHead`, `BBoxHead`) using the
|
||||
`build_detector` function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_classes : int
|
||||
The number of specific target classes the model should predict
|
||||
(required for the `ClassifierHead`). Must be positive.
|
||||
config : BackboneConfig, optional
|
||||
Configuration object specifying the architecture of the backbone
|
||||
(encoder, bottleneck, decoder). If None, default configurations defined
|
||||
within the respective builder functions (`build_encoder`, etc.) will be
|
||||
used to construct a default backbone architecture.
|
||||
|
||||
Returns
|
||||
-------
|
||||
DetectionModel
|
||||
An initialized `Detector` model instance.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num_classes` is not positive, or if errors occur during the
|
||||
construction of the backbone or detector components (e.g., incompatible
|
||||
configurations, invalid parameters).
|
||||
"""
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
postprocessor: Optional[PostprocessorProtocol] = None,
|
||||
):
|
||||
config = config or BackboneConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building model with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
targets = targets or build_targets()
|
||||
preprocessor = preprocessor or build_preprocessor()
|
||||
postprocessor = postprocessor or build_postprocessor(
|
||||
preprocessor=preprocessor,
|
||||
)
|
||||
detector = build_detector(
|
||||
num_classes=len(targets.class_names),
|
||||
config=config,
|
||||
)
|
||||
return Model(
|
||||
detector=detector,
|
||||
postprocessor=postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
targets=targets,
|
||||
)
|
||||
backbone = build_backbone(config)
|
||||
return build_detector(num_classes, backbone)
|
||||
|
||||
@ -18,37 +18,20 @@ automatic padding to handle input sizes not perfectly divisible by the
|
||||
network's total downsampling factor.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from soundevent import data
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.bottleneck import (
|
||||
DEFAULT_BOTTLENECK_CONFIG,
|
||||
BottleneckConfig,
|
||||
build_bottleneck,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
Decoder,
|
||||
DecoderConfig,
|
||||
build_decoder,
|
||||
)
|
||||
from batdetect2.models.encoder import (
|
||||
DEFAULT_ENCODER_CONFIG,
|
||||
Encoder,
|
||||
EncoderConfig,
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.types import BackboneModel
|
||||
from batdetect2.models.bottleneck import build_bottleneck
|
||||
from batdetect2.models.config import BackboneConfig
|
||||
from batdetect2.models.decoder import Decoder, build_decoder
|
||||
from batdetect2.models.encoder import Encoder, build_encoder
|
||||
from batdetect2.typing.models import BackboneModel
|
||||
|
||||
__all__ = [
|
||||
"Backbone",
|
||||
"BackboneConfig",
|
||||
"load_backbone_config",
|
||||
"build_backbone",
|
||||
]
|
||||
|
||||
@ -161,82 +144,6 @@ class Backbone(BackboneModel):
|
||||
return x
|
||||
|
||||
|
||||
class BackboneConfig(BaseConfig):
|
||||
"""Configuration for the Encoder-Decoder Backbone network.
|
||||
|
||||
Aggregates configurations for the encoder, bottleneck, and decoder
|
||||
components, along with defining the input and final output dimensions
|
||||
for the complete backbone.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int, default=128
|
||||
Expected height (frequency bins) of the input spectrograms to the
|
||||
backbone. Must be positive.
|
||||
in_channels : int, default=1
|
||||
Expected number of channels in the input spectrograms (e.g., 1 for
|
||||
mono). Must be positive.
|
||||
encoder : EncoderConfig, optional
|
||||
Configuration for the encoder. If None or omitted,
|
||||
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
|
||||
encoder module) will be used.
|
||||
bottleneck : BottleneckConfig, optional
|
||||
Configuration for the bottleneck layer connecting encoder and decoder.
|
||||
If None or omitted, the default bottleneck configuration will be used.
|
||||
decoder : DecoderConfig, optional
|
||||
Configuration for the decoder. If None or omitted,
|
||||
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
|
||||
decoder module) will be used.
|
||||
out_channels : int, default=32
|
||||
Desired number of channels in the final feature map output by the
|
||||
backbone. Must be positive.
|
||||
"""
|
||||
|
||||
input_height: int = 128
|
||||
in_channels: int = 1
|
||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> BackboneConfig:
|
||||
"""Load the backbone configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`BackboneConfig` schema, potentially extracting data from a nested field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
backbone configuration (e.g., "model.backbone"). If None, the entire
|
||||
file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneConfig
|
||||
The loaded and validated backbone configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded config data does not conform to `BackboneConfig`.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path.
|
||||
"""
|
||||
return load_config(path, schema=BackboneConfig, field=field)
|
||||
|
||||
|
||||
def build_backbone(config: BackboneConfig) -> BackboneModel:
|
||||
"""Factory function to build a Backbone from configuration.
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@ import torch.nn.functional as F
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"ConvBlock",
|
||||
@ -55,6 +55,12 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class SelfAttentionConfig(BaseConfig):
|
||||
name: Literal["SelfAttention"] = "SelfAttention"
|
||||
attention_channels: int
|
||||
temperature: float = 1
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""Self-Attention mechanism operating along the time dimension.
|
||||
|
||||
@ -115,6 +121,7 @@ class SelfAttention(nn.Module):
|
||||
# Note, does not encode position information (absolute or relative)
|
||||
self.temperature = temperature
|
||||
self.att_dim = attention_channels
|
||||
|
||||
self.key_fun = nn.Linear(in_channels, attention_channels)
|
||||
self.value_fun = nn.Linear(in_channels, attention_channels)
|
||||
self.query_fun = nn.Linear(in_channels, attention_channels)
|
||||
@ -171,7 +178,7 @@ class SelfAttention(nn.Module):
|
||||
class ConvConfig(BaseConfig):
|
||||
"""Configuration for a basic ConvBlock."""
|
||||
|
||||
block_type: Literal["ConvBlock"] = "ConvBlock"
|
||||
name: Literal["ConvBlock"] = "ConvBlock"
|
||||
"""Discriminator field indicating the block type."""
|
||||
|
||||
out_channels: int
|
||||
@ -218,7 +225,7 @@ class ConvBlock(nn.Module):
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply Conv -> BN -> ReLU.
|
||||
@ -233,7 +240,7 @@ class ConvBlock(nn.Module):
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H, W)`.
|
||||
"""
|
||||
return F.relu_(self.conv_bn(self.conv(x)))
|
||||
return F.relu_(self.batch_norm(self.conv(x)))
|
||||
|
||||
|
||||
class VerticalConv(nn.Module):
|
||||
@ -293,7 +300,7 @@ class VerticalConv(nn.Module):
|
||||
class FreqCoordConvDownConfig(BaseConfig):
|
||||
"""Configuration for a FreqCoordConvDownBlock."""
|
||||
|
||||
block_type: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
|
||||
name: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
|
||||
"""Discriminator field indicating the block type."""
|
||||
|
||||
out_channels: int
|
||||
@ -357,7 +364,7 @@ class FreqCoordConvDownBlock(nn.Module):
|
||||
padding=pad_size,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply CoordF -> Conv -> MaxPool -> BN -> ReLU.
|
||||
@ -376,14 +383,14 @@ class FreqCoordConvDownBlock(nn.Module):
|
||||
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
||||
x = torch.cat((x, freq_info), 1)
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
x = F.relu(self.conv_bn(x), inplace=True)
|
||||
x = F.relu(self.batch_norm(x), inplace=True)
|
||||
return x
|
||||
|
||||
|
||||
class StandardConvDownConfig(BaseConfig):
|
||||
"""Configuration for a StandardConvDownBlock."""
|
||||
|
||||
block_type: Literal["StandardConvDown"] = "StandardConvDown"
|
||||
name: Literal["StandardConvDown"] = "StandardConvDown"
|
||||
"""Discriminator field indicating the block type."""
|
||||
|
||||
out_channels: int
|
||||
@ -431,7 +438,7 @@ class StandardConvDownBlock(nn.Module):
|
||||
padding=pad_size,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
"""Apply Conv -> MaxPool -> BN -> ReLU.
|
||||
@ -447,13 +454,13 @@ class StandardConvDownBlock(nn.Module):
|
||||
Output tensor, shape `(B, C_out, H/2, W/2)`.
|
||||
"""
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
return F.relu(self.conv_bn(x), inplace=True)
|
||||
return F.relu(self.batch_norm(x), inplace=True)
|
||||
|
||||
|
||||
class FreqCoordConvUpConfig(BaseConfig):
|
||||
"""Configuration for a FreqCoordConvUpBlock."""
|
||||
|
||||
block_type: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
|
||||
name: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
|
||||
"""Discriminator field indicating the block type."""
|
||||
|
||||
out_channels: int
|
||||
@ -527,7 +534,7 @@ class FreqCoordConvUpBlock(nn.Module):
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply Interpolate -> Concat Coords -> Conv -> BN -> ReLU.
|
||||
@ -555,14 +562,14 @@ class FreqCoordConvUpBlock(nn.Module):
|
||||
freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3])
|
||||
op = torch.cat((op, freq_info), 1)
|
||||
op = self.conv(op)
|
||||
op = F.relu(self.conv_bn(op), inplace=True)
|
||||
op = F.relu(self.batch_norm(op), inplace=True)
|
||||
return op
|
||||
|
||||
|
||||
class StandardConvUpConfig(BaseConfig):
|
||||
"""Configuration for a StandardConvUpBlock."""
|
||||
|
||||
block_type: Literal["StandardConvUp"] = "StandardConvUp"
|
||||
name: Literal["StandardConvUp"] = "StandardConvUp"
|
||||
"""Discriminator field indicating the block type."""
|
||||
|
||||
out_channels: int
|
||||
@ -618,7 +625,7 @@ class StandardConvUpBlock(nn.Module):
|
||||
kernel_size=kernel_size,
|
||||
padding=pad_size,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply Interpolate -> Conv -> BN -> ReLU.
|
||||
@ -643,7 +650,7 @@ class StandardConvUpBlock(nn.Module):
|
||||
align_corners=False,
|
||||
)
|
||||
op = self.conv(op)
|
||||
op = F.relu(self.conv_bn(op), inplace=True)
|
||||
op = F.relu(self.batch_norm(op), inplace=True)
|
||||
return op
|
||||
|
||||
|
||||
@ -654,15 +661,16 @@ LayerConfig = Annotated[
|
||||
StandardConvDownConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
StandardConvUpConfig,
|
||||
SelfAttentionConfig,
|
||||
"LayerGroupConfig",
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configuration models."""
|
||||
|
||||
|
||||
class LayerGroupConfig(BaseConfig):
|
||||
block_type: Literal["LayerGroup"] = "LayerGroup"
|
||||
name: Literal["LayerGroup"] = "LayerGroup"
|
||||
layers: List[LayerConfig]
|
||||
|
||||
|
||||
@ -678,7 +686,7 @@ def build_layer_from_config(
|
||||
parameters derived from the config and the current pipeline state
|
||||
(`input_height`, `in_channels`).
|
||||
|
||||
It uses the `block_type` field within the `config` object to determine
|
||||
It uses the `name` field within the `config` object to determine
|
||||
which block class to instantiate.
|
||||
|
||||
Parameters
|
||||
@ -690,7 +698,7 @@ def build_layer_from_config(
|
||||
config : LayerConfig
|
||||
A Pydantic configuration object for the desired block (e.g., an
|
||||
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
|
||||
by its `block_type` field.
|
||||
by its `name` field.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -703,11 +711,11 @@ def build_layer_from_config(
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
If the `config.block_type` does not correspond to a known block type.
|
||||
If the `config.name` does not correspond to a known block type.
|
||||
ValueError
|
||||
If parameters derived from the config are invalid for the block.
|
||||
"""
|
||||
if config.block_type == "ConvBlock":
|
||||
if config.name == "ConvBlock":
|
||||
return (
|
||||
ConvBlock(
|
||||
in_channels=in_channels,
|
||||
@ -719,7 +727,7 @@ def build_layer_from_config(
|
||||
input_height,
|
||||
)
|
||||
|
||||
if config.block_type == "FreqCoordConvDown":
|
||||
if config.name == "FreqCoordConvDown":
|
||||
return (
|
||||
FreqCoordConvDownBlock(
|
||||
in_channels=in_channels,
|
||||
@ -732,7 +740,7 @@ def build_layer_from_config(
|
||||
input_height // 2,
|
||||
)
|
||||
|
||||
if config.block_type == "StandardConvDown":
|
||||
if config.name == "StandardConvDown":
|
||||
return (
|
||||
StandardConvDownBlock(
|
||||
in_channels=in_channels,
|
||||
@ -744,7 +752,7 @@ def build_layer_from_config(
|
||||
input_height // 2,
|
||||
)
|
||||
|
||||
if config.block_type == "FreqCoordConvUp":
|
||||
if config.name == "FreqCoordConvUp":
|
||||
return (
|
||||
FreqCoordConvUpBlock(
|
||||
in_channels=in_channels,
|
||||
@ -757,7 +765,7 @@ def build_layer_from_config(
|
||||
input_height * 2,
|
||||
)
|
||||
|
||||
if config.block_type == "StandardConvUp":
|
||||
if config.name == "StandardConvUp":
|
||||
return (
|
||||
StandardConvUpBlock(
|
||||
in_channels=in_channels,
|
||||
@ -769,7 +777,18 @@ def build_layer_from_config(
|
||||
input_height * 2,
|
||||
)
|
||||
|
||||
if config.block_type == "LayerGroup":
|
||||
if config.name == "SelfAttention":
|
||||
return (
|
||||
SelfAttention(
|
||||
in_channels=in_channels,
|
||||
attention_channels=config.attention_channels,
|
||||
temperature=config.temperature,
|
||||
),
|
||||
config.attention_channels,
|
||||
input_height,
|
||||
)
|
||||
|
||||
if config.name == "LayerGroup":
|
||||
current_channels = in_channels
|
||||
current_height = input_height
|
||||
|
||||
@ -785,4 +804,4 @@ def build_layer_from_config(
|
||||
|
||||
return nn.Sequential(*blocks), current_channels, current_height
|
||||
|
||||
raise NotImplementedError(f"Unknown block type {config.block_type}")
|
||||
raise NotImplementedError(f"Unknown block type {config.name}")
|
||||
|
||||
@ -14,47 +14,26 @@ A factory function `build_bottleneck` constructs the appropriate bottleneck
|
||||
module based on the provided configuration.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Annotated, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.blocks import SelfAttention, VerticalConv
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
SelfAttentionConfig,
|
||||
VerticalConv,
|
||||
build_layer_from_config,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BottleneckConfig",
|
||||
"Bottleneck",
|
||||
"BottleneckAttn",
|
||||
"build_bottleneck",
|
||||
]
|
||||
|
||||
|
||||
class BottleneckConfig(BaseConfig):
|
||||
"""Configuration for the bottleneck layer(s).
|
||||
|
||||
Defines the number of channels within the bottleneck and whether to include
|
||||
a self-attention mechanism.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
channels : int
|
||||
The number of output channels produced by the main convolutional layer
|
||||
within the bottleneck. This often matches the number of channels coming
|
||||
from the last encoder stage, but can be different. Must be positive.
|
||||
This also defines the channel dimensions used within the optional
|
||||
`SelfAttention` layer.
|
||||
self_attention : bool
|
||||
If True, includes a `SelfAttention` layer operating on the time
|
||||
dimension after an initial `VerticalConv` layer within the bottleneck.
|
||||
If False, only the initial `VerticalConv` (and height repetition) is
|
||||
performed.
|
||||
"""
|
||||
|
||||
channels: int
|
||||
self_attention: bool
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""Base Bottleneck module for Encoder-Decoder architectures.
|
||||
|
||||
@ -99,16 +78,24 @@ class Bottleneck(nn.Module):
|
||||
input_height: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
bottleneck_channels: Optional[int] = None,
|
||||
layers: Optional[List[torch.nn.Module]] = None,
|
||||
) -> None:
|
||||
"""Initialize the base Bottleneck layer."""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.input_height = input_height
|
||||
self.out_channels = out_channels
|
||||
self.bottleneck_channels = (
|
||||
bottleneck_channels
|
||||
if bottleneck_channels is not None
|
||||
else out_channels
|
||||
)
|
||||
self.layers = nn.ModuleList(layers or [])
|
||||
|
||||
self.conv_vert = VerticalConv(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
out_channels=self.bottleneck_channels,
|
||||
input_height=input_height,
|
||||
)
|
||||
|
||||
@ -132,73 +119,52 @@ class Bottleneck(nn.Module):
|
||||
convolution.
|
||||
"""
|
||||
x = self.conv_vert(x)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
|
||||
return x.repeat([1, 1, self.input_height, 1])
|
||||
|
||||
|
||||
class BottleneckAttn(Bottleneck):
|
||||
"""Bottleneck module including a Self-Attention layer.
|
||||
BottleneckLayerConfig = Annotated[
|
||||
Union[SelfAttentionConfig,],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||
|
||||
Extends the base `Bottleneck` by inserting a `SelfAttention` layer after
|
||||
the initial `VerticalConv`. This allows the bottleneck to capture global
|
||||
temporal dependencies in the summarized frequency features before passing
|
||||
them to the decoder.
|
||||
|
||||
Sequence: VerticalConv -> SelfAttention -> Repeat Height.
|
||||
class BottleneckConfig(BaseConfig):
|
||||
"""Configuration for the bottleneck layer(s).
|
||||
|
||||
Parameters
|
||||
Defines the number of channels within the bottleneck and whether to include
|
||||
a self-attention mechanism.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int
|
||||
Height (frequency bins) of the input tensor from the encoder.
|
||||
in_channels : int
|
||||
Number of channels in the input tensor from the encoder.
|
||||
out_channels : int
|
||||
Number of output channels produced by the `VerticalConv` and
|
||||
subsequently processed and output by this bottleneck. Also determines
|
||||
the input/output channels of the internal `SelfAttention` layer.
|
||||
attention : nn.Module
|
||||
An initialized `SelfAttention` module instance.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `input_height`, `in_channels`, or `out_channels` are not positive.
|
||||
channels : int
|
||||
The number of output channels produced by the main convolutional layer
|
||||
within the bottleneck. This often matches the number of channels coming
|
||||
from the last encoder stage, but can be different. Must be positive.
|
||||
This also defines the channel dimensions used within the optional
|
||||
`SelfAttention` layer.
|
||||
self_attention : bool
|
||||
If True, includes a `SelfAttention` layer operating on the time
|
||||
dimension after an initial `VerticalConv` layer within the bottleneck.
|
||||
If False, only the initial `VerticalConv` (and height repetition) is
|
||||
performed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_height: int,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
attention: nn.Module,
|
||||
) -> None:
|
||||
"""Initialize the Bottleneck with Self-Attention."""
|
||||
super().__init__(input_height, in_channels, out_channels)
|
||||
self.attention = attention
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Process input tensor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Input tensor from the encoder bottleneck, shape
|
||||
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
|
||||
`H_in` must match `self.input_height`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Output tensor, shape `(B, C_out, H_in, W)`, after applying attention
|
||||
and repeating the height dimension.
|
||||
"""
|
||||
x = self.conv_vert(x)
|
||||
x = self.attention(x)
|
||||
return x.repeat([1, 1, self.input_height, 1])
|
||||
channels: int
|
||||
layers: List[BottleneckLayerConfig] = Field(
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
||||
channels=256,
|
||||
self_attention=True,
|
||||
layers=[
|
||||
SelfAttentionConfig(attention_channels=256),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -234,21 +200,25 @@ def build_bottleneck(
|
||||
"""
|
||||
config = config or DEFAULT_BOTTLENECK_CONFIG
|
||||
|
||||
if config.self_attention:
|
||||
attention = SelfAttention(
|
||||
in_channels=config.channels,
|
||||
attention_channels=config.channels,
|
||||
)
|
||||
current_channels = in_channels
|
||||
current_height = input_height
|
||||
|
||||
return BottleneckAttn(
|
||||
input_height=input_height,
|
||||
in_channels=in_channels,
|
||||
out_channels=config.channels,
|
||||
attention=attention,
|
||||
layers = []
|
||||
|
||||
for layer_config in config.layers:
|
||||
layer, current_channels, current_height = build_layer_from_config(
|
||||
input_height=current_height,
|
||||
in_channels=current_channels,
|
||||
config=layer_config,
|
||||
)
|
||||
assert current_height == input_height, (
|
||||
"Bottleneck layers should not change the spectrogram height"
|
||||
)
|
||||
layers.append(layer)
|
||||
|
||||
return Bottleneck(
|
||||
input_height=input_height,
|
||||
in_channels=in_channels,
|
||||
out_channels=config.channels,
|
||||
layers=layers,
|
||||
)
|
||||
|
||||
98
src/batdetect2/models/config.py
Normal file
98
src/batdetect2/models/config.py
Normal file
@ -0,0 +1,98 @@
|
||||
from typing import Optional
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.models.bottleneck import (
|
||||
DEFAULT_BOTTLENECK_CONFIG,
|
||||
BottleneckConfig,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
DecoderConfig,
|
||||
)
|
||||
from batdetect2.models.encoder import (
|
||||
DEFAULT_ENCODER_CONFIG,
|
||||
EncoderConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BackboneConfig",
|
||||
"load_backbone_config",
|
||||
]
|
||||
|
||||
|
||||
class BackboneConfig(BaseConfig):
|
||||
"""Configuration for the Encoder-Decoder Backbone network.
|
||||
|
||||
Aggregates configurations for the encoder, bottleneck, and decoder
|
||||
components, along with defining the input and final output dimensions
|
||||
for the complete backbone.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int, default=128
|
||||
Expected height (frequency bins) of the input spectrograms to the
|
||||
backbone. Must be positive.
|
||||
in_channels : int, default=1
|
||||
Expected number of channels in the input spectrograms (e.g., 1 for
|
||||
mono). Must be positive.
|
||||
encoder : EncoderConfig, optional
|
||||
Configuration for the encoder. If None or omitted,
|
||||
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
|
||||
encoder module) will be used.
|
||||
bottleneck : BottleneckConfig, optional
|
||||
Configuration for the bottleneck layer connecting encoder and decoder.
|
||||
If None or omitted, the default bottleneck configuration will be used.
|
||||
decoder : DecoderConfig, optional
|
||||
Configuration for the decoder. If None or omitted,
|
||||
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
|
||||
decoder module) will be used.
|
||||
out_channels : int, default=32
|
||||
Desired number of channels in the final feature map output by the
|
||||
backbone. Must be positive.
|
||||
"""
|
||||
|
||||
input_height: int = 128
|
||||
in_channels: int = 1
|
||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> BackboneConfig:
|
||||
"""Load the backbone configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`BackboneConfig` schema, potentially extracting data from a nested field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
backbone configuration (e.g., "model.backbone"). If None, the entire
|
||||
file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneConfig
|
||||
The loaded and validated backbone configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded config data does not conform to `BackboneConfig`.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path.
|
||||
"""
|
||||
return load_config(path, schema=BackboneConfig, field=field)
|
||||
@ -24,7 +24,7 @@ import torch
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
ConvConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
@ -47,7 +47,7 @@ DecoderLayerConfig = Annotated[
|
||||
StandardConvUpConfig,
|
||||
LayerGroupConfig,
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||
|
||||
@ -63,7 +63,7 @@ class DecoderConfig(BaseConfig):
|
||||
layers : List[DecoderLayerConfig]
|
||||
An ordered list of configuration objects, each defining one layer or
|
||||
block in the decoder sequence. Each item must be a valid block
|
||||
config including a `block_type` field and necessary parameters like
|
||||
config including a `name` field and necessary parameters like
|
||||
`out_channels`. Input channels for each layer are inferred sequentially.
|
||||
The list must contain at least one layer.
|
||||
"""
|
||||
@ -249,9 +249,9 @@ def build_decoder(
|
||||
------
|
||||
ValueError
|
||||
If `in_channels` or `input_height` are not positive, or if the layer
|
||||
configuration is invalid (e.g., empty list, unknown `block_type`).
|
||||
configuration is invalid (e.g., empty list, unknown `name`).
|
||||
NotImplementedError
|
||||
If `build_layer_from_config` encounters an unknown `block_type`.
|
||||
If `build_layer_from_config` encounters an unknown `name`.
|
||||
"""
|
||||
config = config or DEFAULT_DECODER_CONFIG
|
||||
|
||||
|
||||
@ -8,18 +8,25 @@ classifying them.
|
||||
|
||||
The primary components are:
|
||||
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
|
||||
- `build_detector`: A factory function to conveniently construct a standard
|
||||
`Detector` instance given a backbone and the number of target classes.
|
||||
|
||||
This module focuses purely on the neural network architecture definition. The
|
||||
logic for preprocessing inputs and postprocessing/decoding outputs resides in
|
||||
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.models.backbones import BackboneConfig, build_backbone
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"Detector",
|
||||
"build_detector",
|
||||
]
|
||||
|
||||
|
||||
class Detector(DetectionModel):
|
||||
@ -119,36 +126,41 @@ class Detector(DetectionModel):
|
||||
)
|
||||
|
||||
|
||||
def build_detector(num_classes: int, backbone: BackboneModel) -> Detector:
|
||||
"""Factory function to build a standard Detector model instance.
|
||||
|
||||
Creates the standard prediction heads (`ClassifierHead`, `DetectorHead`,
|
||||
`BBoxHead`) configured appropriately based on the output channels of the
|
||||
provided `backbone` and the specified `num_classes`. It then assembles
|
||||
these components into a `Detector` model.
|
||||
def build_detector(
|
||||
num_classes: int, config: Optional[BackboneConfig] = None
|
||||
) -> DetectionModel:
|
||||
"""Build the complete BatDetect2 detection model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_classes : int
|
||||
The number of specific target classes for the classification head
|
||||
(excluding any implicit background class). Must be positive.
|
||||
backbone : BackboneModel
|
||||
An initialized feature extraction backbone module instance. The number
|
||||
of output channels from this backbone (`backbone.out_channels`) is used
|
||||
to configure the input channels for the prediction heads.
|
||||
The number of specific target classes the model should predict
|
||||
(required for the `ClassifierHead`). Must be positive.
|
||||
config : BackboneConfig, optional
|
||||
Configuration object specifying the architecture of the backbone
|
||||
(encoder, bottleneck, decoder). If None, default configurations defined
|
||||
within the respective builder functions (`build_encoder`, etc.) will be
|
||||
used to construct a default backbone architecture.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Detector
|
||||
DetectionModel
|
||||
An initialized `Detector` model instance.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num_classes` is not positive.
|
||||
AttributeError
|
||||
If `backbone` does not have the required `out_channels` attribute.
|
||||
If `num_classes` is not positive, or if errors occur during the
|
||||
construction of the backbone or detector components (e.g., incompatible
|
||||
configurations, invalid parameters).
|
||||
"""
|
||||
config = config or BackboneConfig()
|
||||
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building model with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
backbone = build_backbone(config=config)
|
||||
classifier_head = ClassifierHead(
|
||||
num_classes=num_classes,
|
||||
in_channels=backbone.out_channels,
|
||||
|
||||
@ -26,7 +26,7 @@ import torch
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
@ -49,7 +49,7 @@ EncoderLayerConfig = Annotated[
|
||||
StandardConvDownConfig,
|
||||
LayerGroupConfig,
|
||||
],
|
||||
Field(discriminator="block_type"),
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
||||
|
||||
@ -66,7 +66,7 @@ class EncoderConfig(BaseConfig):
|
||||
An ordered list of configuration objects, each defining one layer or
|
||||
block in the encoder sequence. Each item must be a valid block config
|
||||
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
|
||||
`StandardConvDownConfig`) including a `block_type` field and necessary
|
||||
`StandardConvDownConfig`) including a `name` field and necessary
|
||||
parameters like `out_channels`. Input channels for each layer are
|
||||
inferred sequentially. The list must contain at least one layer.
|
||||
"""
|
||||
@ -287,9 +287,9 @@ def build_encoder(
|
||||
------
|
||||
ValueError
|
||||
If `in_channels` or `input_height` are not positive, or if the layer
|
||||
configuration is invalid (e.g., empty list, unknown `block_type`).
|
||||
configuration is invalid (e.g., empty list, unknown `name`).
|
||||
NotImplementedError
|
||||
If `build_layer_from_config` encounters an unknown `block_type`.
|
||||
If `build_layer_from_config` encounters an unknown `name`.
|
||||
"""
|
||||
if in_channels <= 0 or input_height <= 0:
|
||||
raise ValueError("in_channels and input_height must be positive.")
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
from batdetect2.plotting.clip_annotations import plot_clip_annotation
|
||||
from batdetect2.plotting.clip_predictions import plot_clip_prediction
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.plotting.common import plot_spectrogram
|
||||
from batdetect2.plotting.gallery import plot_match_gallery
|
||||
from batdetect2.plotting.heatmaps import (
|
||||
plot_classification_heatmap,
|
||||
plot_detection_heatmap,
|
||||
)
|
||||
from batdetect2.plotting.matches import (
|
||||
plot_cross_trigger_match,
|
||||
plot_false_negative_match,
|
||||
plot_false_positive_match,
|
||||
plot_matches,
|
||||
plot_true_positive_match,
|
||||
)
|
||||
|
||||
@ -13,9 +18,12 @@ __all__ = [
|
||||
"plot_clip",
|
||||
"plot_clip_annotation",
|
||||
"plot_clip_prediction",
|
||||
"plot_matches",
|
||||
"plot_false_positive_match",
|
||||
"plot_true_positive_match",
|
||||
"plot_false_negative_match",
|
||||
"plot_cross_trigger_match",
|
||||
"plot_false_negative_match",
|
||||
"plot_false_positive_match",
|
||||
"plot_spectrogram",
|
||||
"plot_true_positive_match",
|
||||
"plot_detection_heatmap",
|
||||
"plot_classification_heatmap",
|
||||
"plot_match_gallery",
|
||||
]
|
||||
|
||||
@ -4,7 +4,9 @@ from matplotlib.axes import Axes
|
||||
from soundevent import data, plot
|
||||
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.preprocess import PreprocessorProtocol
|
||||
from batdetect2.plotting.common import create_ax
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"plot_clip_annotation",
|
||||
@ -17,8 +19,6 @@ def plot_clip_annotation(
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_points: bool = False,
|
||||
cmap: str = "gray",
|
||||
alpha: float = 1,
|
||||
@ -31,8 +31,6 @@ def plot_clip_annotation(
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=cmap,
|
||||
)
|
||||
|
||||
@ -47,3 +45,29 @@ def plot_clip_annotation(
|
||||
facecolor="none" if not fill else None,
|
||||
)
|
||||
return ax
|
||||
|
||||
|
||||
def plot_anchor_points(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
targets: TargetProtocol,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
size: int = 1,
|
||||
color: str = "red",
|
||||
marker: str = "x",
|
||||
alpha: float = 1,
|
||||
) -> Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
|
||||
positions = []
|
||||
|
||||
for sound_event in clip_annotation.sound_events:
|
||||
if not targets.filter(sound_event):
|
||||
continue
|
||||
|
||||
position, _ = targets.encode_roi(sound_event)
|
||||
positions.append(position)
|
||||
|
||||
X, Y = zip(*positions)
|
||||
ax.scatter(X, Y, s=size, c=color, marker=marker, alpha=alpha)
|
||||
return ax
|
||||
|
||||
@ -8,7 +8,7 @@ from soundevent.plot.geometries import plot_geometry
|
||||
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
|
||||
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"plot_clip_prediction",
|
||||
@ -21,8 +21,6 @@ def plot_clip_prediction(
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_legend: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
linewidth: float = 1,
|
||||
@ -34,8 +32,6 @@ def plot_clip_prediction(
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from matplotlib.axes import Axes
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessorProtocol,
|
||||
get_default_preprocessor,
|
||||
)
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.plotting.common import plot_spectrogram
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"plot_clip",
|
||||
@ -16,29 +17,33 @@ __all__ = [
|
||||
|
||||
def plot_clip(
|
||||
clip: data.Clip,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
) -> Axes:
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(figsize=figsize)
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = get_default_preprocessor()
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
spec = preprocessor.preprocess_clip(clip, audio_dir=audio_dir)
|
||||
if audio_loader is None:
|
||||
audio_loader = build_audio_loader()
|
||||
|
||||
spec.plot( # type: ignore
|
||||
wav = torch.tensor(audio_loader.load_clip(clip, audio_dir=audio_dir))
|
||||
spec = preprocessor(wav)
|
||||
|
||||
plot_spectrogram(
|
||||
spec,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
ax=ax,
|
||||
add_colorbar=add_colorbar,
|
||||
cmap=spec_cmap,
|
||||
add_labels=add_labels,
|
||||
vmin=spec.min().item(),
|
||||
vmax=spec.max().item(),
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
"""General plotting utilities."""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import axes
|
||||
|
||||
__all__ = [
|
||||
@ -12,11 +14,62 @@ __all__ = [
|
||||
|
||||
def create_ax(
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Tuple[int, int] = (10, 10),
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
**kwargs,
|
||||
) -> axes.Axes:
|
||||
"""Create a new axis if none is provided"""
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
|
||||
_, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) # type: ignore
|
||||
|
||||
return ax # type: ignore
|
||||
|
||||
|
||||
def plot_spectrogram(
|
||||
spec: Union[torch.Tensor, np.ndarray],
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
min_freq: Optional[float] = None,
|
||||
max_freq: Optional[float] = None,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_colorbar: bool = False,
|
||||
colorbar_kwargs: Optional[dict] = None,
|
||||
vmin: Optional[float] = None,
|
||||
vmax: Optional[float] = None,
|
||||
cmap="gray",
|
||||
) -> axes.Axes:
|
||||
if isinstance(spec, torch.Tensor):
|
||||
spec = spec.numpy()
|
||||
|
||||
spec = spec.squeeze()
|
||||
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
|
||||
if start_time is None:
|
||||
start_time = 0
|
||||
|
||||
if end_time is None:
|
||||
end_time = spec.shape[-1]
|
||||
|
||||
if min_freq is None:
|
||||
min_freq = 0
|
||||
|
||||
if max_freq is None:
|
||||
max_freq = spec.shape[-2]
|
||||
|
||||
mappable = ax.pcolormesh(
|
||||
np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True),
|
||||
np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True),
|
||||
spec,
|
||||
cmap=cmap,
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
)
|
||||
|
||||
ax.set_xlim(start_time, end_time)
|
||||
ax.set_ylim(min_freq, max_freq)
|
||||
|
||||
if add_colorbar:
|
||||
plt.colorbar(mappable, ax=ax, **(colorbar_kwargs or {}))
|
||||
|
||||
return ax
|
||||
|
||||
113
src/batdetect2/plotting/detections.py
Normal file
113
src/batdetect2/plotting/detections.py
Normal file
@ -0,0 +1,113 @@
|
||||
from typing import Optional
|
||||
|
||||
from matplotlib import axes, patches
|
||||
from soundevent.plot import plot_geometry
|
||||
|
||||
from batdetect2.evaluate.metrics.detection import ClipEval
|
||||
from batdetect2.plotting.clips import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
plot_clip,
|
||||
)
|
||||
from batdetect2.plotting.common import create_ax
|
||||
|
||||
__all__ = [
|
||||
"plot_clip_detections",
|
||||
]
|
||||
|
||||
|
||||
def plot_clip_detections(
|
||||
clip_eval: ClipEval,
|
||||
figsize: tuple[int, int] = (10, 10),
|
||||
ax: Optional[axes.Axes] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
threshold: float = 0.2,
|
||||
add_legend: bool = True,
|
||||
add_title: bool = True,
|
||||
fill: bool = False,
|
||||
linewidth: float = 1.0,
|
||||
gt_color: str = "green",
|
||||
gt_linestyle: str = "-",
|
||||
true_pred_color: str = "yellow",
|
||||
true_pred_linestyle: str = "--",
|
||||
false_pred_color: str = "blue",
|
||||
false_pred_linestyle: str = "-",
|
||||
missed_gt_color: str = "red",
|
||||
missed_gt_linestyle: str = "-",
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(figsize=figsize, ax=ax)
|
||||
|
||||
plot_clip(
|
||||
clip_eval.clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
ax=ax,
|
||||
)
|
||||
|
||||
for m in clip_eval.matches:
|
||||
is_match = (
|
||||
m.pred is not None and m.gt is not None and m.score >= threshold
|
||||
)
|
||||
|
||||
if m.pred is not None:
|
||||
color = true_pred_color if is_match else false_pred_color
|
||||
plot_geometry(
|
||||
m.pred.geometry,
|
||||
ax=ax,
|
||||
add_points=False,
|
||||
facecolor="none" if not fill else color,
|
||||
alpha=m.pred.detection_score,
|
||||
linewidth=linewidth,
|
||||
linestyle=true_pred_linestyle
|
||||
if is_match
|
||||
else missed_gt_linestyle,
|
||||
color=color,
|
||||
)
|
||||
|
||||
if m.gt is not None:
|
||||
color = gt_color if is_match else missed_gt_color
|
||||
plot_geometry(
|
||||
m.gt.sound_event.geometry, # type: ignore
|
||||
ax=ax,
|
||||
add_points=False,
|
||||
linewidth=linewidth,
|
||||
facecolor="none" if not fill else color,
|
||||
linestyle=gt_linestyle if is_match else false_pred_linestyle,
|
||||
color=color,
|
||||
)
|
||||
|
||||
if add_title:
|
||||
ax.set_title(clip_eval.clip.recording.path.name)
|
||||
|
||||
if add_legend:
|
||||
ax.legend(
|
||||
handles=[
|
||||
patches.Patch(
|
||||
label="found GT",
|
||||
edgecolor=gt_color,
|
||||
facecolor="none" if not fill else gt_color,
|
||||
linestyle=gt_linestyle,
|
||||
),
|
||||
patches.Patch(
|
||||
label="missed GT",
|
||||
edgecolor=missed_gt_color,
|
||||
facecolor="none" if not fill else missed_gt_color,
|
||||
linestyle=missed_gt_linestyle,
|
||||
),
|
||||
patches.Patch(
|
||||
label="true Det",
|
||||
edgecolor=true_pred_color,
|
||||
facecolor="none" if not fill else true_pred_color,
|
||||
linestyle=true_pred_linestyle,
|
||||
),
|
||||
patches.Patch(
|
||||
label="false Det",
|
||||
edgecolor=false_pred_color,
|
||||
facecolor="none" if not fill else false_pred_color,
|
||||
linestyle=false_pred_linestyle,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return ax
|
||||
@ -1,160 +0,0 @@
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
|
||||
from batdetect2 import plotting
|
||||
from batdetect2.evaluate.types import MatchEvaluation
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassExamples:
|
||||
false_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||
false_negatives: List[MatchEvaluation] = field(default_factory=list)
|
||||
true_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
||||
|
||||
|
||||
def plot_example_gallery(
|
||||
matches: List[MatchEvaluation],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
n_examples: int = 5,
|
||||
):
|
||||
class_examples = defaultdict(ClassExamples)
|
||||
|
||||
for match in matches:
|
||||
gt_class = match.gt_class
|
||||
pred_class = match.pred_class
|
||||
|
||||
if pred_class is None:
|
||||
class_examples[gt_class].false_negatives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class is None:
|
||||
class_examples[pred_class].false_positives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class != pred_class:
|
||||
class_examples[gt_class].cross_triggers.append(match)
|
||||
class_examples[pred_class].cross_triggers.append(match)
|
||||
continue
|
||||
|
||||
class_examples[gt_class].true_positives.append(match)
|
||||
|
||||
for class_name, examples in class_examples.items():
|
||||
true_positives = get_binned_sample(
|
||||
examples.true_positives,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
|
||||
false_positives = get_binned_sample(
|
||||
examples.false_positives,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
|
||||
false_negatives = random.sample(
|
||||
examples.false_negatives,
|
||||
k=min(n_examples, len(examples.false_negatives)),
|
||||
)
|
||||
|
||||
cross_triggers = get_binned_sample(
|
||||
examples.cross_triggers,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
|
||||
fig = plot_class_examples(
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
cross_triggers,
|
||||
preprocessor=preprocessor,
|
||||
n_examples=n_examples,
|
||||
)
|
||||
|
||||
yield class_name, fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def plot_class_examples(
|
||||
true_positives: List[MatchEvaluation],
|
||||
false_positives: List[MatchEvaluation],
|
||||
false_negatives: List[MatchEvaluation],
|
||||
cross_triggers: List[MatchEvaluation],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
n_examples: int = 5,
|
||||
duration: float = 0.1,
|
||||
):
|
||||
fig = plt.figure(figsize=(20, 20))
|
||||
|
||||
for index, match in enumerate(true_positives[:n_examples]):
|
||||
ax = plt.subplot(4, n_examples, index + 1)
|
||||
try:
|
||||
plotting.plot_true_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(false_positives[:n_examples]):
|
||||
ax = plt.subplot(4, n_examples, n_examples + index + 1)
|
||||
try:
|
||||
plotting.plot_false_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(false_negatives[:n_examples]):
|
||||
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
|
||||
try:
|
||||
plotting.plot_false_negative_match(
|
||||
match,
|
||||
ax=ax,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(cross_triggers[:n_examples]):
|
||||
ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1)
|
||||
try:
|
||||
plotting.plot_cross_trigger_match(
|
||||
match,
|
||||
ax=ax,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError):
|
||||
continue
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
|
||||
if len(matches) < n_examples:
|
||||
return matches
|
||||
|
||||
indices, pred_scores = zip(
|
||||
*[
|
||||
(index, match.pred_class_scores[pred_class])
|
||||
for index, match in enumerate(matches)
|
||||
if (pred_class := match.pred_class) is not None
|
||||
]
|
||||
)
|
||||
|
||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
||||
df = pd.DataFrame({"indices": indices, "bins": bins})
|
||||
sample = df.groupby("bins").apply(lambda x: x.sample(1))
|
||||
return [matches[ind] for ind in sample["indices"]]
|
||||
109
src/batdetect2/plotting/gallery.py
Normal file
109
src/batdetect2/plotting/gallery.py
Normal file
@ -0,0 +1,109 @@
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from batdetect2.plotting.matches import (
|
||||
MatchProtocol,
|
||||
plot_cross_trigger_match,
|
||||
plot_false_negative_match,
|
||||
plot_false_positive_match,
|
||||
plot_true_positive_match,
|
||||
)
|
||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||
|
||||
__all__ = ["plot_match_gallery"]
|
||||
|
||||
|
||||
def plot_match_gallery(
|
||||
true_positives: Sequence[MatchProtocol],
|
||||
false_positives: Sequence[MatchProtocol],
|
||||
false_negatives: Sequence[MatchProtocol],
|
||||
cross_triggers: Sequence[MatchProtocol],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
n_examples: int = 5,
|
||||
duration: float = 0.1,
|
||||
fig: Optional[Figure] = None,
|
||||
):
|
||||
if fig is None:
|
||||
fig = plt.figure(figsize=(20, 20))
|
||||
|
||||
axes = fig.subplots(
|
||||
nrows=4,
|
||||
ncols=n_examples,
|
||||
sharex="none",
|
||||
sharey="row",
|
||||
)
|
||||
|
||||
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples]):
|
||||
try:
|
||||
plot_true_positive_match(
|
||||
tp_match,
|
||||
ax=tp_ax,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (
|
||||
ValueError,
|
||||
AssertionError,
|
||||
RuntimeError,
|
||||
FileNotFoundError,
|
||||
):
|
||||
continue
|
||||
|
||||
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples]):
|
||||
try:
|
||||
plot_false_positive_match(
|
||||
fp_match,
|
||||
ax=fp_ax,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (
|
||||
ValueError,
|
||||
AssertionError,
|
||||
RuntimeError,
|
||||
FileNotFoundError,
|
||||
):
|
||||
continue
|
||||
|
||||
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples]):
|
||||
try:
|
||||
plot_false_negative_match(
|
||||
fn_match,
|
||||
ax=fn_ax,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (
|
||||
ValueError,
|
||||
AssertionError,
|
||||
RuntimeError,
|
||||
FileNotFoundError,
|
||||
):
|
||||
continue
|
||||
|
||||
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples]):
|
||||
try:
|
||||
plot_cross_trigger_match(
|
||||
ct_match,
|
||||
ax=ct_ax,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (
|
||||
ValueError,
|
||||
AssertionError,
|
||||
RuntimeError,
|
||||
FileNotFoundError,
|
||||
):
|
||||
continue
|
||||
|
||||
fig.tight_layout()
|
||||
|
||||
return fig
|
||||
@ -1,26 +1,117 @@
|
||||
"""Plot heatmaps"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import xarray as xr
|
||||
from matplotlib import axes
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import axes, patches
|
||||
from matplotlib.cm import get_cmap
|
||||
from matplotlib.colors import Colormap, LinearSegmentedColormap, to_rgba
|
||||
|
||||
from batdetect2.plotting.common import create_ax
|
||||
|
||||
|
||||
def plot_heatmap(
|
||||
heatmap: xr.DataArray,
|
||||
def plot_detection_heatmap(
|
||||
heatmap: Union[torch.Tensor, np.ndarray],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Tuple[int, int] = (10, 10),
|
||||
threshold: Optional[float] = None,
|
||||
alpha: float = 1,
|
||||
cmap: Union[str, Colormap] = "jet",
|
||||
color: Optional[str] = None,
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax, figsize=figsize)
|
||||
|
||||
if isinstance(heatmap, torch.Tensor):
|
||||
heatmap = heatmap.numpy()
|
||||
|
||||
heatmap = heatmap.squeeze()
|
||||
|
||||
if threshold is not None:
|
||||
heatmap = np.ma.masked_where(
|
||||
heatmap < threshold,
|
||||
heatmap,
|
||||
)
|
||||
|
||||
if color is not None:
|
||||
cmap = create_colormap(color)
|
||||
|
||||
ax.pcolormesh(
|
||||
heatmap.time,
|
||||
heatmap.frequency,
|
||||
heatmap,
|
||||
vmax=1,
|
||||
vmin=0,
|
||||
cmap=cmap,
|
||||
alpha=alpha,
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_classification_heatmap(
|
||||
heatmap: Union[torch.Tensor, np.ndarray],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Tuple[int, int] = (10, 10),
|
||||
class_names: Optional[List[str]] = None,
|
||||
threshold: Optional[float] = 0.1,
|
||||
alpha: float = 1,
|
||||
cmap: Union[str, Colormap] = "tab20",
|
||||
):
|
||||
ax = create_ax(ax, figsize=figsize)
|
||||
|
||||
if isinstance(heatmap, torch.Tensor):
|
||||
heatmap = heatmap.numpy()
|
||||
|
||||
if heatmap.ndim == 4:
|
||||
heatmap = heatmap[0]
|
||||
|
||||
if heatmap.ndim != 3:
|
||||
raise ValueError("Expecting a 3-dimensional array")
|
||||
|
||||
num_classes = heatmap.shape[0]
|
||||
|
||||
if class_names is None:
|
||||
class_names = [f"class_{i}" for i in range(num_classes)]
|
||||
|
||||
if len(class_names) != num_classes:
|
||||
raise ValueError("Inconsistent number of class names")
|
||||
|
||||
if not isinstance(cmap, Colormap):
|
||||
cmap = get_cmap(cmap)
|
||||
|
||||
handles = []
|
||||
|
||||
for index, class_heatmap in enumerate(heatmap):
|
||||
class_name = class_names[index]
|
||||
|
||||
color = cmap(index / num_classes)
|
||||
|
||||
max = class_heatmap.max()
|
||||
|
||||
if max == 0:
|
||||
continue
|
||||
|
||||
if threshold is not None:
|
||||
class_heatmap = np.ma.masked_where(
|
||||
class_heatmap < threshold,
|
||||
class_heatmap,
|
||||
)
|
||||
|
||||
ax.pcolormesh(
|
||||
class_heatmap,
|
||||
vmax=1,
|
||||
vmin=0,
|
||||
cmap=create_colormap(color), # type: ignore
|
||||
alpha=alpha,
|
||||
)
|
||||
|
||||
handles.append(patches.Patch(color=color, label=class_name))
|
||||
|
||||
ax.legend(handles=handles)
|
||||
return ax
|
||||
|
||||
|
||||
def create_colormap(color: str) -> Colormap:
|
||||
(r, g, b, a) = to_rgba(color)
|
||||
return LinearSegmentedColormap.from_list(
|
||||
"cmap", colors=[(0, 0, 0, 0), (r, g, b, a)]
|
||||
)
|
||||
|
||||
@ -1,21 +1,17 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Protocol, Tuple, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.axes import Axes
|
||||
from soundevent import data, plot
|
||||
from soundevent.geometry import compute_bounds
|
||||
from soundevent.plot.tags import TagColorMapper
|
||||
|
||||
from batdetect2.evaluate.types import MatchEvaluation
|
||||
from batdetect2.plotting.clip_predictions import plot_prediction
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.preprocess import (
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
get_default_preprocessor,
|
||||
RawPrediction,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"plot_matches",
|
||||
"plot_false_positive_match",
|
||||
"plot_true_positive_match",
|
||||
"plot_false_negative_match",
|
||||
@ -23,6 +19,14 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class MatchProtocol(Protocol):
|
||||
clip: data.Clip
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
score: float
|
||||
true_class: Optional[str]
|
||||
|
||||
|
||||
DEFAULT_DURATION = 0.05
|
||||
DEFAULT_FALSE_POSITIVE_COLOR = "orange"
|
||||
DEFAULT_FALSE_NEGATIVE_COLOR = "red"
|
||||
@ -32,278 +36,191 @@ DEFAULT_ANNOTATION_LINE_STYLE = "-"
|
||||
DEFAULT_PREDICTION_LINE_STYLE = "--"
|
||||
|
||||
|
||||
def plot_matches(
|
||||
matches: List[data.Match],
|
||||
clip: data.Clip,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
color_mapper: Optional[TagColorMapper] = None,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_points: bool = False,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||
) -> Axes:
|
||||
if preprocessor is None:
|
||||
preprocessor = get_default_preprocessor()
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
ax=ax,
|
||||
figsize=figsize,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
if color_mapper is None:
|
||||
color_mapper = TagColorMapper()
|
||||
|
||||
for match in matches:
|
||||
if match.source is None and match.target is not None:
|
||||
plot.plot_annotation(
|
||||
annotation=match.target,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
color=false_negative_color,
|
||||
color_mapper=color_mapper,
|
||||
linestyle=annotation_linestyle,
|
||||
)
|
||||
elif match.target is None and match.source is not None:
|
||||
plot_prediction(
|
||||
prediction=match.source,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
color=false_positive_color,
|
||||
color_mapper=color_mapper,
|
||||
linestyle=prediction_linestyle,
|
||||
)
|
||||
elif match.target is not None and match.source is not None:
|
||||
plot.plot_annotation(
|
||||
annotation=match.target,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
color=true_positive_color,
|
||||
color_mapper=color_mapper,
|
||||
linestyle=annotation_linestyle,
|
||||
)
|
||||
plot_prediction(
|
||||
prediction=match.source,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
color=true_positive_color,
|
||||
color_mapper=color_mapper,
|
||||
linestyle=prediction_linestyle,
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_false_positive_match(
|
||||
match: MatchEvaluation,
|
||||
match: MatchProtocol,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
use_score: bool = True,
|
||||
add_spectrogram: bool = True,
|
||||
add_text: bool = True,
|
||||
add_points: bool = False,
|
||||
add_title: bool = True,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
time_offset: float = 0,
|
||||
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||
fontsize: Union[float, str] = "small",
|
||||
) -> Axes:
|
||||
assert match.match.source is not None
|
||||
assert match.match.target is None
|
||||
sound_event = match.match.source.sound_event
|
||||
geometry = sound_event.geometry
|
||||
assert geometry is not None
|
||||
assert match.pred is not None
|
||||
|
||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||
start_time, _, _, high_freq = compute_bounds(match.pred.geometry)
|
||||
|
||||
clip = data.Clip(
|
||||
start_time=max(start_time - duration / 2, 0),
|
||||
start_time=max(
|
||||
start_time - duration / 2,
|
||||
0,
|
||||
),
|
||||
end_time=min(
|
||||
start_time + duration / 2,
|
||||
sound_event.recording.duration,
|
||||
match.clip.recording.duration,
|
||||
),
|
||||
recording=sound_event.recording,
|
||||
recording=match.clip.recording,
|
||||
)
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
if add_spectrogram:
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot_prediction(
|
||||
match.match.source,
|
||||
ax = plot.plot_geometry(
|
||||
match.pred.geometry,
|
||||
ax=ax,
|
||||
time_offset=time_offset,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
alpha=match.score if use_score else 1,
|
||||
color=color,
|
||||
)
|
||||
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
if add_text:
|
||||
ax.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"score={match.score:.2f}",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
if add_title:
|
||||
ax.set_title("False Positive")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_false_negative_match(
|
||||
match: MatchEvaluation,
|
||||
match: MatchProtocol,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
add_spectrogram: bool = True,
|
||||
add_points: bool = False,
|
||||
add_title: bool = True,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||
fontsize: Union[float, str] = "small",
|
||||
) -> Axes:
|
||||
assert match.match.source is None
|
||||
assert match.match.target is not None
|
||||
sound_event = match.match.target.sound_event
|
||||
geometry = sound_event.geometry
|
||||
assert match.gt is not None
|
||||
|
||||
geometry = match.gt.sound_event.geometry
|
||||
assert geometry is not None
|
||||
|
||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||
start_time = compute_bounds(geometry)[0]
|
||||
|
||||
clip = data.Clip(
|
||||
start_time=max(start_time - duration / 2, 0),
|
||||
end_time=min(
|
||||
start_time + duration / 2, sound_event.recording.duration
|
||||
start_time=max(
|
||||
start_time - duration / 2,
|
||||
0,
|
||||
),
|
||||
recording=sound_event.recording,
|
||||
end_time=min(
|
||||
start_time + duration / 2,
|
||||
match.clip.recording.duration,
|
||||
),
|
||||
recording=match.clip.recording,
|
||||
)
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
if add_spectrogram:
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot.plot_annotation(
|
||||
match.match.target,
|
||||
ax = plot.plot_geometry(
|
||||
geometry,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
color=color,
|
||||
)
|
||||
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Negative \nClass: {match.gt_class} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
if add_title:
|
||||
ax.set_title("False Negative")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_true_positive_match(
|
||||
match: MatchEvaluation,
|
||||
match: MatchProtocol,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
use_score: bool = True,
|
||||
add_spectrogram: bool = True,
|
||||
add_points: bool = False,
|
||||
add_text: bool = True,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||
fontsize: Union[float, str] = "small",
|
||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||
add_title: bool = True,
|
||||
) -> Axes:
|
||||
assert match.match.source is not None
|
||||
assert match.match.target is not None
|
||||
sound_event = match.match.target.sound_event
|
||||
geometry = sound_event.geometry
|
||||
assert match.gt is not None
|
||||
assert match.pred is not None
|
||||
|
||||
geometry = match.gt.sound_event.geometry
|
||||
assert geometry is not None
|
||||
|
||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||
|
||||
clip = data.Clip(
|
||||
start_time=max(start_time - duration / 2, 0),
|
||||
end_time=min(
|
||||
start_time + duration / 2, sound_event.recording.duration
|
||||
start_time=max(
|
||||
start_time - duration / 2,
|
||||
0,
|
||||
),
|
||||
recording=sound_event.recording,
|
||||
end_time=min(
|
||||
start_time + duration / 2,
|
||||
match.clip.recording.duration,
|
||||
),
|
||||
recording=match.clip.recording,
|
||||
)
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
if add_spectrogram:
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
ax=ax,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot.plot_annotation(
|
||||
match.match.target,
|
||||
ax = plot.plot_geometry(
|
||||
geometry,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
@ -311,41 +228,46 @@ def plot_true_positive_match(
|
||||
linestyle=annotation_linestyle,
|
||||
)
|
||||
|
||||
plot_prediction(
|
||||
match.match.source,
|
||||
plot.plot_geometry(
|
||||
match.pred.geometry,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
alpha=match.score if use_score else 1,
|
||||
color=color,
|
||||
linestyle=prediction_linestyle,
|
||||
)
|
||||
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
if add_text:
|
||||
ax.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"score={match.score:.2f}",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
if add_title:
|
||||
ax.set_title("True Positive")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_cross_trigger_match(
|
||||
match: MatchEvaluation,
|
||||
match: MatchProtocol,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
add_colorbar: bool = False,
|
||||
add_labels: bool = False,
|
||||
use_score: bool = True,
|
||||
add_spectrogram: bool = True,
|
||||
add_points: bool = False,
|
||||
add_text: bool = True,
|
||||
add_title: bool = True,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||
@ -353,38 +275,40 @@ def plot_cross_trigger_match(
|
||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||
) -> Axes:
|
||||
assert match.match.source is not None
|
||||
assert match.match.target is not None
|
||||
sound_event = match.match.source.sound_event
|
||||
geometry = sound_event.geometry
|
||||
assert match.gt is not None
|
||||
assert match.pred is not None
|
||||
|
||||
geometry = match.gt.sound_event.geometry
|
||||
assert geometry is not None
|
||||
|
||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||
|
||||
clip = data.Clip(
|
||||
start_time=max(start_time - duration / 2, 0),
|
||||
end_time=min(
|
||||
start_time + duration / 2, sound_event.recording.duration
|
||||
start_time=max(
|
||||
start_time - duration / 2,
|
||||
0,
|
||||
),
|
||||
recording=sound_event.recording,
|
||||
end_time=min(
|
||||
start_time + duration / 2,
|
||||
match.clip.recording.duration,
|
||||
),
|
||||
recording=match.clip.recording,
|
||||
)
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
add_colorbar=add_colorbar,
|
||||
add_labels=add_labels,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
if add_spectrogram:
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot.plot_annotation(
|
||||
match.match.target,
|
||||
ax = plot.plot_geometry(
|
||||
geometry,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
@ -392,26 +316,29 @@ def plot_cross_trigger_match(
|
||||
linestyle=annotation_linestyle,
|
||||
)
|
||||
|
||||
plot_prediction(
|
||||
match.match.source,
|
||||
ax = plot.plot_geometry(
|
||||
match.pred.geometry,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
alpha=match.score if use_score else 1,
|
||||
color=color,
|
||||
linestyle=prediction_linestyle,
|
||||
)
|
||||
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
if add_text:
|
||||
ax.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"score={match.score:.2f}\nclass={match.true_class}",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
if add_title:
|
||||
ax.set_title("Cross Trigger")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
286
src/batdetect2/plotting/metrics.py
Normal file
286
src/batdetect2/plotting/metrics.py
Normal file
@ -0,0 +1,286 @@
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
from cycler import cycler
|
||||
from matplotlib import axes
|
||||
|
||||
from batdetect2.plotting.common import create_ax
|
||||
|
||||
|
||||
def set_default_styler(ax: axes.Axes) -> axes.Axes:
|
||||
color_cycler = cycler(color=sns.color_palette("muted"))
|
||||
style_cycler = cycler(linestyle=["-", "--", ":"]) * cycler(
|
||||
marker=["o", "s", "^"]
|
||||
)
|
||||
custom_cycler = color_cycler * len(style_cycler) + style_cycler * len(
|
||||
color_cycler
|
||||
)
|
||||
|
||||
ax.set_prop_cycle(custom_cycler)
|
||||
return ax
|
||||
|
||||
|
||||
def set_default_style(ax: axes.Axes) -> axes.Axes:
|
||||
ax = set_default_styler(ax)
|
||||
ax.spines.right.set_visible(False)
|
||||
ax.spines.top.set_visible(False)
|
||||
return ax
|
||||
|
||||
|
||||
def plot_pr_curve(
|
||||
precision: np.ndarray,
|
||||
recall: np.ndarray,
|
||||
thresholds: np.ndarray,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_labels: bool = True,
|
||||
add_legend: bool = False,
|
||||
label: str = "PR Curve",
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
|
||||
ax = set_default_style(ax)
|
||||
|
||||
ax.plot(
|
||||
recall,
|
||||
precision,
|
||||
label=label,
|
||||
marker="o",
|
||||
markevery=_get_marker_positions(thresholds),
|
||||
)
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_legend:
|
||||
ax.legend()
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("Recall")
|
||||
ax.set_ylabel("Precision")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_pr_curves(
|
||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
ax = set_default_style(ax)
|
||||
|
||||
for name, (precision, recall, thresholds) in data.items():
|
||||
ax.plot(
|
||||
recall,
|
||||
precision,
|
||||
label=name,
|
||||
markevery=_get_marker_positions(thresholds),
|
||||
)
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("Recall")
|
||||
ax.set_ylabel("Precision")
|
||||
|
||||
if add_legend:
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.05, 1),
|
||||
loc="upper left",
|
||||
borderaxespad=0.0,
|
||||
)
|
||||
return ax
|
||||
|
||||
|
||||
def plot_threshold_precision_curve(
|
||||
threshold: np.ndarray,
|
||||
precision: np.ndarray,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_labels: bool = True,
|
||||
):
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
|
||||
ax = set_default_style(ax)
|
||||
|
||||
ax.plot(threshold, precision, markevery=_get_marker_positions(threshold))
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("Threshold")
|
||||
ax.set_ylabel("Precision")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_threshold_precision_curves(
|
||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
):
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
ax = set_default_style(ax)
|
||||
|
||||
for name, (precision, _, thresholds) in data.items():
|
||||
ax.plot(
|
||||
thresholds,
|
||||
precision,
|
||||
label=name,
|
||||
markevery=_get_marker_positions(thresholds),
|
||||
)
|
||||
|
||||
if add_legend:
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.05, 1),
|
||||
loc="upper left",
|
||||
borderaxespad=0.0,
|
||||
)
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("Threshold")
|
||||
ax.set_ylabel("Precision")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_threshold_recall_curve(
|
||||
threshold: np.ndarray,
|
||||
recall: np.ndarray,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_labels: bool = True,
|
||||
):
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
|
||||
ax = set_default_style(ax)
|
||||
|
||||
ax.plot(threshold, recall, markevery=_get_marker_positions(threshold))
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("Threshold")
|
||||
ax.set_ylabel("Recall")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_threshold_recall_curves(
|
||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
):
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
ax = set_default_style(ax)
|
||||
|
||||
for name, (_, recall, thresholds) in data.items():
|
||||
ax.plot(
|
||||
thresholds,
|
||||
recall,
|
||||
label=name,
|
||||
markevery=_get_marker_positions(thresholds),
|
||||
)
|
||||
|
||||
if add_legend:
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.05, 1),
|
||||
loc="upper left",
|
||||
borderaxespad=0.0,
|
||||
)
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("Threshold")
|
||||
ax.set_ylabel("Recall")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_roc_curve(
|
||||
fpr: np.ndarray,
|
||||
tpr: np.ndarray,
|
||||
thresholds: np.ndarray,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_labels: bool = True,
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
|
||||
ax = set_default_style(ax)
|
||||
|
||||
ax.plot(
|
||||
fpr,
|
||||
tpr,
|
||||
markevery=_get_marker_positions(thresholds),
|
||||
)
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("False Positive Rate")
|
||||
ax.set_ylabel("True Positive Rate")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_roc_curves(
|
||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
ax = set_default_style(ax)
|
||||
|
||||
for name, (fpr, tpr, thresholds) in data.items():
|
||||
ax.plot(
|
||||
fpr,
|
||||
tpr,
|
||||
label=name,
|
||||
markevery=_get_marker_positions(thresholds),
|
||||
)
|
||||
|
||||
if add_legend:
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.05, 1),
|
||||
loc="upper left",
|
||||
borderaxespad=0.0,
|
||||
)
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("False Positive Rate")
|
||||
ax.set_ylabel("True Positive Rate")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def _get_marker_positions(
|
||||
thresholds: np.ndarray,
|
||||
n_points: int = 11,
|
||||
) -> np.ndarray:
|
||||
size = len(thresholds)
|
||||
cut_points = np.linspace(0, 1, n_points)
|
||||
indices = np.searchsorted(thresholds[::-1], cut_points)
|
||||
return np.clip(size - indices, 0, size - 1) # type: ignore
|
||||
@ -1,598 +1,25 @@
|
||||
"""Main entry point for the BatDetect2 Postprocessing pipeline.
|
||||
"""Main entry point for the BatDetect2 Postprocessing pipeline."""
|
||||
|
||||
This package (`batdetect2.postprocess`) takes the raw outputs from a trained
|
||||
BatDetect2 neural network model and transforms them into meaningful, structured
|
||||
predictions, typically in the form of `soundevent.data.ClipPrediction` objects
|
||||
containing detected sound events with associated class tags and geometry.
|
||||
|
||||
The pipeline involves several configurable steps, implemented in submodules:
|
||||
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
|
||||
2. Coordinate Remapping (`.remapping`): Adds time/frequency coordinates to raw
|
||||
model output arrays.
|
||||
3. Detection Extraction (`.detection`): Identifies candidate detection points
|
||||
(location and score) based on thresholds and score ranking (top-k).
|
||||
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
|
||||
class probabilities, features) at the detected locations.
|
||||
5. Decoding & Formatting (`.decoding`): Converts extracted numerical data and
|
||||
class predictions into interpretable `soundevent` objects, including
|
||||
recovering geometry (ROIs) and decoding class names back to standard tags.
|
||||
|
||||
This module provides the primary interface:
|
||||
- `PostprocessConfig`: A configuration object for postprocessing parameters
|
||||
(thresholds, NMS kernel size, etc.).
|
||||
- `load_postprocess_config`: Function to load the configuration from a file.
|
||||
- `Postprocessor`: The main class (implementing `PostprocessorProtocol`) that
|
||||
holds the configured pipeline logic.
|
||||
- `build_postprocessor`: A factory function to create a `Postprocessor`
|
||||
instance, linking it to the necessary target definitions (`TargetProtocol`).
|
||||
It also re-exports key components from submodules for convenience.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import xarray as xr
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.postprocess.config import (
|
||||
PostprocessConfig,
|
||||
load_postprocess_config,
|
||||
)
|
||||
from batdetect2.postprocess.decoding import (
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
convert_raw_prediction_to_sound_event_prediction,
|
||||
convert_raw_predictions_to_clip_prediction,
|
||||
convert_xr_dataset_to_raw_prediction,
|
||||
to_raw_predictions,
|
||||
)
|
||||
from batdetect2.postprocess.detection import (
|
||||
DEFAULT_DETECTION_THRESHOLD,
|
||||
TOP_K_PER_SEC,
|
||||
extract_detections_from_array,
|
||||
get_max_detections,
|
||||
from batdetect2.postprocess.nms import non_max_suppression
|
||||
from batdetect2.postprocess.postprocessor import (
|
||||
Postprocessor,
|
||||
build_postprocessor,
|
||||
)
|
||||
from batdetect2.postprocess.extraction import (
|
||||
extract_detection_xr_dataset,
|
||||
)
|
||||
from batdetect2.postprocess.nms import (
|
||||
NMS_KERNEL_SIZE,
|
||||
non_max_suppression,
|
||||
)
|
||||
from batdetect2.postprocess.remapping import (
|
||||
classification_to_xarray,
|
||||
detection_to_xarray,
|
||||
features_to_xarray,
|
||||
sizes_to_xarray,
|
||||
)
|
||||
from batdetect2.postprocess.types import (
|
||||
BatDetect2Prediction,
|
||||
PostprocessorProtocol,
|
||||
RawPrediction,
|
||||
)
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||
"DEFAULT_DETECTION_THRESHOLD",
|
||||
"MAX_FREQ",
|
||||
"MIN_FREQ",
|
||||
"ModelOutput",
|
||||
"NMS_KERNEL_SIZE",
|
||||
"PostprocessConfig",
|
||||
"Postprocessor",
|
||||
"PostprocessorProtocol",
|
||||
"RawPrediction",
|
||||
"TOP_K_PER_SEC",
|
||||
"build_postprocessor",
|
||||
"classification_to_xarray",
|
||||
"convert_raw_predictions_to_clip_prediction",
|
||||
"convert_xr_dataset_to_raw_prediction",
|
||||
"detection_to_xarray",
|
||||
"extract_detection_xr_dataset",
|
||||
"extract_detections_from_array",
|
||||
"features_to_xarray",
|
||||
"get_max_detections",
|
||||
"to_raw_predictions",
|
||||
"load_postprocess_config",
|
||||
"non_max_suppression",
|
||||
"sizes_to_xarray",
|
||||
]
|
||||
|
||||
|
||||
class PostprocessConfig(BaseConfig):
|
||||
"""Configuration settings for the postprocessing pipeline.
|
||||
|
||||
Defines tunable parameters that control how raw model outputs are
|
||||
converted into final detections.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
nms_kernel_size : int, default=NMS_KERNEL_SIZE
|
||||
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
|
||||
Used to suppress weaker detections near stronger peaks. Must be
|
||||
positive.
|
||||
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
|
||||
Minimum confidence score from the detection heatmap required to
|
||||
consider a point as a potential detection. Must be >= 0.
|
||||
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||
Minimum confidence score for a specific class prediction to be included
|
||||
in the decoded tags for a detection. Must be >= 0.
|
||||
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||
Desired maximum number of detections per second of audio. Used by
|
||||
`get_max_detections` to calculate an absolute limit based on clip
|
||||
duration before applying `extract_detections_from_array`. Must be
|
||||
positive.
|
||||
"""
|
||||
|
||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||
detection_threshold: float = Field(
|
||||
default=DEFAULT_DETECTION_THRESHOLD,
|
||||
ge=0,
|
||||
)
|
||||
classification_threshold: float = Field(
|
||||
default=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
ge=0,
|
||||
)
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
def load_postprocess_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PostprocessConfig:
|
||||
"""Load the postprocessing configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`PostprocessConfig` schema, potentially extracting data from a nested
|
||||
field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
postprocessing configuration (e.g., "inference.postprocessing").
|
||||
If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PostprocessConfig
|
||||
The loaded and validated postprocessing configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded configuration data does not conform to the
|
||||
`PostprocessConfig` schema.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path within the loaded data.
|
||||
"""
|
||||
return load_config(path, schema=PostprocessConfig, field=field)
|
||||
|
||||
|
||||
def build_postprocessor(
|
||||
targets: TargetProtocol,
|
||||
config: Optional[PostprocessConfig] = None,
|
||||
max_freq: float = MAX_FREQ,
|
||||
min_freq: float = MIN_FREQ,
|
||||
) -> PostprocessorProtocol:
|
||||
"""Factory function to build the standard postprocessor.
|
||||
|
||||
Creates and initializes the `Postprocessor` instance, providing it with the
|
||||
necessary `targets` object and the `PostprocessConfig`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
targets : TargetProtocol
|
||||
An initialized object conforming to the `TargetProtocol`, providing
|
||||
methods like `.decode()` and `.recover_roi()`, and attributes like
|
||||
`.class_names` and `.generic_class_tags`. This links postprocessing
|
||||
to the defined target semantics and geometry mappings.
|
||||
config : PostprocessConfig, optional
|
||||
Configuration object specifying postprocessing parameters (thresholds,
|
||||
NMS kernel size, etc.). If None, default settings defined in
|
||||
`PostprocessConfig` will be used.
|
||||
min_freq : int, default=MIN_FREQ
|
||||
The minimum frequency (Hz) corresponding to the frequency axis of the
|
||||
model outputs. Required for coordinate remapping. Consider setting via
|
||||
`PostprocessConfig` instead for better encapsulation.
|
||||
max_freq : int, default=MAX_FREQ
|
||||
The maximum frequency (Hz) corresponding to the frequency axis of the
|
||||
model outputs. Required for coordinate remapping. Consider setting via
|
||||
`PostprocessConfig`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PostprocessorProtocol
|
||||
An initialized `Postprocessor` instance ready to process model outputs.
|
||||
"""
|
||||
config = config or PostprocessConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building postprocessor with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
return Postprocessor(
|
||||
targets=targets,
|
||||
config=config,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
|
||||
|
||||
class Postprocessor(PostprocessorProtocol):
|
||||
"""Standard implementation of the postprocessing pipeline.
|
||||
|
||||
This class orchestrates the steps required to convert raw model outputs
|
||||
into interpretable `soundevent` predictions. It uses configured parameters
|
||||
and leverages functions from the `batdetect2.postprocess` submodules for
|
||||
each stage (NMS, remapping, detection, extraction, decoding).
|
||||
|
||||
It requires a `TargetProtocol` object during initialization to access
|
||||
necessary decoding information (class name to tag mapping,
|
||||
ROI recovery logic) ensuring consistency with the target definitions used
|
||||
during training or specified for inference.
|
||||
|
||||
Instances are typically created using the `build_postprocessor` factory
|
||||
function.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
targets : TargetProtocol
|
||||
The configured target definition object providing decoding and ROI
|
||||
recovery.
|
||||
config : PostprocessConfig
|
||||
Configuration object holding parameters for NMS, thresholds, etc.
|
||||
min_freq : float
|
||||
Minimum frequency (Hz) assumed for the model output's frequency axis.
|
||||
max_freq : float
|
||||
Maximum frequency (Hz) assumed for the model output's frequency axis.
|
||||
"""
|
||||
|
||||
targets: TargetProtocol
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
config: PostprocessConfig,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
):
|
||||
"""Initialize the Postprocessor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
targets : TargetProtocol
|
||||
Initialized target definition object.
|
||||
config : PostprocessConfig
|
||||
Configuration for postprocessing parameters.
|
||||
min_freq : int, default=MIN_FREQ
|
||||
Minimum frequency (Hz) for coordinate remapping.
|
||||
max_freq : int, default=MAX_FREQ
|
||||
Maximum frequency (Hz) for coordinate remapping.
|
||||
"""
|
||||
self.targets = targets
|
||||
self.config = config
|
||||
self.min_freq = min_freq
|
||||
self.max_freq = max_freq
|
||||
|
||||
def get_feature_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Extract and remap raw feature tensors for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw model output containing `output.features` tensor for the batch.
|
||||
clips : List[data.Clip]
|
||||
List of Clip objects corresponding to the batch items.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
List of coordinate-aware feature DataArrays, one per clip.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If batch sizes of `output.features` and `clips` do not match.
|
||||
"""
|
||||
if len(clips) != len(output.features):
|
||||
raise ValueError(
|
||||
"Number of clips and batch size of feature array"
|
||||
"do not match. "
|
||||
f"(clips: {len(clips)}, features: {len(output.features)})"
|
||||
)
|
||||
|
||||
return [
|
||||
features_to_xarray(
|
||||
feats,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for feats, clip in zip(output.features, clips)
|
||||
]
|
||||
|
||||
def get_detection_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Apply NMS and remap detection heatmaps for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw model output containing `output.detection_probs` tensor for the
|
||||
batch.
|
||||
clips : List[data.Clip]
|
||||
List of Clip objects corresponding to the batch items.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
List of NMS-applied, coordinate-aware detection heatmaps, one per
|
||||
clip.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If batch sizes of `output.detection_probs` and `clips` do not match.
|
||||
"""
|
||||
detections = output.detection_probs
|
||||
|
||||
if len(clips) != len(output.detection_probs):
|
||||
raise ValueError(
|
||||
"Number of clips and batch size of detection array "
|
||||
"do not match. "
|
||||
f"(clips: {len(clips)}, detection: {len(detections)})"
|
||||
)
|
||||
|
||||
detections = non_max_suppression(
|
||||
detections,
|
||||
kernel_size=self.config.nms_kernel_size,
|
||||
)
|
||||
|
||||
return [
|
||||
detection_to_xarray(
|
||||
dets,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for dets, clip in zip(detections, clips)
|
||||
]
|
||||
|
||||
def get_classification_arrays(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[xr.DataArray]:
|
||||
"""Extract and remap raw classification tensors for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw model output containing `output.class_probs` tensor for the
|
||||
batch.
|
||||
clips : List[data.Clip]
|
||||
List of Clip objects corresponding to the batch items.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
List of coordinate-aware class probability maps, one per clip.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If batch sizes of `output.class_probs` and `clips` do not match, or
|
||||
if number of classes mismatches `self.targets.class_names`.
|
||||
"""
|
||||
classifications = output.class_probs
|
||||
|
||||
if len(clips) != len(classifications):
|
||||
raise ValueError(
|
||||
"Number of clips and batch size of classification array "
|
||||
"do not match. "
|
||||
f"(clips: {len(clips)}, classification: {len(classifications)})"
|
||||
)
|
||||
|
||||
return [
|
||||
classification_to_xarray(
|
||||
class_probs,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
class_names=self.targets.class_names,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for class_probs, clip in zip(classifications, clips)
|
||||
]
|
||||
|
||||
def get_sizes_arrays(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[xr.DataArray]:
|
||||
"""Extract and remap raw size prediction tensors for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw model output containing `output.size_preds` tensor for the
|
||||
batch.
|
||||
clips : List[data.Clip]
|
||||
List of Clip objects corresponding to the batch items.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
List of coordinate-aware size prediction maps, one per clip.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If batch sizes of `output.size_preds` and `clips` do not match.
|
||||
"""
|
||||
sizes = output.size_preds
|
||||
|
||||
if len(clips) != len(sizes):
|
||||
raise ValueError(
|
||||
"Number of clips and batch size of sizes array do not match. "
|
||||
f"(clips: {len(clips)}, sizes: {len(sizes)})"
|
||||
)
|
||||
|
||||
return [
|
||||
sizes_to_xarray(
|
||||
size_preds,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.end_time,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for size_preds, clip in zip(sizes, clips)
|
||||
]
|
||||
|
||||
def get_detection_datasets(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[xr.Dataset]:
|
||||
"""Perform NMS, remapping, detection, and data extraction for a batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.Dataset]
|
||||
List of xarray Datasets (one per clip). Each Dataset contains
|
||||
aligned scores, dimensions, class probabilities, and features for
|
||||
detections found in that clip.
|
||||
"""
|
||||
detection_arrays = self.get_detection_arrays(output, clips)
|
||||
classification_arrays = self.get_classification_arrays(output, clips)
|
||||
size_arrays = self.get_sizes_arrays(output, clips)
|
||||
features_arrays = self.get_feature_arrays(output, clips)
|
||||
|
||||
datasets = []
|
||||
for det_array, class_array, sizes_array, feats_array in zip(
|
||||
detection_arrays,
|
||||
classification_arrays,
|
||||
size_arrays,
|
||||
features_arrays,
|
||||
):
|
||||
max_detections = get_max_detections(
|
||||
det_array,
|
||||
top_k_per_sec=self.config.top_k_per_sec,
|
||||
)
|
||||
|
||||
positions = extract_detections_from_array(
|
||||
det_array,
|
||||
max_detections=max_detections,
|
||||
threshold=self.config.detection_threshold,
|
||||
)
|
||||
|
||||
datasets.append(
|
||||
extract_detection_xr_dataset(
|
||||
positions,
|
||||
sizes_array,
|
||||
class_array,
|
||||
feats_array,
|
||||
)
|
||||
)
|
||||
|
||||
return datasets
|
||||
|
||||
def get_raw_predictions(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[List[RawPrediction]]:
|
||||
"""Extract intermediate RawPrediction objects for a batch.
|
||||
|
||||
Processes raw model output through remapping, NMS, detection, data
|
||||
extraction, and geometry recovery via the configured
|
||||
`targets.recover_roi`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[List[RawPrediction]]
|
||||
List of lists (one inner list per input clip). Each inner list
|
||||
contains `RawPrediction` objects for detections in that clip.
|
||||
"""
|
||||
detection_datasets = self.get_detection_datasets(output, clips)
|
||||
return [
|
||||
convert_xr_dataset_to_raw_prediction(
|
||||
dataset,
|
||||
self.targets.decode_roi,
|
||||
)
|
||||
for dataset in detection_datasets
|
||||
]
|
||||
|
||||
def get_sound_event_predictions(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[List[BatDetect2Prediction]]:
|
||||
raw_predictions = self.get_raw_predictions(output, clips)
|
||||
return [
|
||||
[
|
||||
BatDetect2Prediction(
|
||||
raw=raw,
|
||||
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
||||
raw,
|
||||
recording=clip.recording,
|
||||
targets=self.targets,
|
||||
classification_threshold=self.config.classification_threshold,
|
||||
),
|
||||
)
|
||||
for raw in predictions
|
||||
]
|
||||
for predictions, clip in zip(raw_predictions, clips)
|
||||
]
|
||||
|
||||
def get_predictions(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[data.ClipPrediction]:
|
||||
"""Perform the full postprocessing pipeline for a batch.
|
||||
|
||||
Takes raw model output and corresponding clips, applies the entire
|
||||
configured chain (NMS, remapping, extraction, geometry recovery, class
|
||||
decoding), producing final `soundevent.data.ClipPrediction` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.ClipPrediction]
|
||||
List containing one `ClipPrediction` object for each input clip,
|
||||
populated with `SoundEventPrediction` objects.
|
||||
"""
|
||||
raw_predictions = self.get_raw_predictions(output, clips)
|
||||
return [
|
||||
convert_raw_predictions_to_clip_prediction(
|
||||
prediction,
|
||||
clip,
|
||||
targets=self.targets,
|
||||
classification_threshold=self.config.classification_threshold,
|
||||
)
|
||||
for prediction, clip in zip(raw_predictions, clips)
|
||||
]
|
||||
|
||||
94
src/batdetect2/postprocess/config.py
Normal file
94
src/batdetect2/postprocess/config.py
Normal file
@ -0,0 +1,94 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.postprocess.decoding import DEFAULT_CLASSIFICATION_THRESHOLD
|
||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE
|
||||
|
||||
__all__ = [
|
||||
"PostprocessConfig",
|
||||
"load_postprocess_config",
|
||||
]
|
||||
|
||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||
|
||||
|
||||
TOP_K_PER_SEC = 100
|
||||
|
||||
|
||||
class PostprocessConfig(BaseConfig):
|
||||
"""Configuration settings for the postprocessing pipeline.
|
||||
|
||||
Defines tunable parameters that control how raw model outputs are
|
||||
converted into final detections.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
nms_kernel_size : int, default=NMS_KERNEL_SIZE
|
||||
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
|
||||
Used to suppress weaker detections near stronger peaks. Must be
|
||||
positive.
|
||||
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
|
||||
Minimum confidence score from the detection heatmap required to
|
||||
consider a point as a potential detection. Must be >= 0.
|
||||
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||
Minimum confidence score for a specific class prediction to be included
|
||||
in the decoded tags for a detection. Must be >= 0.
|
||||
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||
Desired maximum number of detections per second of audio. Used by
|
||||
`get_max_detections` to calculate an absolute limit based on clip
|
||||
duration before applying `extract_detections_from_array`. Must be
|
||||
positive.
|
||||
"""
|
||||
|
||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||
detection_threshold: float = Field(
|
||||
default=DEFAULT_DETECTION_THRESHOLD,
|
||||
ge=0,
|
||||
)
|
||||
classification_threshold: float = Field(
|
||||
default=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
ge=0,
|
||||
)
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
def load_postprocess_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PostprocessConfig:
|
||||
"""Load the postprocessing configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`PostprocessConfig` schema, potentially extracting data from a nested
|
||||
field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
postprocessing configuration (e.g., "inference.postprocessing").
|
||||
If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PostprocessConfig
|
||||
The loaded and validated postprocessing configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded configuration data does not conform to the
|
||||
`PostprocessConfig` schema.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path within the loaded data.
|
||||
"""
|
||||
return load_config(path, schema=PostprocessConfig, field=field)
|
||||
@ -1,42 +1,18 @@
|
||||
"""Decodes extracted detection data into standard soundevent predictions.
|
||||
|
||||
This module handles the final stages of the BatDetect2 postprocessing pipeline.
|
||||
It takes the structured detection data extracted by the `extraction` module
|
||||
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
|
||||
class probabilities, and features for each detection point) and converts it
|
||||
into standardized prediction objects based on the `soundevent` data model.
|
||||
|
||||
The process involves:
|
||||
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
|
||||
objects, using a configured geometry builder to recover bounding boxes from
|
||||
predicted positions and sizes (`convert_xr_dataset_to_raw_prediction`).
|
||||
2. Converting each `RawPrediction` into a
|
||||
`soundevent.data.SoundEventPrediction`, which involves:
|
||||
- Creating the `soundevent.data.SoundEvent` with geometry and features.
|
||||
- Decoding the predicted class probabilities into representative tags using
|
||||
a configured class decoder (`SoundEventDecoder`).
|
||||
- Applying a classification threshold.
|
||||
- Optionally selecting only the single highest-scoring class (top-1) or
|
||||
including tags for all classes above the threshold (multi-label).
|
||||
- Adding generic class tags as a baseline.
|
||||
- Associating scores with the final prediction and tags.
|
||||
(`convert_raw_prediction_to_sound_event_prediction`)
|
||||
3. Grouping the `SoundEventPrediction` objects for a given audio segment into
|
||||
a `soundevent.data.ClipPrediction`
|
||||
(`convert_raw_predictions_to_clip_prediction`).
|
||||
"""
|
||||
"""Decodes extracted detection data into standard soundevent predictions."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.typing.postprocess import (
|
||||
ClipDetectionsArray,
|
||||
RawPrediction,
|
||||
)
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"convert_xr_dataset_to_raw_prediction",
|
||||
"to_raw_predictions",
|
||||
"convert_raw_predictions_to_clip_prediction",
|
||||
"convert_raw_prediction_to_sound_event_prediction",
|
||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||
@ -51,65 +27,29 @@ decoding.
|
||||
"""
|
||||
|
||||
|
||||
def convert_xr_dataset_to_raw_prediction(
|
||||
detection_dataset: xr.Dataset,
|
||||
geometry_decoder: GeometryDecoder,
|
||||
def to_raw_predictions(
|
||||
detections: ClipDetectionsArray,
|
||||
targets: TargetProtocol,
|
||||
) -> List[RawPrediction]:
|
||||
"""Convert an xarray.Dataset of detections to RawPrediction objects.
|
||||
|
||||
Takes the output of the extraction step (`extract_detection_xr_dataset`)
|
||||
and transforms each detection entry into an intermediate `RawPrediction`
|
||||
object. This involves recovering the geometry (e.g., bounding box) from
|
||||
the predicted position and scaled size dimensions using the provided
|
||||
`geometry_builder` function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_dataset : xr.Dataset
|
||||
An xarray Dataset containing aligned detection information, typically
|
||||
output by `extract_detection_xr_dataset`. Expected variables include
|
||||
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
|
||||
Must have a 'detection' dimension.
|
||||
geometry_decoder : GeometryDecoder
|
||||
A function that takes a position tuple `(time, freq)` and a NumPy array
|
||||
of dimensions, and returns the corresponding reconstructed
|
||||
`soundevent.data.Geometry`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[RawPrediction]
|
||||
A list of `RawPrediction` objects, each containing the detection score,
|
||||
recovered bounding box coordinates (start/end time, low/high freq),
|
||||
the vector of class scores, and the feature vector for one detection.
|
||||
|
||||
Raises
|
||||
------
|
||||
AttributeError, KeyError, ValueError
|
||||
If `detection_dataset` is missing expected variables ('scores',
|
||||
'dimensions', 'classes', 'features') or coordinates ('time', 'freq'
|
||||
associated with 'scores'), or if `geometry_builder` fails.
|
||||
"""
|
||||
detections = []
|
||||
|
||||
categories = detection_dataset.category.values
|
||||
predictions = []
|
||||
|
||||
for score, class_scores, time, freq, dims, feats in zip(
|
||||
detection_dataset["scores"].values,
|
||||
detection_dataset["classes"].values,
|
||||
detection_dataset["time"].values,
|
||||
detection_dataset["frequency"].values,
|
||||
detection_dataset["dimensions"].values,
|
||||
detection_dataset["features"].values,
|
||||
detections.scores,
|
||||
detections.class_scores,
|
||||
detections.times,
|
||||
detections.frequencies,
|
||||
detections.sizes,
|
||||
detections.features,
|
||||
):
|
||||
highest_scoring_class = categories[class_scores.argmax()]
|
||||
highest_scoring_class = targets.class_names[class_scores.argmax()]
|
||||
|
||||
geom = geometry_decoder(
|
||||
geom = targets.decode_roi(
|
||||
(time, freq),
|
||||
dims,
|
||||
class_name=highest_scoring_class,
|
||||
)
|
||||
|
||||
detections.append(
|
||||
predictions.append(
|
||||
RawPrediction(
|
||||
detection_score=score,
|
||||
geometry=geom,
|
||||
@ -118,7 +58,7 @@ def convert_xr_dataset_to_raw_prediction(
|
||||
)
|
||||
)
|
||||
|
||||
return detections
|
||||
return predictions
|
||||
|
||||
|
||||
def convert_raw_predictions_to_clip_prediction(
|
||||
@ -128,35 +68,7 @@ def convert_raw_predictions_to_clip_prediction(
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only: bool = False,
|
||||
) -> data.ClipPrediction:
|
||||
"""Convert a list of RawPredictions into a soundevent ClipPrediction.
|
||||
|
||||
Iterates through `raw_predictions` (assumed to belong to a single clip),
|
||||
converts each one into a `soundevent.data.SoundEventPrediction` using
|
||||
`convert_raw_prediction_to_sound_event_prediction`, and packages them
|
||||
into a `soundevent.data.ClipPrediction` associated with the original `clip`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw_predictions : List[RawPrediction]
|
||||
List of raw prediction objects for a single clip.
|
||||
clip : data.Clip
|
||||
The original `soundevent.data.Clip` object these predictions belong to.
|
||||
sound_event_decoder : SoundEventDecoder
|
||||
Function to decode class names into representative tags.
|
||||
generic_class_tags : List[data.Tag]
|
||||
List of tags representing the generic class category.
|
||||
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||
Threshold applied to class scores during decoding.
|
||||
top_class_only : bool, default=False
|
||||
If True, only decode tags for the single highest-scoring class above
|
||||
the threshold. If False, decode tags for all classes above threshold.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.ClipPrediction
|
||||
A `ClipPrediction` object containing a list of `SoundEventPrediction`
|
||||
objects corresponding to the input `raw_predictions`.
|
||||
"""
|
||||
"""Convert a list of RawPredictions into a soundevent ClipPrediction."""
|
||||
return data.ClipPrediction(
|
||||
clip=clip,
|
||||
sound_events=[
|
||||
@ -181,68 +93,7 @@ def convert_raw_prediction_to_sound_event_prediction(
|
||||
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only: bool = False,
|
||||
):
|
||||
"""Convert a single RawPrediction into a soundevent SoundEventPrediction.
|
||||
|
||||
This function performs the core decoding steps for a single detected event:
|
||||
1. Creates a `soundevent.data.SoundEvent` containing the geometry
|
||||
(BoundingBox derived from `raw_prediction` bounds) and any associated
|
||||
feature vectors.
|
||||
2. Initializes a list of predicted tags using the provided
|
||||
`generic_class_tags`, assigning the overall `detection_score` from the
|
||||
`raw_prediction` to these generic tags.
|
||||
3. Processes the `class_scores` from the `raw_prediction`:
|
||||
a. Optionally filters out scores below `classification_threshold`
|
||||
(if it's not None).
|
||||
b. Sorts the remaining scores in descending order.
|
||||
c. Iterates through the sorted, thresholded class scores.
|
||||
d. For each class, uses the `sound_event_decoder` to get the
|
||||
representative base tags for that class name.
|
||||
e. Wraps these base tags in `soundevent.data.PredictedTag`, associating
|
||||
the specific `score` of that class prediction.
|
||||
f. Appends these specific predicted tags to the list.
|
||||
g. If `top_class_only` is True, stops after processing the first
|
||||
(highest-scoring) class that passed the threshold.
|
||||
4. Creates and returns the final `soundevent.data.SoundEventPrediction`,
|
||||
associating the `SoundEvent`, the overall `detection_score`, and the
|
||||
compiled list of `PredictedTag` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw_prediction : RawPrediction
|
||||
The raw prediction object containing score, bounds, class scores,
|
||||
features. Assumes `class_scores` is an `xr.DataArray` with a 'category'
|
||||
coordinate. Assumes `features` is an `xr.DataArray` with a 'feature'
|
||||
coordinate.
|
||||
recording : data.Recording
|
||||
The recording the sound event belongs to.
|
||||
sound_event_decoder : SoundEventDecoder
|
||||
Configured function mapping class names (str) to lists of base
|
||||
`data.Tag` objects.
|
||||
generic_class_tags : List[data.Tag]
|
||||
List of base tags representing the generic category.
|
||||
classification_threshold : float, optional
|
||||
The minimum score a class prediction must have to be considered
|
||||
significant enough to have its tags decoded and added. If None, no
|
||||
thresholding is applied based on class score (all predicted classes,
|
||||
or the top one if `top_class_only` is True, will be processed).
|
||||
Defaults to `DEFAULT_CLASSIFICATION_THRESHOLD`.
|
||||
top_class_only : bool, default=False
|
||||
If True, only includes tags for the single highest-scoring class that
|
||||
exceeds the threshold. If False (default), includes tags for all classes
|
||||
exceeding the threshold.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.SoundEventPrediction
|
||||
The fully formed sound event prediction object.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `raw_prediction.features` has unexpected structure or if
|
||||
`data.term_from_key` (if used internally) fails.
|
||||
If `sound_event_decoder` fails for a class name and errors are raised.
|
||||
"""
|
||||
"""Convert a single RawPrediction into a soundevent SoundEventPrediction."""
|
||||
sound_event = data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=raw_prediction.geometry,
|
||||
@ -252,7 +103,7 @@ def convert_raw_prediction_to_sound_event_prediction(
|
||||
tags = [
|
||||
*get_generic_tags(
|
||||
raw_prediction.detection_score,
|
||||
generic_class_tags=targets.generic_class_tags,
|
||||
generic_class_tags=targets.detection_class_tags,
|
||||
),
|
||||
*get_class_tags(
|
||||
raw_prediction.class_scores,
|
||||
@ -273,25 +124,7 @@ def get_generic_tags(
|
||||
detection_score: float,
|
||||
generic_class_tags: List[data.Tag],
|
||||
) -> List[data.PredictedTag]:
|
||||
"""Create PredictedTag objects for the generic category.
|
||||
|
||||
Takes the base list of generic tags and assigns the overall detection
|
||||
score to each one, wrapping them in `PredictedTag` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_score : float
|
||||
The overall confidence score of the detection event.
|
||||
generic_class_tags : List[data.Tag]
|
||||
The list of base `soundevent.data.Tag` objects that define the
|
||||
generic category (e.g., ['call_type:Echolocation', 'order:Chiroptera']).
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.PredictedTag]
|
||||
A list of `PredictedTag` objects for the generic category, each
|
||||
assigned the `detection_score`.
|
||||
"""
|
||||
"""Create PredictedTag objects for the generic category."""
|
||||
return [
|
||||
data.PredictedTag(tag=tag, score=detection_score)
|
||||
for tag in generic_class_tags
|
||||
@ -299,25 +132,7 @@ def get_generic_tags(
|
||||
|
||||
|
||||
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
|
||||
"""Convert an extracted feature vector DataArray into soundevent Features.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
features : xr.DataArray
|
||||
A 1D xarray DataArray containing feature values, indexed by a coordinate
|
||||
named 'feature' which holds the feature names (e.g., output of selecting
|
||||
features for one detection from `extract_detection_xr_dataset`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.Feature]
|
||||
A list of `soundevent.data.Feature` objects.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- This function creates basic `Term` objects using the feature coordinate
|
||||
names with a "batdetect2:" prefix.
|
||||
"""
|
||||
"""Convert an extracted feature vector DataArray into soundevent Features."""
|
||||
return [
|
||||
data.Feature(
|
||||
term=data.Term(
|
||||
|
||||
@ -1,162 +0,0 @@
|
||||
"""Extracts candidate detection points from a model output heatmap.
|
||||
|
||||
This module implements Step 3 within the BatDetect2 postprocessing
|
||||
pipeline. Its primary function is to identify potential sound event locations
|
||||
by finding peaks (local maxima or high-scoring points) in the detection heatmap
|
||||
produced by the neural network (usually after Non-Maximum Suppression and
|
||||
coordinate remapping have been applied).
|
||||
|
||||
It provides functionality to:
|
||||
- Identify the locations (time, frequency) of the highest-scoring points.
|
||||
- Filter these points based on a minimum confidence score threshold.
|
||||
- Limit the maximum number of detection points returned (top-k).
|
||||
|
||||
The main output is an `xarray.DataArray` containing the scores and
|
||||
corresponding time/frequency coordinates for the extracted detection points.
|
||||
This output serves as the input for subsequent postprocessing steps, such as
|
||||
extracting predicted class probabilities and bounding box sizes at these
|
||||
specific locations.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions, get_dim_width
|
||||
|
||||
__all__ = [
|
||||
"extract_detections_from_array",
|
||||
"get_max_detections",
|
||||
"DEFAULT_DETECTION_THRESHOLD",
|
||||
"TOP_K_PER_SEC",
|
||||
]
|
||||
|
||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||
"""Default confidence score threshold used for filtering detections."""
|
||||
|
||||
TOP_K_PER_SEC = 200
|
||||
"""Default desired maximum number of detections per second of audio."""
|
||||
|
||||
|
||||
def extract_detections_from_array(
|
||||
detection_array: xr.DataArray,
|
||||
max_detections: Optional[int] = None,
|
||||
threshold: Optional[float] = DEFAULT_DETECTION_THRESHOLD,
|
||||
) -> xr.DataArray:
|
||||
"""Extract detection locations (time, freq) and scores from a heatmap.
|
||||
|
||||
Identifies the pixels with the highest scores in the input detection
|
||||
heatmap, filters them based on an optional score `threshold`, limits the
|
||||
number to an optional `max_detections`, and returns their scores along with
|
||||
their corresponding time and frequency coordinates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_array : xr.DataArray
|
||||
A 2D xarray DataArray representing the detection heatmap. Must have
|
||||
dimensions and coordinates named 'time' and 'frequency'. Higher values
|
||||
are assumed to indicate higher detection confidence.
|
||||
max_detections : int, optional
|
||||
The absolute maximum number of detections to return. If specified, only
|
||||
the top `max_detections` highest-scoring detections (passing the
|
||||
threshold) are returned. If None (default), all detections passing
|
||||
the threshold are returned, sorted by score.
|
||||
threshold : float, optional
|
||||
The minimum confidence score required for a detection peak to be
|
||||
kept. Detections with scores below this value are discarded.
|
||||
Defaults to `DEFAULT_DETECTION_THRESHOLD`. If set to None, no
|
||||
thresholding is applied.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
A 1D xarray DataArray named 'score' with a 'detection' dimension.
|
||||
- The data values are the scores of the extracted detections, sorted
|
||||
in descending order.
|
||||
- It has coordinates 'time' and 'frequency' (also indexed by the
|
||||
'detection' dimension) indicating the location of each detection
|
||||
peak in the original coordinate system.
|
||||
- Returns an empty DataArray if no detections pass the criteria.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `max_detections` is not None and not a positive integer, or if
|
||||
`detection_array` lacks required dimensions/coordinates.
|
||||
"""
|
||||
if max_detections is not None:
|
||||
if max_detections <= 0:
|
||||
raise ValueError("Max detections must be positive")
|
||||
|
||||
values = detection_array.values.flatten()
|
||||
|
||||
if max_detections is not None:
|
||||
top_indices = np.argpartition(-values, max_detections)[:max_detections]
|
||||
top_sorted_indices = top_indices[np.argsort(-values[top_indices])]
|
||||
else:
|
||||
top_sorted_indices = np.argsort(-values)
|
||||
|
||||
top_values = values[top_sorted_indices]
|
||||
|
||||
if threshold is not None:
|
||||
mask = top_values > threshold
|
||||
top_values = top_values[mask]
|
||||
top_sorted_indices = top_sorted_indices[mask]
|
||||
|
||||
freq_indices, time_indices = np.unravel_index(
|
||||
top_sorted_indices,
|
||||
detection_array.shape,
|
||||
)
|
||||
|
||||
times = detection_array.coords[Dimensions.time.value].values[time_indices]
|
||||
freqs = detection_array.coords[Dimensions.frequency.value].values[
|
||||
freq_indices
|
||||
]
|
||||
|
||||
return xr.DataArray(
|
||||
data=top_values,
|
||||
coords={
|
||||
Dimensions.frequency.value: ("detection", freqs),
|
||||
Dimensions.time.value: ("detection", times),
|
||||
},
|
||||
dims="detection",
|
||||
name="score",
|
||||
)
|
||||
|
||||
|
||||
def get_max_detections(
|
||||
detection_array: xr.DataArray,
|
||||
top_k_per_sec: int = TOP_K_PER_SEC,
|
||||
) -> int:
|
||||
"""Calculate max detections allowed based on duration and rate.
|
||||
|
||||
Determines the total maximum number of detections to extract from a
|
||||
heatmap based on its time duration and a desired rate of detections
|
||||
per second.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
detection_array : xr.DataArray
|
||||
The detection heatmap, requiring 'time' coordinates from which the
|
||||
total duration can be calculated using
|
||||
`soundevent.arrays.get_dim_width`.
|
||||
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||
The desired maximum number of detections to allow per second of audio.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The calculated total maximum number of detections allowed for the
|
||||
entire duration of the `detection_array`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the duration cannot be calculated from the `detection_array` (e.g.,
|
||||
missing or invalid 'time' coordinates/dimension).
|
||||
"""
|
||||
if top_k_per_sec < 0:
|
||||
raise ValueError("top_k_per_sec cannot be negative.")
|
||||
|
||||
duration = get_dim_width(detection_array, Dimensions.time.value)
|
||||
return int(duration * top_k_per_sec)
|
||||
@ -15,108 +15,75 @@ precise time-frequency location of each detection. The final output aggregates
|
||||
all extracted information into a structured `xarray.Dataset`.
|
||||
"""
|
||||
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from batdetect2.typing.postprocess import ClipDetectionsTensor
|
||||
|
||||
__all__ = [
|
||||
"extract_values_at_positions",
|
||||
"extract_detection_xr_dataset",
|
||||
"extract_detection_peaks",
|
||||
]
|
||||
|
||||
|
||||
def extract_values_at_positions(
|
||||
array: xr.DataArray,
|
||||
positions: xr.DataArray,
|
||||
) -> xr.DataArray:
|
||||
"""Extract values from an array at specified time-frequency positions.
|
||||
def extract_detection_peaks(
|
||||
detection_heatmap: torch.Tensor,
|
||||
size_heatmap: torch.Tensor,
|
||||
feature_heatmap: torch.Tensor,
|
||||
classification_heatmap: torch.Tensor,
|
||||
max_detections: int = 200,
|
||||
threshold: Optional[float] = None,
|
||||
) -> List[ClipDetectionsTensor]:
|
||||
height = detection_heatmap.shape[-2]
|
||||
width = detection_heatmap.shape[-1]
|
||||
|
||||
Uses coordinate-based indexing to retrieve values from a source `array`
|
||||
(e.g., class probabilities, size predictions, features) at the time and
|
||||
frequency coordinates defined in the `positions` array.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : xr.DataArray
|
||||
The source DataArray from which to extract values. Must have 'time'
|
||||
and 'frequency' dimensions and coordinates matching the space of
|
||||
`positions`.
|
||||
positions : xr.DataArray
|
||||
A 1D DataArray whose 'time' and 'frequency' coordinates specify the
|
||||
locations from which to extract values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
A DataArray containing the values extracted from `array` at the given
|
||||
positions.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError, IndexError, KeyError
|
||||
If dimensions or coordinates are missing or incompatible between
|
||||
`array` and `positions`, or if selection fails.
|
||||
"""
|
||||
return array.sel(
|
||||
**{
|
||||
Dimensions.frequency.value: positions.coords[
|
||||
Dimensions.frequency.value
|
||||
],
|
||||
Dimensions.time.value: positions.coords[Dimensions.time.value],
|
||||
}
|
||||
).T
|
||||
|
||||
|
||||
def extract_detection_xr_dataset(
|
||||
positions: xr.DataArray,
|
||||
sizes: xr.DataArray,
|
||||
classes: xr.DataArray,
|
||||
features: xr.DataArray,
|
||||
) -> xr.Dataset:
|
||||
"""Combine extracted detection information into a structured xr.Dataset.
|
||||
|
||||
Takes the detection positions/scores and the full model output heatmaps
|
||||
(sizes, classes, optional features), extracts the relevant data at the
|
||||
detection positions, and packages everything into a single `xarray.Dataset`
|
||||
where all variables are indexed by a common 'detection' dimension.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
positions : xr.DataArray
|
||||
Output from `extract_detections_from_array`, containing detection
|
||||
scores as data and 'time', 'frequency' coordinates along the
|
||||
'detection' dimension.
|
||||
sizes : xr.DataArray
|
||||
The full size prediction heatmap from the model, with dimensions like
|
||||
('dimension', 'time', 'frequency').
|
||||
classes : xr.DataArray
|
||||
The full class probability heatmap from the model, with dimensions like
|
||||
('category', 'time', 'frequency').
|
||||
features : xr.DataArray
|
||||
The full feature map from the model, with
|
||||
dimensions like ('feature', 'time', 'frequency').
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
An xarray Dataset containing aligned information for each detection:
|
||||
- 'scores': DataArray from `positions` (score data, time/freq coords).
|
||||
- 'dimensions': DataArray with extracted size values
|
||||
(dims: 'detection', 'dimension').
|
||||
- 'classes': DataArray with extracted class probabilities
|
||||
(dims: 'detection', 'category').
|
||||
- 'features': DataArray with extracted feature vectors
|
||||
(dims: 'detection', 'feature'), if `features` was provided. All
|
||||
DataArrays share the 'detection' dimension and associated
|
||||
time/frequency coordinates.
|
||||
"""
|
||||
sizes = extract_values_at_positions(sizes, positions)
|
||||
classes = extract_values_at_positions(classes, positions)
|
||||
features = extract_values_at_positions(features, positions)
|
||||
return xr.Dataset(
|
||||
{
|
||||
"scores": positions,
|
||||
"dimensions": sizes,
|
||||
"classes": classes,
|
||||
"features": features,
|
||||
}
|
||||
freqs, times = torch.meshgrid(
|
||||
torch.arange(height, dtype=torch.int32),
|
||||
torch.arange(width, dtype=torch.int32),
|
||||
indexing="ij",
|
||||
)
|
||||
|
||||
freqs = freqs.flatten().to(detection_heatmap.device)
|
||||
times = times.flatten().to(detection_heatmap.device)
|
||||
|
||||
output_size_preds = size_heatmap.detach()
|
||||
output_features = feature_heatmap.detach()
|
||||
output_class_probs = classification_heatmap.detach()
|
||||
|
||||
predictions = []
|
||||
for idx, item in enumerate(detection_heatmap):
|
||||
item = item.squeeze().flatten() # Remove channel dim
|
||||
indices = torch.argsort(item, descending=True)[:max_detections]
|
||||
|
||||
detection_scores = item.take(indices)
|
||||
detection_freqs = freqs.take(indices)
|
||||
detection_times = times.take(indices)
|
||||
|
||||
if threshold is not None:
|
||||
mask = detection_scores >= threshold
|
||||
|
||||
detection_scores = detection_scores[mask]
|
||||
detection_times = detection_times[mask]
|
||||
detection_freqs = detection_freqs[mask]
|
||||
|
||||
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
|
||||
features = output_features[idx, :, detection_freqs, detection_times].T
|
||||
class_scores = output_class_probs[
|
||||
idx,
|
||||
:,
|
||||
detection_freqs,
|
||||
detection_times,
|
||||
].T
|
||||
|
||||
predictions.append(
|
||||
ClipDetectionsTensor(
|
||||
scores=detection_scores,
|
||||
sizes=sizes,
|
||||
features=features,
|
||||
class_scores=class_scores,
|
||||
times=detection_times.to(torch.float32) / width,
|
||||
frequencies=(detection_freqs.to(torch.float32) / height),
|
||||
)
|
||||
)
|
||||
|
||||
return predictions
|
||||
|
||||
100
src/batdetect2/postprocess/postprocessor.py
Normal file
100
src/batdetect2/postprocess/postprocessor.py
Normal file
@ -0,0 +1,100 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.postprocess.config import (
|
||||
PostprocessConfig,
|
||||
)
|
||||
from batdetect2.postprocess.extraction import extract_detection_peaks
|
||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
||||
from batdetect2.postprocess.remapping import map_detection_to_clip
|
||||
from batdetect2.typing import ModelOutput
|
||||
from batdetect2.typing.postprocess import (
|
||||
ClipDetectionsTensor,
|
||||
PostprocessorProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"build_postprocessor",
|
||||
"Postprocessor",
|
||||
]
|
||||
|
||||
|
||||
def build_postprocessor(
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[PostprocessConfig] = None,
|
||||
) -> PostprocessorProtocol:
|
||||
"""Factory function to build the standard postprocessor."""
|
||||
config = config or PostprocessConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building postprocessor with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
return Postprocessor(
|
||||
samplerate=preprocessor.output_samplerate,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
top_k_per_sec=config.top_k_per_sec,
|
||||
detection_threshold=config.detection_threshold,
|
||||
)
|
||||
|
||||
|
||||
class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
"""Standard implementation of the postprocessing pipeline."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: float,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
top_k_per_sec: int = 200,
|
||||
detection_threshold: float = 0.01,
|
||||
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||
):
|
||||
"""Initialize the Postprocessor."""
|
||||
super().__init__()
|
||||
|
||||
self.output_samplerate = samplerate
|
||||
self.min_freq = min_freq
|
||||
self.max_freq = max_freq
|
||||
self.top_k_per_sec = top_k_per_sec
|
||||
self.detection_threshold = detection_threshold
|
||||
self.nms_kernel_size = nms_kernel_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[ClipDetectionsTensor]:
|
||||
detection_heatmap = non_max_suppression(
|
||||
output.detection_probs.detach(),
|
||||
kernel_size=self.nms_kernel_size,
|
||||
)
|
||||
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.output_samplerate
|
||||
max_detections = int(self.top_k_per_sec * duration)
|
||||
detections = extract_detection_peaks(
|
||||
detection_heatmap,
|
||||
size_heatmap=output.size_preds,
|
||||
feature_heatmap=output.features,
|
||||
classification_heatmap=output.class_probs,
|
||||
max_detections=max_detections,
|
||||
threshold=self.detection_threshold,
|
||||
)
|
||||
|
||||
if start_times is None:
|
||||
start_times = [0 for _ in range(len(detections))]
|
||||
|
||||
return [
|
||||
map_detection_to_clip(
|
||||
detection,
|
||||
start_time=0,
|
||||
end_time=duration,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for detection in detections
|
||||
]
|
||||
@ -20,6 +20,7 @@ import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.typing.postprocess import ClipDetectionsTensor
|
||||
|
||||
__all__ = [
|
||||
"features_to_xarray",
|
||||
@ -29,6 +30,25 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def map_detection_to_clip(
|
||||
detections: ClipDetectionsTensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
) -> ClipDetectionsTensor:
|
||||
duration = end_time - start_time
|
||||
bandwidth = max_freq - min_freq
|
||||
return ClipDetectionsTensor(
|
||||
scores=detections.scores,
|
||||
sizes=detections.sizes,
|
||||
features=detections.features,
|
||||
class_scores=detections.class_scores,
|
||||
times=(detections.times * duration + start_time),
|
||||
frequencies=(detections.frequencies * bandwidth + min_freq),
|
||||
)
|
||||
|
||||
|
||||
def features_to_xarray(
|
||||
features: torch.Tensor,
|
||||
start_time: float,
|
||||
|
||||
@ -1,295 +0,0 @@
|
||||
"""Defines shared interfaces and data structures for postprocessing.
|
||||
|
||||
This module centralizes the Protocol definitions and common data structures
|
||||
used throughout the `batdetect2.postprocess` module.
|
||||
|
||||
The main component is the `PostprocessorProtocol`, which outlines the standard
|
||||
interface for an object responsible for executing the entire postprocessing
|
||||
pipeline. This pipeline transforms raw neural network outputs into interpretable
|
||||
detections represented as `soundevent` objects. Using protocols ensures
|
||||
modularity and consistent interaction between different parts of the BatDetect2
|
||||
system that deal with model predictions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional, Protocol
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.targets.types import Position, Size
|
||||
|
||||
__all__ = [
|
||||
"RawPrediction",
|
||||
"PostprocessorProtocol",
|
||||
"GeometryDecoder",
|
||||
]
|
||||
|
||||
|
||||
# TODO: update the docstring
|
||||
class GeometryDecoder(Protocol):
|
||||
"""Type alias for a function that recovers geometry from position and size.
|
||||
|
||||
This callable takes:
|
||||
1. A position tuple `(time, frequency)`.
|
||||
2. A NumPy array of size dimensions (e.g., `[width, height]`).
|
||||
3. Optionally a class name of the highest scoring class. This is to accomodate
|
||||
different ways of decoding geometry that depend on the predicted class.
|
||||
It should return the reconstructed `soundevent.data.Geometry` (typically a
|
||||
`BoundingBox`).
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, position: Position, size: Size, class_name: Optional[str] = None
|
||||
) -> data.Geometry: ...
|
||||
|
||||
|
||||
class RawPrediction(NamedTuple):
|
||||
"""Intermediate representation of a single detected sound event.
|
||||
|
||||
Holds extracted information about a detection after initial processing
|
||||
(like peak finding, coordinate remapping, geometry recovery) but before
|
||||
final class decoding and conversion into a `SoundEventPrediction`. This
|
||||
can be useful for evaluation or simpler data handling formats.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
geometry: data.Geometry
|
||||
The recovered estimated geometry of the detected sound event.
|
||||
Usually a bounding box.
|
||||
detection_score : float
|
||||
The confidence score associated with this detection, typically from
|
||||
the detection heatmap peak.
|
||||
class_scores : xr.DataArray
|
||||
An xarray DataArray containing the predicted probabilities or scores
|
||||
for each target class at the detection location. Indexed by a
|
||||
'category' coordinate containing class names.
|
||||
features : xr.DataArray
|
||||
An xarray DataArray containing extracted feature vectors at the
|
||||
detection location. Indexed by a 'feature' coordinate.
|
||||
"""
|
||||
|
||||
geometry: data.Geometry
|
||||
detection_score: float
|
||||
class_scores: np.ndarray
|
||||
features: np.ndarray
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatDetect2Prediction:
|
||||
raw: RawPrediction
|
||||
sound_event_prediction: data.SoundEventPrediction
|
||||
|
||||
|
||||
class PostprocessorProtocol(Protocol):
|
||||
"""Protocol defining the interface for the full postprocessing pipeline.
|
||||
|
||||
This protocol outlines the standard methods for an object that takes raw
|
||||
output from a BatDetect2 model and the corresponding input clip metadata,
|
||||
and processes it through various stages (e.g., coordinate remapping, NMS,
|
||||
detection extraction, data extraction, decoding) to produce interpretable
|
||||
results at different levels of completion.
|
||||
|
||||
Implementations manage the configured logic for all postprocessing steps.
|
||||
"""
|
||||
|
||||
def get_feature_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Remap feature tensors to coordinate-aware DataArrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch, expected
|
||||
to contain the necessary feature tensors.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects, one for each item in the
|
||||
processed batch. This list provides the timing, recording, and
|
||||
other metadata context needed to calculate real-world coordinates
|
||||
(seconds, Hz) for the output arrays. The length of this list must
|
||||
correspond to the batch size of the `output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
A list of xarray DataArrays, one for each input clip in the batch,
|
||||
in the same order. Each DataArray contains the feature vectors
|
||||
with dimensions like ('feature', 'time', 'frequency') and
|
||||
corresponding real-world coordinates.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_detection_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Remap detection tensors to coordinate-aware DataArrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch,
|
||||
containing detection heatmaps.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing coordinate context. Must match the batch size of
|
||||
`output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
A list of 2D xarray DataArrays (one per input clip, in order),
|
||||
representing the detection heatmap with 'time' and 'frequency'
|
||||
coordinates. Values typically indicate detection confidence.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_classification_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Remap classification tensors to coordinate-aware DataArrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch,
|
||||
containing class probability tensors.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing coordinate context. Must match the batch size of
|
||||
`output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
A list of 3D xarray DataArrays (one per input clip, in order),
|
||||
representing class probabilities with 'category', 'time', and
|
||||
'frequency' dimensions and coordinates.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_sizes_arrays(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.DataArray]:
|
||||
"""Remap size prediction tensors to coordinate-aware DataArrays.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch,
|
||||
containing predicted size tensors (e.g., width and height).
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing coordinate context. Must match the batch size of
|
||||
`output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.DataArray]
|
||||
A list of 3D xarray DataArrays (one per input clip, in order),
|
||||
representing predicted sizes with 'dimension'
|
||||
(e.g., ['width', 'height']), 'time', and 'frequency' dimensions and
|
||||
coordinates. Values represent estimated detection sizes.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_detection_datasets(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[xr.Dataset]:
|
||||
"""Perform remapping, NMS, detection, and data extraction for a batch.
|
||||
|
||||
Processes the raw model output for a batch to identify detection peaks
|
||||
and extract all associated information (score, position, size, class
|
||||
probs, features) at those peak locations, returning a structured
|
||||
dataset for each input clip in the batch.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing context. Must match the batch size of `output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[xr.Dataset]
|
||||
A list of xarray Datasets (one per input clip, in order). Each
|
||||
Dataset contains multiple DataArrays ('scores', 'dimensions',
|
||||
'classes', 'features') sharing a common 'detection' dimension,
|
||||
providing aligned data for each detected event in that clip.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_raw_predictions(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[List[RawPrediction]]:
|
||||
"""Extract intermediate RawPrediction objects for a batch.
|
||||
|
||||
Processes the raw model output for a batch through remapping, NMS,
|
||||
detection, data extraction, and geometry recovery to produce a list of
|
||||
`RawPrediction` objects for each corresponding input clip. This provides
|
||||
a simplified, intermediate representation before final tag decoding.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing context. Must match the batch size of `output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[List[RawPrediction]]
|
||||
A list of lists (one inner list per input clip, in order). Each
|
||||
inner list contains the `RawPrediction` objects extracted for the
|
||||
corresponding input clip.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_sound_event_predictions(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
) -> List[List[BatDetect2Prediction]]: ...
|
||||
|
||||
def get_predictions(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[data.ClipPrediction]:
|
||||
"""Perform the full postprocessing pipeline for a batch.
|
||||
|
||||
Takes raw model output for a batch and corresponding clips, applies the
|
||||
entire postprocessing chain, and returns the final, interpretable
|
||||
predictions as a list of `soundevent.data.ClipPrediction` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
The raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||
items, providing context. Must match the batch size of `output`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.ClipPrediction]
|
||||
A list containing one `ClipPrediction` object for each input clip
|
||||
(in the same order), populated with `SoundEventPrediction` objects
|
||||
representing the final detections with decoded tags and geometry.
|
||||
"""
|
||||
...
|
||||
@ -1,458 +1,19 @@
|
||||
"""Main entry point for the BatDetect2 Preprocessing subsystem.
|
||||
"""Main entry point for the BatDetect2 preprocessing subsystem."""
|
||||
|
||||
This package (`batdetect2.preprocessing`) defines and orchestrates the pipeline
|
||||
for converting raw audio input (from files or data objects) into processed
|
||||
spectrograms suitable for input to BatDetect2 models. This ensures consistent
|
||||
data handling between model training and inference.
|
||||
|
||||
The preprocessing pipeline consists of two main stages, configured via nested
|
||||
data structures:
|
||||
1. **Audio Processing (`.audio`)**: Loads audio waveforms and applies initial
|
||||
processing like resampling, duration adjustment, centering, and scaling.
|
||||
Configured via `AudioConfig`.
|
||||
2. **Spectrogram Generation (`.spectrogram`)**: Computes the spectrogram from
|
||||
the processed waveform using STFT, followed by frequency cropping, optional
|
||||
PCEN, amplitude scaling (dB, power, linear), optional denoising, optional
|
||||
resizing, and optional peak normalization. Configured via
|
||||
`SpectrogramConfig`.
|
||||
|
||||
This module provides the primary interface:
|
||||
|
||||
- `PreprocessingConfig`: A unified configuration object holding `AudioConfig`
|
||||
and `SpectrogramConfig`.
|
||||
- `load_preprocessing_config`: Function to load the unified configuration.
|
||||
- `Preprocessor`: A protocol defining the interface for the end-to-end pipeline.
|
||||
- `StandardPreprocessor`: The default implementation of the `Preprocessor`.
|
||||
- `build_preprocessor`: A factory function to create a `StandardPreprocessor`
|
||||
instance from a `PreprocessingConfig`.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess.audio import (
|
||||
DEFAULT_DURATION,
|
||||
SCALE_RAW_AUDIO,
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
AudioConfig,
|
||||
ResampleConfig,
|
||||
build_audio_loader,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
MAX_FREQ,
|
||||
MIN_FREQ,
|
||||
ConfigurableSpectrogramBuilder,
|
||||
FrequencyConfig,
|
||||
PcenConfig,
|
||||
SpecSizeConfig,
|
||||
SpectrogramConfig,
|
||||
STFTConfig,
|
||||
build_spectrogram_builder,
|
||||
get_spectrogram_resolution,
|
||||
)
|
||||
from batdetect2.preprocess.types import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
SpectrogramBuilder,
|
||||
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.preprocess.config import (
|
||||
PreprocessingConfig,
|
||||
load_preprocessing_config,
|
||||
)
|
||||
from batdetect2.preprocess.preprocessor import Preprocessor, build_preprocessor
|
||||
from batdetect2.preprocess.spectrogram import MAX_FREQ, MIN_FREQ
|
||||
|
||||
__all__ = [
|
||||
"AudioConfig",
|
||||
"AudioLoader",
|
||||
"ConfigurableSpectrogramBuilder",
|
||||
"DEFAULT_DURATION",
|
||||
"FrequencyConfig",
|
||||
"MAX_FREQ",
|
||||
"MIN_FREQ",
|
||||
"PcenConfig",
|
||||
"PreprocessingConfig",
|
||||
"ResampleConfig",
|
||||
"SCALE_RAW_AUDIO",
|
||||
"STFTConfig",
|
||||
"SpecSizeConfig",
|
||||
"SpectrogramBuilder",
|
||||
"SpectrogramConfig",
|
||||
"StandardPreprocessor",
|
||||
"Preprocessor",
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"build_audio_loader",
|
||||
"build_preprocessor",
|
||||
"build_spectrogram_builder",
|
||||
"get_spectrogram_resolution",
|
||||
"load_preprocessing_config",
|
||||
"get_default_preprocessor",
|
||||
]
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseConfig):
|
||||
"""Unified configuration for the audio preprocessing pipeline.
|
||||
|
||||
Aggregates the configuration for both the initial audio processing stage
|
||||
and the subsequent spectrogram generation stage.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
audio : AudioConfig
|
||||
Configuration settings for the audio loading and initial waveform
|
||||
processing steps (e.g., resampling, duration adjustment, scaling).
|
||||
Defaults to default `AudioConfig` settings if omitted.
|
||||
spectrogram : SpectrogramConfig
|
||||
Configuration settings for the spectrogram generation process
|
||||
(e.g., STFT parameters, frequency cropping, scaling, denoising,
|
||||
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
|
||||
"""
|
||||
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||
|
||||
|
||||
class StandardPreprocessor(PreprocessorProtocol):
|
||||
"""Standard implementation of the `Preprocessor` protocol.
|
||||
|
||||
Orchestrates the audio loading and spectrogram generation pipeline using
|
||||
an `AudioLoader` and a `SpectrogramBuilder` internally, which are
|
||||
configured according to a `PreprocessingConfig`.
|
||||
|
||||
This class is typically instantiated using the `build_preprocessor`
|
||||
factory function.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
audio_loader : AudioLoader
|
||||
The configured audio loader instance used for waveform loading and
|
||||
initial processing.
|
||||
spectrogram_builder : SpectrogramBuilder
|
||||
The configured spectrogram builder instance used for generating
|
||||
spectrograms from waveforms.
|
||||
default_samplerate : int
|
||||
The sample rate (in Hz) assumed for input waveforms when they are
|
||||
provided as raw NumPy arrays without coordinate information (e.g.,
|
||||
when calling `compute_spectrogram` directly with `np.ndarray`).
|
||||
This value is derived from the `AudioConfig` (target resample rate
|
||||
or default if resampling is off) and also serves as documentation
|
||||
for the pipeline's intended operating sample rate. Note that when
|
||||
processing `xr.DataArray` inputs that have coordinate information
|
||||
(the standard internal workflow), the sample rate embedded in the
|
||||
coordinates takes precedence over this default value during
|
||||
spectrogram calculation.
|
||||
"""
|
||||
|
||||
audio_loader: AudioLoader
|
||||
spectrogram_builder: SpectrogramBuilder
|
||||
default_samplerate: int
|
||||
max_freq: float
|
||||
min_freq: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_loader: AudioLoader,
|
||||
spectrogram_builder: SpectrogramBuilder,
|
||||
default_samplerate: int,
|
||||
max_freq: float,
|
||||
min_freq: float,
|
||||
) -> None:
|
||||
"""Initialize the StandardPreprocessor.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio_loader : AudioLoader
|
||||
An initialized audio loader conforming to the AudioLoader protocol.
|
||||
spectrogram_builder : SpectrogramBuilder
|
||||
An initialized spectrogram builder conforming to the
|
||||
SpectrogramBuilder protocol.
|
||||
default_samplerate : int
|
||||
The sample rate to assume for NumPy array inputs and potentially
|
||||
reflecting the target rate of the audio config.
|
||||
"""
|
||||
self.audio_loader = audio_loader
|
||||
self.spectrogram_builder = spectrogram_builder
|
||||
self.default_samplerate = default_samplerate
|
||||
self.max_freq = max_freq
|
||||
self.min_freq = min_freq
|
||||
|
||||
def load_file_audio(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess *only* the audio waveform from a file path.
|
||||
|
||||
Delegates to the internal `audio_loader`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the audio file.
|
||||
audio_dir : PathLike, optional
|
||||
A directory prefix if `path` is relative.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform (typically first
|
||||
channel).
|
||||
"""
|
||||
return self.audio_loader.load_file(
|
||||
path,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_recording_audio(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess *only* the audio waveform for a Recording.
|
||||
|
||||
Delegates to the internal `audio_loader`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The Recording object.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform (typically first
|
||||
channel).
|
||||
"""
|
||||
return self.audio_loader.load_recording(
|
||||
recording,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_clip_audio(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load and preprocess *only* the audio waveform for a Clip.
|
||||
|
||||
Delegates to the internal `audio_loader`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
The Clip object defining the segment.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The loaded and preprocessed audio waveform segment (typically first
|
||||
channel).
|
||||
"""
|
||||
return self.audio_loader.load_clip(
|
||||
clip,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def preprocess_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load audio from a file and compute the final processed spectrogram.
|
||||
|
||||
Performs the full pipeline:
|
||||
|
||||
Load -> Preprocess Audio -> Compute Spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the audio file.
|
||||
audio_dir : PathLike, optional
|
||||
A directory prefix if `path` is relative.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The final processed spectrogram.
|
||||
"""
|
||||
wav = self.load_file_audio(path, audio_dir=audio_dir)
|
||||
return self.spectrogram_builder(
|
||||
wav,
|
||||
samplerate=self.default_samplerate,
|
||||
)
|
||||
|
||||
def preprocess_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load audio for a Recording and compute the processed spectrogram.
|
||||
|
||||
Performs the full pipeline for the entire duration of the recording.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recording : data.Recording
|
||||
The Recording object.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The final processed spectrogram.
|
||||
"""
|
||||
wav = self.load_recording_audio(recording, audio_dir=audio_dir)
|
||||
return self.spectrogram_builder(
|
||||
wav,
|
||||
samplerate=self.default_samplerate,
|
||||
)
|
||||
|
||||
def preprocess_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Load audio for a Clip and compute the final processed spectrogram.
|
||||
|
||||
Performs the full pipeline for the specified clip segment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip : data.Clip
|
||||
The Clip object defining the audio segment.
|
||||
audio_dir : PathLike, optional
|
||||
Directory containing the audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The final processed spectrogram.
|
||||
"""
|
||||
wav = self.load_clip_audio(clip, audio_dir=audio_dir)
|
||||
return self.spectrogram_builder(
|
||||
wav,
|
||||
samplerate=self.default_samplerate,
|
||||
)
|
||||
|
||||
def compute_spectrogram(
|
||||
self, wav: Union[xr.DataArray, np.ndarray]
|
||||
) -> xr.DataArray:
|
||||
"""Compute the spectrogram from a pre-loaded audio waveform.
|
||||
|
||||
Applies the configured spectrogram generation steps
|
||||
(STFT, scaling, etc.) using the internal `spectrogram_builder`.
|
||||
|
||||
If `wav` is a NumPy array, the `default_samplerate` stored in this
|
||||
preprocessor instance will be used. If `wav` is an xarray DataArray
|
||||
with time coordinates, the sample rate derived from those coordinates
|
||||
will take precedence over `default_samplerate`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
wav : Union[xr.DataArray, np.ndarray]
|
||||
The input audio waveform. If numpy array, `default_samplerate`
|
||||
stored in this object will be assumed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The computed spectrogram.
|
||||
"""
|
||||
return self.spectrogram_builder(
|
||||
wav,
|
||||
samplerate=self.default_samplerate,
|
||||
)
|
||||
|
||||
|
||||
def load_preprocessing_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PreprocessingConfig:
|
||||
"""Load the unified preprocessing configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`PreprocessingConfig` schema, potentially extracting data from a nested
|
||||
field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
preprocessing configuration (e.g., "train.preprocessing"). If None, the
|
||||
entire file content is validated as the PreprocessingConfig.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PreprocessingConfig
|
||||
Loaded and validated preprocessing configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded config data does not conform to PreprocessingConfig.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path.
|
||||
"""
|
||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
||||
|
||||
|
||||
def build_preprocessor(
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
) -> PreprocessorProtocol:
|
||||
"""Factory function to build the standard preprocessor from configuration.
|
||||
|
||||
Creates instances of the required `AudioLoader` and `SpectrogramBuilder`
|
||||
based on the provided `PreprocessingConfig` (or defaults if config is None),
|
||||
determines the effective default sample rate, and initializes the
|
||||
`StandardPreprocessor`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : PreprocessingConfig, optional
|
||||
The unified preprocessing configuration object. If None, default
|
||||
configurations for audio and spectrogram processing will be used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Preprocessor
|
||||
An initialized `StandardPreprocessor` instance ready to process audio
|
||||
according to the configuration.
|
||||
"""
|
||||
config = config or PreprocessingConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building preprocessor with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
default_samplerate = (
|
||||
config.audio.resample.samplerate
|
||||
if config.audio.resample
|
||||
else TARGET_SAMPLERATE_HZ
|
||||
)
|
||||
|
||||
min_freq = config.spectrogram.frequencies.min_freq
|
||||
max_freq = config.spectrogram.frequencies.max_freq
|
||||
|
||||
return StandardPreprocessor(
|
||||
audio_loader=build_audio_loader(config.audio),
|
||||
spectrogram_builder=build_spectrogram_builder(config.spectrogram),
|
||||
default_samplerate=default_samplerate,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
|
||||
|
||||
def get_default_preprocessor():
|
||||
return build_preprocessor()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user