mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 17:29:34 +01:00
Compare commits
No commits in common. "2d796394f69e5ada5fa36ae5a1a867c538447786" and "d4f249366ecc8a50e954f1bdad44f8546524fa84" have entirely different histories.
2d796394f6
...
d4f249366e
@ -133,7 +133,7 @@ When you need to specify a tag, you typically use a structure with two fields:
|
|||||||
**It defaults to `class`** if you omit it, which is common when defining the main target classes.
|
**It defaults to `class`** if you omit it, which is common when defining the main target classes.
|
||||||
- `value`: The specific _value_ of the tag (e.g., `Myotis daubentonii`, `Good`, `Rain`).
|
- `value`: The specific _value_ of the tag (e.g., `Myotis daubentonii`, `Good`, `Rain`).
|
||||||
|
|
||||||
**Example YAML Configuration (e.g., inside a filter rule):**
|
**Example YAML Configuration using TagInfo (e.g., inside a filter rule):**
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# ... inside a filtering configuration section ...
|
# ... inside a filtering configuration section ...
|
||||||
|
|||||||
@ -1,125 +1,119 @@
|
|||||||
audio:
|
datasets:
|
||||||
samplerate: 256000
|
train:
|
||||||
resample:
|
name: example dataset
|
||||||
enabled: True
|
description: Only for demonstration purposes
|
||||||
method: "poly"
|
sources:
|
||||||
|
- format: batdetect2
|
||||||
|
name: Example Data
|
||||||
|
description: Examples included for testing batdetect2
|
||||||
|
annotations_dir: example_data/anns
|
||||||
|
audio_dir: example_data/audio
|
||||||
|
|
||||||
|
targets:
|
||||||
|
classes:
|
||||||
|
classes:
|
||||||
|
- name: myomys
|
||||||
|
tags:
|
||||||
|
- value: Myotis mystacinus
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- value: Pipistrellus pipistrellus
|
||||||
|
- name: eptser
|
||||||
|
tags:
|
||||||
|
- value: Eptesicus serotinus
|
||||||
|
- name: rhifer
|
||||||
|
tags:
|
||||||
|
- value: Rhinolophus ferrumequinum
|
||||||
|
generic_class:
|
||||||
|
- key: class
|
||||||
|
value: Bat
|
||||||
|
|
||||||
|
filtering:
|
||||||
|
rules:
|
||||||
|
- match_type: all
|
||||||
|
tags:
|
||||||
|
- key: event
|
||||||
|
value: Echolocation
|
||||||
|
- match_type: exclude
|
||||||
|
tags:
|
||||||
|
- key: class
|
||||||
|
value: Unknown
|
||||||
|
|
||||||
preprocess:
|
preprocess:
|
||||||
stft:
|
audio:
|
||||||
window_duration: 0.002
|
resample:
|
||||||
window_overlap: 0.75
|
samplerate: 256000
|
||||||
window_fn: hann
|
method: "poly"
|
||||||
frequencies:
|
scale: false
|
||||||
max_freq: 120000
|
center: true
|
||||||
min_freq: 10000
|
duration: null
|
||||||
size:
|
|
||||||
height: 128
|
spectrogram:
|
||||||
resize_factor: 0.5
|
stft:
|
||||||
spectrogram_transforms:
|
window_duration: 0.002
|
||||||
- name: pcen
|
window_overlap: 0.75
|
||||||
|
window_fn: hann
|
||||||
|
frequencies:
|
||||||
|
max_freq: 120000
|
||||||
|
min_freq: 10000
|
||||||
|
pcen:
|
||||||
time_constant: 0.1
|
time_constant: 0.1
|
||||||
gain: 0.98
|
gain: 0.98
|
||||||
bias: 2
|
bias: 2
|
||||||
power: 0.5
|
power: 0.5
|
||||||
- name: spectral_mean_substraction
|
scale: "amplitude"
|
||||||
|
size:
|
||||||
|
height: 128
|
||||||
|
resize_factor: 0.5
|
||||||
|
spectral_mean_substraction: true
|
||||||
|
peak_normalize: false
|
||||||
|
|
||||||
postprocess:
|
postprocess:
|
||||||
nms_kernel_size: 9
|
nms_kernel_size: 9
|
||||||
detection_threshold: 0.01
|
detection_threshold: 0.01
|
||||||
|
min_freq: 10000
|
||||||
|
max_freq: 120000
|
||||||
top_k_per_sec: 200
|
top_k_per_sec: 200
|
||||||
|
|
||||||
|
labels:
|
||||||
|
sigma: 3
|
||||||
|
|
||||||
model:
|
model:
|
||||||
input_height: 128
|
input_height: 128
|
||||||
in_channels: 1
|
in_channels: 1
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
encoder:
|
encoder:
|
||||||
layers:
|
layers:
|
||||||
- name: FreqCoordConvDown
|
- block_type: FreqCoordConvDown
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
- name: FreqCoordConvDown
|
- block_type: FreqCoordConvDown
|
||||||
out_channels: 64
|
out_channels: 64
|
||||||
- name: LayerGroup
|
- block_type: LayerGroup
|
||||||
layers:
|
layers:
|
||||||
- name: FreqCoordConvDown
|
- block_type: FreqCoordConvDown
|
||||||
out_channels: 128
|
out_channels: 128
|
||||||
- name: ConvBlock
|
- block_type: ConvBlock
|
||||||
out_channels: 256
|
out_channels: 256
|
||||||
bottleneck:
|
bottleneck:
|
||||||
channels: 256
|
channels: 256
|
||||||
layers:
|
self_attention: true
|
||||||
- name: SelfAttention
|
|
||||||
attention_channels: 256
|
|
||||||
decoder:
|
decoder:
|
||||||
layers:
|
layers:
|
||||||
- name: FreqCoordConvUp
|
- block_type: FreqCoordConvUp
|
||||||
out_channels: 64
|
out_channels: 64
|
||||||
- name: FreqCoordConvUp
|
- block_type: FreqCoordConvUp
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
- name: LayerGroup
|
- block_type: LayerGroup
|
||||||
layers:
|
layers:
|
||||||
- name: FreqCoordConvUp
|
- block_type: FreqCoordConvUp
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
- name: ConvBlock
|
- block_type: ConvBlock
|
||||||
out_channels: 32
|
out_channels: 32
|
||||||
|
|
||||||
train:
|
train:
|
||||||
optimizer:
|
batch_size: 8
|
||||||
learning_rate: 0.001
|
learning_rate: 0.001
|
||||||
t_max: 100
|
t_max: 100
|
||||||
|
|
||||||
labels:
|
|
||||||
sigma: 3
|
|
||||||
|
|
||||||
trainer:
|
|
||||||
max_epochs: 10
|
|
||||||
check_val_every_n_epoch: 5
|
|
||||||
|
|
||||||
train_loader:
|
|
||||||
batch_size: 8
|
|
||||||
|
|
||||||
num_workers: 2
|
|
||||||
|
|
||||||
shuffle: True
|
|
||||||
|
|
||||||
clipping_strategy:
|
|
||||||
name: random_subclip
|
|
||||||
duration: 0.256
|
|
||||||
|
|
||||||
augmentations:
|
|
||||||
enabled: true
|
|
||||||
audio:
|
|
||||||
- name: mix_audio
|
|
||||||
probability: 0.2
|
|
||||||
min_weight: 0.3
|
|
||||||
max_weight: 0.7
|
|
||||||
- name: add_echo
|
|
||||||
probability: 0.2
|
|
||||||
max_delay: 0.005
|
|
||||||
min_weight: 0.0
|
|
||||||
max_weight: 1.0
|
|
||||||
spectrogram:
|
|
||||||
- name: scale_volume
|
|
||||||
probability: 0.2
|
|
||||||
min_scaling: 0.0
|
|
||||||
max_scaling: 2.0
|
|
||||||
- name: warp
|
|
||||||
probability: 0.2
|
|
||||||
delta: 0.04
|
|
||||||
- name: mask_time
|
|
||||||
probability: 0.2
|
|
||||||
max_perc: 0.05
|
|
||||||
max_masks: 3
|
|
||||||
- name: mask_freq
|
|
||||||
probability: 0.2
|
|
||||||
max_perc: 0.10
|
|
||||||
max_masks: 3
|
|
||||||
|
|
||||||
val_loader:
|
|
||||||
num_workers: 2
|
|
||||||
clipping_strategy:
|
|
||||||
name: whole_audio_padded
|
|
||||||
chunk_size: 0.256
|
|
||||||
|
|
||||||
loss:
|
loss:
|
||||||
detection:
|
detection:
|
||||||
weight: 1.0
|
weight: 1.0
|
||||||
@ -133,54 +127,37 @@ train:
|
|||||||
alpha: 2
|
alpha: 2
|
||||||
size:
|
size:
|
||||||
weight: 0.1
|
weight: 0.1
|
||||||
|
|
||||||
logger:
|
logger:
|
||||||
name: csv
|
logger_type: mlflow
|
||||||
|
experiment_name: batdetect2
|
||||||
validation:
|
tracking_uri: http://localhost:5000
|
||||||
tasks:
|
log_model: true
|
||||||
- name: sound_event_detection
|
save_dir: outputs/log/
|
||||||
metrics:
|
artifact_location: outputs/artifacts/
|
||||||
- name: average_precision
|
checkpoint_path_prefix: outputs/checkpoints/
|
||||||
- name: sound_event_classification
|
augmentations:
|
||||||
metrics:
|
steps:
|
||||||
- name: average_precision
|
- augmentation_type: mix_audio
|
||||||
|
probability: 0.2
|
||||||
evaluation:
|
min_weight: 0.3
|
||||||
tasks:
|
max_weight: 0.7
|
||||||
- name: sound_event_detection
|
- augmentation_type: add_echo
|
||||||
metrics:
|
probability: 0.2
|
||||||
- name: average_precision
|
max_delay: 0.005
|
||||||
- name: roc_auc
|
min_weight: 0.0
|
||||||
plots:
|
max_weight: 1.0
|
||||||
- name: pr_curve
|
- augmentation_type: scale_volume
|
||||||
- name: score_distribution
|
probability: 0.2
|
||||||
- name: example_detection
|
min_scaling: 0.0
|
||||||
- name: sound_event_classification
|
max_scaling: 2.0
|
||||||
metrics:
|
- augmentation_type: warp
|
||||||
- name: average_precision
|
probability: 0.2
|
||||||
- name: roc_auc
|
delta: 0.04
|
||||||
plots:
|
- augmentation_type: mask_time
|
||||||
- name: pr_curve
|
probability: 0.2
|
||||||
- name: top_class_detection
|
max_perc: 0.05
|
||||||
metrics:
|
max_masks: 3
|
||||||
- name: average_precision
|
- augmentation_type: mask_freq
|
||||||
plots:
|
probability: 0.2
|
||||||
- name: pr_curve
|
max_perc: 0.10
|
||||||
- name: confusion_matrix
|
max_masks: 3
|
||||||
- name: example_classification
|
|
||||||
- name: clip_detection
|
|
||||||
metrics:
|
|
||||||
- name: average_precision
|
|
||||||
- name: roc_auc
|
|
||||||
plots:
|
|
||||||
- name: pr_curve
|
|
||||||
- name: roc_curve
|
|
||||||
- name: score_distribution
|
|
||||||
- name: clip_classification
|
|
||||||
metrics:
|
|
||||||
- name: average_precision
|
|
||||||
- name: roc_auc
|
|
||||||
plots:
|
|
||||||
- name: pr_curve
|
|
||||||
- name: roc_curve
|
|
||||||
|
|||||||
@ -1,8 +0,0 @@
|
|||||||
name: example dataset
|
|
||||||
description: Only for demonstration purposes
|
|
||||||
sources:
|
|
||||||
- format: batdetect2
|
|
||||||
name: Example Data
|
|
||||||
description: Examples included for testing batdetect2
|
|
||||||
annotations_dir: example_data/anns
|
|
||||||
audio_dir: example_data/audio
|
|
||||||
@ -1,36 +0,0 @@
|
|||||||
detection_target:
|
|
||||||
name: bat
|
|
||||||
match_if:
|
|
||||||
name: all_of
|
|
||||||
conditions:
|
|
||||||
- name: has_tag
|
|
||||||
tag: { key: event, value: Echolocation }
|
|
||||||
- name: not
|
|
||||||
condition:
|
|
||||||
name: has_tag
|
|
||||||
tag: { key: class, value: Unknown }
|
|
||||||
assign_tags:
|
|
||||||
- key: class
|
|
||||||
value: Bat
|
|
||||||
|
|
||||||
classification_targets:
|
|
||||||
- name: myomys
|
|
||||||
tags:
|
|
||||||
- key: class
|
|
||||||
value: Myotis mystacinus
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: class
|
|
||||||
value: Pipistrellus pipistrellus
|
|
||||||
- name: eptser
|
|
||||||
tags:
|
|
||||||
- key: class
|
|
||||||
value: Eptesicus serotinus
|
|
||||||
- name: rhifer
|
|
||||||
tags:
|
|
||||||
- key: class
|
|
||||||
value: Rhinolophus ferrumequinum
|
|
||||||
|
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: top-left
|
|
||||||
12
justfile
12
justfile
@ -92,11 +92,19 @@ clean-build:
|
|||||||
clean: clean-build clean-pyc clean-test clean-docs
|
clean: clean-build clean-pyc clean-test clean-docs
|
||||||
|
|
||||||
# Examples
|
# Examples
|
||||||
|
# Preprocess example data.
|
||||||
|
example-preprocess OPTIONS="":
|
||||||
|
batdetect2 preprocess \
|
||||||
|
--base-dir . \
|
||||||
|
--dataset-field datasets.train \
|
||||||
|
--config example_data/config.yaml \
|
||||||
|
{{OPTIONS}} \
|
||||||
|
example_data/datasets.yaml example_data/preprocessed
|
||||||
|
|
||||||
# Train on example data.
|
# Train on example data.
|
||||||
example-train OPTIONS="":
|
example-train OPTIONS="":
|
||||||
batdetect2 train \
|
batdetect2 train \
|
||||||
--val-dataset example_data/dataset.yaml \
|
--val-dir example_data/preprocessed \
|
||||||
--config example_data/config.yaml \
|
--config example_data/config.yaml \
|
||||||
{{OPTIONS}} \
|
{{OPTIONS}} \
|
||||||
example_data/dataset.yaml
|
example_data/preprocessed
|
||||||
|
|||||||
@ -17,13 +17,13 @@ dependencies = [
|
|||||||
"torch>=1.13.1,<2.5.0",
|
"torch>=1.13.1,<2.5.0",
|
||||||
"torchaudio>=1.13.1,<2.5.0",
|
"torchaudio>=1.13.1,<2.5.0",
|
||||||
"torchvision>=0.14.0",
|
"torchvision>=0.14.0",
|
||||||
"soundevent[audio,geometry,plot]>=2.9.1",
|
"soundevent[audio,geometry,plot]>=2.7.0",
|
||||||
"click>=8.1.7",
|
"click>=8.1.7",
|
||||||
"netcdf4>=1.6.5",
|
"netcdf4>=1.6.5",
|
||||||
"tqdm>=4.66.2",
|
"tqdm>=4.66.2",
|
||||||
"cf-xarray>=0.9.0",
|
"cf-xarray>=0.9.0",
|
||||||
"onnx>=1.16.0",
|
"onnx>=1.16.0",
|
||||||
"lightning[extra]==2.5.0",
|
"lightning[extra]>=2.2.2",
|
||||||
"tensorboard>=2.16.2",
|
"tensorboard>=2.16.2",
|
||||||
"omegaconf>=2.3.0",
|
"omegaconf>=2.3.0",
|
||||||
"pyyaml>=6.0.2",
|
"pyyaml>=6.0.2",
|
||||||
|
|||||||
@ -1,272 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import List, Optional, Sequence
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from soundevent import data
|
|
||||||
from soundevent.audio.files import get_audio_files
|
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
|
||||||
from batdetect2.config import BatDetect2Config
|
|
||||||
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
|
||||||
from batdetect2.inference import process_file_list, run_batch_inference
|
|
||||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
|
||||||
from batdetect2.models import Model, build_model
|
|
||||||
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
|
||||||
from batdetect2.preprocess import build_preprocessor
|
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.train import (
|
|
||||||
DEFAULT_CHECKPOINT_DIR,
|
|
||||||
load_model_from_checkpoint,
|
|
||||||
train,
|
|
||||||
)
|
|
||||||
from batdetect2.typing import (
|
|
||||||
AudioLoader,
|
|
||||||
BatDetect2Prediction,
|
|
||||||
EvaluatorProtocol,
|
|
||||||
PostprocessorProtocol,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
RawPrediction,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2API:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: BatDetect2Config,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
audio_loader: AudioLoader,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
postprocessor: PostprocessorProtocol,
|
|
||||||
evaluator: EvaluatorProtocol,
|
|
||||||
model: Model,
|
|
||||||
):
|
|
||||||
self.config = config
|
|
||||||
self.targets = targets
|
|
||||||
self.audio_loader = audio_loader
|
|
||||||
self.preprocessor = preprocessor
|
|
||||||
self.postprocessor = postprocessor
|
|
||||||
self.evaluator = evaluator
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
self.model.eval()
|
|
||||||
|
|
||||||
def train(
|
|
||||||
self,
|
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
|
||||||
train_workers: Optional[int] = None,
|
|
||||||
val_workers: Optional[int] = None,
|
|
||||||
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
|
|
||||||
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
|
|
||||||
experiment_name: Optional[str] = None,
|
|
||||||
run_name: Optional[str] = None,
|
|
||||||
seed: Optional[int] = None,
|
|
||||||
):
|
|
||||||
train(
|
|
||||||
train_annotations=train_annotations,
|
|
||||||
val_annotations=val_annotations,
|
|
||||||
targets=self.targets,
|
|
||||||
config=self.config,
|
|
||||||
audio_loader=self.audio_loader,
|
|
||||||
preprocessor=self.preprocessor,
|
|
||||||
train_workers=train_workers,
|
|
||||||
val_workers=val_workers,
|
|
||||||
checkpoint_dir=checkpoint_dir,
|
|
||||||
log_dir=log_dir,
|
|
||||||
experiment_name=experiment_name,
|
|
||||||
run_name=run_name,
|
|
||||||
seed=seed,
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
self,
|
|
||||||
test_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
num_workers: Optional[int] = None,
|
|
||||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
|
||||||
experiment_name: Optional[str] = None,
|
|
||||||
run_name: Optional[str] = None,
|
|
||||||
):
|
|
||||||
return evaluate(
|
|
||||||
self.model,
|
|
||||||
test_annotations,
|
|
||||||
targets=self.targets,
|
|
||||||
audio_loader=self.audio_loader,
|
|
||||||
preprocessor=self.preprocessor,
|
|
||||||
config=self.config,
|
|
||||||
num_workers=num_workers,
|
|
||||||
output_dir=output_dir,
|
|
||||||
experiment_name=experiment_name,
|
|
||||||
run_name=run_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_audio(self, path: data.PathLike) -> np.ndarray:
|
|
||||||
return self.audio_loader.load_file(path)
|
|
||||||
|
|
||||||
def load_clip(self, clip: data.Clip) -> np.ndarray:
|
|
||||||
return self.audio_loader.load_clip(clip)
|
|
||||||
|
|
||||||
def generate_spectrogram(
|
|
||||||
self,
|
|
||||||
audio: np.ndarray,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
tensor = torch.tensor(audio).unsqueeze(0)
|
|
||||||
return self.preprocessor(tensor)
|
|
||||||
|
|
||||||
def process_file(self, audio_file: str) -> BatDetect2Prediction:
|
|
||||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
|
||||||
wav = self.audio_loader.load_recording(recording)
|
|
||||||
detections = self.process_audio(wav)
|
|
||||||
return BatDetect2Prediction(
|
|
||||||
clip=data.Clip(
|
|
||||||
uuid=recording.uuid,
|
|
||||||
recording=recording,
|
|
||||||
start_time=0,
|
|
||||||
end_time=recording.duration,
|
|
||||||
),
|
|
||||||
predictions=detections,
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_audio(
|
|
||||||
self,
|
|
||||||
audio: np.ndarray,
|
|
||||||
) -> List[RawPrediction]:
|
|
||||||
spec = self.generate_spectrogram(audio)
|
|
||||||
return self.process_spectrogram(spec)
|
|
||||||
|
|
||||||
def process_spectrogram(
|
|
||||||
self,
|
|
||||||
spec: torch.Tensor,
|
|
||||||
start_time: float = 0,
|
|
||||||
) -> List[RawPrediction]:
|
|
||||||
if spec.ndim == 4 and spec.shape[0] > 1:
|
|
||||||
raise ValueError("Batched spectrograms not supported.")
|
|
||||||
|
|
||||||
if spec.ndim == 3:
|
|
||||||
spec = spec.unsqueeze(0)
|
|
||||||
|
|
||||||
outputs = self.model.detector(spec)
|
|
||||||
|
|
||||||
detections = self.model.postprocessor(
|
|
||||||
outputs,
|
|
||||||
start_times=[start_time],
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
return to_raw_predictions(detections.numpy(), targets=self.targets)
|
|
||||||
|
|
||||||
def process_directory(
|
|
||||||
self,
|
|
||||||
audio_dir: data.PathLike,
|
|
||||||
) -> List[BatDetect2Prediction]:
|
|
||||||
files = list(get_audio_files(audio_dir))
|
|
||||||
return self.process_files(files)
|
|
||||||
|
|
||||||
def process_files(
|
|
||||||
self,
|
|
||||||
audio_files: Sequence[data.PathLike],
|
|
||||||
num_workers: Optional[int] = None,
|
|
||||||
) -> List[BatDetect2Prediction]:
|
|
||||||
return process_file_list(
|
|
||||||
self.model,
|
|
||||||
audio_files,
|
|
||||||
config=self.config,
|
|
||||||
targets=self.targets,
|
|
||||||
audio_loader=self.audio_loader,
|
|
||||||
preprocessor=self.preprocessor,
|
|
||||||
num_workers=num_workers,
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_clips(
|
|
||||||
self,
|
|
||||||
clips: Sequence[data.Clip],
|
|
||||||
) -> List[BatDetect2Prediction]:
|
|
||||||
return run_batch_inference(
|
|
||||||
self.model,
|
|
||||||
clips,
|
|
||||||
targets=self.targets,
|
|
||||||
audio_loader=self.audio_loader,
|
|
||||||
preprocessor=self.preprocessor,
|
|
||||||
config=self.config,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: BatDetect2Config):
|
|
||||||
targets = build_targets(config=config.targets)
|
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=config.audio)
|
|
||||||
|
|
||||||
preprocessor = build_preprocessor(
|
|
||||||
input_samplerate=audio_loader.samplerate,
|
|
||||||
config=config.preprocess,
|
|
||||||
)
|
|
||||||
|
|
||||||
postprocessor = build_postprocessor(
|
|
||||||
preprocessor,
|
|
||||||
config=config.postprocess,
|
|
||||||
)
|
|
||||||
|
|
||||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
|
||||||
|
|
||||||
# NOTE: Better to have a separate instance of
|
|
||||||
# preprocessor and postprocessor as these may be moved
|
|
||||||
# to another device.
|
|
||||||
model = build_model(
|
|
||||||
config=config.model,
|
|
||||||
targets=targets,
|
|
||||||
preprocessor=build_preprocessor(
|
|
||||||
input_samplerate=audio_loader.samplerate,
|
|
||||||
config=config.preprocess,
|
|
||||||
),
|
|
||||||
postprocessor=build_postprocessor(
|
|
||||||
preprocessor,
|
|
||||||
config=config.postprocess,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
evaluator=evaluator,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_checkpoint(
|
|
||||||
cls,
|
|
||||||
path: data.PathLike,
|
|
||||||
config: Optional[BatDetect2Config] = None,
|
|
||||||
):
|
|
||||||
model, stored_config = load_model_from_checkpoint(path)
|
|
||||||
|
|
||||||
config = config or stored_config
|
|
||||||
|
|
||||||
targets = build_targets(config=config.targets)
|
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=config.audio)
|
|
||||||
|
|
||||||
preprocessor = build_preprocessor(
|
|
||||||
input_samplerate=audio_loader.samplerate,
|
|
||||||
config=config.preprocess,
|
|
||||||
)
|
|
||||||
|
|
||||||
postprocessor = build_postprocessor(
|
|
||||||
preprocessor,
|
|
||||||
config=config.postprocess,
|
|
||||||
)
|
|
||||||
|
|
||||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
evaluator=evaluator,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
@ -1,16 +0,0 @@
|
|||||||
from batdetect2.audio.clips import ClipConfig, build_clipper
|
|
||||||
from batdetect2.audio.loader import (
|
|
||||||
TARGET_SAMPLERATE_HZ,
|
|
||||||
AudioConfig,
|
|
||||||
SoundEventAudioLoader,
|
|
||||||
build_audio_loader,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"TARGET_SAMPLERATE_HZ",
|
|
||||||
"AudioConfig",
|
|
||||||
"SoundEventAudioLoader",
|
|
||||||
"build_audio_loader",
|
|
||||||
"ClipConfig",
|
|
||||||
"build_clipper",
|
|
||||||
]
|
|
||||||
@ -1,264 +0,0 @@
|
|||||||
from typing import Annotated, List, Literal, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
from soundevent.geometry import compute_bounds, intervals_overlap
|
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
|
||||||
from batdetect2.typing import ClipperProtocol
|
|
||||||
|
|
||||||
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
|
||||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"build_clipper",
|
|
||||||
"ClipConfig",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")
|
|
||||||
|
|
||||||
|
|
||||||
class RandomClipConfig(BaseConfig):
|
|
||||||
name: Literal["random_subclip"] = "random_subclip"
|
|
||||||
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
|
||||||
random: bool = True
|
|
||||||
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
|
|
||||||
min_sound_event_overlap: float = 0
|
|
||||||
|
|
||||||
|
|
||||||
class RandomClip:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
duration: float = 0.5,
|
|
||||||
max_empty: float = 0.2,
|
|
||||||
random: bool = True,
|
|
||||||
min_sound_event_overlap: float = 0,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.duration = duration
|
|
||||||
self.random = random
|
|
||||||
self.max_empty = max_empty
|
|
||||||
self.min_sound_event_overlap = min_sound_event_overlap
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
) -> data.ClipAnnotation:
|
|
||||||
subclip = self.get_subclip(clip_annotation.clip)
|
|
||||||
sound_events = select_sound_event_annotations(
|
|
||||||
clip_annotation,
|
|
||||||
subclip,
|
|
||||||
min_overlap=self.min_sound_event_overlap,
|
|
||||||
)
|
|
||||||
return clip_annotation.model_copy(
|
|
||||||
update=dict(
|
|
||||||
clip=subclip,
|
|
||||||
sound_events=sound_events,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_subclip(self, clip: data.Clip) -> data.Clip:
|
|
||||||
return select_random_subclip(
|
|
||||||
clip,
|
|
||||||
random=self.random,
|
|
||||||
duration=self.duration,
|
|
||||||
max_empty=self.max_empty,
|
|
||||||
)
|
|
||||||
|
|
||||||
@clipper_registry.register(RandomClipConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: RandomClipConfig):
|
|
||||||
return RandomClip(
|
|
||||||
duration=config.duration,
|
|
||||||
max_empty=config.max_empty,
|
|
||||||
min_sound_event_overlap=config.min_sound_event_overlap,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_subclip_annotation(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
random: bool = True,
|
|
||||||
duration: float = 0.5,
|
|
||||||
max_empty: float = 0.2,
|
|
||||||
min_sound_event_overlap: float = 0,
|
|
||||||
) -> data.ClipAnnotation:
|
|
||||||
clip = clip_annotation.clip
|
|
||||||
|
|
||||||
subclip = select_random_subclip(
|
|
||||||
clip,
|
|
||||||
random=random,
|
|
||||||
duration=duration,
|
|
||||||
max_empty=max_empty,
|
|
||||||
)
|
|
||||||
|
|
||||||
sound_events = select_sound_event_annotations(
|
|
||||||
clip_annotation,
|
|
||||||
subclip,
|
|
||||||
min_overlap=min_sound_event_overlap,
|
|
||||||
)
|
|
||||||
|
|
||||||
return clip_annotation.model_copy(
|
|
||||||
update=dict(
|
|
||||||
clip=subclip,
|
|
||||||
sound_events=sound_events,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def select_random_subclip(
|
|
||||||
clip: data.Clip,
|
|
||||||
random: bool = True,
|
|
||||||
duration: float = 0.5,
|
|
||||||
max_empty: float = 0.2,
|
|
||||||
) -> data.Clip:
|
|
||||||
start_time = clip.start_time
|
|
||||||
end_time = clip.end_time
|
|
||||||
|
|
||||||
if duration > clip.duration + max_empty or not random:
|
|
||||||
return clip.model_copy(
|
|
||||||
update=dict(
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=start_time + duration,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
random_start_time = np.random.uniform(
|
|
||||||
low=start_time,
|
|
||||||
high=end_time + max_empty - duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
return clip.model_copy(
|
|
||||||
update=dict(
|
|
||||||
start_time=random_start_time,
|
|
||||||
end_time=random_start_time + duration,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def select_sound_event_annotations(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
subclip: data.Clip,
|
|
||||||
min_overlap: float = 0,
|
|
||||||
) -> List[data.SoundEventAnnotation]:
|
|
||||||
selected = []
|
|
||||||
|
|
||||||
start_time = subclip.start_time
|
|
||||||
end_time = subclip.end_time
|
|
||||||
|
|
||||||
for sound_event_annotation in clip_annotation.sound_events:
|
|
||||||
geometry = sound_event_annotation.sound_event.geometry
|
|
||||||
|
|
||||||
if geometry is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
geom_start_time, _, geom_end_time, _ = compute_bounds(geometry)
|
|
||||||
|
|
||||||
if not intervals_overlap(
|
|
||||||
(start_time, end_time),
|
|
||||||
(geom_start_time, geom_end_time),
|
|
||||||
min_absolute_overlap=min_overlap,
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
selected.append(sound_event_annotation)
|
|
||||||
|
|
||||||
return selected
|
|
||||||
|
|
||||||
|
|
||||||
class PaddedClipConfig(BaseConfig):
|
|
||||||
name: Literal["whole_audio_padded"] = "whole_audio_padded"
|
|
||||||
chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION
|
|
||||||
|
|
||||||
|
|
||||||
class PaddedClip:
|
|
||||||
def __init__(self, chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION):
|
|
||||||
self.chunk_size = chunk_size
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
) -> data.ClipAnnotation:
|
|
||||||
clip = clip_annotation.clip
|
|
||||||
clip = self.get_subclip(clip)
|
|
||||||
return clip_annotation.model_copy(update=dict(clip=clip))
|
|
||||||
|
|
||||||
def get_subclip(self, clip: data.Clip) -> data.Clip:
|
|
||||||
duration = clip.duration
|
|
||||||
|
|
||||||
target_duration = float(
|
|
||||||
self.chunk_size * np.ceil(duration / self.chunk_size)
|
|
||||||
)
|
|
||||||
clip = clip.model_copy(
|
|
||||||
update=dict(
|
|
||||||
end_time=clip.start_time + target_duration,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return clip
|
|
||||||
|
|
||||||
@clipper_registry.register(PaddedClipConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: PaddedClipConfig):
|
|
||||||
return PaddedClip(chunk_size=config.chunk_size)
|
|
||||||
|
|
||||||
|
|
||||||
class FixedDurationClipConfig(BaseConfig):
|
|
||||||
name: Literal["fixed_duration"] = "fixed_duration"
|
|
||||||
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
|
||||||
|
|
||||||
|
|
||||||
class FixedDurationClip:
|
|
||||||
def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION):
|
|
||||||
self.duration = duration
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
) -> data.ClipAnnotation:
|
|
||||||
clip = self.get_subclip(clip_annotation.clip)
|
|
||||||
sound_events = select_sound_event_annotations(
|
|
||||||
clip_annotation,
|
|
||||||
clip,
|
|
||||||
min_overlap=0,
|
|
||||||
)
|
|
||||||
return clip_annotation.model_copy(
|
|
||||||
update=dict(
|
|
||||||
clip=clip,
|
|
||||||
sound_events=sound_events,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_subclip(self, clip: data.Clip) -> data.Clip:
|
|
||||||
return clip.model_copy(
|
|
||||||
update=dict(
|
|
||||||
end_time=clip.start_time + self.duration,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
@clipper_registry.register(FixedDurationClipConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: FixedDurationClipConfig):
|
|
||||||
return FixedDurationClip(duration=config.duration)
|
|
||||||
|
|
||||||
|
|
||||||
ClipConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
RandomClipConfig,
|
|
||||||
PaddedClipConfig,
|
|
||||||
FixedDurationClipConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
|
|
||||||
config = config or RandomClipConfig()
|
|
||||||
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Building clipper with config: \n{}",
|
|
||||||
lambda: config.to_yaml_string(),
|
|
||||||
)
|
|
||||||
return clipper_registry.build(config)
|
|
||||||
@ -1,295 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from numpy.typing import DTypeLike
|
|
||||||
from pydantic import Field
|
|
||||||
from scipy.signal import resample, resample_poly
|
|
||||||
from soundevent import audio, data
|
|
||||||
from soundfile import LibsndfileError
|
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
|
||||||
from batdetect2.typing import AudioLoader
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"SoundEventAudioLoader",
|
|
||||||
"build_audio_loader",
|
|
||||||
"load_file_audio",
|
|
||||||
"load_recording_audio",
|
|
||||||
"load_clip_audio",
|
|
||||||
"resample_audio",
|
|
||||||
]
|
|
||||||
|
|
||||||
TARGET_SAMPLERATE_HZ = 256_000
|
|
||||||
"""Default target sample rate in Hz used if resampling is enabled."""
|
|
||||||
|
|
||||||
|
|
||||||
class ResampleConfig(BaseConfig):
|
|
||||||
"""Configuration for audio resampling.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
samplerate : int, default=256000
|
|
||||||
The target sample rate in Hz to resample the audio to. Must be > 0.
|
|
||||||
method : str, default="poly"
|
|
||||||
The resampling algorithm to use. Options:
|
|
||||||
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
|
|
||||||
Generally fast.
|
|
||||||
- "fourier": Resampling via Fourier method using
|
|
||||||
`scipy.signal.resample`. May handle non-integer
|
|
||||||
resampling factors differently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
enabled: bool = True
|
|
||||||
method: str = "poly"
|
|
||||||
|
|
||||||
|
|
||||||
class AudioConfig(BaseConfig):
|
|
||||||
"""Configuration for loading and initial audio preprocessing."""
|
|
||||||
|
|
||||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
|
||||||
resample: ResampleConfig = Field(default_factory=ResampleConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader:
|
|
||||||
"""Factory function to create an AudioLoader based on configuration."""
|
|
||||||
config = config or AudioConfig()
|
|
||||||
return SoundEventAudioLoader(
|
|
||||||
samplerate=config.samplerate,
|
|
||||||
config=config.resample,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SoundEventAudioLoader(AudioLoader):
|
|
||||||
"""Concrete implementation of the `AudioLoader`."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
|
||||||
config: Optional[ResampleConfig] = None,
|
|
||||||
):
|
|
||||||
self.samplerate = samplerate
|
|
||||||
self.config = config or ResampleConfig()
|
|
||||||
|
|
||||||
def load_file(
|
|
||||||
self,
|
|
||||||
path: data.PathLike,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Load and preprocess audio directly from a file path."""
|
|
||||||
return load_file_audio(
|
|
||||||
path,
|
|
||||||
samplerate=self.samplerate,
|
|
||||||
config=self.config,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_recording(
|
|
||||||
self,
|
|
||||||
recording: data.Recording,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Load and preprocess the entire audio for a Recording object."""
|
|
||||||
return load_recording_audio(
|
|
||||||
recording,
|
|
||||||
samplerate=self.samplerate,
|
|
||||||
config=self.config,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_clip(
|
|
||||||
self,
|
|
||||||
clip: data.Clip,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Load and preprocess the audio segment defined by a Clip object."""
|
|
||||||
return load_clip_audio(
|
|
||||||
clip,
|
|
||||||
samplerate=self.samplerate,
|
|
||||||
config=self.config,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_file_audio(
|
|
||||||
path: data.PathLike,
|
|
||||||
samplerate: Optional[int] = None,
|
|
||||||
config: Optional[ResampleConfig] = None,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Load and preprocess audio from a file path using specified config."""
|
|
||||||
try:
|
|
||||||
recording = data.Recording.from_file(path)
|
|
||||||
except LibsndfileError as e:
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Could not load the recording at path: {path}. Error: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
return load_recording_audio(
|
|
||||||
recording,
|
|
||||||
samplerate=samplerate,
|
|
||||||
config=config,
|
|
||||||
dtype=dtype,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_recording_audio(
|
|
||||||
recording: data.Recording,
|
|
||||||
samplerate: Optional[int] = None,
|
|
||||||
config: Optional[ResampleConfig] = None,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Load and preprocess the entire audio content of a recording using config."""
|
|
||||||
clip = data.Clip(
|
|
||||||
recording=recording,
|
|
||||||
start_time=0,
|
|
||||||
end_time=recording.duration,
|
|
||||||
)
|
|
||||||
return load_clip_audio(
|
|
||||||
clip,
|
|
||||||
samplerate=samplerate,
|
|
||||||
config=config,
|
|
||||||
dtype=dtype,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_clip_audio(
|
|
||||||
clip: data.Clip,
|
|
||||||
samplerate: Optional[int] = None,
|
|
||||||
config: Optional[ResampleConfig] = None,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Load and preprocess a specific audio clip segment based on config."""
|
|
||||||
try:
|
|
||||||
wav = (
|
|
||||||
audio.load_clip(clip, audio_dir=audio_dir)
|
|
||||||
.sel(channel=0)
|
|
||||||
.astype(dtype)
|
|
||||||
)
|
|
||||||
except LibsndfileError as e:
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Could not load the recording at path: {clip.recording.path}. "
|
|
||||||
f"Error: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if not config or not config.enabled or samplerate is None:
|
|
||||||
return wav.data.astype(dtype)
|
|
||||||
|
|
||||||
sr = int(1 / wav.time.attrs["step"])
|
|
||||||
return resample_audio(
|
|
||||||
wav.data,
|
|
||||||
sr=sr,
|
|
||||||
samplerate=samplerate,
|
|
||||||
method=config.method,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def resample_audio(
|
|
||||||
wav: np.ndarray,
|
|
||||||
sr: int,
|
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
|
||||||
method: str = "poly",
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Resample an audio waveform DataArray to a target sample rate."""
|
|
||||||
if sr == samplerate:
|
|
||||||
return wav
|
|
||||||
|
|
||||||
if method == "poly":
|
|
||||||
return resample_audio_poly(
|
|
||||||
wav,
|
|
||||||
sr_orig=sr,
|
|
||||||
sr_new=samplerate,
|
|
||||||
)
|
|
||||||
elif method == "fourier":
|
|
||||||
return resample_audio_fourier(
|
|
||||||
wav,
|
|
||||||
sr_orig=sr,
|
|
||||||
sr_new=samplerate,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Resampling method '{method}' not implemented"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def resample_audio_poly(
|
|
||||||
array: np.ndarray,
|
|
||||||
sr_orig: int,
|
|
||||||
sr_new: int,
|
|
||||||
axis: int = -1,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Resample a numpy array using `scipy.signal.resample_poly`.
|
|
||||||
|
|
||||||
This method is often preferred for signals when the ratio of new
|
|
||||||
to old sample rates can be expressed as a rational number. It uses
|
|
||||||
polyphase filtering.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
array : np.ndarray
|
|
||||||
The input array to resample.
|
|
||||||
sr_orig : int
|
|
||||||
The original sample rate in Hz.
|
|
||||||
sr_new : int
|
|
||||||
The target sample rate in Hz.
|
|
||||||
axis : int, default=-1
|
|
||||||
The axis of `array` along which to resample.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
The array resampled to the target sample rate.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If sample rates are not positive.
|
|
||||||
"""
|
|
||||||
gcd = np.gcd(sr_orig, sr_new)
|
|
||||||
return resample_poly(
|
|
||||||
array,
|
|
||||||
sr_new // gcd,
|
|
||||||
sr_orig // gcd,
|
|
||||||
axis=axis,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def resample_audio_fourier(
|
|
||||||
array: np.ndarray,
|
|
||||||
sr_orig: int,
|
|
||||||
sr_new: int,
|
|
||||||
axis: int = -1,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Resample a numpy array using `scipy.signal.resample`.
|
|
||||||
|
|
||||||
This method uses FFTs to resample the signal.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
array : np.ndarray
|
|
||||||
The input array to resample.
|
|
||||||
num : int
|
|
||||||
The desired number of samples in the output array along `axis`.
|
|
||||||
axis : int, default=-1
|
|
||||||
The axis of `array` along which to resample.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
The array resampled to have `num` samples along `axis`.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `num` is negative.
|
|
||||||
"""
|
|
||||||
ratio = sr_new / sr_orig
|
|
||||||
return resample( # type: ignore
|
|
||||||
array,
|
|
||||||
int(array.shape[axis] * ratio),
|
|
||||||
axis=axis,
|
|
||||||
)
|
|
||||||
@ -1,7 +1,7 @@
|
|||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.cli.compat import detect
|
from batdetect2.cli.compat import detect
|
||||||
from batdetect2.cli.data import data
|
from batdetect2.cli.data import data
|
||||||
from batdetect2.cli.evaluate import evaluate_command
|
from batdetect2.cli.preprocess import preprocess
|
||||||
from batdetect2.cli.train import train_command
|
from batdetect2.cli.train import train_command
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -9,7 +9,7 @@ __all__ = [
|
|||||||
"detect",
|
"detect",
|
||||||
"data",
|
"data",
|
||||||
"train_command",
|
"train_command",
|
||||||
"evaluate_command",
|
"preprocess",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,15 +1,10 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
from batdetect2 import api
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
|
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
||||||
DEFAULT_MODEL_PATH = os.path.join(
|
from batdetect2.types import ProcessingConfiguration
|
||||||
os.path.dirname(os.path.dirname(__file__)),
|
from batdetect2.utils.detector_utils import save_results_to_file
|
||||||
"models",
|
|
||||||
"checkpoints",
|
|
||||||
"Net2DFast_UK_same.pth.tar",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -79,9 +74,6 @@ def detect(
|
|||||||
|
|
||||||
Input files should be short in duration e.g. < 30 seconds.
|
Input files should be short in duration e.g. < 30 seconds.
|
||||||
"""
|
"""
|
||||||
from batdetect2 import api
|
|
||||||
from batdetect2.utils.detector_utils import save_results_to_file
|
|
||||||
|
|
||||||
click.echo(f"Loading model: {args['model_path']}")
|
click.echo(f"Loading model: {args['model_path']}")
|
||||||
model, params = api.load_model(args["model_path"])
|
model, params = api.load_model(args["model_path"])
|
||||||
|
|
||||||
@ -131,7 +123,7 @@ def detect(
|
|||||||
click.echo(f" {err}")
|
click.echo(f" {err}")
|
||||||
|
|
||||||
|
|
||||||
def print_config(config):
|
def print_config(config: ProcessingConfiguration):
|
||||||
"""Print the processing configuration."""
|
"""Print the processing configuration."""
|
||||||
click.echo("\nProcessing Configuration:")
|
click.echo("\nProcessing Configuration:")
|
||||||
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from typing import Optional
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
|
||||||
__all__ = ["data"]
|
__all__ = ["data"]
|
||||||
|
|
||||||
@ -32,8 +33,6 @@ def summary(
|
|||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
base_dir: Optional[Path] = None,
|
base_dir: Optional[Path] = None,
|
||||||
):
|
):
|
||||||
from batdetect2.data import load_dataset_from_config
|
|
||||||
|
|
||||||
base_dir = base_dir or Path.cwd()
|
base_dir = base_dir or Path.cwd()
|
||||||
dataset = load_dataset_from_config(
|
dataset = load_dataset_from_config(
|
||||||
dataset_config,
|
dataset_config,
|
||||||
|
|||||||
@ -1,78 +0,0 @@
|
|||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import click
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
|
||||||
|
|
||||||
__all__ = ["evaluate_command"]
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command(name="evaluate")
|
|
||||||
@click.argument("model-path", type=click.Path(exists=True))
|
|
||||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
|
||||||
@click.option("--config", "config_path", type=click.Path())
|
|
||||||
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
|
|
||||||
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
|
|
||||||
@click.option("--experiment-name", type=str)
|
|
||||||
@click.option("--run-name", type=str)
|
|
||||||
@click.option("--workers", "num_workers", type=int)
|
|
||||||
@click.option(
|
|
||||||
"-v",
|
|
||||||
"--verbose",
|
|
||||||
count=True,
|
|
||||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
|
||||||
)
|
|
||||||
def evaluate_command(
|
|
||||||
model_path: Path,
|
|
||||||
test_dataset: Path,
|
|
||||||
base_dir: Path,
|
|
||||||
config_path: Optional[Path],
|
|
||||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
|
||||||
num_workers: Optional[int] = None,
|
|
||||||
experiment_name: Optional[str] = None,
|
|
||||||
run_name: Optional[str] = None,
|
|
||||||
verbose: int = 0,
|
|
||||||
):
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
|
||||||
from batdetect2.config import load_full_config
|
|
||||||
from batdetect2.data import load_dataset_from_config
|
|
||||||
|
|
||||||
logger.remove()
|
|
||||||
if verbose == 0:
|
|
||||||
log_level = "WARNING"
|
|
||||||
elif verbose == 1:
|
|
||||||
log_level = "INFO"
|
|
||||||
else:
|
|
||||||
log_level = "DEBUG"
|
|
||||||
logger.add(sys.stderr, level=log_level)
|
|
||||||
|
|
||||||
logger.info("Initiating evaluation process...")
|
|
||||||
|
|
||||||
test_annotations = load_dataset_from_config(
|
|
||||||
test_dataset,
|
|
||||||
base_dir=base_dir,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Loaded {num_annotations} test examples",
|
|
||||||
num_annotations=len(test_annotations),
|
|
||||||
)
|
|
||||||
|
|
||||||
config = None
|
|
||||||
if config_path is not None:
|
|
||||||
config = load_full_config(config_path)
|
|
||||||
|
|
||||||
api = BatDetect2API.from_checkpoint(model_path, config=config)
|
|
||||||
|
|
||||||
api.evaluate(
|
|
||||||
test_annotations,
|
|
||||||
num_workers=num_workers,
|
|
||||||
output_dir=output_dir,
|
|
||||||
experiment_name=experiment_name,
|
|
||||||
run_name=run_name,
|
|
||||||
)
|
|
||||||
154
src/batdetect2/cli/preprocess.py
Normal file
154
src/batdetect2/cli/preprocess.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import click
|
||||||
|
import yaml
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from batdetect2.cli.base import cli
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
from batdetect2.train.preprocess import (
|
||||||
|
TrainPreprocessConfig,
|
||||||
|
load_train_preprocessing_config,
|
||||||
|
preprocess_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["preprocess"]
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument(
|
||||||
|
"dataset_config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
)
|
||||||
|
@click.argument(
|
||||||
|
"output",
|
||||||
|
type=click.Path(),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dataset-field",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"Specifies the key to access the dataset information within the "
|
||||||
|
"dataset configuration file, if the information is nested inside a "
|
||||||
|
"dictionary. If the dataset information is at the top level of the "
|
||||||
|
"config file, you don't need to specify this."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--base-dir",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help=(
|
||||||
|
"The main directory where your audio recordings and annotation "
|
||||||
|
"files are stored. This helps the program find your data, "
|
||||||
|
"especially if the paths in your dataset configuration file "
|
||||||
|
"are relative."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help=(
|
||||||
|
"Path to the configuration file. This file tells "
|
||||||
|
"the program how to prepare your audio data before training, such "
|
||||||
|
"as resampling or applying filters."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--config-field",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"If the preprocessing settings are inside a nested dictionary "
|
||||||
|
"within the preprocessing configuration file, specify the key "
|
||||||
|
"here to access them. If the preprocessing settings are at the "
|
||||||
|
"top level, you don't need to specify this."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--force",
|
||||||
|
is_flag=True,
|
||||||
|
help=(
|
||||||
|
"If a preprocessed file already exists, this option tells the "
|
||||||
|
"program to overwrite it with the new preprocessed data. Use "
|
||||||
|
"this if you want to re-do the preprocessing even if the files "
|
||||||
|
"already exist."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
help=(
|
||||||
|
"The maximum number of computer cores to use when processing "
|
||||||
|
"your audio data. Using more cores can speed up the preprocessing, "
|
||||||
|
"but don't use more than your computer has available. By default, "
|
||||||
|
"the program will use all available cores."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"-v",
|
||||||
|
"--verbose",
|
||||||
|
count=True,
|
||||||
|
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||||
|
)
|
||||||
|
def preprocess(
|
||||||
|
dataset_config: Path,
|
||||||
|
output: Path,
|
||||||
|
base_dir: Optional[Path] = None,
|
||||||
|
config: Optional[Path] = None,
|
||||||
|
config_field: Optional[str] = None,
|
||||||
|
force: bool = False,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
dataset_field: Optional[str] = None,
|
||||||
|
verbose: int = 0,
|
||||||
|
):
|
||||||
|
logger.remove()
|
||||||
|
if verbose == 0:
|
||||||
|
log_level = "WARNING"
|
||||||
|
elif verbose == 1:
|
||||||
|
log_level = "INFO"
|
||||||
|
else:
|
||||||
|
log_level = "DEBUG"
|
||||||
|
logger.add(sys.stderr, level=log_level)
|
||||||
|
|
||||||
|
logger.info("Starting preprocessing.")
|
||||||
|
|
||||||
|
output = Path(output)
|
||||||
|
logger.info("Will save outputs to {output}", output=output)
|
||||||
|
|
||||||
|
base_dir = base_dir or Path.cwd()
|
||||||
|
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
|
||||||
|
|
||||||
|
if config:
|
||||||
|
logger.info(
|
||||||
|
"Loading preprocessing config from: {config}", config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
conf = (
|
||||||
|
load_train_preprocessing_config(config, field=config_field)
|
||||||
|
if config is not None
|
||||||
|
else TrainPreprocessConfig()
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
"Preprocessing config:\n{conf}",
|
||||||
|
conf=yaml.dump(conf.model_dump()),
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = load_dataset_from_config(
|
||||||
|
dataset_config,
|
||||||
|
field=dataset_field,
|
||||||
|
base_dir=base_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Loaded {num_examples} annotated clips from the configured dataset",
|
||||||
|
num_examples=len(dataset),
|
||||||
|
)
|
||||||
|
|
||||||
|
preprocess_dataset(
|
||||||
|
dataset,
|
||||||
|
conf,
|
||||||
|
output=output,
|
||||||
|
force=force,
|
||||||
|
max_workers=num_workers,
|
||||||
|
)
|
||||||
@ -6,24 +6,24 @@ import click
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
|
from batdetect2.train import (
|
||||||
|
FullTrainingConfig,
|
||||||
|
load_full_training_config,
|
||||||
|
train,
|
||||||
|
)
|
||||||
|
from batdetect2.train.dataset import list_preprocessed_files
|
||||||
|
|
||||||
__all__ = ["train_command"]
|
__all__ = ["train_command"]
|
||||||
|
|
||||||
|
|
||||||
@cli.command(name="train")
|
@cli.command(name="train")
|
||||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
@click.argument("train_dir", type=click.Path(exists=True))
|
||||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
@click.option("--val-dir", type=click.Path(exists=True))
|
||||||
@click.option("--model", "model_path", type=click.Path(exists=True))
|
@click.option("--model-path", type=click.Path(exists=True))
|
||||||
@click.option("--targets", "targets_config", type=click.Path(exists=True))
|
|
||||||
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
|
||||||
@click.option("--log-dir", type=click.Path(exists=True))
|
|
||||||
@click.option("--config", type=click.Path(exists=True))
|
@click.option("--config", type=click.Path(exists=True))
|
||||||
@click.option("--config-field", type=str)
|
@click.option("--config-field", type=str)
|
||||||
@click.option("--train-workers", type=int)
|
@click.option("--train-workers", type=int)
|
||||||
@click.option("--val-workers", type=int)
|
@click.option("--val-workers", type=int)
|
||||||
@click.option("--experiment-name", type=str)
|
|
||||||
@click.option("--run-name", type=str)
|
|
||||||
@click.option("--seed", type=int)
|
|
||||||
@click.option(
|
@click.option(
|
||||||
"-v",
|
"-v",
|
||||||
"--verbose",
|
"--verbose",
|
||||||
@ -31,29 +31,15 @@ __all__ = ["train_command"]
|
|||||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||||
)
|
)
|
||||||
def train_command(
|
def train_command(
|
||||||
train_dataset: Path,
|
train_dir: Path,
|
||||||
val_dataset: Optional[Path] = None,
|
val_dir: Optional[Path] = None,
|
||||||
model_path: Optional[Path] = None,
|
model_path: Optional[Path] = None,
|
||||||
ckpt_dir: Optional[Path] = None,
|
|
||||||
log_dir: Optional[Path] = None,
|
|
||||||
config: Optional[Path] = None,
|
config: Optional[Path] = None,
|
||||||
targets_config: Optional[Path] = None,
|
|
||||||
config_field: Optional[str] = None,
|
config_field: Optional[str] = None,
|
||||||
seed: Optional[int] = None,
|
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
val_workers: int = 0,
|
val_workers: int = 0,
|
||||||
experiment_name: Optional[str] = None,
|
|
||||||
run_name: Optional[str] = None,
|
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
):
|
):
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
|
||||||
from batdetect2.config import (
|
|
||||||
BatDetect2Config,
|
|
||||||
load_full_config,
|
|
||||||
)
|
|
||||||
from batdetect2.data import load_dataset_from_config
|
|
||||||
from batdetect2.targets import load_target_config
|
|
||||||
|
|
||||||
logger.remove()
|
logger.remove()
|
||||||
if verbose == 0:
|
if verbose == 0:
|
||||||
log_level = "WARNING"
|
log_level = "WARNING"
|
||||||
@ -62,53 +48,41 @@ def train_command(
|
|||||||
else:
|
else:
|
||||||
log_level = "DEBUG"
|
log_level = "DEBUG"
|
||||||
logger.add(sys.stderr, level=log_level)
|
logger.add(sys.stderr, level=log_level)
|
||||||
|
|
||||||
logger.info("Initiating training process...")
|
logger.info("Initiating training process...")
|
||||||
|
|
||||||
logger.info("Loading configuration...")
|
logger.info("Loading training configuration...")
|
||||||
conf = (
|
conf = (
|
||||||
load_full_config(config, field=config_field)
|
load_full_training_config(config, field=config_field)
|
||||||
if config is not None
|
if config is not None
|
||||||
else BatDetect2Config()
|
else FullTrainingConfig()
|
||||||
)
|
)
|
||||||
|
|
||||||
if targets_config is not None:
|
logger.info("Scanning for training and validation data...")
|
||||||
logger.info("Loading targets configuration...")
|
train_examples = list_preprocessed_files(train_dir)
|
||||||
conf = conf.model_copy(
|
|
||||||
update=dict(targets=load_target_config(targets_config))
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Loading training dataset...")
|
|
||||||
train_annotations = load_dataset_from_config(train_dataset)
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Loaded {num_annotations} training examples",
|
"Found {num_files} training examples in {path}",
|
||||||
num_annotations=len(train_annotations),
|
num_files=len(train_examples),
|
||||||
|
path=train_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
val_annotations = None
|
val_examples = None
|
||||||
if val_dataset is not None:
|
if val_dir is not None:
|
||||||
val_annotations = load_dataset_from_config(val_dataset)
|
val_examples = list_preprocessed_files(val_dir)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Loaded {num_annotations} validation examples",
|
"Found {num_files} validation examples in {path}",
|
||||||
num_annotations=len(val_annotations),
|
num_files=len(val_examples),
|
||||||
|
path=val_dir,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("No validation directory provided.")
|
logger.debug("No validation directory provided.")
|
||||||
|
|
||||||
logger.info("Configuration and data loaded. Starting training...")
|
logger.info("Configuration and data loaded. Starting training...")
|
||||||
|
train(
|
||||||
if model_path is None:
|
train_examples=train_examples,
|
||||||
api = BatDetect2API.from_config(conf)
|
val_examples=val_examples,
|
||||||
else:
|
config=conf,
|
||||||
api = BatDetect2API.from_checkpoint(model_path)
|
model_path=model_path,
|
||||||
|
|
||||||
return api.train(
|
|
||||||
train_annotations=train_annotations,
|
|
||||||
val_annotations=val_annotations,
|
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
val_workers=val_workers,
|
val_workers=val_workers,
|
||||||
checkpoint_dir=ckpt_dir,
|
|
||||||
log_dir=log_dir,
|
|
||||||
experiment_name=experiment_name,
|
|
||||||
run_name=run_name,
|
|
||||||
seed=seed,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from soundevent import data
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
from soundevent.types import ClassMapper
|
from soundevent.types import ClassMapper
|
||||||
|
|
||||||
|
from batdetect2.targets.terms import get_term_from_key
|
||||||
from batdetect2.types import (
|
from batdetect2.types import (
|
||||||
Annotation,
|
Annotation,
|
||||||
AudioLoaderAnnotationGroup,
|
AudioLoaderAnnotationGroup,
|
||||||
@ -172,9 +173,18 @@ def annotation_to_sound_event_annotation(
|
|||||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(key=label_key, value=annotation["class"]),
|
data.Tag(
|
||||||
data.Tag(key=event_key, value=annotation["event"]),
|
term=get_term_from_key(label_key),
|
||||||
data.Tag(key=individual_key, value=str(annotation["individual"])),
|
value=annotation["class"],
|
||||||
|
),
|
||||||
|
data.Tag(
|
||||||
|
term=get_term_from_key(event_key),
|
||||||
|
value=annotation["event"],
|
||||||
|
),
|
||||||
|
data.Tag(
|
||||||
|
term=get_term_from_key(individual_key),
|
||||||
|
value=str(annotation["individual"]),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -209,11 +219,17 @@ def annotation_to_sound_event_prediction(
|
|||||||
tags=[
|
tags=[
|
||||||
data.PredictedTag(
|
data.PredictedTag(
|
||||||
score=annotation["class_prob"],
|
score=annotation["class_prob"],
|
||||||
tag=data.Tag(key=label_key, value=annotation["class"]),
|
tag=data.Tag(
|
||||||
|
term=get_term_from_key(label_key),
|
||||||
|
value=annotation["class"],
|
||||||
|
),
|
||||||
),
|
),
|
||||||
data.PredictedTag(
|
data.PredictedTag(
|
||||||
score=annotation["det_prob"],
|
score=annotation["det_prob"],
|
||||||
tag=data.Tag(key=event_key, value=annotation["event"]),
|
tag=data.Tag(
|
||||||
|
term=get_term_from_key(event_key),
|
||||||
|
value=annotation["event"],
|
||||||
|
),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,42 +0,0 @@
|
|||||||
from typing import Literal, Optional
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent.data import PathLike
|
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig
|
|
||||||
from batdetect2.core import BaseConfig
|
|
||||||
from batdetect2.core.configs import load_config
|
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
|
||||||
from batdetect2.inference.config import InferenceConfig
|
|
||||||
from batdetect2.models.config import BackboneConfig
|
|
||||||
from batdetect2.postprocess.config import PostprocessConfig
|
|
||||||
from batdetect2.preprocess.config import PreprocessingConfig
|
|
||||||
from batdetect2.targets.config import TargetConfig
|
|
||||||
from batdetect2.train.config import TrainingConfig
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BatDetect2Config",
|
|
||||||
"load_full_config",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2Config(BaseConfig):
|
|
||||||
config_version: Literal["v1"] = "v1"
|
|
||||||
|
|
||||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
|
||||||
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
|
||||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
|
||||||
preprocess: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
|
||||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def load_full_config(
|
|
||||||
path: PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> BatDetect2Config:
|
|
||||||
return load_config(path, schema=BatDetect2Config, field=field)
|
|
||||||
@ -27,7 +27,7 @@ class BaseConfig(BaseModel):
|
|||||||
and serialization capabilities.
|
and serialization capabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="ignore")
|
||||||
|
|
||||||
def to_yaml_string(
|
def to_yaml_string(
|
||||||
self,
|
self,
|
||||||
@ -53,7 +53,6 @@ class BaseConfig(BaseModel):
|
|||||||
"""
|
"""
|
||||||
return yaml.dump(
|
return yaml.dump(
|
||||||
self.model_dump(
|
self.model_dump(
|
||||||
mode="json",
|
|
||||||
exclude_none=exclude_none,
|
exclude_none=exclude_none,
|
||||||
exclude_unset=exclude_unset,
|
exclude_unset=exclude_unset,
|
||||||
exclude_defaults=exclude_defaults,
|
exclude_defaults=exclude_defaults,
|
||||||
@ -1,8 +0,0 @@
|
|||||||
from batdetect2.core.configs import BaseConfig, load_config
|
|
||||||
from batdetect2.core.registries import Registry
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseConfig",
|
|
||||||
"load_config",
|
|
||||||
"Registry",
|
|
||||||
]
|
|
||||||
@ -1,95 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import xarray as xr
|
|
||||||
|
|
||||||
|
|
||||||
def spec_to_xarray(
|
|
||||||
spec: np.ndarray,
|
|
||||||
start_time: float,
|
|
||||||
end_time: float,
|
|
||||||
min_freq: float,
|
|
||||||
max_freq: float,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
if spec.ndim != 2:
|
|
||||||
raise ValueError(
|
|
||||||
"Input numpy spectrogram array should be 2-dimensional"
|
|
||||||
)
|
|
||||||
|
|
||||||
height, width = spec.shape
|
|
||||||
return xr.DataArray(
|
|
||||||
data=spec,
|
|
||||||
dims=["frequency", "time"],
|
|
||||||
coords={
|
|
||||||
"frequency": np.linspace(
|
|
||||||
min_freq,
|
|
||||||
max_freq,
|
|
||||||
height,
|
|
||||||
endpoint=False,
|
|
||||||
),
|
|
||||||
"time": np.linspace(
|
|
||||||
start_time,
|
|
||||||
end_time,
|
|
||||||
width,
|
|
||||||
endpoint=False,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def extend_width(
|
|
||||||
tensor: torch.Tensor,
|
|
||||||
extra: int,
|
|
||||||
axis: int = -1,
|
|
||||||
value: float = 0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
dims = len(tensor.shape)
|
|
||||||
axis = dims - axis % dims - 1
|
|
||||||
pad = [0 for _ in range(2 * dims)]
|
|
||||||
pad[2 * axis + 1] = extra
|
|
||||||
return torch.nn.functional.pad(
|
|
||||||
tensor,
|
|
||||||
pad,
|
|
||||||
mode="constant",
|
|
||||||
value=value,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_width(
|
|
||||||
tensor: torch.Tensor,
|
|
||||||
width: int,
|
|
||||||
axis: int = -1,
|
|
||||||
value: float = 0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
dims = len(tensor.shape)
|
|
||||||
axis = axis % dims
|
|
||||||
current_width = tensor.shape[axis]
|
|
||||||
|
|
||||||
if current_width == width:
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
if current_width < width:
|
|
||||||
return extend_width(
|
|
||||||
tensor,
|
|
||||||
extra=width - current_width,
|
|
||||||
axis=axis,
|
|
||||||
value=value,
|
|
||||||
)
|
|
||||||
|
|
||||||
slices = [
|
|
||||||
slice(None, None) if index != axis else slice(None, width)
|
|
||||||
for index in range(dims)
|
|
||||||
]
|
|
||||||
return tensor[tuple(slices)]
|
|
||||||
|
|
||||||
|
|
||||||
def slice_tensor(
|
|
||||||
tensor: torch.Tensor,
|
|
||||||
start: Optional[int] = None,
|
|
||||||
end: Optional[int] = None,
|
|
||||||
dim: int = -1,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
slices = [slice(None)] * tensor.ndim
|
|
||||||
slices[dim] = slice(start, end)
|
|
||||||
return tensor[tuple(slices)]
|
|
||||||
@ -1,98 +0,0 @@
|
|||||||
import sys
|
|
||||||
from typing import Callable, Dict, Generic, Tuple, Type, TypeVar
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 10):
|
|
||||||
from typing import Concatenate, ParamSpec
|
|
||||||
else:
|
|
||||||
from typing_extensions import Concatenate, ParamSpec
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Registry",
|
|
||||||
"SimpleRegistry",
|
|
||||||
]
|
|
||||||
|
|
||||||
T_Config = TypeVar("T_Config", bound=BaseModel, contravariant=True)
|
|
||||||
T_Type = TypeVar("T_Type", covariant=True)
|
|
||||||
P_Type = ParamSpec("P_Type")
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleRegistry(Generic[T]):
|
|
||||||
def __init__(self, name: str):
|
|
||||||
self._name = name
|
|
||||||
self._registry = {}
|
|
||||||
|
|
||||||
def register(self, name: str):
|
|
||||||
def decorator(obj: T) -> T:
|
|
||||||
self._registry[name] = obj
|
|
||||||
return obj
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def get(self, name: str) -> T:
|
|
||||||
return self._registry[name]
|
|
||||||
|
|
||||||
def has(self, name: str) -> bool:
|
|
||||||
return name in self._registry
|
|
||||||
|
|
||||||
|
|
||||||
class Registry(Generic[T_Type, P_Type]):
|
|
||||||
"""A generic class to create and manage a registry of items."""
|
|
||||||
|
|
||||||
def __init__(self, name: str):
|
|
||||||
self._name = name
|
|
||||||
self._registry: Dict[
|
|
||||||
str, Callable[Concatenate[..., P_Type], T_Type]
|
|
||||||
] = {}
|
|
||||||
self._config_types: Dict[str, Type[BaseModel]] = {}
|
|
||||||
|
|
||||||
def register(
|
|
||||||
self,
|
|
||||||
config_cls: Type[T_Config],
|
|
||||||
):
|
|
||||||
fields = config_cls.model_fields
|
|
||||||
|
|
||||||
if "name" not in fields:
|
|
||||||
raise ValueError("Configuration object must have a 'name' field.")
|
|
||||||
|
|
||||||
name = fields["name"].default
|
|
||||||
|
|
||||||
self._config_types[name] = config_cls
|
|
||||||
|
|
||||||
if not isinstance(name, str):
|
|
||||||
raise ValueError("'name' field must be a string literal.")
|
|
||||||
|
|
||||||
def decorator(
|
|
||||||
func: Callable[Concatenate[T_Config, P_Type], T_Type],
|
|
||||||
):
|
|
||||||
self._registry[name] = func
|
|
||||||
return func
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def get_config_types(self) -> Tuple[Type[BaseModel], ...]:
|
|
||||||
return tuple(self._config_types.values())
|
|
||||||
|
|
||||||
def build(
|
|
||||||
self,
|
|
||||||
config: BaseModel,
|
|
||||||
*args: P_Type.args,
|
|
||||||
**kwargs: P_Type.kwargs,
|
|
||||||
) -> T_Type:
|
|
||||||
"""Builds a logic instance from a config object."""
|
|
||||||
|
|
||||||
name = getattr(config, "name") # noqa: B009
|
|
||||||
|
|
||||||
if name is None:
|
|
||||||
raise ValueError("Config does not have a name field")
|
|
||||||
|
|
||||||
if name not in self._registry:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"No {self._name} with name '{name}' is registered."
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._registry[name](config, *args, **kwargs)
|
|
||||||
@ -14,9 +14,8 @@ format-specific loading function to retrieve the annotations as a standard
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.data.annotations.aoef import (
|
from batdetect2.data.annotations.aoef import (
|
||||||
@ -43,13 +42,10 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
AnnotationFormats = Annotated[
|
AnnotationFormats = Union[
|
||||||
Union[
|
BatDetect2MergedAnnotations,
|
||||||
BatDetect2MergedAnnotations,
|
BatDetect2FilesAnnotations,
|
||||||
BatDetect2FilesAnnotations,
|
AOEFAnnotations,
|
||||||
AOEFAnnotations,
|
|
||||||
],
|
|
||||||
Field(discriminator="format"),
|
|
||||||
]
|
]
|
||||||
"""Type Alias representing all supported data source configurations.
|
"""Type Alias representing all supported data source configurations.
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from uuid import uuid5
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data, io
|
from soundevent import data, io
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -33,7 +33,7 @@ from loguru import logger
|
|||||||
from pydantic import Field, ValidationError
|
from pydantic import Field, ValidationError
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.data.annotations.legacy import (
|
from batdetect2.data.annotations.legacy import (
|
||||||
FileAnnotation,
|
FileAnnotation,
|
||||||
file_annotation_to_clip,
|
file_annotation_to_clip,
|
||||||
@ -301,8 +301,7 @@ def load_batdetect2_merged_annotated_dataset(
|
|||||||
for ann in content:
|
for ann in content:
|
||||||
try:
|
try:
|
||||||
ann = FileAnnotation.model_validate(ann)
|
ann = FileAnnotation.model_validate(ann)
|
||||||
except ValueError as err:
|
except ValueError:
|
||||||
logger.warning(f"Invalid annotation file: {err}")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -310,17 +309,14 @@ def load_batdetect2_merged_annotated_dataset(
|
|||||||
and dataset.filter.only_annotated
|
and dataset.filter.only_annotated
|
||||||
and not ann.annotated
|
and not ann.annotated
|
||||||
):
|
):
|
||||||
logger.debug(f"Skipping incomplete annotation {ann.id}")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
|
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
|
||||||
logger.debug(f"Skipping annotation with issues {ann.id}")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
||||||
except FileNotFoundError as err:
|
except FileNotFoundError:
|
||||||
logger.warning(f"Error loading annotations: {err}")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
||||||
|
|||||||
@ -8,6 +8,8 @@ from typing import Callable, List, Optional, Union
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.targets import get_term_from_key
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
@ -89,9 +91,18 @@ def annotation_to_sound_event(
|
|||||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(key=label_key, value=annotation.label),
|
data.Tag(
|
||||||
data.Tag(key=event_key, value=annotation.event),
|
term=get_term_from_key(label_key),
|
||||||
data.Tag(key=individual_key, value=str(annotation.individual)),
|
value=annotation.label,
|
||||||
|
),
|
||||||
|
data.Tag(
|
||||||
|
term=get_term_from_key(event_key),
|
||||||
|
value=annotation.event,
|
||||||
|
),
|
||||||
|
data.Tag(
|
||||||
|
term=get_term_from_key(individual_key),
|
||||||
|
value=str(annotation.individual),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -112,7 +123,12 @@ def file_annotation_to_clip(
|
|||||||
recording = data.Recording.from_file(
|
recording = data.Recording.from_file(
|
||||||
full_path,
|
full_path,
|
||||||
time_expansion=file_annotation.time_exp,
|
time_expansion=file_annotation.time_exp,
|
||||||
tags=[data.Tag(key=label_key, value=file_annotation.label)],
|
tags=[
|
||||||
|
data.Tag(
|
||||||
|
term=get_term_from_key(label_key),
|
||||||
|
value=file_annotation.label,
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return data.Clip(
|
return data.Clip(
|
||||||
@ -139,7 +155,11 @@ def file_annotation_to_clip_annotation(
|
|||||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
||||||
clip=clip,
|
clip=clip,
|
||||||
notes=notes,
|
notes=notes,
|
||||||
tags=[data.Tag(key=label_key, value=file_annotation.label)],
|
tags=[
|
||||||
|
data.Tag(
|
||||||
|
term=get_term_from_key(label_key), value=file_annotation.label
|
||||||
|
)
|
||||||
|
],
|
||||||
sound_events=[
|
sound_events=[
|
||||||
annotation_to_sound_event(
|
annotation_to_sound_event(
|
||||||
annotation,
|
annotation,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnnotatedDataset",
|
"AnnotatedDataset",
|
||||||
|
|||||||
@ -1,287 +0,0 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from typing import Annotated, List, Literal, Sequence, Union
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
from soundevent.geometry import compute_bounds
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
from batdetect2.core.registries import Registry
|
|
||||||
|
|
||||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
|
||||||
|
|
||||||
conditions: Registry[SoundEventCondition, []] = Registry("condition")
|
|
||||||
|
|
||||||
|
|
||||||
class HasTagConfig(BaseConfig):
|
|
||||||
name: Literal["has_tag"] = "has_tag"
|
|
||||||
tag: data.Tag
|
|
||||||
|
|
||||||
|
|
||||||
class HasTag:
|
|
||||||
def __init__(self, tag: data.Tag):
|
|
||||||
self.tag = tag
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return self.tag in sound_event_annotation.tags
|
|
||||||
|
|
||||||
@conditions.register(HasTagConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: HasTagConfig):
|
|
||||||
return HasTag(tag=config.tag)
|
|
||||||
|
|
||||||
|
|
||||||
class HasAllTagsConfig(BaseConfig):
|
|
||||||
name: Literal["has_all_tags"] = "has_all_tags"
|
|
||||||
tags: List[data.Tag]
|
|
||||||
|
|
||||||
|
|
||||||
class HasAllTags:
|
|
||||||
def __init__(self, tags: List[data.Tag]):
|
|
||||||
if not tags:
|
|
||||||
raise ValueError("Need to specify at least one tag")
|
|
||||||
|
|
||||||
self.tags = set(tags)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return self.tags.issubset(sound_event_annotation.tags)
|
|
||||||
|
|
||||||
@conditions.register(HasAllTagsConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: HasAllTagsConfig):
|
|
||||||
return HasAllTags(tags=config.tags)
|
|
||||||
|
|
||||||
|
|
||||||
class HasAnyTagConfig(BaseConfig):
|
|
||||||
name: Literal["has_any_tag"] = "has_any_tag"
|
|
||||||
tags: List[data.Tag]
|
|
||||||
|
|
||||||
|
|
||||||
class HasAnyTag:
|
|
||||||
def __init__(self, tags: List[data.Tag]):
|
|
||||||
if not tags:
|
|
||||||
raise ValueError("Need to specify at least one tag")
|
|
||||||
|
|
||||||
self.tags = set(tags)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return bool(self.tags.intersection(sound_event_annotation.tags))
|
|
||||||
|
|
||||||
@conditions.register(HasAnyTagConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: HasAnyTagConfig):
|
|
||||||
return HasAnyTag(tags=config.tags)
|
|
||||||
|
|
||||||
|
|
||||||
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
|
||||||
|
|
||||||
|
|
||||||
class DurationConfig(BaseConfig):
|
|
||||||
name: Literal["duration"] = "duration"
|
|
||||||
operator: Operator
|
|
||||||
seconds: float
|
|
||||||
|
|
||||||
|
|
||||||
def _build_comparator(
|
|
||||||
operator: Operator, value: float
|
|
||||||
) -> Callable[[float], bool]:
|
|
||||||
if operator == "gt":
|
|
||||||
return lambda x: x > value
|
|
||||||
|
|
||||||
if operator == "gte":
|
|
||||||
return lambda x: x >= value
|
|
||||||
|
|
||||||
if operator == "lt":
|
|
||||||
return lambda x: x < value
|
|
||||||
|
|
||||||
if operator == "lte":
|
|
||||||
return lambda x: x <= value
|
|
||||||
|
|
||||||
if operator == "eq":
|
|
||||||
return lambda x: x == value
|
|
||||||
|
|
||||||
raise ValueError(f"Invalid operator {operator}")
|
|
||||||
|
|
||||||
|
|
||||||
class Duration:
|
|
||||||
def __init__(self, operator: Operator, seconds: float):
|
|
||||||
self.operator = operator
|
|
||||||
self.seconds = seconds
|
|
||||||
self._comparator = _build_comparator(self.operator, self.seconds)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> bool:
|
|
||||||
geometry = sound_event_annotation.sound_event.geometry
|
|
||||||
|
|
||||||
if geometry is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
start_time, _, end_time, _ = compute_bounds(geometry)
|
|
||||||
duration = end_time - start_time
|
|
||||||
|
|
||||||
return self._comparator(duration)
|
|
||||||
|
|
||||||
@conditions.register(DurationConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: DurationConfig):
|
|
||||||
return Duration(operator=config.operator, seconds=config.seconds)
|
|
||||||
|
|
||||||
|
|
||||||
class FrequencyConfig(BaseConfig):
|
|
||||||
name: Literal["frequency"] = "frequency"
|
|
||||||
boundary: Literal["low", "high"]
|
|
||||||
operator: Operator
|
|
||||||
hertz: float
|
|
||||||
|
|
||||||
|
|
||||||
class Frequency:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
operator: Operator,
|
|
||||||
boundary: Literal["low", "high"],
|
|
||||||
hertz: float,
|
|
||||||
):
|
|
||||||
self.operator = operator
|
|
||||||
self.hertz = hertz
|
|
||||||
self.boundary = boundary
|
|
||||||
self._comparator = _build_comparator(self.operator, self.hertz)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> bool:
|
|
||||||
geometry = sound_event_annotation.sound_event.geometry
|
|
||||||
|
|
||||||
if geometry is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Automatically false if geometry does not have a frequency range
|
|
||||||
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
|
|
||||||
return False
|
|
||||||
|
|
||||||
_, low_freq, _, high_freq = compute_bounds(geometry)
|
|
||||||
|
|
||||||
if self.boundary == "low":
|
|
||||||
return self._comparator(low_freq)
|
|
||||||
|
|
||||||
return self._comparator(high_freq)
|
|
||||||
|
|
||||||
@conditions.register(FrequencyConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: FrequencyConfig):
|
|
||||||
return Frequency(
|
|
||||||
operator=config.operator,
|
|
||||||
boundary=config.boundary,
|
|
||||||
hertz=config.hertz,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AllOfConfig(BaseConfig):
|
|
||||||
name: Literal["all_of"] = "all_of"
|
|
||||||
conditions: Sequence["SoundEventConditionConfig"]
|
|
||||||
|
|
||||||
|
|
||||||
class AllOf:
|
|
||||||
def __init__(self, conditions: List[SoundEventCondition]):
|
|
||||||
self.conditions = conditions
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return all(c(sound_event_annotation) for c in self.conditions)
|
|
||||||
|
|
||||||
@conditions.register(AllOfConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: AllOfConfig):
|
|
||||||
conditions = [
|
|
||||||
build_sound_event_condition(cond) for cond in config.conditions
|
|
||||||
]
|
|
||||||
return AllOf(conditions)
|
|
||||||
|
|
||||||
|
|
||||||
class AnyOfConfig(BaseConfig):
|
|
||||||
name: Literal["any_of"] = "any_of"
|
|
||||||
conditions: List["SoundEventConditionConfig"]
|
|
||||||
|
|
||||||
|
|
||||||
class AnyOf:
|
|
||||||
def __init__(self, conditions: List[SoundEventCondition]):
|
|
||||||
self.conditions = conditions
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return any(c(sound_event_annotation) for c in self.conditions)
|
|
||||||
|
|
||||||
@conditions.register(AnyOfConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: AnyOfConfig):
|
|
||||||
conditions = [
|
|
||||||
build_sound_event_condition(cond) for cond in config.conditions
|
|
||||||
]
|
|
||||||
return AnyOf(conditions)
|
|
||||||
|
|
||||||
|
|
||||||
class NotConfig(BaseConfig):
|
|
||||||
name: Literal["not"] = "not"
|
|
||||||
condition: "SoundEventConditionConfig"
|
|
||||||
|
|
||||||
|
|
||||||
class Not:
|
|
||||||
def __init__(self, condition: SoundEventCondition):
|
|
||||||
self.condition = condition
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
|
||||||
) -> bool:
|
|
||||||
return not self.condition(sound_event_annotation)
|
|
||||||
|
|
||||||
@conditions.register(NotConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: NotConfig):
|
|
||||||
condition = build_sound_event_condition(config.condition)
|
|
||||||
return Not(condition)
|
|
||||||
|
|
||||||
|
|
||||||
SoundEventConditionConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
HasTagConfig,
|
|
||||||
HasAllTagsConfig,
|
|
||||||
HasAnyTagConfig,
|
|
||||||
DurationConfig,
|
|
||||||
FrequencyConfig,
|
|
||||||
AllOfConfig,
|
|
||||||
AnyOfConfig,
|
|
||||||
NotConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_sound_event_condition(
|
|
||||||
config: SoundEventConditionConfig,
|
|
||||||
) -> SoundEventCondition:
|
|
||||||
return conditions.build(config)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_clip_annotation(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
condition: SoundEventCondition,
|
|
||||||
) -> data.ClipAnnotation:
|
|
||||||
return clip_annotation.model_copy(
|
|
||||||
update=dict(
|
|
||||||
sound_events=[
|
|
||||||
sound_event
|
|
||||||
for sound_event in clip_annotation.sound_events
|
|
||||||
if condition(sound_event)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@ -19,29 +19,18 @@ The core components are:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import Annotated, List, Optional
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data, io
|
from soundevent import data, io
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.data.annotations import (
|
from batdetect2.data.annotations import (
|
||||||
AnnotatedDataset,
|
AnnotatedDataset,
|
||||||
AnnotationFormats,
|
AnnotationFormats,
|
||||||
load_annotated_dataset,
|
load_annotated_dataset,
|
||||||
)
|
)
|
||||||
from batdetect2.data.conditions import (
|
|
||||||
SoundEventConditionConfig,
|
|
||||||
build_sound_event_condition,
|
|
||||||
filter_clip_annotation,
|
|
||||||
)
|
|
||||||
from batdetect2.data.transforms import (
|
|
||||||
ApplyAll,
|
|
||||||
SoundEventTransformConfig,
|
|
||||||
build_sound_event_transform,
|
|
||||||
transform_clip_annotation,
|
|
||||||
)
|
|
||||||
from batdetect2.targets.terms import data_source
|
from batdetect2.targets.terms import data_source
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -63,68 +52,79 @@ sources.
|
|||||||
|
|
||||||
|
|
||||||
class DatasetConfig(BaseConfig):
|
class DatasetConfig(BaseConfig):
|
||||||
"""Configuration model defining the structure of a BatDetect2 dataset."""
|
"""Configuration model defining the structure of a BatDetect2 dataset.
|
||||||
|
|
||||||
|
This class is typically loaded from a YAML file and describes the components
|
||||||
|
of the dataset, including metadata and a list of data sources.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : str
|
||||||
|
A descriptive name for the dataset (e.g., "UK_Bats_Project_2024").
|
||||||
|
description : str
|
||||||
|
A longer description of the dataset's contents, origin, purpose, etc.
|
||||||
|
sources : List[AnnotationFormats]
|
||||||
|
A list defining the different data sources contributing to this
|
||||||
|
dataset. Each item in the list must conform to one of the Pydantic
|
||||||
|
models defined in the `AnnotationFormats` type union. The specific
|
||||||
|
model used for each source is determined by the mandatory `format`
|
||||||
|
field within the source's configuration, allowing BatDetect2 to use the
|
||||||
|
correct parser for different annotation styles.
|
||||||
|
"""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
sources: List[AnnotationFormats]
|
sources: List[
|
||||||
|
Annotated[AnnotationFormats, Field(..., discriminator="format")]
|
||||||
sound_event_filter: Optional[SoundEventConditionConfig] = None
|
]
|
||||||
sound_event_transforms: List[SoundEventTransformConfig] = Field(
|
|
||||||
default_factory=list
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(
|
def load_dataset(
|
||||||
config: DatasetConfig,
|
dataset: DatasetConfig,
|
||||||
base_dir: Optional[Path] = None,
|
base_dir: Optional[Path] = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
"""Load all clip annotations from the sources defined in a DatasetConfig.
|
||||||
|
|
||||||
|
Iterates through each data source specified in the `dataset_config`,
|
||||||
|
delegates the loading and parsing of that source's annotations to
|
||||||
|
`batdetect2.data.annotations.load_annotated_dataset` (which handles
|
||||||
|
different data formats), and aggregates all resulting `ClipAnnotation`
|
||||||
|
objects into a single flat list.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dataset_config : DatasetConfig
|
||||||
|
The configuration object describing the dataset and its sources.
|
||||||
|
base_dir : Path, optional
|
||||||
|
An optional base directory path. If provided, relative paths for
|
||||||
|
metadata files or data directories within the `dataset_config`'s
|
||||||
|
sources might be resolved relative to this directory. Defaults to None.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Dataset (List[data.ClipAnnotation])
|
||||||
|
A flat list containing all loaded `ClipAnnotation` metadata objects
|
||||||
|
from all specified sources.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
Exception
|
||||||
|
Can raise various exceptions during the delegated loading process
|
||||||
|
(`load_annotated_dataset`) if files are not found, cannot be parsed
|
||||||
|
according to the specified format, or other I/O errors occur.
|
||||||
|
"""
|
||||||
clip_annotations = []
|
clip_annotations = []
|
||||||
|
for source in dataset.sources:
|
||||||
condition = (
|
|
||||||
build_sound_event_condition(config.sound_event_filter)
|
|
||||||
if config.sound_event_filter is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
transform = (
|
|
||||||
ApplyAll(
|
|
||||||
[
|
|
||||||
build_sound_event_transform(step)
|
|
||||||
for step in config.sound_event_transforms
|
|
||||||
]
|
|
||||||
)
|
|
||||||
if config.sound_event_transforms
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
for source in config.sources:
|
|
||||||
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
|
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Loaded {num_examples} from dataset source '{source_name}'",
|
"Loaded {num_examples} from dataset source '{source_name}'",
|
||||||
num_examples=len(annotated_source.clip_annotations),
|
num_examples=len(annotated_source.clip_annotations),
|
||||||
source_name=source.name,
|
source_name=source.name,
|
||||||
)
|
)
|
||||||
|
clip_annotations.extend(
|
||||||
for clip_annotation in annotated_source.clip_annotations:
|
insert_source_tag(clip_annotation, source)
|
||||||
clip_annotation = insert_source_tag(clip_annotation, source)
|
for clip_annotation in annotated_source.clip_annotations
|
||||||
|
)
|
||||||
if condition is not None:
|
|
||||||
clip_annotation = filter_clip_annotation(
|
|
||||||
clip_annotation,
|
|
||||||
condition,
|
|
||||||
)
|
|
||||||
|
|
||||||
if transform is not None:
|
|
||||||
clip_annotation = transform_clip_annotation(
|
|
||||||
clip_annotation,
|
|
||||||
transform,
|
|
||||||
)
|
|
||||||
|
|
||||||
clip_annotations.append(clip_annotation)
|
|
||||||
|
|
||||||
return clip_annotations
|
return clip_annotations
|
||||||
|
|
||||||
|
|
||||||
@ -161,6 +161,7 @@ def insert_source_tag(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: add documentation
|
||||||
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
||||||
return load_config(path=path, schema=DatasetConfig, field=field)
|
return load_config(path=path, schema=DatasetConfig, field=field)
|
||||||
|
|
||||||
|
|||||||
@ -4,14 +4,22 @@ from typing import Optional, Tuple
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
def iterate_over_sound_events(
|
def iterate_over_sound_events(
|
||||||
dataset: Dataset,
|
dataset: Dataset,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
apply_filter: bool = True,
|
||||||
|
apply_transform: bool = True,
|
||||||
|
exclude_generic: bool = True,
|
||||||
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
|
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
|
||||||
"""Iterate over sound events in a dataset.
|
"""Iterate over sound events in a dataset, applying filtering and
|
||||||
|
transformations.
|
||||||
|
|
||||||
|
This generator function processes sound event annotations from a given
|
||||||
|
dataset, allowing for optional filtering, transformation, and exclusion of
|
||||||
|
unclassifiable (generic) events based on the provided target definitions.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@ -21,6 +29,18 @@ def iterate_over_sound_events(
|
|||||||
targets : TargetProtocol
|
targets : TargetProtocol
|
||||||
An object implementing the `TargetProtocol`, which provides methods
|
An object implementing the `TargetProtocol`, which provides methods
|
||||||
for filtering, transforming, and encoding sound events.
|
for filtering, transforming, and encoding sound events.
|
||||||
|
apply_filter : bool, optional
|
||||||
|
If True, sound events will be filtered using `targets.filter()`.
|
||||||
|
Only events for which `targets.filter()` returns True will be yielded.
|
||||||
|
Defaults to True.
|
||||||
|
apply_transform : bool, optional
|
||||||
|
If True, sound events will be transformed using `targets.transform()`
|
||||||
|
before being yielded. Defaults to True.
|
||||||
|
exclude_generic : bool, optional
|
||||||
|
If True, sound events that result in a `None` class name after
|
||||||
|
`targets.encode()` will be excluded. This is typically used to
|
||||||
|
filter out events that cannot be mapped to a specific target class.
|
||||||
|
Defaults to True.
|
||||||
|
|
||||||
Yields
|
Yields
|
||||||
------
|
------
|
||||||
@ -43,9 +63,17 @@ def iterate_over_sound_events(
|
|||||||
"""
|
"""
|
||||||
for clip_annotation in dataset:
|
for clip_annotation in dataset:
|
||||||
for sound_event_annotation in clip_annotation.sound_events:
|
for sound_event_annotation in clip_annotation.sound_events:
|
||||||
if not targets.filter(sound_event_annotation):
|
if apply_filter:
|
||||||
continue
|
if not targets.filter(sound_event_annotation):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if apply_transform:
|
||||||
|
sound_event_annotation = targets.transform(
|
||||||
|
sound_event_annotation
|
||||||
|
)
|
||||||
|
|
||||||
class_name = targets.encode_class(sound_event_annotation)
|
class_name = targets.encode_class(sound_event_annotation)
|
||||||
|
if class_name is None and exclude_generic:
|
||||||
|
continue
|
||||||
|
|
||||||
yield class_name, sound_event_annotation
|
yield class_name, sound_event_annotation
|
||||||
|
|||||||
@ -7,7 +7,7 @@ from batdetect2.data.summary import (
|
|||||||
extract_recordings_df,
|
extract_recordings_df,
|
||||||
extract_sound_events_df,
|
extract_sound_events_df,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
def split_dataset_by_recordings(
|
def split_dataset_by_recordings(
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import pandas as pd
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"extract_recordings_df",
|
"extract_recordings_df",
|
||||||
@ -100,11 +100,8 @@ def extract_sound_events_df(
|
|||||||
|
|
||||||
class_name = targets.encode_class(sound_event)
|
class_name = targets.encode_class(sound_event)
|
||||||
|
|
||||||
if class_name is None:
|
if class_name is None and exclude_generic:
|
||||||
if exclude_generic:
|
continue
|
||||||
continue
|
|
||||||
else:
|
|
||||||
class_name = targets.detection_class_name
|
|
||||||
|
|
||||||
start_time, low_freq, end_time, high_freq = compute_bounds(
|
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||||
sound_event.sound_event.geometry
|
sound_event.sound_event.geometry
|
||||||
@ -156,7 +153,7 @@ def compute_class_summary(
|
|||||||
sound_events = extract_sound_events_df(
|
sound_events = extract_sound_events_df(
|
||||||
dataset,
|
dataset,
|
||||||
targets,
|
targets,
|
||||||
exclude_generic=False,
|
exclude_generic=True,
|
||||||
exclude_non_target=True,
|
exclude_non_target=True,
|
||||||
)
|
)
|
||||||
recordings = extract_recordings_df(dataset)
|
recordings = extract_recordings_df(dataset)
|
||||||
|
|||||||
@ -1,252 +0,0 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from typing import Annotated, Dict, List, Literal, Optional, Union
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
from batdetect2.core.registries import Registry
|
|
||||||
from batdetect2.data.conditions import (
|
|
||||||
SoundEventCondition,
|
|
||||||
SoundEventConditionConfig,
|
|
||||||
build_sound_event_condition,
|
|
||||||
)
|
|
||||||
|
|
||||||
SoundEventTransform = Callable[
|
|
||||||
[data.SoundEventAnnotation],
|
|
||||||
data.SoundEventAnnotation,
|
|
||||||
]
|
|
||||||
|
|
||||||
transforms: Registry[SoundEventTransform, []] = Registry("transform")
|
|
||||||
|
|
||||||
|
|
||||||
class SetFrequencyBoundConfig(BaseConfig):
|
|
||||||
name: Literal["set_frequency"] = "set_frequency"
|
|
||||||
boundary: Literal["low", "high"] = "low"
|
|
||||||
hertz: float
|
|
||||||
|
|
||||||
|
|
||||||
class SetFrequencyBound:
|
|
||||||
def __init__(self, hertz: float, boundary: Literal["low", "high"] = "low"):
|
|
||||||
self.hertz = hertz
|
|
||||||
self.boundary = boundary
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
sound_event = sound_event_annotation.sound_event
|
|
||||||
geometry = sound_event.geometry
|
|
||||||
|
|
||||||
if geometry is None:
|
|
||||||
return sound_event_annotation
|
|
||||||
|
|
||||||
if not isinstance(geometry, data.BoundingBox):
|
|
||||||
return sound_event_annotation
|
|
||||||
|
|
||||||
start_time, low_freq, end_time, high_freq = geometry.coordinates
|
|
||||||
|
|
||||||
if self.boundary == "low":
|
|
||||||
low_freq = self.hertz
|
|
||||||
high_freq = max(high_freq, low_freq)
|
|
||||||
|
|
||||||
elif self.boundary == "high":
|
|
||||||
high_freq = self.hertz
|
|
||||||
low_freq = min(high_freq, low_freq)
|
|
||||||
|
|
||||||
geometry = data.BoundingBox(
|
|
||||||
coordinates=[start_time, low_freq, end_time, high_freq],
|
|
||||||
)
|
|
||||||
|
|
||||||
sound_event = sound_event.model_copy(update=dict(geometry=geometry))
|
|
||||||
return sound_event_annotation.model_copy(
|
|
||||||
update=dict(sound_event=sound_event)
|
|
||||||
)
|
|
||||||
|
|
||||||
@transforms.register(SetFrequencyBoundConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: SetFrequencyBoundConfig):
|
|
||||||
return SetFrequencyBound(hertz=config.hertz, boundary=config.boundary)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplyIfConfig(BaseConfig):
|
|
||||||
name: Literal["apply_if"] = "apply_if"
|
|
||||||
transform: "SoundEventTransformConfig"
|
|
||||||
condition: SoundEventConditionConfig
|
|
||||||
|
|
||||||
|
|
||||||
class ApplyIf:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
condition: SoundEventCondition,
|
|
||||||
transform: SoundEventTransform,
|
|
||||||
):
|
|
||||||
self.condition = condition
|
|
||||||
self.transform = transform
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
if not self.condition(sound_event_annotation):
|
|
||||||
return sound_event_annotation
|
|
||||||
|
|
||||||
return self.transform(sound_event_annotation)
|
|
||||||
|
|
||||||
@transforms.register(ApplyIfConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ApplyIfConfig):
|
|
||||||
transform = build_sound_event_transform(config.transform)
|
|
||||||
condition = build_sound_event_condition(config.condition)
|
|
||||||
return ApplyIf(condition=condition, transform=transform)
|
|
||||||
|
|
||||||
|
|
||||||
class ReplaceTagConfig(BaseConfig):
|
|
||||||
name: Literal["replace_tag"] = "replace_tag"
|
|
||||||
original: data.Tag
|
|
||||||
replacement: data.Tag
|
|
||||||
|
|
||||||
|
|
||||||
class ReplaceTag:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
original: data.Tag,
|
|
||||||
replacement: data.Tag,
|
|
||||||
):
|
|
||||||
self.original = original
|
|
||||||
self.replacement = replacement
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
tags = []
|
|
||||||
|
|
||||||
for tag in sound_event_annotation.tags:
|
|
||||||
if tag == self.original:
|
|
||||||
tags.append(self.replacement)
|
|
||||||
else:
|
|
||||||
tags.append(tag)
|
|
||||||
|
|
||||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
|
||||||
|
|
||||||
@transforms.register(ReplaceTagConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ReplaceTagConfig):
|
|
||||||
return ReplaceTag(
|
|
||||||
original=config.original, replacement=config.replacement
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MapTagValueConfig(BaseConfig):
|
|
||||||
name: Literal["map_tag_value"] = "map_tag_value"
|
|
||||||
tag_key: str
|
|
||||||
value_mapping: Dict[str, str]
|
|
||||||
target_key: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class MapTagValue:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tag_key: str,
|
|
||||||
value_mapping: Dict[str, str],
|
|
||||||
target_key: Optional[str] = None,
|
|
||||||
):
|
|
||||||
self.tag_key = tag_key
|
|
||||||
self.value_mapping = value_mapping
|
|
||||||
self.target_key = target_key
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
tags = []
|
|
||||||
|
|
||||||
for tag in sound_event_annotation.tags:
|
|
||||||
if tag.key != self.tag_key:
|
|
||||||
tags.append(tag)
|
|
||||||
continue
|
|
||||||
|
|
||||||
value = self.value_mapping.get(tag.value)
|
|
||||||
|
|
||||||
if value is None:
|
|
||||||
tags.append(tag)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if self.target_key is None:
|
|
||||||
tags.append(tag.model_copy(update=dict(value=value)))
|
|
||||||
else:
|
|
||||||
tags.append(
|
|
||||||
data.Tag(
|
|
||||||
key=self.target_key, # type: ignore
|
|
||||||
value=value,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return sound_event_annotation.model_copy(update=dict(tags=tags))
|
|
||||||
|
|
||||||
@transforms.register(MapTagValueConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: MapTagValueConfig):
|
|
||||||
return MapTagValue(
|
|
||||||
tag_key=config.tag_key,
|
|
||||||
value_mapping=config.value_mapping,
|
|
||||||
target_key=config.target_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplyAllConfig(BaseConfig):
|
|
||||||
name: Literal["apply_all"] = "apply_all"
|
|
||||||
steps: List["SoundEventTransformConfig"] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class ApplyAll:
|
|
||||||
def __init__(self, steps: List[SoundEventTransform]):
|
|
||||||
self.steps = steps
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
for step in self.steps:
|
|
||||||
sound_event_annotation = step(sound_event_annotation)
|
|
||||||
|
|
||||||
return sound_event_annotation
|
|
||||||
|
|
||||||
@transforms.register(ApplyAllConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ApplyAllConfig):
|
|
||||||
steps = [build_sound_event_transform(step) for step in config.steps]
|
|
||||||
return ApplyAll(steps)
|
|
||||||
|
|
||||||
|
|
||||||
SoundEventTransformConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
SetFrequencyBoundConfig,
|
|
||||||
ReplaceTagConfig,
|
|
||||||
MapTagValueConfig,
|
|
||||||
ApplyIfConfig,
|
|
||||||
ApplyAllConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_sound_event_transform(
|
|
||||||
config: SoundEventTransformConfig,
|
|
||||||
) -> SoundEventTransform:
|
|
||||||
return transforms.build(config)
|
|
||||||
|
|
||||||
|
|
||||||
def transform_clip_annotation(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
transform: SoundEventTransform,
|
|
||||||
) -> data.ClipAnnotation:
|
|
||||||
return clip_annotation.model_copy(
|
|
||||||
update=dict(
|
|
||||||
sound_events=[
|
|
||||||
transform(sound_event)
|
|
||||||
for sound_event in clip_annotation.sound_events
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@ -1,15 +1,15 @@
|
|||||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
from batdetect2.evaluate.config import (
|
||||||
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate
|
EvaluationConfig,
|
||||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
load_evaluation_config,
|
||||||
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
)
|
||||||
|
from batdetect2.evaluate.match import (
|
||||||
|
match_predictions_and_annotations,
|
||||||
|
match_sound_events_and_raw_predictions,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EvaluationConfig",
|
"EvaluationConfig",
|
||||||
"Evaluator",
|
|
||||||
"TaskConfig",
|
|
||||||
"build_evaluator",
|
|
||||||
"build_task",
|
|
||||||
"evaluate",
|
|
||||||
"load_evaluation_config",
|
"load_evaluation_config",
|
||||||
"DEFAULT_EVAL_DIR",
|
"match_predictions_and_annotations",
|
||||||
|
"match_sound_events_and_raw_predictions",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,230 +0,0 @@
|
|||||||
from typing import Annotated, Literal, Optional, Union
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
from soundevent.evaluation import compute_affinity
|
|
||||||
from soundevent.geometry import compute_interval_overlap
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
from batdetect2.core.registries import Registry
|
|
||||||
from batdetect2.typing.evaluate import AffinityFunction
|
|
||||||
|
|
||||||
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
|
||||||
"matching_strategy"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TimeAffinityConfig(BaseConfig):
|
|
||||||
name: Literal["time_affinity"] = "time_affinity"
|
|
||||||
time_buffer: float = 0.01
|
|
||||||
|
|
||||||
|
|
||||||
class TimeAffinity(AffinityFunction):
|
|
||||||
def __init__(self, time_buffer: float):
|
|
||||||
self.time_buffer = time_buffer
|
|
||||||
|
|
||||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
|
||||||
return compute_timestamp_affinity(
|
|
||||||
geometry1, geometry2, time_buffer=self.time_buffer
|
|
||||||
)
|
|
||||||
|
|
||||||
@affinity_functions.register(TimeAffinityConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: TimeAffinityConfig):
|
|
||||||
return TimeAffinity(time_buffer=config.time_buffer)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_timestamp_affinity(
|
|
||||||
geometry1: data.Geometry,
|
|
||||||
geometry2: data.Geometry,
|
|
||||||
time_buffer: float = 0.01,
|
|
||||||
) -> float:
|
|
||||||
assert isinstance(geometry1, data.TimeStamp)
|
|
||||||
assert isinstance(geometry2, data.TimeStamp)
|
|
||||||
|
|
||||||
start_time1 = geometry1.coordinates
|
|
||||||
start_time2 = geometry2.coordinates
|
|
||||||
|
|
||||||
a = min(start_time1, start_time2)
|
|
||||||
b = max(start_time1, start_time2)
|
|
||||||
|
|
||||||
if b - a >= 2 * time_buffer:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
intersection = a - b + 2 * time_buffer
|
|
||||||
union = b - a + 2 * time_buffer
|
|
||||||
return intersection / union
|
|
||||||
|
|
||||||
|
|
||||||
class IntervalIOUConfig(BaseConfig):
|
|
||||||
name: Literal["interval_iou"] = "interval_iou"
|
|
||||||
time_buffer: float = 0.01
|
|
||||||
|
|
||||||
|
|
||||||
class IntervalIOU(AffinityFunction):
|
|
||||||
def __init__(self, time_buffer: float):
|
|
||||||
self.time_buffer = time_buffer
|
|
||||||
|
|
||||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
|
||||||
return compute_interval_iou(
|
|
||||||
geometry1,
|
|
||||||
geometry2,
|
|
||||||
time_buffer=self.time_buffer,
|
|
||||||
)
|
|
||||||
|
|
||||||
@affinity_functions.register(IntervalIOUConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: IntervalIOUConfig):
|
|
||||||
return IntervalIOU(time_buffer=config.time_buffer)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_interval_iou(
|
|
||||||
geometry1: data.Geometry,
|
|
||||||
geometry2: data.Geometry,
|
|
||||||
time_buffer: float = 0.01,
|
|
||||||
) -> float:
|
|
||||||
assert isinstance(geometry1, data.TimeInterval)
|
|
||||||
assert isinstance(geometry2, data.TimeInterval)
|
|
||||||
|
|
||||||
start_time1, end_time1 = geometry1.coordinates
|
|
||||||
start_time2, end_time2 = geometry1.coordinates
|
|
||||||
|
|
||||||
start_time1 -= time_buffer
|
|
||||||
start_time2 -= time_buffer
|
|
||||||
end_time1 += time_buffer
|
|
||||||
end_time2 += time_buffer
|
|
||||||
|
|
||||||
intersection = compute_interval_overlap(
|
|
||||||
(start_time1, end_time1),
|
|
||||||
(start_time2, end_time2),
|
|
||||||
)
|
|
||||||
|
|
||||||
union = (
|
|
||||||
(end_time1 - start_time1) + (end_time2 - start_time2) - intersection
|
|
||||||
)
|
|
||||||
|
|
||||||
if union == 0:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
return intersection / union
|
|
||||||
|
|
||||||
|
|
||||||
class BBoxIOUConfig(BaseConfig):
|
|
||||||
name: Literal["bbox_iou"] = "bbox_iou"
|
|
||||||
time_buffer: float = 0.01
|
|
||||||
freq_buffer: float = 1000
|
|
||||||
|
|
||||||
|
|
||||||
class BBoxIOU(AffinityFunction):
|
|
||||||
def __init__(self, time_buffer: float, freq_buffer: float):
|
|
||||||
self.time_buffer = time_buffer
|
|
||||||
self.freq_buffer = freq_buffer
|
|
||||||
|
|
||||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
|
||||||
if not isinstance(geometry1, data.BoundingBox):
|
|
||||||
raise TypeError(
|
|
||||||
f"Expected geometry1 to be a BoundingBox, got {type(geometry1)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(geometry2, data.BoundingBox):
|
|
||||||
raise TypeError(
|
|
||||||
f"Expected geometry2 to be a BoundingBox, got {type(geometry2)}"
|
|
||||||
)
|
|
||||||
return bbox_iou(
|
|
||||||
geometry1,
|
|
||||||
geometry2,
|
|
||||||
time_buffer=self.time_buffer,
|
|
||||||
freq_buffer=self.freq_buffer,
|
|
||||||
)
|
|
||||||
|
|
||||||
@affinity_functions.register(BBoxIOUConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: BBoxIOUConfig):
|
|
||||||
return BBoxIOU(
|
|
||||||
time_buffer=config.time_buffer,
|
|
||||||
freq_buffer=config.freq_buffer,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def bbox_iou(
|
|
||||||
geometry1: data.BoundingBox,
|
|
||||||
geometry2: data.BoundingBox,
|
|
||||||
time_buffer: float = 0.01,
|
|
||||||
freq_buffer: float = 1000,
|
|
||||||
) -> float:
|
|
||||||
start_time1, low_freq1, end_time1, high_freq1 = geometry1.coordinates
|
|
||||||
start_time2, low_freq2, end_time2, high_freq2 = geometry2.coordinates
|
|
||||||
|
|
||||||
start_time1 -= time_buffer
|
|
||||||
start_time2 -= time_buffer
|
|
||||||
end_time1 += time_buffer
|
|
||||||
end_time2 += time_buffer
|
|
||||||
|
|
||||||
low_freq1 -= freq_buffer
|
|
||||||
low_freq2 -= freq_buffer
|
|
||||||
high_freq1 += freq_buffer
|
|
||||||
high_freq2 += freq_buffer
|
|
||||||
|
|
||||||
time_intersection = compute_interval_overlap(
|
|
||||||
(start_time1, end_time1),
|
|
||||||
(start_time2, end_time2),
|
|
||||||
)
|
|
||||||
|
|
||||||
freq_intersection = max(
|
|
||||||
0,
|
|
||||||
min(high_freq1, high_freq2) - max(low_freq1, low_freq2),
|
|
||||||
)
|
|
||||||
|
|
||||||
intersection = time_intersection * freq_intersection
|
|
||||||
|
|
||||||
if intersection == 0:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
union = (
|
|
||||||
(end_time1 - start_time1) * (high_freq1 - low_freq1)
|
|
||||||
+ (end_time2 - start_time2) * (high_freq2 - low_freq2)
|
|
||||||
- intersection
|
|
||||||
)
|
|
||||||
|
|
||||||
return intersection / union
|
|
||||||
|
|
||||||
|
|
||||||
class GeometricIOUConfig(BaseConfig):
|
|
||||||
name: Literal["geometric_iou"] = "geometric_iou"
|
|
||||||
time_buffer: float = 0.01
|
|
||||||
freq_buffer: float = 1000
|
|
||||||
|
|
||||||
|
|
||||||
class GeometricIOU(AffinityFunction):
|
|
||||||
def __init__(self, time_buffer: float):
|
|
||||||
self.time_buffer = time_buffer
|
|
||||||
|
|
||||||
def __call__(self, geometry1: data.Geometry, geometry2: data.Geometry):
|
|
||||||
return compute_affinity(
|
|
||||||
geometry1,
|
|
||||||
geometry2,
|
|
||||||
time_buffer=self.time_buffer,
|
|
||||||
)
|
|
||||||
|
|
||||||
@affinity_functions.register(GeometricIOUConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: GeometricIOUConfig):
|
|
||||||
return GeometricIOU(time_buffer=config.time_buffer)
|
|
||||||
|
|
||||||
|
|
||||||
AffinityConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
TimeAffinityConfig,
|
|
||||||
IntervalIOUConfig,
|
|
||||||
BBoxIOUConfig,
|
|
||||||
GeometricIOUConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_affinity_function(
|
|
||||||
config: Optional[AffinityConfig] = None,
|
|
||||||
) -> AffinityFunction:
|
|
||||||
config = config or GeometricIOUConfig()
|
|
||||||
return affinity_functions.build(config)
|
|
||||||
@ -1,15 +1,10 @@
|
|||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate.tasks import (
|
from batdetect2.evaluate.match import MatchConfig
|
||||||
TaskConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
|
||||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
|
||||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EvaluationConfig",
|
"EvaluationConfig",
|
||||||
@ -18,13 +13,7 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class EvaluationConfig(BaseConfig):
|
class EvaluationConfig(BaseConfig):
|
||||||
tasks: List[TaskConfig] = Field(
|
match: MatchConfig = Field(default_factory=MatchConfig)
|
||||||
default_factory=lambda: [
|
|
||||||
DetectionTaskConfig(),
|
|
||||||
ClassificationTaskConfig(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def load_evaluation_config(
|
def load_evaluation_config(
|
||||||
|
|||||||
@ -1,144 +0,0 @@
|
|||||||
from typing import List, NamedTuple, Optional, Sequence
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
|
|
||||||
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
|
||||||
from batdetect2.audio.clips import PaddedClipConfig
|
|
||||||
from batdetect2.core import BaseConfig
|
|
||||||
from batdetect2.core.arrays import adjust_width
|
|
||||||
from batdetect2.preprocess import build_preprocessor
|
|
||||||
from batdetect2.typing import (
|
|
||||||
AudioLoader,
|
|
||||||
ClipperProtocol,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"TestDataset",
|
|
||||||
"build_test_dataset",
|
|
||||||
"build_test_loader",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TestExample(NamedTuple):
|
|
||||||
spec: torch.Tensor
|
|
||||||
idx: torch.Tensor
|
|
||||||
start_time: torch.Tensor
|
|
||||||
end_time: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class TestDataset(Dataset[TestExample]):
|
|
||||||
clip_annotations: List[data.ClipAnnotation]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
audio_loader: AudioLoader,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
clipper: Optional[ClipperProtocol] = None,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
):
|
|
||||||
self.clip_annotations = list(clip_annotations)
|
|
||||||
self.clipper = clipper
|
|
||||||
self.preprocessor = preprocessor
|
|
||||||
self.audio_loader = audio_loader
|
|
||||||
self.audio_dir = audio_dir
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.clip_annotations)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> TestExample:
|
|
||||||
clip_annotation = self.clip_annotations[idx]
|
|
||||||
|
|
||||||
if self.clipper is not None:
|
|
||||||
clip_annotation = self.clipper(clip_annotation)
|
|
||||||
|
|
||||||
clip = clip_annotation.clip
|
|
||||||
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
|
||||||
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
|
||||||
spectrogram = self.preprocessor(wav_tensor)
|
|
||||||
return TestExample(
|
|
||||||
spec=spectrogram,
|
|
||||||
idx=torch.tensor(idx),
|
|
||||||
start_time=torch.tensor(clip.start_time),
|
|
||||||
end_time=torch.tensor(clip.end_time),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestLoaderConfig(BaseConfig):
|
|
||||||
num_workers: int = 0
|
|
||||||
clipping_strategy: ClipConfig = Field(
|
|
||||||
default_factory=lambda: PaddedClipConfig()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_test_loader(
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
config: Optional[TestLoaderConfig] = None,
|
|
||||||
num_workers: Optional[int] = None,
|
|
||||||
) -> DataLoader[TestExample]:
|
|
||||||
logger.info("Building test data loader...")
|
|
||||||
config = config or TestLoaderConfig()
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Test data loader config: \n{config}",
|
|
||||||
config=lambda: config.to_yaml_string(exclude_none=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
test_dataset = build_test_dataset(
|
|
||||||
clip_annotations,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_workers = num_workers or config.num_workers
|
|
||||||
return DataLoader(
|
|
||||||
test_dataset,
|
|
||||||
batch_size=1,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=num_workers,
|
|
||||||
collate_fn=_collate_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_test_dataset(
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
config: Optional[TestLoaderConfig] = None,
|
|
||||||
) -> TestDataset:
|
|
||||||
logger.info("Building training dataset...")
|
|
||||||
config = config or TestLoaderConfig()
|
|
||||||
|
|
||||||
clipper = build_clipper(config=config.clipping_strategy)
|
|
||||||
|
|
||||||
if audio_loader is None:
|
|
||||||
audio_loader = build_audio_loader()
|
|
||||||
|
|
||||||
if preprocessor is None:
|
|
||||||
preprocessor = build_preprocessor()
|
|
||||||
|
|
||||||
return TestDataset(
|
|
||||||
clip_annotations,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
clipper=clipper,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _collate_fn(batch: List[TestExample]) -> TestExample:
|
|
||||||
max_width = max(item.spec.shape[-1] for item in batch)
|
|
||||||
return TestExample(
|
|
||||||
spec=torch.stack(
|
|
||||||
[adjust_width(item.spec, max_width) for item in batch]
|
|
||||||
),
|
|
||||||
idx=torch.stack([item.idx for item in batch]),
|
|
||||||
start_time=torch.stack([item.start_time for item in batch]),
|
|
||||||
end_time=torch.stack([item.end_time for item in batch]),
|
|
||||||
)
|
|
||||||
@ -1,69 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING, Optional, Sequence
|
|
||||||
|
|
||||||
from lightning import Trainer
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
|
||||||
from batdetect2.evaluate.dataset import build_test_loader
|
|
||||||
from batdetect2.evaluate.evaluator import build_evaluator
|
|
||||||
from batdetect2.evaluate.lightning import EvaluationModule
|
|
||||||
from batdetect2.logging import build_logger
|
|
||||||
from batdetect2.models import Model
|
|
||||||
from batdetect2.preprocess import build_preprocessor
|
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from batdetect2.config import BatDetect2Config
|
|
||||||
from batdetect2.typing import (
|
|
||||||
AudioLoader,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
model: Model,
|
|
||||||
test_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
targets: Optional["TargetProtocol"] = None,
|
|
||||||
audio_loader: Optional["AudioLoader"] = None,
|
|
||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
|
||||||
config: Optional["BatDetect2Config"] = None,
|
|
||||||
num_workers: Optional[int] = None,
|
|
||||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
|
||||||
experiment_name: Optional[str] = None,
|
|
||||||
run_name: Optional[str] = None,
|
|
||||||
):
|
|
||||||
from batdetect2.config import BatDetect2Config
|
|
||||||
|
|
||||||
config = config or BatDetect2Config()
|
|
||||||
|
|
||||||
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
|
||||||
|
|
||||||
preprocessor = preprocessor or build_preprocessor(
|
|
||||||
config=config.preprocess,
|
|
||||||
input_samplerate=audio_loader.samplerate,
|
|
||||||
)
|
|
||||||
|
|
||||||
targets = targets or build_targets(config=config.targets)
|
|
||||||
|
|
||||||
loader = build_test_loader(
|
|
||||||
test_annotations,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
num_workers=num_workers,
|
|
||||||
)
|
|
||||||
|
|
||||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
|
||||||
|
|
||||||
logger = build_logger(
|
|
||||||
config.evaluation.logger,
|
|
||||||
log_dir=Path(output_dir),
|
|
||||||
experiment_name=experiment_name,
|
|
||||||
run_name=run_name,
|
|
||||||
)
|
|
||||||
module = EvaluationModule(model, evaluator)
|
|
||||||
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
|
||||||
return trainer.test(module, loader)
|
|
||||||
@ -1,67 +0,0 @@
|
|||||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
|
||||||
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
|
||||||
from batdetect2.evaluate.tasks import build_task
|
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.typing import EvaluatorProtocol, RawPrediction, TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Evaluator",
|
|
||||||
"build_evaluator",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Evaluator:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
tasks: Sequence[EvaluatorProtocol],
|
|
||||||
):
|
|
||||||
self.targets = targets
|
|
||||||
self.tasks = tasks
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
self,
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
predictions: Sequence[Sequence[RawPrediction]],
|
|
||||||
) -> List[Any]:
|
|
||||||
return [
|
|
||||||
task.evaluate(clip_annotations, predictions) for task in self.tasks
|
|
||||||
]
|
|
||||||
|
|
||||||
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
for task, outputs in zip(self.tasks, eval_outputs):
|
|
||||||
results.update(task.compute_metrics(outputs))
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def generate_plots(
|
|
||||||
self,
|
|
||||||
eval_outputs: List[Any],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
for task, outputs in zip(self.tasks, eval_outputs):
|
|
||||||
for name, fig in task.generate_plots(outputs):
|
|
||||||
yield name, fig
|
|
||||||
|
|
||||||
|
|
||||||
def build_evaluator(
|
|
||||||
config: Optional[Union[EvaluationConfig, dict]] = None,
|
|
||||||
targets: Optional[TargetProtocol] = None,
|
|
||||||
) -> EvaluatorProtocol:
|
|
||||||
targets = targets or build_targets()
|
|
||||||
|
|
||||||
if config is None:
|
|
||||||
config = EvaluationConfig()
|
|
||||||
|
|
||||||
if not isinstance(config, EvaluationConfig):
|
|
||||||
config = EvaluationConfig.model_validate(config)
|
|
||||||
|
|
||||||
return Evaluator(
|
|
||||||
targets=targets,
|
|
||||||
tasks=[build_task(task, targets=targets) for task in config.tasks],
|
|
||||||
)
|
|
||||||
@ -1,82 +0,0 @@
|
|||||||
from typing import Any, List
|
|
||||||
|
|
||||||
from lightning import LightningModule
|
|
||||||
from soundevent import data
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
|
||||||
from batdetect2.logging import get_image_logger
|
|
||||||
from batdetect2.models import Model
|
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
|
||||||
from batdetect2.typing import EvaluatorProtocol
|
|
||||||
from batdetect2.typing.postprocess import RawPrediction
|
|
||||||
|
|
||||||
|
|
||||||
class EvaluationModule(LightningModule):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: Model,
|
|
||||||
evaluator: EvaluatorProtocol,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.model = model
|
|
||||||
self.evaluator = evaluator
|
|
||||||
|
|
||||||
self.clip_annotations: List[data.ClipAnnotation] = []
|
|
||||||
self.predictions: List[List[RawPrediction]] = []
|
|
||||||
|
|
||||||
def test_step(self, batch: TestExample, batch_idx: int):
|
|
||||||
dataset = self.get_dataset()
|
|
||||||
clip_annotations = [
|
|
||||||
dataset.clip_annotations[int(example_idx)]
|
|
||||||
for example_idx in batch.idx
|
|
||||||
]
|
|
||||||
|
|
||||||
outputs = self.model.detector(batch.spec)
|
|
||||||
clip_detections = self.model.postprocessor(
|
|
||||||
outputs,
|
|
||||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
|
||||||
)
|
|
||||||
predictions = [
|
|
||||||
to_raw_predictions(
|
|
||||||
clip_dets.numpy(),
|
|
||||||
targets=self.evaluator.targets,
|
|
||||||
)
|
|
||||||
for clip_dets in clip_detections
|
|
||||||
]
|
|
||||||
|
|
||||||
self.clip_annotations.extend(clip_annotations)
|
|
||||||
self.predictions.extend(predictions)
|
|
||||||
|
|
||||||
def on_test_epoch_start(self):
|
|
||||||
self.clip_annotations = []
|
|
||||||
self.predictions = []
|
|
||||||
|
|
||||||
def on_test_epoch_end(self):
|
|
||||||
clip_evals = self.evaluator.evaluate(
|
|
||||||
self.clip_annotations,
|
|
||||||
self.predictions,
|
|
||||||
)
|
|
||||||
self.log_metrics(clip_evals)
|
|
||||||
self.generate_plots(clip_evals)
|
|
||||||
|
|
||||||
def generate_plots(self, evaluated_clips: Any):
|
|
||||||
plotter = get_image_logger(self.logger) # type: ignore
|
|
||||||
|
|
||||||
if plotter is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
|
|
||||||
plotter(figure_name, fig, self.global_step)
|
|
||||||
|
|
||||||
def log_metrics(self, evaluated_clips: Any):
|
|
||||||
metrics = self.evaluator.compute_metrics(evaluated_clips)
|
|
||||||
self.log_dict(metrics)
|
|
||||||
|
|
||||||
def get_dataset(self) -> TestDataset:
|
|
||||||
dataloaders = self.trainer.test_dataloaders
|
|
||||||
assert isinstance(dataloaders, DataLoader)
|
|
||||||
dataset = dataloaders.dataset
|
|
||||||
assert isinstance(dataset, TestDataset)
|
|
||||||
return dataset
|
|
||||||
@ -1,211 +1,50 @@
|
|||||||
from collections.abc import Callable, Iterable, Mapping
|
from collections.abc import Callable, Iterable, Mapping
|
||||||
from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
|
from typing import List, Literal, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import compute_affinity
|
from soundevent.evaluation import compute_affinity
|
||||||
from soundevent.evaluation import match_geometries as optimal_match
|
from soundevent.evaluation import (
|
||||||
|
match_geometries as optimal_match,
|
||||||
|
)
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.evaluate.affinity import (
|
from batdetect2.evaluate.types import MatchEvaluation
|
||||||
AffinityConfig,
|
from batdetect2.postprocess.types import BatDetect2Prediction
|
||||||
GeometricIOUConfig,
|
from batdetect2.targets.types import TargetProtocol
|
||||||
build_affinity_function,
|
|
||||||
)
|
MatchingStrategy = Literal["greedy", "optimal"]
|
||||||
from batdetect2.targets import build_targets
|
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
|
||||||
from batdetect2.typing import (
|
|
||||||
AffinityFunction,
|
|
||||||
MatcherProtocol,
|
|
||||||
MatchEvaluation,
|
|
||||||
RawPrediction,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.evaluate import ClipMatches
|
|
||||||
|
|
||||||
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
||||||
"""The geometry representation to use for matching."""
|
"""The geometry representation to use for matching."""
|
||||||
|
|
||||||
matching_strategies = Registry("matching_strategy")
|
|
||||||
|
|
||||||
|
class MatchConfig(BaseConfig):
|
||||||
|
"""Configuration for matching geometries.
|
||||||
|
|
||||||
def match(
|
Attributes
|
||||||
sound_event_annotations: Sequence[data.SoundEventAnnotation],
|
----------
|
||||||
raw_predictions: Sequence[RawPrediction],
|
strategy : MatchingStrategy, default="greedy"
|
||||||
clip: data.Clip,
|
The matching algorithm to use. 'greedy' prioritizes high-confidence
|
||||||
scores: Optional[Sequence[float]] = None,
|
predictions, while 'optimal' finds the globally best set of matches.
|
||||||
targets: Optional[TargetProtocol] = None,
|
geometry : MatchingGeometry, default="timestamp"
|
||||||
matcher: Optional[MatcherProtocol] = None,
|
The geometric representation to use when computing affinity.
|
||||||
) -> ClipMatches:
|
affinity_threshold : float, default=0.0
|
||||||
if matcher is None:
|
The minimum affinity score (e.g., IoU) required for a valid match.
|
||||||
matcher = build_matcher()
|
time_buffer : float, default=0.005
|
||||||
|
Time tolerance in seconds used in affinity calculations.
|
||||||
|
frequency_buffer : float, default=1000
|
||||||
|
Frequency tolerance in Hertz used in affinity calculations.
|
||||||
|
"""
|
||||||
|
|
||||||
if targets is None:
|
strategy: MatchingStrategy = "greedy"
|
||||||
targets = build_targets()
|
geometry: MatchingGeometry = "timestamp"
|
||||||
|
affinity_threshold: float = 0.0
|
||||||
target_geometries: List[data.Geometry] = [ # type: ignore
|
time_buffer: float = 0.005
|
||||||
sound_event_annotation.sound_event.geometry
|
frequency_buffer: float = 1_000
|
||||||
for sound_event_annotation in sound_event_annotations
|
|
||||||
]
|
|
||||||
|
|
||||||
predicted_geometries = [
|
|
||||||
raw_prediction.geometry for raw_prediction in raw_predictions
|
|
||||||
]
|
|
||||||
|
|
||||||
if scores is None:
|
|
||||||
scores = [
|
|
||||||
raw_prediction.detection_score
|
|
||||||
for raw_prediction in raw_predictions
|
|
||||||
]
|
|
||||||
|
|
||||||
matches = []
|
|
||||||
|
|
||||||
for source_idx, target_idx, affinity in matcher(
|
|
||||||
ground_truth=target_geometries,
|
|
||||||
predictions=predicted_geometries,
|
|
||||||
scores=scores,
|
|
||||||
):
|
|
||||||
target = (
|
|
||||||
sound_event_annotations[target_idx]
|
|
||||||
if target_idx is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
prediction = (
|
|
||||||
raw_predictions[source_idx] if source_idx is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
gt_det = target_idx is not None
|
|
||||||
gt_class = targets.encode_class(target) if target is not None else None
|
|
||||||
gt_geometry = (
|
|
||||||
target_geometries[target_idx] if target_idx is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
pred_score = float(prediction.detection_score) if prediction else 0
|
|
||||||
pred_geometry = (
|
|
||||||
predicted_geometries[source_idx]
|
|
||||||
if source_idx is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
class_scores = (
|
|
||||||
{
|
|
||||||
class_name: score
|
|
||||||
for class_name, score in zip(
|
|
||||||
targets.class_names,
|
|
||||||
prediction.class_scores,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
if prediction is not None
|
|
||||||
else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
matches.append(
|
|
||||||
MatchEvaluation(
|
|
||||||
clip=clip,
|
|
||||||
sound_event_annotation=target,
|
|
||||||
gt_det=gt_det,
|
|
||||||
gt_class=gt_class,
|
|
||||||
gt_geometry=gt_geometry,
|
|
||||||
pred_score=pred_score,
|
|
||||||
pred_class_scores=class_scores,
|
|
||||||
pred_geometry=pred_geometry,
|
|
||||||
affinity=affinity,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return ClipMatches(clip=clip, matches=matches)
|
|
||||||
|
|
||||||
|
|
||||||
class StartTimeMatchConfig(BaseConfig):
|
|
||||||
name: Literal["start_time_match"] = "start_time_match"
|
|
||||||
distance_threshold: float = 0.01
|
|
||||||
|
|
||||||
|
|
||||||
class StartTimeMatcher(MatcherProtocol):
|
|
||||||
def __init__(self, distance_threshold: float):
|
|
||||||
self.distance_threshold = distance_threshold
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
ground_truth: Sequence[data.Geometry],
|
|
||||||
predictions: Sequence[data.Geometry],
|
|
||||||
scores: Sequence[float],
|
|
||||||
):
|
|
||||||
return match_start_times(
|
|
||||||
ground_truth,
|
|
||||||
predictions,
|
|
||||||
scores,
|
|
||||||
distance_threshold=self.distance_threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
@matching_strategies.register(StartTimeMatchConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: StartTimeMatchConfig):
|
|
||||||
return StartTimeMatcher(distance_threshold=config.distance_threshold)
|
|
||||||
|
|
||||||
|
|
||||||
def match_start_times(
|
|
||||||
ground_truth: Sequence[data.Geometry],
|
|
||||||
predictions: Sequence[data.Geometry],
|
|
||||||
scores: Sequence[float],
|
|
||||||
distance_threshold: float = 0.01,
|
|
||||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
|
||||||
if not ground_truth:
|
|
||||||
for index in range(len(predictions)):
|
|
||||||
yield index, None, 0
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
if not predictions:
|
|
||||||
for index in range(len(ground_truth)):
|
|
||||||
yield None, index, 0
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
gt_times = np.array([compute_bounds(geom)[0] for geom in ground_truth])
|
|
||||||
pred_times = np.array([compute_bounds(geom)[0] for geom in predictions])
|
|
||||||
|
|
||||||
scores = np.array(scores)
|
|
||||||
sort_args = np.argsort(scores)[::-1]
|
|
||||||
|
|
||||||
distances = np.abs(gt_times[None, :] - pred_times[:, None])
|
|
||||||
closests = np.argmin(distances, axis=-1)
|
|
||||||
|
|
||||||
unmatched_gt = set(range(len(gt_times)))
|
|
||||||
|
|
||||||
for pred_index in sort_args:
|
|
||||||
# Get the closest ground truth
|
|
||||||
gt_closest_index = closests[pred_index]
|
|
||||||
|
|
||||||
if gt_closest_index not in unmatched_gt:
|
|
||||||
# Does not match if closest has been assigned
|
|
||||||
yield pred_index, None, 0
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get the actual distance
|
|
||||||
distance = distances[pred_index, gt_closest_index]
|
|
||||||
|
|
||||||
if distance > distance_threshold:
|
|
||||||
# Does not match if too far from closest
|
|
||||||
yield pred_index, None, 0
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Return affinity value: linear interpolation between 0 to 1, where a
|
|
||||||
# distance at the threshold maps to 0 affinity and a zero distance maps
|
|
||||||
# to 1.
|
|
||||||
affinity = np.interp(
|
|
||||||
distance,
|
|
||||||
[0, distance_threshold],
|
|
||||||
[1, 0],
|
|
||||||
left=1,
|
|
||||||
right=0,
|
|
||||||
)
|
|
||||||
unmatched_gt.remove(gt_closest_index)
|
|
||||||
yield pred_index, gt_closest_index, affinity
|
|
||||||
|
|
||||||
for missing_index in unmatched_gt:
|
|
||||||
yield None, missing_index, 0
|
|
||||||
|
|
||||||
|
|
||||||
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
|
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
|
||||||
@ -234,58 +73,45 @@ _geometry_cast_functions: Mapping[
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class GreedyMatchConfig(BaseConfig):
|
def match_geometries(
|
||||||
name: Literal["greedy_match"] = "greedy_match"
|
source: List[data.Geometry],
|
||||||
geometry: MatchingGeometry = "timestamp"
|
target: List[data.Geometry],
|
||||||
affinity_threshold: float = 0.5
|
config: MatchConfig,
|
||||||
affinity_function: AffinityConfig = Field(
|
scores: Optional[List[float]] = None,
|
||||||
default_factory=GeometricIOUConfig
|
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||||
|
geometry_cast = _geometry_cast_functions[config.geometry]
|
||||||
|
|
||||||
|
if config.strategy == "optimal":
|
||||||
|
return optimal_match(
|
||||||
|
source=[geometry_cast(geom) for geom in source],
|
||||||
|
target=[geometry_cast(geom) for geom in target],
|
||||||
|
time_buffer=config.time_buffer,
|
||||||
|
freq_buffer=config.frequency_buffer,
|
||||||
|
affinity_threshold=config.affinity_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.strategy == "greedy":
|
||||||
|
return greedy_match(
|
||||||
|
source=[geometry_cast(geom) for geom in source],
|
||||||
|
target=[geometry_cast(geom) for geom in target],
|
||||||
|
time_buffer=config.time_buffer,
|
||||||
|
freq_buffer=config.frequency_buffer,
|
||||||
|
affinity_threshold=config.affinity_threshold,
|
||||||
|
scores=scores,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Matching strategy not implemented {config.strategy}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GreedyMatcher(MatcherProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
geometry: MatchingGeometry,
|
|
||||||
affinity_threshold: float,
|
|
||||||
affinity_function: AffinityFunction,
|
|
||||||
):
|
|
||||||
self.geometry = geometry
|
|
||||||
self.affinity_function = affinity_function
|
|
||||||
self.affinity_threshold = affinity_threshold
|
|
||||||
self.cast_geometry = _geometry_cast_functions[self.geometry]
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
ground_truth: Sequence[data.Geometry],
|
|
||||||
predictions: Sequence[data.Geometry],
|
|
||||||
scores: Sequence[float],
|
|
||||||
):
|
|
||||||
return greedy_match(
|
|
||||||
ground_truth=[self.cast_geometry(geom) for geom in ground_truth],
|
|
||||||
predictions=[self.cast_geometry(geom) for geom in predictions],
|
|
||||||
scores=scores,
|
|
||||||
affinity_function=self.affinity_function,
|
|
||||||
affinity_threshold=self.affinity_threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
@matching_strategies.register(GreedyMatchConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: GreedyMatchConfig):
|
|
||||||
affinity_function = build_affinity_function(config.affinity_function)
|
|
||||||
return GreedyMatcher(
|
|
||||||
geometry=config.geometry,
|
|
||||||
affinity_threshold=config.affinity_threshold,
|
|
||||||
affinity_function=affinity_function,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def greedy_match(
|
def greedy_match(
|
||||||
ground_truth: Sequence[data.Geometry],
|
source: List[data.Geometry],
|
||||||
predictions: Sequence[data.Geometry],
|
target: List[data.Geometry],
|
||||||
scores: Sequence[float],
|
scores: Optional[List[float]] = None,
|
||||||
affinity_threshold: float = 0.5,
|
affinity_threshold: float = 0.5,
|
||||||
affinity_function: AffinityFunction = compute_affinity,
|
time_buffer: float = 0.001,
|
||||||
|
freq_buffer: float = 1000,
|
||||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||||
"""Performs a greedy, one-to-one matching of source to target geometries.
|
"""Performs a greedy, one-to-one matching of source to target geometries.
|
||||||
|
|
||||||
@ -303,6 +129,10 @@ def greedy_match(
|
|||||||
Confidence scores for each source geometry for prioritization.
|
Confidence scores for each source geometry for prioritization.
|
||||||
affinity_threshold
|
affinity_threshold
|
||||||
The minimum affinity score required for a valid match.
|
The minimum affinity score required for a valid match.
|
||||||
|
time_buffer
|
||||||
|
Time tolerance in seconds for affinity calculation.
|
||||||
|
freq_buffer
|
||||||
|
Frequency tolerance in Hertz for affinity calculation.
|
||||||
|
|
||||||
Yields
|
Yields
|
||||||
------
|
------
|
||||||
@ -313,29 +143,37 @@ def greedy_match(
|
|||||||
- Unmatched Source (False Positive): `(source_idx, None, 0)`
|
- Unmatched Source (False Positive): `(source_idx, None, 0)`
|
||||||
- Unmatched Target (False Negative): `(None, target_idx, 0)`
|
- Unmatched Target (False Negative): `(None, target_idx, 0)`
|
||||||
"""
|
"""
|
||||||
unassigned_gt = set(range(len(ground_truth)))
|
assigned = set()
|
||||||
|
|
||||||
if not predictions:
|
if not source:
|
||||||
for gt_idx in range(len(ground_truth)):
|
for target_idx in range(len(target)):
|
||||||
yield None, gt_idx, 0
|
yield None, target_idx, 0
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if not ground_truth:
|
if not target:
|
||||||
for pred_idx in range(len(predictions)):
|
for source_idx in range(len(source)):
|
||||||
yield pred_idx, None, 0
|
yield source_idx, None, 0
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
indices = np.argsort(scores)[::-1]
|
if scores is None:
|
||||||
|
indices = np.arange(len(source))
|
||||||
|
else:
|
||||||
|
indices = np.argsort(scores)[::-1]
|
||||||
|
|
||||||
for pred_idx in indices:
|
for source_idx in indices:
|
||||||
source_geometry = predictions[pred_idx]
|
source_geometry = source[source_idx]
|
||||||
|
|
||||||
affinities = np.array(
|
affinities = np.array(
|
||||||
[
|
[
|
||||||
affinity_function(source_geometry, target_geometry)
|
compute_affinity(
|
||||||
for target_geometry in ground_truth
|
source_geometry,
|
||||||
|
target_geometry,
|
||||||
|
time_buffer=time_buffer,
|
||||||
|
freq_buffer=freq_buffer,
|
||||||
|
)
|
||||||
|
for target_geometry in target
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -343,72 +181,162 @@ def greedy_match(
|
|||||||
affinity = affinities[closest_target]
|
affinity = affinities[closest_target]
|
||||||
|
|
||||||
if affinities[closest_target] <= affinity_threshold:
|
if affinities[closest_target] <= affinity_threshold:
|
||||||
yield pred_idx, None, 0
|
yield source_idx, None, 0
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if closest_target not in unassigned_gt:
|
if closest_target in assigned:
|
||||||
yield pred_idx, None, 0
|
yield source_idx, None, 0
|
||||||
continue
|
continue
|
||||||
|
|
||||||
unassigned_gt.remove(closest_target)
|
assigned.add(closest_target)
|
||||||
yield pred_idx, closest_target, affinity
|
yield source_idx, closest_target, affinity
|
||||||
|
|
||||||
for gt_idx in unassigned_gt:
|
missed_ground_truth = set(range(len(target))) - assigned
|
||||||
yield None, gt_idx, 0
|
for target_idx in missed_ground_truth:
|
||||||
|
yield None, target_idx, 0
|
||||||
|
|
||||||
|
|
||||||
class OptimalMatchConfig(BaseConfig):
|
def match_sound_events_and_raw_predictions(
|
||||||
name: Literal["optimal_match"] = "optimal_match"
|
clip_annotation: data.ClipAnnotation,
|
||||||
affinity_threshold: float = 0.5
|
raw_predictions: List[BatDetect2Prediction],
|
||||||
time_buffer: float = 0.005
|
targets: TargetProtocol,
|
||||||
frequency_buffer: float = 1_000
|
config: Optional[MatchConfig] = None,
|
||||||
|
) -> List[MatchEvaluation]:
|
||||||
|
config = config or MatchConfig()
|
||||||
|
|
||||||
|
target_sound_events = [
|
||||||
|
targets.transform(sound_event_annotation)
|
||||||
|
for sound_event_annotation in clip_annotation.sound_events
|
||||||
|
if targets.filter(sound_event_annotation)
|
||||||
|
and sound_event_annotation.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
class OptimalMatcher(MatcherProtocol):
|
target_geometries: List[data.Geometry] = [ # type: ignore
|
||||||
def __init__(
|
sound_event_annotation.sound_event.geometry
|
||||||
self,
|
for sound_event_annotation in target_sound_events
|
||||||
affinity_threshold: float,
|
if sound_event_annotation.sound_event.geometry is not None
|
||||||
time_buffer: float,
|
]
|
||||||
frequency_buffer: float,
|
|
||||||
|
predicted_geometries = [
|
||||||
|
raw_prediction.raw.geometry for raw_prediction in raw_predictions
|
||||||
|
]
|
||||||
|
|
||||||
|
scores = [
|
||||||
|
raw_prediction.raw.detection_score
|
||||||
|
for raw_prediction in raw_predictions
|
||||||
|
]
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
for source_idx, target_idx, affinity in match_geometries(
|
||||||
|
source=predicted_geometries,
|
||||||
|
target=target_geometries,
|
||||||
|
config=config,
|
||||||
|
scores=scores,
|
||||||
):
|
):
|
||||||
self.affinity_threshold = affinity_threshold
|
target = (
|
||||||
self.time_buffer = time_buffer
|
target_sound_events[target_idx] if target_idx is not None else None
|
||||||
self.frequency_buffer = frequency_buffer
|
)
|
||||||
|
prediction = (
|
||||||
def __call__(
|
raw_predictions[source_idx] if source_idx is not None else None
|
||||||
self,
|
|
||||||
ground_truth: Sequence[data.Geometry],
|
|
||||||
predictions: Sequence[data.Geometry],
|
|
||||||
scores: Sequence[float],
|
|
||||||
):
|
|
||||||
return optimal_match(
|
|
||||||
source=predictions,
|
|
||||||
target=ground_truth,
|
|
||||||
time_buffer=self.time_buffer,
|
|
||||||
freq_buffer=self.frequency_buffer,
|
|
||||||
affinity_threshold=self.affinity_threshold,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@matching_strategies.register(OptimalMatchConfig)
|
gt_det = target is not None
|
||||||
@staticmethod
|
gt_class = targets.encode_class(target) if target is not None else None
|
||||||
def from_config(config: OptimalMatchConfig):
|
|
||||||
return OptimalMatcher(
|
pred_score = float(prediction.raw.detection_score) if prediction else 0
|
||||||
affinity_threshold=config.affinity_threshold,
|
|
||||||
time_buffer=config.time_buffer,
|
class_scores = (
|
||||||
frequency_buffer=config.frequency_buffer,
|
{
|
||||||
|
str(class_name): float(score)
|
||||||
|
for class_name, score in zip(
|
||||||
|
targets.class_names,
|
||||||
|
prediction.raw.class_scores,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if prediction is not None
|
||||||
|
else {}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
matches.append(
|
||||||
|
MatchEvaluation(
|
||||||
|
match=data.Match(
|
||||||
|
source=None
|
||||||
|
if prediction is None
|
||||||
|
else prediction.sound_event_prediction,
|
||||||
|
target=target,
|
||||||
|
affinity=affinity,
|
||||||
|
),
|
||||||
|
gt_det=gt_det,
|
||||||
|
gt_class=gt_class,
|
||||||
|
pred_score=pred_score,
|
||||||
|
pred_class_scores=class_scores,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
MatchConfig = Annotated[
|
return matches
|
||||||
Union[
|
|
||||||
GreedyMatchConfig,
|
|
||||||
StartTimeMatchConfig,
|
|
||||||
OptimalMatchConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
|
def match_predictions_and_annotations(
|
||||||
config = config or StartTimeMatchConfig()
|
clip_annotation: data.ClipAnnotation,
|
||||||
return matching_strategies.build(config)
|
clip_prediction: data.ClipPrediction,
|
||||||
|
config: Optional[MatchConfig] = None,
|
||||||
|
) -> List[data.Match]:
|
||||||
|
config = config or MatchConfig()
|
||||||
|
|
||||||
|
annotated_sound_events = [
|
||||||
|
sound_event_annotation
|
||||||
|
for sound_event_annotation in clip_annotation.sound_events
|
||||||
|
if sound_event_annotation.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
predicted_sound_events = [
|
||||||
|
sound_event_prediction
|
||||||
|
for sound_event_prediction in clip_prediction.sound_events
|
||||||
|
if sound_event_prediction.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
annotated_geometries: List[data.Geometry] = [
|
||||||
|
sound_event.sound_event.geometry
|
||||||
|
for sound_event in annotated_sound_events
|
||||||
|
if sound_event.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
predicted_geometries: List[data.Geometry] = [
|
||||||
|
sound_event.sound_event.geometry
|
||||||
|
for sound_event in predicted_sound_events
|
||||||
|
if sound_event.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
scores = [
|
||||||
|
sound_event.score
|
||||||
|
for sound_event in predicted_sound_events
|
||||||
|
if sound_event.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for source_idx, target_idx, affinity in match_geometries(
|
||||||
|
source=predicted_geometries,
|
||||||
|
target=annotated_geometries,
|
||||||
|
config=config,
|
||||||
|
scores=scores,
|
||||||
|
):
|
||||||
|
target = (
|
||||||
|
annotated_sound_events[target_idx]
|
||||||
|
if target_idx is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
source = (
|
||||||
|
predicted_sound_events[source_idx]
|
||||||
|
if source_idx is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
matches.append(
|
||||||
|
data.Match(
|
||||||
|
source=source,
|
||||||
|
target=target,
|
||||||
|
affinity=affinity,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|||||||
97
src/batdetect2/evaluate/metrics.py
Normal file
97
src/batdetect2/evaluate/metrics.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn import metrics
|
||||||
|
from sklearn.preprocessing import label_binarize
|
||||||
|
|
||||||
|
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||||
|
|
||||||
|
__all__ = ["DetectionAveragePrecision"]
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionAveragePrecision(MetricsProtocol):
|
||||||
|
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||||
|
y_true, y_score = zip(
|
||||||
|
*[(match.gt_det, match.pred_score) for match in matches]
|
||||||
|
)
|
||||||
|
score = float(metrics.average_precision_score(y_true, y_score))
|
||||||
|
return {"detection_AP": score}
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationMeanAveragePrecision(MetricsProtocol):
|
||||||
|
def __init__(self, class_names: List[str], per_class: bool = True):
|
||||||
|
self.class_names = class_names
|
||||||
|
self.per_class = per_class
|
||||||
|
|
||||||
|
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||||
|
y_true = label_binarize(
|
||||||
|
[
|
||||||
|
match.gt_class if match.gt_class is not None else "__NONE__"
|
||||||
|
for match in matches
|
||||||
|
],
|
||||||
|
classes=self.class_names,
|
||||||
|
)
|
||||||
|
y_pred = pd.DataFrame(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
name: match.pred_class_scores.get(name, 0)
|
||||||
|
for name in self.class_names
|
||||||
|
}
|
||||||
|
for match in matches
|
||||||
|
]
|
||||||
|
).fillna(0)
|
||||||
|
mAP = metrics.average_precision_score(y_true, y_pred[self.class_names])
|
||||||
|
|
||||||
|
ret = {
|
||||||
|
"classification_mAP": float(mAP),
|
||||||
|
}
|
||||||
|
|
||||||
|
if not self.per_class:
|
||||||
|
return ret
|
||||||
|
|
||||||
|
for class_index, class_name in enumerate(self.class_names):
|
||||||
|
y_true_class = y_true[:, class_index]
|
||||||
|
y_pred_class = y_pred[class_name]
|
||||||
|
class_ap = metrics.average_precision_score(
|
||||||
|
y_true_class,
|
||||||
|
y_pred_class,
|
||||||
|
)
|
||||||
|
ret[f"classification_AP/{class_name}"] = float(class_ap)
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationAccuracy(MetricsProtocol):
|
||||||
|
def __init__(self, class_names: List[str]):
|
||||||
|
self.class_names = class_names
|
||||||
|
|
||||||
|
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||||
|
y_true = [
|
||||||
|
match.gt_class if match.gt_class is not None else "__NONE__"
|
||||||
|
for match in matches
|
||||||
|
]
|
||||||
|
|
||||||
|
y_pred = pd.DataFrame(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
name: match.pred_class_scores.get(name, 0)
|
||||||
|
for name in self.class_names
|
||||||
|
}
|
||||||
|
for match in matches
|
||||||
|
]
|
||||||
|
).fillna(0)
|
||||||
|
y_pred = y_pred.apply(
|
||||||
|
lambda row: row.idxmax()
|
||||||
|
if row.max() >= (1 - row.sum())
|
||||||
|
else "__NONE__",
|
||||||
|
axis=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
accuracy = metrics.balanced_accuracy_score(
|
||||||
|
y_true,
|
||||||
|
y_pred,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"classification_acc": float(accuracy),
|
||||||
|
}
|
||||||
@ -1,267 +0,0 @@
|
|||||||
from collections import defaultdict
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import (
|
|
||||||
Annotated,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from pydantic import Field
|
|
||||||
from sklearn import metrics
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
|
||||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ClassificationMetric",
|
|
||||||
"ClassificationMetricConfig",
|
|
||||||
"build_classification_metric",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MatchEval:
|
|
||||||
clip: data.Clip
|
|
||||||
gt: Optional[data.SoundEventAnnotation]
|
|
||||||
pred: Optional[RawPrediction]
|
|
||||||
|
|
||||||
is_prediction: bool
|
|
||||||
is_ground_truth: bool
|
|
||||||
is_generic: bool
|
|
||||||
true_class: Optional[str]
|
|
||||||
score: float
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClipEval:
|
|
||||||
clip: data.Clip
|
|
||||||
matches: Mapping[str, List[MatchEval]]
|
|
||||||
|
|
||||||
|
|
||||||
ClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
|
||||||
|
|
||||||
|
|
||||||
classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
|
|
||||||
Registry("classification_metric")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseClassificationConfig(BaseConfig):
|
|
||||||
include: Optional[List[str]] = None
|
|
||||||
exclude: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class BaseClassificationMetric:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
include: Optional[List[str]] = None,
|
|
||||||
exclude: Optional[List[str]] = None,
|
|
||||||
):
|
|
||||||
self.targets = targets
|
|
||||||
self.include = include
|
|
||||||
self.exclude = exclude
|
|
||||||
|
|
||||||
def include_class(self, class_name: str) -> bool:
|
|
||||||
if self.include is not None:
|
|
||||||
return class_name in self.include
|
|
||||||
|
|
||||||
if self.exclude is not None:
|
|
||||||
return class_name not in self.exclude
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationAveragePrecisionConfig(BaseClassificationConfig):
|
|
||||||
name: Literal["average_precision"] = "average_precision"
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
ignore_generic: bool = True
|
|
||||||
label: str = "average_precision"
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationAveragePrecision(BaseClassificationMetric):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
label: str = "average_precision",
|
|
||||||
include: Optional[List[str]] = None,
|
|
||||||
exclude: Optional[List[str]] = None,
|
|
||||||
):
|
|
||||||
super().__init__(include=include, exclude=exclude, targets=targets)
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.ignore_generic = ignore_generic
|
|
||||||
self.label = label
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, clip_evaluations: Sequence[ClipEval]
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
|
||||||
clip_evaluations,
|
|
||||||
ignore_non_predictions=self.ignore_non_predictions,
|
|
||||||
ignore_generic=self.ignore_generic,
|
|
||||||
)
|
|
||||||
|
|
||||||
class_scores = {
|
|
||||||
class_name: average_precision(
|
|
||||||
y_true[class_name],
|
|
||||||
y_score[class_name],
|
|
||||||
num_positives=num_positives[class_name],
|
|
||||||
)
|
|
||||||
for class_name in self.targets.class_names
|
|
||||||
}
|
|
||||||
|
|
||||||
mean_score = float(
|
|
||||||
np.mean([v for v in class_scores.values() if not np.isnan(v)])
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
f"mean_{self.label}": mean_score,
|
|
||||||
**{
|
|
||||||
f"{self.label}/{class_name}": score
|
|
||||||
for class_name, score in class_scores.items()
|
|
||||||
if self.include_class(class_name)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@classification_metrics.register(ClassificationAveragePrecisionConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: ClassificationAveragePrecisionConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
):
|
|
||||||
return ClassificationAveragePrecision(
|
|
||||||
targets=targets,
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
|
||||||
ignore_generic=config.ignore_generic,
|
|
||||||
label=config.label,
|
|
||||||
include=config.include,
|
|
||||||
exclude=config.exclude,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationROCAUCConfig(BaseClassificationConfig):
|
|
||||||
name: Literal["roc_auc"] = "roc_auc"
|
|
||||||
label: str = "roc_auc"
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
ignore_generic: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationROCAUC(BaseClassificationMetric):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
label: str = "roc_auc",
|
|
||||||
include: Optional[List[str]] = None,
|
|
||||||
exclude: Optional[List[str]] = None,
|
|
||||||
):
|
|
||||||
self.targets = targets
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.ignore_generic = ignore_generic
|
|
||||||
self.label = label
|
|
||||||
self.include = include
|
|
||||||
self.exclude = exclude
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, clip_evaluations: Sequence[ClipEval]
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true, y_score, _ = _extract_per_class_metric_data(
|
|
||||||
clip_evaluations,
|
|
||||||
ignore_non_predictions=self.ignore_non_predictions,
|
|
||||||
ignore_generic=self.ignore_generic,
|
|
||||||
)
|
|
||||||
|
|
||||||
class_scores = {
|
|
||||||
class_name: float(
|
|
||||||
metrics.roc_auc_score(
|
|
||||||
y_true[class_name],
|
|
||||||
y_score[class_name],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for class_name in self.targets.class_names
|
|
||||||
}
|
|
||||||
|
|
||||||
mean_score = float(
|
|
||||||
np.mean([v for v in class_scores.values() if v != np.nan])
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
f"mean_{self.label}": mean_score,
|
|
||||||
**{
|
|
||||||
f"{self.label}/{class_name}": score
|
|
||||||
for class_name, score in class_scores.items()
|
|
||||||
if self.include_class(class_name)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@classification_metrics.register(ClassificationROCAUCConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: ClassificationROCAUCConfig, targets: TargetProtocol
|
|
||||||
):
|
|
||||||
return ClassificationROCAUC(
|
|
||||||
targets=targets,
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
|
||||||
ignore_generic=config.ignore_generic,
|
|
||||||
label=config.label,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
ClassificationMetricConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
ClassificationAveragePrecisionConfig,
|
|
||||||
ClassificationROCAUCConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_classification_metric(
|
|
||||||
config: ClassificationMetricConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
) -> ClassificationMetric:
|
|
||||||
return classification_metrics.build(config, targets)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_per_class_metric_data(
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
):
|
|
||||||
y_true = defaultdict(list)
|
|
||||||
y_score = defaultdict(list)
|
|
||||||
num_positives = defaultdict(lambda: 0)
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for class_name, matches in clip_eval.matches.items():
|
|
||||||
for m in matches:
|
|
||||||
# Exclude matches with ground truth sounds where the class
|
|
||||||
# is unknown
|
|
||||||
if m.is_generic and ignore_generic:
|
|
||||||
continue
|
|
||||||
|
|
||||||
is_class = m.true_class == class_name
|
|
||||||
|
|
||||||
if is_class:
|
|
||||||
num_positives[class_name] += 1
|
|
||||||
|
|
||||||
# Ignore matches that don't correspond to a prediction
|
|
||||||
if not m.is_prediction and ignore_non_predictions:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true[class_name].append(is_class)
|
|
||||||
y_score[class_name].append(m.score)
|
|
||||||
|
|
||||||
return y_true, y_score, num_positives
|
|
||||||
@ -1,135 +0,0 @@
|
|||||||
from collections import defaultdict
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from pydantic import Field
|
|
||||||
from sklearn import metrics
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
from batdetect2.core.registries import Registry
|
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClipEval:
|
|
||||||
true_classes: Set[str]
|
|
||||||
class_scores: Dict[str, float]
|
|
||||||
|
|
||||||
|
|
||||||
ClipClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
|
||||||
|
|
||||||
clip_classification_metrics: Registry[ClipClassificationMetric, []] = Registry(
|
|
||||||
"clip_classification_metric"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationAveragePrecisionConfig(BaseConfig):
|
|
||||||
name: Literal["average_precision"] = "average_precision"
|
|
||||||
label: str = "average_precision"
|
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationAveragePrecision:
|
|
||||||
def __init__(self, label: str = "average_precision"):
|
|
||||||
self.label = label
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true = defaultdict(list)
|
|
||||||
y_score = defaultdict(list)
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for class_name, score in clip_eval.class_scores.items():
|
|
||||||
y_true[class_name].append(class_name in clip_eval.true_classes)
|
|
||||||
y_score[class_name].append(score)
|
|
||||||
|
|
||||||
class_scores = {
|
|
||||||
class_name: float(
|
|
||||||
average_precision(
|
|
||||||
y_true=y_true[class_name],
|
|
||||||
y_score=y_score[class_name],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for class_name in y_true
|
|
||||||
}
|
|
||||||
|
|
||||||
mean = np.mean([v for v in class_scores.values() if not np.isnan(v)])
|
|
||||||
|
|
||||||
return {
|
|
||||||
f"mean_{self.label}": float(mean),
|
|
||||||
**{
|
|
||||||
f"{self.label}/{class_name}": score
|
|
||||||
for class_name, score in class_scores.items()
|
|
||||||
if not np.isnan(score)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@clip_classification_metrics.register(
|
|
||||||
ClipClassificationAveragePrecisionConfig
|
|
||||||
)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ClipClassificationAveragePrecisionConfig):
|
|
||||||
return ClipClassificationAveragePrecision(label=config.label)
|
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationROCAUCConfig(BaseConfig):
|
|
||||||
name: Literal["roc_auc"] = "roc_auc"
|
|
||||||
label: str = "roc_auc"
|
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationROCAUC:
|
|
||||||
def __init__(self, label: str = "roc_auc"):
|
|
||||||
self.label = label
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true = defaultdict(list)
|
|
||||||
y_score = defaultdict(list)
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for class_name, score in clip_eval.class_scores.items():
|
|
||||||
y_true[class_name].append(class_name in clip_eval.true_classes)
|
|
||||||
y_score[class_name].append(score)
|
|
||||||
|
|
||||||
class_scores = {
|
|
||||||
class_name: float(
|
|
||||||
metrics.roc_auc_score(
|
|
||||||
y_true=y_true[class_name],
|
|
||||||
y_score=y_score[class_name],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for class_name in y_true
|
|
||||||
}
|
|
||||||
|
|
||||||
mean = np.mean([v for v in class_scores.values() if not np.isnan(v)])
|
|
||||||
|
|
||||||
return {
|
|
||||||
f"mean_{self.label}": float(mean),
|
|
||||||
**{
|
|
||||||
f"{self.label}/{class_name}": score
|
|
||||||
for class_name, score in class_scores.items()
|
|
||||||
if not np.isnan(score)
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
@clip_classification_metrics.register(ClipClassificationROCAUCConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ClipClassificationROCAUCConfig):
|
|
||||||
return ClipClassificationROCAUC(label=config.label)
|
|
||||||
|
|
||||||
|
|
||||||
ClipClassificationMetricConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
ClipClassificationAveragePrecisionConfig,
|
|
||||||
ClipClassificationROCAUCConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_clip_metric(config: ClipClassificationMetricConfig):
|
|
||||||
return clip_classification_metrics.build(config)
|
|
||||||
@ -1,173 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import Annotated, Callable, Dict, Literal, Sequence, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from pydantic import Field
|
|
||||||
from sklearn import metrics
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
from batdetect2.core.registries import Registry
|
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClipEval:
|
|
||||||
gt_det: bool
|
|
||||||
score: float
|
|
||||||
|
|
||||||
|
|
||||||
ClipDetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
|
||||||
|
|
||||||
clip_detection_metrics: Registry[ClipDetectionMetric, []] = Registry(
|
|
||||||
"clip_detection_metric"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionAveragePrecisionConfig(BaseConfig):
|
|
||||||
name: Literal["average_precision"] = "average_precision"
|
|
||||||
label: str = "average_precision"
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionAveragePrecision:
|
|
||||||
def __init__(self, label: str = "average_precision"):
|
|
||||||
self.label = label
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
y_true.append(clip_eval.gt_det)
|
|
||||||
y_score.append(clip_eval.score)
|
|
||||||
|
|
||||||
score = average_precision(y_true=y_true, y_score=y_score)
|
|
||||||
return {self.label: score}
|
|
||||||
|
|
||||||
@clip_detection_metrics.register(ClipDetectionAveragePrecisionConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ClipDetectionAveragePrecisionConfig):
|
|
||||||
return ClipDetectionAveragePrecision(label=config.label)
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionROCAUCConfig(BaseConfig):
|
|
||||||
name: Literal["roc_auc"] = "roc_auc"
|
|
||||||
label: str = "roc_auc"
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionROCAUC:
|
|
||||||
def __init__(self, label: str = "roc_auc"):
|
|
||||||
self.label = label
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
y_true.append(clip_eval.gt_det)
|
|
||||||
y_score.append(clip_eval.score)
|
|
||||||
|
|
||||||
score = float(metrics.roc_auc_score(y_true=y_true, y_score=y_score))
|
|
||||||
return {self.label: score}
|
|
||||||
|
|
||||||
@clip_detection_metrics.register(ClipDetectionROCAUCConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ClipDetectionROCAUCConfig):
|
|
||||||
return ClipDetectionROCAUC(label=config.label)
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionRecallConfig(BaseConfig):
|
|
||||||
name: Literal["recall"] = "recall"
|
|
||||||
threshold: float = 0.5
|
|
||||||
label: str = "recall"
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionRecall:
|
|
||||||
def __init__(self, threshold: float, label: str = "recall"):
|
|
||||||
self.threshold = threshold
|
|
||||||
self.label = label
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
num_positives = 0
|
|
||||||
true_positives = 0
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
if clip_eval.gt_det:
|
|
||||||
num_positives += 1
|
|
||||||
|
|
||||||
if clip_eval.score >= self.threshold and clip_eval.gt_det:
|
|
||||||
true_positives += 1
|
|
||||||
|
|
||||||
if num_positives == 0:
|
|
||||||
return {self.label: np.nan}
|
|
||||||
|
|
||||||
score = true_positives / num_positives
|
|
||||||
return {self.label: score}
|
|
||||||
|
|
||||||
@clip_detection_metrics.register(ClipDetectionRecallConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ClipDetectionRecallConfig):
|
|
||||||
return ClipDetectionRecall(
|
|
||||||
threshold=config.threshold, label=config.label
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionPrecisionConfig(BaseConfig):
|
|
||||||
name: Literal["precision"] = "precision"
|
|
||||||
threshold: float = 0.5
|
|
||||||
label: str = "precision"
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionPrecision:
|
|
||||||
def __init__(self, threshold: float, label: str = "precision"):
|
|
||||||
self.threshold = threshold
|
|
||||||
self.label = label
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
num_detections = 0
|
|
||||||
true_positives = 0
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
if clip_eval.score >= self.threshold:
|
|
||||||
num_detections += 1
|
|
||||||
|
|
||||||
if clip_eval.score >= self.threshold and clip_eval.gt_det:
|
|
||||||
true_positives += 1
|
|
||||||
|
|
||||||
if num_detections == 0:
|
|
||||||
return {self.label: np.nan}
|
|
||||||
|
|
||||||
score = true_positives / num_detections
|
|
||||||
return {self.label: score}
|
|
||||||
|
|
||||||
@clip_detection_metrics.register(ClipDetectionPrecisionConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ClipDetectionPrecisionConfig):
|
|
||||||
return ClipDetectionPrecision(
|
|
||||||
threshold=config.threshold, label=config.label
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
ClipDetectionMetricConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
ClipDetectionAveragePrecisionConfig,
|
|
||||||
ClipDetectionROCAUCConfig,
|
|
||||||
ClipDetectionRecallConfig,
|
|
||||||
ClipDetectionPrecisionConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_clip_metric(config: ClipDetectionMetricConfig):
|
|
||||||
return clip_detection_metrics.build(config)
|
|
||||||
@ -1,63 +0,0 @@
|
|||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"compute_precision_recall",
|
|
||||||
"average_precision",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def compute_precision_recall(
|
|
||||||
y_true,
|
|
||||||
y_score,
|
|
||||||
num_positives: Optional[int] = None,
|
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
y_true = np.array(y_true)
|
|
||||||
y_score = np.array(y_score)
|
|
||||||
|
|
||||||
if num_positives is None:
|
|
||||||
num_positives = y_true.sum()
|
|
||||||
|
|
||||||
# Sort by score
|
|
||||||
sort_ind = np.argsort(y_score)[::-1]
|
|
||||||
y_true_sorted = y_true[sort_ind]
|
|
||||||
y_score_sorted = y_score[sort_ind]
|
|
||||||
|
|
||||||
false_pos_c = np.cumsum(1 - y_true_sorted)
|
|
||||||
true_pos_c = np.cumsum(y_true_sorted)
|
|
||||||
|
|
||||||
recall = true_pos_c / num_positives
|
|
||||||
precision = true_pos_c / np.maximum(
|
|
||||||
true_pos_c + false_pos_c,
|
|
||||||
np.finfo(np.float64).eps,
|
|
||||||
)
|
|
||||||
|
|
||||||
precision[np.isnan(precision)] = 0
|
|
||||||
recall[np.isnan(recall)] = 0
|
|
||||||
return precision, recall, y_score_sorted
|
|
||||||
|
|
||||||
|
|
||||||
def average_precision(
|
|
||||||
y_true,
|
|
||||||
y_score,
|
|
||||||
num_positives: Optional[int] = None,
|
|
||||||
) -> float:
|
|
||||||
if num_positives == 0:
|
|
||||||
return np.nan
|
|
||||||
|
|
||||||
precision, recall, _ = compute_precision_recall(
|
|
||||||
y_true,
|
|
||||||
y_score,
|
|
||||||
num_positives=num_positives,
|
|
||||||
)
|
|
||||||
|
|
||||||
# pascal 12 way
|
|
||||||
mprec = np.hstack((0, precision, 0))
|
|
||||||
mrec = np.hstack((0, recall, 1))
|
|
||||||
for ii in range(mprec.shape[0] - 2, -1, -1):
|
|
||||||
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
|
||||||
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
|
||||||
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
|
||||||
|
|
||||||
return ave_prec
|
|
||||||
@ -1,226 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import (
|
|
||||||
Annotated,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from pydantic import Field
|
|
||||||
from sklearn import metrics
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
|
||||||
from batdetect2.typing import RawPrediction
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"DetectionMetricConfig",
|
|
||||||
"DetectionMetric",
|
|
||||||
"build_detection_metric",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MatchEval:
|
|
||||||
gt: Optional[data.SoundEventAnnotation]
|
|
||||||
pred: Optional[RawPrediction]
|
|
||||||
|
|
||||||
is_prediction: bool
|
|
||||||
is_ground_truth: bool
|
|
||||||
score: float
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClipEval:
|
|
||||||
clip: data.Clip
|
|
||||||
matches: List[MatchEval]
|
|
||||||
|
|
||||||
|
|
||||||
DetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
|
||||||
|
|
||||||
|
|
||||||
detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric")
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionAveragePrecisionConfig(BaseConfig):
|
|
||||||
name: Literal["average_precision"] = "average_precision"
|
|
||||||
label: str = "average_precision"
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionAveragePrecision:
|
|
||||||
def __init__(self, label: str, ignore_non_predictions: bool = True):
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.label = label
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evals: Sequence[ClipEval],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
num_positives = 0
|
|
||||||
|
|
||||||
for clip_eval in clip_evals:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
num_positives += int(m.is_ground_truth)
|
|
||||||
|
|
||||||
# Ignore matches that don't correspond to a prediction
|
|
||||||
if not m.is_prediction and self.ignore_non_predictions:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(m.is_ground_truth)
|
|
||||||
y_score.append(m.score)
|
|
||||||
|
|
||||||
ap = average_precision(y_true, y_score, num_positives=num_positives)
|
|
||||||
return {self.label: ap}
|
|
||||||
|
|
||||||
@detection_metrics.register(DetectionAveragePrecisionConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: DetectionAveragePrecisionConfig):
|
|
||||||
return DetectionAveragePrecision(
|
|
||||||
label=config.label,
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionROCAUCConfig(BaseConfig):
|
|
||||||
name: Literal["roc_auc"] = "roc_auc"
|
|
||||||
label: str = "roc_auc"
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionROCAUC:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
label: str = "roc_auc",
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
):
|
|
||||||
self.label = label
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
|
|
||||||
def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]:
|
|
||||||
y_true: List[bool] = []
|
|
||||||
y_score: List[float] = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evals:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
if not m.is_prediction and self.ignore_non_predictions:
|
|
||||||
# Ignore matches that don't correspond to a prediction
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(m.is_ground_truth)
|
|
||||||
y_score.append(m.score)
|
|
||||||
|
|
||||||
score = float(metrics.roc_auc_score(y_true, y_score))
|
|
||||||
return {self.label: score}
|
|
||||||
|
|
||||||
@detection_metrics.register(DetectionROCAUCConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: DetectionROCAUCConfig):
|
|
||||||
return DetectionROCAUC(
|
|
||||||
label=config.label,
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionRecallConfig(BaseConfig):
|
|
||||||
name: Literal["recall"] = "recall"
|
|
||||||
label: str = "recall"
|
|
||||||
threshold: float = 0.5
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionRecall:
|
|
||||||
def __init__(self, threshold: float, label: str = "recall"):
|
|
||||||
self.label = label
|
|
||||||
self.threshold = threshold
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
num_positives = 0
|
|
||||||
true_positives = 0
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
if m.is_ground_truth:
|
|
||||||
num_positives += 1
|
|
||||||
|
|
||||||
if m.score >= self.threshold and m.is_ground_truth:
|
|
||||||
true_positives += 1
|
|
||||||
|
|
||||||
if num_positives == 0:
|
|
||||||
return {self.label: np.nan}
|
|
||||||
|
|
||||||
score = true_positives / num_positives
|
|
||||||
return {self.label: score}
|
|
||||||
|
|
||||||
@detection_metrics.register(DetectionRecallConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: DetectionRecallConfig):
|
|
||||||
return DetectionRecall(threshold=config.threshold, label=config.label)
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionPrecisionConfig(BaseConfig):
|
|
||||||
name: Literal["precision"] = "precision"
|
|
||||||
label: str = "precision"
|
|
||||||
threshold: float = 0.5
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionPrecision:
|
|
||||||
def __init__(self, threshold: float, label: str = "precision"):
|
|
||||||
self.threshold = threshold
|
|
||||||
self.label = label
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
num_detections = 0
|
|
||||||
true_positives = 0
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
is_detection = m.score >= self.threshold
|
|
||||||
|
|
||||||
if is_detection:
|
|
||||||
num_detections += 1
|
|
||||||
|
|
||||||
if is_detection and m.is_ground_truth:
|
|
||||||
true_positives += 1
|
|
||||||
|
|
||||||
if num_detections == 0:
|
|
||||||
return {self.label: np.nan}
|
|
||||||
|
|
||||||
score = true_positives / num_detections
|
|
||||||
return {self.label: score}
|
|
||||||
|
|
||||||
@detection_metrics.register(DetectionPrecisionConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: DetectionPrecisionConfig):
|
|
||||||
return DetectionPrecision(
|
|
||||||
threshold=config.threshold,
|
|
||||||
label=config.label,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
DetectionMetricConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
DetectionAveragePrecisionConfig,
|
|
||||||
DetectionROCAUCConfig,
|
|
||||||
DetectionRecallConfig,
|
|
||||||
DetectionPrecisionConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_detection_metric(config: DetectionMetricConfig):
|
|
||||||
return detection_metrics.build(config)
|
|
||||||
@ -1,314 +0,0 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import (
|
|
||||||
Annotated,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from pydantic import Field
|
|
||||||
from sklearn import metrics, preprocessing
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
|
||||||
from batdetect2.typing import RawPrediction
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"TopClassMetricConfig",
|
|
||||||
"TopClassMetric",
|
|
||||||
"build_top_class_metric",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MatchEval:
|
|
||||||
clip: data.Clip
|
|
||||||
gt: Optional[data.SoundEventAnnotation]
|
|
||||||
pred: Optional[RawPrediction]
|
|
||||||
|
|
||||||
is_ground_truth: bool
|
|
||||||
is_generic: bool
|
|
||||||
is_prediction: bool
|
|
||||||
pred_class: Optional[str]
|
|
||||||
true_class: Optional[str]
|
|
||||||
score: float
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClipEval:
|
|
||||||
clip: data.Clip
|
|
||||||
matches: List[MatchEval]
|
|
||||||
|
|
||||||
|
|
||||||
TopClassMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
|
||||||
|
|
||||||
|
|
||||||
top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric")
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassAveragePrecisionConfig(BaseConfig):
|
|
||||||
name: Literal["average_precision"] = "average_precision"
|
|
||||||
label: str = "average_precision"
|
|
||||||
ignore_generic: bool = True
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassAveragePrecision:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
label: str = "average_precision",
|
|
||||||
):
|
|
||||||
self.ignore_generic = ignore_generic
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.label = label
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evals: Sequence[ClipEval],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
num_positives = 0
|
|
||||||
|
|
||||||
for clip_eval in clip_evals:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
if m.is_generic and self.ignore_generic:
|
|
||||||
# Ignore gt sounds with unknown class
|
|
||||||
continue
|
|
||||||
|
|
||||||
num_positives += int(m.is_ground_truth)
|
|
||||||
|
|
||||||
if not m.is_prediction and self.ignore_non_predictions:
|
|
||||||
# Ignore non predictions
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(m.pred_class == m.true_class)
|
|
||||||
y_score.append(m.score)
|
|
||||||
|
|
||||||
score = average_precision(y_true, y_score, num_positives=num_positives)
|
|
||||||
return {self.label: score}
|
|
||||||
|
|
||||||
@top_class_metrics.register(TopClassAveragePrecisionConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: TopClassAveragePrecisionConfig):
|
|
||||||
return TopClassAveragePrecision(
|
|
||||||
ignore_generic=config.ignore_generic,
|
|
||||||
label=config.label,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassROCAUCConfig(BaseConfig):
|
|
||||||
name: Literal["roc_auc"] = "roc_auc"
|
|
||||||
ignore_generic: bool = True
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
label: str = "roc_auc"
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassROCAUC:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
label: str = "roc_auc",
|
|
||||||
):
|
|
||||||
self.ignore_generic = ignore_generic
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.label = label
|
|
||||||
|
|
||||||
def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]:
|
|
||||||
y_true: List[bool] = []
|
|
||||||
y_score: List[float] = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evals:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
if m.is_generic and self.ignore_generic:
|
|
||||||
# Ignore gt sounds with unknown class
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not m.is_prediction and self.ignore_non_predictions:
|
|
||||||
# Ignore non predictions
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(m.pred_class == m.true_class)
|
|
||||||
y_score.append(m.score)
|
|
||||||
|
|
||||||
score = float(metrics.roc_auc_score(y_true, y_score))
|
|
||||||
return {self.label: score}
|
|
||||||
|
|
||||||
@top_class_metrics.register(TopClassROCAUCConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: TopClassROCAUCConfig):
|
|
||||||
return TopClassROCAUC(
|
|
||||||
ignore_generic=config.ignore_generic,
|
|
||||||
label=config.label,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassRecallConfig(BaseConfig):
|
|
||||||
name: Literal["recall"] = "recall"
|
|
||||||
threshold: float = 0.5
|
|
||||||
label: str = "recall"
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassRecall:
|
|
||||||
def __init__(self, threshold: float, label: str = "recall"):
|
|
||||||
self.threshold = threshold
|
|
||||||
self.label = label
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
num_positives = 0
|
|
||||||
true_positives = 0
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
if m.is_ground_truth:
|
|
||||||
num_positives += 1
|
|
||||||
|
|
||||||
if m.score >= self.threshold and m.pred_class == m.true_class:
|
|
||||||
true_positives += 1
|
|
||||||
|
|
||||||
if num_positives == 0:
|
|
||||||
return {self.label: np.nan}
|
|
||||||
|
|
||||||
score = true_positives / num_positives
|
|
||||||
return {self.label: score}
|
|
||||||
|
|
||||||
@top_class_metrics.register(TopClassRecallConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: TopClassRecallConfig):
|
|
||||||
return TopClassRecall(
|
|
||||||
threshold=config.threshold,
|
|
||||||
label=config.label,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassPrecisionConfig(BaseConfig):
|
|
||||||
name: Literal["precision"] = "precision"
|
|
||||||
threshold: float = 0.5
|
|
||||||
label: str = "precision"
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassPrecision:
|
|
||||||
def __init__(self, threshold: float, label: str = "precision"):
|
|
||||||
self.threshold = threshold
|
|
||||||
self.label = label
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
num_detections = 0
|
|
||||||
true_positives = 0
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
is_detection = m.score >= self.threshold
|
|
||||||
|
|
||||||
if is_detection:
|
|
||||||
num_detections += 1
|
|
||||||
|
|
||||||
if is_detection and m.pred_class == m.true_class:
|
|
||||||
true_positives += 1
|
|
||||||
|
|
||||||
if num_detections == 0:
|
|
||||||
return {self.label: np.nan}
|
|
||||||
|
|
||||||
score = true_positives / num_detections
|
|
||||||
return {self.label: score}
|
|
||||||
|
|
||||||
@top_class_metrics.register(TopClassPrecisionConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: TopClassPrecisionConfig):
|
|
||||||
return TopClassPrecision(
|
|
||||||
threshold=config.threshold,
|
|
||||||
label=config.label,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BalancedAccuracyConfig(BaseConfig):
|
|
||||||
name: Literal["balanced_accuracy"] = "balanced_accuracy"
|
|
||||||
label: str = "balanced_accuracy"
|
|
||||||
exclude_noise: bool = False
|
|
||||||
noise_class: str = "noise"
|
|
||||||
|
|
||||||
|
|
||||||
class BalancedAccuracy:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
exclude_noise: bool = True,
|
|
||||||
noise_class: str = "noise",
|
|
||||||
label: str = "balanced_accuracy",
|
|
||||||
):
|
|
||||||
self.exclude_noise = exclude_noise
|
|
||||||
self.noise_class = noise_class
|
|
||||||
self.label = label
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
y_true: List[str] = []
|
|
||||||
y_pred: List[str] = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
if m.is_generic:
|
|
||||||
# Ignore matches that correspond to a sound event
|
|
||||||
# with unknown class
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not m.is_ground_truth and self.exclude_noise:
|
|
||||||
# Ignore predictions that were not matched to a
|
|
||||||
# ground truth
|
|
||||||
continue
|
|
||||||
|
|
||||||
if m.pred_class is None and self.exclude_noise:
|
|
||||||
# Ignore non-predictions
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(m.true_class or self.noise_class)
|
|
||||||
y_pred.append(m.pred_class or self.noise_class)
|
|
||||||
|
|
||||||
encoder = preprocessing.LabelEncoder()
|
|
||||||
encoder.fit(list(set(y_true) | set(y_pred)))
|
|
||||||
|
|
||||||
y_true = encoder.transform(y_true)
|
|
||||||
y_pred = encoder.transform(y_pred)
|
|
||||||
score = metrics.balanced_accuracy_score(y_true, y_pred)
|
|
||||||
return {self.label: score}
|
|
||||||
|
|
||||||
@top_class_metrics.register(BalancedAccuracyConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: BalancedAccuracyConfig):
|
|
||||||
return BalancedAccuracy(
|
|
||||||
exclude_noise=config.exclude_noise,
|
|
||||||
noise_class=config.noise_class,
|
|
||||||
label=config.label,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
TopClassMetricConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
TopClassAveragePrecisionConfig,
|
|
||||||
TopClassROCAUCConfig,
|
|
||||||
TopClassRecallConfig,
|
|
||||||
TopClassPrecisionConfig,
|
|
||||||
BalancedAccuracyConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_top_class_metric(config: TopClassMetricConfig):
|
|
||||||
return top_class_metrics.build(config)
|
|
||||||
@ -1,54 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
|
||||||
from batdetect2.typing import TargetProtocol
|
|
||||||
|
|
||||||
|
|
||||||
class BasePlotConfig(BaseConfig):
|
|
||||||
label: str = "plot"
|
|
||||||
theme: str = "default"
|
|
||||||
title: Optional[str] = None
|
|
||||||
figsize: tuple[int, int] = (10, 10)
|
|
||||||
dpi: int = 100
|
|
||||||
|
|
||||||
|
|
||||||
class BasePlot:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
label: str = "plot",
|
|
||||||
figsize: tuple[int, int] = (10, 10),
|
|
||||||
title: Optional[str] = None,
|
|
||||||
dpi: int = 100,
|
|
||||||
theme: str = "default",
|
|
||||||
):
|
|
||||||
self.targets = targets
|
|
||||||
self.label = label
|
|
||||||
self.figsize = figsize
|
|
||||||
self.dpi = dpi
|
|
||||||
self.theme = theme
|
|
||||||
self.title = title
|
|
||||||
|
|
||||||
def create_figure(self) -> Figure:
|
|
||||||
plt.style.use(self.theme)
|
|
||||||
fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
|
|
||||||
|
|
||||||
if self.title is not None:
|
|
||||||
fig.suptitle(self.title)
|
|
||||||
|
|
||||||
return fig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build(cls, config: BasePlotConfig, targets: TargetProtocol, **kwargs):
|
|
||||||
return cls(
|
|
||||||
targets=targets,
|
|
||||||
figsize=config.figsize,
|
|
||||||
dpi=config.dpi,
|
|
||||||
theme=config.theme,
|
|
||||||
label=config.label,
|
|
||||||
title=config.title,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
@ -1,370 +0,0 @@
|
|||||||
from typing import (
|
|
||||||
Annotated,
|
|
||||||
Callable,
|
|
||||||
Iterable,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from pydantic import Field
|
|
||||||
from sklearn import metrics
|
|
||||||
|
|
||||||
from batdetect2.core import Registry
|
|
||||||
from batdetect2.evaluate.metrics.classification import (
|
|
||||||
ClipEval,
|
|
||||||
_extract_per_class_metric_data,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
|
||||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
|
||||||
from batdetect2.plotting.metrics import (
|
|
||||||
plot_pr_curve,
|
|
||||||
plot_pr_curves,
|
|
||||||
plot_roc_curve,
|
|
||||||
plot_roc_curves,
|
|
||||||
plot_threshold_precision_curve,
|
|
||||||
plot_threshold_precision_curves,
|
|
||||||
plot_threshold_recall_curve,
|
|
||||||
plot_threshold_recall_curves,
|
|
||||||
)
|
|
||||||
from batdetect2.typing import TargetProtocol
|
|
||||||
|
|
||||||
ClassificationPlotter = Callable[
|
|
||||||
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
|
||||||
]
|
|
||||||
|
|
||||||
classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
|
|
||||||
Registry("classification_plot")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PRCurveConfig(BasePlotConfig):
|
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
|
||||||
label: str = "pr_curve"
|
|
||||||
title: Optional[str] = "Classification Precision-Recall Curve"
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
ignore_generic: bool = True
|
|
||||||
separate_figures: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class PRCurve(BasePlot):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
separate_figures: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.ignore_generic = ignore_generic
|
|
||||||
self.separate_figures = separate_figures
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
|
||||||
clip_evaluations,
|
|
||||||
ignore_non_predictions=self.ignore_non_predictions,
|
|
||||||
ignore_generic=self.ignore_generic,
|
|
||||||
)
|
|
||||||
|
|
||||||
data = {
|
|
||||||
class_name: compute_precision_recall(
|
|
||||||
y_true[class_name],
|
|
||||||
y_score[class_name],
|
|
||||||
num_positives=num_positives[class_name],
|
|
||||||
)
|
|
||||||
for class_name in self.targets.class_names
|
|
||||||
}
|
|
||||||
|
|
||||||
if not self.separate_figures:
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
plot_pr_curves(data, ax=ax)
|
|
||||||
yield self.label, fig
|
|
||||||
return
|
|
||||||
|
|
||||||
for class_name, (precision, recall, thresholds) in data.items():
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
|
|
||||||
ax.set_title(class_name)
|
|
||||||
|
|
||||||
yield f"{self.label}/{class_name}", fig
|
|
||||||
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
@classification_plots.register(PRCurveConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
|
||||||
return PRCurve.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
|
||||||
ignore_generic=config.ignore_generic,
|
|
||||||
separate_figures=config.separate_figures,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ThresholdPrecisionCurveConfig(BasePlotConfig):
|
|
||||||
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
|
|
||||||
label: str = "threshold_precision_curve"
|
|
||||||
title: Optional[str] = "Classification Threshold-Precision Curve"
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
ignore_generic: bool = True
|
|
||||||
separate_figures: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class ThresholdPrecisionCurve(BasePlot):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
separate_figures: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.ignore_generic = ignore_generic
|
|
||||||
self.separate_figures = separate_figures
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
|
||||||
clip_evaluations,
|
|
||||||
ignore_non_predictions=self.ignore_non_predictions,
|
|
||||||
ignore_generic=self.ignore_generic,
|
|
||||||
)
|
|
||||||
|
|
||||||
data = {
|
|
||||||
class_name: compute_precision_recall(
|
|
||||||
y_true[class_name],
|
|
||||||
y_score[class_name],
|
|
||||||
num_positives[class_name],
|
|
||||||
)
|
|
||||||
for class_name in self.targets.class_names
|
|
||||||
}
|
|
||||||
|
|
||||||
if not self.separate_figures:
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
plot_threshold_precision_curves(data, ax=ax)
|
|
||||||
|
|
||||||
yield self.label, fig
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
for class_name, (precision, _, thresholds) in data.items():
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
ax = plot_threshold_precision_curve(
|
|
||||||
thresholds,
|
|
||||||
precision,
|
|
||||||
ax=ax,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_title(class_name)
|
|
||||||
|
|
||||||
yield f"{self.label}/{class_name}", fig
|
|
||||||
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
@classification_plots.register(ThresholdPrecisionCurveConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: ThresholdPrecisionCurveConfig, targets: TargetProtocol
|
|
||||||
):
|
|
||||||
return ThresholdPrecisionCurve.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
|
||||||
ignore_generic=config.ignore_generic,
|
|
||||||
separate_figures=config.separate_figures,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ThresholdRecallCurveConfig(BasePlotConfig):
|
|
||||||
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
|
|
||||||
label: str = "threshold_recall_curve"
|
|
||||||
title: Optional[str] = "Classification Threshold-Recall Curve"
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
ignore_generic: bool = True
|
|
||||||
separate_figures: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class ThresholdRecallCurve(BasePlot):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
separate_figures: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.ignore_generic = ignore_generic
|
|
||||||
self.separate_figures = separate_figures
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
|
||||||
clip_evaluations,
|
|
||||||
ignore_non_predictions=self.ignore_non_predictions,
|
|
||||||
ignore_generic=self.ignore_generic,
|
|
||||||
)
|
|
||||||
|
|
||||||
data = {
|
|
||||||
class_name: compute_precision_recall(
|
|
||||||
y_true[class_name],
|
|
||||||
y_score[class_name],
|
|
||||||
num_positives[class_name],
|
|
||||||
)
|
|
||||||
for class_name in self.targets.class_names
|
|
||||||
}
|
|
||||||
|
|
||||||
if not self.separate_figures:
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
plot_threshold_recall_curves(data, ax=ax, add_legend=True)
|
|
||||||
|
|
||||||
yield self.label, fig
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
for class_name, (_, recall, thresholds) in data.items():
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
ax = plot_threshold_recall_curve(
|
|
||||||
thresholds,
|
|
||||||
recall,
|
|
||||||
ax=ax,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_title(class_name)
|
|
||||||
|
|
||||||
yield f"{self.label}/{class_name}", fig
|
|
||||||
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
@classification_plots.register(ThresholdRecallCurveConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: ThresholdRecallCurveConfig, targets: TargetProtocol
|
|
||||||
):
|
|
||||||
return ThresholdRecallCurve.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
|
||||||
ignore_generic=config.ignore_generic,
|
|
||||||
separate_figures=config.separate_figures,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ROCCurveConfig(BasePlotConfig):
|
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
|
||||||
label: str = "roc_curve"
|
|
||||||
title: Optional[str] = "Classification ROC Curve"
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
ignore_generic: bool = True
|
|
||||||
separate_figures: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class ROCCurve(BasePlot):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
separate_figures: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.ignore_generic = ignore_generic
|
|
||||||
self.separate_figures = separate_figures
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
y_true, y_score, _ = _extract_per_class_metric_data(
|
|
||||||
clip_evaluations,
|
|
||||||
ignore_non_predictions=self.ignore_non_predictions,
|
|
||||||
ignore_generic=self.ignore_generic,
|
|
||||||
)
|
|
||||||
|
|
||||||
data = {
|
|
||||||
class_name: metrics.roc_curve(
|
|
||||||
y_true[class_name],
|
|
||||||
y_score[class_name],
|
|
||||||
)
|
|
||||||
for class_name in self.targets.class_names
|
|
||||||
}
|
|
||||||
|
|
||||||
if not self.separate_figures:
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
plot_roc_curves(data, ax=ax)
|
|
||||||
|
|
||||||
yield self.label, fig
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
for class_name, (fpr, tpr, thresholds) in data.items():
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
|
||||||
ax.set_title(class_name)
|
|
||||||
|
|
||||||
yield f"{self.label}/{class_name}", fig
|
|
||||||
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
@classification_plots.register(ROCCurveConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
|
||||||
return ROCCurve.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
|
||||||
ignore_generic=config.ignore_generic,
|
|
||||||
separate_figures=config.separate_figures,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
ClassificationPlotConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
PRCurveConfig,
|
|
||||||
ROCCurveConfig,
|
|
||||||
ThresholdPrecisionCurveConfig,
|
|
||||||
ThresholdRecallCurveConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_classification_plotter(
|
|
||||||
config: ClassificationPlotConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
) -> ClassificationPlotter:
|
|
||||||
return classification_plots.build(config, targets)
|
|
||||||
@ -1,189 +0,0 @@
|
|||||||
from typing import (
|
|
||||||
Annotated,
|
|
||||||
Callable,
|
|
||||||
Iterable,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from pydantic import Field
|
|
||||||
from sklearn import metrics
|
|
||||||
|
|
||||||
from batdetect2.core import Registry
|
|
||||||
from batdetect2.evaluate.metrics.clip_classification import ClipEval
|
|
||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
|
||||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
|
||||||
from batdetect2.plotting.metrics import (
|
|
||||||
plot_pr_curve,
|
|
||||||
plot_pr_curves,
|
|
||||||
plot_roc_curve,
|
|
||||||
plot_roc_curves,
|
|
||||||
)
|
|
||||||
from batdetect2.typing import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ClipClassificationPlotConfig",
|
|
||||||
"ClipClassificationPlotter",
|
|
||||||
"build_clip_classification_plotter",
|
|
||||||
]
|
|
||||||
|
|
||||||
ClipClassificationPlotter = Callable[
|
|
||||||
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
|
||||||
]
|
|
||||||
|
|
||||||
clip_classification_plots: Registry[
|
|
||||||
ClipClassificationPlotter, [TargetProtocol]
|
|
||||||
] = Registry("clip_classification_plot")
|
|
||||||
|
|
||||||
|
|
||||||
class PRCurveConfig(BasePlotConfig):
|
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
|
||||||
label: str = "pr_curve"
|
|
||||||
title: Optional[str] = "Clip Classification Precision-Recall Curve"
|
|
||||||
separate_figures: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class PRCurve(BasePlot):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
separate_figures: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.separate_figures = separate_figures
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
data = {}
|
|
||||||
|
|
||||||
for class_name in self.targets.class_names:
|
|
||||||
y_true = [class_name in c.true_classes for c in clip_evaluations]
|
|
||||||
y_score = [
|
|
||||||
c.class_scores.get(class_name, 0) for c in clip_evaluations
|
|
||||||
]
|
|
||||||
|
|
||||||
precision, recall, thresholds = compute_precision_recall(
|
|
||||||
y_true,
|
|
||||||
y_score,
|
|
||||||
)
|
|
||||||
|
|
||||||
data[class_name] = (precision, recall, thresholds)
|
|
||||||
|
|
||||||
if not self.separate_figures:
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
plot_pr_curves(data, ax=ax)
|
|
||||||
|
|
||||||
yield self.label, fig
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
for class_name, (precision, recall, thresholds) in data.items():
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
|
|
||||||
ax.set_title(class_name)
|
|
||||||
|
|
||||||
yield f"{self.label}/{class_name}", fig
|
|
||||||
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
@clip_classification_plots.register(PRCurveConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
|
||||||
return PRCurve.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
separate_figures=config.separate_figures,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ROCCurveConfig(BasePlotConfig):
|
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
|
||||||
label: str = "roc_curve"
|
|
||||||
title: Optional[str] = "Clip Classification ROC Curve"
|
|
||||||
separate_figures: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class ROCCurve(BasePlot):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
separate_figures: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.separate_figures = separate_figures
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
data = {}
|
|
||||||
|
|
||||||
for class_name in self.targets.class_names:
|
|
||||||
y_true = [class_name in c.true_classes for c in clip_evaluations]
|
|
||||||
y_score = [
|
|
||||||
c.class_scores.get(class_name, 0) for c in clip_evaluations
|
|
||||||
]
|
|
||||||
|
|
||||||
fpr, tpr, thresholds = metrics.roc_curve(
|
|
||||||
y_true,
|
|
||||||
y_score,
|
|
||||||
)
|
|
||||||
|
|
||||||
data[class_name] = (fpr, tpr, thresholds)
|
|
||||||
|
|
||||||
if not self.separate_figures:
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
plot_roc_curves(data, ax=ax)
|
|
||||||
yield self.label, fig
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
for class_name, (fpr, tpr, thresholds) in data.items():
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
|
||||||
ax.set_title(class_name)
|
|
||||||
|
|
||||||
yield f"{self.label}/{class_name}", fig
|
|
||||||
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
@clip_classification_plots.register(ROCCurveConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
|
||||||
return ROCCurve.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
separate_figures=config.separate_figures,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
ClipClassificationPlotConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
PRCurveConfig,
|
|
||||||
ROCCurveConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_clip_classification_plotter(
|
|
||||||
config: ClipClassificationPlotConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
) -> ClipClassificationPlotter:
|
|
||||||
return clip_classification_plots.build(config, targets)
|
|
||||||
@ -1,163 +0,0 @@
|
|||||||
from typing import (
|
|
||||||
Annotated,
|
|
||||||
Callable,
|
|
||||||
Iterable,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import seaborn as sns
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from pydantic import Field
|
|
||||||
from sklearn import metrics
|
|
||||||
|
|
||||||
from batdetect2.core import Registry
|
|
||||||
from batdetect2.evaluate.metrics.clip_detection import ClipEval
|
|
||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
|
||||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
|
||||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
|
||||||
from batdetect2.typing import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ClipDetectionPlotConfig",
|
|
||||||
"ClipDetectionPlotter",
|
|
||||||
"build_clip_detection_plotter",
|
|
||||||
]
|
|
||||||
|
|
||||||
ClipDetectionPlotter = Callable[
|
|
||||||
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
|
|
||||||
Registry("clip_detection_plot")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PRCurveConfig(BasePlotConfig):
|
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
|
||||||
label: str = "pr_curve"
|
|
||||||
title: Optional[str] = "Clip Detection Precision-Recall Curve"
|
|
||||||
|
|
||||||
|
|
||||||
class PRCurve(BasePlot):
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
y_true = [c.gt_det for c in clip_evaluations]
|
|
||||||
y_score = [c.score for c in clip_evaluations]
|
|
||||||
|
|
||||||
precision, recall, thresholds = compute_precision_recall(
|
|
||||||
y_true,
|
|
||||||
y_score,
|
|
||||||
)
|
|
||||||
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
|
||||||
yield self.label, fig
|
|
||||||
|
|
||||||
@clip_detection_plots.register(PRCurveConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
|
||||||
return PRCurve.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ROCCurveConfig(BasePlotConfig):
|
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
|
||||||
label: str = "roc_curve"
|
|
||||||
title: Optional[str] = "Clip Detection ROC Curve"
|
|
||||||
|
|
||||||
|
|
||||||
class ROCCurve(BasePlot):
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
y_true = [c.gt_det for c in clip_evaluations]
|
|
||||||
y_score = [c.score for c in clip_evaluations]
|
|
||||||
|
|
||||||
fpr, tpr, thresholds = metrics.roc_curve(
|
|
||||||
y_true,
|
|
||||||
y_score,
|
|
||||||
)
|
|
||||||
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
|
||||||
yield self.label, fig
|
|
||||||
|
|
||||||
@clip_detection_plots.register(ROCCurveConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
|
||||||
return ROCCurve.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
|
||||||
name: Literal["score_distribution"] = "score_distribution"
|
|
||||||
label: str = "score_distribution"
|
|
||||||
title: Optional[str] = "Clip Detection Score Distribution"
|
|
||||||
|
|
||||||
|
|
||||||
class ScoreDistributionPlot(BasePlot):
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
y_true = [c.gt_det for c in clip_evaluations]
|
|
||||||
y_score = [c.score for c in clip_evaluations]
|
|
||||||
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
df = pd.DataFrame({"is_true": y_true, "score": y_score})
|
|
||||||
sns.histplot(
|
|
||||||
data=df,
|
|
||||||
x="score",
|
|
||||||
binwidth=0.025,
|
|
||||||
binrange=(0, 1),
|
|
||||||
hue="is_true",
|
|
||||||
ax=ax,
|
|
||||||
stat="probability",
|
|
||||||
common_norm=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self.label, fig
|
|
||||||
|
|
||||||
@clip_detection_plots.register(ScoreDistributionPlotConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: ScoreDistributionPlotConfig, targets: TargetProtocol
|
|
||||||
):
|
|
||||||
return ScoreDistributionPlot.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
ClipDetectionPlotConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
PRCurveConfig,
|
|
||||||
ROCCurveConfig,
|
|
||||||
ScoreDistributionPlotConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_clip_detection_plotter(
|
|
||||||
config: ClipDetectionPlotConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
) -> ClipDetectionPlotter:
|
|
||||||
return clip_detection_plots.build(config, targets)
|
|
||||||
@ -1,309 +0,0 @@
|
|||||||
import random
|
|
||||||
from typing import (
|
|
||||||
Annotated,
|
|
||||||
Callable,
|
|
||||||
Iterable,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import pandas as pd
|
|
||||||
import seaborn as sns
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from pydantic import Field
|
|
||||||
from sklearn import metrics
|
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
|
||||||
from batdetect2.core import Registry
|
|
||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
|
||||||
from batdetect2.evaluate.metrics.detection import ClipEval
|
|
||||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
|
||||||
from batdetect2.plotting.detections import plot_clip_detections
|
|
||||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
|
||||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
|
||||||
|
|
||||||
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
|
||||||
|
|
||||||
detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
|
|
||||||
name="detection_plot"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PRCurveConfig(BasePlotConfig):
|
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
|
||||||
label: str = "pr_curve"
|
|
||||||
title: Optional[str] = "Detection Precision-Recall Curve"
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
ignore_generic: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class PRCurve(BasePlot):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.ignore_generic = ignore_generic
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evals: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
num_positives = 0
|
|
||||||
|
|
||||||
for clip_eval in clip_evals:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
num_positives += int(m.is_ground_truth)
|
|
||||||
|
|
||||||
# Ignore matches that don't correspond to a prediction
|
|
||||||
if not m.is_prediction and self.ignore_non_predictions:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(m.is_ground_truth)
|
|
||||||
y_score.append(m.score)
|
|
||||||
|
|
||||||
precision, recall, thresholds = compute_precision_recall(
|
|
||||||
y_true,
|
|
||||||
y_score,
|
|
||||||
num_positives=num_positives,
|
|
||||||
)
|
|
||||||
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
|
||||||
|
|
||||||
yield self.label, fig
|
|
||||||
|
|
||||||
@detection_plots.register(PRCurveConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
|
||||||
return PRCurve.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
|
||||||
ignore_generic=config.ignore_generic,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ROCCurveConfig(BasePlotConfig):
|
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
|
||||||
label: str = "roc_curve"
|
|
||||||
title: Optional[str] = "Detection ROC Curve"
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
ignore_generic: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class ROCCurve(BasePlot):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.ignore_generic = ignore_generic
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
if not m.is_prediction and self.ignore_non_predictions:
|
|
||||||
# Ignore matches that don't correspond to a prediction
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(m.is_ground_truth)
|
|
||||||
y_score.append(m.score)
|
|
||||||
|
|
||||||
fpr, tpr, thresholds = metrics.roc_curve(
|
|
||||||
y_true,
|
|
||||||
y_score,
|
|
||||||
)
|
|
||||||
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
|
||||||
|
|
||||||
yield self.label, fig
|
|
||||||
|
|
||||||
@detection_plots.register(ROCCurveConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
|
||||||
return ROCCurve.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
|
||||||
ignore_generic=config.ignore_generic,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
|
||||||
name: Literal["score_distribution"] = "score_distribution"
|
|
||||||
label: str = "score_distribution"
|
|
||||||
title: Optional[str] = "Detection Score Distribution"
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
ignore_generic: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class ScoreDistributionPlot(BasePlot):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.ignore_generic = ignore_generic
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
if not m.is_prediction and self.ignore_non_predictions:
|
|
||||||
# Ignore matches that don't correspond to a prediction
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(m.is_ground_truth)
|
|
||||||
y_score.append(m.score)
|
|
||||||
|
|
||||||
df = pd.DataFrame({"is_true": y_true, "score": y_score})
|
|
||||||
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
sns.histplot(
|
|
||||||
data=df,
|
|
||||||
x="score",
|
|
||||||
binwidth=0.025,
|
|
||||||
binrange=(0, 1),
|
|
||||||
hue="is_true",
|
|
||||||
ax=ax,
|
|
||||||
stat="probability",
|
|
||||||
common_norm=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self.label, fig
|
|
||||||
|
|
||||||
@detection_plots.register(ScoreDistributionPlotConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: ScoreDistributionPlotConfig, targets: TargetProtocol
|
|
||||||
):
|
|
||||||
return ScoreDistributionPlot.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
|
||||||
ignore_generic=config.ignore_generic,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleDetectionPlotConfig(BasePlotConfig):
|
|
||||||
name: Literal["example_detection"] = "example_detection"
|
|
||||||
label: str = "example_detection"
|
|
||||||
title: Optional[str] = "Example Detection"
|
|
||||||
figsize: tuple[int, int] = (10, 4)
|
|
||||||
num_examples: int = 5
|
|
||||||
threshold: float = 0.2
|
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
|
||||||
preprocessing: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleDetectionPlot(BasePlot):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
num_examples: int = 5,
|
|
||||||
threshold: float = 0.2,
|
|
||||||
audio_loader: AudioLoader,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.num_examples = num_examples
|
|
||||||
self.audio_loader = audio_loader
|
|
||||||
self.threshold = threshold
|
|
||||||
self.preprocessor = preprocessor
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
sample = clip_evaluations
|
|
||||||
|
|
||||||
if self.num_examples < len(sample):
|
|
||||||
sample = random.sample(sample, self.num_examples)
|
|
||||||
|
|
||||||
for num_example, clip_eval in enumerate(sample):
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
plot_clip_detections(
|
|
||||||
clip_eval,
|
|
||||||
ax=ax,
|
|
||||||
audio_loader=self.audio_loader,
|
|
||||||
preprocessor=self.preprocessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield f"{self.label}/example_{num_example}", fig
|
|
||||||
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
@detection_plots.register(ExampleDetectionPlotConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: ExampleDetectionPlotConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
):
|
|
||||||
return ExampleDetectionPlot.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
num_examples=config.num_examples,
|
|
||||||
audio_loader=build_audio_loader(config.audio),
|
|
||||||
preprocessor=build_preprocessor(config.preprocessing),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
DetectionPlotConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
PRCurveConfig,
|
|
||||||
ROCCurveConfig,
|
|
||||||
ScoreDistributionPlotConfig,
|
|
||||||
ExampleDetectionPlotConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_detection_plotter(
|
|
||||||
config: DetectionPlotConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
) -> DetectionPlotter:
|
|
||||||
return detection_plots.build(config, targets)
|
|
||||||
@ -1,444 +0,0 @@
|
|||||||
import random
|
|
||||||
from collections import defaultdict
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import (
|
|
||||||
Annotated,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import pandas as pd
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from pydantic import Field
|
|
||||||
from sklearn import metrics
|
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
|
||||||
from batdetect2.core import Registry
|
|
||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
|
||||||
from batdetect2.evaluate.metrics.top_class import ClipEval, MatchEval
|
|
||||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
|
||||||
from batdetect2.plotting.gallery import plot_match_gallery
|
|
||||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
|
||||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
|
||||||
|
|
||||||
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
|
||||||
|
|
||||||
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
|
||||||
name="top_class_plot"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PRCurveConfig(BasePlotConfig):
|
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
|
||||||
label: str = "pr_curve"
|
|
||||||
title: Optional[str] = "Top Class Precision-Recall Curve"
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
ignore_generic: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class PRCurve(BasePlot):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.ignore_generic = ignore_generic
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
num_positives = 0
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
if m.is_generic and self.ignore_generic:
|
|
||||||
# Ignore gt sounds with unknown class
|
|
||||||
continue
|
|
||||||
|
|
||||||
num_positives += int(m.is_ground_truth)
|
|
||||||
|
|
||||||
if not m.is_prediction and self.ignore_non_predictions:
|
|
||||||
# Ignore non predictions
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(m.pred_class == m.true_class)
|
|
||||||
y_score.append(m.score)
|
|
||||||
|
|
||||||
precision, recall, thresholds = compute_precision_recall(
|
|
||||||
y_true,
|
|
||||||
y_score,
|
|
||||||
num_positives=num_positives,
|
|
||||||
)
|
|
||||||
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
|
||||||
|
|
||||||
yield self.label, fig
|
|
||||||
|
|
||||||
@top_class_plots.register(PRCurveConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
|
||||||
return PRCurve.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
|
||||||
ignore_generic=config.ignore_generic,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ROCCurveConfig(BasePlotConfig):
|
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
|
||||||
label: str = "roc_curve"
|
|
||||||
title: Optional[str] = "Top Class ROC Curve"
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
ignore_generic: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class ROCCurve(BasePlot):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.ignore_generic = ignore_generic
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
if m.is_generic and self.ignore_generic:
|
|
||||||
# Ignore gt sounds with unknown class
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not m.is_prediction and self.ignore_non_predictions:
|
|
||||||
# Ignore non predictions
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(m.pred_class == m.true_class)
|
|
||||||
y_score.append(m.score)
|
|
||||||
|
|
||||||
fpr, tpr, thresholds = metrics.roc_curve(
|
|
||||||
y_true,
|
|
||||||
y_score,
|
|
||||||
)
|
|
||||||
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
|
||||||
|
|
||||||
yield self.label, fig
|
|
||||||
|
|
||||||
@top_class_plots.register(ROCCurveConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
|
||||||
return ROCCurve.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
|
||||||
ignore_generic=config.ignore_generic,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ConfusionMatrixConfig(BasePlotConfig):
|
|
||||||
name: Literal["confusion_matrix"] = "confusion_matrix"
|
|
||||||
title: Optional[str] = "Top Class Confusion Matrix"
|
|
||||||
figsize: tuple[int, int] = (10, 10)
|
|
||||||
label: str = "confusion_matrix"
|
|
||||||
exclude_generic: bool = True
|
|
||||||
exclude_noise: bool = False
|
|
||||||
noise_class: str = "noise"
|
|
||||||
normalize: Literal["true", "pred", "all", "none"] = "true"
|
|
||||||
threshold: float = 0.2
|
|
||||||
add_colorbar: bool = True
|
|
||||||
cmap: str = "Blues"
|
|
||||||
|
|
||||||
|
|
||||||
class ConfusionMatrix(BasePlot):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
exclude_generic: bool = True,
|
|
||||||
exclude_noise: bool = False,
|
|
||||||
noise_class: str = "noise",
|
|
||||||
add_colorbar: bool = True,
|
|
||||||
normalize: Literal["true", "pred", "all", "none"] = "true",
|
|
||||||
cmap: str = "Blues",
|
|
||||||
threshold: float = 0.2,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.exclude_generic = exclude_generic
|
|
||||||
self.exclude_noise = exclude_noise
|
|
||||||
self.noise_class = noise_class
|
|
||||||
self.normalize = normalize
|
|
||||||
self.add_colorbar = add_colorbar
|
|
||||||
self.threshold = threshold
|
|
||||||
self.cmap = cmap
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
y_true: List[str] = []
|
|
||||||
y_pred: List[str] = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
true_class = m.true_class
|
|
||||||
pred_class = m.pred_class
|
|
||||||
|
|
||||||
if not m.is_prediction and self.exclude_noise:
|
|
||||||
# Ignore matches that don't correspond to a prediction
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not m.is_ground_truth and self.exclude_noise:
|
|
||||||
# Ignore matches that don't correspond to a ground truth
|
|
||||||
continue
|
|
||||||
|
|
||||||
if m.score < self.threshold:
|
|
||||||
if self.exclude_noise:
|
|
||||||
continue
|
|
||||||
|
|
||||||
pred_class = self.noise_class
|
|
||||||
|
|
||||||
if m.is_generic:
|
|
||||||
if self.exclude_generic:
|
|
||||||
# Ignore gt sounds with unknown class
|
|
||||||
continue
|
|
||||||
|
|
||||||
true_class = self.targets.detection_class_name
|
|
||||||
|
|
||||||
y_true.append(true_class or self.noise_class)
|
|
||||||
y_pred.append(pred_class or self.noise_class)
|
|
||||||
|
|
||||||
fig = self.create_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
class_names = [*self.targets.class_names]
|
|
||||||
|
|
||||||
if not self.exclude_generic:
|
|
||||||
class_names.append(self.targets.detection_class_name)
|
|
||||||
|
|
||||||
if not self.exclude_noise:
|
|
||||||
class_names.append(self.noise_class)
|
|
||||||
|
|
||||||
metrics.ConfusionMatrixDisplay.from_predictions(
|
|
||||||
y_true,
|
|
||||||
y_pred,
|
|
||||||
labels=class_names,
|
|
||||||
ax=ax,
|
|
||||||
xticks_rotation="vertical",
|
|
||||||
cmap=self.cmap,
|
|
||||||
colorbar=self.add_colorbar,
|
|
||||||
normalize=self.normalize if self.normalize != "none" else None,
|
|
||||||
values_format=".2f",
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self.label, fig
|
|
||||||
|
|
||||||
@top_class_plots.register(ConfusionMatrixConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ConfusionMatrixConfig, targets: TargetProtocol):
|
|
||||||
return ConfusionMatrix.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
exclude_generic=config.exclude_generic,
|
|
||||||
exclude_noise=config.exclude_noise,
|
|
||||||
noise_class=config.noise_class,
|
|
||||||
add_colorbar=config.add_colorbar,
|
|
||||||
normalize=config.normalize,
|
|
||||||
cmap=config.cmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleClassificationPlotConfig(BasePlotConfig):
|
|
||||||
name: Literal["example_classification"] = "example_classification"
|
|
||||||
label: str = "example_classification"
|
|
||||||
title: Optional[str] = "Example Classification"
|
|
||||||
num_examples: int = 4
|
|
||||||
threshold: float = 0.2
|
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
|
||||||
preprocessing: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleClassificationPlot(BasePlot):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
num_examples: int = 4,
|
|
||||||
threshold: float = 0.2,
|
|
||||||
audio_loader: AudioLoader,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.num_examples = num_examples
|
|
||||||
self.audio_loader = audio_loader
|
|
||||||
self.threshold = threshold
|
|
||||||
self.preprocessor = preprocessor
|
|
||||||
self.num_examples = num_examples
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipEval],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
grouped = group_matches(clip_evaluations, threshold=self.threshold)
|
|
||||||
|
|
||||||
for class_name, matches in grouped.items():
|
|
||||||
true_positives: List[MatchEval] = get_binned_sample(
|
|
||||||
matches.true_positives,
|
|
||||||
n_examples=self.num_examples,
|
|
||||||
)
|
|
||||||
|
|
||||||
false_positives: List[MatchEval] = get_binned_sample(
|
|
||||||
matches.false_positives,
|
|
||||||
n_examples=self.num_examples,
|
|
||||||
)
|
|
||||||
|
|
||||||
false_negatives: List[MatchEval] = random.sample(
|
|
||||||
matches.false_negatives,
|
|
||||||
k=min(self.num_examples, len(matches.false_negatives)),
|
|
||||||
)
|
|
||||||
|
|
||||||
cross_triggers: List[MatchEval] = get_binned_sample(
|
|
||||||
matches.cross_triggers, n_examples=self.num_examples
|
|
||||||
)
|
|
||||||
|
|
||||||
fig = self.create_figure()
|
|
||||||
|
|
||||||
fig = plot_match_gallery(
|
|
||||||
true_positives,
|
|
||||||
false_positives,
|
|
||||||
false_negatives,
|
|
||||||
cross_triggers,
|
|
||||||
preprocessor=self.preprocessor,
|
|
||||||
audio_loader=self.audio_loader,
|
|
||||||
n_examples=self.num_examples,
|
|
||||||
fig=fig,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.title is not None:
|
|
||||||
fig.suptitle(f"{self.title}: {class_name}")
|
|
||||||
else:
|
|
||||||
fig.suptitle(class_name)
|
|
||||||
|
|
||||||
yield f"{self.label}/{class_name}", fig
|
|
||||||
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
@top_class_plots.register(ExampleClassificationPlotConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: ExampleClassificationPlotConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
):
|
|
||||||
return ExampleClassificationPlot.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
num_examples=config.num_examples,
|
|
||||||
threshold=config.threshold,
|
|
||||||
audio_loader=build_audio_loader(config.audio),
|
|
||||||
preprocessor=build_preprocessor(config.preprocessing),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
TopClassPlotConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
PRCurveConfig,
|
|
||||||
ROCCurveConfig,
|
|
||||||
ConfusionMatrixConfig,
|
|
||||||
ExampleClassificationPlotConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_top_class_plotter(
|
|
||||||
config: TopClassPlotConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
) -> TopClassPlotter:
|
|
||||||
return top_class_plots.build(config, targets)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClassMatches:
|
|
||||||
false_positives: List[MatchEval] = field(default_factory=list)
|
|
||||||
false_negatives: List[MatchEval] = field(default_factory=list)
|
|
||||||
true_positives: List[MatchEval] = field(default_factory=list)
|
|
||||||
cross_triggers: List[MatchEval] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
def group_matches(
|
|
||||||
clip_evals: Sequence[ClipEval],
|
|
||||||
threshold: float = 0.2,
|
|
||||||
) -> Dict[str, ClassMatches]:
|
|
||||||
class_examples = defaultdict(ClassMatches)
|
|
||||||
|
|
||||||
for clip_eval in clip_evals:
|
|
||||||
for match in clip_eval.matches:
|
|
||||||
gt_class = match.true_class
|
|
||||||
pred_class = match.pred_class
|
|
||||||
is_pred = match.score >= threshold
|
|
||||||
|
|
||||||
if not is_pred and gt_class is not None:
|
|
||||||
class_examples[gt_class].false_negatives.append(match)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not is_pred:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if gt_class is None:
|
|
||||||
class_examples[pred_class].false_positives.append(match)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if gt_class != pred_class:
|
|
||||||
class_examples[pred_class].cross_triggers.append(match)
|
|
||||||
continue
|
|
||||||
|
|
||||||
class_examples[gt_class].true_positives.append(match)
|
|
||||||
|
|
||||||
return class_examples
|
|
||||||
|
|
||||||
|
|
||||||
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
|
|
||||||
if len(matches) < n_examples:
|
|
||||||
return matches
|
|
||||||
|
|
||||||
indices, pred_scores = zip(
|
|
||||||
*[(index, match.score) for index, match in enumerate(matches)]
|
|
||||||
)
|
|
||||||
|
|
||||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
|
||||||
df = pd.DataFrame({"indices": indices, "bins": bins})
|
|
||||||
sample = df.groupby("bins").sample(1)
|
|
||||||
return [matches[ind] for ind in sample["indices"]]
|
|
||||||
@ -1,106 +0,0 @@
|
|||||||
from typing import Annotated, Callable, Literal, Sequence, Union
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent.geometry import compute_bounds
|
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
|
||||||
from batdetect2.typing import ClipMatches
|
|
||||||
|
|
||||||
EvaluationTableGenerator = Callable[[Sequence[ClipMatches]], pd.DataFrame]
|
|
||||||
|
|
||||||
|
|
||||||
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
|
|
||||||
"evaluation_table"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FullEvaluationTableConfig(BaseConfig):
|
|
||||||
name: Literal["full_evaluation"] = "full_evaluation"
|
|
||||||
|
|
||||||
|
|
||||||
class FullEvaluationTable:
|
|
||||||
def __call__(
|
|
||||||
self, clip_evaluations: Sequence[ClipMatches]
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
return extract_matches_dataframe(clip_evaluations)
|
|
||||||
|
|
||||||
@tables_registry.register(FullEvaluationTableConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: FullEvaluationTableConfig):
|
|
||||||
return FullEvaluationTable()
|
|
||||||
|
|
||||||
|
|
||||||
def extract_matches_dataframe(
|
|
||||||
clip_evaluations: Sequence[ClipMatches],
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
data = []
|
|
||||||
|
|
||||||
for clip_evaluation in clip_evaluations:
|
|
||||||
for match in clip_evaluation.matches:
|
|
||||||
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
|
||||||
pred_start_time = pred_low_freq = pred_end_time = (
|
|
||||||
pred_high_freq
|
|
||||||
) = None
|
|
||||||
|
|
||||||
sound_event_annotation = match.sound_event_annotation
|
|
||||||
|
|
||||||
if sound_event_annotation is not None:
|
|
||||||
geometry = sound_event_annotation.sound_event.geometry
|
|
||||||
assert geometry is not None
|
|
||||||
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
|
|
||||||
compute_bounds(geometry)
|
|
||||||
)
|
|
||||||
|
|
||||||
if match.pred_geometry is not None:
|
|
||||||
(
|
|
||||||
pred_start_time,
|
|
||||||
pred_low_freq,
|
|
||||||
pred_end_time,
|
|
||||||
pred_high_freq,
|
|
||||||
) = compute_bounds(match.pred_geometry)
|
|
||||||
|
|
||||||
data.append(
|
|
||||||
{
|
|
||||||
("recording", "uuid"): match.clip.recording.uuid,
|
|
||||||
("clip", "uuid"): match.clip.uuid,
|
|
||||||
("clip", "start_time"): match.clip.start_time,
|
|
||||||
("clip", "end_time"): match.clip.end_time,
|
|
||||||
("gt", "uuid"): match.sound_event_annotation.uuid
|
|
||||||
if match.sound_event_annotation is not None
|
|
||||||
else None,
|
|
||||||
("gt", "class"): match.gt_class,
|
|
||||||
("gt", "det"): match.gt_det,
|
|
||||||
("gt", "start_time"): gt_start_time,
|
|
||||||
("gt", "end_time"): gt_end_time,
|
|
||||||
("gt", "low_freq"): gt_low_freq,
|
|
||||||
("gt", "high_freq"): gt_high_freq,
|
|
||||||
("pred", "score"): match.pred_score,
|
|
||||||
("pred", "class"): match.top_class,
|
|
||||||
("pred", "class_score"): match.top_class_score,
|
|
||||||
("pred", "start_time"): pred_start_time,
|
|
||||||
("pred", "end_time"): pred_end_time,
|
|
||||||
("pred", "low_freq"): pred_low_freq,
|
|
||||||
("pred", "high_freq"): pred_high_freq,
|
|
||||||
("match", "affinity"): match.affinity,
|
|
||||||
**{
|
|
||||||
("pred_class_score", key): value
|
|
||||||
for key, value in match.pred_class_scores.items()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
df = pd.DataFrame(data)
|
|
||||||
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
EvaluationTableConfig = Annotated[
|
|
||||||
Union[FullEvaluationTableConfig,], Field(discriminator="name")
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_table_generator(
|
|
||||||
config: EvaluationTableConfig,
|
|
||||||
) -> EvaluationTableGenerator:
|
|
||||||
return tables_registry.build(config)
|
|
||||||
@ -1,39 +0,0 @@
|
|||||||
from typing import Annotated, Optional, Union
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from batdetect2.evaluate.tasks.base import tasks_registry
|
|
||||||
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
|
||||||
from batdetect2.evaluate.tasks.clip_classification import (
|
|
||||||
ClipClassificationTaskConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
|
|
||||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
|
||||||
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"TaskConfig",
|
|
||||||
"build_task",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
TaskConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
ClassificationTaskConfig,
|
|
||||||
DetectionTaskConfig,
|
|
||||||
ClipDetectionTaskConfig,
|
|
||||||
ClipClassificationTaskConfig,
|
|
||||||
TopClassDetectionTaskConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_task(
|
|
||||||
config: TaskConfig,
|
|
||||||
targets: Optional[TargetProtocol] = None,
|
|
||||||
) -> EvaluatorProtocol:
|
|
||||||
targets = targets or build_targets()
|
|
||||||
return tasks_registry.build(config, targets)
|
|
||||||
@ -1,175 +0,0 @@
|
|||||||
from typing import (
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Generic,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
from soundevent.geometry import compute_bounds
|
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
|
||||||
from batdetect2.core.registries import Registry
|
|
||||||
from batdetect2.evaluate.match import (
|
|
||||||
MatchConfig,
|
|
||||||
StartTimeMatchConfig,
|
|
||||||
build_matcher,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
|
|
||||||
from batdetect2.typing.postprocess import RawPrediction
|
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseTaskConfig",
|
|
||||||
"BaseTask",
|
|
||||||
]
|
|
||||||
|
|
||||||
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
|
|
||||||
"tasks"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
T_Output = TypeVar("T_Output")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTaskConfig(BaseConfig):
|
|
||||||
prefix: str
|
|
||||||
ignore_start_end: float = 0.01
|
|
||||||
matching_strategy: MatchConfig = Field(
|
|
||||||
default_factory=StartTimeMatchConfig
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|
||||||
targets: TargetProtocol
|
|
||||||
|
|
||||||
matcher: MatcherProtocol
|
|
||||||
|
|
||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
|
||||||
|
|
||||||
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
|
||||||
|
|
||||||
ignore_start_end: float
|
|
||||||
|
|
||||||
prefix: str
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
matcher: MatcherProtocol,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
|
||||||
prefix: str,
|
|
||||||
ignore_start_end: float = 0.01,
|
|
||||||
plots: Optional[
|
|
||||||
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
|
||||||
] = None,
|
|
||||||
):
|
|
||||||
self.matcher = matcher
|
|
||||||
self.metrics = metrics
|
|
||||||
self.plots = plots or []
|
|
||||||
self.targets = targets
|
|
||||||
self.prefix = prefix
|
|
||||||
self.ignore_start_end = ignore_start_end
|
|
||||||
|
|
||||||
def compute_metrics(
|
|
||||||
self,
|
|
||||||
eval_outputs: List[T_Output],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
scores = [metric(eval_outputs) for metric in self.metrics]
|
|
||||||
return {
|
|
||||||
f"{self.prefix}/{name}": score
|
|
||||||
for metric_output in scores
|
|
||||||
for name, score in metric_output.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
def generate_plots(
|
|
||||||
self, eval_outputs: List[T_Output]
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
for plot in self.plots:
|
|
||||||
for name, fig in plot(eval_outputs):
|
|
||||||
yield f"{self.prefix}/{name}", fig
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
self,
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
predictions: Sequence[Sequence[RawPrediction]],
|
|
||||||
) -> List[T_Output]:
|
|
||||||
return [
|
|
||||||
self.evaluate_clip(clip_annotation, preds)
|
|
||||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
|
||||||
]
|
|
||||||
|
|
||||||
def evaluate_clip(
|
|
||||||
self,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
predictions: Sequence[RawPrediction],
|
|
||||||
) -> T_Output: ...
|
|
||||||
|
|
||||||
def include_sound_event_annotation(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
clip: data.Clip,
|
|
||||||
) -> bool:
|
|
||||||
if not self.targets.filter(sound_event_annotation):
|
|
||||||
return False
|
|
||||||
|
|
||||||
geometry = sound_event_annotation.sound_event.geometry
|
|
||||||
if geometry is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return is_in_bounds(
|
|
||||||
geometry,
|
|
||||||
clip,
|
|
||||||
self.ignore_start_end,
|
|
||||||
)
|
|
||||||
|
|
||||||
def include_prediction(
|
|
||||||
self,
|
|
||||||
prediction: RawPrediction,
|
|
||||||
clip: data.Clip,
|
|
||||||
) -> bool:
|
|
||||||
return is_in_bounds(
|
|
||||||
prediction.geometry,
|
|
||||||
clip,
|
|
||||||
self.ignore_start_end,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def build(
|
|
||||||
cls,
|
|
||||||
config: BaseTaskConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
|
||||||
plots: Optional[
|
|
||||||
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
|
||||||
] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
matcher = build_matcher(config.matching_strategy)
|
|
||||||
return cls(
|
|
||||||
matcher=matcher,
|
|
||||||
targets=targets,
|
|
||||||
metrics=metrics,
|
|
||||||
plots=plots,
|
|
||||||
prefix=config.prefix,
|
|
||||||
ignore_start_end=config.ignore_start_end,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_in_bounds(
|
|
||||||
geometry: data.Geometry,
|
|
||||||
clip: data.Clip,
|
|
||||||
buffer: float,
|
|
||||||
) -> bool:
|
|
||||||
start_time = compute_bounds(geometry)[0]
|
|
||||||
return (start_time >= clip.start_time + buffer) and (
|
|
||||||
start_time <= clip.end_time - buffer
|
|
||||||
)
|
|
||||||
@ -1,149 +0,0 @@
|
|||||||
from typing import (
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Sequence,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.evaluate.metrics.classification import (
|
|
||||||
ClassificationAveragePrecisionConfig,
|
|
||||||
ClassificationMetricConfig,
|
|
||||||
ClipEval,
|
|
||||||
MatchEval,
|
|
||||||
build_classification_metric,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.plots.classification import (
|
|
||||||
ClassificationPlotConfig,
|
|
||||||
build_classification_plotter,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.tasks.base import (
|
|
||||||
BaseTask,
|
|
||||||
BaseTaskConfig,
|
|
||||||
tasks_registry,
|
|
||||||
)
|
|
||||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTaskConfig(BaseTaskConfig):
|
|
||||||
name: Literal["sound_event_classification"] = "sound_event_classification"
|
|
||||||
prefix: str = "classification"
|
|
||||||
metrics: List[ClassificationMetricConfig] = Field(
|
|
||||||
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
|
||||||
)
|
|
||||||
plots: List[ClassificationPlotConfig] = Field(default_factory=list)
|
|
||||||
include_generics: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTask(BaseTask[ClipEval]):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
include_generics: bool = True,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.include_generics = include_generics
|
|
||||||
|
|
||||||
def evaluate_clip(
|
|
||||||
self,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
predictions: Sequence[RawPrediction],
|
|
||||||
) -> ClipEval:
|
|
||||||
clip = clip_annotation.clip
|
|
||||||
|
|
||||||
preds = [
|
|
||||||
pred for pred in predictions if self.include_prediction(pred, clip)
|
|
||||||
]
|
|
||||||
|
|
||||||
all_gts = [
|
|
||||||
sound_event
|
|
||||||
for sound_event in clip_annotation.sound_events
|
|
||||||
if self.include_sound_event_annotation(sound_event, clip)
|
|
||||||
]
|
|
||||||
|
|
||||||
per_class_matches = {}
|
|
||||||
|
|
||||||
for class_name in self.targets.class_names:
|
|
||||||
class_idx = self.targets.class_names.index(class_name)
|
|
||||||
|
|
||||||
# Only match to targets of the given class
|
|
||||||
gts = [
|
|
||||||
sound_event
|
|
||||||
for sound_event in all_gts
|
|
||||||
if self.is_class(sound_event, class_name)
|
|
||||||
]
|
|
||||||
scores = [float(pred.class_scores[class_idx]) for pred in preds]
|
|
||||||
|
|
||||||
matches = []
|
|
||||||
|
|
||||||
for pred_idx, gt_idx, _ in self.matcher(
|
|
||||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
|
||||||
predictions=[pred.geometry for pred in preds],
|
|
||||||
scores=scores,
|
|
||||||
):
|
|
||||||
gt = gts[gt_idx] if gt_idx is not None else None
|
|
||||||
pred = preds[pred_idx] if pred_idx is not None else None
|
|
||||||
|
|
||||||
true_class = (
|
|
||||||
self.targets.encode_class(gt) if gt is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
score = (
|
|
||||||
float(pred.class_scores[class_idx])
|
|
||||||
if pred is not None
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
matches.append(
|
|
||||||
MatchEval(
|
|
||||||
clip=clip,
|
|
||||||
gt=gt,
|
|
||||||
pred=pred,
|
|
||||||
is_prediction=pred is not None,
|
|
||||||
is_ground_truth=gt is not None,
|
|
||||||
is_generic=gt is not None and true_class is None,
|
|
||||||
true_class=true_class,
|
|
||||||
score=score,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
per_class_matches[class_name] = matches
|
|
||||||
|
|
||||||
return ClipEval(clip=clip, matches=per_class_matches)
|
|
||||||
|
|
||||||
def is_class(
|
|
||||||
self,
|
|
||||||
sound_event: data.SoundEventAnnotation,
|
|
||||||
class_name: str,
|
|
||||||
) -> bool:
|
|
||||||
sound_event_class = self.targets.encode_class(sound_event)
|
|
||||||
|
|
||||||
if sound_event_class is None and self.include_generics:
|
|
||||||
# Sound events that are generic could be of the given
|
|
||||||
# class
|
|
||||||
return True
|
|
||||||
|
|
||||||
return sound_event_class == class_name
|
|
||||||
|
|
||||||
@tasks_registry.register(ClassificationTaskConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: ClassificationTaskConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
):
|
|
||||||
metrics = [
|
|
||||||
build_classification_metric(metric, targets)
|
|
||||||
for metric in config.metrics
|
|
||||||
]
|
|
||||||
plots = [
|
|
||||||
build_classification_plotter(plot, targets)
|
|
||||||
for plot in config.plots
|
|
||||||
]
|
|
||||||
return ClassificationTask.build(
|
|
||||||
config=config,
|
|
||||||
plots=plots,
|
|
||||||
targets=targets,
|
|
||||||
metrics=metrics,
|
|
||||||
)
|
|
||||||
@ -1,85 +0,0 @@
|
|||||||
from collections import defaultdict
|
|
||||||
from typing import List, Literal, Sequence
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.evaluate.metrics.clip_classification import (
|
|
||||||
ClipClassificationAveragePrecisionConfig,
|
|
||||||
ClipClassificationMetricConfig,
|
|
||||||
ClipEval,
|
|
||||||
build_clip_metric,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.plots.clip_classification import (
|
|
||||||
ClipClassificationPlotConfig,
|
|
||||||
build_clip_classification_plotter,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.tasks.base import (
|
|
||||||
BaseTask,
|
|
||||||
BaseTaskConfig,
|
|
||||||
tasks_registry,
|
|
||||||
)
|
|
||||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
|
||||||
name: Literal["clip_classification"] = "clip_classification"
|
|
||||||
prefix: str = "clip_classification"
|
|
||||||
metrics: List[ClipClassificationMetricConfig] = Field(
|
|
||||||
default_factory=lambda: [
|
|
||||||
ClipClassificationAveragePrecisionConfig(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationTask(BaseTask[ClipEval]):
|
|
||||||
def evaluate_clip(
|
|
||||||
self,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
predictions: Sequence[RawPrediction],
|
|
||||||
) -> ClipEval:
|
|
||||||
clip = clip_annotation.clip
|
|
||||||
|
|
||||||
gt_classes = set()
|
|
||||||
for sound_event in clip_annotation.sound_events:
|
|
||||||
if not self.include_sound_event_annotation(sound_event, clip):
|
|
||||||
continue
|
|
||||||
|
|
||||||
class_name = self.targets.encode_class(sound_event)
|
|
||||||
|
|
||||||
if class_name is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
gt_classes.add(class_name)
|
|
||||||
|
|
||||||
pred_scores = defaultdict(float)
|
|
||||||
for pred in predictions:
|
|
||||||
if not self.include_prediction(pred, clip):
|
|
||||||
continue
|
|
||||||
|
|
||||||
for class_idx, class_name in enumerate(self.targets.class_names):
|
|
||||||
pred_scores[class_name] = max(
|
|
||||||
float(pred.class_scores[class_idx]),
|
|
||||||
pred_scores[class_name],
|
|
||||||
)
|
|
||||||
|
|
||||||
return ClipEval(true_classes=gt_classes, class_scores=pred_scores)
|
|
||||||
|
|
||||||
@tasks_registry.register(ClipClassificationTaskConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: ClipClassificationTaskConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
):
|
|
||||||
metrics = [build_clip_metric(metric) for metric in config.metrics]
|
|
||||||
plots = [
|
|
||||||
build_clip_classification_plotter(plot, targets)
|
|
||||||
for plot in config.plots
|
|
||||||
]
|
|
||||||
return ClipClassificationTask.build(
|
|
||||||
config=config,
|
|
||||||
plots=plots,
|
|
||||||
metrics=metrics,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
@ -1,76 +0,0 @@
|
|||||||
from typing import List, Literal, Sequence
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.evaluate.metrics.clip_detection import (
|
|
||||||
ClipDetectionAveragePrecisionConfig,
|
|
||||||
ClipDetectionMetricConfig,
|
|
||||||
ClipEval,
|
|
||||||
build_clip_metric,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.plots.clip_detection import (
|
|
||||||
ClipDetectionPlotConfig,
|
|
||||||
build_clip_detection_plotter,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.tasks.base import (
|
|
||||||
BaseTask,
|
|
||||||
BaseTaskConfig,
|
|
||||||
tasks_registry,
|
|
||||||
)
|
|
||||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
|
||||||
name: Literal["clip_detection"] = "clip_detection"
|
|
||||||
prefix: str = "clip_detection"
|
|
||||||
metrics: List[ClipDetectionMetricConfig] = Field(
|
|
||||||
default_factory=lambda: [
|
|
||||||
ClipDetectionAveragePrecisionConfig(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionTask(BaseTask[ClipEval]):
|
|
||||||
def evaluate_clip(
|
|
||||||
self,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
predictions: Sequence[RawPrediction],
|
|
||||||
) -> ClipEval:
|
|
||||||
clip = clip_annotation.clip
|
|
||||||
|
|
||||||
gt_det = any(
|
|
||||||
self.include_sound_event_annotation(sound_event, clip)
|
|
||||||
for sound_event in clip_annotation.sound_events
|
|
||||||
)
|
|
||||||
|
|
||||||
pred_score = 0
|
|
||||||
for pred in predictions:
|
|
||||||
if not self.include_prediction(pred, clip):
|
|
||||||
continue
|
|
||||||
|
|
||||||
pred_score = max(pred_score, pred.detection_score)
|
|
||||||
|
|
||||||
return ClipEval(
|
|
||||||
gt_det=gt_det,
|
|
||||||
score=pred_score,
|
|
||||||
)
|
|
||||||
|
|
||||||
@tasks_registry.register(ClipDetectionTaskConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: ClipDetectionTaskConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
):
|
|
||||||
metrics = [build_clip_metric(metric) for metric in config.metrics]
|
|
||||||
plots = [
|
|
||||||
build_clip_detection_plotter(plot, targets)
|
|
||||||
for plot in config.plots
|
|
||||||
]
|
|
||||||
return ClipDetectionTask.build(
|
|
||||||
config=config,
|
|
||||||
metrics=metrics,
|
|
||||||
targets=targets,
|
|
||||||
plots=plots,
|
|
||||||
)
|
|
||||||
@ -1,88 +0,0 @@
|
|||||||
from typing import List, Literal, Sequence
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.evaluate.metrics.detection import (
|
|
||||||
ClipEval,
|
|
||||||
DetectionAveragePrecisionConfig,
|
|
||||||
DetectionMetricConfig,
|
|
||||||
MatchEval,
|
|
||||||
build_detection_metric,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.plots.detection import (
|
|
||||||
DetectionPlotConfig,
|
|
||||||
build_detection_plotter,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.tasks.base import (
|
|
||||||
BaseTask,
|
|
||||||
BaseTaskConfig,
|
|
||||||
tasks_registry,
|
|
||||||
)
|
|
||||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionTaskConfig(BaseTaskConfig):
|
|
||||||
name: Literal["sound_event_detection"] = "sound_event_detection"
|
|
||||||
prefix: str = "detection"
|
|
||||||
metrics: List[DetectionMetricConfig] = Field(
|
|
||||||
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
|
||||||
)
|
|
||||||
plots: List[DetectionPlotConfig] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionTask(BaseTask[ClipEval]):
|
|
||||||
def evaluate_clip(
|
|
||||||
self,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
predictions: Sequence[RawPrediction],
|
|
||||||
) -> ClipEval:
|
|
||||||
clip = clip_annotation.clip
|
|
||||||
|
|
||||||
gts = [
|
|
||||||
sound_event
|
|
||||||
for sound_event in clip_annotation.sound_events
|
|
||||||
if self.include_sound_event_annotation(sound_event, clip)
|
|
||||||
]
|
|
||||||
preds = [
|
|
||||||
pred for pred in predictions if self.include_prediction(pred, clip)
|
|
||||||
]
|
|
||||||
scores = [pred.detection_score for pred in preds]
|
|
||||||
|
|
||||||
matches = []
|
|
||||||
for pred_idx, gt_idx, _ in self.matcher(
|
|
||||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
|
||||||
predictions=[pred.geometry for pred in preds],
|
|
||||||
scores=scores,
|
|
||||||
):
|
|
||||||
gt = gts[gt_idx] if gt_idx is not None else None
|
|
||||||
pred = preds[pred_idx] if pred_idx is not None else None
|
|
||||||
|
|
||||||
matches.append(
|
|
||||||
MatchEval(
|
|
||||||
gt=gt,
|
|
||||||
pred=pred,
|
|
||||||
is_prediction=pred is not None,
|
|
||||||
is_ground_truth=gt is not None,
|
|
||||||
score=pred.detection_score if pred is not None else 0,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return ClipEval(clip=clip, matches=matches)
|
|
||||||
|
|
||||||
@tasks_registry.register(DetectionTaskConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: DetectionTaskConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
):
|
|
||||||
metrics = [build_detection_metric(metric) for metric in config.metrics]
|
|
||||||
plots = [
|
|
||||||
build_detection_plotter(plot, targets) for plot in config.plots
|
|
||||||
]
|
|
||||||
return DetectionTask.build(
|
|
||||||
config=config,
|
|
||||||
metrics=metrics,
|
|
||||||
targets=targets,
|
|
||||||
plots=plots,
|
|
||||||
)
|
|
||||||
@ -1,111 +0,0 @@
|
|||||||
from typing import List, Literal, Sequence
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.evaluate.metrics.top_class import (
|
|
||||||
ClipEval,
|
|
||||||
MatchEval,
|
|
||||||
TopClassAveragePrecisionConfig,
|
|
||||||
TopClassMetricConfig,
|
|
||||||
build_top_class_metric,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.plots.top_class import (
|
|
||||||
TopClassPlotConfig,
|
|
||||||
build_top_class_plotter,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.tasks.base import (
|
|
||||||
BaseTask,
|
|
||||||
BaseTaskConfig,
|
|
||||||
tasks_registry,
|
|
||||||
)
|
|
||||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassDetectionTaskConfig(BaseTaskConfig):
|
|
||||||
name: Literal["top_class_detection"] = "top_class_detection"
|
|
||||||
prefix: str = "top_class"
|
|
||||||
metrics: List[TopClassMetricConfig] = Field(
|
|
||||||
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
|
||||||
)
|
|
||||||
plots: List[TopClassPlotConfig] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassDetectionTask(BaseTask[ClipEval]):
|
|
||||||
def evaluate_clip(
|
|
||||||
self,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
predictions: Sequence[RawPrediction],
|
|
||||||
) -> ClipEval:
|
|
||||||
clip = clip_annotation.clip
|
|
||||||
|
|
||||||
gts = [
|
|
||||||
sound_event
|
|
||||||
for sound_event in clip_annotation.sound_events
|
|
||||||
if self.include_sound_event_annotation(sound_event, clip)
|
|
||||||
]
|
|
||||||
preds = [
|
|
||||||
pred for pred in predictions if self.include_prediction(pred, clip)
|
|
||||||
]
|
|
||||||
# Take the highest score for each prediction
|
|
||||||
scores = [pred.class_scores.max() for pred in preds]
|
|
||||||
|
|
||||||
matches = []
|
|
||||||
for pred_idx, gt_idx, _ in self.matcher(
|
|
||||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
|
||||||
predictions=[pred.geometry for pred in preds],
|
|
||||||
scores=scores,
|
|
||||||
):
|
|
||||||
gt = gts[gt_idx] if gt_idx is not None else None
|
|
||||||
pred = preds[pred_idx] if pred_idx is not None else None
|
|
||||||
|
|
||||||
true_class = (
|
|
||||||
self.targets.encode_class(gt) if gt is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
class_idx = (
|
|
||||||
pred.class_scores.argmax() if pred is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
score = (
|
|
||||||
float(pred.class_scores[class_idx]) if pred is not None else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
pred_class = (
|
|
||||||
self.targets.class_names[class_idx]
|
|
||||||
if class_idx is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
matches.append(
|
|
||||||
MatchEval(
|
|
||||||
clip=clip,
|
|
||||||
gt=gt,
|
|
||||||
pred=pred,
|
|
||||||
is_ground_truth=gt is not None,
|
|
||||||
is_prediction=pred is not None,
|
|
||||||
true_class=true_class,
|
|
||||||
is_generic=gt is not None and true_class is None,
|
|
||||||
pred_class=pred_class,
|
|
||||||
score=score,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return ClipEval(clip=clip, matches=matches)
|
|
||||||
|
|
||||||
@tasks_registry.register(TopClassDetectionTaskConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: TopClassDetectionTaskConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
):
|
|
||||||
metrics = [build_top_class_metric(metric) for metric in config.metrics]
|
|
||||||
plots = [
|
|
||||||
build_top_class_plotter(plot, targets) for plot in config.plots
|
|
||||||
]
|
|
||||||
return TopClassDetectionTask.build(
|
|
||||||
config=config,
|
|
||||||
plots=plots,
|
|
||||||
metrics=metrics,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
40
src/batdetect2/evaluate/types.py
Normal file
40
src/batdetect2/evaluate/types.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, List, Optional, Protocol
|
||||||
|
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MetricsProtocol",
|
||||||
|
"MatchEvaluation",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MatchEvaluation:
|
||||||
|
match: data.Match
|
||||||
|
|
||||||
|
gt_det: bool
|
||||||
|
gt_class: Optional[str]
|
||||||
|
|
||||||
|
pred_score: float
|
||||||
|
pred_class_scores: Dict[str, float]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pred_class(self) -> Optional[str]:
|
||||||
|
if not self.pred_class_scores:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pred_class_score(self) -> float:
|
||||||
|
pred_class = self.pred_class
|
||||||
|
|
||||||
|
if pred_class is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return self.pred_class_scores[pred_class]
|
||||||
|
|
||||||
|
|
||||||
|
class MetricsProtocol(Protocol):
|
||||||
|
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...
|
||||||
@ -1,10 +0,0 @@
|
|||||||
from batdetect2.inference.batch import process_file_list, run_batch_inference
|
|
||||||
from batdetect2.inference.clips import get_clips_from_files
|
|
||||||
from batdetect2.inference.config import InferenceConfig
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"process_file_list",
|
|
||||||
"run_batch_inference",
|
|
||||||
"InferenceConfig",
|
|
||||||
"get_clips_from_files",
|
|
||||||
]
|
|
||||||
@ -1,88 +0,0 @@
|
|||||||
from typing import TYPE_CHECKING, List, Optional, Sequence
|
|
||||||
|
|
||||||
from lightning import Trainer
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.audio.loader import build_audio_loader
|
|
||||||
from batdetect2.inference.clips import get_clips_from_files
|
|
||||||
from batdetect2.inference.dataset import build_inference_loader
|
|
||||||
from batdetect2.inference.lightning import InferenceModule
|
|
||||||
from batdetect2.models import Model
|
|
||||||
from batdetect2.preprocess.preprocessor import build_preprocessor
|
|
||||||
from batdetect2.targets.targets import build_targets
|
|
||||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from batdetect2.config import BatDetect2Config
|
|
||||||
from batdetect2.typing import (
|
|
||||||
AudioLoader,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def run_batch_inference(
|
|
||||||
model,
|
|
||||||
clips: Sequence[data.Clip],
|
|
||||||
targets: Optional["TargetProtocol"] = None,
|
|
||||||
audio_loader: Optional["AudioLoader"] = None,
|
|
||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
|
||||||
config: Optional["BatDetect2Config"] = None,
|
|
||||||
num_workers: Optional[int] = None,
|
|
||||||
) -> List[BatDetect2Prediction]:
|
|
||||||
from batdetect2.config import BatDetect2Config
|
|
||||||
|
|
||||||
config = config or BatDetect2Config()
|
|
||||||
|
|
||||||
audio_loader = audio_loader or build_audio_loader()
|
|
||||||
|
|
||||||
preprocessor = preprocessor or build_preprocessor(
|
|
||||||
input_samplerate=audio_loader.samplerate,
|
|
||||||
)
|
|
||||||
|
|
||||||
targets = targets or build_targets()
|
|
||||||
|
|
||||||
loader = build_inference_loader(
|
|
||||||
clips,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
config=config.inference.loader,
|
|
||||||
num_workers=num_workers,
|
|
||||||
)
|
|
||||||
|
|
||||||
module = InferenceModule(model)
|
|
||||||
trainer = Trainer(enable_checkpointing=False, logger=False)
|
|
||||||
outputs = trainer.predict(module, loader)
|
|
||||||
return [
|
|
||||||
clip_prediction
|
|
||||||
for clip_predictions in outputs # type: ignore
|
|
||||||
for clip_prediction in clip_predictions
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def process_file_list(
|
|
||||||
model: Model,
|
|
||||||
paths: Sequence[data.PathLike],
|
|
||||||
config: "BatDetect2Config",
|
|
||||||
targets: Optional["TargetProtocol"] = None,
|
|
||||||
audio_loader: Optional["AudioLoader"] = None,
|
|
||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
|
||||||
num_workers: Optional[int] = None,
|
|
||||||
) -> List[BatDetect2Prediction]:
|
|
||||||
clip_config = config.inference.clipping
|
|
||||||
clips = get_clips_from_files(
|
|
||||||
paths,
|
|
||||||
duration=clip_config.duration,
|
|
||||||
overlap=clip_config.overlap,
|
|
||||||
max_empty=clip_config.max_empty,
|
|
||||||
discard_empty=clip_config.discard_empty,
|
|
||||||
)
|
|
||||||
return run_batch_inference(
|
|
||||||
model,
|
|
||||||
clips,
|
|
||||||
targets=targets,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
config=config,
|
|
||||||
num_workers=num_workers,
|
|
||||||
)
|
|
||||||
@ -1,75 +0,0 @@
|
|||||||
from typing import List, Sequence
|
|
||||||
from uuid import uuid5
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
|
|
||||||
def get_clips_from_files(
|
|
||||||
paths: Sequence[data.PathLike],
|
|
||||||
duration: float,
|
|
||||||
overlap: float = 0.0,
|
|
||||||
max_empty: float = 0.0,
|
|
||||||
discard_empty: bool = True,
|
|
||||||
compute_hash: bool = False,
|
|
||||||
) -> List[data.Clip]:
|
|
||||||
clips: List[data.Clip] = []
|
|
||||||
|
|
||||||
for path in paths:
|
|
||||||
recording = data.Recording.from_file(path, compute_hash=compute_hash)
|
|
||||||
clips.extend(
|
|
||||||
get_recording_clips(
|
|
||||||
recording,
|
|
||||||
duration,
|
|
||||||
overlap=overlap,
|
|
||||||
max_empty=max_empty,
|
|
||||||
discard_empty=discard_empty,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return clips
|
|
||||||
|
|
||||||
|
|
||||||
def get_recording_clips(
|
|
||||||
recording: data.Recording,
|
|
||||||
duration: float,
|
|
||||||
overlap: float = 0.0,
|
|
||||||
max_empty: float = 0.0,
|
|
||||||
discard_empty: bool = True,
|
|
||||||
) -> Sequence[data.Clip]:
|
|
||||||
start_time = 0
|
|
||||||
duration = recording.duration
|
|
||||||
hop = duration * (1 - overlap)
|
|
||||||
|
|
||||||
num_clips = int(np.ceil(duration / hop))
|
|
||||||
|
|
||||||
if num_clips == 0:
|
|
||||||
# This should only happen if the clip's duration is zero,
|
|
||||||
# which should never happen in practice, but just in case...
|
|
||||||
return []
|
|
||||||
|
|
||||||
clips = []
|
|
||||||
for i in range(num_clips):
|
|
||||||
start = start_time + i * hop
|
|
||||||
end = start + duration
|
|
||||||
|
|
||||||
if end > duration:
|
|
||||||
empty_duration = end - duration
|
|
||||||
|
|
||||||
if empty_duration > max_empty and discard_empty:
|
|
||||||
# Discard clips that contain too much empty space
|
|
||||||
continue
|
|
||||||
|
|
||||||
clips.append(
|
|
||||||
data.Clip(
|
|
||||||
uuid=uuid5(recording.uuid, f"{start}_{end}"),
|
|
||||||
recording=recording,
|
|
||||||
start_time=start,
|
|
||||||
end_time=end,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if discard_empty:
|
|
||||||
clips = [clip for clip in clips if clip.duration > max_empty]
|
|
||||||
|
|
||||||
return clips
|
|
||||||
@ -1,21 +0,0 @@
|
|||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
from batdetect2.inference.dataset import InferenceLoaderConfig
|
|
||||||
|
|
||||||
__all__ = ["InferenceConfig"]
|
|
||||||
|
|
||||||
|
|
||||||
class ClipingConfig(BaseConfig):
|
|
||||||
enabled: bool = True
|
|
||||||
duration: float = 0.5
|
|
||||||
overlap: float = 0.0
|
|
||||||
max_empty: float = 0.0
|
|
||||||
discard_empty: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceConfig(BaseConfig):
|
|
||||||
loader: InferenceLoaderConfig = Field(
|
|
||||||
default_factory=InferenceLoaderConfig
|
|
||||||
)
|
|
||||||
clipping: ClipingConfig = Field(default_factory=ClipingConfig)
|
|
||||||
@ -1,120 +0,0 @@
|
|||||||
from typing import List, NamedTuple, Optional, Sequence
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from loguru import logger
|
|
||||||
from soundevent import data
|
|
||||||
from torch.utils.data import DataLoader, Dataset
|
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
|
||||||
from batdetect2.core import BaseConfig
|
|
||||||
from batdetect2.core.arrays import adjust_width
|
|
||||||
from batdetect2.preprocess import build_preprocessor
|
|
||||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"InferenceDataset",
|
|
||||||
"build_inference_dataset",
|
|
||||||
"build_inference_loader",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_INFERENCE_CLIP_DURATION = 0.512
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetItem(NamedTuple):
|
|
||||||
spec: torch.Tensor
|
|
||||||
idx: torch.Tensor
|
|
||||||
start_time: torch.Tensor
|
|
||||||
end_time: torch.Tensor
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceDataset(Dataset[DatasetItem]):
|
|
||||||
clips: List[data.Clip]
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
clips: Sequence[data.Clip],
|
|
||||||
audio_loader: AudioLoader,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
):
|
|
||||||
self.clips = list(clips)
|
|
||||||
self.preprocessor = preprocessor
|
|
||||||
self.audio_loader = audio_loader
|
|
||||||
self.audio_dir = audio_dir
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.clips)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> DatasetItem:
|
|
||||||
clip = self.clips[idx]
|
|
||||||
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
|
||||||
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
|
||||||
spectrogram = self.preprocessor(wav_tensor)
|
|
||||||
return DatasetItem(
|
|
||||||
spec=spectrogram,
|
|
||||||
idx=torch.tensor(idx),
|
|
||||||
start_time=torch.tensor(clip.start_time),
|
|
||||||
end_time=torch.tensor(clip.end_time),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceLoaderConfig(BaseConfig):
|
|
||||||
num_workers: int = 0
|
|
||||||
batch_size: int = 8
|
|
||||||
|
|
||||||
|
|
||||||
def build_inference_loader(
|
|
||||||
clips: Sequence[data.Clip],
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
config: Optional[InferenceLoaderConfig] = None,
|
|
||||||
num_workers: Optional[int] = None,
|
|
||||||
) -> DataLoader[DatasetItem]:
|
|
||||||
logger.info("Building inference data loader...")
|
|
||||||
config = config or InferenceLoaderConfig()
|
|
||||||
|
|
||||||
inference_dataset = build_inference_dataset(
|
|
||||||
clips,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_workers = num_workers or config.num_workers
|
|
||||||
return DataLoader(
|
|
||||||
inference_dataset,
|
|
||||||
batch_size=config.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=config.num_workers,
|
|
||||||
collate_fn=_collate_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_inference_dataset(
|
|
||||||
clips: Sequence[data.Clip],
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
) -> InferenceDataset:
|
|
||||||
if audio_loader is None:
|
|
||||||
audio_loader = build_audio_loader()
|
|
||||||
|
|
||||||
if preprocessor is None:
|
|
||||||
preprocessor = build_preprocessor()
|
|
||||||
|
|
||||||
return InferenceDataset(
|
|
||||||
clips,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _collate_fn(batch: List[DatasetItem]) -> DatasetItem:
|
|
||||||
max_width = max(item.spec.shape[-1] for item in batch)
|
|
||||||
return DatasetItem(
|
|
||||||
spec=torch.stack(
|
|
||||||
[adjust_width(item.spec, max_width) for item in batch]
|
|
||||||
),
|
|
||||||
idx=torch.stack([item.idx for item in batch]),
|
|
||||||
start_time=torch.stack([item.start_time for item in batch]),
|
|
||||||
end_time=torch.stack([item.end_time for item in batch]),
|
|
||||||
)
|
|
||||||
@ -1,52 +0,0 @@
|
|||||||
from typing import Sequence
|
|
||||||
|
|
||||||
from lightning import LightningModule
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
|
||||||
from batdetect2.models import Model
|
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
|
||||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceModule(LightningModule):
|
|
||||||
def __init__(self, model: Model):
|
|
||||||
super().__init__()
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
def predict_step(
|
|
||||||
self,
|
|
||||||
batch: DatasetItem,
|
|
||||||
batch_idx: int,
|
|
||||||
dataloader_idx: int = 0,
|
|
||||||
) -> Sequence[BatDetect2Prediction]:
|
|
||||||
dataset = self.get_dataset()
|
|
||||||
|
|
||||||
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
|
|
||||||
|
|
||||||
outputs = self.model.detector(batch.spec)
|
|
||||||
|
|
||||||
clip_detections = self.model.postprocessor(
|
|
||||||
outputs,
|
|
||||||
start_times=[clip.start_time for clip in clips],
|
|
||||||
)
|
|
||||||
|
|
||||||
predictions = [
|
|
||||||
BatDetect2Prediction(
|
|
||||||
clip=clip,
|
|
||||||
predictions=to_raw_predictions(
|
|
||||||
clip_dets.numpy(),
|
|
||||||
targets=self.model.targets,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for clip, clip_dets in zip(clips, clip_detections)
|
|
||||||
]
|
|
||||||
|
|
||||||
return predictions
|
|
||||||
|
|
||||||
def get_dataset(self) -> InferenceDataset:
|
|
||||||
dataloaders = self.trainer.predict_dataloaders
|
|
||||||
assert isinstance(dataloaders, DataLoader)
|
|
||||||
dataset = dataloaders.dataset
|
|
||||||
assert isinstance(dataset, InferenceDataset)
|
|
||||||
return dataset
|
|
||||||
@ -1,314 +0,0 @@
|
|||||||
import io
|
|
||||||
from collections.abc import Callable
|
|
||||||
from functools import partial
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import (
|
|
||||||
Annotated,
|
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
Generic,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Protocol,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from lightning.pytorch.loggers import (
|
|
||||||
CSVLogger,
|
|
||||||
Logger,
|
|
||||||
MLFlowLogger,
|
|
||||||
TensorBoardLogger,
|
|
||||||
)
|
|
||||||
from loguru import logger
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
|
|
||||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
|
||||||
|
|
||||||
|
|
||||||
class BaseLoggerConfig(BaseConfig):
|
|
||||||
log_dir: Path = DEFAULT_LOGS_DIR
|
|
||||||
experiment_name: Optional[str] = None
|
|
||||||
run_name: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
class DVCLiveConfig(BaseLoggerConfig):
|
|
||||||
name: Literal["dvclive"] = "dvclive"
|
|
||||||
prefix: str = ""
|
|
||||||
log_model: Union[bool, Literal["all"]] = False
|
|
||||||
monitor_system: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class CSVLoggerConfig(BaseLoggerConfig):
|
|
||||||
name: Literal["csv"] = "csv"
|
|
||||||
flush_logs_every_n_steps: int = 100
|
|
||||||
|
|
||||||
|
|
||||||
class TensorBoardLoggerConfig(BaseLoggerConfig):
|
|
||||||
name: Literal["tensorboard"] = "tensorboard"
|
|
||||||
log_graph: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class MLFlowLoggerConfig(BaseLoggerConfig):
|
|
||||||
name: Literal["mlflow"] = "mlflow"
|
|
||||||
tracking_uri: Optional[str] = "http://localhost:5000"
|
|
||||||
tags: Optional[dict[str, Any]] = None
|
|
||||||
log_model: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
LoggerConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
DVCLiveConfig,
|
|
||||||
CSVLoggerConfig,
|
|
||||||
TensorBoardLoggerConfig,
|
|
||||||
MLFlowLoggerConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=LoggerConfig, contravariant=True)
|
|
||||||
|
|
||||||
|
|
||||||
class LoggerBuilder(Protocol, Generic[T]):
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
config: T,
|
|
||||||
log_dir: Optional[Path] = None,
|
|
||||||
experiment_name: Optional[str] = None,
|
|
||||||
run_name: Optional[str] = None,
|
|
||||||
) -> Logger: ...
|
|
||||||
|
|
||||||
|
|
||||||
def create_dvclive_logger(
|
|
||||||
config: DVCLiveConfig,
|
|
||||||
log_dir: Optional[Path] = None,
|
|
||||||
experiment_name: Optional[str] = None,
|
|
||||||
run_name: Optional[str] = None,
|
|
||||||
) -> Logger:
|
|
||||||
try:
|
|
||||||
from dvclive.lightning import DVCLiveLogger # type: ignore
|
|
||||||
except ImportError as error:
|
|
||||||
raise ValueError(
|
|
||||||
"DVCLive is not installed and cannot be used for logging"
|
|
||||||
"Make sure you have it installed by running `pip install dvclive`"
|
|
||||||
"or `uv add dvclive`"
|
|
||||||
) from error
|
|
||||||
|
|
||||||
return DVCLiveLogger(
|
|
||||||
dir=log_dir if log_dir is not None else config.log_dir,
|
|
||||||
run_name=run_name if run_name is not None else config.run_name,
|
|
||||||
experiment=experiment_name
|
|
||||||
if experiment_name is not None
|
|
||||||
else config.experiment_name,
|
|
||||||
prefix=config.prefix,
|
|
||||||
log_model=config.log_model,
|
|
||||||
monitor_system=config.monitor_system,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_csv_logger(
|
|
||||||
config: CSVLoggerConfig,
|
|
||||||
log_dir: Optional[Path] = None,
|
|
||||||
experiment_name: Optional[str] = None,
|
|
||||||
run_name: Optional[str] = None,
|
|
||||||
) -> Logger:
|
|
||||||
from lightning.pytorch.loggers import CSVLogger
|
|
||||||
|
|
||||||
if log_dir is None:
|
|
||||||
log_dir = Path(config.log_dir)
|
|
||||||
|
|
||||||
if run_name is None:
|
|
||||||
run_name = config.run_name
|
|
||||||
|
|
||||||
if experiment_name is None:
|
|
||||||
experiment_name = config.experiment_name
|
|
||||||
|
|
||||||
name = run_name
|
|
||||||
|
|
||||||
if run_name is not None and experiment_name is not None:
|
|
||||||
name = str(Path(experiment_name) / run_name)
|
|
||||||
|
|
||||||
return CSVLogger(
|
|
||||||
save_dir=str(log_dir),
|
|
||||||
name=name,
|
|
||||||
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_tensorboard_logger(
|
|
||||||
config: TensorBoardLoggerConfig,
|
|
||||||
log_dir: Optional[Path] = None,
|
|
||||||
experiment_name: Optional[str] = None,
|
|
||||||
run_name: Optional[str] = None,
|
|
||||||
) -> Logger:
|
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
|
||||||
|
|
||||||
if log_dir is None:
|
|
||||||
log_dir = Path(config.log_dir)
|
|
||||||
|
|
||||||
if run_name is None:
|
|
||||||
run_name = config.run_name
|
|
||||||
|
|
||||||
if experiment_name is None:
|
|
||||||
experiment_name = config.experiment_name
|
|
||||||
|
|
||||||
name = run_name
|
|
||||||
|
|
||||||
if name is None:
|
|
||||||
name = experiment_name
|
|
||||||
|
|
||||||
if run_name is not None and experiment_name is not None:
|
|
||||||
name = str(Path(experiment_name) / run_name)
|
|
||||||
|
|
||||||
return TensorBoardLogger(
|
|
||||||
save_dir=str(log_dir),
|
|
||||||
name=name,
|
|
||||||
log_graph=config.log_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_mlflow_logger(
|
|
||||||
config: MLFlowLoggerConfig,
|
|
||||||
log_dir: Optional[data.PathLike] = None,
|
|
||||||
experiment_name: Optional[str] = None,
|
|
||||||
run_name: Optional[str] = None,
|
|
||||||
) -> Logger:
|
|
||||||
try:
|
|
||||||
from lightning.pytorch.loggers import MLFlowLogger
|
|
||||||
except ImportError as error:
|
|
||||||
raise ValueError(
|
|
||||||
"MLFlow is not installed and cannot be used for logging. "
|
|
||||||
"Make sure you have it installed by running `pip install mlflow` "
|
|
||||||
"or `uv add mlflow`"
|
|
||||||
) from error
|
|
||||||
|
|
||||||
if experiment_name is None:
|
|
||||||
experiment_name = config.experiment_name or "Default"
|
|
||||||
|
|
||||||
if log_dir is None:
|
|
||||||
log_dir = config.log_dir
|
|
||||||
|
|
||||||
return MLFlowLogger(
|
|
||||||
experiment_name=experiment_name
|
|
||||||
if experiment_name is not None
|
|
||||||
else config.experiment_name,
|
|
||||||
run_name=run_name if run_name is not None else config.run_name,
|
|
||||||
save_dir=str(log_dir),
|
|
||||||
tracking_uri=config.tracking_uri,
|
|
||||||
tags=config.tags,
|
|
||||||
log_model=config.log_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
LOGGER_FACTORY: Dict[str, LoggerBuilder] = {
|
|
||||||
"dvclive": create_dvclive_logger,
|
|
||||||
"csv": create_csv_logger,
|
|
||||||
"tensorboard": create_tensorboard_logger,
|
|
||||||
"mlflow": create_mlflow_logger,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def build_logger(
|
|
||||||
config: LoggerConfig,
|
|
||||||
log_dir: Optional[Path] = None,
|
|
||||||
experiment_name: Optional[str] = None,
|
|
||||||
run_name: Optional[str] = None,
|
|
||||||
) -> Logger:
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Building logger with config: \n{}",
|
|
||||||
lambda: config.to_yaml_string(),
|
|
||||||
)
|
|
||||||
|
|
||||||
logger_type = config.name
|
|
||||||
if logger_type not in LOGGER_FACTORY:
|
|
||||||
raise ValueError(f"Unknown logger type: {logger_type}")
|
|
||||||
|
|
||||||
creation_func = LOGGER_FACTORY[logger_type]
|
|
||||||
return creation_func(
|
|
||||||
config,
|
|
||||||
log_dir=log_dir,
|
|
||||||
experiment_name=experiment_name,
|
|
||||||
run_name=run_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
PlotLogger = Callable[[str, Figure, int], None]
|
|
||||||
|
|
||||||
|
|
||||||
def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
|
|
||||||
if isinstance(logger, TensorBoardLogger):
|
|
||||||
return logger.experiment.add_figure
|
|
||||||
|
|
||||||
if isinstance(logger, MLFlowLogger):
|
|
||||||
|
|
||||||
def plot_figure(name, figure, step):
|
|
||||||
image = _convert_figure_to_array(figure)
|
|
||||||
name = name.replace("/", "_")
|
|
||||||
return logger.experiment.log_image(
|
|
||||||
logger.run_id,
|
|
||||||
image,
|
|
||||||
key=name,
|
|
||||||
step=step,
|
|
||||||
)
|
|
||||||
|
|
||||||
return plot_figure
|
|
||||||
|
|
||||||
if isinstance(logger, CSVLogger):
|
|
||||||
return partial(save_figure, dir=Path(logger.log_dir))
|
|
||||||
|
|
||||||
|
|
||||||
TableLogger = Callable[[str, pd.DataFrame, int], None]
|
|
||||||
|
|
||||||
|
|
||||||
def get_table_logger(logger: Logger) -> Optional[TableLogger]:
|
|
||||||
if isinstance(logger, TensorBoardLogger):
|
|
||||||
return partial(save_table, dir=Path(logger.log_dir))
|
|
||||||
|
|
||||||
if isinstance(logger, MLFlowLogger):
|
|
||||||
|
|
||||||
def plot_figure(name: str, df: pd.DataFrame, step: int):
|
|
||||||
return logger.experiment.log_table(
|
|
||||||
logger.run_id,
|
|
||||||
data=df,
|
|
||||||
artifact_file=f"{name}_step_{step}.json",
|
|
||||||
)
|
|
||||||
|
|
||||||
return plot_figure
|
|
||||||
|
|
||||||
if isinstance(logger, CSVLogger):
|
|
||||||
return partial(save_table, dir=Path(logger.log_dir))
|
|
||||||
|
|
||||||
|
|
||||||
def save_table(name: str, df: pd.DataFrame, step: int, dir: Path) -> None:
|
|
||||||
path = dir / "tables" / f"{name}_step_{step}.csv"
|
|
||||||
|
|
||||||
if not path.parent.exists():
|
|
||||||
path.parent.mkdir(parents=True)
|
|
||||||
|
|
||||||
df.to_csv(path, index=False)
|
|
||||||
|
|
||||||
|
|
||||||
def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
|
|
||||||
path = dir / "plots" / f"{name}_step_{step}.png"
|
|
||||||
|
|
||||||
if not path.parent.exists():
|
|
||||||
path.parent.mkdir(parents=True)
|
|
||||||
|
|
||||||
fig.savefig(path, transparent=True, bbox_inches="tight")
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_figure_to_array(figure: Figure) -> np.ndarray:
|
|
||||||
with io.BytesIO() as buff:
|
|
||||||
figure.savefig(buff, format="raw")
|
|
||||||
buff.seek(0)
|
|
||||||
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
|
|
||||||
w, h = figure.canvas.get_width_height()
|
|
||||||
im = data.reshape((int(h), int(w), -1))
|
|
||||||
return im
|
|
||||||
@ -26,13 +26,15 @@ for creating a standard BatDetect2 model instance is the `build_model` function
|
|||||||
provided here.
|
provided here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import (
|
||||||
Backbone,
|
Backbone,
|
||||||
|
BackboneConfig,
|
||||||
build_backbone,
|
build_backbone,
|
||||||
|
load_backbone_config,
|
||||||
)
|
)
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
@ -46,37 +48,29 @@ from batdetect2.models.bottleneck import (
|
|||||||
BottleneckConfig,
|
BottleneckConfig,
|
||||||
build_bottleneck,
|
build_bottleneck,
|
||||||
)
|
)
|
||||||
from batdetect2.models.config import (
|
|
||||||
BackboneConfig,
|
|
||||||
load_backbone_config,
|
|
||||||
)
|
|
||||||
from batdetect2.models.decoder import (
|
from batdetect2.models.decoder import (
|
||||||
DEFAULT_DECODER_CONFIG,
|
DEFAULT_DECODER_CONFIG,
|
||||||
DecoderConfig,
|
DecoderConfig,
|
||||||
build_decoder,
|
build_decoder,
|
||||||
)
|
)
|
||||||
from batdetect2.models.detectors import Detector, build_detector
|
from batdetect2.models.detectors import (
|
||||||
|
Detector,
|
||||||
|
build_detector,
|
||||||
|
)
|
||||||
from batdetect2.models.encoder import (
|
from batdetect2.models.encoder import (
|
||||||
DEFAULT_ENCODER_CONFIG,
|
DEFAULT_ENCODER_CONFIG,
|
||||||
EncoderConfig,
|
EncoderConfig,
|
||||||
build_encoder,
|
build_encoder,
|
||||||
)
|
)
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||||
from batdetect2.preprocess import build_preprocessor
|
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.typing.models import DetectionModel
|
|
||||||
from batdetect2.typing.postprocess import (
|
|
||||||
ClipDetectionsTensor,
|
|
||||||
PostprocessorProtocol,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BBoxHead",
|
"BBoxHead",
|
||||||
"Backbone",
|
"Backbone",
|
||||||
"BackboneConfig",
|
"BackboneConfig",
|
||||||
|
"BackboneModel",
|
||||||
|
"BackboneModel",
|
||||||
"Bottleneck",
|
"Bottleneck",
|
||||||
"BottleneckConfig",
|
"BottleneckConfig",
|
||||||
"ClassifierHead",
|
"ClassifierHead",
|
||||||
@ -84,68 +78,65 @@ __all__ = [
|
|||||||
"DEFAULT_DECODER_CONFIG",
|
"DEFAULT_DECODER_CONFIG",
|
||||||
"DEFAULT_ENCODER_CONFIG",
|
"DEFAULT_ENCODER_CONFIG",
|
||||||
"DecoderConfig",
|
"DecoderConfig",
|
||||||
|
"DetectionModel",
|
||||||
"Detector",
|
"Detector",
|
||||||
"DetectorHead",
|
"DetectorHead",
|
||||||
"EncoderConfig",
|
"EncoderConfig",
|
||||||
"FreqCoordConvDownConfig",
|
"FreqCoordConvDownConfig",
|
||||||
"FreqCoordConvUpConfig",
|
"FreqCoordConvUpConfig",
|
||||||
|
"ModelOutput",
|
||||||
"StandardConvDownConfig",
|
"StandardConvDownConfig",
|
||||||
"StandardConvUpConfig",
|
"StandardConvUpConfig",
|
||||||
"build_backbone",
|
"build_backbone",
|
||||||
"build_bottleneck",
|
"build_bottleneck",
|
||||||
"build_decoder",
|
"build_decoder",
|
||||||
"build_encoder",
|
|
||||||
"build_detector",
|
"build_detector",
|
||||||
"load_backbone_config",
|
"build_encoder",
|
||||||
"Model",
|
|
||||||
"build_model",
|
"build_model",
|
||||||
|
"load_backbone_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
|
||||||
detector: DetectionModel
|
|
||||||
preprocessor: PreprocessorProtocol
|
|
||||||
postprocessor: PostprocessorProtocol
|
|
||||||
targets: TargetProtocol
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
detector: DetectionModel,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
postprocessor: PostprocessorProtocol,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.detector = detector
|
|
||||||
self.preprocessor = preprocessor
|
|
||||||
self.postprocessor = postprocessor
|
|
||||||
self.targets = targets
|
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
|
|
||||||
spec = self.preprocessor(wav)
|
|
||||||
outputs = self.detector(spec)
|
|
||||||
return self.postprocessor(outputs)
|
|
||||||
|
|
||||||
|
|
||||||
def build_model(
|
def build_model(
|
||||||
|
num_classes: int,
|
||||||
config: Optional[BackboneConfig] = None,
|
config: Optional[BackboneConfig] = None,
|
||||||
targets: Optional[TargetProtocol] = None,
|
) -> DetectionModel:
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
"""Build the complete BatDetect2 detection model.
|
||||||
postprocessor: Optional[PostprocessorProtocol] = None,
|
|
||||||
):
|
This high-level factory function constructs the standard BatDetect2 model
|
||||||
|
architecture. It first builds the feature extraction backbone (typically an
|
||||||
|
encoder-bottleneck-decoder structure) based on the provided
|
||||||
|
`BackboneConfig` (or defaults if None), and then attaches the standard
|
||||||
|
prediction heads (`DetectorHead`, `ClassifierHead`, `BBoxHead`) using the
|
||||||
|
`build_detector` function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
num_classes : int
|
||||||
|
The number of specific target classes the model should predict
|
||||||
|
(required for the `ClassifierHead`). Must be positive.
|
||||||
|
config : BackboneConfig, optional
|
||||||
|
Configuration object specifying the architecture of the backbone
|
||||||
|
(encoder, bottleneck, decoder). If None, default configurations defined
|
||||||
|
within the respective builder functions (`build_encoder`, etc.) will be
|
||||||
|
used to construct a default backbone architecture.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
DetectionModel
|
||||||
|
An initialized `Detector` model instance.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `num_classes` is not positive, or if errors occur during the
|
||||||
|
construction of the backbone or detector components (e.g., incompatible
|
||||||
|
configurations, invalid parameters).
|
||||||
|
"""
|
||||||
config = config or BackboneConfig()
|
config = config or BackboneConfig()
|
||||||
targets = targets or build_targets()
|
logger.opt(lazy=True).debug(
|
||||||
preprocessor = preprocessor or build_preprocessor()
|
"Building model with config: \n{}",
|
||||||
postprocessor = postprocessor or build_postprocessor(
|
lambda: config.to_yaml_string(),
|
||||||
preprocessor=preprocessor,
|
|
||||||
)
|
|
||||||
detector = build_detector(
|
|
||||||
num_classes=len(targets.class_names),
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
return Model(
|
|
||||||
detector=detector,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
targets=targets,
|
|
||||||
)
|
)
|
||||||
|
backbone = build_backbone(config)
|
||||||
|
return build_detector(num_classes, backbone)
|
||||||
|
|||||||
@ -18,20 +18,37 @@ automatic padding to handle input sizes not perfectly divisible by the
|
|||||||
network's total downsampling factor.
|
network's total downsampling factor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from soundevent import data
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.models.bottleneck import build_bottleneck
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.config import BackboneConfig
|
from batdetect2.models.bottleneck import (
|
||||||
from batdetect2.models.decoder import Decoder, build_decoder
|
DEFAULT_BOTTLENECK_CONFIG,
|
||||||
from batdetect2.models.encoder import Encoder, build_encoder
|
BottleneckConfig,
|
||||||
from batdetect2.typing.models import BackboneModel
|
build_bottleneck,
|
||||||
|
)
|
||||||
|
from batdetect2.models.decoder import (
|
||||||
|
DEFAULT_DECODER_CONFIG,
|
||||||
|
Decoder,
|
||||||
|
DecoderConfig,
|
||||||
|
build_decoder,
|
||||||
|
)
|
||||||
|
from batdetect2.models.encoder import (
|
||||||
|
DEFAULT_ENCODER_CONFIG,
|
||||||
|
Encoder,
|
||||||
|
EncoderConfig,
|
||||||
|
build_encoder,
|
||||||
|
)
|
||||||
|
from batdetect2.models.types import BackboneModel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Backbone",
|
"Backbone",
|
||||||
|
"BackboneConfig",
|
||||||
|
"load_backbone_config",
|
||||||
"build_backbone",
|
"build_backbone",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -144,6 +161,82 @@ class Backbone(BackboneModel):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BackboneConfig(BaseConfig):
|
||||||
|
"""Configuration for the Encoder-Decoder Backbone network.
|
||||||
|
|
||||||
|
Aggregates configurations for the encoder, bottleneck, and decoder
|
||||||
|
components, along with defining the input and final output dimensions
|
||||||
|
for the complete backbone.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
input_height : int, default=128
|
||||||
|
Expected height (frequency bins) of the input spectrograms to the
|
||||||
|
backbone. Must be positive.
|
||||||
|
in_channels : int, default=1
|
||||||
|
Expected number of channels in the input spectrograms (e.g., 1 for
|
||||||
|
mono). Must be positive.
|
||||||
|
encoder : EncoderConfig, optional
|
||||||
|
Configuration for the encoder. If None or omitted,
|
||||||
|
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
|
||||||
|
encoder module) will be used.
|
||||||
|
bottleneck : BottleneckConfig, optional
|
||||||
|
Configuration for the bottleneck layer connecting encoder and decoder.
|
||||||
|
If None or omitted, the default bottleneck configuration will be used.
|
||||||
|
decoder : DecoderConfig, optional
|
||||||
|
Configuration for the decoder. If None or omitted,
|
||||||
|
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
|
||||||
|
decoder module) will be used.
|
||||||
|
out_channels : int, default=32
|
||||||
|
Desired number of channels in the final feature map output by the
|
||||||
|
backbone. Must be positive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_height: int = 128
|
||||||
|
in_channels: int = 1
|
||||||
|
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||||
|
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||||
|
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||||
|
out_channels: int = 32
|
||||||
|
|
||||||
|
|
||||||
|
def load_backbone_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> BackboneConfig:
|
||||||
|
"""Load the backbone configuration from a file.
|
||||||
|
|
||||||
|
Reads a configuration file (YAML) and validates it against the
|
||||||
|
`BackboneConfig` schema, potentially extracting data from a nested field.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
Path to the configuration file.
|
||||||
|
field : str, optional
|
||||||
|
Dot-separated path to a nested section within the file containing the
|
||||||
|
backbone configuration (e.g., "model.backbone"). If None, the entire
|
||||||
|
file content is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BackboneConfig
|
||||||
|
The loaded and validated backbone configuration object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
yaml.YAMLError
|
||||||
|
If the file content is not valid YAML.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the loaded config data does not conform to `BackboneConfig`.
|
||||||
|
KeyError, TypeError
|
||||||
|
If `field` specifies an invalid path.
|
||||||
|
"""
|
||||||
|
return load_config(path, schema=BackboneConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
def build_backbone(config: BackboneConfig) -> BackboneModel:
|
def build_backbone(config: BackboneConfig) -> BackboneModel:
|
||||||
"""Factory function to build a Backbone from configuration.
|
"""Factory function to build a Backbone from configuration.
|
||||||
|
|
||||||
|
|||||||
@ -34,7 +34,7 @@ import torch.nn.functional as F
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ConvBlock",
|
"ConvBlock",
|
||||||
@ -55,12 +55,6 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class SelfAttentionConfig(BaseConfig):
|
|
||||||
name: Literal["SelfAttention"] = "SelfAttention"
|
|
||||||
attention_channels: int
|
|
||||||
temperature: float = 1
|
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
"""Self-Attention mechanism operating along the time dimension.
|
"""Self-Attention mechanism operating along the time dimension.
|
||||||
|
|
||||||
@ -121,7 +115,6 @@ class SelfAttention(nn.Module):
|
|||||||
# Note, does not encode position information (absolute or relative)
|
# Note, does not encode position information (absolute or relative)
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.att_dim = attention_channels
|
self.att_dim = attention_channels
|
||||||
|
|
||||||
self.key_fun = nn.Linear(in_channels, attention_channels)
|
self.key_fun = nn.Linear(in_channels, attention_channels)
|
||||||
self.value_fun = nn.Linear(in_channels, attention_channels)
|
self.value_fun = nn.Linear(in_channels, attention_channels)
|
||||||
self.query_fun = nn.Linear(in_channels, attention_channels)
|
self.query_fun = nn.Linear(in_channels, attention_channels)
|
||||||
@ -178,7 +171,7 @@ class SelfAttention(nn.Module):
|
|||||||
class ConvConfig(BaseConfig):
|
class ConvConfig(BaseConfig):
|
||||||
"""Configuration for a basic ConvBlock."""
|
"""Configuration for a basic ConvBlock."""
|
||||||
|
|
||||||
name: Literal["ConvBlock"] = "ConvBlock"
|
block_type: Literal["ConvBlock"] = "ConvBlock"
|
||||||
"""Discriminator field indicating the block type."""
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
@ -225,7 +218,7 @@ class ConvBlock(nn.Module):
|
|||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
padding=pad_size,
|
padding=pad_size,
|
||||||
)
|
)
|
||||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Apply Conv -> BN -> ReLU.
|
"""Apply Conv -> BN -> ReLU.
|
||||||
@ -240,7 +233,7 @@ class ConvBlock(nn.Module):
|
|||||||
torch.Tensor
|
torch.Tensor
|
||||||
Output tensor, shape `(B, C_out, H, W)`.
|
Output tensor, shape `(B, C_out, H, W)`.
|
||||||
"""
|
"""
|
||||||
return F.relu_(self.batch_norm(self.conv(x)))
|
return F.relu_(self.conv_bn(self.conv(x)))
|
||||||
|
|
||||||
|
|
||||||
class VerticalConv(nn.Module):
|
class VerticalConv(nn.Module):
|
||||||
@ -300,7 +293,7 @@ class VerticalConv(nn.Module):
|
|||||||
class FreqCoordConvDownConfig(BaseConfig):
|
class FreqCoordConvDownConfig(BaseConfig):
|
||||||
"""Configuration for a FreqCoordConvDownBlock."""
|
"""Configuration for a FreqCoordConvDownBlock."""
|
||||||
|
|
||||||
name: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
|
block_type: Literal["FreqCoordConvDown"] = "FreqCoordConvDown"
|
||||||
"""Discriminator field indicating the block type."""
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
@ -364,7 +357,7 @@ class FreqCoordConvDownBlock(nn.Module):
|
|||||||
padding=pad_size,
|
padding=pad_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Apply CoordF -> Conv -> MaxPool -> BN -> ReLU.
|
"""Apply CoordF -> Conv -> MaxPool -> BN -> ReLU.
|
||||||
@ -383,14 +376,14 @@ class FreqCoordConvDownBlock(nn.Module):
|
|||||||
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
||||||
x = torch.cat((x, freq_info), 1)
|
x = torch.cat((x, freq_info), 1)
|
||||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||||
x = F.relu(self.batch_norm(x), inplace=True)
|
x = F.relu(self.conv_bn(x), inplace=True)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class StandardConvDownConfig(BaseConfig):
|
class StandardConvDownConfig(BaseConfig):
|
||||||
"""Configuration for a StandardConvDownBlock."""
|
"""Configuration for a StandardConvDownBlock."""
|
||||||
|
|
||||||
name: Literal["StandardConvDown"] = "StandardConvDown"
|
block_type: Literal["StandardConvDown"] = "StandardConvDown"
|
||||||
"""Discriminator field indicating the block type."""
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
@ -438,7 +431,7 @@ class StandardConvDownBlock(nn.Module):
|
|||||||
padding=pad_size,
|
padding=pad_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
)
|
)
|
||||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""Apply Conv -> MaxPool -> BN -> ReLU.
|
"""Apply Conv -> MaxPool -> BN -> ReLU.
|
||||||
@ -454,13 +447,13 @@ class StandardConvDownBlock(nn.Module):
|
|||||||
Output tensor, shape `(B, C_out, H/2, W/2)`.
|
Output tensor, shape `(B, C_out, H/2, W/2)`.
|
||||||
"""
|
"""
|
||||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||||
return F.relu(self.batch_norm(x), inplace=True)
|
return F.relu(self.conv_bn(x), inplace=True)
|
||||||
|
|
||||||
|
|
||||||
class FreqCoordConvUpConfig(BaseConfig):
|
class FreqCoordConvUpConfig(BaseConfig):
|
||||||
"""Configuration for a FreqCoordConvUpBlock."""
|
"""Configuration for a FreqCoordConvUpBlock."""
|
||||||
|
|
||||||
name: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
|
block_type: Literal["FreqCoordConvUp"] = "FreqCoordConvUp"
|
||||||
"""Discriminator field indicating the block type."""
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
@ -534,7 +527,7 @@ class FreqCoordConvUpBlock(nn.Module):
|
|||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
padding=pad_size,
|
padding=pad_size,
|
||||||
)
|
)
|
||||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Apply Interpolate -> Concat Coords -> Conv -> BN -> ReLU.
|
"""Apply Interpolate -> Concat Coords -> Conv -> BN -> ReLU.
|
||||||
@ -562,14 +555,14 @@ class FreqCoordConvUpBlock(nn.Module):
|
|||||||
freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3])
|
freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3])
|
||||||
op = torch.cat((op, freq_info), 1)
|
op = torch.cat((op, freq_info), 1)
|
||||||
op = self.conv(op)
|
op = self.conv(op)
|
||||||
op = F.relu(self.batch_norm(op), inplace=True)
|
op = F.relu(self.conv_bn(op), inplace=True)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
class StandardConvUpConfig(BaseConfig):
|
class StandardConvUpConfig(BaseConfig):
|
||||||
"""Configuration for a StandardConvUpBlock."""
|
"""Configuration for a StandardConvUpBlock."""
|
||||||
|
|
||||||
name: Literal["StandardConvUp"] = "StandardConvUp"
|
block_type: Literal["StandardConvUp"] = "StandardConvUp"
|
||||||
"""Discriminator field indicating the block type."""
|
"""Discriminator field indicating the block type."""
|
||||||
|
|
||||||
out_channels: int
|
out_channels: int
|
||||||
@ -625,7 +618,7 @@ class StandardConvUpBlock(nn.Module):
|
|||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
padding=pad_size,
|
padding=pad_size,
|
||||||
)
|
)
|
||||||
self.batch_norm = nn.BatchNorm2d(out_channels)
|
self.conv_bn = nn.BatchNorm2d(out_channels)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""Apply Interpolate -> Conv -> BN -> ReLU.
|
"""Apply Interpolate -> Conv -> BN -> ReLU.
|
||||||
@ -650,7 +643,7 @@ class StandardConvUpBlock(nn.Module):
|
|||||||
align_corners=False,
|
align_corners=False,
|
||||||
)
|
)
|
||||||
op = self.conv(op)
|
op = self.conv(op)
|
||||||
op = F.relu(self.batch_norm(op), inplace=True)
|
op = F.relu(self.conv_bn(op), inplace=True)
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
|
||||||
@ -661,16 +654,15 @@ LayerConfig = Annotated[
|
|||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
SelfAttentionConfig,
|
|
||||||
"LayerGroupConfig",
|
"LayerGroupConfig",
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="block_type"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configuration models."""
|
"""Type alias for the discriminated union of block configuration models."""
|
||||||
|
|
||||||
|
|
||||||
class LayerGroupConfig(BaseConfig):
|
class LayerGroupConfig(BaseConfig):
|
||||||
name: Literal["LayerGroup"] = "LayerGroup"
|
block_type: Literal["LayerGroup"] = "LayerGroup"
|
||||||
layers: List[LayerConfig]
|
layers: List[LayerConfig]
|
||||||
|
|
||||||
|
|
||||||
@ -686,7 +678,7 @@ def build_layer_from_config(
|
|||||||
parameters derived from the config and the current pipeline state
|
parameters derived from the config and the current pipeline state
|
||||||
(`input_height`, `in_channels`).
|
(`input_height`, `in_channels`).
|
||||||
|
|
||||||
It uses the `name` field within the `config` object to determine
|
It uses the `block_type` field within the `config` object to determine
|
||||||
which block class to instantiate.
|
which block class to instantiate.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -698,7 +690,7 @@ def build_layer_from_config(
|
|||||||
config : LayerConfig
|
config : LayerConfig
|
||||||
A Pydantic configuration object for the desired block (e.g., an
|
A Pydantic configuration object for the desired block (e.g., an
|
||||||
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
|
instance of `ConvConfig`, `FreqCoordConvDownConfig`, etc.), identified
|
||||||
by its `name` field.
|
by its `block_type` field.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -711,11 +703,11 @@ def build_layer_from_config(
|
|||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
NotImplementedError
|
NotImplementedError
|
||||||
If the `config.name` does not correspond to a known block type.
|
If the `config.block_type` does not correspond to a known block type.
|
||||||
ValueError
|
ValueError
|
||||||
If parameters derived from the config are invalid for the block.
|
If parameters derived from the config are invalid for the block.
|
||||||
"""
|
"""
|
||||||
if config.name == "ConvBlock":
|
if config.block_type == "ConvBlock":
|
||||||
return (
|
return (
|
||||||
ConvBlock(
|
ConvBlock(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -727,7 +719,7 @@ def build_layer_from_config(
|
|||||||
input_height,
|
input_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "FreqCoordConvDown":
|
if config.block_type == "FreqCoordConvDown":
|
||||||
return (
|
return (
|
||||||
FreqCoordConvDownBlock(
|
FreqCoordConvDownBlock(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -740,7 +732,7 @@ def build_layer_from_config(
|
|||||||
input_height // 2,
|
input_height // 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "StandardConvDown":
|
if config.block_type == "StandardConvDown":
|
||||||
return (
|
return (
|
||||||
StandardConvDownBlock(
|
StandardConvDownBlock(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -752,7 +744,7 @@ def build_layer_from_config(
|
|||||||
input_height // 2,
|
input_height // 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "FreqCoordConvUp":
|
if config.block_type == "FreqCoordConvUp":
|
||||||
return (
|
return (
|
||||||
FreqCoordConvUpBlock(
|
FreqCoordConvUpBlock(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -765,7 +757,7 @@ def build_layer_from_config(
|
|||||||
input_height * 2,
|
input_height * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "StandardConvUp":
|
if config.block_type == "StandardConvUp":
|
||||||
return (
|
return (
|
||||||
StandardConvUpBlock(
|
StandardConvUpBlock(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
@ -777,18 +769,7 @@ def build_layer_from_config(
|
|||||||
input_height * 2,
|
input_height * 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "SelfAttention":
|
if config.block_type == "LayerGroup":
|
||||||
return (
|
|
||||||
SelfAttention(
|
|
||||||
in_channels=in_channels,
|
|
||||||
attention_channels=config.attention_channels,
|
|
||||||
temperature=config.temperature,
|
|
||||||
),
|
|
||||||
config.attention_channels,
|
|
||||||
input_height,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.name == "LayerGroup":
|
|
||||||
current_channels = in_channels
|
current_channels = in_channels
|
||||||
current_height = input_height
|
current_height = input_height
|
||||||
|
|
||||||
@ -804,4 +785,4 @@ def build_layer_from_config(
|
|||||||
|
|
||||||
return nn.Sequential(*blocks), current_channels, current_height
|
return nn.Sequential(*blocks), current_channels, current_height
|
||||||
|
|
||||||
raise NotImplementedError(f"Unknown block type {config.name}")
|
raise NotImplementedError(f"Unknown block type {config.block_type}")
|
||||||
|
|||||||
@ -14,26 +14,47 @@ A factory function `build_bottleneck` constructs the appropriate bottleneck
|
|||||||
module based on the provided configuration.
|
module based on the provided configuration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Annotated, List, Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import SelfAttention, VerticalConv
|
||||||
SelfAttentionConfig,
|
|
||||||
VerticalConv,
|
|
||||||
build_layer_from_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BottleneckConfig",
|
"BottleneckConfig",
|
||||||
"Bottleneck",
|
"Bottleneck",
|
||||||
|
"BottleneckAttn",
|
||||||
"build_bottleneck",
|
"build_bottleneck",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckConfig(BaseConfig):
|
||||||
|
"""Configuration for the bottleneck layer(s).
|
||||||
|
|
||||||
|
Defines the number of channels within the bottleneck and whether to include
|
||||||
|
a self-attention mechanism.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
channels : int
|
||||||
|
The number of output channels produced by the main convolutional layer
|
||||||
|
within the bottleneck. This often matches the number of channels coming
|
||||||
|
from the last encoder stage, but can be different. Must be positive.
|
||||||
|
This also defines the channel dimensions used within the optional
|
||||||
|
`SelfAttention` layer.
|
||||||
|
self_attention : bool
|
||||||
|
If True, includes a `SelfAttention` layer operating on the time
|
||||||
|
dimension after an initial `VerticalConv` layer within the bottleneck.
|
||||||
|
If False, only the initial `VerticalConv` (and height repetition) is
|
||||||
|
performed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
channels: int
|
||||||
|
self_attention: bool
|
||||||
|
|
||||||
|
|
||||||
class Bottleneck(nn.Module):
|
class Bottleneck(nn.Module):
|
||||||
"""Base Bottleneck module for Encoder-Decoder architectures.
|
"""Base Bottleneck module for Encoder-Decoder architectures.
|
||||||
|
|
||||||
@ -78,24 +99,16 @@ class Bottleneck(nn.Module):
|
|||||||
input_height: int,
|
input_height: int,
|
||||||
in_channels: int,
|
in_channels: int,
|
||||||
out_channels: int,
|
out_channels: int,
|
||||||
bottleneck_channels: Optional[int] = None,
|
|
||||||
layers: Optional[List[torch.nn.Module]] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the base Bottleneck layer."""
|
"""Initialize the base Bottleneck layer."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.input_height = input_height
|
self.input_height = input_height
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.bottleneck_channels = (
|
|
||||||
bottleneck_channels
|
|
||||||
if bottleneck_channels is not None
|
|
||||||
else out_channels
|
|
||||||
)
|
|
||||||
self.layers = nn.ModuleList(layers or [])
|
|
||||||
|
|
||||||
self.conv_vert = VerticalConv(
|
self.conv_vert = VerticalConv(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=self.bottleneck_channels,
|
out_channels=out_channels,
|
||||||
input_height=input_height,
|
input_height=input_height,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -119,52 +132,73 @@ class Bottleneck(nn.Module):
|
|||||||
convolution.
|
convolution.
|
||||||
"""
|
"""
|
||||||
x = self.conv_vert(x)
|
x = self.conv_vert(x)
|
||||||
|
|
||||||
for layer in self.layers:
|
|
||||||
x = layer(x)
|
|
||||||
|
|
||||||
return x.repeat([1, 1, self.input_height, 1])
|
return x.repeat([1, 1, self.input_height, 1])
|
||||||
|
|
||||||
|
|
||||||
BottleneckLayerConfig = Annotated[
|
class BottleneckAttn(Bottleneck):
|
||||||
Union[SelfAttentionConfig,],
|
"""Bottleneck module including a Self-Attention layer.
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
|
||||||
|
|
||||||
|
Extends the base `Bottleneck` by inserting a `SelfAttention` layer after
|
||||||
|
the initial `VerticalConv`. This allows the bottleneck to capture global
|
||||||
|
temporal dependencies in the summarized frequency features before passing
|
||||||
|
them to the decoder.
|
||||||
|
|
||||||
class BottleneckConfig(BaseConfig):
|
Sequence: VerticalConv -> SelfAttention -> Repeat Height.
|
||||||
"""Configuration for the bottleneck layer(s).
|
|
||||||
|
|
||||||
Defines the number of channels within the bottleneck and whether to include
|
Parameters
|
||||||
a self-attention mechanism.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
----------
|
||||||
channels : int
|
input_height : int
|
||||||
The number of output channels produced by the main convolutional layer
|
Height (frequency bins) of the input tensor from the encoder.
|
||||||
within the bottleneck. This often matches the number of channels coming
|
in_channels : int
|
||||||
from the last encoder stage, but can be different. Must be positive.
|
Number of channels in the input tensor from the encoder.
|
||||||
This also defines the channel dimensions used within the optional
|
out_channels : int
|
||||||
`SelfAttention` layer.
|
Number of output channels produced by the `VerticalConv` and
|
||||||
self_attention : bool
|
subsequently processed and output by this bottleneck. Also determines
|
||||||
If True, includes a `SelfAttention` layer operating on the time
|
the input/output channels of the internal `SelfAttention` layer.
|
||||||
dimension after an initial `VerticalConv` layer within the bottleneck.
|
attention : nn.Module
|
||||||
If False, only the initial `VerticalConv` (and height repetition) is
|
An initialized `SelfAttention` module instance.
|
||||||
performed.
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `input_height`, `in_channels`, or `out_channels` are not positive.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
channels: int
|
def __init__(
|
||||||
layers: List[BottleneckLayerConfig] = Field(
|
self,
|
||||||
default_factory=list,
|
input_height: int,
|
||||||
)
|
in_channels: int,
|
||||||
|
out_channels: int,
|
||||||
|
attention: nn.Module,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the Bottleneck with Self-Attention."""
|
||||||
|
super().__init__(input_height, in_channels, out_channels)
|
||||||
|
self.attention = attention
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Process input tensor.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Input tensor from the encoder bottleneck, shape
|
||||||
|
`(B, C_in, H_in, W)`. `C_in` must match `self.in_channels`,
|
||||||
|
`H_in` must match `self.input_height`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Output tensor, shape `(B, C_out, H_in, W)`, after applying attention
|
||||||
|
and repeating the height dimension.
|
||||||
|
"""
|
||||||
|
x = self.conv_vert(x)
|
||||||
|
x = self.attention(x)
|
||||||
|
return x.repeat([1, 1, self.input_height, 1])
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
DEFAULT_BOTTLENECK_CONFIG: BottleneckConfig = BottleneckConfig(
|
||||||
channels=256,
|
channels=256,
|
||||||
layers=[
|
self_attention=True,
|
||||||
SelfAttentionConfig(attention_channels=256),
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -200,25 +234,21 @@ def build_bottleneck(
|
|||||||
"""
|
"""
|
||||||
config = config or DEFAULT_BOTTLENECK_CONFIG
|
config = config or DEFAULT_BOTTLENECK_CONFIG
|
||||||
|
|
||||||
current_channels = in_channels
|
if config.self_attention:
|
||||||
current_height = input_height
|
attention = SelfAttention(
|
||||||
|
in_channels=config.channels,
|
||||||
layers = []
|
attention_channels=config.channels,
|
||||||
|
|
||||||
for layer_config in config.layers:
|
|
||||||
layer, current_channels, current_height = build_layer_from_config(
|
|
||||||
input_height=current_height,
|
|
||||||
in_channels=current_channels,
|
|
||||||
config=layer_config,
|
|
||||||
)
|
)
|
||||||
assert current_height == input_height, (
|
|
||||||
"Bottleneck layers should not change the spectrogram height"
|
return BottleneckAttn(
|
||||||
|
input_height=input_height,
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=config.channels,
|
||||||
|
attention=attention,
|
||||||
)
|
)
|
||||||
layers.append(layer)
|
|
||||||
|
|
||||||
return Bottleneck(
|
return Bottleneck(
|
||||||
input_height=input_height,
|
input_height=input_height,
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=config.channels,
|
out_channels=config.channels,
|
||||||
layers=layers,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,98 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
|
||||||
from batdetect2.models.bottleneck import (
|
|
||||||
DEFAULT_BOTTLENECK_CONFIG,
|
|
||||||
BottleneckConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.models.decoder import (
|
|
||||||
DEFAULT_DECODER_CONFIG,
|
|
||||||
DecoderConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.models.encoder import (
|
|
||||||
DEFAULT_ENCODER_CONFIG,
|
|
||||||
EncoderConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BackboneConfig",
|
|
||||||
"load_backbone_config",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class BackboneConfig(BaseConfig):
|
|
||||||
"""Configuration for the Encoder-Decoder Backbone network.
|
|
||||||
|
|
||||||
Aggregates configurations for the encoder, bottleneck, and decoder
|
|
||||||
components, along with defining the input and final output dimensions
|
|
||||||
for the complete backbone.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
input_height : int, default=128
|
|
||||||
Expected height (frequency bins) of the input spectrograms to the
|
|
||||||
backbone. Must be positive.
|
|
||||||
in_channels : int, default=1
|
|
||||||
Expected number of channels in the input spectrograms (e.g., 1 for
|
|
||||||
mono). Must be positive.
|
|
||||||
encoder : EncoderConfig, optional
|
|
||||||
Configuration for the encoder. If None or omitted,
|
|
||||||
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
|
|
||||||
encoder module) will be used.
|
|
||||||
bottleneck : BottleneckConfig, optional
|
|
||||||
Configuration for the bottleneck layer connecting encoder and decoder.
|
|
||||||
If None or omitted, the default bottleneck configuration will be used.
|
|
||||||
decoder : DecoderConfig, optional
|
|
||||||
Configuration for the decoder. If None or omitted,
|
|
||||||
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
|
|
||||||
decoder module) will be used.
|
|
||||||
out_channels : int, default=32
|
|
||||||
Desired number of channels in the final feature map output by the
|
|
||||||
backbone. Must be positive.
|
|
||||||
"""
|
|
||||||
|
|
||||||
input_height: int = 128
|
|
||||||
in_channels: int = 1
|
|
||||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
|
||||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
|
||||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
|
||||||
out_channels: int = 32
|
|
||||||
|
|
||||||
|
|
||||||
def load_backbone_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> BackboneConfig:
|
|
||||||
"""Load the backbone configuration from a file.
|
|
||||||
|
|
||||||
Reads a configuration file (YAML) and validates it against the
|
|
||||||
`BackboneConfig` schema, potentially extracting data from a nested field.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : PathLike
|
|
||||||
Path to the configuration file.
|
|
||||||
field : str, optional
|
|
||||||
Dot-separated path to a nested section within the file containing the
|
|
||||||
backbone configuration (e.g., "model.backbone"). If None, the entire
|
|
||||||
file content is used.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
BackboneConfig
|
|
||||||
The loaded and validated backbone configuration object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the config file path does not exist.
|
|
||||||
yaml.YAMLError
|
|
||||||
If the file content is not valid YAML.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the loaded config data does not conform to `BackboneConfig`.
|
|
||||||
KeyError, TypeError
|
|
||||||
If `field` specifies an invalid path.
|
|
||||||
"""
|
|
||||||
return load_config(path, schema=BackboneConfig, field=field)
|
|
||||||
@ -24,7 +24,7 @@ import torch
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
@ -47,7 +47,7 @@ DecoderLayerConfig = Annotated[
|
|||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
LayerGroupConfig,
|
LayerGroupConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="block_type"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
"""Type alias for the discriminated union of block configs usable in Decoder."""
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ class DecoderConfig(BaseConfig):
|
|||||||
layers : List[DecoderLayerConfig]
|
layers : List[DecoderLayerConfig]
|
||||||
An ordered list of configuration objects, each defining one layer or
|
An ordered list of configuration objects, each defining one layer or
|
||||||
block in the decoder sequence. Each item must be a valid block
|
block in the decoder sequence. Each item must be a valid block
|
||||||
config including a `name` field and necessary parameters like
|
config including a `block_type` field and necessary parameters like
|
||||||
`out_channels`. Input channels for each layer are inferred sequentially.
|
`out_channels`. Input channels for each layer are inferred sequentially.
|
||||||
The list must contain at least one layer.
|
The list must contain at least one layer.
|
||||||
"""
|
"""
|
||||||
@ -249,9 +249,9 @@ def build_decoder(
|
|||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If `in_channels` or `input_height` are not positive, or if the layer
|
If `in_channels` or `input_height` are not positive, or if the layer
|
||||||
configuration is invalid (e.g., empty list, unknown `name`).
|
configuration is invalid (e.g., empty list, unknown `block_type`).
|
||||||
NotImplementedError
|
NotImplementedError
|
||||||
If `build_layer_from_config` encounters an unknown `name`.
|
If `build_layer_from_config` encounters an unknown `block_type`.
|
||||||
"""
|
"""
|
||||||
config = config or DEFAULT_DECODER_CONFIG
|
config = config or DEFAULT_DECODER_CONFIG
|
||||||
|
|
||||||
|
|||||||
@ -8,25 +8,18 @@ classifying them.
|
|||||||
|
|
||||||
The primary components are:
|
The primary components are:
|
||||||
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
|
- `Detector`: The `torch.nn.Module` subclass representing the complete model.
|
||||||
|
- `build_detector`: A factory function to conveniently construct a standard
|
||||||
|
`Detector` instance given a backbone and the number of target classes.
|
||||||
|
|
||||||
This module focuses purely on the neural network architecture definition. The
|
This module focuses purely on the neural network architecture definition. The
|
||||||
logic for preprocessing inputs and postprocessing/decoding outputs resides in
|
logic for preprocessing inputs and postprocessing/decoding outputs resides in
|
||||||
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
the `batdetect2.preprocess` and `batdetect2.postprocess` packages, respectively.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from batdetect2.models.backbones import BackboneConfig, build_backbone
|
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
from batdetect2.models.types import BackboneModel, DetectionModel, ModelOutput
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Detector",
|
|
||||||
"build_detector",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Detector(DetectionModel):
|
class Detector(DetectionModel):
|
||||||
@ -126,41 +119,36 @@ class Detector(DetectionModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_detector(
|
def build_detector(num_classes: int, backbone: BackboneModel) -> Detector:
|
||||||
num_classes: int, config: Optional[BackboneConfig] = None
|
"""Factory function to build a standard Detector model instance.
|
||||||
) -> DetectionModel:
|
|
||||||
"""Build the complete BatDetect2 detection model.
|
Creates the standard prediction heads (`ClassifierHead`, `DetectorHead`,
|
||||||
|
`BBoxHead`) configured appropriately based on the output channels of the
|
||||||
|
provided `backbone` and the specified `num_classes`. It then assembles
|
||||||
|
these components into a `Detector` model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
num_classes : int
|
num_classes : int
|
||||||
The number of specific target classes the model should predict
|
The number of specific target classes for the classification head
|
||||||
(required for the `ClassifierHead`). Must be positive.
|
(excluding any implicit background class). Must be positive.
|
||||||
config : BackboneConfig, optional
|
backbone : BackboneModel
|
||||||
Configuration object specifying the architecture of the backbone
|
An initialized feature extraction backbone module instance. The number
|
||||||
(encoder, bottleneck, decoder). If None, default configurations defined
|
of output channels from this backbone (`backbone.out_channels`) is used
|
||||||
within the respective builder functions (`build_encoder`, etc.) will be
|
to configure the input channels for the prediction heads.
|
||||||
used to construct a default backbone architecture.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
DetectionModel
|
Detector
|
||||||
An initialized `Detector` model instance.
|
An initialized `Detector` model instance.
|
||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If `num_classes` is not positive, or if errors occur during the
|
If `num_classes` is not positive.
|
||||||
construction of the backbone or detector components (e.g., incompatible
|
AttributeError
|
||||||
configurations, invalid parameters).
|
If `backbone` does not have the required `out_channels` attribute.
|
||||||
"""
|
"""
|
||||||
config = config or BackboneConfig()
|
|
||||||
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Building model with config: \n{}",
|
|
||||||
lambda: config.to_yaml_string(),
|
|
||||||
)
|
|
||||||
backbone = build_backbone(config=config)
|
|
||||||
classifier_head = ClassifierHead(
|
classifier_head = ClassifierHead(
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
in_channels=backbone.out_channels,
|
in_channels=backbone.out_channels,
|
||||||
|
|||||||
@ -26,7 +26,7 @@ import torch
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvDownConfig,
|
FreqCoordConvDownConfig,
|
||||||
@ -49,7 +49,7 @@ EncoderLayerConfig = Annotated[
|
|||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
LayerGroupConfig,
|
LayerGroupConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="block_type"),
|
||||||
]
|
]
|
||||||
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
"""Type alias for the discriminated union of block configs usable in Encoder."""
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ class EncoderConfig(BaseConfig):
|
|||||||
An ordered list of configuration objects, each defining one layer or
|
An ordered list of configuration objects, each defining one layer or
|
||||||
block in the encoder sequence. Each item must be a valid block config
|
block in the encoder sequence. Each item must be a valid block config
|
||||||
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
|
(e.g., `ConvConfig`, `FreqCoordConvDownConfig`,
|
||||||
`StandardConvDownConfig`) including a `name` field and necessary
|
`StandardConvDownConfig`) including a `block_type` field and necessary
|
||||||
parameters like `out_channels`. Input channels for each layer are
|
parameters like `out_channels`. Input channels for each layer are
|
||||||
inferred sequentially. The list must contain at least one layer.
|
inferred sequentially. The list must contain at least one layer.
|
||||||
"""
|
"""
|
||||||
@ -287,9 +287,9 @@ def build_encoder(
|
|||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If `in_channels` or `input_height` are not positive, or if the layer
|
If `in_channels` or `input_height` are not positive, or if the layer
|
||||||
configuration is invalid (e.g., empty list, unknown `name`).
|
configuration is invalid (e.g., empty list, unknown `block_type`).
|
||||||
NotImplementedError
|
NotImplementedError
|
||||||
If `build_layer_from_config` encounters an unknown `name`.
|
If `build_layer_from_config` encounters an unknown `block_type`.
|
||||||
"""
|
"""
|
||||||
if in_channels <= 0 or input_height <= 0:
|
if in_channels <= 0 or input_height <= 0:
|
||||||
raise ValueError("in_channels and input_height must be positive.")
|
raise ValueError("in_channels and input_height must be positive.")
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ModelOutput",
|
"ModelOutput",
|
||||||
@ -64,7 +65,7 @@ class ModelOutput(NamedTuple):
|
|||||||
features: torch.Tensor
|
features: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class BackboneModel(ABC, torch.nn.Module):
|
class BackboneModel(ABC, nn.Module):
|
||||||
"""Abstract Base Class for generic feature extraction backbone models.
|
"""Abstract Base Class for generic feature extraction backbone models.
|
||||||
|
|
||||||
Defines the minimal interface for a feature extractor network within a
|
Defines the minimal interface for a feature extractor network within a
|
||||||
@ -190,7 +191,7 @@ class EncoderDecoderModel(BackboneModel):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class DetectionModel(ABC, torch.nn.Module):
|
class DetectionModel(ABC, nn.Module):
|
||||||
"""Abstract Base Class for complete BatDetect2 detection models.
|
"""Abstract Base Class for complete BatDetect2 detection models.
|
||||||
|
|
||||||
Defines the interface for the overall model that takes an input spectrogram
|
Defines the interface for the overall model that takes an input spectrogram
|
||||||
@ -1,16 +1,11 @@
|
|||||||
from batdetect2.plotting.clip_annotations import plot_clip_annotation
|
from batdetect2.plotting.clip_annotations import plot_clip_annotation
|
||||||
from batdetect2.plotting.clip_predictions import plot_clip_prediction
|
from batdetect2.plotting.clip_predictions import plot_clip_prediction
|
||||||
from batdetect2.plotting.clips import plot_clip
|
from batdetect2.plotting.clips import plot_clip
|
||||||
from batdetect2.plotting.common import plot_spectrogram
|
|
||||||
from batdetect2.plotting.gallery import plot_match_gallery
|
|
||||||
from batdetect2.plotting.heatmaps import (
|
|
||||||
plot_classification_heatmap,
|
|
||||||
plot_detection_heatmap,
|
|
||||||
)
|
|
||||||
from batdetect2.plotting.matches import (
|
from batdetect2.plotting.matches import (
|
||||||
plot_cross_trigger_match,
|
plot_cross_trigger_match,
|
||||||
plot_false_negative_match,
|
plot_false_negative_match,
|
||||||
plot_false_positive_match,
|
plot_false_positive_match,
|
||||||
|
plot_matches,
|
||||||
plot_true_positive_match,
|
plot_true_positive_match,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -18,12 +13,9 @@ __all__ = [
|
|||||||
"plot_clip",
|
"plot_clip",
|
||||||
"plot_clip_annotation",
|
"plot_clip_annotation",
|
||||||
"plot_clip_prediction",
|
"plot_clip_prediction",
|
||||||
"plot_cross_trigger_match",
|
"plot_matches",
|
||||||
"plot_false_negative_match",
|
|
||||||
"plot_false_positive_match",
|
"plot_false_positive_match",
|
||||||
"plot_spectrogram",
|
|
||||||
"plot_true_positive_match",
|
"plot_true_positive_match",
|
||||||
"plot_detection_heatmap",
|
"plot_false_negative_match",
|
||||||
"plot_classification_heatmap",
|
"plot_cross_trigger_match",
|
||||||
"plot_match_gallery",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@ -4,9 +4,7 @@ from matplotlib.axes import Axes
|
|||||||
from soundevent import data, plot
|
from soundevent import data, plot
|
||||||
|
|
||||||
from batdetect2.plotting.clips import plot_clip
|
from batdetect2.plotting.clips import plot_clip
|
||||||
from batdetect2.plotting.common import create_ax
|
from batdetect2.preprocess import PreprocessorProtocol
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plot_clip_annotation",
|
"plot_clip_annotation",
|
||||||
@ -19,6 +17,8 @@ def plot_clip_annotation(
|
|||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
add_colorbar: bool = False,
|
||||||
|
add_labels: bool = False,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
cmap: str = "gray",
|
cmap: str = "gray",
|
||||||
alpha: float = 1,
|
alpha: float = 1,
|
||||||
@ -31,6 +31,8 @@ def plot_clip_annotation(
|
|||||||
figsize=figsize,
|
figsize=figsize,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
audio_dir=audio_dir,
|
audio_dir=audio_dir,
|
||||||
|
add_colorbar=add_colorbar,
|
||||||
|
add_labels=add_labels,
|
||||||
spec_cmap=cmap,
|
spec_cmap=cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -45,29 +47,3 @@ def plot_clip_annotation(
|
|||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
)
|
)
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
def plot_anchor_points(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
|
||||||
ax: Optional[Axes] = None,
|
|
||||||
size: int = 1,
|
|
||||||
color: str = "red",
|
|
||||||
marker: str = "x",
|
|
||||||
alpha: float = 1,
|
|
||||||
) -> Axes:
|
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
|
||||||
|
|
||||||
positions = []
|
|
||||||
|
|
||||||
for sound_event in clip_annotation.sound_events:
|
|
||||||
if not targets.filter(sound_event):
|
|
||||||
continue
|
|
||||||
|
|
||||||
position, _ = targets.encode_roi(sound_event)
|
|
||||||
positions.append(position)
|
|
||||||
|
|
||||||
X, Y = zip(*positions)
|
|
||||||
ax.scatter(X, Y, s=size, c=color, marker=marker, alpha=alpha)
|
|
||||||
return ax
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from soundevent.plot.geometries import plot_geometry
|
|||||||
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
|
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
|
||||||
|
|
||||||
from batdetect2.plotting.clips import plot_clip
|
from batdetect2.plotting.clips import plot_clip
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
from batdetect2.preprocess import PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plot_clip_prediction",
|
"plot_clip_prediction",
|
||||||
@ -21,6 +21,8 @@ def plot_clip_prediction(
|
|||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
add_colorbar: bool = False,
|
||||||
|
add_labels: bool = False,
|
||||||
add_legend: bool = False,
|
add_legend: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
linewidth: float = 1,
|
linewidth: float = 1,
|
||||||
@ -32,6 +34,8 @@ def plot_clip_prediction(
|
|||||||
figsize=figsize,
|
figsize=figsize,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
audio_dir=audio_dir,
|
audio_dir=audio_dir,
|
||||||
|
add_colorbar=add_colorbar,
|
||||||
|
add_labels=add_labels,
|
||||||
spec_cmap=spec_cmap,
|
spec_cmap=spec_cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,14 +1,13 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.preprocess import (
|
||||||
from batdetect2.plotting.common import plot_spectrogram
|
PreprocessorProtocol,
|
||||||
from batdetect2.preprocess import build_preprocessor
|
get_default_preprocessor,
|
||||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plot_clip",
|
"plot_clip",
|
||||||
@ -17,33 +16,29 @@ __all__ = [
|
|||||||
|
|
||||||
def plot_clip(
|
def plot_clip(
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
add_colorbar: bool = False,
|
||||||
|
add_labels: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
if ax is None:
|
if ax is None:
|
||||||
_, ax = plt.subplots(figsize=figsize)
|
_, ax = plt.subplots(figsize=figsize)
|
||||||
|
|
||||||
if preprocessor is None:
|
if preprocessor is None:
|
||||||
preprocessor = build_preprocessor()
|
preprocessor = get_default_preprocessor()
|
||||||
|
|
||||||
if audio_loader is None:
|
spec = preprocessor.preprocess_clip(clip, audio_dir=audio_dir)
|
||||||
audio_loader = build_audio_loader()
|
|
||||||
|
|
||||||
wav = torch.tensor(audio_loader.load_clip(clip, audio_dir=audio_dir))
|
spec.plot( # type: ignore
|
||||||
spec = preprocessor(wav)
|
|
||||||
|
|
||||||
plot_spectrogram(
|
|
||||||
spec,
|
|
||||||
start_time=clip.start_time,
|
|
||||||
end_time=clip.end_time,
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
ax=ax,
|
ax=ax,
|
||||||
|
add_colorbar=add_colorbar,
|
||||||
cmap=spec_cmap,
|
cmap=spec_cmap,
|
||||||
|
add_labels=add_labels,
|
||||||
|
vmin=spec.min().item(),
|
||||||
|
vmax=spec.max().item(),
|
||||||
)
|
)
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|||||||
@ -1,10 +1,8 @@
|
|||||||
"""General plotting utilities."""
|
"""General plotting utilities."""
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from matplotlib import axes
|
from matplotlib import axes
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -14,62 +12,11 @@ __all__ = [
|
|||||||
|
|
||||||
def create_ax(
|
def create_ax(
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: Optional[axes.Axes] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Tuple[int, int] = (10, 10),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
"""Create a new axis if none is provided"""
|
"""Create a new axis if none is provided"""
|
||||||
if ax is None:
|
if ax is None:
|
||||||
_, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) # type: ignore
|
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
|
||||||
|
|
||||||
return ax # type: ignore
|
return ax # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def plot_spectrogram(
|
|
||||||
spec: Union[torch.Tensor, np.ndarray],
|
|
||||||
start_time: Optional[float] = None,
|
|
||||||
end_time: Optional[float] = None,
|
|
||||||
min_freq: Optional[float] = None,
|
|
||||||
max_freq: Optional[float] = None,
|
|
||||||
ax: Optional[axes.Axes] = None,
|
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
|
||||||
add_colorbar: bool = False,
|
|
||||||
colorbar_kwargs: Optional[dict] = None,
|
|
||||||
vmin: Optional[float] = None,
|
|
||||||
vmax: Optional[float] = None,
|
|
||||||
cmap="gray",
|
|
||||||
) -> axes.Axes:
|
|
||||||
if isinstance(spec, torch.Tensor):
|
|
||||||
spec = spec.numpy()
|
|
||||||
|
|
||||||
spec = spec.squeeze()
|
|
||||||
|
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
|
||||||
|
|
||||||
if start_time is None:
|
|
||||||
start_time = 0
|
|
||||||
|
|
||||||
if end_time is None:
|
|
||||||
end_time = spec.shape[-1]
|
|
||||||
|
|
||||||
if min_freq is None:
|
|
||||||
min_freq = 0
|
|
||||||
|
|
||||||
if max_freq is None:
|
|
||||||
max_freq = spec.shape[-2]
|
|
||||||
|
|
||||||
mappable = ax.pcolormesh(
|
|
||||||
np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True),
|
|
||||||
np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True),
|
|
||||||
spec,
|
|
||||||
cmap=cmap,
|
|
||||||
vmin=vmin,
|
|
||||||
vmax=vmax,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_xlim(start_time, end_time)
|
|
||||||
ax.set_ylim(min_freq, max_freq)
|
|
||||||
|
|
||||||
if add_colorbar:
|
|
||||||
plt.colorbar(mappable, ax=ax, **(colorbar_kwargs or {}))
|
|
||||||
|
|
||||||
return ax
|
|
||||||
|
|||||||
@ -1,113 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from matplotlib import axes, patches
|
|
||||||
from soundevent.plot import plot_geometry
|
|
||||||
|
|
||||||
from batdetect2.evaluate.metrics.detection import ClipEval
|
|
||||||
from batdetect2.plotting.clips import (
|
|
||||||
AudioLoader,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
plot_clip,
|
|
||||||
)
|
|
||||||
from batdetect2.plotting.common import create_ax
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"plot_clip_detections",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def plot_clip_detections(
|
|
||||||
clip_eval: ClipEval,
|
|
||||||
figsize: tuple[int, int] = (10, 10),
|
|
||||||
ax: Optional[axes.Axes] = None,
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
threshold: float = 0.2,
|
|
||||||
add_legend: bool = True,
|
|
||||||
add_title: bool = True,
|
|
||||||
fill: bool = False,
|
|
||||||
linewidth: float = 1.0,
|
|
||||||
gt_color: str = "green",
|
|
||||||
gt_linestyle: str = "-",
|
|
||||||
true_pred_color: str = "yellow",
|
|
||||||
true_pred_linestyle: str = "--",
|
|
||||||
false_pred_color: str = "blue",
|
|
||||||
false_pred_linestyle: str = "-",
|
|
||||||
missed_gt_color: str = "red",
|
|
||||||
missed_gt_linestyle: str = "-",
|
|
||||||
) -> axes.Axes:
|
|
||||||
ax = create_ax(figsize=figsize, ax=ax)
|
|
||||||
|
|
||||||
plot_clip(
|
|
||||||
clip_eval.clip,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
ax=ax,
|
|
||||||
)
|
|
||||||
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
is_match = (
|
|
||||||
m.pred is not None and m.gt is not None and m.score >= threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
if m.pred is not None:
|
|
||||||
color = true_pred_color if is_match else false_pred_color
|
|
||||||
plot_geometry(
|
|
||||||
m.pred.geometry,
|
|
||||||
ax=ax,
|
|
||||||
add_points=False,
|
|
||||||
facecolor="none" if not fill else color,
|
|
||||||
alpha=m.pred.detection_score,
|
|
||||||
linewidth=linewidth,
|
|
||||||
linestyle=true_pred_linestyle
|
|
||||||
if is_match
|
|
||||||
else missed_gt_linestyle,
|
|
||||||
color=color,
|
|
||||||
)
|
|
||||||
|
|
||||||
if m.gt is not None:
|
|
||||||
color = gt_color if is_match else missed_gt_color
|
|
||||||
plot_geometry(
|
|
||||||
m.gt.sound_event.geometry, # type: ignore
|
|
||||||
ax=ax,
|
|
||||||
add_points=False,
|
|
||||||
linewidth=linewidth,
|
|
||||||
facecolor="none" if not fill else color,
|
|
||||||
linestyle=gt_linestyle if is_match else false_pred_linestyle,
|
|
||||||
color=color,
|
|
||||||
)
|
|
||||||
|
|
||||||
if add_title:
|
|
||||||
ax.set_title(clip_eval.clip.recording.path.name)
|
|
||||||
|
|
||||||
if add_legend:
|
|
||||||
ax.legend(
|
|
||||||
handles=[
|
|
||||||
patches.Patch(
|
|
||||||
label="found GT",
|
|
||||||
edgecolor=gt_color,
|
|
||||||
facecolor="none" if not fill else gt_color,
|
|
||||||
linestyle=gt_linestyle,
|
|
||||||
),
|
|
||||||
patches.Patch(
|
|
||||||
label="missed GT",
|
|
||||||
edgecolor=missed_gt_color,
|
|
||||||
facecolor="none" if not fill else missed_gt_color,
|
|
||||||
linestyle=missed_gt_linestyle,
|
|
||||||
),
|
|
||||||
patches.Patch(
|
|
||||||
label="true Det",
|
|
||||||
edgecolor=true_pred_color,
|
|
||||||
facecolor="none" if not fill else true_pred_color,
|
|
||||||
linestyle=true_pred_linestyle,
|
|
||||||
),
|
|
||||||
patches.Patch(
|
|
||||||
label="false Det",
|
|
||||||
edgecolor=false_pred_color,
|
|
||||||
facecolor="none" if not fill else false_pred_color,
|
|
||||||
linestyle=false_pred_linestyle,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return ax
|
|
||||||
160
src/batdetect2/plotting/evaluation.py
Normal file
160
src/batdetect2/plotting/evaluation.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from batdetect2 import plotting
|
||||||
|
from batdetect2.evaluate.types import MatchEvaluation
|
||||||
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClassExamples:
|
||||||
|
false_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||||
|
false_negatives: List[MatchEvaluation] = field(default_factory=list)
|
||||||
|
true_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||||
|
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_example_gallery(
|
||||||
|
matches: List[MatchEvaluation],
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
n_examples: int = 5,
|
||||||
|
):
|
||||||
|
class_examples = defaultdict(ClassExamples)
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
gt_class = match.gt_class
|
||||||
|
pred_class = match.pred_class
|
||||||
|
|
||||||
|
if pred_class is None:
|
||||||
|
class_examples[gt_class].false_negatives.append(match)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if gt_class is None:
|
||||||
|
class_examples[pred_class].false_positives.append(match)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if gt_class != pred_class:
|
||||||
|
class_examples[gt_class].cross_triggers.append(match)
|
||||||
|
class_examples[pred_class].cross_triggers.append(match)
|
||||||
|
continue
|
||||||
|
|
||||||
|
class_examples[gt_class].true_positives.append(match)
|
||||||
|
|
||||||
|
for class_name, examples in class_examples.items():
|
||||||
|
true_positives = get_binned_sample(
|
||||||
|
examples.true_positives,
|
||||||
|
n_examples=n_examples,
|
||||||
|
)
|
||||||
|
|
||||||
|
false_positives = get_binned_sample(
|
||||||
|
examples.false_positives,
|
||||||
|
n_examples=n_examples,
|
||||||
|
)
|
||||||
|
|
||||||
|
false_negatives = random.sample(
|
||||||
|
examples.false_negatives,
|
||||||
|
k=min(n_examples, len(examples.false_negatives)),
|
||||||
|
)
|
||||||
|
|
||||||
|
cross_triggers = get_binned_sample(
|
||||||
|
examples.cross_triggers,
|
||||||
|
n_examples=n_examples,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = plot_class_examples(
|
||||||
|
true_positives,
|
||||||
|
false_positives,
|
||||||
|
false_negatives,
|
||||||
|
cross_triggers,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
n_examples=n_examples,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield class_name, fig
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_class_examples(
|
||||||
|
true_positives: List[MatchEvaluation],
|
||||||
|
false_positives: List[MatchEvaluation],
|
||||||
|
false_negatives: List[MatchEvaluation],
|
||||||
|
cross_triggers: List[MatchEvaluation],
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
n_examples: int = 5,
|
||||||
|
duration: float = 0.1,
|
||||||
|
):
|
||||||
|
fig = plt.figure(figsize=(20, 20))
|
||||||
|
|
||||||
|
for index, match in enumerate(true_positives[:n_examples]):
|
||||||
|
ax = plt.subplot(4, n_examples, index + 1)
|
||||||
|
try:
|
||||||
|
plotting.plot_true_positive_match(
|
||||||
|
match,
|
||||||
|
ax=ax,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
except (ValueError, AssertionError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for index, match in enumerate(false_positives[:n_examples]):
|
||||||
|
ax = plt.subplot(4, n_examples, n_examples + index + 1)
|
||||||
|
try:
|
||||||
|
plotting.plot_false_positive_match(
|
||||||
|
match,
|
||||||
|
ax=ax,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
except (ValueError, AssertionError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for index, match in enumerate(false_negatives[:n_examples]):
|
||||||
|
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
|
||||||
|
try:
|
||||||
|
plotting.plot_false_negative_match(
|
||||||
|
match,
|
||||||
|
ax=ax,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
except (ValueError, AssertionError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for index, match in enumerate(cross_triggers[:n_examples]):
|
||||||
|
ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1)
|
||||||
|
try:
|
||||||
|
plotting.plot_cross_trigger_match(
|
||||||
|
match,
|
||||||
|
ax=ax,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
duration=duration,
|
||||||
|
)
|
||||||
|
except (ValueError, AssertionError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
|
||||||
|
if len(matches) < n_examples:
|
||||||
|
return matches
|
||||||
|
|
||||||
|
indices, pred_scores = zip(
|
||||||
|
*[
|
||||||
|
(index, match.pred_class_scores[pred_class])
|
||||||
|
for index, match in enumerate(matches)
|
||||||
|
if (pred_class := match.pred_class) is not None
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
||||||
|
df = pd.DataFrame({"indices": indices, "bins": bins})
|
||||||
|
sample = df.groupby("bins").apply(lambda x: x.sample(1))
|
||||||
|
return [matches[ind] for ind in sample["indices"]]
|
||||||
@ -1,109 +0,0 @@
|
|||||||
from typing import Optional, Sequence
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
|
|
||||||
from batdetect2.plotting.matches import (
|
|
||||||
MatchProtocol,
|
|
||||||
plot_cross_trigger_match,
|
|
||||||
plot_false_negative_match,
|
|
||||||
plot_false_positive_match,
|
|
||||||
plot_true_positive_match,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
|
||||||
|
|
||||||
__all__ = ["plot_match_gallery"]
|
|
||||||
|
|
||||||
|
|
||||||
def plot_match_gallery(
|
|
||||||
true_positives: Sequence[MatchProtocol],
|
|
||||||
false_positives: Sequence[MatchProtocol],
|
|
||||||
false_negatives: Sequence[MatchProtocol],
|
|
||||||
cross_triggers: Sequence[MatchProtocol],
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
n_examples: int = 5,
|
|
||||||
duration: float = 0.1,
|
|
||||||
fig: Optional[Figure] = None,
|
|
||||||
):
|
|
||||||
if fig is None:
|
|
||||||
fig = plt.figure(figsize=(20, 20))
|
|
||||||
|
|
||||||
axes = fig.subplots(
|
|
||||||
nrows=4,
|
|
||||||
ncols=n_examples,
|
|
||||||
sharex="none",
|
|
||||||
sharey="row",
|
|
||||||
)
|
|
||||||
|
|
||||||
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples]):
|
|
||||||
try:
|
|
||||||
plot_true_positive_match(
|
|
||||||
tp_match,
|
|
||||||
ax=tp_ax,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
duration=duration,
|
|
||||||
)
|
|
||||||
except (
|
|
||||||
ValueError,
|
|
||||||
AssertionError,
|
|
||||||
RuntimeError,
|
|
||||||
FileNotFoundError,
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples]):
|
|
||||||
try:
|
|
||||||
plot_false_positive_match(
|
|
||||||
fp_match,
|
|
||||||
ax=fp_ax,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
duration=duration,
|
|
||||||
)
|
|
||||||
except (
|
|
||||||
ValueError,
|
|
||||||
AssertionError,
|
|
||||||
RuntimeError,
|
|
||||||
FileNotFoundError,
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples]):
|
|
||||||
try:
|
|
||||||
plot_false_negative_match(
|
|
||||||
fn_match,
|
|
||||||
ax=fn_ax,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
duration=duration,
|
|
||||||
)
|
|
||||||
except (
|
|
||||||
ValueError,
|
|
||||||
AssertionError,
|
|
||||||
RuntimeError,
|
|
||||||
FileNotFoundError,
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples]):
|
|
||||||
try:
|
|
||||||
plot_cross_trigger_match(
|
|
||||||
ct_match,
|
|
||||||
ax=ct_ax,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
duration=duration,
|
|
||||||
)
|
|
||||||
except (
|
|
||||||
ValueError,
|
|
||||||
AssertionError,
|
|
||||||
RuntimeError,
|
|
||||||
FileNotFoundError,
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
fig.tight_layout()
|
|
||||||
|
|
||||||
return fig
|
|
||||||
@ -1,117 +1,26 @@
|
|||||||
"""Plot heatmaps"""
|
"""Plot heatmaps"""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import xarray as xr
|
||||||
import torch
|
from matplotlib import axes
|
||||||
from matplotlib import axes, patches
|
|
||||||
from matplotlib.cm import get_cmap
|
|
||||||
from matplotlib.colors import Colormap, LinearSegmentedColormap, to_rgba
|
|
||||||
|
|
||||||
from batdetect2.plotting.common import create_ax
|
from batdetect2.plotting.common import create_ax
|
||||||
|
|
||||||
|
|
||||||
def plot_detection_heatmap(
|
def plot_heatmap(
|
||||||
heatmap: Union[torch.Tensor, np.ndarray],
|
heatmap: xr.DataArray,
|
||||||
ax: Optional[axes.Axes] = None,
|
ax: Optional[axes.Axes] = None,
|
||||||
figsize: Tuple[int, int] = (10, 10),
|
figsize: Tuple[int, int] = (10, 10),
|
||||||
threshold: Optional[float] = None,
|
|
||||||
alpha: float = 1,
|
|
||||||
cmap: Union[str, Colormap] = "jet",
|
|
||||||
color: Optional[str] = None,
|
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
ax = create_ax(ax, figsize=figsize)
|
ax = create_ax(ax, figsize=figsize)
|
||||||
|
|
||||||
if isinstance(heatmap, torch.Tensor):
|
|
||||||
heatmap = heatmap.numpy()
|
|
||||||
|
|
||||||
heatmap = heatmap.squeeze()
|
|
||||||
|
|
||||||
if threshold is not None:
|
|
||||||
heatmap = np.ma.masked_where(
|
|
||||||
heatmap < threshold,
|
|
||||||
heatmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
if color is not None:
|
|
||||||
cmap = create_colormap(color)
|
|
||||||
|
|
||||||
ax.pcolormesh(
|
ax.pcolormesh(
|
||||||
|
heatmap.time,
|
||||||
|
heatmap.frequency,
|
||||||
heatmap,
|
heatmap,
|
||||||
vmax=1,
|
vmax=1,
|
||||||
vmin=0,
|
vmin=0,
|
||||||
cmap=cmap,
|
|
||||||
alpha=alpha,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
def plot_classification_heatmap(
|
|
||||||
heatmap: Union[torch.Tensor, np.ndarray],
|
|
||||||
ax: Optional[axes.Axes] = None,
|
|
||||||
figsize: Tuple[int, int] = (10, 10),
|
|
||||||
class_names: Optional[List[str]] = None,
|
|
||||||
threshold: Optional[float] = 0.1,
|
|
||||||
alpha: float = 1,
|
|
||||||
cmap: Union[str, Colormap] = "tab20",
|
|
||||||
):
|
|
||||||
ax = create_ax(ax, figsize=figsize)
|
|
||||||
|
|
||||||
if isinstance(heatmap, torch.Tensor):
|
|
||||||
heatmap = heatmap.numpy()
|
|
||||||
|
|
||||||
if heatmap.ndim == 4:
|
|
||||||
heatmap = heatmap[0]
|
|
||||||
|
|
||||||
if heatmap.ndim != 3:
|
|
||||||
raise ValueError("Expecting a 3-dimensional array")
|
|
||||||
|
|
||||||
num_classes = heatmap.shape[0]
|
|
||||||
|
|
||||||
if class_names is None:
|
|
||||||
class_names = [f"class_{i}" for i in range(num_classes)]
|
|
||||||
|
|
||||||
if len(class_names) != num_classes:
|
|
||||||
raise ValueError("Inconsistent number of class names")
|
|
||||||
|
|
||||||
if not isinstance(cmap, Colormap):
|
|
||||||
cmap = get_cmap(cmap)
|
|
||||||
|
|
||||||
handles = []
|
|
||||||
|
|
||||||
for index, class_heatmap in enumerate(heatmap):
|
|
||||||
class_name = class_names[index]
|
|
||||||
|
|
||||||
color = cmap(index / num_classes)
|
|
||||||
|
|
||||||
max = class_heatmap.max()
|
|
||||||
|
|
||||||
if max == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if threshold is not None:
|
|
||||||
class_heatmap = np.ma.masked_where(
|
|
||||||
class_heatmap < threshold,
|
|
||||||
class_heatmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.pcolormesh(
|
|
||||||
class_heatmap,
|
|
||||||
vmax=1,
|
|
||||||
vmin=0,
|
|
||||||
cmap=create_colormap(color), # type: ignore
|
|
||||||
alpha=alpha,
|
|
||||||
)
|
|
||||||
|
|
||||||
handles.append(patches.Patch(color=color, label=class_name))
|
|
||||||
|
|
||||||
ax.legend(handles=handles)
|
|
||||||
return ax
|
|
||||||
|
|
||||||
|
|
||||||
def create_colormap(color: str) -> Colormap:
|
|
||||||
(r, g, b, a) = to_rgba(color)
|
|
||||||
return LinearSegmentedColormap.from_list(
|
|
||||||
"cmap", colors=[(0, 0, 0, 0), (r, g, b, a)]
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,17 +1,21 @@
|
|||||||
from typing import Optional, Protocol, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from soundevent import data, plot
|
from soundevent import data, plot
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
from soundevent.plot.tags import TagColorMapper
|
||||||
|
|
||||||
|
from batdetect2.evaluate.types import MatchEvaluation
|
||||||
|
from batdetect2.plotting.clip_predictions import plot_prediction
|
||||||
from batdetect2.plotting.clips import plot_clip
|
from batdetect2.plotting.clips import plot_clip
|
||||||
from batdetect2.typing import (
|
from batdetect2.preprocess import (
|
||||||
AudioLoader,
|
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
RawPrediction,
|
get_default_preprocessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"plot_matches",
|
||||||
"plot_false_positive_match",
|
"plot_false_positive_match",
|
||||||
"plot_true_positive_match",
|
"plot_true_positive_match",
|
||||||
"plot_false_negative_match",
|
"plot_false_negative_match",
|
||||||
@ -19,14 +23,6 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class MatchProtocol(Protocol):
|
|
||||||
clip: data.Clip
|
|
||||||
gt: Optional[data.SoundEventAnnotation]
|
|
||||||
pred: Optional[RawPrediction]
|
|
||||||
score: float
|
|
||||||
true_class: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_DURATION = 0.05
|
DEFAULT_DURATION = 0.05
|
||||||
DEFAULT_FALSE_POSITIVE_COLOR = "orange"
|
DEFAULT_FALSE_POSITIVE_COLOR = "orange"
|
||||||
DEFAULT_FALSE_NEGATIVE_COLOR = "red"
|
DEFAULT_FALSE_NEGATIVE_COLOR = "red"
|
||||||
@ -36,191 +32,278 @@ DEFAULT_ANNOTATION_LINE_STYLE = "-"
|
|||||||
DEFAULT_PREDICTION_LINE_STYLE = "--"
|
DEFAULT_PREDICTION_LINE_STYLE = "--"
|
||||||
|
|
||||||
|
|
||||||
def plot_false_positive_match(
|
def plot_matches(
|
||||||
match: MatchProtocol,
|
matches: List[data.Match],
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
clip: data.Clip,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
color_mapper: Optional[TagColorMapper] = None,
|
||||||
use_score: bool = True,
|
add_colorbar: bool = False,
|
||||||
add_spectrogram: bool = True,
|
add_labels: bool = False,
|
||||||
add_text: bool = True,
|
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
add_title: bool = True,
|
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||||
fontsize: Union[float, str] = "small",
|
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||||
|
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||||
|
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||||
|
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.pred is not None
|
if preprocessor is None:
|
||||||
|
preprocessor = get_default_preprocessor()
|
||||||
|
|
||||||
start_time, _, _, high_freq = compute_bounds(match.pred.geometry)
|
ax = plot_clip(
|
||||||
|
clip,
|
||||||
clip = data.Clip(
|
|
||||||
start_time=max(
|
|
||||||
start_time - duration / 2,
|
|
||||||
0,
|
|
||||||
),
|
|
||||||
end_time=min(
|
|
||||||
start_time + duration / 2,
|
|
||||||
match.clip.recording.duration,
|
|
||||||
),
|
|
||||||
recording=match.clip.recording,
|
|
||||||
)
|
|
||||||
|
|
||||||
if add_spectrogram:
|
|
||||||
ax = plot_clip(
|
|
||||||
clip,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
figsize=figsize,
|
|
||||||
ax=ax,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
spec_cmap=spec_cmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax = plot.plot_geometry(
|
|
||||||
match.pred.geometry,
|
|
||||||
ax=ax,
|
ax=ax,
|
||||||
add_points=add_points,
|
figsize=figsize,
|
||||||
facecolor="none" if not fill else None,
|
audio_dir=audio_dir,
|
||||||
alpha=match.score if use_score else 1,
|
add_colorbar=add_colorbar,
|
||||||
color=color,
|
add_labels=add_labels,
|
||||||
|
spec_cmap=spec_cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_text:
|
if color_mapper is None:
|
||||||
ax.text(
|
color_mapper = TagColorMapper()
|
||||||
start_time,
|
|
||||||
high_freq,
|
|
||||||
f"score={match.score:.2f}",
|
|
||||||
va="top",
|
|
||||||
ha="right",
|
|
||||||
color=color,
|
|
||||||
fontsize=fontsize,
|
|
||||||
)
|
|
||||||
|
|
||||||
if add_title:
|
for match in matches:
|
||||||
ax.set_title("False Positive")
|
if match.source is None and match.target is not None:
|
||||||
|
plot.plot_annotation(
|
||||||
|
annotation=match.target,
|
||||||
|
ax=ax,
|
||||||
|
time_offset=0.004,
|
||||||
|
freq_offset=2_000,
|
||||||
|
add_points=add_points,
|
||||||
|
facecolor="none" if not fill else None,
|
||||||
|
color=false_negative_color,
|
||||||
|
color_mapper=color_mapper,
|
||||||
|
linestyle=annotation_linestyle,
|
||||||
|
)
|
||||||
|
elif match.target is None and match.source is not None:
|
||||||
|
plot_prediction(
|
||||||
|
prediction=match.source,
|
||||||
|
ax=ax,
|
||||||
|
time_offset=0.004,
|
||||||
|
freq_offset=2_000,
|
||||||
|
add_points=add_points,
|
||||||
|
facecolor="none" if not fill else None,
|
||||||
|
color=false_positive_color,
|
||||||
|
color_mapper=color_mapper,
|
||||||
|
linestyle=prediction_linestyle,
|
||||||
|
)
|
||||||
|
elif match.target is not None and match.source is not None:
|
||||||
|
plot.plot_annotation(
|
||||||
|
annotation=match.target,
|
||||||
|
ax=ax,
|
||||||
|
time_offset=0.004,
|
||||||
|
freq_offset=2_000,
|
||||||
|
add_points=add_points,
|
||||||
|
facecolor="none" if not fill else None,
|
||||||
|
color=true_positive_color,
|
||||||
|
color_mapper=color_mapper,
|
||||||
|
linestyle=annotation_linestyle,
|
||||||
|
)
|
||||||
|
plot_prediction(
|
||||||
|
prediction=match.source,
|
||||||
|
ax=ax,
|
||||||
|
time_offset=0.004,
|
||||||
|
freq_offset=2_000,
|
||||||
|
add_points=add_points,
|
||||||
|
facecolor="none" if not fill else None,
|
||||||
|
color=true_positive_color,
|
||||||
|
color_mapper=color_mapper,
|
||||||
|
linestyle=prediction_linestyle,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
def plot_false_negative_match(
|
def plot_false_positive_match(
|
||||||
match: MatchProtocol,
|
match: MatchEvaluation,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
add_spectrogram: bool = True,
|
add_colorbar: bool = False,
|
||||||
|
add_labels: bool = False,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
add_title: bool = True,
|
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
time_offset: float = 0,
|
||||||
|
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||||
|
fontsize: Union[float, str] = "small",
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.gt is not None
|
assert match.match.source is not None
|
||||||
|
assert match.match.target is None
|
||||||
geometry = match.gt.sound_event.geometry
|
sound_event = match.match.source.sound_event
|
||||||
|
geometry = sound_event.geometry
|
||||||
assert geometry is not None
|
assert geometry is not None
|
||||||
|
|
||||||
start_time = compute_bounds(geometry)[0]
|
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||||
|
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
start_time=max(
|
start_time=max(start_time - duration / 2, 0),
|
||||||
start_time - duration / 2,
|
|
||||||
0,
|
|
||||||
),
|
|
||||||
end_time=min(
|
end_time=min(
|
||||||
start_time + duration / 2,
|
start_time + duration / 2,
|
||||||
match.clip.recording.duration,
|
sound_event.recording.duration,
|
||||||
),
|
),
|
||||||
recording=match.clip.recording,
|
recording=sound_event.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_spectrogram:
|
ax = plot_clip(
|
||||||
ax = plot_clip(
|
clip,
|
||||||
clip,
|
preprocessor=preprocessor,
|
||||||
audio_loader=audio_loader,
|
figsize=figsize,
|
||||||
preprocessor=preprocessor,
|
|
||||||
figsize=figsize,
|
|
||||||
ax=ax,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
spec_cmap=spec_cmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax = plot.plot_geometry(
|
|
||||||
geometry,
|
|
||||||
ax=ax,
|
ax=ax,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
add_colorbar=add_colorbar,
|
||||||
|
add_labels=add_labels,
|
||||||
|
spec_cmap=spec_cmap,
|
||||||
|
)
|
||||||
|
|
||||||
|
plot_prediction(
|
||||||
|
match.match.source,
|
||||||
|
ax=ax,
|
||||||
|
time_offset=time_offset,
|
||||||
|
freq_offset=2_000,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=1,
|
||||||
color=color,
|
color=color,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_title:
|
plt.text(
|
||||||
ax.set_title("False Negative")
|
start_time,
|
||||||
|
high_freq,
|
||||||
|
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||||
|
va="top",
|
||||||
|
ha="right",
|
||||||
|
color=color,
|
||||||
|
fontsize=fontsize,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_false_negative_match(
|
||||||
|
match: MatchEvaluation,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
ax: Optional[Axes] = None,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
duration: float = DEFAULT_DURATION,
|
||||||
|
add_colorbar: bool = False,
|
||||||
|
add_labels: bool = False,
|
||||||
|
add_points: bool = False,
|
||||||
|
fill: bool = False,
|
||||||
|
spec_cmap: str = "gray",
|
||||||
|
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||||
|
fontsize: Union[float, str] = "small",
|
||||||
|
) -> Axes:
|
||||||
|
assert match.match.source is None
|
||||||
|
assert match.match.target is not None
|
||||||
|
sound_event = match.match.target.sound_event
|
||||||
|
geometry = sound_event.geometry
|
||||||
|
assert geometry is not None
|
||||||
|
|
||||||
|
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||||
|
|
||||||
|
clip = data.Clip(
|
||||||
|
start_time=max(start_time - duration / 2, 0),
|
||||||
|
end_time=min(
|
||||||
|
start_time + duration / 2, sound_event.recording.duration
|
||||||
|
),
|
||||||
|
recording=sound_event.recording,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax = plot_clip(
|
||||||
|
clip,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
figsize=figsize,
|
||||||
|
ax=ax,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
add_colorbar=add_colorbar,
|
||||||
|
add_labels=add_labels,
|
||||||
|
spec_cmap=spec_cmap,
|
||||||
|
)
|
||||||
|
|
||||||
|
plot.plot_annotation(
|
||||||
|
match.match.target,
|
||||||
|
ax=ax,
|
||||||
|
time_offset=0.001,
|
||||||
|
freq_offset=2_000,
|
||||||
|
add_points=add_points,
|
||||||
|
facecolor="none" if not fill else None,
|
||||||
|
alpha=1,
|
||||||
|
color=color,
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.text(
|
||||||
|
start_time,
|
||||||
|
high_freq,
|
||||||
|
f"False Negative \nClass: {match.gt_class} ",
|
||||||
|
va="top",
|
||||||
|
ha="right",
|
||||||
|
color=color,
|
||||||
|
fontsize=fontsize,
|
||||||
|
)
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
def plot_true_positive_match(
|
def plot_true_positive_match(
|
||||||
match: MatchProtocol,
|
match: MatchEvaluation,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
use_score: bool = True,
|
add_colorbar: bool = False,
|
||||||
add_spectrogram: bool = True,
|
add_labels: bool = False,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
add_text: bool = True,
|
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||||
fontsize: Union[float, str] = "small",
|
fontsize: Union[float, str] = "small",
|
||||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||||
add_title: bool = True,
|
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.gt is not None
|
assert match.match.source is not None
|
||||||
assert match.pred is not None
|
assert match.match.target is not None
|
||||||
|
sound_event = match.match.target.sound_event
|
||||||
geometry = match.gt.sound_event.geometry
|
geometry = sound_event.geometry
|
||||||
assert geometry is not None
|
assert geometry is not None
|
||||||
|
|
||||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||||
|
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
start_time=max(
|
start_time=max(start_time - duration / 2, 0),
|
||||||
start_time - duration / 2,
|
|
||||||
0,
|
|
||||||
),
|
|
||||||
end_time=min(
|
end_time=min(
|
||||||
start_time + duration / 2,
|
start_time + duration / 2, sound_event.recording.duration
|
||||||
match.clip.recording.duration,
|
|
||||||
),
|
),
|
||||||
recording=match.clip.recording,
|
recording=sound_event.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_spectrogram:
|
ax = plot_clip(
|
||||||
ax = plot_clip(
|
clip,
|
||||||
clip,
|
preprocessor=preprocessor,
|
||||||
ax=ax,
|
figsize=figsize,
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
figsize=figsize,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
spec_cmap=spec_cmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax = plot.plot_geometry(
|
|
||||||
geometry,
|
|
||||||
ax=ax,
|
ax=ax,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
add_colorbar=add_colorbar,
|
||||||
|
add_labels=add_labels,
|
||||||
|
spec_cmap=spec_cmap,
|
||||||
|
)
|
||||||
|
|
||||||
|
plot.plot_annotation(
|
||||||
|
match.match.target,
|
||||||
|
ax=ax,
|
||||||
|
time_offset=0.001,
|
||||||
|
freq_offset=2_000,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=1,
|
||||||
@ -228,46 +311,41 @@ def plot_true_positive_match(
|
|||||||
linestyle=annotation_linestyle,
|
linestyle=annotation_linestyle,
|
||||||
)
|
)
|
||||||
|
|
||||||
plot.plot_geometry(
|
plot_prediction(
|
||||||
match.pred.geometry,
|
match.match.source,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
|
time_offset=0.001,
|
||||||
|
freq_offset=2_000,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=match.score if use_score else 1,
|
alpha=1,
|
||||||
color=color,
|
color=color,
|
||||||
linestyle=prediction_linestyle,
|
linestyle=prediction_linestyle,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_text:
|
plt.text(
|
||||||
ax.text(
|
start_time,
|
||||||
start_time,
|
high_freq,
|
||||||
high_freq,
|
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||||
f"score={match.score:.2f}",
|
va="top",
|
||||||
va="top",
|
ha="right",
|
||||||
ha="right",
|
color=color,
|
||||||
color=color,
|
fontsize=fontsize,
|
||||||
fontsize=fontsize,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if add_title:
|
|
||||||
ax.set_title("True Positive")
|
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
def plot_cross_trigger_match(
|
def plot_cross_trigger_match(
|
||||||
match: MatchProtocol,
|
match: MatchEvaluation,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
use_score: bool = True,
|
add_colorbar: bool = False,
|
||||||
add_spectrogram: bool = True,
|
add_labels: bool = False,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
add_text: bool = True,
|
|
||||||
add_title: bool = True,
|
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||||
@ -275,40 +353,38 @@ def plot_cross_trigger_match(
|
|||||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.gt is not None
|
assert match.match.source is not None
|
||||||
assert match.pred is not None
|
assert match.match.target is not None
|
||||||
|
sound_event = match.match.source.sound_event
|
||||||
geometry = match.gt.sound_event.geometry
|
geometry = sound_event.geometry
|
||||||
assert geometry is not None
|
assert geometry is not None
|
||||||
|
|
||||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||||
|
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
start_time=max(
|
start_time=max(start_time - duration / 2, 0),
|
||||||
start_time - duration / 2,
|
|
||||||
0,
|
|
||||||
),
|
|
||||||
end_time=min(
|
end_time=min(
|
||||||
start_time + duration / 2,
|
start_time + duration / 2, sound_event.recording.duration
|
||||||
match.clip.recording.duration,
|
|
||||||
),
|
),
|
||||||
recording=match.clip.recording,
|
recording=sound_event.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_spectrogram:
|
ax = plot_clip(
|
||||||
ax = plot_clip(
|
clip,
|
||||||
clip,
|
preprocessor=preprocessor,
|
||||||
audio_loader=audio_loader,
|
figsize=figsize,
|
||||||
preprocessor=preprocessor,
|
|
||||||
figsize=figsize,
|
|
||||||
ax=ax,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
spec_cmap=spec_cmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax = plot.plot_geometry(
|
|
||||||
geometry,
|
|
||||||
ax=ax,
|
ax=ax,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
add_colorbar=add_colorbar,
|
||||||
|
add_labels=add_labels,
|
||||||
|
spec_cmap=spec_cmap,
|
||||||
|
)
|
||||||
|
|
||||||
|
plot.plot_annotation(
|
||||||
|
match.match.target,
|
||||||
|
ax=ax,
|
||||||
|
time_offset=0.001,
|
||||||
|
freq_offset=2_000,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=1,
|
||||||
@ -316,29 +392,26 @@ def plot_cross_trigger_match(
|
|||||||
linestyle=annotation_linestyle,
|
linestyle=annotation_linestyle,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax = plot.plot_geometry(
|
plot_prediction(
|
||||||
match.pred.geometry,
|
match.match.source,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
|
time_offset=0.001,
|
||||||
|
freq_offset=2_000,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=match.score if use_score else 1,
|
alpha=1,
|
||||||
color=color,
|
color=color,
|
||||||
linestyle=prediction_linestyle,
|
linestyle=prediction_linestyle,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_text:
|
plt.text(
|
||||||
ax.text(
|
start_time,
|
||||||
start_time,
|
high_freq,
|
||||||
high_freq,
|
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||||
f"score={match.score:.2f}\nclass={match.true_class}",
|
va="top",
|
||||||
va="top",
|
ha="right",
|
||||||
ha="right",
|
color=color,
|
||||||
color=color,
|
fontsize=fontsize,
|
||||||
fontsize=fontsize,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
if add_title:
|
|
||||||
ax.set_title("Cross Trigger")
|
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|||||||
@ -1,286 +0,0 @@
|
|||||||
from typing import Dict, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import seaborn as sns
|
|
||||||
from cycler import cycler
|
|
||||||
from matplotlib import axes
|
|
||||||
|
|
||||||
from batdetect2.plotting.common import create_ax
|
|
||||||
|
|
||||||
|
|
||||||
def set_default_styler(ax: axes.Axes) -> axes.Axes:
|
|
||||||
color_cycler = cycler(color=sns.color_palette("muted"))
|
|
||||||
style_cycler = cycler(linestyle=["-", "--", ":"]) * cycler(
|
|
||||||
marker=["o", "s", "^"]
|
|
||||||
)
|
|
||||||
custom_cycler = color_cycler * len(style_cycler) + style_cycler * len(
|
|
||||||
color_cycler
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_prop_cycle(custom_cycler)
|
|
||||||
return ax
|
|
||||||
|
|
||||||
|
|
||||||
def set_default_style(ax: axes.Axes) -> axes.Axes:
|
|
||||||
ax = set_default_styler(ax)
|
|
||||||
ax.spines.right.set_visible(False)
|
|
||||||
ax.spines.top.set_visible(False)
|
|
||||||
return ax
|
|
||||||
|
|
||||||
|
|
||||||
def plot_pr_curve(
|
|
||||||
precision: np.ndarray,
|
|
||||||
recall: np.ndarray,
|
|
||||||
thresholds: np.ndarray,
|
|
||||||
ax: Optional[axes.Axes] = None,
|
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
|
||||||
add_labels: bool = True,
|
|
||||||
add_legend: bool = False,
|
|
||||||
label: str = "PR Curve",
|
|
||||||
) -> axes.Axes:
|
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
|
||||||
|
|
||||||
ax = set_default_style(ax)
|
|
||||||
|
|
||||||
ax.plot(
|
|
||||||
recall,
|
|
||||||
precision,
|
|
||||||
label=label,
|
|
||||||
marker="o",
|
|
||||||
markevery=_get_marker_positions(thresholds),
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_xlim(0, 1.05)
|
|
||||||
ax.set_ylim(0, 1.05)
|
|
||||||
|
|
||||||
if add_legend:
|
|
||||||
ax.legend()
|
|
||||||
|
|
||||||
if add_labels:
|
|
||||||
ax.set_xlabel("Recall")
|
|
||||||
ax.set_ylabel("Precision")
|
|
||||||
|
|
||||||
return ax
|
|
||||||
|
|
||||||
|
|
||||||
def plot_pr_curves(
|
|
||||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
|
||||||
ax: Optional[axes.Axes] = None,
|
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
|
||||||
add_legend: bool = True,
|
|
||||||
add_labels: bool = True,
|
|
||||||
) -> axes.Axes:
|
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
|
||||||
ax = set_default_style(ax)
|
|
||||||
|
|
||||||
for name, (precision, recall, thresholds) in data.items():
|
|
||||||
ax.plot(
|
|
||||||
recall,
|
|
||||||
precision,
|
|
||||||
label=name,
|
|
||||||
markevery=_get_marker_positions(thresholds),
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_xlim(0, 1.05)
|
|
||||||
ax.set_ylim(0, 1.05)
|
|
||||||
|
|
||||||
if add_labels:
|
|
||||||
ax.set_xlabel("Recall")
|
|
||||||
ax.set_ylabel("Precision")
|
|
||||||
|
|
||||||
if add_legend:
|
|
||||||
ax.legend(
|
|
||||||
bbox_to_anchor=(1.05, 1),
|
|
||||||
loc="upper left",
|
|
||||||
borderaxespad=0.0,
|
|
||||||
)
|
|
||||||
return ax
|
|
||||||
|
|
||||||
|
|
||||||
def plot_threshold_precision_curve(
|
|
||||||
threshold: np.ndarray,
|
|
||||||
precision: np.ndarray,
|
|
||||||
ax: Optional[axes.Axes] = None,
|
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
|
||||||
add_labels: bool = True,
|
|
||||||
):
|
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
|
||||||
|
|
||||||
ax = set_default_style(ax)
|
|
||||||
|
|
||||||
ax.plot(threshold, precision, markevery=_get_marker_positions(threshold))
|
|
||||||
|
|
||||||
ax.set_xlim(0, 1.05)
|
|
||||||
ax.set_ylim(0, 1.05)
|
|
||||||
|
|
||||||
if add_labels:
|
|
||||||
ax.set_xlabel("Threshold")
|
|
||||||
ax.set_ylabel("Precision")
|
|
||||||
|
|
||||||
return ax
|
|
||||||
|
|
||||||
|
|
||||||
def plot_threshold_precision_curves(
|
|
||||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
|
||||||
ax: Optional[axes.Axes] = None,
|
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
|
||||||
add_legend: bool = True,
|
|
||||||
add_labels: bool = True,
|
|
||||||
):
|
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
|
||||||
ax = set_default_style(ax)
|
|
||||||
|
|
||||||
for name, (precision, _, thresholds) in data.items():
|
|
||||||
ax.plot(
|
|
||||||
thresholds,
|
|
||||||
precision,
|
|
||||||
label=name,
|
|
||||||
markevery=_get_marker_positions(thresholds),
|
|
||||||
)
|
|
||||||
|
|
||||||
if add_legend:
|
|
||||||
ax.legend(
|
|
||||||
bbox_to_anchor=(1.05, 1),
|
|
||||||
loc="upper left",
|
|
||||||
borderaxespad=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_xlim(0, 1.05)
|
|
||||||
ax.set_ylim(0, 1.05)
|
|
||||||
|
|
||||||
if add_labels:
|
|
||||||
ax.set_xlabel("Threshold")
|
|
||||||
ax.set_ylabel("Precision")
|
|
||||||
|
|
||||||
return ax
|
|
||||||
|
|
||||||
|
|
||||||
def plot_threshold_recall_curve(
|
|
||||||
threshold: np.ndarray,
|
|
||||||
recall: np.ndarray,
|
|
||||||
ax: Optional[axes.Axes] = None,
|
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
|
||||||
add_labels: bool = True,
|
|
||||||
):
|
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
|
||||||
|
|
||||||
ax = set_default_style(ax)
|
|
||||||
|
|
||||||
ax.plot(threshold, recall, markevery=_get_marker_positions(threshold))
|
|
||||||
|
|
||||||
ax.set_xlim(0, 1.05)
|
|
||||||
ax.set_ylim(0, 1.05)
|
|
||||||
|
|
||||||
if add_labels:
|
|
||||||
ax.set_xlabel("Threshold")
|
|
||||||
ax.set_ylabel("Recall")
|
|
||||||
|
|
||||||
return ax
|
|
||||||
|
|
||||||
|
|
||||||
def plot_threshold_recall_curves(
|
|
||||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
|
||||||
ax: Optional[axes.Axes] = None,
|
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
|
||||||
add_legend: bool = True,
|
|
||||||
add_labels: bool = True,
|
|
||||||
):
|
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
|
||||||
ax = set_default_style(ax)
|
|
||||||
|
|
||||||
for name, (_, recall, thresholds) in data.items():
|
|
||||||
ax.plot(
|
|
||||||
thresholds,
|
|
||||||
recall,
|
|
||||||
label=name,
|
|
||||||
markevery=_get_marker_positions(thresholds),
|
|
||||||
)
|
|
||||||
|
|
||||||
if add_legend:
|
|
||||||
ax.legend(
|
|
||||||
bbox_to_anchor=(1.05, 1),
|
|
||||||
loc="upper left",
|
|
||||||
borderaxespad=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_xlim(0, 1.05)
|
|
||||||
ax.set_ylim(0, 1.05)
|
|
||||||
|
|
||||||
if add_labels:
|
|
||||||
ax.set_xlabel("Threshold")
|
|
||||||
ax.set_ylabel("Recall")
|
|
||||||
|
|
||||||
return ax
|
|
||||||
|
|
||||||
|
|
||||||
def plot_roc_curve(
|
|
||||||
fpr: np.ndarray,
|
|
||||||
tpr: np.ndarray,
|
|
||||||
thresholds: np.ndarray,
|
|
||||||
ax: Optional[axes.Axes] = None,
|
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
|
||||||
add_labels: bool = True,
|
|
||||||
) -> axes.Axes:
|
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
|
||||||
|
|
||||||
ax = set_default_style(ax)
|
|
||||||
|
|
||||||
ax.plot(
|
|
||||||
fpr,
|
|
||||||
tpr,
|
|
||||||
markevery=_get_marker_positions(thresholds),
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_xlim(0, 1.05)
|
|
||||||
ax.set_ylim(0, 1.05)
|
|
||||||
|
|
||||||
if add_labels:
|
|
||||||
ax.set_xlabel("False Positive Rate")
|
|
||||||
ax.set_ylabel("True Positive Rate")
|
|
||||||
|
|
||||||
return ax
|
|
||||||
|
|
||||||
|
|
||||||
def plot_roc_curves(
|
|
||||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
|
||||||
ax: Optional[axes.Axes] = None,
|
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
|
||||||
add_legend: bool = True,
|
|
||||||
add_labels: bool = True,
|
|
||||||
) -> axes.Axes:
|
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
|
||||||
ax = set_default_style(ax)
|
|
||||||
|
|
||||||
for name, (fpr, tpr, thresholds) in data.items():
|
|
||||||
ax.plot(
|
|
||||||
fpr,
|
|
||||||
tpr,
|
|
||||||
label=name,
|
|
||||||
markevery=_get_marker_positions(thresholds),
|
|
||||||
)
|
|
||||||
|
|
||||||
if add_legend:
|
|
||||||
ax.legend(
|
|
||||||
bbox_to_anchor=(1.05, 1),
|
|
||||||
loc="upper left",
|
|
||||||
borderaxespad=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_xlim(0, 1.05)
|
|
||||||
ax.set_ylim(0, 1.05)
|
|
||||||
|
|
||||||
if add_labels:
|
|
||||||
ax.set_xlabel("False Positive Rate")
|
|
||||||
ax.set_ylabel("True Positive Rate")
|
|
||||||
|
|
||||||
return ax
|
|
||||||
|
|
||||||
|
|
||||||
def _get_marker_positions(
|
|
||||||
thresholds: np.ndarray,
|
|
||||||
n_points: int = 11,
|
|
||||||
) -> np.ndarray:
|
|
||||||
size = len(thresholds)
|
|
||||||
cut_points = np.linspace(0, 1, n_points)
|
|
||||||
indices = np.searchsorted(thresholds[::-1], cut_points)
|
|
||||||
return np.clip(size - indices, 0, size - 1) # type: ignore
|
|
||||||
@ -1,25 +1,598 @@
|
|||||||
"""Main entry point for the BatDetect2 Postprocessing pipeline."""
|
"""Main entry point for the BatDetect2 Postprocessing pipeline.
|
||||||
|
|
||||||
from batdetect2.postprocess.config import (
|
This package (`batdetect2.postprocess`) takes the raw outputs from a trained
|
||||||
PostprocessConfig,
|
BatDetect2 neural network model and transforms them into meaningful, structured
|
||||||
load_postprocess_config,
|
predictions, typically in the form of `soundevent.data.ClipPrediction` objects
|
||||||
)
|
containing detected sound events with associated class tags and geometry.
|
||||||
|
|
||||||
|
The pipeline involves several configurable steps, implemented in submodules:
|
||||||
|
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
|
||||||
|
2. Coordinate Remapping (`.remapping`): Adds time/frequency coordinates to raw
|
||||||
|
model output arrays.
|
||||||
|
3. Detection Extraction (`.detection`): Identifies candidate detection points
|
||||||
|
(location and score) based on thresholds and score ranking (top-k).
|
||||||
|
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
|
||||||
|
class probabilities, features) at the detected locations.
|
||||||
|
5. Decoding & Formatting (`.decoding`): Converts extracted numerical data and
|
||||||
|
class predictions into interpretable `soundevent` objects, including
|
||||||
|
recovering geometry (ROIs) and decoding class names back to standard tags.
|
||||||
|
|
||||||
|
This module provides the primary interface:
|
||||||
|
- `PostprocessConfig`: A configuration object for postprocessing parameters
|
||||||
|
(thresholds, NMS kernel size, etc.).
|
||||||
|
- `load_postprocess_config`: Function to load the configuration from a file.
|
||||||
|
- `Postprocessor`: The main class (implementing `PostprocessorProtocol`) that
|
||||||
|
holds the configured pipeline logic.
|
||||||
|
- `build_postprocessor`: A factory function to create a `Postprocessor`
|
||||||
|
instance, linking it to the necessary target definitions (`TargetProtocol`).
|
||||||
|
It also re-exports key components from submodules for convenience.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import xarray as xr
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.models.types import ModelOutput
|
||||||
from batdetect2.postprocess.decoding import (
|
from batdetect2.postprocess.decoding import (
|
||||||
|
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
|
convert_raw_prediction_to_sound_event_prediction,
|
||||||
convert_raw_predictions_to_clip_prediction,
|
convert_raw_predictions_to_clip_prediction,
|
||||||
to_raw_predictions,
|
convert_xr_dataset_to_raw_prediction,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.nms import non_max_suppression
|
from batdetect2.postprocess.detection import (
|
||||||
from batdetect2.postprocess.postprocessor import (
|
DEFAULT_DETECTION_THRESHOLD,
|
||||||
Postprocessor,
|
TOP_K_PER_SEC,
|
||||||
build_postprocessor,
|
extract_detections_from_array,
|
||||||
|
get_max_detections,
|
||||||
)
|
)
|
||||||
|
from batdetect2.postprocess.extraction import (
|
||||||
|
extract_detection_xr_dataset,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess.nms import (
|
||||||
|
NMS_KERNEL_SIZE,
|
||||||
|
non_max_suppression,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess.remapping import (
|
||||||
|
classification_to_xarray,
|
||||||
|
detection_to_xarray,
|
||||||
|
features_to_xarray,
|
||||||
|
sizes_to_xarray,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess.types import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
PostprocessorProtocol,
|
||||||
|
RawPrediction,
|
||||||
|
)
|
||||||
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||||
|
"DEFAULT_DETECTION_THRESHOLD",
|
||||||
|
"MAX_FREQ",
|
||||||
|
"MIN_FREQ",
|
||||||
|
"ModelOutput",
|
||||||
|
"NMS_KERNEL_SIZE",
|
||||||
"PostprocessConfig",
|
"PostprocessConfig",
|
||||||
"Postprocessor",
|
"Postprocessor",
|
||||||
|
"PostprocessorProtocol",
|
||||||
|
"RawPrediction",
|
||||||
|
"TOP_K_PER_SEC",
|
||||||
"build_postprocessor",
|
"build_postprocessor",
|
||||||
|
"classification_to_xarray",
|
||||||
"convert_raw_predictions_to_clip_prediction",
|
"convert_raw_predictions_to_clip_prediction",
|
||||||
"to_raw_predictions",
|
"convert_xr_dataset_to_raw_prediction",
|
||||||
|
"detection_to_xarray",
|
||||||
|
"extract_detection_xr_dataset",
|
||||||
|
"extract_detections_from_array",
|
||||||
|
"features_to_xarray",
|
||||||
|
"get_max_detections",
|
||||||
"load_postprocess_config",
|
"load_postprocess_config",
|
||||||
"non_max_suppression",
|
"non_max_suppression",
|
||||||
|
"sizes_to_xarray",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PostprocessConfig(BaseConfig):
|
||||||
|
"""Configuration settings for the postprocessing pipeline.
|
||||||
|
|
||||||
|
Defines tunable parameters that control how raw model outputs are
|
||||||
|
converted into final detections.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
nms_kernel_size : int, default=NMS_KERNEL_SIZE
|
||||||
|
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
|
||||||
|
Used to suppress weaker detections near stronger peaks. Must be
|
||||||
|
positive.
|
||||||
|
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
|
||||||
|
Minimum confidence score from the detection heatmap required to
|
||||||
|
consider a point as a potential detection. Must be >= 0.
|
||||||
|
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||||
|
Minimum confidence score for a specific class prediction to be included
|
||||||
|
in the decoded tags for a detection. Must be >= 0.
|
||||||
|
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||||
|
Desired maximum number of detections per second of audio. Used by
|
||||||
|
`get_max_detections` to calculate an absolute limit based on clip
|
||||||
|
duration before applying `extract_detections_from_array`. Must be
|
||||||
|
positive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||||
|
detection_threshold: float = Field(
|
||||||
|
default=DEFAULT_DETECTION_THRESHOLD,
|
||||||
|
ge=0,
|
||||||
|
)
|
||||||
|
classification_threshold: float = Field(
|
||||||
|
default=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
|
ge=0,
|
||||||
|
)
|
||||||
|
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||||
|
|
||||||
|
|
||||||
|
def load_postprocess_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> PostprocessConfig:
|
||||||
|
"""Load the postprocessing configuration from a file.
|
||||||
|
|
||||||
|
Reads a configuration file (YAML) and validates it against the
|
||||||
|
`PostprocessConfig` schema, potentially extracting data from a nested
|
||||||
|
field.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
Path to the configuration file.
|
||||||
|
field : str, optional
|
||||||
|
Dot-separated path to a nested section within the file containing the
|
||||||
|
postprocessing configuration (e.g., "inference.postprocessing").
|
||||||
|
If None, the entire file content is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
PostprocessConfig
|
||||||
|
The loaded and validated postprocessing configuration object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
yaml.YAMLError
|
||||||
|
If the file content is not valid YAML.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the loaded configuration data does not conform to the
|
||||||
|
`PostprocessConfig` schema.
|
||||||
|
KeyError, TypeError
|
||||||
|
If `field` specifies an invalid path within the loaded data.
|
||||||
|
"""
|
||||||
|
return load_config(path, schema=PostprocessConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
def build_postprocessor(
|
||||||
|
targets: TargetProtocol,
|
||||||
|
config: Optional[PostprocessConfig] = None,
|
||||||
|
max_freq: float = MAX_FREQ,
|
||||||
|
min_freq: float = MIN_FREQ,
|
||||||
|
) -> PostprocessorProtocol:
|
||||||
|
"""Factory function to build the standard postprocessor.
|
||||||
|
|
||||||
|
Creates and initializes the `Postprocessor` instance, providing it with the
|
||||||
|
necessary `targets` object and the `PostprocessConfig`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
targets : TargetProtocol
|
||||||
|
An initialized object conforming to the `TargetProtocol`, providing
|
||||||
|
methods like `.decode()` and `.recover_roi()`, and attributes like
|
||||||
|
`.class_names` and `.generic_class_tags`. This links postprocessing
|
||||||
|
to the defined target semantics and geometry mappings.
|
||||||
|
config : PostprocessConfig, optional
|
||||||
|
Configuration object specifying postprocessing parameters (thresholds,
|
||||||
|
NMS kernel size, etc.). If None, default settings defined in
|
||||||
|
`PostprocessConfig` will be used.
|
||||||
|
min_freq : int, default=MIN_FREQ
|
||||||
|
The minimum frequency (Hz) corresponding to the frequency axis of the
|
||||||
|
model outputs. Required for coordinate remapping. Consider setting via
|
||||||
|
`PostprocessConfig` instead for better encapsulation.
|
||||||
|
max_freq : int, default=MAX_FREQ
|
||||||
|
The maximum frequency (Hz) corresponding to the frequency axis of the
|
||||||
|
model outputs. Required for coordinate remapping. Consider setting via
|
||||||
|
`PostprocessConfig`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
PostprocessorProtocol
|
||||||
|
An initialized `Postprocessor` instance ready to process model outputs.
|
||||||
|
"""
|
||||||
|
config = config or PostprocessConfig()
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building postprocessor with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
|
return Postprocessor(
|
||||||
|
targets=targets,
|
||||||
|
config=config,
|
||||||
|
min_freq=min_freq,
|
||||||
|
max_freq=max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Postprocessor(PostprocessorProtocol):
|
||||||
|
"""Standard implementation of the postprocessing pipeline.
|
||||||
|
|
||||||
|
This class orchestrates the steps required to convert raw model outputs
|
||||||
|
into interpretable `soundevent` predictions. It uses configured parameters
|
||||||
|
and leverages functions from the `batdetect2.postprocess` submodules for
|
||||||
|
each stage (NMS, remapping, detection, extraction, decoding).
|
||||||
|
|
||||||
|
It requires a `TargetProtocol` object during initialization to access
|
||||||
|
necessary decoding information (class name to tag mapping,
|
||||||
|
ROI recovery logic) ensuring consistency with the target definitions used
|
||||||
|
during training or specified for inference.
|
||||||
|
|
||||||
|
Instances are typically created using the `build_postprocessor` factory
|
||||||
|
function.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
targets : TargetProtocol
|
||||||
|
The configured target definition object providing decoding and ROI
|
||||||
|
recovery.
|
||||||
|
config : PostprocessConfig
|
||||||
|
Configuration object holding parameters for NMS, thresholds, etc.
|
||||||
|
min_freq : float
|
||||||
|
Minimum frequency (Hz) assumed for the model output's frequency axis.
|
||||||
|
max_freq : float
|
||||||
|
Maximum frequency (Hz) assumed for the model output's frequency axis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
targets: TargetProtocol
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
config: PostprocessConfig,
|
||||||
|
min_freq: float = MIN_FREQ,
|
||||||
|
max_freq: float = MAX_FREQ,
|
||||||
|
):
|
||||||
|
"""Initialize the Postprocessor.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
targets : TargetProtocol
|
||||||
|
Initialized target definition object.
|
||||||
|
config : PostprocessConfig
|
||||||
|
Configuration for postprocessing parameters.
|
||||||
|
min_freq : int, default=MIN_FREQ
|
||||||
|
Minimum frequency (Hz) for coordinate remapping.
|
||||||
|
max_freq : int, default=MAX_FREQ
|
||||||
|
Maximum frequency (Hz) for coordinate remapping.
|
||||||
|
"""
|
||||||
|
self.targets = targets
|
||||||
|
self.config = config
|
||||||
|
self.min_freq = min_freq
|
||||||
|
self.max_freq = max_freq
|
||||||
|
|
||||||
|
def get_feature_arrays(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[xr.DataArray]:
|
||||||
|
"""Extract and remap raw feature tensors for a batch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
Raw model output containing `output.features` tensor for the batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
List of Clip objects corresponding to the batch items.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.DataArray]
|
||||||
|
List of coordinate-aware feature DataArrays, one per clip.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If batch sizes of `output.features` and `clips` do not match.
|
||||||
|
"""
|
||||||
|
if len(clips) != len(output.features):
|
||||||
|
raise ValueError(
|
||||||
|
"Number of clips and batch size of feature array"
|
||||||
|
"do not match. "
|
||||||
|
f"(clips: {len(clips)}, features: {len(output.features)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
features_to_xarray(
|
||||||
|
feats,
|
||||||
|
start_time=clip.start_time,
|
||||||
|
end_time=clip.end_time,
|
||||||
|
min_freq=self.min_freq,
|
||||||
|
max_freq=self.max_freq,
|
||||||
|
)
|
||||||
|
for feats, clip in zip(output.features, clips)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_detection_arrays(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[xr.DataArray]:
|
||||||
|
"""Apply NMS and remap detection heatmaps for a batch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
Raw model output containing `output.detection_probs` tensor for the
|
||||||
|
batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
List of Clip objects corresponding to the batch items.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.DataArray]
|
||||||
|
List of NMS-applied, coordinate-aware detection heatmaps, one per
|
||||||
|
clip.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If batch sizes of `output.detection_probs` and `clips` do not match.
|
||||||
|
"""
|
||||||
|
detections = output.detection_probs
|
||||||
|
|
||||||
|
if len(clips) != len(output.detection_probs):
|
||||||
|
raise ValueError(
|
||||||
|
"Number of clips and batch size of detection array "
|
||||||
|
"do not match. "
|
||||||
|
f"(clips: {len(clips)}, detection: {len(detections)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
detections = non_max_suppression(
|
||||||
|
detections,
|
||||||
|
kernel_size=self.config.nms_kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
detection_to_xarray(
|
||||||
|
dets,
|
||||||
|
start_time=clip.start_time,
|
||||||
|
end_time=clip.end_time,
|
||||||
|
min_freq=self.min_freq,
|
||||||
|
max_freq=self.max_freq,
|
||||||
|
)
|
||||||
|
for dets, clip in zip(detections, clips)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_classification_arrays(
|
||||||
|
self, output: ModelOutput, clips: List[data.Clip]
|
||||||
|
) -> List[xr.DataArray]:
|
||||||
|
"""Extract and remap raw classification tensors for a batch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
Raw model output containing `output.class_probs` tensor for the
|
||||||
|
batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
List of Clip objects corresponding to the batch items.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.DataArray]
|
||||||
|
List of coordinate-aware class probability maps, one per clip.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If batch sizes of `output.class_probs` and `clips` do not match, or
|
||||||
|
if number of classes mismatches `self.targets.class_names`.
|
||||||
|
"""
|
||||||
|
classifications = output.class_probs
|
||||||
|
|
||||||
|
if len(clips) != len(classifications):
|
||||||
|
raise ValueError(
|
||||||
|
"Number of clips and batch size of classification array "
|
||||||
|
"do not match. "
|
||||||
|
f"(clips: {len(clips)}, classification: {len(classifications)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
classification_to_xarray(
|
||||||
|
class_probs,
|
||||||
|
start_time=clip.start_time,
|
||||||
|
end_time=clip.end_time,
|
||||||
|
class_names=self.targets.class_names,
|
||||||
|
min_freq=self.min_freq,
|
||||||
|
max_freq=self.max_freq,
|
||||||
|
)
|
||||||
|
for class_probs, clip in zip(classifications, clips)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_sizes_arrays(
|
||||||
|
self, output: ModelOutput, clips: List[data.Clip]
|
||||||
|
) -> List[xr.DataArray]:
|
||||||
|
"""Extract and remap raw size prediction tensors for a batch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
Raw model output containing `output.size_preds` tensor for the
|
||||||
|
batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
List of Clip objects corresponding to the batch items.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.DataArray]
|
||||||
|
List of coordinate-aware size prediction maps, one per clip.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If batch sizes of `output.size_preds` and `clips` do not match.
|
||||||
|
"""
|
||||||
|
sizes = output.size_preds
|
||||||
|
|
||||||
|
if len(clips) != len(sizes):
|
||||||
|
raise ValueError(
|
||||||
|
"Number of clips and batch size of sizes array do not match. "
|
||||||
|
f"(clips: {len(clips)}, sizes: {len(sizes)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
sizes_to_xarray(
|
||||||
|
size_preds,
|
||||||
|
start_time=clip.start_time,
|
||||||
|
end_time=clip.end_time,
|
||||||
|
min_freq=self.min_freq,
|
||||||
|
max_freq=self.max_freq,
|
||||||
|
)
|
||||||
|
for size_preds, clip in zip(sizes, clips)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_detection_datasets(
|
||||||
|
self, output: ModelOutput, clips: List[data.Clip]
|
||||||
|
) -> List[xr.Dataset]:
|
||||||
|
"""Perform NMS, remapping, detection, and data extraction for a batch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
Raw output from the neural network model for a batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.Dataset]
|
||||||
|
List of xarray Datasets (one per clip). Each Dataset contains
|
||||||
|
aligned scores, dimensions, class probabilities, and features for
|
||||||
|
detections found in that clip.
|
||||||
|
"""
|
||||||
|
detection_arrays = self.get_detection_arrays(output, clips)
|
||||||
|
classification_arrays = self.get_classification_arrays(output, clips)
|
||||||
|
size_arrays = self.get_sizes_arrays(output, clips)
|
||||||
|
features_arrays = self.get_feature_arrays(output, clips)
|
||||||
|
|
||||||
|
datasets = []
|
||||||
|
for det_array, class_array, sizes_array, feats_array in zip(
|
||||||
|
detection_arrays,
|
||||||
|
classification_arrays,
|
||||||
|
size_arrays,
|
||||||
|
features_arrays,
|
||||||
|
):
|
||||||
|
max_detections = get_max_detections(
|
||||||
|
det_array,
|
||||||
|
top_k_per_sec=self.config.top_k_per_sec,
|
||||||
|
)
|
||||||
|
|
||||||
|
positions = extract_detections_from_array(
|
||||||
|
det_array,
|
||||||
|
max_detections=max_detections,
|
||||||
|
threshold=self.config.detection_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
datasets.append(
|
||||||
|
extract_detection_xr_dataset(
|
||||||
|
positions,
|
||||||
|
sizes_array,
|
||||||
|
class_array,
|
||||||
|
feats_array,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return datasets
|
||||||
|
|
||||||
|
def get_raw_predictions(
|
||||||
|
self, output: ModelOutput, clips: List[data.Clip]
|
||||||
|
) -> List[List[RawPrediction]]:
|
||||||
|
"""Extract intermediate RawPrediction objects for a batch.
|
||||||
|
|
||||||
|
Processes raw model output through remapping, NMS, detection, data
|
||||||
|
extraction, and geometry recovery via the configured
|
||||||
|
`targets.recover_roi`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
Raw output from the neural network model for a batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[List[RawPrediction]]
|
||||||
|
List of lists (one inner list per input clip). Each inner list
|
||||||
|
contains `RawPrediction` objects for detections in that clip.
|
||||||
|
"""
|
||||||
|
detection_datasets = self.get_detection_datasets(output, clips)
|
||||||
|
return [
|
||||||
|
convert_xr_dataset_to_raw_prediction(
|
||||||
|
dataset,
|
||||||
|
self.targets.decode_roi,
|
||||||
|
)
|
||||||
|
for dataset in detection_datasets
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_sound_event_predictions(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[List[BatDetect2Prediction]]:
|
||||||
|
raw_predictions = self.get_raw_predictions(output, clips)
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
BatDetect2Prediction(
|
||||||
|
raw=raw,
|
||||||
|
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
||||||
|
raw,
|
||||||
|
recording=clip.recording,
|
||||||
|
targets=self.targets,
|
||||||
|
classification_threshold=self.config.classification_threshold,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for raw in predictions
|
||||||
|
]
|
||||||
|
for predictions, clip in zip(raw_predictions, clips)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_predictions(
|
||||||
|
self, output: ModelOutput, clips: List[data.Clip]
|
||||||
|
) -> List[data.ClipPrediction]:
|
||||||
|
"""Perform the full postprocessing pipeline for a batch.
|
||||||
|
|
||||||
|
Takes raw model output and corresponding clips, applies the entire
|
||||||
|
configured chain (NMS, remapping, extraction, geometry recovery, class
|
||||||
|
decoding), producing final `soundevent.data.ClipPrediction` objects.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
Raw output from the neural network model for a batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[data.ClipPrediction]
|
||||||
|
List containing one `ClipPrediction` object for each input clip,
|
||||||
|
populated with `SoundEventPrediction` objects.
|
||||||
|
"""
|
||||||
|
raw_predictions = self.get_raw_predictions(output, clips)
|
||||||
|
return [
|
||||||
|
convert_raw_predictions_to_clip_prediction(
|
||||||
|
prediction,
|
||||||
|
clip,
|
||||||
|
targets=self.targets,
|
||||||
|
classification_threshold=self.config.classification_threshold,
|
||||||
|
)
|
||||||
|
for prediction, clip in zip(raw_predictions, clips)
|
||||||
|
]
|
||||||
|
|||||||
@ -1,94 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
|
||||||
from batdetect2.postprocess.decoding import DEFAULT_CLASSIFICATION_THRESHOLD
|
|
||||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"PostprocessConfig",
|
|
||||||
"load_postprocess_config",
|
|
||||||
]
|
|
||||||
|
|
||||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
|
||||||
|
|
||||||
|
|
||||||
TOP_K_PER_SEC = 100
|
|
||||||
|
|
||||||
|
|
||||||
class PostprocessConfig(BaseConfig):
|
|
||||||
"""Configuration settings for the postprocessing pipeline.
|
|
||||||
|
|
||||||
Defines tunable parameters that control how raw model outputs are
|
|
||||||
converted into final detections.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
nms_kernel_size : int, default=NMS_KERNEL_SIZE
|
|
||||||
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
|
|
||||||
Used to suppress weaker detections near stronger peaks. Must be
|
|
||||||
positive.
|
|
||||||
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
|
|
||||||
Minimum confidence score from the detection heatmap required to
|
|
||||||
consider a point as a potential detection. Must be >= 0.
|
|
||||||
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
|
||||||
Minimum confidence score for a specific class prediction to be included
|
|
||||||
in the decoded tags for a detection. Must be >= 0.
|
|
||||||
top_k_per_sec : int, default=TOP_K_PER_SEC
|
|
||||||
Desired maximum number of detections per second of audio. Used by
|
|
||||||
`get_max_detections` to calculate an absolute limit based on clip
|
|
||||||
duration before applying `extract_detections_from_array`. Must be
|
|
||||||
positive.
|
|
||||||
"""
|
|
||||||
|
|
||||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
|
||||||
detection_threshold: float = Field(
|
|
||||||
default=DEFAULT_DETECTION_THRESHOLD,
|
|
||||||
ge=0,
|
|
||||||
)
|
|
||||||
classification_threshold: float = Field(
|
|
||||||
default=DEFAULT_CLASSIFICATION_THRESHOLD,
|
|
||||||
ge=0,
|
|
||||||
)
|
|
||||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
|
||||||
|
|
||||||
|
|
||||||
def load_postprocess_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> PostprocessConfig:
|
|
||||||
"""Load the postprocessing configuration from a file.
|
|
||||||
|
|
||||||
Reads a configuration file (YAML) and validates it against the
|
|
||||||
`PostprocessConfig` schema, potentially extracting data from a nested
|
|
||||||
field.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : PathLike
|
|
||||||
Path to the configuration file.
|
|
||||||
field : str, optional
|
|
||||||
Dot-separated path to a nested section within the file containing the
|
|
||||||
postprocessing configuration (e.g., "inference.postprocessing").
|
|
||||||
If None, the entire file content is used.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
PostprocessConfig
|
|
||||||
The loaded and validated postprocessing configuration object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the config file path does not exist.
|
|
||||||
yaml.YAMLError
|
|
||||||
If the file content is not valid YAML.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the loaded configuration data does not conform to the
|
|
||||||
`PostprocessConfig` schema.
|
|
||||||
KeyError, TypeError
|
|
||||||
If `field` specifies an invalid path within the loaded data.
|
|
||||||
"""
|
|
||||||
return load_config(path, schema=PostprocessConfig, field=field)
|
|
||||||
@ -1,18 +1,42 @@
|
|||||||
"""Decodes extracted detection data into standard soundevent predictions."""
|
"""Decodes extracted detection data into standard soundevent predictions.
|
||||||
|
|
||||||
|
This module handles the final stages of the BatDetect2 postprocessing pipeline.
|
||||||
|
It takes the structured detection data extracted by the `extraction` module
|
||||||
|
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
|
||||||
|
class probabilities, and features for each detection point) and converts it
|
||||||
|
into standardized prediction objects based on the `soundevent` data model.
|
||||||
|
|
||||||
|
The process involves:
|
||||||
|
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
|
||||||
|
objects, using a configured geometry builder to recover bounding boxes from
|
||||||
|
predicted positions and sizes (`convert_xr_dataset_to_raw_prediction`).
|
||||||
|
2. Converting each `RawPrediction` into a
|
||||||
|
`soundevent.data.SoundEventPrediction`, which involves:
|
||||||
|
- Creating the `soundevent.data.SoundEvent` with geometry and features.
|
||||||
|
- Decoding the predicted class probabilities into representative tags using
|
||||||
|
a configured class decoder (`SoundEventDecoder`).
|
||||||
|
- Applying a classification threshold.
|
||||||
|
- Optionally selecting only the single highest-scoring class (top-1) or
|
||||||
|
including tags for all classes above the threshold (multi-label).
|
||||||
|
- Adding generic class tags as a baseline.
|
||||||
|
- Associating scores with the final prediction and tags.
|
||||||
|
(`convert_raw_prediction_to_sound_event_prediction`)
|
||||||
|
3. Grouping the `SoundEventPrediction` objects for a given audio segment into
|
||||||
|
a `soundevent.data.ClipPrediction`
|
||||||
|
(`convert_raw_predictions_to_clip_prediction`).
|
||||||
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
|
||||||
ClipDetectionsArray,
|
from batdetect2.targets.types import TargetProtocol
|
||||||
RawPrediction,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"to_raw_predictions",
|
"convert_xr_dataset_to_raw_prediction",
|
||||||
"convert_raw_predictions_to_clip_prediction",
|
"convert_raw_predictions_to_clip_prediction",
|
||||||
"convert_raw_prediction_to_sound_event_prediction",
|
"convert_raw_prediction_to_sound_event_prediction",
|
||||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||||
@ -27,29 +51,65 @@ decoding.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def to_raw_predictions(
|
def convert_xr_dataset_to_raw_prediction(
|
||||||
detections: ClipDetectionsArray,
|
detection_dataset: xr.Dataset,
|
||||||
targets: TargetProtocol,
|
geometry_decoder: GeometryDecoder,
|
||||||
) -> List[RawPrediction]:
|
) -> List[RawPrediction]:
|
||||||
predictions = []
|
"""Convert an xarray.Dataset of detections to RawPrediction objects.
|
||||||
|
|
||||||
|
Takes the output of the extraction step (`extract_detection_xr_dataset`)
|
||||||
|
and transforms each detection entry into an intermediate `RawPrediction`
|
||||||
|
object. This involves recovering the geometry (e.g., bounding box) from
|
||||||
|
the predicted position and scaled size dimensions using the provided
|
||||||
|
`geometry_builder` function.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
detection_dataset : xr.Dataset
|
||||||
|
An xarray Dataset containing aligned detection information, typically
|
||||||
|
output by `extract_detection_xr_dataset`. Expected variables include
|
||||||
|
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
|
||||||
|
Must have a 'detection' dimension.
|
||||||
|
geometry_decoder : GeometryDecoder
|
||||||
|
A function that takes a position tuple `(time, freq)` and a NumPy array
|
||||||
|
of dimensions, and returns the corresponding reconstructed
|
||||||
|
`soundevent.data.Geometry`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[RawPrediction]
|
||||||
|
A list of `RawPrediction` objects, each containing the detection score,
|
||||||
|
recovered bounding box coordinates (start/end time, low/high freq),
|
||||||
|
the vector of class scores, and the feature vector for one detection.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
AttributeError, KeyError, ValueError
|
||||||
|
If `detection_dataset` is missing expected variables ('scores',
|
||||||
|
'dimensions', 'classes', 'features') or coordinates ('time', 'freq'
|
||||||
|
associated with 'scores'), or if `geometry_builder` fails.
|
||||||
|
"""
|
||||||
|
detections = []
|
||||||
|
|
||||||
|
categories = detection_dataset.category.values
|
||||||
|
|
||||||
for score, class_scores, time, freq, dims, feats in zip(
|
for score, class_scores, time, freq, dims, feats in zip(
|
||||||
detections.scores,
|
detection_dataset["scores"].values,
|
||||||
detections.class_scores,
|
detection_dataset["classes"].values,
|
||||||
detections.times,
|
detection_dataset["time"].values,
|
||||||
detections.frequencies,
|
detection_dataset["frequency"].values,
|
||||||
detections.sizes,
|
detection_dataset["dimensions"].values,
|
||||||
detections.features,
|
detection_dataset["features"].values,
|
||||||
):
|
):
|
||||||
highest_scoring_class = targets.class_names[class_scores.argmax()]
|
highest_scoring_class = categories[class_scores.argmax()]
|
||||||
|
|
||||||
geom = targets.decode_roi(
|
geom = geometry_decoder(
|
||||||
(time, freq),
|
(time, freq),
|
||||||
dims,
|
dims,
|
||||||
class_name=highest_scoring_class,
|
class_name=highest_scoring_class,
|
||||||
)
|
)
|
||||||
|
|
||||||
predictions.append(
|
detections.append(
|
||||||
RawPrediction(
|
RawPrediction(
|
||||||
detection_score=score,
|
detection_score=score,
|
||||||
geometry=geom,
|
geometry=geom,
|
||||||
@ -58,7 +118,7 @@ def to_raw_predictions(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return predictions
|
return detections
|
||||||
|
|
||||||
|
|
||||||
def convert_raw_predictions_to_clip_prediction(
|
def convert_raw_predictions_to_clip_prediction(
|
||||||
@ -68,7 +128,35 @@ def convert_raw_predictions_to_clip_prediction(
|
|||||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
top_class_only: bool = False,
|
top_class_only: bool = False,
|
||||||
) -> data.ClipPrediction:
|
) -> data.ClipPrediction:
|
||||||
"""Convert a list of RawPredictions into a soundevent ClipPrediction."""
|
"""Convert a list of RawPredictions into a soundevent ClipPrediction.
|
||||||
|
|
||||||
|
Iterates through `raw_predictions` (assumed to belong to a single clip),
|
||||||
|
converts each one into a `soundevent.data.SoundEventPrediction` using
|
||||||
|
`convert_raw_prediction_to_sound_event_prediction`, and packages them
|
||||||
|
into a `soundevent.data.ClipPrediction` associated with the original `clip`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
raw_predictions : List[RawPrediction]
|
||||||
|
List of raw prediction objects for a single clip.
|
||||||
|
clip : data.Clip
|
||||||
|
The original `soundevent.data.Clip` object these predictions belong to.
|
||||||
|
sound_event_decoder : SoundEventDecoder
|
||||||
|
Function to decode class names into representative tags.
|
||||||
|
generic_class_tags : List[data.Tag]
|
||||||
|
List of tags representing the generic class category.
|
||||||
|
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||||
|
Threshold applied to class scores during decoding.
|
||||||
|
top_class_only : bool, default=False
|
||||||
|
If True, only decode tags for the single highest-scoring class above
|
||||||
|
the threshold. If False, decode tags for all classes above threshold.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
data.ClipPrediction
|
||||||
|
A `ClipPrediction` object containing a list of `SoundEventPrediction`
|
||||||
|
objects corresponding to the input `raw_predictions`.
|
||||||
|
"""
|
||||||
return data.ClipPrediction(
|
return data.ClipPrediction(
|
||||||
clip=clip,
|
clip=clip,
|
||||||
sound_events=[
|
sound_events=[
|
||||||
@ -93,7 +181,68 @@ def convert_raw_prediction_to_sound_event_prediction(
|
|||||||
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
top_class_only: bool = False,
|
top_class_only: bool = False,
|
||||||
):
|
):
|
||||||
"""Convert a single RawPrediction into a soundevent SoundEventPrediction."""
|
"""Convert a single RawPrediction into a soundevent SoundEventPrediction.
|
||||||
|
|
||||||
|
This function performs the core decoding steps for a single detected event:
|
||||||
|
1. Creates a `soundevent.data.SoundEvent` containing the geometry
|
||||||
|
(BoundingBox derived from `raw_prediction` bounds) and any associated
|
||||||
|
feature vectors.
|
||||||
|
2. Initializes a list of predicted tags using the provided
|
||||||
|
`generic_class_tags`, assigning the overall `detection_score` from the
|
||||||
|
`raw_prediction` to these generic tags.
|
||||||
|
3. Processes the `class_scores` from the `raw_prediction`:
|
||||||
|
a. Optionally filters out scores below `classification_threshold`
|
||||||
|
(if it's not None).
|
||||||
|
b. Sorts the remaining scores in descending order.
|
||||||
|
c. Iterates through the sorted, thresholded class scores.
|
||||||
|
d. For each class, uses the `sound_event_decoder` to get the
|
||||||
|
representative base tags for that class name.
|
||||||
|
e. Wraps these base tags in `soundevent.data.PredictedTag`, associating
|
||||||
|
the specific `score` of that class prediction.
|
||||||
|
f. Appends these specific predicted tags to the list.
|
||||||
|
g. If `top_class_only` is True, stops after processing the first
|
||||||
|
(highest-scoring) class that passed the threshold.
|
||||||
|
4. Creates and returns the final `soundevent.data.SoundEventPrediction`,
|
||||||
|
associating the `SoundEvent`, the overall `detection_score`, and the
|
||||||
|
compiled list of `PredictedTag` objects.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
raw_prediction : RawPrediction
|
||||||
|
The raw prediction object containing score, bounds, class scores,
|
||||||
|
features. Assumes `class_scores` is an `xr.DataArray` with a 'category'
|
||||||
|
coordinate. Assumes `features` is an `xr.DataArray` with a 'feature'
|
||||||
|
coordinate.
|
||||||
|
recording : data.Recording
|
||||||
|
The recording the sound event belongs to.
|
||||||
|
sound_event_decoder : SoundEventDecoder
|
||||||
|
Configured function mapping class names (str) to lists of base
|
||||||
|
`data.Tag` objects.
|
||||||
|
generic_class_tags : List[data.Tag]
|
||||||
|
List of base tags representing the generic category.
|
||||||
|
classification_threshold : float, optional
|
||||||
|
The minimum score a class prediction must have to be considered
|
||||||
|
significant enough to have its tags decoded and added. If None, no
|
||||||
|
thresholding is applied based on class score (all predicted classes,
|
||||||
|
or the top one if `top_class_only` is True, will be processed).
|
||||||
|
Defaults to `DEFAULT_CLASSIFICATION_THRESHOLD`.
|
||||||
|
top_class_only : bool, default=False
|
||||||
|
If True, only includes tags for the single highest-scoring class that
|
||||||
|
exceeds the threshold. If False (default), includes tags for all classes
|
||||||
|
exceeding the threshold.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
data.SoundEventPrediction
|
||||||
|
The fully formed sound event prediction object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `raw_prediction.features` has unexpected structure or if
|
||||||
|
`data.term_from_key` (if used internally) fails.
|
||||||
|
If `sound_event_decoder` fails for a class name and errors are raised.
|
||||||
|
"""
|
||||||
sound_event = data.SoundEvent(
|
sound_event = data.SoundEvent(
|
||||||
recording=recording,
|
recording=recording,
|
||||||
geometry=raw_prediction.geometry,
|
geometry=raw_prediction.geometry,
|
||||||
@ -103,7 +252,7 @@ def convert_raw_prediction_to_sound_event_prediction(
|
|||||||
tags = [
|
tags = [
|
||||||
*get_generic_tags(
|
*get_generic_tags(
|
||||||
raw_prediction.detection_score,
|
raw_prediction.detection_score,
|
||||||
generic_class_tags=targets.detection_class_tags,
|
generic_class_tags=targets.generic_class_tags,
|
||||||
),
|
),
|
||||||
*get_class_tags(
|
*get_class_tags(
|
||||||
raw_prediction.class_scores,
|
raw_prediction.class_scores,
|
||||||
@ -124,7 +273,25 @@ def get_generic_tags(
|
|||||||
detection_score: float,
|
detection_score: float,
|
||||||
generic_class_tags: List[data.Tag],
|
generic_class_tags: List[data.Tag],
|
||||||
) -> List[data.PredictedTag]:
|
) -> List[data.PredictedTag]:
|
||||||
"""Create PredictedTag objects for the generic category."""
|
"""Create PredictedTag objects for the generic category.
|
||||||
|
|
||||||
|
Takes the base list of generic tags and assigns the overall detection
|
||||||
|
score to each one, wrapping them in `PredictedTag` objects.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
detection_score : float
|
||||||
|
The overall confidence score of the detection event.
|
||||||
|
generic_class_tags : List[data.Tag]
|
||||||
|
The list of base `soundevent.data.Tag` objects that define the
|
||||||
|
generic category (e.g., ['call_type:Echolocation', 'order:Chiroptera']).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[data.PredictedTag]
|
||||||
|
A list of `PredictedTag` objects for the generic category, each
|
||||||
|
assigned the `detection_score`.
|
||||||
|
"""
|
||||||
return [
|
return [
|
||||||
data.PredictedTag(tag=tag, score=detection_score)
|
data.PredictedTag(tag=tag, score=detection_score)
|
||||||
for tag in generic_class_tags
|
for tag in generic_class_tags
|
||||||
@ -132,7 +299,25 @@ def get_generic_tags(
|
|||||||
|
|
||||||
|
|
||||||
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
|
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
|
||||||
"""Convert an extracted feature vector DataArray into soundevent Features."""
|
"""Convert an extracted feature vector DataArray into soundevent Features.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
features : xr.DataArray
|
||||||
|
A 1D xarray DataArray containing feature values, indexed by a coordinate
|
||||||
|
named 'feature' which holds the feature names (e.g., output of selecting
|
||||||
|
features for one detection from `extract_detection_xr_dataset`).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[data.Feature]
|
||||||
|
A list of `soundevent.data.Feature` objects.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
- This function creates basic `Term` objects using the feature coordinate
|
||||||
|
names with a "batdetect2:" prefix.
|
||||||
|
"""
|
||||||
return [
|
return [
|
||||||
data.Feature(
|
data.Feature(
|
||||||
term=data.Term(
|
term=data.Term(
|
||||||
|
|||||||
162
src/batdetect2/postprocess/detection.py
Normal file
162
src/batdetect2/postprocess/detection.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
"""Extracts candidate detection points from a model output heatmap.
|
||||||
|
|
||||||
|
This module implements Step 3 within the BatDetect2 postprocessing
|
||||||
|
pipeline. Its primary function is to identify potential sound event locations
|
||||||
|
by finding peaks (local maxima or high-scoring points) in the detection heatmap
|
||||||
|
produced by the neural network (usually after Non-Maximum Suppression and
|
||||||
|
coordinate remapping have been applied).
|
||||||
|
|
||||||
|
It provides functionality to:
|
||||||
|
- Identify the locations (time, frequency) of the highest-scoring points.
|
||||||
|
- Filter these points based on a minimum confidence score threshold.
|
||||||
|
- Limit the maximum number of detection points returned (top-k).
|
||||||
|
|
||||||
|
The main output is an `xarray.DataArray` containing the scores and
|
||||||
|
corresponding time/frequency coordinates for the extracted detection points.
|
||||||
|
This output serves as the input for subsequent postprocessing steps, such as
|
||||||
|
extracting predicted class probabilities and bounding box sizes at these
|
||||||
|
specific locations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent.arrays import Dimensions, get_dim_width
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"extract_detections_from_array",
|
||||||
|
"get_max_detections",
|
||||||
|
"DEFAULT_DETECTION_THRESHOLD",
|
||||||
|
"TOP_K_PER_SEC",
|
||||||
|
]
|
||||||
|
|
||||||
|
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||||
|
"""Default confidence score threshold used for filtering detections."""
|
||||||
|
|
||||||
|
TOP_K_PER_SEC = 200
|
||||||
|
"""Default desired maximum number of detections per second of audio."""
|
||||||
|
|
||||||
|
|
||||||
|
def extract_detections_from_array(
|
||||||
|
detection_array: xr.DataArray,
|
||||||
|
max_detections: Optional[int] = None,
|
||||||
|
threshold: Optional[float] = DEFAULT_DETECTION_THRESHOLD,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
"""Extract detection locations (time, freq) and scores from a heatmap.
|
||||||
|
|
||||||
|
Identifies the pixels with the highest scores in the input detection
|
||||||
|
heatmap, filters them based on an optional score `threshold`, limits the
|
||||||
|
number to an optional `max_detections`, and returns their scores along with
|
||||||
|
their corresponding time and frequency coordinates.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
detection_array : xr.DataArray
|
||||||
|
A 2D xarray DataArray representing the detection heatmap. Must have
|
||||||
|
dimensions and coordinates named 'time' and 'frequency'. Higher values
|
||||||
|
are assumed to indicate higher detection confidence.
|
||||||
|
max_detections : int, optional
|
||||||
|
The absolute maximum number of detections to return. If specified, only
|
||||||
|
the top `max_detections` highest-scoring detections (passing the
|
||||||
|
threshold) are returned. If None (default), all detections passing
|
||||||
|
the threshold are returned, sorted by score.
|
||||||
|
threshold : float, optional
|
||||||
|
The minimum confidence score required for a detection peak to be
|
||||||
|
kept. Detections with scores below this value are discarded.
|
||||||
|
Defaults to `DEFAULT_DETECTION_THRESHOLD`. If set to None, no
|
||||||
|
thresholding is applied.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
A 1D xarray DataArray named 'score' with a 'detection' dimension.
|
||||||
|
- The data values are the scores of the extracted detections, sorted
|
||||||
|
in descending order.
|
||||||
|
- It has coordinates 'time' and 'frequency' (also indexed by the
|
||||||
|
'detection' dimension) indicating the location of each detection
|
||||||
|
peak in the original coordinate system.
|
||||||
|
- Returns an empty DataArray if no detections pass the criteria.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `max_detections` is not None and not a positive integer, or if
|
||||||
|
`detection_array` lacks required dimensions/coordinates.
|
||||||
|
"""
|
||||||
|
if max_detections is not None:
|
||||||
|
if max_detections <= 0:
|
||||||
|
raise ValueError("Max detections must be positive")
|
||||||
|
|
||||||
|
values = detection_array.values.flatten()
|
||||||
|
|
||||||
|
if max_detections is not None:
|
||||||
|
top_indices = np.argpartition(-values, max_detections)[:max_detections]
|
||||||
|
top_sorted_indices = top_indices[np.argsort(-values[top_indices])]
|
||||||
|
else:
|
||||||
|
top_sorted_indices = np.argsort(-values)
|
||||||
|
|
||||||
|
top_values = values[top_sorted_indices]
|
||||||
|
|
||||||
|
if threshold is not None:
|
||||||
|
mask = top_values > threshold
|
||||||
|
top_values = top_values[mask]
|
||||||
|
top_sorted_indices = top_sorted_indices[mask]
|
||||||
|
|
||||||
|
freq_indices, time_indices = np.unravel_index(
|
||||||
|
top_sorted_indices,
|
||||||
|
detection_array.shape,
|
||||||
|
)
|
||||||
|
|
||||||
|
times = detection_array.coords[Dimensions.time.value].values[time_indices]
|
||||||
|
freqs = detection_array.coords[Dimensions.frequency.value].values[
|
||||||
|
freq_indices
|
||||||
|
]
|
||||||
|
|
||||||
|
return xr.DataArray(
|
||||||
|
data=top_values,
|
||||||
|
coords={
|
||||||
|
Dimensions.frequency.value: ("detection", freqs),
|
||||||
|
Dimensions.time.value: ("detection", times),
|
||||||
|
},
|
||||||
|
dims="detection",
|
||||||
|
name="score",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_detections(
|
||||||
|
detection_array: xr.DataArray,
|
||||||
|
top_k_per_sec: int = TOP_K_PER_SEC,
|
||||||
|
) -> int:
|
||||||
|
"""Calculate max detections allowed based on duration and rate.
|
||||||
|
|
||||||
|
Determines the total maximum number of detections to extract from a
|
||||||
|
heatmap based on its time duration and a desired rate of detections
|
||||||
|
per second.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
detection_array : xr.DataArray
|
||||||
|
The detection heatmap, requiring 'time' coordinates from which the
|
||||||
|
total duration can be calculated using
|
||||||
|
`soundevent.arrays.get_dim_width`.
|
||||||
|
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||||
|
The desired maximum number of detections to allow per second of audio.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
int
|
||||||
|
The calculated total maximum number of detections allowed for the
|
||||||
|
entire duration of the `detection_array`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the duration cannot be calculated from the `detection_array` (e.g.,
|
||||||
|
missing or invalid 'time' coordinates/dimension).
|
||||||
|
"""
|
||||||
|
if top_k_per_sec < 0:
|
||||||
|
raise ValueError("top_k_per_sec cannot be negative.")
|
||||||
|
|
||||||
|
duration = get_dim_width(detection_array, Dimensions.time.value)
|
||||||
|
return int(duration * top_k_per_sec)
|
||||||
@ -15,75 +15,108 @@ precise time-frequency location of each detection. The final output aggregates
|
|||||||
all extracted information into a structured `xarray.Dataset`.
|
all extracted information into a structured `xarray.Dataset`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
import xarray as xr
|
||||||
|
from soundevent.arrays import Dimensions
|
||||||
import torch
|
|
||||||
|
|
||||||
from batdetect2.typing.postprocess import ClipDetectionsTensor
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"extract_detection_peaks",
|
"extract_values_at_positions",
|
||||||
|
"extract_detection_xr_dataset",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def extract_detection_peaks(
|
def extract_values_at_positions(
|
||||||
detection_heatmap: torch.Tensor,
|
array: xr.DataArray,
|
||||||
size_heatmap: torch.Tensor,
|
positions: xr.DataArray,
|
||||||
feature_heatmap: torch.Tensor,
|
) -> xr.DataArray:
|
||||||
classification_heatmap: torch.Tensor,
|
"""Extract values from an array at specified time-frequency positions.
|
||||||
max_detections: int = 200,
|
|
||||||
threshold: Optional[float] = None,
|
|
||||||
) -> List[ClipDetectionsTensor]:
|
|
||||||
height = detection_heatmap.shape[-2]
|
|
||||||
width = detection_heatmap.shape[-1]
|
|
||||||
|
|
||||||
freqs, times = torch.meshgrid(
|
Uses coordinate-based indexing to retrieve values from a source `array`
|
||||||
torch.arange(height, dtype=torch.int32),
|
(e.g., class probabilities, size predictions, features) at the time and
|
||||||
torch.arange(width, dtype=torch.int32),
|
frequency coordinates defined in the `positions` array.
|
||||||
indexing="ij",
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
array : xr.DataArray
|
||||||
|
The source DataArray from which to extract values. Must have 'time'
|
||||||
|
and 'frequency' dimensions and coordinates matching the space of
|
||||||
|
`positions`.
|
||||||
|
positions : xr.DataArray
|
||||||
|
A 1D DataArray whose 'time' and 'frequency' coordinates specify the
|
||||||
|
locations from which to extract values.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
A DataArray containing the values extracted from `array` at the given
|
||||||
|
positions.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError, IndexError, KeyError
|
||||||
|
If dimensions or coordinates are missing or incompatible between
|
||||||
|
`array` and `positions`, or if selection fails.
|
||||||
|
"""
|
||||||
|
return array.sel(
|
||||||
|
**{
|
||||||
|
Dimensions.frequency.value: positions.coords[
|
||||||
|
Dimensions.frequency.value
|
||||||
|
],
|
||||||
|
Dimensions.time.value: positions.coords[Dimensions.time.value],
|
||||||
|
}
|
||||||
|
).T
|
||||||
|
|
||||||
|
|
||||||
|
def extract_detection_xr_dataset(
|
||||||
|
positions: xr.DataArray,
|
||||||
|
sizes: xr.DataArray,
|
||||||
|
classes: xr.DataArray,
|
||||||
|
features: xr.DataArray,
|
||||||
|
) -> xr.Dataset:
|
||||||
|
"""Combine extracted detection information into a structured xr.Dataset.
|
||||||
|
|
||||||
|
Takes the detection positions/scores and the full model output heatmaps
|
||||||
|
(sizes, classes, optional features), extracts the relevant data at the
|
||||||
|
detection positions, and packages everything into a single `xarray.Dataset`
|
||||||
|
where all variables are indexed by a common 'detection' dimension.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
positions : xr.DataArray
|
||||||
|
Output from `extract_detections_from_array`, containing detection
|
||||||
|
scores as data and 'time', 'frequency' coordinates along the
|
||||||
|
'detection' dimension.
|
||||||
|
sizes : xr.DataArray
|
||||||
|
The full size prediction heatmap from the model, with dimensions like
|
||||||
|
('dimension', 'time', 'frequency').
|
||||||
|
classes : xr.DataArray
|
||||||
|
The full class probability heatmap from the model, with dimensions like
|
||||||
|
('category', 'time', 'frequency').
|
||||||
|
features : xr.DataArray
|
||||||
|
The full feature map from the model, with
|
||||||
|
dimensions like ('feature', 'time', 'frequency').
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.Dataset
|
||||||
|
An xarray Dataset containing aligned information for each detection:
|
||||||
|
- 'scores': DataArray from `positions` (score data, time/freq coords).
|
||||||
|
- 'dimensions': DataArray with extracted size values
|
||||||
|
(dims: 'detection', 'dimension').
|
||||||
|
- 'classes': DataArray with extracted class probabilities
|
||||||
|
(dims: 'detection', 'category').
|
||||||
|
- 'features': DataArray with extracted feature vectors
|
||||||
|
(dims: 'detection', 'feature'), if `features` was provided. All
|
||||||
|
DataArrays share the 'detection' dimension and associated
|
||||||
|
time/frequency coordinates.
|
||||||
|
"""
|
||||||
|
sizes = extract_values_at_positions(sizes, positions)
|
||||||
|
classes = extract_values_at_positions(classes, positions)
|
||||||
|
features = extract_values_at_positions(features, positions)
|
||||||
|
return xr.Dataset(
|
||||||
|
{
|
||||||
|
"scores": positions,
|
||||||
|
"dimensions": sizes,
|
||||||
|
"classes": classes,
|
||||||
|
"features": features,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
freqs = freqs.flatten().to(detection_heatmap.device)
|
|
||||||
times = times.flatten().to(detection_heatmap.device)
|
|
||||||
|
|
||||||
output_size_preds = size_heatmap.detach()
|
|
||||||
output_features = feature_heatmap.detach()
|
|
||||||
output_class_probs = classification_heatmap.detach()
|
|
||||||
|
|
||||||
predictions = []
|
|
||||||
for idx, item in enumerate(detection_heatmap):
|
|
||||||
item = item.squeeze().flatten() # Remove channel dim
|
|
||||||
indices = torch.argsort(item, descending=True)[:max_detections]
|
|
||||||
|
|
||||||
detection_scores = item.take(indices)
|
|
||||||
detection_freqs = freqs.take(indices)
|
|
||||||
detection_times = times.take(indices)
|
|
||||||
|
|
||||||
if threshold is not None:
|
|
||||||
mask = detection_scores >= threshold
|
|
||||||
|
|
||||||
detection_scores = detection_scores[mask]
|
|
||||||
detection_times = detection_times[mask]
|
|
||||||
detection_freqs = detection_freqs[mask]
|
|
||||||
|
|
||||||
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
|
|
||||||
features = output_features[idx, :, detection_freqs, detection_times].T
|
|
||||||
class_scores = output_class_probs[
|
|
||||||
idx,
|
|
||||||
:,
|
|
||||||
detection_freqs,
|
|
||||||
detection_times,
|
|
||||||
].T
|
|
||||||
|
|
||||||
predictions.append(
|
|
||||||
ClipDetectionsTensor(
|
|
||||||
scores=detection_scores,
|
|
||||||
sizes=sizes,
|
|
||||||
features=features,
|
|
||||||
class_scores=class_scores,
|
|
||||||
times=detection_times.to(torch.float32) / width,
|
|
||||||
frequencies=(detection_freqs.to(torch.float32) / height),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return predictions
|
|
||||||
|
|||||||
@ -1,100 +0,0 @@
|
|||||||
from typing import List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from batdetect2.postprocess.config import (
|
|
||||||
PostprocessConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.postprocess.extraction import extract_detection_peaks
|
|
||||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
|
||||||
from batdetect2.postprocess.remapping import map_detection_to_clip
|
|
||||||
from batdetect2.typing import ModelOutput
|
|
||||||
from batdetect2.typing.postprocess import (
|
|
||||||
ClipDetectionsTensor,
|
|
||||||
PostprocessorProtocol,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"build_postprocessor",
|
|
||||||
"Postprocessor",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_postprocessor(
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
config: Optional[PostprocessConfig] = None,
|
|
||||||
) -> PostprocessorProtocol:
|
|
||||||
"""Factory function to build the standard postprocessor."""
|
|
||||||
config = config or PostprocessConfig()
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Building postprocessor with config: \n{}",
|
|
||||||
lambda: config.to_yaml_string(),
|
|
||||||
)
|
|
||||||
return Postprocessor(
|
|
||||||
samplerate=preprocessor.output_samplerate,
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
top_k_per_sec=config.top_k_per_sec,
|
|
||||||
detection_threshold=config.detection_threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|
||||||
"""Standard implementation of the postprocessing pipeline."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
samplerate: float,
|
|
||||||
min_freq: float,
|
|
||||||
max_freq: float,
|
|
||||||
top_k_per_sec: int = 200,
|
|
||||||
detection_threshold: float = 0.01,
|
|
||||||
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
|
||||||
):
|
|
||||||
"""Initialize the Postprocessor."""
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.output_samplerate = samplerate
|
|
||||||
self.min_freq = min_freq
|
|
||||||
self.max_freq = max_freq
|
|
||||||
self.top_k_per_sec = top_k_per_sec
|
|
||||||
self.detection_threshold = detection_threshold
|
|
||||||
self.nms_kernel_size = nms_kernel_size
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
output: ModelOutput,
|
|
||||||
start_times: Optional[List[float]] = None,
|
|
||||||
) -> List[ClipDetectionsTensor]:
|
|
||||||
detection_heatmap = non_max_suppression(
|
|
||||||
output.detection_probs.detach(),
|
|
||||||
kernel_size=self.nms_kernel_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
width = output.detection_probs.shape[-1]
|
|
||||||
duration = width / self.output_samplerate
|
|
||||||
max_detections = int(self.top_k_per_sec * duration)
|
|
||||||
detections = extract_detection_peaks(
|
|
||||||
detection_heatmap,
|
|
||||||
size_heatmap=output.size_preds,
|
|
||||||
feature_heatmap=output.features,
|
|
||||||
classification_heatmap=output.class_probs,
|
|
||||||
max_detections=max_detections,
|
|
||||||
threshold=self.detection_threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
if start_times is None:
|
|
||||||
start_times = [0 for _ in range(len(detections))]
|
|
||||||
|
|
||||||
return [
|
|
||||||
map_detection_to_clip(
|
|
||||||
detection,
|
|
||||||
start_time=0,
|
|
||||||
end_time=duration,
|
|
||||||
min_freq=self.min_freq,
|
|
||||||
max_freq=self.max_freq,
|
|
||||||
)
|
|
||||||
for detection in detections
|
|
||||||
]
|
|
||||||
@ -20,7 +20,6 @@ import xarray as xr
|
|||||||
from soundevent.arrays import Dimensions
|
from soundevent.arrays import Dimensions
|
||||||
|
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
from batdetect2.typing.postprocess import ClipDetectionsTensor
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"features_to_xarray",
|
"features_to_xarray",
|
||||||
@ -30,25 +29,6 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def map_detection_to_clip(
|
|
||||||
detections: ClipDetectionsTensor,
|
|
||||||
start_time: float,
|
|
||||||
end_time: float,
|
|
||||||
min_freq: float,
|
|
||||||
max_freq: float,
|
|
||||||
) -> ClipDetectionsTensor:
|
|
||||||
duration = end_time - start_time
|
|
||||||
bandwidth = max_freq - min_freq
|
|
||||||
return ClipDetectionsTensor(
|
|
||||||
scores=detections.scores,
|
|
||||||
sizes=detections.sizes,
|
|
||||||
features=detections.features,
|
|
||||||
class_scores=detections.class_scores,
|
|
||||||
times=(detections.times * duration + start_time),
|
|
||||||
frequencies=(detections.frequencies * bandwidth + min_freq),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def features_to_xarray(
|
def features_to_xarray(
|
||||||
features: torch.Tensor,
|
features: torch.Tensor,
|
||||||
start_time: float,
|
start_time: float,
|
||||||
|
|||||||
295
src/batdetect2/postprocess/types.py
Normal file
295
src/batdetect2/postprocess/types.py
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
"""Defines shared interfaces and data structures for postprocessing.
|
||||||
|
|
||||||
|
This module centralizes the Protocol definitions and common data structures
|
||||||
|
used throughout the `batdetect2.postprocess` module.
|
||||||
|
|
||||||
|
The main component is the `PostprocessorProtocol`, which outlines the standard
|
||||||
|
interface for an object responsible for executing the entire postprocessing
|
||||||
|
pipeline. This pipeline transforms raw neural network outputs into interpretable
|
||||||
|
detections represented as `soundevent` objects. Using protocols ensures
|
||||||
|
modularity and consistent interaction between different parts of the BatDetect2
|
||||||
|
system that deal with model predictions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, NamedTuple, Optional, Protocol
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.models.types import ModelOutput
|
||||||
|
from batdetect2.targets.types import Position, Size
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RawPrediction",
|
||||||
|
"PostprocessorProtocol",
|
||||||
|
"GeometryDecoder",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: update the docstring
|
||||||
|
class GeometryDecoder(Protocol):
|
||||||
|
"""Type alias for a function that recovers geometry from position and size.
|
||||||
|
|
||||||
|
This callable takes:
|
||||||
|
1. A position tuple `(time, frequency)`.
|
||||||
|
2. A NumPy array of size dimensions (e.g., `[width, height]`).
|
||||||
|
3. Optionally a class name of the highest scoring class. This is to accomodate
|
||||||
|
different ways of decoding geometry that depend on the predicted class.
|
||||||
|
It should return the reconstructed `soundevent.data.Geometry` (typically a
|
||||||
|
`BoundingBox`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, position: Position, size: Size, class_name: Optional[str] = None
|
||||||
|
) -> data.Geometry: ...
|
||||||
|
|
||||||
|
|
||||||
|
class RawPrediction(NamedTuple):
|
||||||
|
"""Intermediate representation of a single detected sound event.
|
||||||
|
|
||||||
|
Holds extracted information about a detection after initial processing
|
||||||
|
(like peak finding, coordinate remapping, geometry recovery) but before
|
||||||
|
final class decoding and conversion into a `SoundEventPrediction`. This
|
||||||
|
can be useful for evaluation or simpler data handling formats.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
geometry: data.Geometry
|
||||||
|
The recovered estimated geometry of the detected sound event.
|
||||||
|
Usually a bounding box.
|
||||||
|
detection_score : float
|
||||||
|
The confidence score associated with this detection, typically from
|
||||||
|
the detection heatmap peak.
|
||||||
|
class_scores : xr.DataArray
|
||||||
|
An xarray DataArray containing the predicted probabilities or scores
|
||||||
|
for each target class at the detection location. Indexed by a
|
||||||
|
'category' coordinate containing class names.
|
||||||
|
features : xr.DataArray
|
||||||
|
An xarray DataArray containing extracted feature vectors at the
|
||||||
|
detection location. Indexed by a 'feature' coordinate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
geometry: data.Geometry
|
||||||
|
detection_score: float
|
||||||
|
class_scores: np.ndarray
|
||||||
|
features: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BatDetect2Prediction:
|
||||||
|
raw: RawPrediction
|
||||||
|
sound_event_prediction: data.SoundEventPrediction
|
||||||
|
|
||||||
|
|
||||||
|
class PostprocessorProtocol(Protocol):
|
||||||
|
"""Protocol defining the interface for the full postprocessing pipeline.
|
||||||
|
|
||||||
|
This protocol outlines the standard methods for an object that takes raw
|
||||||
|
output from a BatDetect2 model and the corresponding input clip metadata,
|
||||||
|
and processes it through various stages (e.g., coordinate remapping, NMS,
|
||||||
|
detection extraction, data extraction, decoding) to produce interpretable
|
||||||
|
results at different levels of completion.
|
||||||
|
|
||||||
|
Implementations manage the configured logic for all postprocessing steps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_feature_arrays(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[xr.DataArray]:
|
||||||
|
"""Remap feature tensors to coordinate-aware DataArrays.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
The raw output from the neural network model for a batch, expected
|
||||||
|
to contain the necessary feature tensors.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
A list of `soundevent.data.Clip` objects, one for each item in the
|
||||||
|
processed batch. This list provides the timing, recording, and
|
||||||
|
other metadata context needed to calculate real-world coordinates
|
||||||
|
(seconds, Hz) for the output arrays. The length of this list must
|
||||||
|
correspond to the batch size of the `output`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.DataArray]
|
||||||
|
A list of xarray DataArrays, one for each input clip in the batch,
|
||||||
|
in the same order. Each DataArray contains the feature vectors
|
||||||
|
with dimensions like ('feature', 'time', 'frequency') and
|
||||||
|
corresponding real-world coordinates.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_detection_arrays(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[xr.DataArray]:
|
||||||
|
"""Remap detection tensors to coordinate-aware DataArrays.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
The raw output from the neural network model for a batch,
|
||||||
|
containing detection heatmaps.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||||
|
items, providing coordinate context. Must match the batch size of
|
||||||
|
`output`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.DataArray]
|
||||||
|
A list of 2D xarray DataArrays (one per input clip, in order),
|
||||||
|
representing the detection heatmap with 'time' and 'frequency'
|
||||||
|
coordinates. Values typically indicate detection confidence.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_classification_arrays(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[xr.DataArray]:
|
||||||
|
"""Remap classification tensors to coordinate-aware DataArrays.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
The raw output from the neural network model for a batch,
|
||||||
|
containing class probability tensors.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||||
|
items, providing coordinate context. Must match the batch size of
|
||||||
|
`output`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.DataArray]
|
||||||
|
A list of 3D xarray DataArrays (one per input clip, in order),
|
||||||
|
representing class probabilities with 'category', 'time', and
|
||||||
|
'frequency' dimensions and coordinates.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_sizes_arrays(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[xr.DataArray]:
|
||||||
|
"""Remap size prediction tensors to coordinate-aware DataArrays.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
The raw output from the neural network model for a batch,
|
||||||
|
containing predicted size tensors (e.g., width and height).
|
||||||
|
clips : List[data.Clip]
|
||||||
|
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||||
|
items, providing coordinate context. Must match the batch size of
|
||||||
|
`output`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.DataArray]
|
||||||
|
A list of 3D xarray DataArrays (one per input clip, in order),
|
||||||
|
representing predicted sizes with 'dimension'
|
||||||
|
(e.g., ['width', 'height']), 'time', and 'frequency' dimensions and
|
||||||
|
coordinates. Values represent estimated detection sizes.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_detection_datasets(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[xr.Dataset]:
|
||||||
|
"""Perform remapping, NMS, detection, and data extraction for a batch.
|
||||||
|
|
||||||
|
Processes the raw model output for a batch to identify detection peaks
|
||||||
|
and extract all associated information (score, position, size, class
|
||||||
|
probs, features) at those peak locations, returning a structured
|
||||||
|
dataset for each input clip in the batch.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
The raw output from the neural network model for a batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||||
|
items, providing context. Must match the batch size of `output`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[xr.Dataset]
|
||||||
|
A list of xarray Datasets (one per input clip, in order). Each
|
||||||
|
Dataset contains multiple DataArrays ('scores', 'dimensions',
|
||||||
|
'classes', 'features') sharing a common 'detection' dimension,
|
||||||
|
providing aligned data for each detected event in that clip.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_raw_predictions(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[List[RawPrediction]]:
|
||||||
|
"""Extract intermediate RawPrediction objects for a batch.
|
||||||
|
|
||||||
|
Processes the raw model output for a batch through remapping, NMS,
|
||||||
|
detection, data extraction, and geometry recovery to produce a list of
|
||||||
|
`RawPrediction` objects for each corresponding input clip. This provides
|
||||||
|
a simplified, intermediate representation before final tag decoding.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
The raw output from the neural network model for a batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||||
|
items, providing context. Must match the batch size of `output`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[List[RawPrediction]]
|
||||||
|
A list of lists (one inner list per input clip, in order). Each
|
||||||
|
inner list contains the `RawPrediction` objects extracted for the
|
||||||
|
corresponding input clip.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_sound_event_predictions(
|
||||||
|
self, output: ModelOutput, clips: List[data.Clip]
|
||||||
|
) -> List[List[BatDetect2Prediction]]: ...
|
||||||
|
|
||||||
|
def get_predictions(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
) -> List[data.ClipPrediction]:
|
||||||
|
"""Perform the full postprocessing pipeline for a batch.
|
||||||
|
|
||||||
|
Takes raw model output for a batch and corresponding clips, applies the
|
||||||
|
entire postprocessing chain, and returns the final, interpretable
|
||||||
|
predictions as a list of `soundevent.data.ClipPrediction` objects.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
output : ModelOutput
|
||||||
|
The raw output from the neural network model for a batch.
|
||||||
|
clips : List[data.Clip]
|
||||||
|
A list of `soundevent.data.Clip` objects corresponding to the batch
|
||||||
|
items, providing context. Must match the batch size of `output`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[data.ClipPrediction]
|
||||||
|
A list containing one `ClipPrediction` object for each input clip
|
||||||
|
(in the same order), populated with `SoundEventPrediction` objects
|
||||||
|
representing the final detections with decoded tags and geometry.
|
||||||
|
"""
|
||||||
|
...
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user