mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Compare commits
9 Commits
60e922d565
...
4cd983a2c2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4cd983a2c2 | ||
|
|
e65df81db2 | ||
|
|
6c25787123 | ||
|
|
8c80402f08 | ||
|
|
b81a882b58 | ||
|
|
6e217380f2 | ||
|
|
957c0735d2 | ||
|
|
bbb96b33a2 | ||
|
|
7d6cba5465 |
@ -1,66 +1,27 @@
|
|||||||
targets:
|
audio:
|
||||||
detection_target:
|
samplerate: 256000
|
||||||
name: bat
|
resample:
|
||||||
match_if:
|
enabled: True
|
||||||
name: all_of
|
method: "poly"
|
||||||
conditions:
|
|
||||||
- name: has_tag
|
|
||||||
tag: { key: event, value: Echolocation }
|
|
||||||
- name: not
|
|
||||||
condition:
|
|
||||||
name: has_tag
|
|
||||||
tag: { key: class, value: Unknown }
|
|
||||||
assign_tags:
|
|
||||||
- key: class
|
|
||||||
value: Bat
|
|
||||||
|
|
||||||
classification_targets:
|
|
||||||
- name: myomys
|
|
||||||
tags:
|
|
||||||
- key: class
|
|
||||||
value: Myotis mystacinus
|
|
||||||
- name: pippip
|
|
||||||
tags:
|
|
||||||
- key: class
|
|
||||||
value: Pipistrellus pipistrellus
|
|
||||||
- name: eptser
|
|
||||||
tags:
|
|
||||||
- key: class
|
|
||||||
value: Eptesicus serotinus
|
|
||||||
- name: rhifer
|
|
||||||
tags:
|
|
||||||
- key: class
|
|
||||||
value: Rhinolophus ferrumequinum
|
|
||||||
|
|
||||||
roi:
|
|
||||||
name: anchor_bbox
|
|
||||||
anchor: top-left
|
|
||||||
|
|
||||||
preprocess:
|
preprocess:
|
||||||
audio:
|
stft:
|
||||||
samplerate: 256000
|
window_duration: 0.002
|
||||||
resample:
|
window_overlap: 0.75
|
||||||
enabled: True
|
window_fn: hann
|
||||||
method: "poly"
|
frequencies:
|
||||||
|
max_freq: 120000
|
||||||
spectrogram:
|
min_freq: 10000
|
||||||
stft:
|
size:
|
||||||
window_duration: 0.002
|
height: 128
|
||||||
window_overlap: 0.75
|
resize_factor: 0.5
|
||||||
window_fn: hann
|
spectrogram_transforms:
|
||||||
frequencies:
|
- name: pcen
|
||||||
max_freq: 120000
|
time_constant: 0.1
|
||||||
min_freq: 10000
|
gain: 0.98
|
||||||
size:
|
bias: 2
|
||||||
height: 128
|
power: 0.5
|
||||||
resize_factor: 0.5
|
- name: spectral_mean_substraction
|
||||||
transforms:
|
|
||||||
- name: pcen
|
|
||||||
time_constant: 0.1
|
|
||||||
gain: 0.98
|
|
||||||
bias: 2
|
|
||||||
power: 0.5
|
|
||||||
- name: spectral_mean_substraction
|
|
||||||
|
|
||||||
postprocess:
|
postprocess:
|
||||||
nms_kernel_size: 9
|
nms_kernel_size: 9
|
||||||
@ -102,23 +63,57 @@ model:
|
|||||||
out_channels: 32
|
out_channels: 32
|
||||||
|
|
||||||
train:
|
train:
|
||||||
learning_rate: 0.001
|
optimizer:
|
||||||
t_max: 100
|
learning_rate: 0.001
|
||||||
|
t_max: 100
|
||||||
|
|
||||||
labels:
|
labels:
|
||||||
sigma: 3
|
sigma: 3
|
||||||
|
|
||||||
trainer:
|
trainer:
|
||||||
max_epochs: 5
|
max_epochs: 10
|
||||||
|
check_val_every_n_epoch: 5
|
||||||
|
|
||||||
train_loader:
|
train_loader:
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
|
|
||||||
num_workers: 2
|
num_workers: 2
|
||||||
|
|
||||||
shuffle: True
|
shuffle: True
|
||||||
|
|
||||||
clipping_strategy:
|
clipping_strategy:
|
||||||
name: random_subclip
|
name: random_subclip
|
||||||
duration: 0.256
|
duration: 0.256
|
||||||
|
|
||||||
|
augmentations:
|
||||||
|
enabled: true
|
||||||
|
audio:
|
||||||
|
- name: mix_audio
|
||||||
|
probability: 0.2
|
||||||
|
min_weight: 0.3
|
||||||
|
max_weight: 0.7
|
||||||
|
- name: add_echo
|
||||||
|
probability: 0.2
|
||||||
|
max_delay: 0.005
|
||||||
|
min_weight: 0.0
|
||||||
|
max_weight: 1.0
|
||||||
|
spectrogram:
|
||||||
|
- name: scale_volume
|
||||||
|
probability: 0.2
|
||||||
|
min_scaling: 0.0
|
||||||
|
max_scaling: 2.0
|
||||||
|
- name: warp
|
||||||
|
probability: 0.2
|
||||||
|
delta: 0.04
|
||||||
|
- name: mask_time
|
||||||
|
probability: 0.2
|
||||||
|
max_perc: 0.05
|
||||||
|
max_masks: 3
|
||||||
|
- name: mask_freq
|
||||||
|
probability: 0.2
|
||||||
|
max_perc: 0.10
|
||||||
|
max_masks: 3
|
||||||
|
|
||||||
val_loader:
|
val_loader:
|
||||||
num_workers: 2
|
num_workers: 2
|
||||||
clipping_strategy:
|
clipping_strategy:
|
||||||
@ -142,31 +137,28 @@ train:
|
|||||||
logger:
|
logger:
|
||||||
name: csv
|
name: csv
|
||||||
|
|
||||||
augmentations:
|
validation:
|
||||||
enabled: true
|
metrics:
|
||||||
audio:
|
- name: detection_ap
|
||||||
- name: mix_audio
|
- name: detection_roc_auc
|
||||||
probability: 0.2
|
- name: classification_ap
|
||||||
min_weight: 0.3
|
- name: classification_roc_auc
|
||||||
max_weight: 0.7
|
- name: top_class_ap
|
||||||
- name: add_echo
|
- name: classification_balanced_accuracy
|
||||||
probability: 0.2
|
- name: clip_ap
|
||||||
max_delay: 0.005
|
- name: clip_roc_auc
|
||||||
min_weight: 0.0
|
|
||||||
max_weight: 1.0
|
evaluation:
|
||||||
spectrogram:
|
match_strategy:
|
||||||
- name: scale_volume
|
name: start_time_match
|
||||||
probability: 0.2
|
distance_threshold: 0.01
|
||||||
min_scaling: 0.0
|
metrics:
|
||||||
max_scaling: 2.0
|
- name: classification_ap
|
||||||
- name: warp
|
- name: detection_ap
|
||||||
probability: 0.2
|
plots:
|
||||||
delta: 0.04
|
- name: example_gallery
|
||||||
- name: mask_time
|
- name: example_clip
|
||||||
probability: 0.2
|
- name: detection_pr_curve
|
||||||
max_perc: 0.05
|
- name: classification_pr_curves
|
||||||
max_masks: 3
|
- name: detection_roc_curve
|
||||||
- name: mask_freq
|
- name: classification_roc_curves
|
||||||
probability: 0.2
|
|
||||||
max_perc: 0.10
|
|
||||||
max_masks: 3
|
|
||||||
|
|||||||
36
example_data/targets.yaml
Normal file
36
example_data/targets.yaml
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
detection_target:
|
||||||
|
name: bat
|
||||||
|
match_if:
|
||||||
|
name: all_of
|
||||||
|
conditions:
|
||||||
|
- name: has_tag
|
||||||
|
tag: { key: event, value: Echolocation }
|
||||||
|
- name: not
|
||||||
|
condition:
|
||||||
|
name: has_tag
|
||||||
|
tag: { key: class, value: Unknown }
|
||||||
|
assign_tags:
|
||||||
|
- key: class
|
||||||
|
value: Bat
|
||||||
|
|
||||||
|
classification_targets:
|
||||||
|
- name: myomys
|
||||||
|
tags:
|
||||||
|
- key: class
|
||||||
|
value: Myotis mystacinus
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: class
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
- name: eptser
|
||||||
|
tags:
|
||||||
|
- key: class
|
||||||
|
value: Eptesicus serotinus
|
||||||
|
- name: rhifer
|
||||||
|
tags:
|
||||||
|
- key: class
|
||||||
|
value: Rhinolophus ferrumequinum
|
||||||
|
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: top-left
|
||||||
179
src/batdetect2/api/base.py
Normal file
179
src/batdetect2/api/base.py
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.audio import build_audio_loader
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
|
from batdetect2.evaluate import build_evaluator, evaluate
|
||||||
|
from batdetect2.models import Model, build_model
|
||||||
|
from batdetect2.postprocess import build_postprocessor
|
||||||
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.targets.targets import build_targets
|
||||||
|
from batdetect2.train import train
|
||||||
|
from batdetect2.train.lightning import load_model_from_checkpoint
|
||||||
|
from batdetect2.typing import (
|
||||||
|
AudioLoader,
|
||||||
|
EvaluatorProtocol,
|
||||||
|
PostprocessorProtocol,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BatDetect2API:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: BatDetect2Config,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
audio_loader: AudioLoader,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
postprocessor: PostprocessorProtocol,
|
||||||
|
evaluator: EvaluatorProtocol,
|
||||||
|
model: Model,
|
||||||
|
):
|
||||||
|
self.config = config
|
||||||
|
self.targets = targets
|
||||||
|
self.audio_loader = audio_loader
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
self.postprocessor = postprocessor
|
||||||
|
self.evaluator = evaluator
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self,
|
||||||
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
||||||
|
train_workers: Optional[int] = None,
|
||||||
|
val_workers: Optional[int] = None,
|
||||||
|
checkpoint_dir: Optional[Path] = None,
|
||||||
|
log_dir: Optional[Path] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
):
|
||||||
|
train(
|
||||||
|
train_annotations=train_annotations,
|
||||||
|
val_annotations=val_annotations,
|
||||||
|
targets=self.targets,
|
||||||
|
config=self.config,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
|
preprocessor=self.preprocessor,
|
||||||
|
train_workers=train_workers,
|
||||||
|
val_workers=val_workers,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
log_dir=log_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
run_name=run_name,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
self,
|
||||||
|
test_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
output_dir: data.PathLike = ".",
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
return evaluate(
|
||||||
|
self.model,
|
||||||
|
test_annotations,
|
||||||
|
targets=self.targets,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
|
preprocessor=self.preprocessor,
|
||||||
|
config=self.config,
|
||||||
|
num_workers=num_workers,
|
||||||
|
output_dir=output_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
run_name=run_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: BatDetect2Config):
|
||||||
|
targets = build_targets(config=config.targets)
|
||||||
|
|
||||||
|
audio_loader = build_audio_loader(config=config.audio)
|
||||||
|
|
||||||
|
preprocessor = build_preprocessor(
|
||||||
|
input_samplerate=audio_loader.samplerate,
|
||||||
|
config=config.preprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
postprocessor = build_postprocessor(
|
||||||
|
preprocessor,
|
||||||
|
config=config.postprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
evaluator = build_evaluator(
|
||||||
|
config=config.evaluation,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: Better to have a separate instance of
|
||||||
|
# preprocessor and postprocessor as these may be moved
|
||||||
|
# to another device.
|
||||||
|
model = build_model(
|
||||||
|
config=config.model,
|
||||||
|
targets=targets,
|
||||||
|
preprocessor=build_preprocessor(
|
||||||
|
input_samplerate=audio_loader.samplerate,
|
||||||
|
config=config.preprocess,
|
||||||
|
),
|
||||||
|
postprocessor=build_postprocessor(
|
||||||
|
preprocessor,
|
||||||
|
config=config.postprocess,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
evaluator=evaluator,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_checkpoint(
|
||||||
|
cls,
|
||||||
|
path: data.PathLike,
|
||||||
|
config: Optional[BatDetect2Config] = None,
|
||||||
|
):
|
||||||
|
model, stored_config = load_model_from_checkpoint(path)
|
||||||
|
|
||||||
|
config = config or stored_config
|
||||||
|
|
||||||
|
targets = build_targets(config=config.targets)
|
||||||
|
|
||||||
|
audio_loader = build_audio_loader(config=config.audio)
|
||||||
|
|
||||||
|
preprocessor = build_preprocessor(
|
||||||
|
input_samplerate=audio_loader.samplerate,
|
||||||
|
config=config.preprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
postprocessor = build_postprocessor(
|
||||||
|
preprocessor,
|
||||||
|
config=config.postprocess,
|
||||||
|
)
|
||||||
|
|
||||||
|
evaluator = build_evaluator(
|
||||||
|
config=config.evaluation,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
evaluator=evaluator,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
16
src/batdetect2/audio/__init__.py
Normal file
16
src/batdetect2/audio/__init__.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from batdetect2.audio.clips import ClipConfig, build_clipper
|
||||||
|
from batdetect2.audio.loader import (
|
||||||
|
TARGET_SAMPLERATE_HZ,
|
||||||
|
AudioConfig,
|
||||||
|
SoundEventAudioLoader,
|
||||||
|
build_audio_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TARGET_SAMPLERATE_HZ",
|
||||||
|
"AudioConfig",
|
||||||
|
"SoundEventAudioLoader",
|
||||||
|
"build_audio_loader",
|
||||||
|
"ClipConfig",
|
||||||
|
"build_clipper",
|
||||||
|
]
|
||||||
@ -6,14 +6,19 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds, intervals_overlap
|
from soundevent.geometry import compute_bounds, intervals_overlap
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.data._core import Registry
|
|
||||||
from batdetect2.typing import ClipperProtocol
|
from batdetect2.typing import ClipperProtocol
|
||||||
|
|
||||||
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
||||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"build_clipper",
|
||||||
|
"ClipConfig",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")
|
clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")
|
||||||
|
|
||||||
|
|
||||||
295
src/batdetect2/audio/loader.py
Normal file
295
src/batdetect2/audio/loader.py
Normal file
@ -0,0 +1,295 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import DTypeLike
|
||||||
|
from pydantic import Field
|
||||||
|
from scipy.signal import resample, resample_poly
|
||||||
|
from soundevent import audio, data
|
||||||
|
from soundfile import LibsndfileError
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig
|
||||||
|
from batdetect2.typing import AudioLoader
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SoundEventAudioLoader",
|
||||||
|
"build_audio_loader",
|
||||||
|
"load_file_audio",
|
||||||
|
"load_recording_audio",
|
||||||
|
"load_clip_audio",
|
||||||
|
"resample_audio",
|
||||||
|
]
|
||||||
|
|
||||||
|
TARGET_SAMPLERATE_HZ = 256_000
|
||||||
|
"""Default target sample rate in Hz used if resampling is enabled."""
|
||||||
|
|
||||||
|
|
||||||
|
class ResampleConfig(BaseConfig):
|
||||||
|
"""Configuration for audio resampling.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
samplerate : int, default=256000
|
||||||
|
The target sample rate in Hz to resample the audio to. Must be > 0.
|
||||||
|
method : str, default="poly"
|
||||||
|
The resampling algorithm to use. Options:
|
||||||
|
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
|
||||||
|
Generally fast.
|
||||||
|
- "fourier": Resampling via Fourier method using
|
||||||
|
`scipy.signal.resample`. May handle non-integer
|
||||||
|
resampling factors differently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
method: str = "poly"
|
||||||
|
|
||||||
|
|
||||||
|
class AudioConfig(BaseConfig):
|
||||||
|
"""Configuration for loading and initial audio preprocessing."""
|
||||||
|
|
||||||
|
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||||
|
resample: ResampleConfig = Field(default_factory=ResampleConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader:
|
||||||
|
"""Factory function to create an AudioLoader based on configuration."""
|
||||||
|
config = config or AudioConfig()
|
||||||
|
return SoundEventAudioLoader(
|
||||||
|
samplerate=config.samplerate,
|
||||||
|
config=config.resample,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SoundEventAudioLoader(AudioLoader):
|
||||||
|
"""Concrete implementation of the `AudioLoader`."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
|
config: Optional[ResampleConfig] = None,
|
||||||
|
):
|
||||||
|
self.samplerate = samplerate
|
||||||
|
self.config = config or ResampleConfig()
|
||||||
|
|
||||||
|
def load_file(
|
||||||
|
self,
|
||||||
|
path: data.PathLike,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Load and preprocess audio directly from a file path."""
|
||||||
|
return load_file_audio(
|
||||||
|
path,
|
||||||
|
samplerate=self.samplerate,
|
||||||
|
config=self.config,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_recording(
|
||||||
|
self,
|
||||||
|
recording: data.Recording,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Load and preprocess the entire audio for a Recording object."""
|
||||||
|
return load_recording_audio(
|
||||||
|
recording,
|
||||||
|
samplerate=self.samplerate,
|
||||||
|
config=self.config,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_clip(
|
||||||
|
self,
|
||||||
|
clip: data.Clip,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Load and preprocess the audio segment defined by a Clip object."""
|
||||||
|
return load_clip_audio(
|
||||||
|
clip,
|
||||||
|
samplerate=self.samplerate,
|
||||||
|
config=self.config,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_file_audio(
|
||||||
|
path: data.PathLike,
|
||||||
|
samplerate: Optional[int] = None,
|
||||||
|
config: Optional[ResampleConfig] = None,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Load and preprocess audio from a file path using specified config."""
|
||||||
|
try:
|
||||||
|
recording = data.Recording.from_file(path)
|
||||||
|
except LibsndfileError as e:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Could not load the recording at path: {path}. Error: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
return load_recording_audio(
|
||||||
|
recording,
|
||||||
|
samplerate=samplerate,
|
||||||
|
config=config,
|
||||||
|
dtype=dtype,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_recording_audio(
|
||||||
|
recording: data.Recording,
|
||||||
|
samplerate: Optional[int] = None,
|
||||||
|
config: Optional[ResampleConfig] = None,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Load and preprocess the entire audio content of a recording using config."""
|
||||||
|
clip = data.Clip(
|
||||||
|
recording=recording,
|
||||||
|
start_time=0,
|
||||||
|
end_time=recording.duration,
|
||||||
|
)
|
||||||
|
return load_clip_audio(
|
||||||
|
clip,
|
||||||
|
samplerate=samplerate,
|
||||||
|
config=config,
|
||||||
|
dtype=dtype,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_clip_audio(
|
||||||
|
clip: data.Clip,
|
||||||
|
samplerate: Optional[int] = None,
|
||||||
|
config: Optional[ResampleConfig] = None,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
dtype: DTypeLike = np.float32, # type: ignore
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Load and preprocess a specific audio clip segment based on config."""
|
||||||
|
try:
|
||||||
|
wav = (
|
||||||
|
audio.load_clip(clip, audio_dir=audio_dir)
|
||||||
|
.sel(channel=0)
|
||||||
|
.astype(dtype)
|
||||||
|
)
|
||||||
|
except LibsndfileError as e:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Could not load the recording at path: {clip.recording.path}. "
|
||||||
|
f"Error: {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
if not config or not config.enabled or samplerate is None:
|
||||||
|
return wav.data.astype(dtype)
|
||||||
|
|
||||||
|
sr = int(1 / wav.time.attrs["step"])
|
||||||
|
return resample_audio(
|
||||||
|
wav.data,
|
||||||
|
sr=sr,
|
||||||
|
samplerate=samplerate,
|
||||||
|
method=config.method,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resample_audio(
|
||||||
|
wav: np.ndarray,
|
||||||
|
sr: int,
|
||||||
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
|
method: str = "poly",
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Resample an audio waveform DataArray to a target sample rate."""
|
||||||
|
if sr == samplerate:
|
||||||
|
return wav
|
||||||
|
|
||||||
|
if method == "poly":
|
||||||
|
return resample_audio_poly(
|
||||||
|
wav,
|
||||||
|
sr_orig=sr,
|
||||||
|
sr_new=samplerate,
|
||||||
|
)
|
||||||
|
elif method == "fourier":
|
||||||
|
return resample_audio_fourier(
|
||||||
|
wav,
|
||||||
|
sr_orig=sr,
|
||||||
|
sr_new=samplerate,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Resampling method '{method}' not implemented"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resample_audio_poly(
|
||||||
|
array: np.ndarray,
|
||||||
|
sr_orig: int,
|
||||||
|
sr_new: int,
|
||||||
|
axis: int = -1,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Resample a numpy array using `scipy.signal.resample_poly`.
|
||||||
|
|
||||||
|
This method is often preferred for signals when the ratio of new
|
||||||
|
to old sample rates can be expressed as a rational number. It uses
|
||||||
|
polyphase filtering.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
array : np.ndarray
|
||||||
|
The input array to resample.
|
||||||
|
sr_orig : int
|
||||||
|
The original sample rate in Hz.
|
||||||
|
sr_new : int
|
||||||
|
The target sample rate in Hz.
|
||||||
|
axis : int, default=-1
|
||||||
|
The axis of `array` along which to resample.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
The array resampled to the target sample rate.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If sample rates are not positive.
|
||||||
|
"""
|
||||||
|
gcd = np.gcd(sr_orig, sr_new)
|
||||||
|
return resample_poly(
|
||||||
|
array,
|
||||||
|
sr_new // gcd,
|
||||||
|
sr_orig // gcd,
|
||||||
|
axis=axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resample_audio_fourier(
|
||||||
|
array: np.ndarray,
|
||||||
|
sr_orig: int,
|
||||||
|
sr_new: int,
|
||||||
|
axis: int = -1,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Resample a numpy array using `scipy.signal.resample`.
|
||||||
|
|
||||||
|
This method uses FFTs to resample the signal.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
array : np.ndarray
|
||||||
|
The input array to resample.
|
||||||
|
num : int
|
||||||
|
The desired number of samples in the output array along `axis`.
|
||||||
|
axis : int, default=-1
|
||||||
|
The axis of `array` along which to resample.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray
|
||||||
|
The array resampled to have `num` samples along `axis`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `num` is negative.
|
||||||
|
"""
|
||||||
|
ratio = sr_new / sr_orig
|
||||||
|
return resample( # type: ignore
|
||||||
|
array,
|
||||||
|
int(array.shape[axis] * ratio),
|
||||||
|
axis=axis,
|
||||||
|
)
|
||||||
@ -1,10 +1,15 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
from batdetect2 import api
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
|
||||||
from batdetect2.types import ProcessingConfiguration
|
DEFAULT_MODEL_PATH = os.path.join(
|
||||||
from batdetect2.utils.detector_utils import save_results_to_file
|
os.path.dirname(os.path.dirname(__file__)),
|
||||||
|
"models",
|
||||||
|
"checkpoints",
|
||||||
|
"Net2DFast_UK_same.pth.tar",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -74,6 +79,9 @@ def detect(
|
|||||||
|
|
||||||
Input files should be short in duration e.g. < 30 seconds.
|
Input files should be short in duration e.g. < 30 seconds.
|
||||||
"""
|
"""
|
||||||
|
from batdetect2 import api
|
||||||
|
from batdetect2.utils.detector_utils import save_results_to_file
|
||||||
|
|
||||||
click.echo(f"Loading model: {args['model_path']}")
|
click.echo(f"Loading model: {args['model_path']}")
|
||||||
model, params = api.load_model(args["model_path"])
|
model, params = api.load_model(args["model_path"])
|
||||||
|
|
||||||
@ -123,7 +131,7 @@ def detect(
|
|||||||
click.echo(f" {err}")
|
click.echo(f" {err}")
|
||||||
|
|
||||||
|
|
||||||
def print_config(config: ProcessingConfiguration):
|
def print_config(config):
|
||||||
"""Print the processing configuration."""
|
"""Print the processing configuration."""
|
||||||
click.echo("\nProcessing Configuration:")
|
click.echo("\nProcessing Configuration:")
|
||||||
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from typing import Optional
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.data import load_dataset_from_config
|
|
||||||
|
|
||||||
__all__ = ["data"]
|
__all__ = ["data"]
|
||||||
|
|
||||||
@ -33,6 +32,8 @@ def summary(
|
|||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
base_dir: Optional[Path] = None,
|
base_dir: Optional[Path] = None,
|
||||||
):
|
):
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
|
||||||
base_dir = base_dir or Path.cwd()
|
base_dir = base_dir or Path.cwd()
|
||||||
dataset = load_dataset_from_config(
|
dataset = load_dataset_from_config(
|
||||||
dataset_config,
|
dataset_config,
|
||||||
|
|||||||
@ -6,18 +6,21 @@ import click
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.data import load_dataset_from_config
|
|
||||||
from batdetect2.evaluate.evaluate import evaluate
|
|
||||||
from batdetect2.train.lightning import load_model_from_checkpoint
|
|
||||||
|
|
||||||
__all__ = ["evaluate_command"]
|
__all__ = ["evaluate_command"]
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||||
|
|
||||||
|
|
||||||
@cli.command(name="evaluate")
|
@cli.command(name="evaluate")
|
||||||
@click.argument("model-path", type=click.Path(exists=True))
|
@click.argument("model-path", type=click.Path(exists=True))
|
||||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||||
@click.option("--output-dir", type=click.Path())
|
@click.option("--config", "config_path", type=click.Path())
|
||||||
@click.option("--workers", type=int)
|
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
|
||||||
|
@click.option("--experiment-name", type=str)
|
||||||
|
@click.option("--run-name", type=str)
|
||||||
|
@click.option("--workers", "num_workers", type=int)
|
||||||
@click.option(
|
@click.option(
|
||||||
"-v",
|
"-v",
|
||||||
"--verbose",
|
"--verbose",
|
||||||
@ -27,10 +30,17 @@ __all__ = ["evaluate_command"]
|
|||||||
def evaluate_command(
|
def evaluate_command(
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
test_dataset: Path,
|
test_dataset: Path,
|
||||||
output_dir: Optional[Path] = None,
|
config_path: Optional[Path],
|
||||||
workers: Optional[int] = None,
|
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
):
|
):
|
||||||
|
from batdetect2.api.base import BatDetect2API
|
||||||
|
from batdetect2.config import load_full_config
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
|
||||||
logger.remove()
|
logger.remove()
|
||||||
if verbose == 0:
|
if verbose == 0:
|
||||||
log_level = "WARNING"
|
log_level = "WARNING"
|
||||||
@ -48,16 +58,16 @@ def evaluate_command(
|
|||||||
num_annotations=len(test_annotations),
|
num_annotations=len(test_annotations),
|
||||||
)
|
)
|
||||||
|
|
||||||
model, train_config = load_model_from_checkpoint(model_path)
|
config = None
|
||||||
|
if config_path is not None:
|
||||||
|
config = load_full_config(config_path)
|
||||||
|
|
||||||
df, results = evaluate(
|
api = BatDetect2API.from_checkpoint(model_path, config=config)
|
||||||
model,
|
|
||||||
|
api.evaluate(
|
||||||
test_annotations,
|
test_annotations,
|
||||||
config=train_config,
|
num_workers=num_workers,
|
||||||
num_workers=workers,
|
output_dir=output_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
run_name=run_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(results)
|
|
||||||
|
|
||||||
if output_dir:
|
|
||||||
df.to_csv(output_dir / "results.csv")
|
|
||||||
|
|||||||
@ -6,13 +6,6 @@ import click
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.data import load_dataset_from_config
|
|
||||||
from batdetect2.targets import load_target_config
|
|
||||||
from batdetect2.train import (
|
|
||||||
FullTrainingConfig,
|
|
||||||
load_full_training_config,
|
|
||||||
train,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = ["train_command"]
|
__all__ = ["train_command"]
|
||||||
|
|
||||||
@ -20,8 +13,8 @@ __all__ = ["train_command"]
|
|||||||
@cli.command(name="train")
|
@cli.command(name="train")
|
||||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||||
@click.option("--model-path", type=click.Path(exists=True))
|
@click.option("--model", "model_path", type=click.Path(exists=True))
|
||||||
@click.option("--targets", type=click.Path(exists=True))
|
@click.option("--targets", "targets_config", type=click.Path(exists=True))
|
||||||
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
||||||
@click.option("--log-dir", type=click.Path(exists=True))
|
@click.option("--log-dir", type=click.Path(exists=True))
|
||||||
@click.option("--config", type=click.Path(exists=True))
|
@click.option("--config", type=click.Path(exists=True))
|
||||||
@ -44,7 +37,7 @@ def train_command(
|
|||||||
ckpt_dir: Optional[Path] = None,
|
ckpt_dir: Optional[Path] = None,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Optional[Path] = None,
|
||||||
config: Optional[Path] = None,
|
config: Optional[Path] = None,
|
||||||
targets: Optional[Path] = None,
|
targets_config: Optional[Path] = None,
|
||||||
config_field: Optional[str] = None,
|
config_field: Optional[str] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
@ -53,6 +46,14 @@ def train_command(
|
|||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
):
|
):
|
||||||
|
from batdetect2.api.base import BatDetect2API
|
||||||
|
from batdetect2.config import (
|
||||||
|
BatDetect2Config,
|
||||||
|
load_full_config,
|
||||||
|
)
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
from batdetect2.targets import load_target_config
|
||||||
|
|
||||||
logger.remove()
|
logger.remove()
|
||||||
if verbose == 0:
|
if verbose == 0:
|
||||||
log_level = "WARNING"
|
log_level = "WARNING"
|
||||||
@ -61,21 +62,20 @@ def train_command(
|
|||||||
else:
|
else:
|
||||||
log_level = "DEBUG"
|
log_level = "DEBUG"
|
||||||
logger.add(sys.stderr, level=log_level)
|
logger.add(sys.stderr, level=log_level)
|
||||||
|
|
||||||
logger.info("Initiating training process...")
|
logger.info("Initiating training process...")
|
||||||
|
|
||||||
logger.info("Loading training configuration...")
|
logger.info("Loading configuration...")
|
||||||
|
|
||||||
conf = (
|
conf = (
|
||||||
load_full_training_config(config, field=config_field)
|
load_full_config(config, field=config_field)
|
||||||
if config is not None
|
if config is not None
|
||||||
else FullTrainingConfig()
|
else BatDetect2Config()
|
||||||
)
|
)
|
||||||
|
|
||||||
if targets is not None:
|
if targets_config is not None:
|
||||||
logger.info("Loading targets configuration...")
|
logger.info("Loading targets configuration...")
|
||||||
targets_config = load_target_config(targets)
|
conf = conf.model_copy(
|
||||||
conf = conf.model_copy(update=dict(targets=targets_config))
|
update=dict(targets=load_target_config(targets_config))
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Loading training dataset...")
|
logger.info("Loading training dataset...")
|
||||||
train_annotations = load_dataset_from_config(train_dataset)
|
train_annotations = load_dataset_from_config(train_dataset)
|
||||||
@ -95,16 +95,20 @@ def train_command(
|
|||||||
logger.debug("No validation directory provided.")
|
logger.debug("No validation directory provided.")
|
||||||
|
|
||||||
logger.info("Configuration and data loaded. Starting training...")
|
logger.info("Configuration and data loaded. Starting training...")
|
||||||
train(
|
|
||||||
|
if model_path is None:
|
||||||
|
api = BatDetect2API.from_config(conf)
|
||||||
|
else:
|
||||||
|
api = BatDetect2API.from_checkpoint(model_path)
|
||||||
|
|
||||||
|
return api.train(
|
||||||
train_annotations=train_annotations,
|
train_annotations=train_annotations,
|
||||||
val_annotations=val_annotations,
|
val_annotations=val_annotations,
|
||||||
config=conf,
|
|
||||||
model_path=model_path,
|
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
val_workers=val_workers,
|
val_workers=val_workers,
|
||||||
experiment_name=experiment_name,
|
|
||||||
log_dir=log_dir,
|
|
||||||
checkpoint_dir=ckpt_dir,
|
checkpoint_dir=ckpt_dir,
|
||||||
seed=seed,
|
log_dir=log_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
|
seed=seed,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from soundevent import data
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
from soundevent.types import ClassMapper
|
from soundevent.types import ClassMapper
|
||||||
|
|
||||||
from batdetect2.targets.terms import get_term_from_key
|
|
||||||
from batdetect2.types import (
|
from batdetect2.types import (
|
||||||
Annotation,
|
Annotation,
|
||||||
AudioLoaderAnnotationGroup,
|
AudioLoaderAnnotationGroup,
|
||||||
@ -173,18 +172,9 @@ def annotation_to_sound_event_annotation(
|
|||||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(
|
data.Tag(key=label_key, value=annotation["class"]),
|
||||||
term=get_term_from_key(label_key),
|
data.Tag(key=event_key, value=annotation["event"]),
|
||||||
value=annotation["class"],
|
data.Tag(key=individual_key, value=str(annotation["individual"])),
|
||||||
),
|
|
||||||
data.Tag(
|
|
||||||
term=get_term_from_key(event_key),
|
|
||||||
value=annotation["event"],
|
|
||||||
),
|
|
||||||
data.Tag(
|
|
||||||
term=get_term_from_key(individual_key),
|
|
||||||
value=str(annotation["individual"]),
|
|
||||||
),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -219,17 +209,11 @@ def annotation_to_sound_event_prediction(
|
|||||||
tags=[
|
tags=[
|
||||||
data.PredictedTag(
|
data.PredictedTag(
|
||||||
score=annotation["class_prob"],
|
score=annotation["class_prob"],
|
||||||
tag=data.Tag(
|
tag=data.Tag(key=label_key, value=annotation["class"]),
|
||||||
term=get_term_from_key(label_key),
|
|
||||||
value=annotation["class"],
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
data.PredictedTag(
|
data.PredictedTag(
|
||||||
score=annotation["det_prob"],
|
score=annotation["det_prob"],
|
||||||
tag=data.Tag(
|
tag=data.Tag(key=event_key, value=annotation["event"]),
|
||||||
term=get_term_from_key(event_key),
|
|
||||||
value=annotation["event"],
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
40
src/batdetect2/config.py
Normal file
40
src/batdetect2/config.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
from batdetect2.audio import AudioConfig
|
||||||
|
from batdetect2.core import BaseConfig
|
||||||
|
from batdetect2.core.configs import load_config
|
||||||
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
|
from batdetect2.models.config import BackboneConfig
|
||||||
|
from batdetect2.postprocess.config import PostprocessConfig
|
||||||
|
from batdetect2.preprocess.config import PreprocessingConfig
|
||||||
|
from batdetect2.targets.config import TargetConfig
|
||||||
|
from batdetect2.train.config import TrainingConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BatDetect2Config",
|
||||||
|
"load_full_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BatDetect2Config(BaseConfig):
|
||||||
|
config_version: Literal["v1"] = "v1"
|
||||||
|
|
||||||
|
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||||
|
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||||
|
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||||
|
preprocess: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||||
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def load_full_config(
|
||||||
|
path: PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> BatDetect2Config:
|
||||||
|
return load_config(path, schema=BatDetect2Config, field=field)
|
||||||
8
src/batdetect2/core/__init__.py
Normal file
8
src/batdetect2/core/__init__.py
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.core.registries import Registry
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseConfig",
|
||||||
|
"load_config",
|
||||||
|
"Registry",
|
||||||
|
]
|
||||||
@ -27,7 +27,7 @@ class BaseConfig(BaseModel):
|
|||||||
and serialization capabilities.
|
and serialization capabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="ignore")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
def to_yaml_string(
|
def to_yaml_string(
|
||||||
self,
|
self,
|
||||||
@ -1,7 +1,13 @@
|
|||||||
|
import sys
|
||||||
from typing import Generic, Protocol, Type, TypeVar
|
from typing import Generic, Protocol, Type, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import assert_type
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
from typing import ParamSpec
|
||||||
|
else:
|
||||||
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Registry",
|
"Registry",
|
||||||
@ -39,7 +45,6 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
config_cls: Type[T_Config],
|
config_cls: Type[T_Config],
|
||||||
logic_cls: LogicProtocol[T_Config, T_Type, P_Type],
|
logic_cls: LogicProtocol[T_Config, T_Type, P_Type],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""A decorator factory to register a new item."""
|
|
||||||
fields = config_cls.model_fields
|
fields = config_cls.model_fields
|
||||||
|
|
||||||
if "name" not in fields:
|
if "name" not in fields:
|
||||||
@ -18,7 +18,7 @@ from uuid import uuid5
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data, io
|
from soundevent import data, io
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -33,7 +33,7 @@ from loguru import logger
|
|||||||
from pydantic import Field, ValidationError
|
from pydantic import Field, ValidationError
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data.annotations.legacy import (
|
from batdetect2.data.annotations.legacy import (
|
||||||
FileAnnotation,
|
FileAnnotation,
|
||||||
file_annotation_to_clip,
|
file_annotation_to_clip,
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnnotatedDataset",
|
"AnnotatedDataset",
|
||||||
|
|||||||
@ -5,8 +5,8 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data._core import Registry
|
from batdetect2.core.registries import Registry
|
||||||
|
|
||||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||||
|
|
||||||
|
|||||||
@ -25,7 +25,7 @@ from loguru import logger
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data, io
|
from soundevent import data, io
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.data.annotations import (
|
from batdetect2.data.annotations import (
|
||||||
AnnotatedDataset,
|
AnnotatedDataset,
|
||||||
AnnotationFormats,
|
AnnotationFormats,
|
||||||
|
|||||||
@ -4,8 +4,8 @@ from typing import Annotated, Dict, List, Literal, Optional, Union
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data._core import Registry
|
from batdetect2.core.registries import Registry
|
||||||
from batdetect2.data.conditions import (
|
from batdetect2.data.conditions import (
|
||||||
SoundEventCondition,
|
SoundEventCondition,
|
||||||
SoundEventConditionConfig,
|
SoundEventConditionConfig,
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||||
|
from batdetect2.evaluate.evaluate import evaluate
|
||||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EvaluationConfig",
|
"EvaluationConfig",
|
||||||
"load_evaluation_config",
|
"load_evaluation_config",
|
||||||
|
"evaluate",
|
||||||
"Evaluator",
|
"Evaluator",
|
||||||
"build_evaluator",
|
"build_evaluator",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -4,8 +4,8 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import compute_affinity
|
from soundevent.evaluation import compute_affinity
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data._core import Registry
|
from batdetect2.core.registries import Registry
|
||||||
from batdetect2.typing.evaluate import AffinityFunction
|
from batdetect2.typing.evaluate import AffinityFunction
|
||||||
|
|
||||||
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
||||||
|
|||||||
@ -3,14 +3,15 @@ from typing import List, Optional
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
|
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
|
||||||
from batdetect2.evaluate.metrics import (
|
from batdetect2.evaluate.metrics import (
|
||||||
ClassificationAPConfig,
|
ClassificationAPConfig,
|
||||||
DetectionAPConfig,
|
DetectionAPConfig,
|
||||||
MetricConfig,
|
MetricConfig,
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.plots import ExampleGalleryConfig, PlotConfig
|
from batdetect2.evaluate.plots import PlotConfig
|
||||||
|
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EvaluationConfig",
|
"EvaluationConfig",
|
||||||
@ -20,18 +21,15 @@ __all__ = [
|
|||||||
|
|
||||||
class EvaluationConfig(BaseConfig):
|
class EvaluationConfig(BaseConfig):
|
||||||
ignore_start_end: float = 0.01
|
ignore_start_end: float = 0.01
|
||||||
match: MatchConfig = Field(default_factory=StartTimeMatchConfig)
|
match_strategy: MatchConfig = Field(default_factory=StartTimeMatchConfig)
|
||||||
metrics: List[MetricConfig] = Field(
|
metrics: List[MetricConfig] = Field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
DetectionAPConfig(),
|
DetectionAPConfig(),
|
||||||
ClassificationAPConfig(),
|
ClassificationAPConfig(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
plots: List[PlotConfig] = Field(
|
plots: List[PlotConfig] = Field(default_factory=list)
|
||||||
default_factory=lambda: [
|
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||||
ExampleGalleryConfig(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_evaluation_config(
|
def load_evaluation_config(
|
||||||
|
|||||||
144
src/batdetect2/evaluate/dataset.py
Normal file
144
src/batdetect2/evaluate/dataset.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
from typing import List, NamedTuple, Optional, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
||||||
|
from batdetect2.audio.clips import PaddedClipConfig
|
||||||
|
from batdetect2.core import BaseConfig
|
||||||
|
from batdetect2.core.arrays import adjust_width
|
||||||
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.typing import (
|
||||||
|
AudioLoader,
|
||||||
|
ClipperProtocol,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TestDataset",
|
||||||
|
"build_test_dataset",
|
||||||
|
"build_test_loader",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestExample(NamedTuple):
|
||||||
|
spec: torch.Tensor
|
||||||
|
idx: torch.Tensor
|
||||||
|
start_time: torch.Tensor
|
||||||
|
end_time: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataset(Dataset[TestExample]):
|
||||||
|
clip_annotations: List[data.ClipAnnotation]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
audio_loader: AudioLoader,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
clipper: Optional[ClipperProtocol] = None,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
):
|
||||||
|
self.clip_annotations = list(clip_annotations)
|
||||||
|
self.clipper = clipper
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
self.audio_loader = audio_loader
|
||||||
|
self.audio_dir = audio_dir
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.clip_annotations)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> TestExample:
|
||||||
|
clip_annotation = self.clip_annotations[idx]
|
||||||
|
|
||||||
|
if self.clipper is not None:
|
||||||
|
clip_annotation = self.clipper(clip_annotation)
|
||||||
|
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||||
|
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
||||||
|
spectrogram = self.preprocessor(wav_tensor)
|
||||||
|
return TestExample(
|
||||||
|
spec=spectrogram,
|
||||||
|
idx=torch.tensor(idx),
|
||||||
|
start_time=torch.tensor(clip.start_time),
|
||||||
|
end_time=torch.tensor(clip.end_time),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoaderConfig(BaseConfig):
|
||||||
|
num_workers: int = 0
|
||||||
|
clipping_strategy: ClipConfig = Field(
|
||||||
|
default_factory=lambda: PaddedClipConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_test_loader(
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
config: Optional[TestLoaderConfig] = None,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
) -> DataLoader[TestExample]:
|
||||||
|
logger.info("Building test data loader...")
|
||||||
|
config = config or TestLoaderConfig()
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Test data loader config: \n{config}",
|
||||||
|
config=lambda: config.to_yaml_string(exclude_none=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
test_dataset = build_test_dataset(
|
||||||
|
clip_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_workers = num_workers or config.num_workers
|
||||||
|
return DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
collate_fn=_collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_test_dataset(
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
config: Optional[TestLoaderConfig] = None,
|
||||||
|
) -> TestDataset:
|
||||||
|
logger.info("Building training dataset...")
|
||||||
|
config = config or TestLoaderConfig()
|
||||||
|
|
||||||
|
clipper = build_clipper(config=config.clipping_strategy)
|
||||||
|
|
||||||
|
if audio_loader is None:
|
||||||
|
audio_loader = build_audio_loader()
|
||||||
|
|
||||||
|
if preprocessor is None:
|
||||||
|
preprocessor = build_preprocessor()
|
||||||
|
|
||||||
|
return TestDataset(
|
||||||
|
clip_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
clipper=clipper,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _collate_fn(batch: List[TestExample]) -> TestExample:
|
||||||
|
max_width = max(item.spec.shape[-1] for item in batch)
|
||||||
|
return TestExample(
|
||||||
|
spec=torch.stack(
|
||||||
|
[adjust_width(item.spec, max_width) for item in batch]
|
||||||
|
),
|
||||||
|
idx=torch.stack([item.idx for item in batch]),
|
||||||
|
start_time=torch.stack([item.start_time for item in batch]),
|
||||||
|
end_time=torch.stack([item.end_time for item in batch]),
|
||||||
|
)
|
||||||
@ -1,92 +1,68 @@
|
|||||||
from typing import List, Optional, Tuple
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Optional, Sequence
|
||||||
|
|
||||||
import pandas as pd
|
from lightning import Trainer
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.evaluate.dataframe import extract_matches_dataframe
|
from batdetect2.audio import build_audio_loader
|
||||||
|
from batdetect2.evaluate.dataset import build_test_loader
|
||||||
from batdetect2.evaluate.evaluator import build_evaluator
|
from batdetect2.evaluate.evaluator import build_evaluator
|
||||||
from batdetect2.evaluate.metrics import ClassificationAP, DetectionAP
|
from batdetect2.evaluate.lightning import EvaluationModule
|
||||||
|
from batdetect2.logging import build_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.plotting.clips import build_audio_loader
|
|
||||||
from batdetect2.postprocess import get_raw_predictions
|
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.train.config import FullTrainingConfig
|
|
||||||
from batdetect2.train.dataset import ValidationDataset
|
if TYPE_CHECKING:
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.train.train import build_val_loader
|
from batdetect2.typing import (
|
||||||
|
AudioLoader,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
DEFAULT_OUTPUT_DIR: Path = Path("outputs") / "evaluations"
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
model: Model,
|
model: Model,
|
||||||
test_annotations: List[data.ClipAnnotation],
|
test_annotations: Sequence[data.ClipAnnotation],
|
||||||
config: Optional[FullTrainingConfig] = None,
|
targets: Optional["TargetProtocol"] = None,
|
||||||
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
|
config: Optional["BatDetect2Config"] = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
) -> Tuple[pd.DataFrame, dict]:
|
output_dir: data.PathLike = DEFAULT_OUTPUT_DIR,
|
||||||
config = config or FullTrainingConfig()
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config.preprocess.audio)
|
config = config or BatDetect2Config()
|
||||||
|
|
||||||
preprocessor = build_preprocessor(config.preprocess)
|
audio_loader = audio_loader or build_audio_loader()
|
||||||
|
|
||||||
targets = build_targets(config.targets)
|
preprocessor = preprocessor or build_preprocessor(
|
||||||
|
input_samplerate=audio_loader.samplerate,
|
||||||
labeller = build_clip_labeler(
|
|
||||||
targets,
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
config=config.train.labels,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loader = build_val_loader(
|
targets = targets or build_targets()
|
||||||
|
|
||||||
|
loader = build_test_loader(
|
||||||
test_annotations,
|
test_annotations,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
labeller=labeller,
|
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config.train.val_loader,
|
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset: ValidationDataset = loader.dataset # type: ignore
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||||
|
|
||||||
clip_annotations = []
|
logger = build_logger(
|
||||||
predictions = []
|
config.evaluation.logger,
|
||||||
|
log_dir=Path(output_dir),
|
||||||
evaluator = build_evaluator(config=config.evaluation)
|
experiment_name=experiment_name,
|
||||||
|
run_name=run_name,
|
||||||
for batch in loader:
|
)
|
||||||
outputs = model.detector(batch.spec)
|
module = EvaluationModule(model, evaluator)
|
||||||
|
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
||||||
clip_annotations = [
|
return trainer.test(module, loader)
|
||||||
dataset.clip_annotations[int(example_idx)]
|
|
||||||
for example_idx in batch.idx
|
|
||||||
]
|
|
||||||
|
|
||||||
predictions = get_raw_predictions(
|
|
||||||
outputs,
|
|
||||||
start_times=[
|
|
||||||
clip_annotation.clip.start_time
|
|
||||||
for clip_annotation in clip_annotations
|
|
||||||
],
|
|
||||||
targets=targets,
|
|
||||||
postprocessor=model.postprocessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
clip_annotations.extend(clip_annotations)
|
|
||||||
predictions.extend(predictions)
|
|
||||||
|
|
||||||
matches = evaluator.evaluate(clip_annotations, predictions)
|
|
||||||
df = extract_matches_dataframe(matches)
|
|
||||||
|
|
||||||
metrics = [
|
|
||||||
DetectionAP(),
|
|
||||||
ClassificationAP(class_names=targets.class_names),
|
|
||||||
]
|
|
||||||
|
|
||||||
results = {
|
|
||||||
name: value
|
|
||||||
for metric in metrics
|
|
||||||
for name, value in metric(matches).items()
|
|
||||||
}
|
|
||||||
|
|
||||||
return df, results
|
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from batdetect2.evaluate.plots import build_plotter
|
|||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.typing.evaluate import (
|
from batdetect2.typing.evaluate import (
|
||||||
ClipEvaluation,
|
ClipEvaluation,
|
||||||
|
EvaluatorProtocol,
|
||||||
MatcherProtocol,
|
MatcherProtocol,
|
||||||
MetricsProtocol,
|
MetricsProtocol,
|
||||||
PlotterProtocol,
|
PlotterProtocol,
|
||||||
@ -135,10 +136,10 @@ def build_evaluator(
|
|||||||
matcher: Optional[MatcherProtocol] = None,
|
matcher: Optional[MatcherProtocol] = None,
|
||||||
plots: Optional[List[PlotterProtocol]] = None,
|
plots: Optional[List[PlotterProtocol]] = None,
|
||||||
metrics: Optional[List[MetricsProtocol]] = None,
|
metrics: Optional[List[MetricsProtocol]] = None,
|
||||||
) -> Evaluator:
|
) -> EvaluatorProtocol:
|
||||||
config = config or EvaluationConfig()
|
config = config or EvaluationConfig()
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
matcher = matcher or build_matcher(config.match)
|
matcher = matcher or build_matcher(config.match_strategy)
|
||||||
|
|
||||||
if metrics is None:
|
if metrics is None:
|
||||||
metrics = [
|
metrics = [
|
||||||
@ -147,7 +148,10 @@ def build_evaluator(
|
|||||||
]
|
]
|
||||||
|
|
||||||
if plots is None:
|
if plots is None:
|
||||||
plots = [build_plotter(config) for config in config.plots]
|
plots = [
|
||||||
|
build_plotter(config, targets.class_names)
|
||||||
|
for config in config.plots
|
||||||
|
]
|
||||||
|
|
||||||
return Evaluator(
|
return Evaluator(
|
||||||
config=config,
|
config=config,
|
||||||
|
|||||||
86
src/batdetect2/evaluate/lightning.py
Normal file
86
src/batdetect2/evaluate/lightning.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from lightning import LightningModule
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||||
|
from batdetect2.evaluate.tables import FullEvaluationTable
|
||||||
|
from batdetect2.logging import get_image_logger, get_table_logger
|
||||||
|
from batdetect2.models import Model
|
||||||
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
|
from batdetect2.typing import ClipEvaluation, EvaluatorProtocol
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationModule(LightningModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Model,
|
||||||
|
evaluator: EvaluatorProtocol,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.evaluator = evaluator
|
||||||
|
|
||||||
|
self.clip_evaluations = []
|
||||||
|
|
||||||
|
def test_step(self, batch: TestExample):
|
||||||
|
dataset = self.get_dataset()
|
||||||
|
clip_annotations = [
|
||||||
|
dataset.clip_annotations[int(example_idx)]
|
||||||
|
for example_idx in batch.idx
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = self.model.detector(batch.spec)
|
||||||
|
clip_detections = self.model.postprocessor(
|
||||||
|
outputs,
|
||||||
|
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||||
|
)
|
||||||
|
predictions = [
|
||||||
|
to_raw_predictions(
|
||||||
|
clip_dets.numpy(),
|
||||||
|
targets=self.evaluator.targets,
|
||||||
|
)
|
||||||
|
for clip_dets in clip_detections
|
||||||
|
]
|
||||||
|
|
||||||
|
self.clip_evaluations.extend(
|
||||||
|
self.evaluator.evaluate(clip_annotations, predictions)
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_test_epoch_start(self):
|
||||||
|
self.clip_evaluations = []
|
||||||
|
|
||||||
|
def on_test_epoch_end(self):
|
||||||
|
self.log_metrics(self.clip_evaluations)
|
||||||
|
self.plot_examples(self.clip_evaluations)
|
||||||
|
self.log_table(self.clip_evaluations)
|
||||||
|
|
||||||
|
def log_table(self, evaluated_clips: Sequence[ClipEvaluation]):
|
||||||
|
table_logger = get_table_logger(self.logger) # type: ignore
|
||||||
|
|
||||||
|
if table_logger is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
df = FullEvaluationTable()(evaluated_clips)
|
||||||
|
table_logger("full_evaluation", df, 0)
|
||||||
|
|
||||||
|
def plot_examples(self, evaluated_clips: Sequence[ClipEvaluation]):
|
||||||
|
plotter = get_image_logger(self.logger) # type: ignore
|
||||||
|
|
||||||
|
if plotter is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
|
||||||
|
plotter(figure_name, fig, self.global_step)
|
||||||
|
|
||||||
|
def log_metrics(self, evaluated_clips: Sequence[ClipEvaluation]):
|
||||||
|
metrics = self.evaluator.compute_metrics(evaluated_clips)
|
||||||
|
self.log_dict(metrics)
|
||||||
|
|
||||||
|
def get_dataset(self) -> TestDataset:
|
||||||
|
dataloaders = self.trainer.test_dataloaders
|
||||||
|
assert isinstance(dataloaders, DataLoader)
|
||||||
|
dataset = dataloaders.dataset
|
||||||
|
assert isinstance(dataset, TestDataset)
|
||||||
|
return dataset
|
||||||
@ -8,8 +8,8 @@ from soundevent.evaluation import compute_affinity
|
|||||||
from soundevent.evaluation import match_geometries as optimal_match
|
from soundevent.evaluation import match_geometries as optimal_match
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data._core import Registry
|
from batdetect2.core.registries import Registry
|
||||||
from batdetect2.evaluate.affinity import (
|
from batdetect2.evaluate.affinity import (
|
||||||
AffinityConfig,
|
AffinityConfig,
|
||||||
GeometricIOUConfig,
|
GeometricIOUConfig,
|
||||||
@ -111,7 +111,7 @@ def match(
|
|||||||
|
|
||||||
|
|
||||||
class StartTimeMatchConfig(BaseConfig):
|
class StartTimeMatchConfig(BaseConfig):
|
||||||
name: Literal["start_time"] = "start_time"
|
name: Literal["start_time_match"] = "start_time_match"
|
||||||
distance_threshold: float = 0.01
|
distance_threshold: float = 0.01
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from collections import defaultdict
|
||||||
from collections.abc import Callable, Mapping
|
from collections.abc import Callable, Mapping
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
@ -12,13 +13,10 @@ from typing import (
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from sklearn import metrics
|
from sklearn import metrics, preprocessing
|
||||||
from sklearn.preprocessing import label_binarize
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.data._core import Registry
|
from batdetect2.typing import ClipEvaluation, MetricsProtocol
|
||||||
from batdetect2.typing import MetricsProtocol
|
|
||||||
from batdetect2.typing.evaluate import ClipEvaluation
|
|
||||||
|
|
||||||
__all__ = ["DetectionAP", "ClassificationAP"]
|
__all__ = ["DetectionAP", "ClassificationAP"]
|
||||||
|
|
||||||
@ -26,57 +24,18 @@ __all__ = ["DetectionAP", "ClassificationAP"]
|
|||||||
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
|
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
|
||||||
|
|
||||||
|
|
||||||
AveragePrecisionImplementation = Literal["sklearn", "pascal_voc"]
|
APImplementation = Literal["sklearn", "pascal_voc"]
|
||||||
|
|
||||||
|
|
||||||
class DetectionAPConfig(BaseConfig):
|
class DetectionAPConfig(BaseConfig):
|
||||||
name: Literal["detection_ap"] = "detection_ap"
|
name: Literal["detection_ap"] = "detection_ap"
|
||||||
implementation: AveragePrecisionImplementation = "pascal_voc"
|
ap_implementation: APImplementation = "pascal_voc"
|
||||||
|
|
||||||
|
|
||||||
def pascal_voc_average_precision(y_true, y_score) -> float:
|
|
||||||
y_true = np.array(y_true)
|
|
||||||
y_score = np.array(y_score)
|
|
||||||
|
|
||||||
sort_ind = np.argsort(y_score)[::-1]
|
|
||||||
y_true_sorted = y_true[sort_ind]
|
|
||||||
|
|
||||||
num_positives = y_true.sum()
|
|
||||||
false_pos_c = np.cumsum(1 - y_true_sorted)
|
|
||||||
true_pos_c = np.cumsum(y_true_sorted)
|
|
||||||
|
|
||||||
recall = true_pos_c / num_positives
|
|
||||||
precision = true_pos_c / np.maximum(
|
|
||||||
true_pos_c + false_pos_c,
|
|
||||||
np.finfo(np.float64).eps,
|
|
||||||
)
|
|
||||||
|
|
||||||
precision[np.isnan(precision)] = 0
|
|
||||||
recall[np.isnan(recall)] = 0
|
|
||||||
|
|
||||||
# pascal 12 way
|
|
||||||
mprec = np.hstack((0, precision, 0))
|
|
||||||
mrec = np.hstack((0, recall, 1))
|
|
||||||
for ii in range(mprec.shape[0] - 2, -1, -1):
|
|
||||||
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
|
||||||
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
|
||||||
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
|
||||||
|
|
||||||
return ave_prec
|
|
||||||
|
|
||||||
|
|
||||||
_ap_impl_mapping: Mapping[
|
|
||||||
AveragePrecisionImplementation, Callable[[Any, Any], float]
|
|
||||||
] = {
|
|
||||||
"sklearn": metrics.average_precision_score,
|
|
||||||
"pascal_voc": pascal_voc_average_precision,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionAP(MetricsProtocol):
|
class DetectionAP(MetricsProtocol):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
implementation: AveragePrecisionImplementation = "pascal_voc",
|
implementation: APImplementation = "pascal_voc",
|
||||||
):
|
):
|
||||||
self.implementation = implementation
|
self.implementation = implementation
|
||||||
self.metric = _ap_impl_mapping[self.implementation]
|
self.metric = _ap_impl_mapping[self.implementation]
|
||||||
@ -96,14 +55,43 @@ class DetectionAP(MetricsProtocol):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
|
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
|
||||||
return cls(implementation=config.implementation)
|
return cls(implementation=config.ap_implementation)
|
||||||
|
|
||||||
|
|
||||||
metrics_registry.register(DetectionAPConfig, DetectionAP)
|
metrics_registry.register(DetectionAPConfig, DetectionAP)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionROCAUCConfig(BaseConfig):
|
||||||
|
name: Literal["detection_roc_auc"] = "detection_roc_auc"
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionROCAUC(MetricsProtocol):
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true, y_score = zip(
|
||||||
|
*[
|
||||||
|
(match.gt_det, match.pred_score)
|
||||||
|
for clip_eval in clip_evaluations
|
||||||
|
for match in clip_eval.matches
|
||||||
|
]
|
||||||
|
)
|
||||||
|
score = float(metrics.roc_auc_score(y_true, y_score))
|
||||||
|
return {"detection_ROC_AUC": score}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls, config: DetectionROCAUCConfig, class_names: List[str]
|
||||||
|
):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
metrics_registry.register(DetectionROCAUCConfig, DetectionROCAUC)
|
||||||
|
|
||||||
|
|
||||||
class ClassificationAPConfig(BaseConfig):
|
class ClassificationAPConfig(BaseConfig):
|
||||||
name: Literal["classification_ap"] = "classification_ap"
|
name: Literal["classification_ap"] = "classification_ap"
|
||||||
|
ap_implementation: APImplementation = "pascal_voc"
|
||||||
include: Optional[List[str]] = None
|
include: Optional[List[str]] = None
|
||||||
exclude: Optional[List[str]] = None
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
@ -112,7 +100,7 @@ class ClassificationAP(MetricsProtocol):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
class_names: List[str],
|
class_names: List[str],
|
||||||
implementation: AveragePrecisionImplementation = "pascal_voc",
|
implementation: APImplementation = "pascal_voc",
|
||||||
include: Optional[List[str]] = None,
|
include: Optional[List[str]] = None,
|
||||||
exclude: Optional[List[str]] = None,
|
exclude: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
@ -163,7 +151,7 @@ class ClassificationAP(MetricsProtocol):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
y_true = label_binarize(y_true, classes=self.class_names)
|
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
|
||||||
y_pred = np.stack(y_pred)
|
y_pred = np.stack(y_pred)
|
||||||
|
|
||||||
class_scores = {}
|
class_scores = {}
|
||||||
@ -193,6 +181,7 @@ class ClassificationAP(MetricsProtocol):
|
|||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
class_names,
|
class_names,
|
||||||
|
implementation=config.ap_implementation,
|
||||||
include=config.include,
|
include=config.include,
|
||||||
exclude=config.exclude,
|
exclude=config.exclude,
|
||||||
)
|
)
|
||||||
@ -201,11 +190,523 @@ class ClassificationAP(MetricsProtocol):
|
|||||||
metrics_registry.register(ClassificationAPConfig, ClassificationAP)
|
metrics_registry.register(ClassificationAPConfig, ClassificationAP)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationROCAUCConfig(BaseConfig):
|
||||||
|
name: Literal["classification_roc_auc"] = "classification_roc_auc"
|
||||||
|
include: Optional[List[str]] = None
|
||||||
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationROCAUC(MetricsProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
class_names: List[str],
|
||||||
|
include: Optional[List[str]] = None,
|
||||||
|
exclude: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
self.class_names = class_names
|
||||||
|
self.selected = class_names
|
||||||
|
|
||||||
|
if include is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name in include
|
||||||
|
]
|
||||||
|
|
||||||
|
if exclude is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name not in exclude
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_pred = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
# Ignore generic unclassified targets
|
||||||
|
if match.gt_det and match.gt_class is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(
|
||||||
|
match.gt_class
|
||||||
|
if match.gt_class is not None
|
||||||
|
else "__NONE__"
|
||||||
|
)
|
||||||
|
|
||||||
|
y_pred.append(
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
match.pred_class_scores.get(name, 0)
|
||||||
|
for name in self.class_names
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
|
||||||
|
y_pred = np.stack(y_pred)
|
||||||
|
|
||||||
|
class_scores = {}
|
||||||
|
for class_index, class_name in enumerate(self.class_names):
|
||||||
|
y_true_class = y_true[:, class_index]
|
||||||
|
y_pred_class = y_pred[:, class_index]
|
||||||
|
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
|
||||||
|
class_scores[class_name] = float(class_roc_auc)
|
||||||
|
|
||||||
|
mean_roc_auc = np.mean(
|
||||||
|
[value for value in class_scores.values() if value != 0]
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"classification_macro_average_ROC_AUC": float(mean_roc_auc),
|
||||||
|
**{
|
||||||
|
f"classification_ROC_AUC/{class_name}": class_scores[
|
||||||
|
class_name
|
||||||
|
]
|
||||||
|
for class_name in self.selected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ClassificationROCAUCConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
class_names,
|
||||||
|
include=config.include,
|
||||||
|
exclude=config.exclude,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
metrics_registry.register(ClassificationROCAUCConfig, ClassificationROCAUC)
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassAPConfig(BaseConfig):
|
||||||
|
name: Literal["top_class_ap"] = "top_class_ap"
|
||||||
|
ap_implementation: APImplementation = "pascal_voc"
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassAP(MetricsProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
implementation: APImplementation = "pascal_voc",
|
||||||
|
):
|
||||||
|
self.implementation = implementation
|
||||||
|
self.metric = _ap_impl_mapping[self.implementation]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
# Ignore generic unclassified targets
|
||||||
|
if match.gt_det and match.gt_class is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
top_class = match.pred_class
|
||||||
|
|
||||||
|
y_true.append(top_class == match.gt_class)
|
||||||
|
y_score.append(match.pred_class_score)
|
||||||
|
|
||||||
|
score = float(self.metric(y_true, y_score))
|
||||||
|
return {"top_class_AP": score}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: TopClassAPConfig, class_names: List[str]):
|
||||||
|
return cls(implementation=config.ap_implementation)
|
||||||
|
|
||||||
|
|
||||||
|
metrics_registry.register(TopClassAPConfig, TopClassAP)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationBalancedAccuracyConfig(BaseConfig):
|
||||||
|
name: Literal["classification_balanced_accuracy"] = (
|
||||||
|
"classification_balanced_accuracy"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationBalancedAccuracy(MetricsProtocol):
|
||||||
|
def __init__(self, class_names: List[str]):
|
||||||
|
self.class_names = class_names
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_pred = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
top_class = match.pred_class
|
||||||
|
|
||||||
|
# Focus on matches
|
||||||
|
if match.gt_class is None or top_class is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(self.class_names.index(match.gt_class))
|
||||||
|
y_pred.append(self.class_names.index(top_class))
|
||||||
|
|
||||||
|
score = float(metrics.balanced_accuracy_score(y_true, y_pred))
|
||||||
|
return {"classification_balanced_accuracy": score}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ClassificationBalancedAccuracyConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls(class_names)
|
||||||
|
|
||||||
|
|
||||||
|
metrics_registry.register(
|
||||||
|
ClassificationBalancedAccuracyConfig,
|
||||||
|
ClassificationBalancedAccuracy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionAPConfig(BaseConfig):
|
||||||
|
name: Literal["clip_detection_ap"] = "clip_detection_ap"
|
||||||
|
ap_implementation: APImplementation = "pascal_voc"
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionAP(MetricsProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
implementation: APImplementation,
|
||||||
|
):
|
||||||
|
self.implementation = implementation
|
||||||
|
self.metric = _ap_impl_mapping[self.implementation]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
clip_det = []
|
||||||
|
clip_scores = []
|
||||||
|
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
clip_det.append(match.gt_det)
|
||||||
|
clip_scores.append(match.pred_score)
|
||||||
|
|
||||||
|
y_true.append(any(clip_det))
|
||||||
|
y_score.append(max(clip_scores or [0]))
|
||||||
|
|
||||||
|
return {"clip_detection_ap": self.metric(y_true, y_score)}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ClipDetectionAPConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls(implementation=config.ap_implementation)
|
||||||
|
|
||||||
|
|
||||||
|
metrics_registry.register(ClipDetectionAPConfig, ClipDetectionAP)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionROCAUCConfig(BaseConfig):
|
||||||
|
name: Literal["clip_detection_roc_auc"] = "clip_detection_roc_auc"
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionROCAUC(MetricsProtocol):
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
clip_det = []
|
||||||
|
clip_scores = []
|
||||||
|
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
clip_det.append(match.gt_det)
|
||||||
|
clip_scores.append(match.pred_score)
|
||||||
|
|
||||||
|
y_true.append(any(clip_det))
|
||||||
|
y_score.append(max(clip_scores or [0]))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"clip_detection_ap": float(metrics.roc_auc_score(y_true, y_score))
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ClipDetectionROCAUCConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
metrics_registry.register(ClipDetectionROCAUCConfig, ClipDetectionROCAUC)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipMulticlassAPConfig(BaseConfig):
|
||||||
|
name: Literal["clip_multiclass_ap"] = "clip_multiclass_ap"
|
||||||
|
ap_implementation: APImplementation = "pascal_voc"
|
||||||
|
include: Optional[List[str]] = None
|
||||||
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ClipMulticlassAP(MetricsProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
class_names: List[str],
|
||||||
|
implementation: APImplementation,
|
||||||
|
include: Optional[Sequence[str]] = None,
|
||||||
|
exclude: Optional[Sequence[str]] = None,
|
||||||
|
):
|
||||||
|
self.implementation = implementation
|
||||||
|
self.metric = _ap_impl_mapping[self.implementation]
|
||||||
|
self.class_names = class_names
|
||||||
|
|
||||||
|
self.selected = class_names
|
||||||
|
|
||||||
|
if include is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name in include
|
||||||
|
]
|
||||||
|
|
||||||
|
if exclude is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name not in exclude
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_pred = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
clip_classes = set()
|
||||||
|
clip_scores = defaultdict(list)
|
||||||
|
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
if match.gt_class is not None:
|
||||||
|
clip_classes.add(match.gt_class)
|
||||||
|
|
||||||
|
for class_name, score in match.pred_class_scores.items():
|
||||||
|
clip_scores[class_name].append(score)
|
||||||
|
|
||||||
|
y_true.append(clip_classes)
|
||||||
|
y_pred.append(
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
# Get max score for each class
|
||||||
|
max(clip_scores.get(class_name, [0]))
|
||||||
|
for class_name in self.class_names
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
y_true = preprocessing.MultiLabelBinarizer(
|
||||||
|
classes=self.class_names
|
||||||
|
).fit_transform(y_true)
|
||||||
|
y_pred = np.stack(y_pred)
|
||||||
|
|
||||||
|
class_scores = {}
|
||||||
|
for class_index, class_name in enumerate(self.class_names):
|
||||||
|
y_true_class = y_true[:, class_index]
|
||||||
|
y_pred_class = y_pred[:, class_index]
|
||||||
|
class_ap = self.metric(y_true_class, y_pred_class)
|
||||||
|
class_scores[class_name] = float(class_ap)
|
||||||
|
|
||||||
|
mean_ap = np.mean(
|
||||||
|
[value for value in class_scores.values() if value != 0]
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"clip_multiclass_mAP": float(mean_ap),
|
||||||
|
**{
|
||||||
|
f"clip_multiclass_AP/{class_name}": class_scores[class_name]
|
||||||
|
for class_name in self.selected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls, config: ClipMulticlassAPConfig, class_names: List[str]
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
implementation=config.ap_implementation,
|
||||||
|
include=config.include,
|
||||||
|
exclude=config.exclude,
|
||||||
|
class_names=class_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
metrics_registry.register(ClipMulticlassAPConfig, ClipMulticlassAP)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipMulticlassROCAUCConfig(BaseConfig):
|
||||||
|
name: Literal["clip_multiclass_roc_auc"] = "clip_multiclass_roc_auc"
|
||||||
|
include: Optional[List[str]] = None
|
||||||
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ClipMulticlassROCAUC(MetricsProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
class_names: List[str],
|
||||||
|
include: Optional[Sequence[str]] = None,
|
||||||
|
exclude: Optional[Sequence[str]] = None,
|
||||||
|
):
|
||||||
|
self.class_names = class_names
|
||||||
|
self.selected = class_names
|
||||||
|
|
||||||
|
if include is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name in include
|
||||||
|
]
|
||||||
|
|
||||||
|
if exclude is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name not in exclude
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_pred = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
clip_classes = set()
|
||||||
|
clip_scores = defaultdict(list)
|
||||||
|
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
if match.gt_class is not None:
|
||||||
|
clip_classes.add(match.gt_class)
|
||||||
|
|
||||||
|
for class_name, score in match.pred_class_scores.items():
|
||||||
|
clip_scores[class_name].append(score)
|
||||||
|
|
||||||
|
y_true.append(clip_classes)
|
||||||
|
y_pred.append(
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
# Get maximum score for each class
|
||||||
|
max(clip_scores.get(class_name, [0]))
|
||||||
|
for class_name in self.class_names
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
y_true = preprocessing.MultiLabelBinarizer(
|
||||||
|
classes=self.class_names
|
||||||
|
).fit_transform(y_true)
|
||||||
|
y_pred = np.stack(y_pred)
|
||||||
|
|
||||||
|
class_scores = {}
|
||||||
|
for class_index, class_name in enumerate(self.class_names):
|
||||||
|
y_true_class = y_true[:, class_index]
|
||||||
|
y_pred_class = y_pred[:, class_index]
|
||||||
|
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
|
||||||
|
class_scores[class_name] = float(class_roc_auc)
|
||||||
|
|
||||||
|
mean_roc_auc = np.mean(
|
||||||
|
[value for value in class_scores.values() if value != 0]
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"clip_multiclass_macro_ROC_AUC": float(mean_roc_auc),
|
||||||
|
**{
|
||||||
|
f"clip_multiclass_ROC_AUC/{class_name}": class_scores[
|
||||||
|
class_name
|
||||||
|
]
|
||||||
|
for class_name in self.selected
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ClipMulticlassROCAUCConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
include=config.include,
|
||||||
|
exclude=config.exclude,
|
||||||
|
class_names=class_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
metrics_registry.register(ClipMulticlassROCAUCConfig, ClipMulticlassROCAUC)
|
||||||
|
|
||||||
MetricConfig = Annotated[
|
MetricConfig = Annotated[
|
||||||
Union[ClassificationAPConfig, DetectionAPConfig],
|
Union[
|
||||||
|
DetectionAPConfig,
|
||||||
|
DetectionROCAUCConfig,
|
||||||
|
ClassificationAPConfig,
|
||||||
|
ClassificationROCAUCConfig,
|
||||||
|
TopClassAPConfig,
|
||||||
|
ClassificationBalancedAccuracyConfig,
|
||||||
|
ClipDetectionAPConfig,
|
||||||
|
ClipDetectionROCAUCConfig,
|
||||||
|
ClipMulticlassAPConfig,
|
||||||
|
ClipMulticlassROCAUCConfig,
|
||||||
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_metric(config: MetricConfig, class_names: List[str]):
|
def build_metric(config: MetricConfig, class_names: List[str]):
|
||||||
return metrics_registry.build(config, class_names)
|
return metrics_registry.build(config, class_names)
|
||||||
|
|
||||||
|
|
||||||
|
def pascal_voc_average_precision(y_true, y_score) -> float:
|
||||||
|
y_true = np.array(y_true)
|
||||||
|
y_score = np.array(y_score)
|
||||||
|
|
||||||
|
sort_ind = np.argsort(y_score)[::-1]
|
||||||
|
y_true_sorted = y_true[sort_ind]
|
||||||
|
|
||||||
|
num_positives = y_true.sum()
|
||||||
|
false_pos_c = np.cumsum(1 - y_true_sorted)
|
||||||
|
true_pos_c = np.cumsum(y_true_sorted)
|
||||||
|
|
||||||
|
recall = true_pos_c / num_positives
|
||||||
|
precision = true_pos_c / np.maximum(
|
||||||
|
true_pos_c + false_pos_c,
|
||||||
|
np.finfo(np.float64).eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
precision[np.isnan(precision)] = 0
|
||||||
|
recall[np.isnan(recall)] = 0
|
||||||
|
|
||||||
|
# pascal 12 way
|
||||||
|
mprec = np.hstack((0, precision, 0))
|
||||||
|
mrec = np.hstack((0, recall, 1))
|
||||||
|
for ii in range(mprec.shape[0] - 2, -1, -1):
|
||||||
|
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
||||||
|
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
||||||
|
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
||||||
|
|
||||||
|
return ave_prec
|
||||||
|
|
||||||
|
|
||||||
|
_ap_impl_mapping: Mapping[APImplementation, Callable[[Any, Any], float]] = {
|
||||||
|
"sklearn": metrics.average_precision_score,
|
||||||
|
"pascal_voc": pascal_voc_average_precision,
|
||||||
|
}
|
||||||
|
|||||||
@ -4,20 +4,24 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
|
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
from sklearn.preprocessing import label_binarize
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
from batdetect2.data._core import Registry
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
|
|
||||||
from batdetect2.plotting.gallery import plot_match_gallery
|
from batdetect2.plotting.gallery import plot_match_gallery
|
||||||
|
from batdetect2.plotting.matches import plot_matches
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.typing.evaluate import (
|
from batdetect2.typing import (
|
||||||
|
AudioLoader,
|
||||||
ClipEvaluation,
|
ClipEvaluation,
|
||||||
MatchEvaluation,
|
MatchEvaluation,
|
||||||
PlotterProtocol,
|
PlotterProtocol,
|
||||||
|
PreprocessorProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"build_plotter",
|
"build_plotter",
|
||||||
@ -26,12 +30,13 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
plots_registry: Registry[PlotterProtocol, []] = Registry("plot")
|
plots_registry: Registry[PlotterProtocol, [List[str]]] = Registry("plot")
|
||||||
|
|
||||||
|
|
||||||
class ExampleGalleryConfig(BaseConfig):
|
class ExampleGalleryConfig(BaseConfig):
|
||||||
name: Literal["example_gallery"] = "example_gallery"
|
name: Literal["example_gallery"] = "example_gallery"
|
||||||
examples_per_class: int = 5
|
examples_per_class: int = 5
|
||||||
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
preprocessing: PreprocessingConfig = Field(
|
preprocessing: PreprocessingConfig = Field(
|
||||||
default_factory=PreprocessingConfig
|
default_factory=PreprocessingConfig
|
||||||
)
|
)
|
||||||
@ -87,9 +92,12 @@ class ExampleGallery(PlotterProtocol):
|
|||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: ExampleGalleryConfig):
|
def from_config(cls, config: ExampleGalleryConfig, class_names: List[str]):
|
||||||
preprocessor = build_preprocessor(config.preprocessing)
|
audio_loader = build_audio_loader(config.audio)
|
||||||
audio_loader = build_audio_loader(config.preprocessing.audio)
|
preprocessor = build_preprocessor(
|
||||||
|
config.preprocessing,
|
||||||
|
input_samplerate=audio_loader.samplerate,
|
||||||
|
)
|
||||||
return cls(
|
return cls(
|
||||||
examples_per_class=config.examples_per_class,
|
examples_per_class=config.examples_per_class,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
@ -100,13 +108,402 @@ class ExampleGallery(PlotterProtocol):
|
|||||||
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
|
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipEvaluationPlotConfig(BaseConfig):
|
||||||
|
name: Literal["example_clip"] = "example_clip"
|
||||||
|
num_plots: int = 5
|
||||||
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
|
preprocessing: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PlotClipEvaluation(PlotterProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_plots: int = 3,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
):
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
self.audio_loader = audio_loader
|
||||||
|
self.num_plots = num_plots
|
||||||
|
|
||||||
|
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||||
|
examples = random.sample(
|
||||||
|
clip_evaluations,
|
||||||
|
k=min(self.num_plots, len(clip_evaluations)),
|
||||||
|
)
|
||||||
|
|
||||||
|
for index, clip_evaluation in enumerate(examples):
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
plot_matches(
|
||||||
|
clip_evaluation.matches,
|
||||||
|
clip=clip_evaluation.clip,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
|
ax=ax,
|
||||||
|
)
|
||||||
|
yield f"clip_evaluation/example_{index}", fig
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ClipEvaluationPlotConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
audio_loader = build_audio_loader(config.audio)
|
||||||
|
preprocessor = build_preprocessor(
|
||||||
|
config.preprocessing,
|
||||||
|
input_samplerate=audio_loader.samplerate,
|
||||||
|
)
|
||||||
|
return cls(
|
||||||
|
num_plots=config.num_plots,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
plots_registry.register(ClipEvaluationPlotConfig, PlotClipEvaluation)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionPRCurveConfig(BaseConfig):
|
||||||
|
name: Literal["detection_pr_curve"] = "detection_pr_curve"
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionPRCurve(PlotterProtocol):
|
||||||
|
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||||
|
y_true, y_score = zip(
|
||||||
|
*[
|
||||||
|
(match.gt_det, match.pred_score)
|
||||||
|
for clip_eval in clip_evaluations
|
||||||
|
for match in clip_eval.matches
|
||||||
|
]
|
||||||
|
)
|
||||||
|
precision, recall, _ = metrics.precision_recall_curve(y_true, y_score)
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
|
||||||
|
ax.plot(recall, precision, label="Detector")
|
||||||
|
ax.set_xlabel("Recall")
|
||||||
|
ax.set_ylabel("Precision")
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
yield "detection_pr_curve", fig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: DetectionPRCurveConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
plots_registry.register(DetectionPRCurveConfig, DetectionPRCurve)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationPRCurvesConfig(BaseConfig):
|
||||||
|
name: Literal["classification_pr_curves"] = "classification_pr_curves"
|
||||||
|
include: Optional[List[str]] = None
|
||||||
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationPRCurves(PlotterProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
class_names: List[str],
|
||||||
|
include: Optional[List[str]] = None,
|
||||||
|
exclude: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
self.class_names = class_names
|
||||||
|
self.selected = class_names
|
||||||
|
|
||||||
|
if include is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name in include
|
||||||
|
]
|
||||||
|
|
||||||
|
if exclude is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name not in exclude
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||||
|
y_true = []
|
||||||
|
y_pred = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
# Ignore generic unclassified targets
|
||||||
|
if match.gt_det and match.gt_class is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(
|
||||||
|
match.gt_class
|
||||||
|
if match.gt_class is not None
|
||||||
|
else "__NONE__"
|
||||||
|
)
|
||||||
|
|
||||||
|
y_pred.append(
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
match.pred_class_scores.get(name, 0)
|
||||||
|
for name in self.class_names
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
y_true = label_binarize(y_true, classes=self.class_names)
|
||||||
|
y_pred = np.stack(y_pred)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 10))
|
||||||
|
for class_index, class_name in enumerate(self.class_names):
|
||||||
|
if class_name not in self.selected:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true_class = y_true[:, class_index]
|
||||||
|
y_pred_class = y_pred[:, class_index]
|
||||||
|
precision, recall, _ = metrics.precision_recall_curve(
|
||||||
|
y_true_class,
|
||||||
|
y_pred_class,
|
||||||
|
)
|
||||||
|
ax.plot(recall, precision, label=class_name)
|
||||||
|
|
||||||
|
ax.set_xlabel("Recall")
|
||||||
|
ax.set_ylabel("Precision")
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(1.05, 1),
|
||||||
|
loc="upper left",
|
||||||
|
borderaxespad=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "classification_pr_curve", fig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ClassificationPRCurvesConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
class_names=class_names,
|
||||||
|
include=config.include,
|
||||||
|
exclude=config.exclude,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
plots_registry.register(ClassificationPRCurvesConfig, ClassificationPRCurves)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionROCCurveConfig(BaseConfig):
|
||||||
|
name: Literal["detection_roc_curve"] = "detection_roc_curve"
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionROCCurve(PlotterProtocol):
|
||||||
|
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||||
|
y_true, y_score = zip(
|
||||||
|
*[
|
||||||
|
(match.gt_det, match.pred_score)
|
||||||
|
for clip_eval in clip_evaluations
|
||||||
|
for match in clip_eval.matches
|
||||||
|
]
|
||||||
|
)
|
||||||
|
fpr, tpr, _ = metrics.roc_curve(y_true, y_score)
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
|
||||||
|
ax.plot(fpr, tpr, label="Detection")
|
||||||
|
ax.set_xlabel("False Positive Rate")
|
||||||
|
ax.set_ylabel("True Positive Rate")
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
yield "detection_roc_curve", fig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: DetectionROCCurveConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
plots_registry.register(DetectionROCCurveConfig, DetectionROCCurve)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationROCCurvesConfig(BaseConfig):
|
||||||
|
name: Literal["classification_roc_curves"] = "classification_roc_curves"
|
||||||
|
include: Optional[List[str]] = None
|
||||||
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationROCCurves(PlotterProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
class_names: List[str],
|
||||||
|
include: Optional[List[str]] = None,
|
||||||
|
exclude: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
self.class_names = class_names
|
||||||
|
self.selected = class_names
|
||||||
|
|
||||||
|
if include is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name in include
|
||||||
|
]
|
||||||
|
|
||||||
|
if exclude is not None:
|
||||||
|
self.selected = [
|
||||||
|
class_name
|
||||||
|
for class_name in self.selected
|
||||||
|
if class_name not in exclude
|
||||||
|
]
|
||||||
|
|
||||||
|
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||||
|
y_true = []
|
||||||
|
y_pred = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
# Ignore generic unclassified targets
|
||||||
|
if match.gt_det and match.gt_class is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(
|
||||||
|
match.gt_class
|
||||||
|
if match.gt_class is not None
|
||||||
|
else "__NONE__"
|
||||||
|
)
|
||||||
|
|
||||||
|
y_pred.append(
|
||||||
|
np.array(
|
||||||
|
[
|
||||||
|
match.pred_class_scores.get(name, 0)
|
||||||
|
for name in self.class_names
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
y_true = label_binarize(y_true, classes=self.class_names)
|
||||||
|
y_pred = np.stack(y_pred)
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(figsize=(10, 10))
|
||||||
|
for class_index, class_name in enumerate(self.class_names):
|
||||||
|
if class_name not in self.selected:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true_class = y_true[:, class_index]
|
||||||
|
y_roced_class = y_pred[:, class_index]
|
||||||
|
fpr, tpr, _ = metrics.roc_curve(
|
||||||
|
y_true_class,
|
||||||
|
y_roced_class,
|
||||||
|
)
|
||||||
|
ax.plot(fpr, tpr, label=class_name)
|
||||||
|
|
||||||
|
ax.set_xlabel("False Positive Rate")
|
||||||
|
ax.set_ylabel("True Positive Rate")
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(1.05, 1),
|
||||||
|
loc="upper left",
|
||||||
|
borderaxespad=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "classification_roc_curve", fig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ClassificationROCCurvesConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
class_names=class_names,
|
||||||
|
include=config.include,
|
||||||
|
exclude=config.exclude,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrixConfig(BaseConfig):
|
||||||
|
name: Literal["confusion_matrix"] = "confusion_matrix"
|
||||||
|
background_class: str = "noise"
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrix(PlotterProtocol):
|
||||||
|
def __init__(self, background_class: str, class_names: List[str]):
|
||||||
|
self.background_class = background_class
|
||||||
|
self.class_names = class_names
|
||||||
|
|
||||||
|
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||||
|
y_true = []
|
||||||
|
y_pred = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
# Ignore generic unclassified targets
|
||||||
|
if match.gt_det and match.gt_class is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(
|
||||||
|
match.gt_class
|
||||||
|
if match.gt_class is not None
|
||||||
|
else self.background_class
|
||||||
|
)
|
||||||
|
|
||||||
|
top_class = match.pred_class
|
||||||
|
y_pred.append(
|
||||||
|
top_class
|
||||||
|
if top_class is not None
|
||||||
|
else self.background_class
|
||||||
|
)
|
||||||
|
|
||||||
|
display = metrics.ConfusionMatrixDisplay.from_predictions(
|
||||||
|
y_true,
|
||||||
|
y_pred,
|
||||||
|
labels=[*self.class_names, self.background_class],
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "confusion_matrix", display.figure_
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ConfusionMatrixConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
background_class=config.background_class,
|
||||||
|
class_names=class_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
plots_registry.register(ConfusionMatrixConfig, ConfusionMatrix)
|
||||||
|
|
||||||
|
|
||||||
PlotConfig = Annotated[
|
PlotConfig = Annotated[
|
||||||
Union[ExampleGalleryConfig,], Field(discriminator="name")
|
Union[
|
||||||
|
ExampleGalleryConfig,
|
||||||
|
ClipEvaluationPlotConfig,
|
||||||
|
DetectionPRCurveConfig,
|
||||||
|
ClassificationPRCurvesConfig,
|
||||||
|
DetectionROCCurveConfig,
|
||||||
|
ClassificationROCCurvesConfig,
|
||||||
|
ConfusionMatrixConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_plotter(config: PlotConfig) -> PlotterProtocol:
|
def build_plotter(
|
||||||
return plots_registry.build(config)
|
config: PlotConfig, class_names: List[str]
|
||||||
|
) -> PlotterProtocol:
|
||||||
|
return plots_registry.build(config, class_names)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -1,18 +1,49 @@
|
|||||||
from typing import List
|
from typing import Annotated, Callable, Literal, Sequence, Union
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from pydantic import Field
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.typing.evaluate import ClipEvaluation
|
from batdetect2.core import BaseConfig, Registry
|
||||||
|
from batdetect2.typing import ClipEvaluation
|
||||||
|
|
||||||
|
EvaluationTableGenerator = Callable[[Sequence[ClipEvaluation]], pd.DataFrame]
|
||||||
|
|
||||||
|
|
||||||
def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.DataFrame:
|
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
|
||||||
|
"evaluation_table"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FullEvaluationTableConfig(BaseConfig):
|
||||||
|
name: Literal["full_evaluation"] = "full_evaluation"
|
||||||
|
|
||||||
|
|
||||||
|
class FullEvaluationTable:
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
return extract_matches_dataframe(clip_evaluations)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: FullEvaluationTableConfig):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
tables_registry.register(FullEvaluationTableConfig, FullEvaluationTable)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_matches_dataframe(
|
||||||
|
clip_evaluations: Sequence[ClipEvaluation],
|
||||||
|
) -> pd.DataFrame:
|
||||||
data = []
|
data = []
|
||||||
|
|
||||||
for clip_evaluation in clip_evaluations:
|
for clip_evaluation in clip_evaluations:
|
||||||
for match in clip_evaluation.matches:
|
for match in clip_evaluation.matches:
|
||||||
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
||||||
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
|
pred_start_time = pred_low_freq = pred_end_time = (
|
||||||
|
pred_high_freq
|
||||||
|
) = None
|
||||||
|
|
||||||
sound_event_annotation = match.sound_event_annotation
|
sound_event_annotation = match.sound_event_annotation
|
||||||
|
|
||||||
@ -24,9 +55,12 @@ def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.Data
|
|||||||
)
|
)
|
||||||
|
|
||||||
if match.pred_geometry is not None:
|
if match.pred_geometry is not None:
|
||||||
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
|
(
|
||||||
compute_bounds(match.pred_geometry)
|
pred_start_time,
|
||||||
)
|
pred_low_freq,
|
||||||
|
pred_end_time,
|
||||||
|
pred_high_freq,
|
||||||
|
) = compute_bounds(match.pred_geometry)
|
||||||
|
|
||||||
data.append(
|
data.append(
|
||||||
{
|
{
|
||||||
@ -61,3 +95,14 @@ def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.Data
|
|||||||
df = pd.DataFrame(data)
|
df = pd.DataFrame(data)
|
||||||
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
|
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
EvaluationTableConfig = Annotated[
|
||||||
|
Union[FullEvaluationTableConfig,], Field(discriminator="name")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_table_generator(
|
||||||
|
config: EvaluationTableConfig,
|
||||||
|
) -> EvaluationTableGenerator:
|
||||||
|
return tables_registry.build(config)
|
||||||
0
src/batdetect2/inference/__init__.py
Normal file
0
src/batdetect2/inference/__init__.py
Normal file
@ -1,4 +1,6 @@
|
|||||||
import io
|
import io
|
||||||
|
from collections.abc import Callable
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
@ -13,12 +15,19 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
|
import pandas as pd
|
||||||
|
from lightning.pytorch.loggers import (
|
||||||
|
CSVLogger,
|
||||||
|
Logger,
|
||||||
|
MLFlowLogger,
|
||||||
|
TensorBoardLogger,
|
||||||
|
)
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from matplotlib.figure import Figure
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
|
||||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||||
|
|
||||||
@ -48,7 +57,7 @@ class TensorBoardLoggerConfig(BaseLoggerConfig):
|
|||||||
|
|
||||||
class MLFlowLoggerConfig(BaseLoggerConfig):
|
class MLFlowLoggerConfig(BaseLoggerConfig):
|
||||||
name: Literal["mlflow"] = "mlflow"
|
name: Literal["mlflow"] = "mlflow"
|
||||||
tracking_uri: Optional[str] = None
|
tracking_uri: Optional[str] = "http://localhost:5000"
|
||||||
tags: Optional[dict[str, Any]] = None
|
tags: Optional[dict[str, Any]] = None
|
||||||
log_model: bool = False
|
log_model: bool = False
|
||||||
|
|
||||||
@ -152,6 +161,9 @@ def create_tensorboard_logger(
|
|||||||
|
|
||||||
name = run_name
|
name = run_name
|
||||||
|
|
||||||
|
if name is None:
|
||||||
|
name = experiment_name
|
||||||
|
|
||||||
if run_name is not None and experiment_name is not None:
|
if run_name is not None and experiment_name is not None:
|
||||||
name = str(Path(experiment_name) / run_name)
|
name = str(Path(experiment_name) / run_name)
|
||||||
|
|
||||||
@ -231,18 +243,18 @@ def build_logger(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_image_plotter(logger: Logger):
|
PlotLogger = Callable[[str, Figure, int], None]
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
|
||||||
if isinstance(logger, TensorBoardLogger):
|
if isinstance(logger, TensorBoardLogger):
|
||||||
|
return logger.experiment.add_figure
|
||||||
def plot_figure(name, figure, step):
|
|
||||||
return logger.experiment.add_figure(name, figure, step)
|
|
||||||
|
|
||||||
return plot_figure
|
|
||||||
|
|
||||||
if isinstance(logger, MLFlowLogger):
|
if isinstance(logger, MLFlowLogger):
|
||||||
|
|
||||||
def plot_figure(name, figure, step):
|
def plot_figure(name, figure, step):
|
||||||
image = _convert_figure_to_image(figure)
|
image = _convert_figure_to_array(figure)
|
||||||
|
name = name.replace("/", "_")
|
||||||
return logger.experiment.log_image(
|
return logger.experiment.log_image(
|
||||||
logger.run_id,
|
logger.run_id,
|
||||||
image,
|
image,
|
||||||
@ -252,8 +264,51 @@ def get_image_plotter(logger: Logger):
|
|||||||
|
|
||||||
return plot_figure
|
return plot_figure
|
||||||
|
|
||||||
|
if isinstance(logger, CSVLogger):
|
||||||
|
return partial(save_figure, dir=Path(logger.log_dir))
|
||||||
|
|
||||||
def _convert_figure_to_image(figure):
|
|
||||||
|
TableLogger = Callable[[str, pd.DataFrame, int], None]
|
||||||
|
|
||||||
|
|
||||||
|
def get_table_logger(logger: Logger) -> Optional[TableLogger]:
|
||||||
|
if isinstance(logger, TensorBoardLogger):
|
||||||
|
return partial(save_table, dir=Path(logger.log_dir))
|
||||||
|
|
||||||
|
if isinstance(logger, MLFlowLogger):
|
||||||
|
|
||||||
|
def plot_figure(name: str, df: pd.DataFrame, step: int):
|
||||||
|
return logger.experiment.log_table(
|
||||||
|
logger.run_id,
|
||||||
|
data=df,
|
||||||
|
artifact_file=f"{name}_step_{step}.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
return plot_figure
|
||||||
|
|
||||||
|
if isinstance(logger, CSVLogger):
|
||||||
|
return partial(save_table, dir=Path(logger.log_dir))
|
||||||
|
|
||||||
|
|
||||||
|
def save_table(name: str, df: pd.DataFrame, step: int, dir: Path) -> None:
|
||||||
|
path = dir / "tables" / f"{name}_step_{step}.csv"
|
||||||
|
|
||||||
|
if not path.parent.exists():
|
||||||
|
path.parent.mkdir(parents=True)
|
||||||
|
|
||||||
|
df.to_csv(path, index=False)
|
||||||
|
|
||||||
|
|
||||||
|
def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
|
||||||
|
path = dir / "plots" / f"{name}_step_{step}.png"
|
||||||
|
|
||||||
|
if not path.parent.exists():
|
||||||
|
path.parent.mkdir(parents=True)
|
||||||
|
|
||||||
|
fig.savefig(path, transparent=True, bbox_inches="tight")
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_figure_to_array(figure: Figure) -> np.ndarray:
|
||||||
with io.BytesIO() as buff:
|
with io.BytesIO() as buff:
|
||||||
figure.savefig(buff, format="raw")
|
figure.savefig(buff, format="raw")
|
||||||
buff.seek(0)
|
buff.seek(0)
|
||||||
@ -29,15 +29,10 @@ provided here.
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
|
||||||
from soundevent.data import PathLike
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import (
|
||||||
Backbone,
|
Backbone,
|
||||||
BackboneConfig,
|
|
||||||
build_backbone,
|
build_backbone,
|
||||||
load_backbone_config,
|
|
||||||
)
|
)
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
@ -51,6 +46,10 @@ from batdetect2.models.bottleneck import (
|
|||||||
BottleneckConfig,
|
BottleneckConfig,
|
||||||
build_bottleneck,
|
build_bottleneck,
|
||||||
)
|
)
|
||||||
|
from batdetect2.models.config import (
|
||||||
|
BackboneConfig,
|
||||||
|
load_backbone_config,
|
||||||
|
)
|
||||||
from batdetect2.models.decoder import (
|
from batdetect2.models.decoder import (
|
||||||
DEFAULT_DECODER_CONFIG,
|
DEFAULT_DECODER_CONFIG,
|
||||||
DecoderConfig,
|
DecoderConfig,
|
||||||
@ -63,12 +62,12 @@ from batdetect2.models.encoder import (
|
|||||||
build_encoder,
|
build_encoder,
|
||||||
)
|
)
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||||
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import TargetConfig, build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.typing.models import DetectionModel
|
from batdetect2.typing.models import DetectionModel
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.typing.postprocess import (
|
||||||
DetectionsTensor,
|
ClipDetectionsTensor,
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
@ -99,20 +98,10 @@ __all__ = [
|
|||||||
"build_detector",
|
"build_detector",
|
||||||
"load_backbone_config",
|
"load_backbone_config",
|
||||||
"Model",
|
"Model",
|
||||||
"ModelConfig",
|
|
||||||
"build_model",
|
"build_model",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseConfig):
|
|
||||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
|
||||||
preprocess: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
|
||||||
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
detector: DetectionModel
|
detector: DetectionModel
|
||||||
preprocessor: PreprocessorProtocol
|
preprocessor: PreprocessorProtocol
|
||||||
@ -125,47 +114,38 @@ class Model(torch.nn.Module):
|
|||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
config: ModelConfig,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.detector = detector
|
self.detector = detector
|
||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
self.postprocessor = postprocessor
|
self.postprocessor = postprocessor
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
|
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
|
||||||
spec = self.preprocessor(wav)
|
spec = self.preprocessor(wav)
|
||||||
outputs = self.detector(spec)
|
outputs = self.detector(spec)
|
||||||
return self.postprocessor(outputs)
|
return self.postprocessor(outputs)
|
||||||
|
|
||||||
|
|
||||||
def build_model(config: Optional[ModelConfig] = None):
|
def build_model(
|
||||||
config = config or ModelConfig()
|
config: Optional[BackboneConfig] = None,
|
||||||
|
targets: Optional[TargetProtocol] = None,
|
||||||
targets = build_targets(config=config.targets)
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
postprocessor: Optional[PostprocessorProtocol] = None,
|
||||||
preprocessor = build_preprocessor(config=config.preprocess)
|
):
|
||||||
|
config = config or BackboneConfig()
|
||||||
postprocessor = build_postprocessor(
|
targets = targets or build_targets()
|
||||||
|
preprocessor = preprocessor or build_preprocessor()
|
||||||
|
postprocessor = postprocessor or build_postprocessor(
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config.postprocess,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
detector = build_detector(
|
detector = build_detector(
|
||||||
num_classes=len(targets.class_names),
|
num_classes=len(targets.class_names),
|
||||||
config=config.model,
|
config=config,
|
||||||
)
|
)
|
||||||
return Model(
|
return Model(
|
||||||
config=config,
|
|
||||||
detector=detector,
|
detector=detector,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_model_config(
|
|
||||||
path: PathLike, field: Optional[str] = None
|
|
||||||
) -> ModelConfig:
|
|
||||||
return load_config(path, schema=ModelConfig, field=field)
|
|
||||||
|
|||||||
@ -18,37 +18,20 @@ automatic padding to handle input sizes not perfectly divisible by the
|
|||||||
network's total downsampling factor.
|
network's total downsampling factor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from soundevent import data
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.models.bottleneck import build_bottleneck
|
||||||
from batdetect2.models.bottleneck import (
|
from batdetect2.models.config import BackboneConfig
|
||||||
DEFAULT_BOTTLENECK_CONFIG,
|
from batdetect2.models.decoder import Decoder, build_decoder
|
||||||
BottleneckConfig,
|
from batdetect2.models.encoder import Encoder, build_encoder
|
||||||
build_bottleneck,
|
|
||||||
)
|
|
||||||
from batdetect2.models.decoder import (
|
|
||||||
DEFAULT_DECODER_CONFIG,
|
|
||||||
Decoder,
|
|
||||||
DecoderConfig,
|
|
||||||
build_decoder,
|
|
||||||
)
|
|
||||||
from batdetect2.models.encoder import (
|
|
||||||
DEFAULT_ENCODER_CONFIG,
|
|
||||||
Encoder,
|
|
||||||
EncoderConfig,
|
|
||||||
build_encoder,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.models import BackboneModel
|
from batdetect2.typing.models import BackboneModel
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Backbone",
|
"Backbone",
|
||||||
"BackboneConfig",
|
|
||||||
"load_backbone_config",
|
|
||||||
"build_backbone",
|
"build_backbone",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -161,82 +144,6 @@ class Backbone(BackboneModel):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class BackboneConfig(BaseConfig):
|
|
||||||
"""Configuration for the Encoder-Decoder Backbone network.
|
|
||||||
|
|
||||||
Aggregates configurations for the encoder, bottleneck, and decoder
|
|
||||||
components, along with defining the input and final output dimensions
|
|
||||||
for the complete backbone.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
input_height : int, default=128
|
|
||||||
Expected height (frequency bins) of the input spectrograms to the
|
|
||||||
backbone. Must be positive.
|
|
||||||
in_channels : int, default=1
|
|
||||||
Expected number of channels in the input spectrograms (e.g., 1 for
|
|
||||||
mono). Must be positive.
|
|
||||||
encoder : EncoderConfig, optional
|
|
||||||
Configuration for the encoder. If None or omitted,
|
|
||||||
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
|
|
||||||
encoder module) will be used.
|
|
||||||
bottleneck : BottleneckConfig, optional
|
|
||||||
Configuration for the bottleneck layer connecting encoder and decoder.
|
|
||||||
If None or omitted, the default bottleneck configuration will be used.
|
|
||||||
decoder : DecoderConfig, optional
|
|
||||||
Configuration for the decoder. If None or omitted,
|
|
||||||
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
|
|
||||||
decoder module) will be used.
|
|
||||||
out_channels : int, default=32
|
|
||||||
Desired number of channels in the final feature map output by the
|
|
||||||
backbone. Must be positive.
|
|
||||||
"""
|
|
||||||
|
|
||||||
input_height: int = 128
|
|
||||||
in_channels: int = 1
|
|
||||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
|
||||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
|
||||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
|
||||||
out_channels: int = 32
|
|
||||||
|
|
||||||
|
|
||||||
def load_backbone_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> BackboneConfig:
|
|
||||||
"""Load the backbone configuration from a file.
|
|
||||||
|
|
||||||
Reads a configuration file (YAML) and validates it against the
|
|
||||||
`BackboneConfig` schema, potentially extracting data from a nested field.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : PathLike
|
|
||||||
Path to the configuration file.
|
|
||||||
field : str, optional
|
|
||||||
Dot-separated path to a nested section within the file containing the
|
|
||||||
backbone configuration (e.g., "model.backbone"). If None, the entire
|
|
||||||
file content is used.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
BackboneConfig
|
|
||||||
The loaded and validated backbone configuration object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the config file path does not exist.
|
|
||||||
yaml.YAMLError
|
|
||||||
If the file content is not valid YAML.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the loaded config data does not conform to `BackboneConfig`.
|
|
||||||
KeyError, TypeError
|
|
||||||
If `field` specifies an invalid path.
|
|
||||||
"""
|
|
||||||
return load_config(path, schema=BackboneConfig, field=field)
|
|
||||||
|
|
||||||
|
|
||||||
def build_backbone(config: BackboneConfig) -> BackboneModel:
|
def build_backbone(config: BackboneConfig) -> BackboneModel:
|
||||||
"""Factory function to build a Backbone from configuration.
|
"""Factory function to build a Backbone from configuration.
|
||||||
|
|
||||||
|
|||||||
@ -34,7 +34,7 @@ import torch.nn.functional as F
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ConvBlock",
|
"ConvBlock",
|
||||||
|
|||||||
@ -20,7 +20,7 @@ import torch
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
SelfAttentionConfig,
|
SelfAttentionConfig,
|
||||||
VerticalConv,
|
VerticalConv,
|
||||||
|
|||||||
98
src/batdetect2/models/config.py
Normal file
98
src/batdetect2/models/config.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.models.bottleneck import (
|
||||||
|
DEFAULT_BOTTLENECK_CONFIG,
|
||||||
|
BottleneckConfig,
|
||||||
|
)
|
||||||
|
from batdetect2.models.decoder import (
|
||||||
|
DEFAULT_DECODER_CONFIG,
|
||||||
|
DecoderConfig,
|
||||||
|
)
|
||||||
|
from batdetect2.models.encoder import (
|
||||||
|
DEFAULT_ENCODER_CONFIG,
|
||||||
|
EncoderConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BackboneConfig",
|
||||||
|
"load_backbone_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BackboneConfig(BaseConfig):
|
||||||
|
"""Configuration for the Encoder-Decoder Backbone network.
|
||||||
|
|
||||||
|
Aggregates configurations for the encoder, bottleneck, and decoder
|
||||||
|
components, along with defining the input and final output dimensions
|
||||||
|
for the complete backbone.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
input_height : int, default=128
|
||||||
|
Expected height (frequency bins) of the input spectrograms to the
|
||||||
|
backbone. Must be positive.
|
||||||
|
in_channels : int, default=1
|
||||||
|
Expected number of channels in the input spectrograms (e.g., 1 for
|
||||||
|
mono). Must be positive.
|
||||||
|
encoder : EncoderConfig, optional
|
||||||
|
Configuration for the encoder. If None or omitted,
|
||||||
|
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
|
||||||
|
encoder module) will be used.
|
||||||
|
bottleneck : BottleneckConfig, optional
|
||||||
|
Configuration for the bottleneck layer connecting encoder and decoder.
|
||||||
|
If None or omitted, the default bottleneck configuration will be used.
|
||||||
|
decoder : DecoderConfig, optional
|
||||||
|
Configuration for the decoder. If None or omitted,
|
||||||
|
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
|
||||||
|
decoder module) will be used.
|
||||||
|
out_channels : int, default=32
|
||||||
|
Desired number of channels in the final feature map output by the
|
||||||
|
backbone. Must be positive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_height: int = 128
|
||||||
|
in_channels: int = 1
|
||||||
|
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||||
|
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||||
|
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||||
|
out_channels: int = 32
|
||||||
|
|
||||||
|
|
||||||
|
def load_backbone_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> BackboneConfig:
|
||||||
|
"""Load the backbone configuration from a file.
|
||||||
|
|
||||||
|
Reads a configuration file (YAML) and validates it against the
|
||||||
|
`BackboneConfig` schema, potentially extracting data from a nested field.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
Path to the configuration file.
|
||||||
|
field : str, optional
|
||||||
|
Dot-separated path to a nested section within the file containing the
|
||||||
|
backbone configuration (e.g., "model.backbone"). If None, the entire
|
||||||
|
file content is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
BackboneConfig
|
||||||
|
The loaded and validated backbone configuration object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
yaml.YAMLError
|
||||||
|
If the file content is not valid YAML.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the loaded config data does not conform to `BackboneConfig`.
|
||||||
|
KeyError, TypeError
|
||||||
|
If `field` specifies an invalid path.
|
||||||
|
"""
|
||||||
|
return load_config(path, schema=BackboneConfig, field=field)
|
||||||
@ -24,7 +24,7 @@ import torch
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
|
|||||||
@ -26,7 +26,7 @@ import torch
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvDownConfig,
|
FreqCoordConvDownConfig,
|
||||||
|
|||||||
@ -5,8 +5,9 @@ import torch
|
|||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.plotting.common import plot_spectrogram
|
from batdetect2.plotting.common import plot_spectrogram
|
||||||
from batdetect2.preprocess import build_audio_loader, build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -6,10 +6,8 @@ from soundevent import data, plot
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
from soundevent.plot.tags import TagColorMapper
|
from soundevent.plot.tags import TagColorMapper
|
||||||
|
|
||||||
from batdetect2.plotting.clip_predictions import plot_prediction
|
|
||||||
from batdetect2.plotting.clips import AudioLoader, plot_clip
|
from batdetect2.plotting.clips import AudioLoader, plot_clip
|
||||||
from batdetect2.preprocess import PreprocessorProtocol
|
from batdetect2.typing import MatchEvaluation, PreprocessorProtocol
|
||||||
from batdetect2.typing.evaluate import MatchEvaluation
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plot_matches",
|
"plot_matches",
|
||||||
@ -30,7 +28,7 @@ DEFAULT_PREDICTION_LINE_STYLE = "--"
|
|||||||
|
|
||||||
|
|
||||||
def plot_matches(
|
def plot_matches(
|
||||||
matches: List[data.Match],
|
matches: List[MatchEvaluation],
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
@ -44,8 +42,7 @@ def plot_matches(
|
|||||||
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||||
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||||
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
cross_trigger_color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
ax = plot_clip(
|
ax = plot_clip(
|
||||||
clip,
|
clip,
|
||||||
@ -61,52 +58,48 @@ def plot_matches(
|
|||||||
color_mapper = TagColorMapper()
|
color_mapper = TagColorMapper()
|
||||||
|
|
||||||
for match in matches:
|
for match in matches:
|
||||||
if match.source is None and match.target is not None:
|
if match.is_cross_trigger():
|
||||||
plot.plot_annotation(
|
plot_cross_trigger_match(
|
||||||
annotation=match.target,
|
match,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.004,
|
fill=fill,
|
||||||
freq_offset=2_000,
|
add_points=add_points,
|
||||||
|
add_spectrogram=False,
|
||||||
|
use_score=True,
|
||||||
|
color=cross_trigger_color,
|
||||||
|
add_text=False,
|
||||||
|
)
|
||||||
|
elif match.is_true_positive():
|
||||||
|
plot_true_positive_match(
|
||||||
|
match,
|
||||||
|
ax=ax,
|
||||||
|
fill=fill,
|
||||||
|
add_spectrogram=False,
|
||||||
|
use_score=True,
|
||||||
|
add_points=add_points,
|
||||||
|
color=true_positive_color,
|
||||||
|
add_text=False,
|
||||||
|
)
|
||||||
|
elif match.is_false_negative():
|
||||||
|
plot_false_negative_match(
|
||||||
|
match,
|
||||||
|
ax=ax,
|
||||||
|
fill=fill,
|
||||||
|
add_spectrogram=False,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
|
||||||
color=false_negative_color,
|
color=false_negative_color,
|
||||||
color_mapper=color_mapper,
|
add_text=False,
|
||||||
linestyle=annotation_linestyle,
|
|
||||||
)
|
)
|
||||||
elif match.target is None and match.source is not None:
|
elif match.is_false_positive:
|
||||||
plot_prediction(
|
plot_false_positive_match(
|
||||||
prediction=match.source,
|
match,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.004,
|
fill=fill,
|
||||||
freq_offset=2_000,
|
add_spectrogram=False,
|
||||||
|
use_score=True,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
|
||||||
color=false_positive_color,
|
color=false_positive_color,
|
||||||
color_mapper=color_mapper,
|
add_text=False,
|
||||||
linestyle=prediction_linestyle,
|
|
||||||
)
|
|
||||||
elif match.target is not None and match.source is not None:
|
|
||||||
plot.plot_annotation(
|
|
||||||
annotation=match.target,
|
|
||||||
ax=ax,
|
|
||||||
time_offset=0.004,
|
|
||||||
freq_offset=2_000,
|
|
||||||
add_points=add_points,
|
|
||||||
facecolor="none" if not fill else None,
|
|
||||||
color=true_positive_color,
|
|
||||||
color_mapper=color_mapper,
|
|
||||||
linestyle=annotation_linestyle,
|
|
||||||
)
|
|
||||||
plot_prediction(
|
|
||||||
prediction=match.source,
|
|
||||||
ax=ax,
|
|
||||||
time_offset=0.004,
|
|
||||||
freq_offset=2_000,
|
|
||||||
add_points=add_points,
|
|
||||||
facecolor="none" if not fill else None,
|
|
||||||
color=true_positive_color,
|
|
||||||
color_mapper=color_mapper,
|
|
||||||
linestyle=prediction_linestyle,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
@ -122,6 +115,9 @@ def plot_false_positive_match(
|
|||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
|
use_score: bool = True,
|
||||||
|
add_spectrogram: bool = True,
|
||||||
|
add_text: bool = True,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
@ -142,34 +138,36 @@ def plot_false_positive_match(
|
|||||||
recording=match.clip.recording,
|
recording=match.clip.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax = plot_clip(
|
if add_spectrogram:
|
||||||
clip,
|
ax = plot_clip(
|
||||||
audio_loader=audio_loader,
|
clip,
|
||||||
preprocessor=preprocessor,
|
audio_loader=audio_loader,
|
||||||
figsize=figsize,
|
preprocessor=preprocessor,
|
||||||
ax=ax,
|
figsize=figsize,
|
||||||
audio_dir=audio_dir,
|
ax=ax,
|
||||||
spec_cmap=spec_cmap,
|
audio_dir=audio_dir,
|
||||||
)
|
spec_cmap=spec_cmap,
|
||||||
|
)
|
||||||
|
|
||||||
plot.plot_geometry(
|
ax = plot.plot_geometry(
|
||||||
match.pred_geometry,
|
match.pred_geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=match.pred_score if use_score else 1,
|
||||||
color=color,
|
color=color,
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.text(
|
if add_text:
|
||||||
start_time,
|
plt.text(
|
||||||
high_freq,
|
start_time,
|
||||||
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
high_freq,
|
||||||
va="top",
|
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||||
ha="right",
|
va="top",
|
||||||
color=color,
|
ha="right",
|
||||||
fontsize=fontsize,
|
color=color,
|
||||||
)
|
fontsize=fontsize,
|
||||||
|
)
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
@ -182,7 +180,9 @@ def plot_false_negative_match(
|
|||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
|
add_spectrogram: bool = True,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
|
add_text: bool = True,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||||
@ -204,17 +204,18 @@ def plot_false_negative_match(
|
|||||||
recording=sound_event.recording,
|
recording=sound_event.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax = plot_clip(
|
if add_spectrogram:
|
||||||
clip,
|
ax = plot_clip(
|
||||||
audio_loader=audio_loader,
|
clip,
|
||||||
preprocessor=preprocessor,
|
audio_loader=audio_loader,
|
||||||
figsize=figsize,
|
preprocessor=preprocessor,
|
||||||
ax=ax,
|
figsize=figsize,
|
||||||
audio_dir=audio_dir,
|
ax=ax,
|
||||||
spec_cmap=spec_cmap,
|
audio_dir=audio_dir,
|
||||||
)
|
spec_cmap=spec_cmap,
|
||||||
|
)
|
||||||
|
|
||||||
plot.plot_annotation(
|
ax = plot.plot_annotation(
|
||||||
match.sound_event_annotation,
|
match.sound_event_annotation,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
time_offset=0.001,
|
||||||
@ -225,15 +226,16 @@ def plot_false_negative_match(
|
|||||||
color=color,
|
color=color,
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.text(
|
if add_text:
|
||||||
start_time,
|
plt.text(
|
||||||
high_freq,
|
start_time,
|
||||||
f"False Negative \nClass: {match.gt_class} ",
|
high_freq,
|
||||||
va="top",
|
f"False Negative \nClass: {match.gt_class} ",
|
||||||
ha="right",
|
va="top",
|
||||||
color=color,
|
ha="right",
|
||||||
fontsize=fontsize,
|
color=color,
|
||||||
)
|
fontsize=fontsize,
|
||||||
|
)
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
@ -246,7 +248,10 @@ def plot_true_positive_match(
|
|||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
|
use_score: bool = True,
|
||||||
|
add_spectrogram: bool = True,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
|
add_text: bool = True,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||||
@ -270,17 +275,18 @@ def plot_true_positive_match(
|
|||||||
recording=sound_event.recording,
|
recording=sound_event.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax = plot_clip(
|
if add_spectrogram:
|
||||||
clip,
|
ax = plot_clip(
|
||||||
audio_loader=audio_loader,
|
clip,
|
||||||
preprocessor=preprocessor,
|
audio_loader=audio_loader,
|
||||||
figsize=figsize,
|
preprocessor=preprocessor,
|
||||||
ax=ax,
|
figsize=figsize,
|
||||||
audio_dir=audio_dir,
|
ax=ax,
|
||||||
spec_cmap=spec_cmap,
|
audio_dir=audio_dir,
|
||||||
)
|
spec_cmap=spec_cmap,
|
||||||
|
)
|
||||||
|
|
||||||
plot.plot_annotation(
|
ax = plot.plot_annotation(
|
||||||
match.sound_event_annotation,
|
match.sound_event_annotation,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
time_offset=0.001,
|
||||||
@ -297,20 +303,21 @@ def plot_true_positive_match(
|
|||||||
ax=ax,
|
ax=ax,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=match.pred_score if use_score else 1,
|
||||||
color=color,
|
color=color,
|
||||||
linestyle=prediction_linestyle,
|
linestyle=prediction_linestyle,
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.text(
|
if add_text:
|
||||||
start_time,
|
plt.text(
|
||||||
high_freq,
|
start_time,
|
||||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
high_freq,
|
||||||
va="top",
|
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||||
ha="right",
|
va="top",
|
||||||
color=color,
|
ha="right",
|
||||||
fontsize=fontsize,
|
color=color,
|
||||||
)
|
fontsize=fontsize,
|
||||||
|
)
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
@ -323,7 +330,10 @@ def plot_cross_trigger_match(
|
|||||||
ax: Optional[Axes] = None,
|
ax: Optional[Axes] = None,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
|
use_score: bool = True,
|
||||||
|
add_spectrogram: bool = True,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
|
add_text: bool = True,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||||
@ -347,17 +357,18 @@ def plot_cross_trigger_match(
|
|||||||
recording=sound_event.recording,
|
recording=sound_event.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax = plot_clip(
|
if add_spectrogram:
|
||||||
clip,
|
ax = plot_clip(
|
||||||
audio_loader=audio_loader,
|
clip,
|
||||||
preprocessor=preprocessor,
|
audio_loader=audio_loader,
|
||||||
figsize=figsize,
|
preprocessor=preprocessor,
|
||||||
ax=ax,
|
figsize=figsize,
|
||||||
audio_dir=audio_dir,
|
ax=ax,
|
||||||
spec_cmap=spec_cmap,
|
audio_dir=audio_dir,
|
||||||
)
|
spec_cmap=spec_cmap,
|
||||||
|
)
|
||||||
|
|
||||||
plot.plot_annotation(
|
ax = plot.plot_annotation(
|
||||||
match.sound_event_annotation,
|
match.sound_event_annotation,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
time_offset=0.001,
|
||||||
@ -369,24 +380,25 @@ def plot_cross_trigger_match(
|
|||||||
linestyle=annotation_linestyle,
|
linestyle=annotation_linestyle,
|
||||||
)
|
)
|
||||||
|
|
||||||
plot.plot_geometry(
|
ax = plot.plot_geometry(
|
||||||
match.pred_geometry,
|
match.pred_geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=match.pred_score if use_score else 1,
|
||||||
color=color,
|
color=color,
|
||||||
linestyle=prediction_linestyle,
|
linestyle=prediction_linestyle,
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.text(
|
if add_text:
|
||||||
start_time,
|
plt.text(
|
||||||
high_freq,
|
start_time,
|
||||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
high_freq,
|
||||||
va="top",
|
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||||
ha="right",
|
va="top",
|
||||||
color=color,
|
ha="right",
|
||||||
fontsize=fontsize,
|
color=color,
|
||||||
)
|
fontsize=fontsize,
|
||||||
|
)
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|||||||
@ -1,307 +1,25 @@
|
|||||||
"""Main entry point for the BatDetect2 Postprocessing pipeline."""
|
"""Main entry point for the BatDetect2 Postprocessing pipeline."""
|
||||||
|
|
||||||
from typing import List, Optional
|
from batdetect2.postprocess.config import (
|
||||||
|
PostprocessConfig,
|
||||||
import torch
|
load_postprocess_config,
|
||||||
from loguru import logger
|
)
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
|
||||||
from batdetect2.postprocess.decoding import (
|
from batdetect2.postprocess.decoding import (
|
||||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
|
||||||
convert_raw_prediction_to_sound_event_prediction,
|
|
||||||
convert_raw_predictions_to_clip_prediction,
|
convert_raw_predictions_to_clip_prediction,
|
||||||
to_raw_predictions,
|
to_raw_predictions,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.extraction import extract_prediction_tensor
|
from batdetect2.postprocess.nms import non_max_suppression
|
||||||
from batdetect2.postprocess.nms import (
|
from batdetect2.postprocess.postprocessor import (
|
||||||
NMS_KERNEL_SIZE,
|
Postprocessor,
|
||||||
non_max_suppression,
|
build_postprocessor,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.remapping import map_detection_to_clip
|
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
|
||||||
from batdetect2.typing import ModelOutput
|
|
||||||
from batdetect2.typing.postprocess import (
|
|
||||||
BatDetect2Prediction,
|
|
||||||
DetectionsTensor,
|
|
||||||
PostprocessorProtocol,
|
|
||||||
RawPrediction,
|
|
||||||
)
|
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
|
||||||
"DEFAULT_DETECTION_THRESHOLD",
|
|
||||||
"MAX_FREQ",
|
|
||||||
"MIN_FREQ",
|
|
||||||
"ModelOutput",
|
|
||||||
"NMS_KERNEL_SIZE",
|
|
||||||
"PostprocessConfig",
|
"PostprocessConfig",
|
||||||
"Postprocessor",
|
"Postprocessor",
|
||||||
"TOP_K_PER_SEC",
|
|
||||||
"build_postprocessor",
|
"build_postprocessor",
|
||||||
"convert_raw_predictions_to_clip_prediction",
|
"convert_raw_predictions_to_clip_prediction",
|
||||||
"to_raw_predictions",
|
"to_raw_predictions",
|
||||||
"load_postprocess_config",
|
"load_postprocess_config",
|
||||||
"non_max_suppression",
|
"non_max_suppression",
|
||||||
]
|
]
|
||||||
|
|
||||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
|
||||||
|
|
||||||
|
|
||||||
TOP_K_PER_SEC = 100
|
|
||||||
|
|
||||||
|
|
||||||
class PostprocessConfig(BaseConfig):
|
|
||||||
"""Configuration settings for the postprocessing pipeline.
|
|
||||||
|
|
||||||
Defines tunable parameters that control how raw model outputs are
|
|
||||||
converted into final detections.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
nms_kernel_size : int, default=NMS_KERNEL_SIZE
|
|
||||||
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
|
|
||||||
Used to suppress weaker detections near stronger peaks. Must be
|
|
||||||
positive.
|
|
||||||
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
|
|
||||||
Minimum confidence score from the detection heatmap required to
|
|
||||||
consider a point as a potential detection. Must be >= 0.
|
|
||||||
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
|
||||||
Minimum confidence score for a specific class prediction to be included
|
|
||||||
in the decoded tags for a detection. Must be >= 0.
|
|
||||||
top_k_per_sec : int, default=TOP_K_PER_SEC
|
|
||||||
Desired maximum number of detections per second of audio. Used by
|
|
||||||
`get_max_detections` to calculate an absolute limit based on clip
|
|
||||||
duration before applying `extract_detections_from_array`. Must be
|
|
||||||
positive.
|
|
||||||
"""
|
|
||||||
|
|
||||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
|
||||||
detection_threshold: float = Field(
|
|
||||||
default=DEFAULT_DETECTION_THRESHOLD,
|
|
||||||
ge=0,
|
|
||||||
)
|
|
||||||
classification_threshold: float = Field(
|
|
||||||
default=DEFAULT_CLASSIFICATION_THRESHOLD,
|
|
||||||
ge=0,
|
|
||||||
)
|
|
||||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
|
||||||
|
|
||||||
|
|
||||||
def load_postprocess_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> PostprocessConfig:
|
|
||||||
"""Load the postprocessing configuration from a file.
|
|
||||||
|
|
||||||
Reads a configuration file (YAML) and validates it against the
|
|
||||||
`PostprocessConfig` schema, potentially extracting data from a nested
|
|
||||||
field.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : PathLike
|
|
||||||
Path to the configuration file.
|
|
||||||
field : str, optional
|
|
||||||
Dot-separated path to a nested section within the file containing the
|
|
||||||
postprocessing configuration (e.g., "inference.postprocessing").
|
|
||||||
If None, the entire file content is used.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
PostprocessConfig
|
|
||||||
The loaded and validated postprocessing configuration object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the config file path does not exist.
|
|
||||||
yaml.YAMLError
|
|
||||||
If the file content is not valid YAML.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the loaded configuration data does not conform to the
|
|
||||||
`PostprocessConfig` schema.
|
|
||||||
KeyError, TypeError
|
|
||||||
If `field` specifies an invalid path within the loaded data.
|
|
||||||
"""
|
|
||||||
return load_config(path, schema=PostprocessConfig, field=field)
|
|
||||||
|
|
||||||
|
|
||||||
def build_postprocessor(
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
config: Optional[PostprocessConfig] = None,
|
|
||||||
) -> PostprocessorProtocol:
|
|
||||||
"""Factory function to build the standard postprocessor."""
|
|
||||||
config = config or PostprocessConfig()
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Building postprocessor with config: \n{}",
|
|
||||||
lambda: config.to_yaml_string(),
|
|
||||||
)
|
|
||||||
return Postprocessor(
|
|
||||||
samplerate=preprocessor.output_samplerate,
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
top_k_per_sec=config.top_k_per_sec,
|
|
||||||
detection_threshold=config.detection_threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|
||||||
"""Standard implementation of the postprocessing pipeline."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
samplerate: float,
|
|
||||||
min_freq: float,
|
|
||||||
max_freq: float,
|
|
||||||
top_k_per_sec: int = 200,
|
|
||||||
detection_threshold: float = 0.01,
|
|
||||||
):
|
|
||||||
"""Initialize the Postprocessor."""
|
|
||||||
super().__init__()
|
|
||||||
self.samplerate = samplerate
|
|
||||||
self.min_freq = min_freq
|
|
||||||
self.max_freq = max_freq
|
|
||||||
self.top_k_per_sec = top_k_per_sec
|
|
||||||
self.detection_threshold = detection_threshold
|
|
||||||
|
|
||||||
def forward(self, output: ModelOutput) -> List[DetectionsTensor]:
|
|
||||||
width = output.detection_probs.shape[-1]
|
|
||||||
duration = width / self.samplerate
|
|
||||||
max_detections = int(self.top_k_per_sec * duration)
|
|
||||||
detections = extract_prediction_tensor(
|
|
||||||
output,
|
|
||||||
max_detections=max_detections,
|
|
||||||
threshold=self.detection_threshold,
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
map_detection_to_clip(
|
|
||||||
detection,
|
|
||||||
start_time=0,
|
|
||||||
end_time=duration,
|
|
||||||
min_freq=self.min_freq,
|
|
||||||
max_freq=self.max_freq,
|
|
||||||
)
|
|
||||||
for detection in detections
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_detections(
|
|
||||||
self,
|
|
||||||
output: ModelOutput,
|
|
||||||
start_times: Optional[List[float]] = None,
|
|
||||||
) -> List[DetectionsTensor]:
|
|
||||||
width = output.detection_probs.shape[-1]
|
|
||||||
duration = width / self.samplerate
|
|
||||||
max_detections = int(self.top_k_per_sec * duration)
|
|
||||||
|
|
||||||
detections = extract_prediction_tensor(
|
|
||||||
output,
|
|
||||||
max_detections=max_detections,
|
|
||||||
threshold=self.detection_threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
if start_times is None:
|
|
||||||
return detections
|
|
||||||
|
|
||||||
width = output.detection_probs.shape[-1]
|
|
||||||
duration = width / self.samplerate
|
|
||||||
return [
|
|
||||||
map_detection_to_clip(
|
|
||||||
detection,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=start_time + duration,
|
|
||||||
min_freq=self.min_freq,
|
|
||||||
max_freq=self.max_freq,
|
|
||||||
)
|
|
||||||
for detection, start_time in zip(detections, start_times)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_raw_predictions(
|
|
||||||
output: ModelOutput,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
postprocessor: PostprocessorProtocol,
|
|
||||||
start_times: Optional[List[float]] = None,
|
|
||||||
) -> List[List[RawPrediction]]:
|
|
||||||
"""Extract intermediate RawPrediction objects for a batch."""
|
|
||||||
detections = postprocessor.get_detections(output, start_times)
|
|
||||||
return [
|
|
||||||
to_raw_predictions(detection.numpy(), targets=targets)
|
|
||||||
for detection in detections
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_sound_event_predictions(
|
|
||||||
output: ModelOutput,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
postprocessor: PostprocessorProtocol,
|
|
||||||
clips: List[data.Clip],
|
|
||||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
|
||||||
) -> List[List[BatDetect2Prediction]]:
|
|
||||||
raw_predictions = get_raw_predictions(
|
|
||||||
output,
|
|
||||||
targets=targets,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
start_times=[clip.start_time for clip in clips],
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
[
|
|
||||||
BatDetect2Prediction(
|
|
||||||
raw=raw,
|
|
||||||
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
|
||||||
raw,
|
|
||||||
recording=clip.recording,
|
|
||||||
targets=targets,
|
|
||||||
classification_threshold=classification_threshold,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for raw in predictions
|
|
||||||
]
|
|
||||||
for predictions, clip in zip(raw_predictions, clips)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_predictions(
|
|
||||||
output: ModelOutput,
|
|
||||||
clips: List[data.Clip],
|
|
||||||
targets: TargetProtocol,
|
|
||||||
postprocessor: PostprocessorProtocol,
|
|
||||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
|
||||||
) -> List[data.ClipPrediction]:
|
|
||||||
"""Perform the full postprocessing pipeline for a batch.
|
|
||||||
|
|
||||||
Takes raw model output and corresponding clips, applies the entire
|
|
||||||
configured chain (NMS, remapping, extraction, geometry recovery, class
|
|
||||||
decoding), producing final `soundevent.data.ClipPrediction` objects.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
output : ModelOutput
|
|
||||||
Raw output from the neural network model for a batch.
|
|
||||||
clips : List[data.Clip]
|
|
||||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[data.ClipPrediction]
|
|
||||||
List containing one `ClipPrediction` object for each input clip,
|
|
||||||
populated with `SoundEventPrediction` objects.
|
|
||||||
"""
|
|
||||||
raw_predictions = get_raw_predictions(
|
|
||||||
output,
|
|
||||||
targets=targets,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
start_times=[clip.start_time for clip in clips],
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
convert_raw_predictions_to_clip_prediction(
|
|
||||||
prediction,
|
|
||||||
clip,
|
|
||||||
targets=targets,
|
|
||||||
classification_threshold=classification_threshold,
|
|
||||||
)
|
|
||||||
for prediction, clip in zip(raw_predictions, clips)
|
|
||||||
]
|
|
||||||
|
|||||||
94
src/batdetect2/postprocess/config.py
Normal file
94
src/batdetect2/postprocess/config.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.postprocess.decoding import DEFAULT_CLASSIFICATION_THRESHOLD
|
||||||
|
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PostprocessConfig",
|
||||||
|
"load_postprocess_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||||
|
|
||||||
|
|
||||||
|
TOP_K_PER_SEC = 100
|
||||||
|
|
||||||
|
|
||||||
|
class PostprocessConfig(BaseConfig):
|
||||||
|
"""Configuration settings for the postprocessing pipeline.
|
||||||
|
|
||||||
|
Defines tunable parameters that control how raw model outputs are
|
||||||
|
converted into final detections.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
nms_kernel_size : int, default=NMS_KERNEL_SIZE
|
||||||
|
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
|
||||||
|
Used to suppress weaker detections near stronger peaks. Must be
|
||||||
|
positive.
|
||||||
|
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
|
||||||
|
Minimum confidence score from the detection heatmap required to
|
||||||
|
consider a point as a potential detection. Must be >= 0.
|
||||||
|
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||||
|
Minimum confidence score for a specific class prediction to be included
|
||||||
|
in the decoded tags for a detection. Must be >= 0.
|
||||||
|
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||||
|
Desired maximum number of detections per second of audio. Used by
|
||||||
|
`get_max_detections` to calculate an absolute limit based on clip
|
||||||
|
duration before applying `extract_detections_from_array`. Must be
|
||||||
|
positive.
|
||||||
|
"""
|
||||||
|
|
||||||
|
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||||
|
detection_threshold: float = Field(
|
||||||
|
default=DEFAULT_DETECTION_THRESHOLD,
|
||||||
|
ge=0,
|
||||||
|
)
|
||||||
|
classification_threshold: float = Field(
|
||||||
|
default=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
|
ge=0,
|
||||||
|
)
|
||||||
|
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||||
|
|
||||||
|
|
||||||
|
def load_postprocess_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> PostprocessConfig:
|
||||||
|
"""Load the postprocessing configuration from a file.
|
||||||
|
|
||||||
|
Reads a configuration file (YAML) and validates it against the
|
||||||
|
`PostprocessConfig` schema, potentially extracting data from a nested
|
||||||
|
field.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
Path to the configuration file.
|
||||||
|
field : str, optional
|
||||||
|
Dot-separated path to a nested section within the file containing the
|
||||||
|
postprocessing configuration (e.g., "inference.postprocessing").
|
||||||
|
If None, the entire file content is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
PostprocessConfig
|
||||||
|
The loaded and validated postprocessing configuration object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
yaml.YAMLError
|
||||||
|
If the file content is not valid YAML.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the loaded configuration data does not conform to the
|
||||||
|
`PostprocessConfig` schema.
|
||||||
|
KeyError, TypeError
|
||||||
|
If `field` specifies an invalid path within the loaded data.
|
||||||
|
"""
|
||||||
|
return load_config(path, schema=PostprocessConfig, field=field)
|
||||||
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.typing.postprocess import (
|
||||||
DetectionsArray,
|
ClipDetectionsArray,
|
||||||
RawPrediction,
|
RawPrediction,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
@ -28,7 +28,7 @@ decoding.
|
|||||||
|
|
||||||
|
|
||||||
def to_raw_predictions(
|
def to_raw_predictions(
|
||||||
detections: DetectionsArray,
|
detections: ClipDetectionsArray,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
) -> List[RawPrediction]:
|
) -> List[RawPrediction]:
|
||||||
predictions = []
|
predictions = []
|
||||||
|
|||||||
@ -15,32 +15,25 @@ precise time-frequency location of each detection. The final output aggregates
|
|||||||
all extracted information into a structured `xarray.Dataset`.
|
all extracted information into a structured `xarray.Dataset`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
from batdetect2.typing.postprocess import ClipDetectionsTensor
|
||||||
from batdetect2.typing.postprocess import (
|
|
||||||
DetectionsTensor,
|
|
||||||
ModelOutput,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"extract_prediction_tensor",
|
"extract_detection_peaks",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def extract_prediction_tensor(
|
def extract_detection_peaks(
|
||||||
output: ModelOutput,
|
detection_heatmap: torch.Tensor,
|
||||||
|
size_heatmap: torch.Tensor,
|
||||||
|
feature_heatmap: torch.Tensor,
|
||||||
|
classification_heatmap: torch.Tensor,
|
||||||
max_detections: int = 200,
|
max_detections: int = 200,
|
||||||
threshold: Optional[float] = None,
|
threshold: Optional[float] = None,
|
||||||
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
) -> List[ClipDetectionsTensor]:
|
||||||
) -> List[DetectionsTensor]:
|
|
||||||
detection_heatmap = non_max_suppression(
|
|
||||||
output.detection_probs.detach(),
|
|
||||||
kernel_size=nms_kernel_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
height = detection_heatmap.shape[-2]
|
height = detection_heatmap.shape[-2]
|
||||||
width = detection_heatmap.shape[-1]
|
width = detection_heatmap.shape[-1]
|
||||||
|
|
||||||
@ -53,9 +46,9 @@ def extract_prediction_tensor(
|
|||||||
freqs = freqs.flatten().to(detection_heatmap.device)
|
freqs = freqs.flatten().to(detection_heatmap.device)
|
||||||
times = times.flatten().to(detection_heatmap.device)
|
times = times.flatten().to(detection_heatmap.device)
|
||||||
|
|
||||||
output_size_preds = output.size_preds.detach()
|
output_size_preds = size_heatmap.detach()
|
||||||
output_features = output.features.detach()
|
output_features = feature_heatmap.detach()
|
||||||
output_class_probs = output.class_probs.detach()
|
output_class_probs = classification_heatmap.detach()
|
||||||
|
|
||||||
predictions = []
|
predictions = []
|
||||||
for idx, item in enumerate(detection_heatmap):
|
for idx, item in enumerate(detection_heatmap):
|
||||||
@ -65,23 +58,25 @@ def extract_prediction_tensor(
|
|||||||
detection_scores = item.take(indices)
|
detection_scores = item.take(indices)
|
||||||
detection_freqs = freqs.take(indices)
|
detection_freqs = freqs.take(indices)
|
||||||
detection_times = times.take(indices)
|
detection_times = times.take(indices)
|
||||||
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
|
|
||||||
features = output_features[idx, :, detection_freqs, detection_times].T
|
|
||||||
class_scores = output_class_probs[
|
|
||||||
idx, :, detection_freqs, detection_times
|
|
||||||
].T
|
|
||||||
|
|
||||||
if threshold is not None:
|
if threshold is not None:
|
||||||
mask = detection_scores >= threshold
|
mask = detection_scores >= threshold
|
||||||
|
|
||||||
detection_scores = detection_scores[mask]
|
detection_scores = detection_scores[mask]
|
||||||
sizes = sizes[mask]
|
|
||||||
detection_times = detection_times[mask]
|
detection_times = detection_times[mask]
|
||||||
detection_freqs = detection_freqs[mask]
|
detection_freqs = detection_freqs[mask]
|
||||||
features = features[mask]
|
|
||||||
class_scores = class_scores[mask]
|
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
|
||||||
|
features = output_features[idx, :, detection_freqs, detection_times].T
|
||||||
|
class_scores = output_class_probs[
|
||||||
|
idx,
|
||||||
|
:,
|
||||||
|
detection_freqs,
|
||||||
|
detection_times,
|
||||||
|
].T
|
||||||
|
|
||||||
predictions.append(
|
predictions.append(
|
||||||
DetectionsTensor(
|
ClipDetectionsTensor(
|
||||||
scores=detection_scores,
|
scores=detection_scores,
|
||||||
sizes=sizes,
|
sizes=sizes,
|
||||||
features=features,
|
features=features,
|
||||||
|
|||||||
100
src/batdetect2/postprocess/postprocessor.py
Normal file
100
src/batdetect2/postprocess/postprocessor.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from batdetect2.postprocess.config import (
|
||||||
|
PostprocessConfig,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess.extraction import extract_detection_peaks
|
||||||
|
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
||||||
|
from batdetect2.postprocess.remapping import map_detection_to_clip
|
||||||
|
from batdetect2.typing import ModelOutput
|
||||||
|
from batdetect2.typing.postprocess import (
|
||||||
|
ClipDetectionsTensor,
|
||||||
|
PostprocessorProtocol,
|
||||||
|
)
|
||||||
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"build_postprocessor",
|
||||||
|
"Postprocessor",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_postprocessor(
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
config: Optional[PostprocessConfig] = None,
|
||||||
|
) -> PostprocessorProtocol:
|
||||||
|
"""Factory function to build the standard postprocessor."""
|
||||||
|
config = config or PostprocessConfig()
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building postprocessor with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
|
return Postprocessor(
|
||||||
|
samplerate=preprocessor.output_samplerate,
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
top_k_per_sec=config.top_k_per_sec,
|
||||||
|
detection_threshold=config.detection_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||||
|
"""Standard implementation of the postprocessing pipeline."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
samplerate: float,
|
||||||
|
min_freq: float,
|
||||||
|
max_freq: float,
|
||||||
|
top_k_per_sec: int = 200,
|
||||||
|
detection_threshold: float = 0.01,
|
||||||
|
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||||
|
):
|
||||||
|
"""Initialize the Postprocessor."""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.output_samplerate = samplerate
|
||||||
|
self.min_freq = min_freq
|
||||||
|
self.max_freq = max_freq
|
||||||
|
self.top_k_per_sec = top_k_per_sec
|
||||||
|
self.detection_threshold = detection_threshold
|
||||||
|
self.nms_kernel_size = nms_kernel_size
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
start_times: Optional[List[float]] = None,
|
||||||
|
) -> List[ClipDetectionsTensor]:
|
||||||
|
detection_heatmap = non_max_suppression(
|
||||||
|
output.detection_probs.detach(),
|
||||||
|
kernel_size=self.nms_kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
width = output.detection_probs.shape[-1]
|
||||||
|
duration = width / self.output_samplerate
|
||||||
|
max_detections = int(self.top_k_per_sec * duration)
|
||||||
|
detections = extract_detection_peaks(
|
||||||
|
detection_heatmap,
|
||||||
|
size_heatmap=output.size_preds,
|
||||||
|
feature_heatmap=output.features,
|
||||||
|
classification_heatmap=output.class_probs,
|
||||||
|
max_detections=max_detections,
|
||||||
|
threshold=self.detection_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
if start_times is None:
|
||||||
|
start_times = [0 for _ in range(len(detections))]
|
||||||
|
|
||||||
|
return [
|
||||||
|
map_detection_to_clip(
|
||||||
|
detection,
|
||||||
|
start_time=0,
|
||||||
|
end_time=duration,
|
||||||
|
min_freq=self.min_freq,
|
||||||
|
max_freq=self.max_freq,
|
||||||
|
)
|
||||||
|
for detection in detections
|
||||||
|
]
|
||||||
@ -20,7 +20,7 @@ import xarray as xr
|
|||||||
from soundevent.arrays import Dimensions
|
from soundevent.arrays import Dimensions
|
||||||
|
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
from batdetect2.typing.postprocess import DetectionsTensor
|
from batdetect2.typing.postprocess import ClipDetectionsTensor
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"features_to_xarray",
|
"features_to_xarray",
|
||||||
@ -31,15 +31,15 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
def map_detection_to_clip(
|
def map_detection_to_clip(
|
||||||
detections: DetectionsTensor,
|
detections: ClipDetectionsTensor,
|
||||||
start_time: float,
|
start_time: float,
|
||||||
end_time: float,
|
end_time: float,
|
||||||
min_freq: float,
|
min_freq: float,
|
||||||
max_freq: float,
|
max_freq: float,
|
||||||
) -> DetectionsTensor:
|
) -> ClipDetectionsTensor:
|
||||||
duration = end_time - start_time
|
duration = end_time - start_time
|
||||||
bandwidth = max_freq - min_freq
|
bandwidth = max_freq - min_freq
|
||||||
return DetectionsTensor(
|
return ClipDetectionsTensor(
|
||||||
scores=detections.scores,
|
scores=detections.scores,
|
||||||
sizes=detections.sizes,
|
sizes=detections.sizes,
|
||||||
features=detections.features,
|
features=detections.features,
|
||||||
|
|||||||
@ -1,176 +1,19 @@
|
|||||||
"""Main entry point for the BatDetect2 Preprocessing subsystem.
|
"""Main entry point for the BatDetect2 preprocessing subsystem."""
|
||||||
|
|
||||||
This package (`batdetect2.preprocessing`) defines and orchestrates the pipeline
|
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||||
for converting raw audio input (from files or data objects) into processed
|
from batdetect2.preprocess.config import (
|
||||||
spectrograms suitable for input to BatDetect2 models. This ensures consistent
|
PreprocessingConfig,
|
||||||
data handling between model training and inference.
|
load_preprocessing_config,
|
||||||
|
|
||||||
The preprocessing pipeline consists of two main stages, configured via nested
|
|
||||||
data structures:
|
|
||||||
1. **Audio Processing (`.audio`)**: Loads audio waveforms and applies initial
|
|
||||||
processing like resampling, duration adjustment, centering, and scaling.
|
|
||||||
Configured via `AudioConfig`.
|
|
||||||
2. **Spectrogram Generation (`.spectrogram`)**: Computes the spectrogram from
|
|
||||||
the processed waveform using STFT, followed by frequency cropping, optional
|
|
||||||
PCEN, amplitude scaling (dB, power, linear), optional denoising, optional
|
|
||||||
resizing, and optional peak normalization. Configured via
|
|
||||||
`SpectrogramConfig`.
|
|
||||||
|
|
||||||
This module provides the primary interface:
|
|
||||||
|
|
||||||
- `PreprocessingConfig`: A unified configuration object holding `AudioConfig`
|
|
||||||
and `SpectrogramConfig`.
|
|
||||||
- `load_preprocessing_config`: Function to load the unified configuration.
|
|
||||||
- `Preprocessor`: A protocol defining the interface for the end-to-end pipeline.
|
|
||||||
- `StandardPreprocessor`: The default implementation of the `Preprocessor`.
|
|
||||||
- `build_preprocessor`: A factory function to create a `StandardPreprocessor`
|
|
||||||
instance from a `PreprocessingConfig`.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent.data import PathLike
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
|
||||||
from batdetect2.preprocess.audio import (
|
|
||||||
DEFAULT_DURATION,
|
|
||||||
SCALE_RAW_AUDIO,
|
|
||||||
TARGET_SAMPLERATE_HZ,
|
|
||||||
AudioConfig,
|
|
||||||
ResampleConfig,
|
|
||||||
build_audio_loader,
|
|
||||||
build_audio_pipeline,
|
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess.spectrogram import (
|
from batdetect2.preprocess.preprocessor import Preprocessor, build_preprocessor
|
||||||
MAX_FREQ,
|
from batdetect2.preprocess.spectrogram import MAX_FREQ, MIN_FREQ
|
||||||
MIN_FREQ,
|
|
||||||
FrequencyConfig,
|
|
||||||
PcenConfig,
|
|
||||||
SpectrogramConfig,
|
|
||||||
SpectrogramPipeline,
|
|
||||||
STFTConfig,
|
|
||||||
_spec_params_from_config,
|
|
||||||
build_spectrogram_builder,
|
|
||||||
build_spectrogram_pipeline,
|
|
||||||
)
|
|
||||||
from batdetect2.typing import PreprocessorProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AudioConfig",
|
|
||||||
"DEFAULT_DURATION",
|
|
||||||
"FrequencyConfig",
|
|
||||||
"MAX_FREQ",
|
"MAX_FREQ",
|
||||||
"MIN_FREQ",
|
"MIN_FREQ",
|
||||||
"PcenConfig",
|
|
||||||
"PreprocessingConfig",
|
"PreprocessingConfig",
|
||||||
"ResampleConfig",
|
"Preprocessor",
|
||||||
"SCALE_RAW_AUDIO",
|
|
||||||
"STFTConfig",
|
|
||||||
"SpectrogramConfig",
|
|
||||||
"TARGET_SAMPLERATE_HZ",
|
"TARGET_SAMPLERATE_HZ",
|
||||||
"build_audio_loader",
|
"build_preprocessor",
|
||||||
"build_spectrogram_builder",
|
|
||||||
"load_preprocessing_config",
|
"load_preprocessing_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class PreprocessingConfig(BaseConfig):
|
|
||||||
"""Unified configuration for the audio preprocessing pipeline.
|
|
||||||
|
|
||||||
Aggregates the configuration for both the initial audio processing stage
|
|
||||||
and the subsequent spectrogram generation stage.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
audio : AudioConfig
|
|
||||||
Configuration settings for the audio loading and initial waveform
|
|
||||||
processing steps (e.g., resampling, duration adjustment, scaling).
|
|
||||||
Defaults to default `AudioConfig` settings if omitted.
|
|
||||||
spectrogram : SpectrogramConfig
|
|
||||||
Configuration settings for the spectrogram generation process
|
|
||||||
(e.g., STFT parameters, frequency cropping, scaling, denoising,
|
|
||||||
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
|
|
||||||
"""
|
|
||||||
|
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
|
||||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def load_preprocessing_config(
|
|
||||||
path: PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> PreprocessingConfig:
|
|
||||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
|
||||||
|
|
||||||
|
|
||||||
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
|
|
||||||
"""Standard implementation of the `Preprocessor` protocol."""
|
|
||||||
|
|
||||||
input_samplerate: int
|
|
||||||
output_samplerate: float
|
|
||||||
|
|
||||||
max_freq: float
|
|
||||||
min_freq: float
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
audio_pipeline: torch.nn.Module,
|
|
||||||
spectrogram_pipeline: SpectrogramPipeline,
|
|
||||||
input_samplerate: int,
|
|
||||||
output_samplerate: float,
|
|
||||||
max_freq: float,
|
|
||||||
min_freq: float,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.audio_pipeline = audio_pipeline
|
|
||||||
self.spectrogram_pipeline = spectrogram_pipeline
|
|
||||||
|
|
||||||
self.max_freq = max_freq
|
|
||||||
self.min_freq = min_freq
|
|
||||||
|
|
||||||
self.input_samplerate = input_samplerate
|
|
||||||
self.output_samplerate = output_samplerate
|
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
|
||||||
wav = self.audio_pipeline(wav)
|
|
||||||
return self.spectrogram_pipeline(wav)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_output_samplerate(config: PreprocessingConfig) -> float:
|
|
||||||
samplerate = config.audio.samplerate
|
|
||||||
_, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft)
|
|
||||||
factor = config.spectrogram.size.resize_factor
|
|
||||||
return samplerate * factor / hop_size
|
|
||||||
|
|
||||||
|
|
||||||
def build_preprocessor(
|
|
||||||
config: Optional[PreprocessingConfig] = None,
|
|
||||||
) -> PreprocessorProtocol:
|
|
||||||
"""Factory function to build the standard preprocessor from configuration."""
|
|
||||||
config = config or PreprocessingConfig()
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Building preprocessor with config: \n{}",
|
|
||||||
lambda: config.to_yaml_string(),
|
|
||||||
)
|
|
||||||
|
|
||||||
samplerate = config.audio.samplerate
|
|
||||||
|
|
||||||
min_freq = config.spectrogram.frequencies.min_freq
|
|
||||||
max_freq = config.spectrogram.frequencies.max_freq
|
|
||||||
|
|
||||||
output_samplerate = compute_output_samplerate(config)
|
|
||||||
|
|
||||||
return StandardPreprocessor(
|
|
||||||
audio_pipeline=build_audio_pipeline(config.audio),
|
|
||||||
spectrogram_pipeline=build_spectrogram_pipeline(
|
|
||||||
samplerate, config.spectrogram
|
|
||||||
),
|
|
||||||
input_samplerate=samplerate,
|
|
||||||
output_samplerate=output_samplerate,
|
|
||||||
min_freq=min_freq,
|
|
||||||
max_freq=max_freq,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,307 +1,57 @@
|
|||||||
"""Handles loading and initial preprocessing of audio waveforms."""
|
from typing import Annotated, Literal, Union
|
||||||
|
|
||||||
from typing import Annotated, List, Literal, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from numpy.typing import DTypeLike
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from scipy.signal import resample, resample_poly
|
|
||||||
from soundevent import audio, data
|
|
||||||
from soundfile import LibsndfileError
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||||
from batdetect2.preprocess.common import CenterTensor, PeakNormalize
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.typing import AudioLoader
|
from batdetect2.preprocess.common import center_tensor, peak_normalize
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ResampleConfig",
|
"CenterAudioConfig",
|
||||||
"AudioConfig",
|
"ScaleAudioConfig",
|
||||||
"SoundEventAudioLoader",
|
"FixDurationConfig",
|
||||||
"build_audio_loader",
|
"build_audio_transform",
|
||||||
"load_file_audio",
|
|
||||||
"load_recording_audio",
|
|
||||||
"load_clip_audio",
|
|
||||||
"resample_audio",
|
|
||||||
"TARGET_SAMPLERATE_HZ",
|
|
||||||
"SCALE_RAW_AUDIO",
|
|
||||||
"DEFAULT_DURATION",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
TARGET_SAMPLERATE_HZ = 256_000
|
|
||||||
"""Default target sample rate in Hz used if resampling is enabled."""
|
|
||||||
|
|
||||||
SCALE_RAW_AUDIO = False
|
audio_transforms: Registry[torch.nn.Module, [int]] = Registry(
|
||||||
"""Default setting for whether to perform peak normalization."""
|
"audio_transform"
|
||||||
|
)
|
||||||
DEFAULT_DURATION = None
|
|
||||||
"""Default setting for target audio duration in seconds."""
|
|
||||||
|
|
||||||
|
|
||||||
class ResampleConfig(BaseConfig):
|
|
||||||
"""Configuration for audio resampling.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
samplerate : int, default=256000
|
|
||||||
The target sample rate in Hz to resample the audio to. Must be > 0.
|
|
||||||
method : str, default="poly"
|
|
||||||
The resampling algorithm to use. Options:
|
|
||||||
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
|
|
||||||
Generally fast.
|
|
||||||
- "fourier": Resampling via Fourier method using
|
|
||||||
`scipy.signal.resample`. May handle non-integer
|
|
||||||
resampling factors differently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
enabled: bool = True
|
|
||||||
method: str = "poly"
|
|
||||||
|
|
||||||
|
|
||||||
class SoundEventAudioLoader:
|
|
||||||
"""Concrete implementation of the `AudioLoader`."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
|
||||||
config: Optional[ResampleConfig] = None,
|
|
||||||
):
|
|
||||||
self.samplerate = samplerate
|
|
||||||
self.config = config or ResampleConfig()
|
|
||||||
|
|
||||||
def load_file(
|
|
||||||
self,
|
|
||||||
path: data.PathLike,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Load and preprocess audio directly from a file path."""
|
|
||||||
return load_file_audio(
|
|
||||||
path,
|
|
||||||
samplerate=self.samplerate,
|
|
||||||
config=self.config,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_recording(
|
|
||||||
self,
|
|
||||||
recording: data.Recording,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Load and preprocess the entire audio for a Recording object."""
|
|
||||||
return load_recording_audio(
|
|
||||||
recording,
|
|
||||||
samplerate=self.samplerate,
|
|
||||||
config=self.config,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_clip(
|
|
||||||
self,
|
|
||||||
clip: data.Clip,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Load and preprocess the audio segment defined by a Clip object."""
|
|
||||||
return load_clip_audio(
|
|
||||||
clip,
|
|
||||||
samplerate=self.samplerate,
|
|
||||||
config=self.config,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_file_audio(
|
|
||||||
path: data.PathLike,
|
|
||||||
samplerate: Optional[int] = None,
|
|
||||||
config: Optional[ResampleConfig] = None,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Load and preprocess audio from a file path using specified config."""
|
|
||||||
try:
|
|
||||||
recording = data.Recording.from_file(path)
|
|
||||||
except LibsndfileError as e:
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Could not load the recording at path: {path}. Error: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
return load_recording_audio(
|
|
||||||
recording,
|
|
||||||
samplerate=samplerate,
|
|
||||||
config=config,
|
|
||||||
dtype=dtype,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_recording_audio(
|
|
||||||
recording: data.Recording,
|
|
||||||
samplerate: Optional[int] = None,
|
|
||||||
config: Optional[ResampleConfig] = None,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Load and preprocess the entire audio content of a recording using config."""
|
|
||||||
clip = data.Clip(
|
|
||||||
recording=recording,
|
|
||||||
start_time=0,
|
|
||||||
end_time=recording.duration,
|
|
||||||
)
|
|
||||||
return load_clip_audio(
|
|
||||||
clip,
|
|
||||||
samplerate=samplerate,
|
|
||||||
config=config,
|
|
||||||
dtype=dtype,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_clip_audio(
|
|
||||||
clip: data.Clip,
|
|
||||||
samplerate: Optional[int] = None,
|
|
||||||
config: Optional[ResampleConfig] = None,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
dtype: DTypeLike = np.float32, # type: ignore
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Load and preprocess a specific audio clip segment based on config."""
|
|
||||||
try:
|
|
||||||
wav = (
|
|
||||||
audio.load_clip(clip, audio_dir=audio_dir)
|
|
||||||
.sel(channel=0)
|
|
||||||
.astype(dtype)
|
|
||||||
)
|
|
||||||
except LibsndfileError as e:
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Could not load the recording at path: {clip.recording.path}. "
|
|
||||||
f"Error: {e}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
if not config or not config.enabled or samplerate is None:
|
|
||||||
return wav.data.astype(dtype)
|
|
||||||
|
|
||||||
sr = int(1 / wav.time.attrs["step"])
|
|
||||||
return resample_audio(
|
|
||||||
wav.data,
|
|
||||||
sr=sr,
|
|
||||||
samplerate=samplerate,
|
|
||||||
method=config.method,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def resample_audio(
|
|
||||||
wav: np.ndarray,
|
|
||||||
sr: int,
|
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
|
||||||
method: str = "poly",
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Resample an audio waveform DataArray to a target sample rate."""
|
|
||||||
if sr == samplerate:
|
|
||||||
return wav
|
|
||||||
|
|
||||||
if method == "poly":
|
|
||||||
return resample_audio_poly(
|
|
||||||
wav,
|
|
||||||
sr_orig=sr,
|
|
||||||
sr_new=samplerate,
|
|
||||||
)
|
|
||||||
elif method == "fourier":
|
|
||||||
return resample_audio_fourier(
|
|
||||||
wav,
|
|
||||||
sr_orig=sr,
|
|
||||||
sr_new=samplerate,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Resampling method '{method}' not implemented"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def resample_audio_poly(
|
|
||||||
array: np.ndarray,
|
|
||||||
sr_orig: int,
|
|
||||||
sr_new: int,
|
|
||||||
axis: int = -1,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Resample a numpy array using `scipy.signal.resample_poly`.
|
|
||||||
|
|
||||||
This method is often preferred for signals when the ratio of new
|
|
||||||
to old sample rates can be expressed as a rational number. It uses
|
|
||||||
polyphase filtering.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
array : np.ndarray
|
|
||||||
The input array to resample.
|
|
||||||
sr_orig : int
|
|
||||||
The original sample rate in Hz.
|
|
||||||
sr_new : int
|
|
||||||
The target sample rate in Hz.
|
|
||||||
axis : int, default=-1
|
|
||||||
The axis of `array` along which to resample.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
The array resampled to the target sample rate.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If sample rates are not positive.
|
|
||||||
"""
|
|
||||||
gcd = np.gcd(sr_orig, sr_new)
|
|
||||||
return resample_poly(
|
|
||||||
array,
|
|
||||||
sr_new // gcd,
|
|
||||||
sr_orig // gcd,
|
|
||||||
axis=axis,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def resample_audio_fourier(
|
|
||||||
array: np.ndarray,
|
|
||||||
sr_orig: int,
|
|
||||||
sr_new: int,
|
|
||||||
axis: int = -1,
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""Resample a numpy array using `scipy.signal.resample`.
|
|
||||||
|
|
||||||
This method uses FFTs to resample the signal.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
array : np.ndarray
|
|
||||||
The input array to resample.
|
|
||||||
num : int
|
|
||||||
The desired number of samples in the output array along `axis`.
|
|
||||||
axis : int, default=-1
|
|
||||||
The axis of `array` along which to resample.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
np.ndarray
|
|
||||||
The array resampled to have `num` samples along `axis`.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If `num` is negative.
|
|
||||||
"""
|
|
||||||
ratio = sr_new / sr_orig
|
|
||||||
return resample( # type: ignore
|
|
||||||
array,
|
|
||||||
int(array.shape[axis] * ratio),
|
|
||||||
axis=axis,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CenterAudioConfig(BaseConfig):
|
class CenterAudioConfig(BaseConfig):
|
||||||
name: Literal["center_audio"] = "center_audio"
|
name: Literal["center_audio"] = "center_audio"
|
||||||
|
|
||||||
|
|
||||||
|
class CenterAudio(torch.nn.Module):
|
||||||
|
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
return center_tensor(wav)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: CenterAudioConfig, samplerate: int):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
audio_transforms.register(CenterAudioConfig, CenterAudio)
|
||||||
|
|
||||||
|
|
||||||
class ScaleAudioConfig(BaseConfig):
|
class ScaleAudioConfig(BaseConfig):
|
||||||
name: Literal["scale_audio"] = "scale_audio"
|
name: Literal["scale_audio"] = "scale_audio"
|
||||||
|
|
||||||
|
|
||||||
|
class ScaleAudio(torch.nn.Module):
|
||||||
|
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
return peak_normalize(wav)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: ScaleAudioConfig, samplerate: int):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
audio_transforms.register(ScaleAudioConfig, ScaleAudio)
|
||||||
|
|
||||||
|
|
||||||
class FixDurationConfig(BaseConfig):
|
class FixDurationConfig(BaseConfig):
|
||||||
name: Literal["fix_duration"] = "fix_duration"
|
name: Literal["fix_duration"] = "fix_duration"
|
||||||
duration: float = 0.5
|
duration: float = 0.5
|
||||||
@ -325,6 +75,12 @@ class FixDuration(torch.nn.Module):
|
|||||||
|
|
||||||
return torch.nn.functional.pad(wav, (0, self.length - length))
|
return torch.nn.functional.pad(wav, (0, self.length - length))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: FixDurationConfig, samplerate: int):
|
||||||
|
return cls(samplerate=samplerate, duration=config.duration)
|
||||||
|
|
||||||
|
|
||||||
|
audio_transforms.register(FixDurationConfig, FixDuration)
|
||||||
|
|
||||||
AudioTransform = Annotated[
|
AudioTransform = Annotated[
|
||||||
Union[
|
Union[
|
||||||
@ -336,47 +92,8 @@ AudioTransform = Annotated[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class AudioConfig(BaseConfig):
|
def build_audio_transform(
|
||||||
"""Configuration for loading and initial audio preprocessing."""
|
|
||||||
|
|
||||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
|
||||||
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
|
||||||
transforms: List[AudioTransform] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
def build_audio_loader(
|
|
||||||
config: Optional[AudioConfig] = None,
|
|
||||||
) -> AudioLoader:
|
|
||||||
"""Factory function to create an AudioLoader based on configuration."""
|
|
||||||
config = config or AudioConfig()
|
|
||||||
return SoundEventAudioLoader(
|
|
||||||
samplerate=config.samplerate,
|
|
||||||
config=config.resample,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_audio_transform_step(
|
|
||||||
config: AudioTransform,
|
config: AudioTransform,
|
||||||
samplerate: int,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
if config.name == "fix_duration":
|
return audio_transforms.build(config, samplerate)
|
||||||
return FixDuration(samplerate=samplerate, duration=config.duration)
|
|
||||||
|
|
||||||
if config.name == "scale_audio":
|
|
||||||
return PeakNormalize()
|
|
||||||
|
|
||||||
if config.name == "center_audio":
|
|
||||||
return CenterTensor()
|
|
||||||
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Audio preprocessing step {config.name} not implemented"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_audio_pipeline(config: AudioConfig) -> torch.nn.Module:
|
|
||||||
return torch.nn.Sequential(
|
|
||||||
*[
|
|
||||||
build_audio_transform_step(step, samplerate=config.samplerate)
|
|
||||||
for step in config.transforms
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,24 +1,22 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CenterTensor",
|
"center_tensor",
|
||||||
"PeakNormalize",
|
"peak_normalize",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class CenterTensor(torch.nn.Module):
|
def center_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
def forward(self, wav: torch.Tensor):
|
return tensor - tensor.mean()
|
||||||
return wav - wav.mean()
|
|
||||||
|
|
||||||
|
|
||||||
class PeakNormalize(torch.nn.Module):
|
def peak_normalize(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
def forward(self, wav: torch.Tensor):
|
max_value = tensor.abs().min()
|
||||||
max_value = wav.abs().min()
|
|
||||||
|
|
||||||
denominator = torch.where(
|
denominator = torch.where(
|
||||||
max_value == 0,
|
max_value == 0,
|
||||||
torch.tensor(1.0, device=wav.device, dtype=wav.dtype),
|
torch.tensor(1.0, device=tensor.device, dtype=tensor.dtype),
|
||||||
max_value,
|
max_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
return wav / denominator
|
return tensor / denominator
|
||||||
|
|||||||
62
src/batdetect2/preprocess/config.py
Normal file
62
src/batdetect2/preprocess/config.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.preprocess.audio import AudioTransform
|
||||||
|
from batdetect2.preprocess.spectrogram import (
|
||||||
|
FrequencyConfig,
|
||||||
|
PcenConfig,
|
||||||
|
ResizeConfig,
|
||||||
|
SpectralMeanSubstractionConfig,
|
||||||
|
SpectrogramTransform,
|
||||||
|
STFTConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"load_preprocessing_config",
|
||||||
|
"AudioTransform",
|
||||||
|
"PreprocessingConfig",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessingConfig(BaseConfig):
|
||||||
|
"""Unified configuration for the audio preprocessing pipeline.
|
||||||
|
|
||||||
|
Aggregates the configuration for both the initial audio processing stage
|
||||||
|
and the subsequent spectrogram generation stage.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
audio : AudioConfig
|
||||||
|
Configuration settings for the audio loading and initial waveform
|
||||||
|
processing steps (e.g., resampling, duration adjustment, scaling).
|
||||||
|
Defaults to default `AudioConfig` settings if omitted.
|
||||||
|
spectrogram : SpectrogramConfig
|
||||||
|
Configuration settings for the spectrogram generation process
|
||||||
|
(e.g., STFT parameters, frequency cropping, scaling, denoising,
|
||||||
|
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
audio_transforms: List[AudioTransform] = Field(default_factory=list)
|
||||||
|
|
||||||
|
spectrogram_transforms: List[SpectrogramTransform] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
PcenConfig(),
|
||||||
|
SpectralMeanSubstractionConfig(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||||
|
|
||||||
|
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||||
|
|
||||||
|
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def load_preprocessing_config(
|
||||||
|
path: PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> PreprocessingConfig:
|
||||||
|
return load_config(path, schema=PreprocessingConfig, field=field)
|
||||||
114
src/batdetect2/preprocess/preprocessor.py
Normal file
114
src/batdetect2/preprocess/preprocessor.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||||
|
from batdetect2.preprocess.audio import build_audio_transform
|
||||||
|
from batdetect2.preprocess.config import PreprocessingConfig
|
||||||
|
from batdetect2.preprocess.spectrogram import (
|
||||||
|
_spec_params_from_config,
|
||||||
|
build_spectrogram_builder,
|
||||||
|
build_spectrogram_crop,
|
||||||
|
build_spectrogram_resizer,
|
||||||
|
build_spectrogram_transform,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import PreprocessorProtocol
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Preprocessor",
|
||||||
|
"build_preprocessor",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Preprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||||
|
"""Standard implementation of the `Preprocessor` protocol."""
|
||||||
|
|
||||||
|
input_samplerate: int
|
||||||
|
output_samplerate: float
|
||||||
|
|
||||||
|
max_freq: float
|
||||||
|
min_freq: float
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PreprocessingConfig,
|
||||||
|
input_samplerate: int,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.audio_transforms = torch.nn.Sequential(
|
||||||
|
*[
|
||||||
|
build_audio_transform(step, samplerate=input_samplerate)
|
||||||
|
for step in config.audio_transforms
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.spectrogram_transforms = torch.nn.Sequential(
|
||||||
|
*[
|
||||||
|
build_spectrogram_transform(step, samplerate=input_samplerate)
|
||||||
|
for step in config.spectrogram_transforms
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.spectrogram_builder = build_spectrogram_builder(
|
||||||
|
config.stft,
|
||||||
|
samplerate=input_samplerate,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.spectrogram_crop = build_spectrogram_crop(
|
||||||
|
config.frequencies,
|
||||||
|
stft=config.stft,
|
||||||
|
samplerate=input_samplerate,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.spectrogram_resizer = build_spectrogram_resizer(config.size)
|
||||||
|
|
||||||
|
self.min_freq = config.frequencies.min_freq
|
||||||
|
self.max_freq = config.frequencies.max_freq
|
||||||
|
|
||||||
|
self.input_samplerate = input_samplerate
|
||||||
|
self.output_samplerate = compute_output_samplerate(
|
||||||
|
config,
|
||||||
|
input_samplerate=input_samplerate,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
wav = self.audio_transforms(wav)
|
||||||
|
spec = self.spectrogram_builder(wav)
|
||||||
|
return self.process_spectrogram(spec)
|
||||||
|
|
||||||
|
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.spectrogram_builder(wav)
|
||||||
|
|
||||||
|
def process_audio(self, wav: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self(wav)
|
||||||
|
|
||||||
|
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
spec = self.spectrogram_crop(spec)
|
||||||
|
spec = self.spectrogram_transforms(spec)
|
||||||
|
return self.spectrogram_resizer(spec)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_output_samplerate(
|
||||||
|
config: PreprocessingConfig,
|
||||||
|
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
|
) -> float:
|
||||||
|
_, hop_size = _spec_params_from_config(
|
||||||
|
config.stft, samplerate=input_samplerate
|
||||||
|
)
|
||||||
|
factor = config.size.resize_factor
|
||||||
|
return input_samplerate * factor / hop_size
|
||||||
|
|
||||||
|
|
||||||
|
def build_preprocessor(
|
||||||
|
config: Optional[PreprocessingConfig] = None,
|
||||||
|
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
|
) -> PreprocessorProtocol:
|
||||||
|
"""Factory function to build the standard preprocessor from configuration."""
|
||||||
|
config = config or PreprocessingConfig()
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building preprocessor with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
|
return Preprocessor(config=config, input_samplerate=input_samplerate)
|
||||||
@ -1,31 +1,21 @@
|
|||||||
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
||||||
|
|
||||||
from typing import (
|
from typing import Annotated, Callable, Literal, Optional, Union
|
||||||
Annotated,
|
|
||||||
Callable,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||||
from batdetect2.preprocess.common import PeakNormalize
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
from batdetect2.core.registries import Registry
|
||||||
|
from batdetect2.preprocess.common import peak_normalize
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"STFTConfig",
|
"STFTConfig",
|
||||||
"FrequencyConfig",
|
"build_spectrogram_transform",
|
||||||
"PcenConfig",
|
|
||||||
"SpectrogramConfig",
|
|
||||||
"build_spectrogram_builder",
|
"build_spectrogram_builder",
|
||||||
"MIN_FREQ",
|
|
||||||
"MAX_FREQ",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -60,6 +50,20 @@ class STFTConfig(BaseConfig):
|
|||||||
window_fn: str = "hann"
|
window_fn: str = "hann"
|
||||||
|
|
||||||
|
|
||||||
|
def build_spectrogram_builder(
|
||||||
|
config: STFTConfig,
|
||||||
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
n_fft, hop_length = _spec_params_from_config(config, samplerate=samplerate)
|
||||||
|
return torchaudio.transforms.Spectrogram(
|
||||||
|
n_fft=n_fft,
|
||||||
|
hop_length=hop_length,
|
||||||
|
window_fn=get_spectrogram_window(config.window_fn),
|
||||||
|
center=True,
|
||||||
|
power=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
|
def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
|
||||||
if name == "hann":
|
if name == "hann":
|
||||||
return torch.hann_window
|
return torch.hann_window
|
||||||
@ -81,24 +85,31 @@ def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _spec_params_from_config(samplerate: int, conf: STFTConfig):
|
def _spec_params_from_config(
|
||||||
n_fft = int(samplerate * conf.window_duration)
|
config: STFTConfig,
|
||||||
hop_length = int(n_fft * (1 - conf.window_overlap))
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
|
):
|
||||||
|
n_fft = int(samplerate * config.window_duration)
|
||||||
|
hop_length = int(n_fft * (1 - config.window_overlap))
|
||||||
return n_fft, hop_length
|
return n_fft, hop_length
|
||||||
|
|
||||||
|
|
||||||
def build_spectrogram_builder(
|
def _frequency_to_index(
|
||||||
samplerate: int,
|
freq: float,
|
||||||
conf: STFTConfig,
|
n_fft: int,
|
||||||
) -> torch.nn.Module:
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
|
) -> Optional[int]:
|
||||||
return torchaudio.transforms.Spectrogram(
|
alpha = freq * 2 / samplerate
|
||||||
n_fft=n_fft,
|
height = np.floor(n_fft / 2) + 1
|
||||||
hop_length=hop_length,
|
index = int(np.floor(alpha * height))
|
||||||
window_fn=get_spectrogram_window(conf.window_fn),
|
|
||||||
center=True,
|
if index <= 0:
|
||||||
power=1,
|
return None
|
||||||
)
|
|
||||||
|
if index >= height:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return index
|
||||||
|
|
||||||
|
|
||||||
class FrequencyConfig(BaseConfig):
|
class FrequencyConfig(BaseConfig):
|
||||||
@ -114,36 +125,36 @@ class FrequencyConfig(BaseConfig):
|
|||||||
Frequencies below this value will be cropped. Must be >= 0.
|
Frequencies below this value will be cropped. Must be >= 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_freq: int = Field(default=120_000, ge=0)
|
max_freq: int = Field(default=MAX_FREQ, ge=0)
|
||||||
min_freq: int = Field(default=10_000, ge=0)
|
min_freq: int = Field(default=MIN_FREQ, ge=0)
|
||||||
|
|
||||||
|
|
||||||
def _frequency_to_index(
|
class FrequencyCrop(torch.nn.Module):
|
||||||
freq: float,
|
|
||||||
samplerate: int,
|
|
||||||
n_fft: int,
|
|
||||||
) -> Optional[int]:
|
|
||||||
alpha = freq * 2 / samplerate
|
|
||||||
height = np.floor(n_fft / 2) + 1
|
|
||||||
index = int(np.floor(alpha * height))
|
|
||||||
|
|
||||||
if index <= 0:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if index >= height:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return index
|
|
||||||
|
|
||||||
|
|
||||||
class FrequencyClip(torch.nn.Module):
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
low_index: Optional[int] = None,
|
samplerate: int,
|
||||||
high_index: Optional[int] = None,
|
n_fft: int,
|
||||||
|
min_freq: Optional[int] = None,
|
||||||
|
max_freq: Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.n_fft = n_fft
|
||||||
|
self.samplerate = samplerate
|
||||||
|
self.min_freq = min_freq
|
||||||
|
self.max_freq = max_freq
|
||||||
|
|
||||||
|
low_index = None
|
||||||
|
if min_freq is not None:
|
||||||
|
low_index = _frequency_to_index(
|
||||||
|
min_freq, self.samplerate, self.n_fft
|
||||||
|
)
|
||||||
self.low_index = low_index
|
self.low_index = low_index
|
||||||
|
|
||||||
|
high_index = None
|
||||||
|
if max_freq is not None:
|
||||||
|
high_index = _frequency_to_index(
|
||||||
|
max_freq, self.samplerate, self.n_fft
|
||||||
|
)
|
||||||
self.high_index = high_index
|
self.high_index = high_index
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
@ -164,6 +175,62 @@ class FrequencyClip(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_spectrogram_crop(
|
||||||
|
config: FrequencyConfig,
|
||||||
|
stft: Optional[STFTConfig] = None,
|
||||||
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
|
) -> torch.nn.Module:
|
||||||
|
stft = stft or STFTConfig()
|
||||||
|
n_fft, _ = _spec_params_from_config(stft, samplerate=samplerate)
|
||||||
|
return FrequencyCrop(
|
||||||
|
samplerate=samplerate,
|
||||||
|
n_fft=n_fft,
|
||||||
|
min_freq=config.min_freq,
|
||||||
|
max_freq=config.max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeConfig(BaseConfig):
|
||||||
|
name: Literal["resize_spec"] = "resize_spec"
|
||||||
|
height: int = 128
|
||||||
|
resize_factor: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class ResizeSpec(torch.nn.Module):
|
||||||
|
def __init__(self, height: int, time_factor: float):
|
||||||
|
super().__init__()
|
||||||
|
self.height = height
|
||||||
|
self.time_factor = time_factor
|
||||||
|
|
||||||
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
current_length = spec.shape[-1]
|
||||||
|
target_length = int(self.time_factor * current_length)
|
||||||
|
|
||||||
|
original_ndim = spec.ndim
|
||||||
|
while spec.ndim < 4:
|
||||||
|
spec = spec.unsqueeze(0)
|
||||||
|
|
||||||
|
resized = torch.nn.functional.interpolate(
|
||||||
|
spec,
|
||||||
|
size=(self.height, target_length),
|
||||||
|
mode="bilinear",
|
||||||
|
)
|
||||||
|
|
||||||
|
while resized.ndim != original_ndim:
|
||||||
|
resized = resized.squeeze(0)
|
||||||
|
|
||||||
|
return resized
|
||||||
|
|
||||||
|
|
||||||
|
def build_spectrogram_resizer(config: ResizeConfig) -> torch.nn.Module:
|
||||||
|
return ResizeSpec(height=config.height, time_factor=config.resize_factor)
|
||||||
|
|
||||||
|
|
||||||
|
spectrogram_transforms: Registry[torch.nn.Module, [int]] = Registry(
|
||||||
|
"spectrogram_transform"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PcenConfig(BaseConfig):
|
class PcenConfig(BaseConfig):
|
||||||
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
|
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
|
||||||
|
|
||||||
@ -182,7 +249,7 @@ class PCEN(torch.nn.Module):
|
|||||||
bias: float = 2.0,
|
bias: float = 2.0,
|
||||||
power: float = 0.5,
|
power: float = 0.5,
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
dtype=torch.float64,
|
dtype=torch.float32,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.smoothing_constant = smoothing_constant
|
self.smoothing_constant = smoothing_constant
|
||||||
@ -218,6 +285,19 @@ class PCEN(torch.nn.Module):
|
|||||||
* torch.expm1(self.power * torch.log1p(S * smooth / self.bias))
|
* torch.expm1(self.power * torch.log1p(S * smooth / self.bias))
|
||||||
).to(spec.dtype)
|
).to(spec.dtype)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: PcenConfig, samplerate: int):
|
||||||
|
smooth = _compute_smoothing_constant(samplerate, config.time_constant)
|
||||||
|
return cls(
|
||||||
|
smoothing_constant=smooth,
|
||||||
|
gain=config.gain,
|
||||||
|
bias=config.bias,
|
||||||
|
power=config.power,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
spectrogram_transforms.register(PcenConfig, PCEN)
|
||||||
|
|
||||||
|
|
||||||
def _compute_smoothing_constant(
|
def _compute_smoothing_constant(
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
@ -241,16 +321,26 @@ class ToPower(torch.nn.Module):
|
|||||||
return spec**2
|
return spec**2
|
||||||
|
|
||||||
|
|
||||||
def _build_amplitude_scaler(conf: ScaleAmplitudeConfig) -> torch.nn.Module:
|
_scalers = {
|
||||||
if conf.scale == "db":
|
"db": torchaudio.transforms.AmplitudeToDB,
|
||||||
return torchaudio.transforms.AmplitudeToDB()
|
"power": ToPower,
|
||||||
|
}
|
||||||
|
|
||||||
if conf.scale == "power":
|
|
||||||
return ToPower()
|
|
||||||
|
|
||||||
raise NotImplementedError(
|
class ScaleAmplitude(torch.nn.Module):
|
||||||
f"Amplitude scaling {conf.scale} not implemented"
|
def __init__(self, scale: Literal["power", "db"]):
|
||||||
)
|
self.scale = scale
|
||||||
|
self.scaler = _scalers[scale]()
|
||||||
|
|
||||||
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
return self.scaler(spec)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: ScaleAmplitudeConfig, samplerate: int):
|
||||||
|
return cls(scale=config.scale)
|
||||||
|
|
||||||
|
|
||||||
|
spectrogram_transforms.register(ScaleAmplitudeConfig, ScaleAmplitude)
|
||||||
|
|
||||||
|
|
||||||
class SpectralMeanSubstractionConfig(BaseConfig):
|
class SpectralMeanSubstractionConfig(BaseConfig):
|
||||||
@ -262,43 +352,36 @@ class SpectralMeanSubstraction(torch.nn.Module):
|
|||||||
mean = spec.mean(-1, keepdim=True)
|
mean = spec.mean(-1, keepdim=True)
|
||||||
return (spec - mean).clamp(min=0)
|
return (spec - mean).clamp(min=0)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
class ResizeConfig(BaseConfig):
|
def from_config(
|
||||||
name: Literal["resize_spec"] = "resize_spec"
|
cls,
|
||||||
height: int = 128
|
config: SpectralMeanSubstractionConfig,
|
||||||
resize_factor: float = 0.5
|
samplerate: int,
|
||||||
|
):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
class ResizeSpec(torch.nn.Module):
|
spectrogram_transforms.register(
|
||||||
def __init__(self, height: int, time_factor: float):
|
SpectralMeanSubstractionConfig,
|
||||||
super().__init__()
|
SpectralMeanSubstraction,
|
||||||
self.height = height
|
)
|
||||||
self.time_factor = time_factor
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
|
||||||
current_length = spec.shape[-1]
|
|
||||||
target_length = int(self.time_factor * current_length)
|
|
||||||
|
|
||||||
original_ndim = spec.ndim
|
|
||||||
while spec.ndim < 4:
|
|
||||||
spec = spec.unsqueeze(0)
|
|
||||||
|
|
||||||
resized = torch.nn.functional.interpolate(
|
|
||||||
spec,
|
|
||||||
size=(self.height, target_length),
|
|
||||||
mode="bilinear",
|
|
||||||
)
|
|
||||||
|
|
||||||
while resized.ndim != original_ndim:
|
|
||||||
resized = resized.squeeze(0)
|
|
||||||
|
|
||||||
return resized
|
|
||||||
|
|
||||||
|
|
||||||
class PeakNormalizeConfig(BaseConfig):
|
class PeakNormalizeConfig(BaseConfig):
|
||||||
name: Literal["peak_normalize"] = "peak_normalize"
|
name: Literal["peak_normalize"] = "peak_normalize"
|
||||||
|
|
||||||
|
|
||||||
|
class PeakNormalize(torch.nn.Module):
|
||||||
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
return peak_normalize(spec)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: PeakNormalizeConfig, samplerate: int):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
spectrogram_transforms.register(PeakNormalizeConfig, PeakNormalize)
|
||||||
|
|
||||||
SpectrogramTransform = Annotated[
|
SpectrogramTransform = Annotated[
|
||||||
Union[
|
Union[
|
||||||
PcenConfig,
|
PcenConfig,
|
||||||
@ -310,114 +393,8 @@ SpectrogramTransform = Annotated[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class SpectrogramConfig(BaseConfig):
|
|
||||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
|
||||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
|
||||||
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
|
||||||
transforms: Sequence[SpectrogramTransform] = Field(
|
|
||||||
default_factory=lambda: [
|
|
||||||
PcenConfig(),
|
|
||||||
SpectralMeanSubstractionConfig(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_spectrogram_transform_step(
|
|
||||||
step: SpectrogramTransform,
|
|
||||||
samplerate: int,
|
|
||||||
) -> torch.nn.Module:
|
|
||||||
if step.name == "pcen":
|
|
||||||
return PCEN(
|
|
||||||
smoothing_constant=_compute_smoothing_constant(
|
|
||||||
samplerate=samplerate,
|
|
||||||
time_constant=step.time_constant,
|
|
||||||
),
|
|
||||||
gain=step.gain,
|
|
||||||
bias=step.bias,
|
|
||||||
power=step.power,
|
|
||||||
)
|
|
||||||
|
|
||||||
if step.name == "scale_amplitude":
|
|
||||||
return _build_amplitude_scaler(step)
|
|
||||||
|
|
||||||
if step.name == "spectral_mean_substraction":
|
|
||||||
return SpectralMeanSubstraction()
|
|
||||||
|
|
||||||
if step.name == "peak_normalize":
|
|
||||||
return PeakNormalize()
|
|
||||||
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Spectrogram preprocessing step {step.name} not implemented"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_spectrogram_transform(
|
def build_spectrogram_transform(
|
||||||
|
config: SpectrogramTransform,
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
conf: SpectrogramConfig,
|
|
||||||
) -> torch.nn.Module:
|
) -> torch.nn.Module:
|
||||||
return torch.nn.Sequential(
|
return spectrogram_transforms.build(config, samplerate)
|
||||||
*[
|
|
||||||
_build_spectrogram_transform_step(step, samplerate=samplerate)
|
|
||||||
for step in conf.transforms
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SpectrogramPipeline(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
spec_builder: torch.nn.Module,
|
|
||||||
freq_cutter: torch.nn.Module,
|
|
||||||
transforms: torch.nn.Module,
|
|
||||||
resizer: torch.nn.Module,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.spec_builder = spec_builder
|
|
||||||
self.freq_cutter = freq_cutter
|
|
||||||
self.transforms = transforms
|
|
||||||
self.resizer = resizer
|
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
|
||||||
spec = self.spec_builder(wav)
|
|
||||||
spec = self.freq_cutter(spec)
|
|
||||||
spec = self.transforms(spec)
|
|
||||||
return self.resizer(spec)
|
|
||||||
|
|
||||||
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.spec_builder(wav)
|
|
||||||
|
|
||||||
def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.freq_cutter(spec)
|
|
||||||
|
|
||||||
def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.transforms(spec)
|
|
||||||
|
|
||||||
def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
|
|
||||||
return self.resizer(spec)
|
|
||||||
|
|
||||||
|
|
||||||
def build_spectrogram_pipeline(
|
|
||||||
samplerate: int,
|
|
||||||
conf: SpectrogramConfig,
|
|
||||||
) -> SpectrogramPipeline:
|
|
||||||
spec_builder = build_spectrogram_builder(samplerate, conf.stft)
|
|
||||||
n_fft, _ = _spec_params_from_config(samplerate, conf.stft)
|
|
||||||
cutter = FrequencyClip(
|
|
||||||
low_index=_frequency_to_index(
|
|
||||||
conf.frequencies.min_freq, samplerate, n_fft
|
|
||||||
),
|
|
||||||
high_index=_frequency_to_index(
|
|
||||||
conf.frequencies.max_freq, samplerate, n_fft
|
|
||||||
),
|
|
||||||
)
|
|
||||||
transforms = build_spectrogram_transform(samplerate, conf)
|
|
||||||
resizer = ResizeSpec(
|
|
||||||
height=conf.size.height,
|
|
||||||
time_factor=conf.size.resize_factor,
|
|
||||||
)
|
|
||||||
return SpectrogramPipeline(
|
|
||||||
spec_builder=spec_builder,
|
|
||||||
freq_cutter=cutter,
|
|
||||||
transforms=transforms,
|
|
||||||
resizer=resizer,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -1,17 +1,6 @@
|
|||||||
"""BatDetect2 Target Definition system."""
|
"""BatDetect2 Target Definition system."""
|
||||||
|
|
||||||
from collections import Counter
|
|
||||||
from typing import Iterable, List, Optional, Tuple
|
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field, field_validator
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
|
||||||
from batdetect2.data.conditions import build_sound_event_condition
|
|
||||||
from batdetect2.targets.classes import (
|
from batdetect2.targets.classes import (
|
||||||
DEFAULT_CLASSES,
|
|
||||||
DEFAULT_DETECTION_CLASS,
|
|
||||||
SoundEventDecoder,
|
SoundEventDecoder,
|
||||||
SoundEventEncoder,
|
SoundEventEncoder,
|
||||||
TargetClassConfig,
|
TargetClassConfig,
|
||||||
@ -19,23 +8,29 @@ from batdetect2.targets.classes import (
|
|||||||
build_sound_event_encoder,
|
build_sound_event_encoder,
|
||||||
get_class_names_from_config,
|
get_class_names_from_config,
|
||||||
)
|
)
|
||||||
|
from batdetect2.targets.config import TargetConfig, load_target_config
|
||||||
from batdetect2.targets.rois import (
|
from batdetect2.targets.rois import (
|
||||||
AnchorBBoxMapperConfig,
|
AnchorBBoxMapperConfig,
|
||||||
ROIMapperConfig,
|
ROIMapperConfig,
|
||||||
ROITargetMapper,
|
ROITargetMapper,
|
||||||
build_roi_mapper,
|
build_roi_mapper,
|
||||||
)
|
)
|
||||||
|
from batdetect2.targets.targets import (
|
||||||
|
Targets,
|
||||||
|
build_targets,
|
||||||
|
iterate_encoded_sound_events,
|
||||||
|
load_targets,
|
||||||
|
)
|
||||||
from batdetect2.targets.terms import (
|
from batdetect2.targets.terms import (
|
||||||
call_type,
|
call_type,
|
||||||
data_source,
|
data_source,
|
||||||
generic_class,
|
generic_class,
|
||||||
individual,
|
individual,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnchorBBoxMapperConfig",
|
"AnchorBBoxMapperConfig",
|
||||||
"DEFAULT_TARGET_CONFIG",
|
"ROIMapperConfig",
|
||||||
"ROITargetMapper",
|
"ROITargetMapper",
|
||||||
"SoundEventDecoder",
|
"SoundEventDecoder",
|
||||||
"SoundEventEncoder",
|
"SoundEventEncoder",
|
||||||
@ -45,365 +40,13 @@ __all__ = [
|
|||||||
"build_roi_mapper",
|
"build_roi_mapper",
|
||||||
"build_sound_event_decoder",
|
"build_sound_event_decoder",
|
||||||
"build_sound_event_encoder",
|
"build_sound_event_encoder",
|
||||||
|
"build_targets",
|
||||||
"call_type",
|
"call_type",
|
||||||
"data_source",
|
"data_source",
|
||||||
"generic_class",
|
"generic_class",
|
||||||
"get_class_names_from_config",
|
"get_class_names_from_config",
|
||||||
"individual",
|
"individual",
|
||||||
|
"iterate_encoded_sound_events",
|
||||||
"load_target_config",
|
"load_target_config",
|
||||||
|
"load_targets",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class TargetConfig(BaseConfig):
|
|
||||||
detection_target: TargetClassConfig = Field(
|
|
||||||
default=DEFAULT_DETECTION_CLASS
|
|
||||||
)
|
|
||||||
|
|
||||||
classification_targets: List[TargetClassConfig] = Field(
|
|
||||||
default_factory=lambda: DEFAULT_CLASSES
|
|
||||||
)
|
|
||||||
|
|
||||||
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
|
||||||
|
|
||||||
@field_validator("classification_targets")
|
|
||||||
def check_unique_class_names(cls, v: List[TargetClassConfig]):
|
|
||||||
"""Ensure all defined class names are unique."""
|
|
||||||
names = [c.name for c in v]
|
|
||||||
|
|
||||||
if len(names) != len(set(names)):
|
|
||||||
name_counts = Counter(names)
|
|
||||||
duplicates = [
|
|
||||||
name for name, count in name_counts.items() if count > 1
|
|
||||||
]
|
|
||||||
raise ValueError(
|
|
||||||
"Class names must be unique. Found duplicates: "
|
|
||||||
f"{', '.join(duplicates)}"
|
|
||||||
)
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
def load_target_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> TargetConfig:
|
|
||||||
"""Load the unified target configuration from a file.
|
|
||||||
|
|
||||||
Reads a configuration file (typically YAML) and validates it against the
|
|
||||||
`TargetConfig` schema, potentially extracting data from a nested field.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : data.PathLike
|
|
||||||
Path to the configuration file.
|
|
||||||
field : str, optional
|
|
||||||
Dot-separated path to a nested section within the file containing the
|
|
||||||
target configuration. If None, the entire file content is used.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
TargetConfig
|
|
||||||
The loaded and validated unified target configuration object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the config file path does not exist.
|
|
||||||
yaml.YAMLError
|
|
||||||
If the file content is not valid YAML.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the loaded configuration data does not conform to the
|
|
||||||
`TargetConfig` schema (including validation within nested configs
|
|
||||||
like `ClassesConfig`).
|
|
||||||
KeyError, TypeError
|
|
||||||
If `field` specifies an invalid path within the loaded data.
|
|
||||||
"""
|
|
||||||
return load_config(path=path, schema=TargetConfig, field=field)
|
|
||||||
|
|
||||||
|
|
||||||
class Targets(TargetProtocol):
|
|
||||||
"""Encapsulates the complete configured target definition pipeline.
|
|
||||||
|
|
||||||
This class implements the `TargetProtocol`, holding the configured
|
|
||||||
functions for filtering, transforming, encoding (tags to class name),
|
|
||||||
decoding (class name to tags), and mapping ROIs (geometry to position/size
|
|
||||||
and back). It provides a high-level interface to apply these steps and
|
|
||||||
access relevant metadata like class names and dimension names.
|
|
||||||
|
|
||||||
Instances are typically created using the `build_targets` factory function
|
|
||||||
or the `load_targets` convenience loader.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
class_names : List[str]
|
|
||||||
An ordered list of the unique names of the specific target classes
|
|
||||||
defined in the configuration.
|
|
||||||
generic_class_tags : List[data.Tag]
|
|
||||||
A list of `soundevent.data.Tag` objects representing the configured
|
|
||||||
generic class category (used when no specific class matches).
|
|
||||||
dimension_names : List[str]
|
|
||||||
The names of the size dimensions handled by the ROI mapper
|
|
||||||
(e.g., ['width', 'height']).
|
|
||||||
"""
|
|
||||||
|
|
||||||
class_names: List[str]
|
|
||||||
detection_class_tags: List[data.Tag]
|
|
||||||
dimension_names: List[str]
|
|
||||||
detection_class_name: str
|
|
||||||
|
|
||||||
def __init__(self, config: TargetConfig):
|
|
||||||
"""Initialize the Targets object."""
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self._filter_fn = build_sound_event_condition(
|
|
||||||
config.detection_target.match_if
|
|
||||||
)
|
|
||||||
self._encode_fn = build_sound_event_encoder(
|
|
||||||
config.classification_targets
|
|
||||||
)
|
|
||||||
self._decode_fn = build_sound_event_decoder(
|
|
||||||
config.classification_targets
|
|
||||||
)
|
|
||||||
|
|
||||||
self._roi_mapper = build_roi_mapper(config.roi)
|
|
||||||
|
|
||||||
self.dimension_names = self._roi_mapper.dimension_names
|
|
||||||
|
|
||||||
self.class_names = get_class_names_from_config(
|
|
||||||
config.classification_targets
|
|
||||||
)
|
|
||||||
|
|
||||||
self.detection_class_name = config.detection_target.name
|
|
||||||
self.detection_class_tags = config.detection_target.assign_tags
|
|
||||||
|
|
||||||
self._roi_mapper_overrides = {
|
|
||||||
class_config.name: build_roi_mapper(class_config.roi)
|
|
||||||
for class_config in config.classification_targets
|
|
||||||
if class_config.roi is not None
|
|
||||||
}
|
|
||||||
|
|
||||||
for class_name in self._roi_mapper_overrides:
|
|
||||||
if class_name not in self.class_names:
|
|
||||||
# TODO: improve this warning
|
|
||||||
logger.warning(
|
|
||||||
"The ROI mapper overrides contains a class ({class_name}) "
|
|
||||||
"not present in the class names.",
|
|
||||||
class_name=class_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
|
||||||
"""Apply the configured filter to a sound event annotation.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event : data.SoundEventAnnotation
|
|
||||||
The annotation to filter.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
bool
|
|
||||||
True if the annotation should be kept (passes the filter),
|
|
||||||
False otherwise. If no filter was configured, always returns True.
|
|
||||||
"""
|
|
||||||
return self._filter_fn(sound_event)
|
|
||||||
|
|
||||||
def encode_class(
|
|
||||||
self, sound_event: data.SoundEventAnnotation
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""Encode a sound event annotation to its target class name.
|
|
||||||
|
|
||||||
Applies the configured class definition rules (including priority)
|
|
||||||
to determine the specific class name for the annotation.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event : data.SoundEventAnnotation
|
|
||||||
The annotation to encode. Note: This should typically be called
|
|
||||||
*after* applying any transformations via the `transform` method.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
str or None
|
|
||||||
The name of the matched target class, or None if the annotation
|
|
||||||
does not match any specific class rule (i.e., it belongs to the
|
|
||||||
generic category).
|
|
||||||
"""
|
|
||||||
return self._encode_fn(sound_event)
|
|
||||||
|
|
||||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
|
||||||
"""Decode a predicted class name back into representative tags.
|
|
||||||
|
|
||||||
Uses the configured mapping (based on `TargetClass.output_tags` or
|
|
||||||
`TargetClass.tags`) to convert a class name string into a list of
|
|
||||||
`soundevent.data.Tag` objects.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
class_label : str
|
|
||||||
The class name to decode.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[data.Tag]
|
|
||||||
The list of tags corresponding to the input class name.
|
|
||||||
"""
|
|
||||||
return self._decode_fn(class_label)
|
|
||||||
|
|
||||||
def encode_roi(
|
|
||||||
self, sound_event: data.SoundEventAnnotation
|
|
||||||
) -> tuple[Position, Size]:
|
|
||||||
"""Extract the target reference position from the annotation's roi.
|
|
||||||
|
|
||||||
Delegates to the internal ROI mapper's `get_roi_position` method.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_event : data.SoundEventAnnotation
|
|
||||||
The annotation containing the geometry (ROI).
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Tuple[float, float]
|
|
||||||
The reference position `(time, frequency)`.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the annotation lacks geometry.
|
|
||||||
"""
|
|
||||||
class_name = self.encode_class(sound_event)
|
|
||||||
|
|
||||||
if class_name in self._roi_mapper_overrides:
|
|
||||||
return self._roi_mapper_overrides[class_name].encode(
|
|
||||||
sound_event.sound_event
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._roi_mapper.encode(sound_event.sound_event)
|
|
||||||
|
|
||||||
def decode_roi(
|
|
||||||
self,
|
|
||||||
position: Position,
|
|
||||||
size: Size,
|
|
||||||
class_name: Optional[str] = None,
|
|
||||||
) -> data.Geometry:
|
|
||||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
|
||||||
|
|
||||||
Delegates to the internal ROI mapper's `recover_roi` method, which
|
|
||||||
un-scales the dimensions and reconstructs the geometry (typically a
|
|
||||||
`BoundingBox`).
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
pos : Tuple[float, float]
|
|
||||||
The reference position `(time, frequency)`.
|
|
||||||
dims : np.ndarray
|
|
||||||
NumPy array with size dimensions (e.g., from model prediction),
|
|
||||||
matching the order in `self.dimension_names`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
data.Geometry
|
|
||||||
The reconstructed geometry (typically `BoundingBox`).
|
|
||||||
"""
|
|
||||||
if class_name in self._roi_mapper_overrides:
|
|
||||||
return self._roi_mapper_overrides[class_name].decode(
|
|
||||||
position,
|
|
||||||
size,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self._roi_mapper.decode(position, size)
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
|
||||||
classification_targets=DEFAULT_CLASSES,
|
|
||||||
detection_target=DEFAULT_DETECTION_CLASS,
|
|
||||||
roi=AnchorBBoxMapperConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
|
||||||
"""Build a Targets object from a loaded TargetConfig.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config : TargetConfig
|
|
||||||
The loaded and validated unified target configuration object.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Targets
|
|
||||||
An initialized `Targets` object ready for use.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
KeyError
|
|
||||||
If term keys or derivation function keys specified in the `config`
|
|
||||||
are not found in their respective registries.
|
|
||||||
ImportError, AttributeError, TypeError
|
|
||||||
If dynamic import of a derivation function fails (when configured).
|
|
||||||
"""
|
|
||||||
config = config or DEFAULT_TARGET_CONFIG
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Building targets with config: \n{}",
|
|
||||||
lambda: config.to_yaml_string(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return Targets(config=config)
|
|
||||||
|
|
||||||
|
|
||||||
def load_targets(
|
|
||||||
config_path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> Targets:
|
|
||||||
"""Load a Targets object directly from a configuration file.
|
|
||||||
|
|
||||||
This convenience factory method loads the `TargetConfig` from the
|
|
||||||
specified file path and then calls `Targets.from_config` to build
|
|
||||||
the fully initialized `Targets` object.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config_path : data.PathLike
|
|
||||||
Path to the configuration file (e.g., YAML).
|
|
||||||
field : str, optional
|
|
||||||
Dot-separated path to a nested section within the file containing
|
|
||||||
the target configuration. If None, the entire file content is used.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Targets
|
|
||||||
An initialized `Targets` object ready for use.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
|
|
||||||
TypeError
|
|
||||||
Errors raised during file loading, validation, or extraction via
|
|
||||||
`load_target_config`.
|
|
||||||
KeyError, ImportError, AttributeError, TypeError
|
|
||||||
Errors raised during the build process by `Targets.from_config`
|
|
||||||
(e.g., missing keys in registries, failed imports).
|
|
||||||
"""
|
|
||||||
config = load_target_config(
|
|
||||||
config_path,
|
|
||||||
field=field,
|
|
||||||
)
|
|
||||||
return build_targets(config)
|
|
||||||
|
|
||||||
|
|
||||||
def iterate_encoded_sound_events(
|
|
||||||
sound_events: Iterable[data.SoundEventAnnotation],
|
|
||||||
targets: TargetProtocol,
|
|
||||||
) -> Iterable[Tuple[Optional[str], Position, Size]]:
|
|
||||||
for sound_event in sound_events:
|
|
||||||
if not targets.filter(sound_event):
|
|
||||||
continue
|
|
||||||
|
|
||||||
geometry = sound_event.sound_event.geometry
|
|
||||||
|
|
||||||
if geometry is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
class_name = targets.encode_class(sound_event)
|
|
||||||
position, size = targets.encode_roi(sound_event)
|
|
||||||
|
|
||||||
yield class_name, position, size
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import Dict, List, Optional
|
|||||||
from pydantic import Field, PrivateAttr, computed_field, model_validator
|
from pydantic import Field, PrivateAttr, computed_field, model_validator
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data.conditions import (
|
from batdetect2.data.conditions import (
|
||||||
AllOfConfig,
|
AllOfConfig,
|
||||||
HasAllTagsConfig,
|
HasAllTagsConfig,
|
||||||
|
|||||||
84
src/batdetect2/targets/config.py
Normal file
84
src/batdetect2/targets/config.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
from collections import Counter
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from pydantic import Field, field_validator
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.targets.classes import (
|
||||||
|
DEFAULT_CLASSES,
|
||||||
|
DEFAULT_DETECTION_CLASS,
|
||||||
|
TargetClassConfig,
|
||||||
|
)
|
||||||
|
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TargetConfig",
|
||||||
|
"load_target_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TargetConfig(BaseConfig):
|
||||||
|
detection_target: TargetClassConfig = Field(
|
||||||
|
default=DEFAULT_DETECTION_CLASS
|
||||||
|
)
|
||||||
|
|
||||||
|
classification_targets: List[TargetClassConfig] = Field(
|
||||||
|
default_factory=lambda: DEFAULT_CLASSES
|
||||||
|
)
|
||||||
|
|
||||||
|
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
||||||
|
|
||||||
|
@field_validator("classification_targets")
|
||||||
|
def check_unique_class_names(cls, v: List[TargetClassConfig]):
|
||||||
|
"""Ensure all defined class names are unique."""
|
||||||
|
names = [c.name for c in v]
|
||||||
|
|
||||||
|
if len(names) != len(set(names)):
|
||||||
|
name_counts = Counter(names)
|
||||||
|
duplicates = [
|
||||||
|
name for name, count in name_counts.items() if count > 1
|
||||||
|
]
|
||||||
|
raise ValueError(
|
||||||
|
"Class names must be unique. Found duplicates: "
|
||||||
|
f"{', '.join(duplicates)}"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def load_target_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> TargetConfig:
|
||||||
|
"""Load the unified target configuration from a file.
|
||||||
|
|
||||||
|
Reads a configuration file (typically YAML) and validates it against the
|
||||||
|
`TargetConfig` schema, potentially extracting data from a nested field.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : data.PathLike
|
||||||
|
Path to the configuration file.
|
||||||
|
field : str, optional
|
||||||
|
Dot-separated path to a nested section within the file containing the
|
||||||
|
target configuration. If None, the entire file content is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
TargetConfig
|
||||||
|
The loaded and validated unified target configuration object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
yaml.YAMLError
|
||||||
|
If the file content is not valid YAML.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the loaded configuration data does not conform to the
|
||||||
|
`TargetConfig` schema (including validation within nested configs
|
||||||
|
like `ClassesConfig`).
|
||||||
|
KeyError, TypeError
|
||||||
|
If `field` specifies an invalid path within the loaded data.
|
||||||
|
"""
|
||||||
|
return load_config(path=path, schema=TargetConfig, field=field)
|
||||||
@ -26,12 +26,17 @@ import numpy as np
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
|
from batdetect2.core.arrays import spec_to_xarray
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.preprocess.audio import build_audio_loader
|
from batdetect2.typing import (
|
||||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
AudioLoader,
|
||||||
from batdetect2.typing.targets import Position, ROITargetMapper, Size
|
Position,
|
||||||
from batdetect2.utils.arrays import spec_to_xarray
|
PreprocessorProtocol,
|
||||||
|
ROITargetMapper,
|
||||||
|
Size,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Anchor",
|
"Anchor",
|
||||||
@ -260,6 +265,7 @@ class PeakEnergyBBoxMapperConfig(BaseConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name: Literal["peak_energy_bbox"] = "peak_energy_bbox"
|
name: Literal["peak_energy_bbox"] = "peak_energy_bbox"
|
||||||
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
preprocessing: PreprocessingConfig = Field(
|
preprocessing: PreprocessingConfig = Field(
|
||||||
default_factory=PreprocessingConfig
|
default_factory=PreprocessingConfig
|
||||||
)
|
)
|
||||||
@ -451,8 +457,11 @@ def build_roi_mapper(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if config.name == "peak_energy_bbox":
|
if config.name == "peak_energy_bbox":
|
||||||
preprocessor = build_preprocessor(config.preprocessing)
|
audio_loader = build_audio_loader(config=config.audio)
|
||||||
audio_loader = build_audio_loader(config.preprocessing.audio)
|
preprocessor = build_preprocessor(
|
||||||
|
config.preprocessing,
|
||||||
|
input_samplerate=audio_loader.samplerate,
|
||||||
|
)
|
||||||
return PeakEnergyBBoxMapper(
|
return PeakEnergyBBoxMapper(
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
|
|||||||
308
src/batdetect2/targets/targets.py
Normal file
308
src/batdetect2/targets/targets.py
Normal file
@ -0,0 +1,308 @@
|
|||||||
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.data.conditions import build_sound_event_condition
|
||||||
|
from batdetect2.targets.classes import (
|
||||||
|
DEFAULT_CLASSES,
|
||||||
|
DEFAULT_DETECTION_CLASS,
|
||||||
|
build_sound_event_decoder,
|
||||||
|
build_sound_event_encoder,
|
||||||
|
get_class_names_from_config,
|
||||||
|
)
|
||||||
|
from batdetect2.targets.config import TargetConfig, load_target_config
|
||||||
|
from batdetect2.targets.rois import (
|
||||||
|
AnchorBBoxMapperConfig,
|
||||||
|
build_roi_mapper,
|
||||||
|
)
|
||||||
|
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
class Targets(TargetProtocol):
|
||||||
|
"""Encapsulates the complete configured target definition pipeline.
|
||||||
|
|
||||||
|
This class implements the `TargetProtocol`, holding the configured
|
||||||
|
functions for filtering, transforming, encoding (tags to class name),
|
||||||
|
decoding (class name to tags), and mapping ROIs (geometry to position/size
|
||||||
|
and back). It provides a high-level interface to apply these steps and
|
||||||
|
access relevant metadata like class names and dimension names.
|
||||||
|
|
||||||
|
Instances are typically created using the `build_targets` factory function
|
||||||
|
or the `load_targets` convenience loader.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
class_names : List[str]
|
||||||
|
An ordered list of the unique names of the specific target classes
|
||||||
|
defined in the configuration.
|
||||||
|
generic_class_tags : List[data.Tag]
|
||||||
|
A list of `soundevent.data.Tag` objects representing the configured
|
||||||
|
generic class category (used when no specific class matches).
|
||||||
|
dimension_names : List[str]
|
||||||
|
The names of the size dimensions handled by the ROI mapper
|
||||||
|
(e.g., ['width', 'height']).
|
||||||
|
"""
|
||||||
|
|
||||||
|
class_names: List[str]
|
||||||
|
detection_class_tags: List[data.Tag]
|
||||||
|
dimension_names: List[str]
|
||||||
|
detection_class_name: str
|
||||||
|
|
||||||
|
def __init__(self, config: TargetConfig):
|
||||||
|
"""Initialize the Targets object."""
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self._filter_fn = build_sound_event_condition(
|
||||||
|
config.detection_target.match_if
|
||||||
|
)
|
||||||
|
self._encode_fn = build_sound_event_encoder(
|
||||||
|
config.classification_targets
|
||||||
|
)
|
||||||
|
self._decode_fn = build_sound_event_decoder(
|
||||||
|
config.classification_targets
|
||||||
|
)
|
||||||
|
|
||||||
|
self._roi_mapper = build_roi_mapper(config.roi)
|
||||||
|
|
||||||
|
self.dimension_names = self._roi_mapper.dimension_names
|
||||||
|
|
||||||
|
self.class_names = get_class_names_from_config(
|
||||||
|
config.classification_targets
|
||||||
|
)
|
||||||
|
|
||||||
|
self.detection_class_name = config.detection_target.name
|
||||||
|
self.detection_class_tags = config.detection_target.assign_tags
|
||||||
|
|
||||||
|
self._roi_mapper_overrides = {
|
||||||
|
class_config.name: build_roi_mapper(class_config.roi)
|
||||||
|
for class_config in config.classification_targets
|
||||||
|
if class_config.roi is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
for class_name in self._roi_mapper_overrides:
|
||||||
|
if class_name not in self.class_names:
|
||||||
|
# TODO: improve this warning
|
||||||
|
logger.warning(
|
||||||
|
"The ROI mapper overrides contains a class ({class_name}) "
|
||||||
|
"not present in the class names.",
|
||||||
|
class_name=class_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||||
|
"""Apply the configured filter to a sound event annotation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event : data.SoundEventAnnotation
|
||||||
|
The annotation to filter.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True if the annotation should be kept (passes the filter),
|
||||||
|
False otherwise. If no filter was configured, always returns True.
|
||||||
|
"""
|
||||||
|
return self._filter_fn(sound_event)
|
||||||
|
|
||||||
|
def encode_class(
|
||||||
|
self, sound_event: data.SoundEventAnnotation
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Encode a sound event annotation to its target class name.
|
||||||
|
|
||||||
|
Applies the configured class definition rules (including priority)
|
||||||
|
to determine the specific class name for the annotation.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event : data.SoundEventAnnotation
|
||||||
|
The annotation to encode. Note: This should typically be called
|
||||||
|
*after* applying any transformations via the `transform` method.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str or None
|
||||||
|
The name of the matched target class, or None if the annotation
|
||||||
|
does not match any specific class rule (i.e., it belongs to the
|
||||||
|
generic category).
|
||||||
|
"""
|
||||||
|
return self._encode_fn(sound_event)
|
||||||
|
|
||||||
|
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||||
|
"""Decode a predicted class name back into representative tags.
|
||||||
|
|
||||||
|
Uses the configured mapping (based on `TargetClass.output_tags` or
|
||||||
|
`TargetClass.tags`) to convert a class name string into a list of
|
||||||
|
`soundevent.data.Tag` objects.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
class_label : str
|
||||||
|
The class name to decode.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[data.Tag]
|
||||||
|
The list of tags corresponding to the input class name.
|
||||||
|
"""
|
||||||
|
return self._decode_fn(class_label)
|
||||||
|
|
||||||
|
def encode_roi(
|
||||||
|
self, sound_event: data.SoundEventAnnotation
|
||||||
|
) -> tuple[Position, Size]:
|
||||||
|
"""Extract the target reference position from the annotation's roi.
|
||||||
|
|
||||||
|
Delegates to the internal ROI mapper's `get_roi_position` method.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
sound_event : data.SoundEventAnnotation
|
||||||
|
The annotation containing the geometry (ROI).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tuple[float, float]
|
||||||
|
The reference position `(time, frequency)`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the annotation lacks geometry.
|
||||||
|
"""
|
||||||
|
class_name = self.encode_class(sound_event)
|
||||||
|
|
||||||
|
if class_name in self._roi_mapper_overrides:
|
||||||
|
return self._roi_mapper_overrides[class_name].encode(
|
||||||
|
sound_event.sound_event
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._roi_mapper.encode(sound_event.sound_event)
|
||||||
|
|
||||||
|
def decode_roi(
|
||||||
|
self,
|
||||||
|
position: Position,
|
||||||
|
size: Size,
|
||||||
|
class_name: Optional[str] = None,
|
||||||
|
) -> data.Geometry:
|
||||||
|
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||||
|
|
||||||
|
Delegates to the internal ROI mapper's `recover_roi` method, which
|
||||||
|
un-scales the dimensions and reconstructs the geometry (typically a
|
||||||
|
`BoundingBox`).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
pos : Tuple[float, float]
|
||||||
|
The reference position `(time, frequency)`.
|
||||||
|
dims : np.ndarray
|
||||||
|
NumPy array with size dimensions (e.g., from model prediction),
|
||||||
|
matching the order in `self.dimension_names`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
data.Geometry
|
||||||
|
The reconstructed geometry (typically `BoundingBox`).
|
||||||
|
"""
|
||||||
|
if class_name in self._roi_mapper_overrides:
|
||||||
|
return self._roi_mapper_overrides[class_name].decode(
|
||||||
|
position,
|
||||||
|
size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._roi_mapper.decode(position, size)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||||
|
classification_targets=DEFAULT_CLASSES,
|
||||||
|
detection_target=DEFAULT_DETECTION_CLASS,
|
||||||
|
roi=AnchorBBoxMapperConfig(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
||||||
|
"""Build a Targets object from a loaded TargetConfig.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : TargetConfig
|
||||||
|
The loaded and validated unified target configuration object.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Targets
|
||||||
|
An initialized `Targets` object ready for use.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
KeyError
|
||||||
|
If term keys or derivation function keys specified in the `config`
|
||||||
|
are not found in their respective registries.
|
||||||
|
ImportError, AttributeError, TypeError
|
||||||
|
If dynamic import of a derivation function fails (when configured).
|
||||||
|
"""
|
||||||
|
config = config or DEFAULT_TARGET_CONFIG
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building targets with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return Targets(config=config)
|
||||||
|
|
||||||
|
|
||||||
|
def load_targets(
|
||||||
|
config_path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> Targets:
|
||||||
|
"""Load a Targets object directly from a configuration file.
|
||||||
|
|
||||||
|
This convenience factory method loads the `TargetConfig` from the
|
||||||
|
specified file path and then calls `Targets.from_config` to build
|
||||||
|
the fully initialized `Targets` object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config_path : data.PathLike
|
||||||
|
Path to the configuration file (e.g., YAML).
|
||||||
|
field : str, optional
|
||||||
|
Dot-separated path to a nested section within the file containing
|
||||||
|
the target configuration. If None, the entire file content is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Targets
|
||||||
|
An initialized `Targets` object ready for use.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
|
||||||
|
TypeError
|
||||||
|
Errors raised during file loading, validation, or extraction via
|
||||||
|
`load_target_config`.
|
||||||
|
KeyError, ImportError, AttributeError, TypeError
|
||||||
|
Errors raised during the build process by `Targets.from_config`
|
||||||
|
(e.g., missing keys in registries, failed imports).
|
||||||
|
"""
|
||||||
|
config = load_target_config(
|
||||||
|
config_path,
|
||||||
|
field=field,
|
||||||
|
)
|
||||||
|
return build_targets(config)
|
||||||
|
|
||||||
|
|
||||||
|
def iterate_encoded_sound_events(
|
||||||
|
sound_events: Iterable[data.SoundEventAnnotation],
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> Iterable[Tuple[Optional[str], Position, Size]]:
|
||||||
|
for sound_event in sound_events:
|
||||||
|
if not targets.filter(sound_event):
|
||||||
|
continue
|
||||||
|
|
||||||
|
geometry = sound_event.sound_event.geometry
|
||||||
|
|
||||||
|
if geometry is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
class_name = targets.encode_class(sound_event)
|
||||||
|
position, size = targets.encode_roi(sound_event)
|
||||||
|
|
||||||
|
yield class_name, position, size
|
||||||
@ -14,12 +14,9 @@ from batdetect2.train.augmentations import (
|
|||||||
scale_volume,
|
scale_volume,
|
||||||
warp_spectrogram,
|
warp_spectrogram,
|
||||||
)
|
)
|
||||||
from batdetect2.train.clips import build_clipper, select_subclip
|
|
||||||
from batdetect2.train.config import (
|
from batdetect2.train.config import (
|
||||||
FullTrainingConfig,
|
|
||||||
PLTrainerConfig,
|
PLTrainerConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
load_full_training_config,
|
|
||||||
load_train_config,
|
load_train_config,
|
||||||
)
|
)
|
||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import (
|
||||||
@ -48,7 +45,6 @@ __all__ = [
|
|||||||
"DetectionLossConfig",
|
"DetectionLossConfig",
|
||||||
"EchoAugmentationConfig",
|
"EchoAugmentationConfig",
|
||||||
"FrequencyMaskAugmentationConfig",
|
"FrequencyMaskAugmentationConfig",
|
||||||
"FullTrainingConfig",
|
|
||||||
"LossConfig",
|
"LossConfig",
|
||||||
"LossFunction",
|
"LossFunction",
|
||||||
"PLTrainerConfig",
|
"PLTrainerConfig",
|
||||||
@ -64,21 +60,18 @@ __all__ = [
|
|||||||
"add_echo",
|
"add_echo",
|
||||||
"build_augmentations",
|
"build_augmentations",
|
||||||
"build_clip_labeler",
|
"build_clip_labeler",
|
||||||
"build_clipper",
|
|
||||||
"build_loss",
|
"build_loss",
|
||||||
"build_train_dataset",
|
"build_train_dataset",
|
||||||
"build_train_loader",
|
"build_train_loader",
|
||||||
"build_trainer",
|
"build_trainer",
|
||||||
"build_val_dataset",
|
"build_val_dataset",
|
||||||
"build_val_loader",
|
"build_val_loader",
|
||||||
"load_full_training_config",
|
|
||||||
"load_label_config",
|
"load_label_config",
|
||||||
"load_train_config",
|
"load_train_config",
|
||||||
"mask_frequency",
|
"mask_frequency",
|
||||||
"mask_time",
|
"mask_time",
|
||||||
"mix_audio",
|
"mix_audio",
|
||||||
"scale_volume",
|
"scale_volume",
|
||||||
"select_subclip",
|
|
||||||
"train",
|
"train",
|
||||||
"warp_spectrogram",
|
"warp_spectrogram",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -11,11 +11,10 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import scale_geometry, shift_geometry
|
from soundevent.geometry import scale_geometry, shift_geometry
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.audio.clips import get_subclip_annotation
|
||||||
from batdetect2.train.clips import get_subclip_annotation
|
from batdetect2.core.arrays import adjust_width
|
||||||
from batdetect2.typing import Augmentation
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
from batdetect2.typing import AudioLoader, Augmentation
|
||||||
from batdetect2.utils.arrays import adjust_width
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AugmentationConfig",
|
"AugmentationConfig",
|
||||||
|
|||||||
@ -5,19 +5,21 @@ from lightning.pytorch.callbacks import Callback
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.evaluate import Evaluator
|
from batdetect2.logging import get_image_logger
|
||||||
from batdetect2.postprocess import get_raw_predictions
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
from batdetect2.train.dataset import ValidationDataset
|
from batdetect2.train.dataset import ValidationDataset
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.logging import get_image_plotter
|
from batdetect2.typing import (
|
||||||
from batdetect2.typing.evaluate import ClipEvaluation
|
ClipEvaluation,
|
||||||
from batdetect2.typing.models import ModelOutput
|
EvaluatorProtocol,
|
||||||
from batdetect2.typing.postprocess import RawPrediction
|
ModelOutput,
|
||||||
from batdetect2.typing.train import TrainExample
|
RawPrediction,
|
||||||
|
TrainExample,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ValidationMetrics(Callback):
|
class ValidationMetrics(Callback):
|
||||||
def __init__(self, evaluator: Evaluator):
|
def __init__(self, evaluator: EvaluatorProtocol):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
@ -32,12 +34,12 @@ class ValidationMetrics(Callback):
|
|||||||
assert isinstance(dataset, ValidationDataset)
|
assert isinstance(dataset, ValidationDataset)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def plot_examples(
|
def generate_plots(
|
||||||
self,
|
self,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
evaluated_clips: List[ClipEvaluation],
|
evaluated_clips: List[ClipEvaluation],
|
||||||
):
|
):
|
||||||
plotter = get_image_plotter(pl_module.logger) # type: ignore
|
plotter = get_image_logger(pl_module.logger) # type: ignore
|
||||||
|
|
||||||
if plotter is None:
|
if plotter is None:
|
||||||
return
|
return
|
||||||
@ -64,7 +66,7 @@ class ValidationMetrics(Callback):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.log_metrics(pl_module, clip_evaluations)
|
self.log_metrics(pl_module, clip_evaluations)
|
||||||
self.plot_examples(pl_module, clip_evaluations)
|
self.generate_plots(pl_module, clip_evaluations)
|
||||||
|
|
||||||
return super().on_validation_epoch_end(trainer, pl_module)
|
return super().on_validation_epoch_end(trainer, pl_module)
|
||||||
|
|
||||||
@ -86,8 +88,7 @@ class ValidationMetrics(Callback):
|
|||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
dataloader_idx: int = 0,
|
dataloader_idx: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
postprocessor = pl_module.model.postprocessor
|
model = pl_module.model
|
||||||
targets = pl_module.model.targets
|
|
||||||
dataset = self.get_dataset(trainer)
|
dataset = self.get_dataset(trainer)
|
||||||
|
|
||||||
clip_annotations = [
|
clip_annotations = [
|
||||||
@ -95,15 +96,14 @@ class ValidationMetrics(Callback):
|
|||||||
for example_idx in batch.idx
|
for example_idx in batch.idx
|
||||||
]
|
]
|
||||||
|
|
||||||
predictions = get_raw_predictions(
|
clip_detections = model.postprocessor(
|
||||||
outputs,
|
outputs,
|
||||||
start_times=[
|
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||||
clip_annotation.clip.start_time
|
|
||||||
for clip_annotation in clip_annotations
|
|
||||||
],
|
|
||||||
targets=targets,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
)
|
)
|
||||||
|
predictions = [
|
||||||
|
to_raw_predictions(clip_dets.numpy(), targets=model.targets)
|
||||||
|
for clip_dets in clip_detections
|
||||||
|
]
|
||||||
|
|
||||||
self._clip_annotations.extend(clip_annotations)
|
self._clip_annotations.extend(clip_annotations)
|
||||||
self._predictions.extend(predictions)
|
self._predictions.extend(predictions)
|
||||||
|
|||||||
@ -3,27 +3,16 @@ from typing import Optional, Union
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate import EvaluationConfig
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
from batdetect2.models import ModelConfig
|
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
|
||||||
DEFAULT_AUGMENTATION_CONFIG,
|
|
||||||
AugmentationsConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.train.clips import (
|
|
||||||
ClipConfig,
|
|
||||||
PaddedClipConfig,
|
|
||||||
RandomClipConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.train.labels import LabelConfig
|
from batdetect2.train.labels import LabelConfig
|
||||||
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
|
|
||||||
from batdetect2.train.losses import LossConfig
|
from batdetect2.train.losses import LossConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"load_train_config",
|
"load_train_config",
|
||||||
"FullTrainingConfig",
|
|
||||||
"load_full_training_config",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -48,30 +37,6 @@ class PLTrainerConfig(BaseConfig):
|
|||||||
val_check_interval: Optional[Union[int, float]] = None
|
val_check_interval: Optional[Union[int, float]] = None
|
||||||
|
|
||||||
|
|
||||||
class ValLoaderConfig(BaseConfig):
|
|
||||||
num_workers: int = 0
|
|
||||||
|
|
||||||
clipping_strategy: ClipConfig = Field(
|
|
||||||
default_factory=lambda: PaddedClipConfig()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TrainLoaderConfig(BaseConfig):
|
|
||||||
num_workers: int = 0
|
|
||||||
|
|
||||||
batch_size: int = 8
|
|
||||||
|
|
||||||
shuffle: bool = False
|
|
||||||
|
|
||||||
augmentations: AugmentationsConfig = Field(
|
|
||||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
|
|
||||||
)
|
|
||||||
|
|
||||||
clipping_strategy: ClipConfig = Field(
|
|
||||||
default_factory=lambda: PaddedClipConfig()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizerConfig(BaseConfig):
|
class OptimizerConfig(BaseConfig):
|
||||||
learning_rate: float = 1e-3
|
learning_rate: float = 1e-3
|
||||||
t_max: int = 100
|
t_max: int = 100
|
||||||
@ -80,13 +45,12 @@ class OptimizerConfig(BaseConfig):
|
|||||||
class TrainingConfig(BaseConfig):
|
class TrainingConfig(BaseConfig):
|
||||||
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
|
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
|
||||||
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
|
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
|
||||||
|
|
||||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||||
loss: LossConfig = Field(default_factory=LossConfig)
|
loss: LossConfig = Field(default_factory=LossConfig)
|
||||||
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
|
|
||||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
||||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||||
|
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||||
|
|
||||||
|
|
||||||
def load_train_config(
|
def load_train_config(
|
||||||
@ -94,18 +58,3 @@ def load_train_config(
|
|||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
) -> TrainingConfig:
|
) -> TrainingConfig:
|
||||||
return load_config(path, schema=TrainingConfig, field=field)
|
return load_config(path, schema=TrainingConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
class FullTrainingConfig(ModelConfig):
|
|
||||||
"""Full training configuration."""
|
|
||||||
|
|
||||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
|
||||||
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def load_full_training_config(
|
|
||||||
path: data.PathLike,
|
|
||||||
field: Optional[str] = None,
|
|
||||||
) -> FullTrainingConfig:
|
|
||||||
"""Load the full training configuration."""
|
|
||||||
return load_config(path, schema=FullTrainingConfig, field=field)
|
|
||||||
|
|||||||
@ -2,22 +2,30 @@ from typing import List, Optional, Sequence
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
from batdetect2.plotting.clips import build_audio_loader
|
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
||||||
|
from batdetect2.audio.clips import PaddedClipConfig
|
||||||
|
from batdetect2.core import BaseConfig
|
||||||
|
from batdetect2.core.arrays import adjust_width
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
|
DEFAULT_AUGMENTATION_CONFIG,
|
||||||
|
AugmentationsConfig,
|
||||||
RandomAudioSource,
|
RandomAudioSource,
|
||||||
build_augmentations,
|
build_augmentations,
|
||||||
)
|
)
|
||||||
from batdetect2.train.clips import build_clipper
|
|
||||||
from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig
|
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.typing import ClipperProtocol, TrainExample
|
from batdetect2.typing import (
|
||||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
AudioLoader,
|
||||||
from batdetect2.typing.train import Augmentation, ClipLabeller
|
Augmentation,
|
||||||
from batdetect2.utils.arrays import adjust_width
|
ClipLabeller,
|
||||||
|
ClipperProtocol,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
TrainExample,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingDataset",
|
"TrainingDataset",
|
||||||
@ -139,6 +147,22 @@ class ValidationDataset(Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainLoaderConfig(BaseConfig):
|
||||||
|
num_workers: int = 0
|
||||||
|
|
||||||
|
batch_size: int = 8
|
||||||
|
|
||||||
|
shuffle: bool = False
|
||||||
|
|
||||||
|
augmentations: AugmentationsConfig = Field(
|
||||||
|
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
|
||||||
|
)
|
||||||
|
|
||||||
|
clipping_strategy: ClipConfig = Field(
|
||||||
|
default_factory=lambda: PaddedClipConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_train_loader(
|
def build_train_loader(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
@ -173,6 +197,14 @@ def build_train_loader(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ValLoaderConfig(BaseConfig):
|
||||||
|
num_workers: int = 0
|
||||||
|
|
||||||
|
clipping_strategy: ClipConfig = Field(
|
||||||
|
default_factory=lambda: PaddedClipConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_val_loader(
|
def build_val_loader(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
|||||||
@ -13,14 +13,10 @@ import torch
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
from batdetect2.targets import build_targets, iterate_encoded_sound_events
|
from batdetect2.targets import build_targets, iterate_encoded_sound_events
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import ClipLabeller, Heatmaps, TargetProtocol
|
||||||
ClipLabeller,
|
|
||||||
Heatmaps,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LabelConfig",
|
"LabelConfig",
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import TYPE_CHECKING, Optional, Tuple
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
import torch
|
import torch
|
||||||
@ -6,11 +6,17 @@ from soundevent.data import PathLike
|
|||||||
from torch.optim.adam import Adam
|
from torch.optim.adam import Adam
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
|
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||||
from batdetect2.models import Model, build_model
|
from batdetect2.models import Model, build_model
|
||||||
from batdetect2.train.config import FullTrainingConfig
|
from batdetect2.plotting.clips import build_preprocessor
|
||||||
|
from batdetect2.postprocess import build_postprocessor
|
||||||
|
from batdetect2.targets.targets import build_targets
|
||||||
from batdetect2.train.losses import build_loss
|
from batdetect2.train.losses import build_loss
|
||||||
from batdetect2.typing import ModelOutput, TrainExample
|
from batdetect2.typing import ModelOutput, TrainExample
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
]
|
]
|
||||||
@ -21,7 +27,8 @@ class TrainingModule(L.LightningModule):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: FullTrainingConfig,
|
config: "BatDetect2Config",
|
||||||
|
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
learning_rate: float = 0.001,
|
learning_rate: float = 0.001,
|
||||||
t_max: int = 100,
|
t_max: int = 100,
|
||||||
model: Optional[Model] = None,
|
model: Optional[Model] = None,
|
||||||
@ -31,6 +38,7 @@ class TrainingModule(L.LightningModule):
|
|||||||
|
|
||||||
self.save_hyperparameters(logger=False)
|
self.save_hyperparameters(logger=False)
|
||||||
|
|
||||||
|
self.input_samplerate = input_samplerate
|
||||||
self.config = config
|
self.config = config
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.t_max = t_max
|
self.t_max = t_max
|
||||||
@ -39,7 +47,23 @@ class TrainingModule(L.LightningModule):
|
|||||||
loss = build_loss(self.config.train.loss)
|
loss = build_loss(self.config.train.loss)
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
model = build_model(self.config)
|
targets = build_targets(self.config.targets)
|
||||||
|
|
||||||
|
preprocessor = build_preprocessor(
|
||||||
|
config=self.config.preprocess,
|
||||||
|
input_samplerate=self.input_samplerate,
|
||||||
|
)
|
||||||
|
|
||||||
|
postprocessor = build_postprocessor(
|
||||||
|
preprocessor, config=self.config.postprocess
|
||||||
|
)
|
||||||
|
|
||||||
|
model = build_model(
|
||||||
|
config=self.config.model,
|
||||||
|
targets=targets,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
)
|
||||||
|
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -74,16 +98,18 @@ class TrainingModule(L.LightningModule):
|
|||||||
|
|
||||||
def load_model_from_checkpoint(
|
def load_model_from_checkpoint(
|
||||||
path: PathLike,
|
path: PathLike,
|
||||||
) -> Tuple[Model, FullTrainingConfig]:
|
) -> Tuple[Model, "BatDetect2Config"]:
|
||||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||||
return module.model, module.config
|
return module.model, module.config
|
||||||
|
|
||||||
|
|
||||||
def build_training_module(
|
def build_training_module(
|
||||||
config: Optional[FullTrainingConfig] = None,
|
config: Optional["BatDetect2Config"] = None,
|
||||||
t_max: int = 200,
|
t_max: int = 200,
|
||||||
) -> TrainingModule:
|
) -> TrainingModule:
|
||||||
config = config or FullTrainingConfig()
|
from batdetect2.config import BatDetect2Config
|
||||||
|
|
||||||
|
config = config or BatDetect2Config()
|
||||||
return TrainingModule(
|
return TrainingModule(
|
||||||
config=config,
|
config=config,
|
||||||
learning_rate=config.train.optimizer.learning_rate,
|
learning_rate=config.train.optimizer.learning_rate,
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from loguru import logger
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample
|
from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -1,29 +1,31 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from lightning import Trainer, seed_everything
|
from lightning import Trainer, seed_everything
|
||||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.evaluate.evaluator import build_evaluator
|
from batdetect2.evaluate.evaluator import build_evaluator
|
||||||
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
|
from batdetect2.logging import build_logger
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.config import (
|
|
||||||
FullTrainingConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.lightning import TrainingModule, build_training_module
|
from batdetect2.train.lightning import build_training_module
|
||||||
from batdetect2.train.logging import build_logger
|
|
||||||
from batdetect2.typing import (
|
if TYPE_CHECKING:
|
||||||
TargetProtocol,
|
from batdetect2.config import BatDetect2Config
|
||||||
)
|
from batdetect2.typing import (
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
AudioLoader,
|
||||||
from batdetect2.typing.train import ClipLabeller
|
ClipLabeller,
|
||||||
|
EvaluatorProtocol,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"build_trainer",
|
"build_trainer",
|
||||||
@ -36,13 +38,12 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
|||||||
def train(
|
def train(
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
||||||
|
targets: Optional["TargetProtocol"] = None,
|
||||||
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
|
labeller: Optional["ClipLabeller"] = None,
|
||||||
|
config: Optional["BatDetect2Config"] = None,
|
||||||
trainer: Optional[Trainer] = None,
|
trainer: Optional[Trainer] = None,
|
||||||
targets: Optional[TargetProtocol] = None,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
labeller: Optional[ClipLabeller] = None,
|
|
||||||
config: Optional[FullTrainingConfig] = None,
|
|
||||||
model_path: Optional[data.PathLike] = None,
|
|
||||||
train_workers: Optional[int] = None,
|
train_workers: Optional[int] = None,
|
||||||
val_workers: Optional[int] = None,
|
val_workers: Optional[int] = None,
|
||||||
checkpoint_dir: Optional[Path] = None,
|
checkpoint_dir: Optional[Path] = None,
|
||||||
@ -51,17 +52,20 @@ def train(
|
|||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
|
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
|
|
||||||
config = config or FullTrainingConfig()
|
config = config or BatDetect2Config()
|
||||||
|
|
||||||
targets = targets or build_targets(config.targets)
|
targets = targets or build_targets(config=config.targets)
|
||||||
|
|
||||||
preprocessor = preprocessor or build_preprocessor(config.preprocess)
|
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
||||||
|
|
||||||
audio_loader = audio_loader or build_audio_loader(
|
preprocessor = preprocessor or build_preprocessor(
|
||||||
config=config.preprocess.audio
|
input_samplerate=audio_loader.samplerate,
|
||||||
|
config=config.preprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
labeller = labeller or build_clip_labeler(
|
labeller = labeller or build_clip_labeler(
|
||||||
@ -93,18 +97,15 @@ def train(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_path is not None:
|
module = build_training_module(
|
||||||
logger.debug("Loading model from: {path}", path=model_path)
|
config,
|
||||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
t_max=config.train.optimizer.t_max * len(train_dataloader),
|
||||||
else:
|
)
|
||||||
module = build_training_module(
|
|
||||||
config,
|
|
||||||
t_max=config.train.optimizer.t_max * len(train_dataloader),
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = trainer or build_trainer(
|
trainer = trainer or build_trainer(
|
||||||
config,
|
config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
evaluator=build_evaluator(config.train.validation, targets=targets),
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
@ -121,8 +122,8 @@ def train(
|
|||||||
|
|
||||||
|
|
||||||
def build_trainer_callbacks(
|
def build_trainer_callbacks(
|
||||||
targets: TargetProtocol,
|
targets: "TargetProtocol",
|
||||||
config: FullTrainingConfig,
|
evaluator: Optional["EvaluatorProtocol"] = None,
|
||||||
checkpoint_dir: Optional[Path] = None,
|
checkpoint_dir: Optional[Path] = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
@ -136,13 +137,12 @@ def build_trainer_callbacks(
|
|||||||
if run_name is not None:
|
if run_name is not None:
|
||||||
checkpoint_dir = checkpoint_dir / run_name
|
checkpoint_dir = checkpoint_dir / run_name
|
||||||
|
|
||||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
evaluator = evaluator or build_evaluator(targets=targets)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ModelCheckpoint(
|
ModelCheckpoint(
|
||||||
dirpath=str(checkpoint_dir),
|
dirpath=str(checkpoint_dir),
|
||||||
save_top_k=1,
|
save_top_k=1,
|
||||||
filename="best-{epoch:02d}-{val_loss:.0f}",
|
|
||||||
monitor="total_loss/val",
|
monitor="total_loss/val",
|
||||||
),
|
),
|
||||||
ValidationMetrics(evaluator),
|
ValidationMetrics(evaluator),
|
||||||
@ -150,8 +150,9 @@ def build_trainer_callbacks(
|
|||||||
|
|
||||||
|
|
||||||
def build_trainer(
|
def build_trainer(
|
||||||
conf: FullTrainingConfig,
|
conf: "BatDetect2Config",
|
||||||
targets: TargetProtocol,
|
targets: "TargetProtocol",
|
||||||
|
evaluator: Optional["EvaluatorProtocol"] = None,
|
||||||
checkpoint_dir: Optional[Path] = None,
|
checkpoint_dir: Optional[Path] = None,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Optional[Path] = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
@ -181,7 +182,7 @@ def build_trainer(
|
|||||||
logger=train_logger,
|
logger=train_logger,
|
||||||
callbacks=build_trainer_callbacks(
|
callbacks=build_trainer_callbacks(
|
||||||
targets,
|
targets,
|
||||||
config=conf,
|
evaluator=evaluator,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
|
|||||||
@ -1,4 +1,10 @@
|
|||||||
from batdetect2.typing.evaluate import MatchEvaluation, MetricsProtocol
|
from batdetect2.typing.evaluate import (
|
||||||
|
ClipEvaluation,
|
||||||
|
EvaluatorProtocol,
|
||||||
|
MatchEvaluation,
|
||||||
|
MetricsProtocol,
|
||||||
|
PlotterProtocol,
|
||||||
|
)
|
||||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.typing.postprocess import (
|
||||||
BatDetect2Prediction,
|
BatDetect2Prediction,
|
||||||
@ -9,10 +15,10 @@ from batdetect2.typing.postprocess import (
|
|||||||
from batdetect2.typing.preprocess import (
|
from batdetect2.typing.preprocess import (
|
||||||
AudioLoader,
|
AudioLoader,
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
SpectrogramBuilder,
|
|
||||||
)
|
)
|
||||||
from batdetect2.typing.targets import (
|
from batdetect2.typing.targets import (
|
||||||
Position,
|
Position,
|
||||||
|
ROITargetMapper,
|
||||||
Size,
|
Size,
|
||||||
SoundEventDecoder,
|
SoundEventDecoder,
|
||||||
SoundEventEncoder,
|
SoundEventEncoder,
|
||||||
@ -34,6 +40,7 @@ __all__ = [
|
|||||||
"Augmentation",
|
"Augmentation",
|
||||||
"BackboneModel",
|
"BackboneModel",
|
||||||
"BatDetect2Prediction",
|
"BatDetect2Prediction",
|
||||||
|
"ClipEvaluation",
|
||||||
"ClipLabeller",
|
"ClipLabeller",
|
||||||
"ClipperProtocol",
|
"ClipperProtocol",
|
||||||
"DetectionModel",
|
"DetectionModel",
|
||||||
@ -44,15 +51,17 @@ __all__ = [
|
|||||||
"MatchEvaluation",
|
"MatchEvaluation",
|
||||||
"MetricsProtocol",
|
"MetricsProtocol",
|
||||||
"ModelOutput",
|
"ModelOutput",
|
||||||
|
"PlotterProtocol",
|
||||||
"Position",
|
"Position",
|
||||||
"PostprocessorProtocol",
|
"PostprocessorProtocol",
|
||||||
"PreprocessorProtocol",
|
"PreprocessorProtocol",
|
||||||
|
"ROITargetMapper",
|
||||||
"RawPrediction",
|
"RawPrediction",
|
||||||
"Size",
|
"Size",
|
||||||
"SoundEventDecoder",
|
"SoundEventDecoder",
|
||||||
"SoundEventEncoder",
|
"SoundEventEncoder",
|
||||||
"SoundEventFilter",
|
"SoundEventFilter",
|
||||||
"SpectrogramBuilder",
|
|
||||||
"TargetProtocol",
|
"TargetProtocol",
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
|
"EvaluatorProtocol",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -14,7 +14,11 @@ from typing import (
|
|||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.typing.postprocess import RawPrediction
|
||||||
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"EvaluatorProtocol",
|
||||||
"MetricsProtocol",
|
"MetricsProtocol",
|
||||||
"MatchEvaluation",
|
"MatchEvaluation",
|
||||||
]
|
]
|
||||||
@ -50,6 +54,26 @@ class MatchEvaluation:
|
|||||||
|
|
||||||
return self.pred_class_scores[pred_class]
|
return self.pred_class_scores[pred_class]
|
||||||
|
|
||||||
|
def is_true_positive(self, threshold: float = 0) -> bool:
|
||||||
|
return (
|
||||||
|
self.gt_det
|
||||||
|
and self.pred_score > threshold
|
||||||
|
and self.gt_class == self.pred_class
|
||||||
|
)
|
||||||
|
|
||||||
|
def is_false_positive(self, threshold: float = 0) -> bool:
|
||||||
|
return self.gt_det is None and self.pred_score > threshold
|
||||||
|
|
||||||
|
def is_false_negative(self, threshold: float = 0) -> bool:
|
||||||
|
return self.gt_det and self.pred_score <= threshold
|
||||||
|
|
||||||
|
def is_cross_trigger(self, threshold: float = 0) -> bool:
|
||||||
|
return (
|
||||||
|
self.gt_det
|
||||||
|
and self.pred_score > threshold
|
||||||
|
and self.gt_class != self.pred_class
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ClipEvaluation:
|
class ClipEvaluation:
|
||||||
@ -87,3 +111,21 @@ class PlotterProtocol(Protocol):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
) -> Iterable[Tuple[str, Figure]]: ...
|
) -> Iterable[Tuple[str, Figure]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluatorProtocol(Protocol):
|
||||||
|
targets: TargetProtocol
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
self,
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
predictions: Sequence[Sequence[RawPrediction]],
|
||||||
|
) -> List[ClipEvaluation]: ...
|
||||||
|
|
||||||
|
def compute_metrics(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]: ...
|
||||||
|
|
||||||
|
def generate_plots(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Iterable[Tuple[str, Figure]]: ...
|
||||||
|
|||||||
@ -12,7 +12,7 @@ system that deal with model predictions.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, NamedTuple, Optional, Protocol
|
from typing import List, NamedTuple, Optional, Protocol, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -47,15 +47,13 @@ class GeometryDecoder(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
class RawPrediction(NamedTuple):
|
class RawPrediction(NamedTuple):
|
||||||
"""Intermediate representation of a single detected sound event."""
|
|
||||||
|
|
||||||
geometry: data.Geometry
|
geometry: data.Geometry
|
||||||
detection_score: float
|
detection_score: float
|
||||||
class_scores: np.ndarray
|
class_scores: np.ndarray
|
||||||
features: np.ndarray
|
features: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
class DetectionsArray(NamedTuple):
|
class ClipDetectionsArray(NamedTuple):
|
||||||
scores: np.ndarray
|
scores: np.ndarray
|
||||||
sizes: np.ndarray
|
sizes: np.ndarray
|
||||||
class_scores: np.ndarray
|
class_scores: np.ndarray
|
||||||
@ -64,7 +62,7 @@ class DetectionsArray(NamedTuple):
|
|||||||
features: np.ndarray
|
features: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
class DetectionsTensor(NamedTuple):
|
class ClipDetectionsTensor(NamedTuple):
|
||||||
scores: torch.Tensor
|
scores: torch.Tensor
|
||||||
sizes: torch.Tensor
|
sizes: torch.Tensor
|
||||||
class_scores: torch.Tensor
|
class_scores: torch.Tensor
|
||||||
@ -72,8 +70,8 @@ class DetectionsTensor(NamedTuple):
|
|||||||
frequencies: torch.Tensor
|
frequencies: torch.Tensor
|
||||||
features: torch.Tensor
|
features: torch.Tensor
|
||||||
|
|
||||||
def numpy(self) -> DetectionsArray:
|
def numpy(self) -> ClipDetectionsArray:
|
||||||
return DetectionsArray(
|
return ClipDetectionsArray(
|
||||||
scores=self.scores.detach().cpu().numpy(),
|
scores=self.scores.detach().cpu().numpy(),
|
||||||
sizes=self.sizes.detach().cpu().numpy(),
|
sizes=self.sizes.detach().cpu().numpy(),
|
||||||
class_scores=self.class_scores.detach().cpu().numpy(),
|
class_scores=self.class_scores.detach().cpu().numpy(),
|
||||||
@ -92,10 +90,8 @@ class BatDetect2Prediction:
|
|||||||
class PostprocessorProtocol(Protocol):
|
class PostprocessorProtocol(Protocol):
|
||||||
"""Protocol defining the interface for the full postprocessing pipeline."""
|
"""Protocol defining the interface for the full postprocessing pipeline."""
|
||||||
|
|
||||||
def __call__(self, output: ModelOutput) -> List[DetectionsTensor]: ...
|
def __call__(
|
||||||
|
|
||||||
def get_detections(
|
|
||||||
self,
|
self,
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
start_times: Optional[List[float]] = None,
|
start_times: Optional[Sequence[float]] = None,
|
||||||
) -> List[DetectionsTensor]: ...
|
) -> List[ClipDetectionsTensor]: ...
|
||||||
|
|||||||
@ -32,6 +32,8 @@ class AudioLoader(Protocol):
|
|||||||
allows for different loading strategies or implementations.
|
allows for different loading strategies or implementations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
samplerate: int
|
||||||
|
|
||||||
def load_file(
|
def load_file(
|
||||||
self,
|
self,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
@ -125,22 +127,6 @@ class SpectrogramBuilder(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class AudioPipeline(Protocol):
|
|
||||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
|
|
||||||
class SpectrogramPipeline(Protocol):
|
|
||||||
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
|
|
||||||
class PreprocessorProtocol(Protocol):
|
class PreprocessorProtocol(Protocol):
|
||||||
"""Defines a high-level interface for the complete preprocessing pipeline."""
|
"""Defines a high-level interface for the complete preprocessing pipeline."""
|
||||||
|
|
||||||
@ -152,11 +138,13 @@ class PreprocessorProtocol(Protocol):
|
|||||||
|
|
||||||
output_samplerate: float
|
output_samplerate: float
|
||||||
|
|
||||||
audio_pipeline: AudioPipeline
|
|
||||||
|
|
||||||
spectrogram_pipeline: SpectrogramPipeline
|
|
||||||
|
|
||||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
def process_audio(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
|
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
|
||||||
return self(torch.tensor(wav)).numpy()
|
return self(torch.tensor(wav)).numpy()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user