mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Compare commits
9 Commits
60e922d565
...
4cd983a2c2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4cd983a2c2 | ||
|
|
e65df81db2 | ||
|
|
6c25787123 | ||
|
|
8c80402f08 | ||
|
|
b81a882b58 | ||
|
|
6e217380f2 | ||
|
|
957c0735d2 | ||
|
|
bbb96b33a2 | ||
|
|
7d6cba5465 |
@ -1,66 +1,27 @@
|
||||
targets:
|
||||
detection_target:
|
||||
name: bat
|
||||
match_if:
|
||||
name: all_of
|
||||
conditions:
|
||||
- name: has_tag
|
||||
tag: { key: event, value: Echolocation }
|
||||
- name: not
|
||||
condition:
|
||||
name: has_tag
|
||||
tag: { key: class, value: Unknown }
|
||||
assign_tags:
|
||||
- key: class
|
||||
value: Bat
|
||||
|
||||
classification_targets:
|
||||
- name: myomys
|
||||
tags:
|
||||
- key: class
|
||||
value: Myotis mystacinus
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: class
|
||||
value: Pipistrellus pipistrellus
|
||||
- name: eptser
|
||||
tags:
|
||||
- key: class
|
||||
value: Eptesicus serotinus
|
||||
- name: rhifer
|
||||
tags:
|
||||
- key: class
|
||||
value: Rhinolophus ferrumequinum
|
||||
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
audio:
|
||||
samplerate: 256000
|
||||
resample:
|
||||
enabled: True
|
||||
method: "poly"
|
||||
|
||||
preprocess:
|
||||
audio:
|
||||
samplerate: 256000
|
||||
resample:
|
||||
enabled: True
|
||||
method: "poly"
|
||||
|
||||
spectrogram:
|
||||
stft:
|
||||
window_duration: 0.002
|
||||
window_overlap: 0.75
|
||||
window_fn: hann
|
||||
frequencies:
|
||||
max_freq: 120000
|
||||
min_freq: 10000
|
||||
size:
|
||||
height: 128
|
||||
resize_factor: 0.5
|
||||
transforms:
|
||||
- name: pcen
|
||||
time_constant: 0.1
|
||||
gain: 0.98
|
||||
bias: 2
|
||||
power: 0.5
|
||||
- name: spectral_mean_substraction
|
||||
stft:
|
||||
window_duration: 0.002
|
||||
window_overlap: 0.75
|
||||
window_fn: hann
|
||||
frequencies:
|
||||
max_freq: 120000
|
||||
min_freq: 10000
|
||||
size:
|
||||
height: 128
|
||||
resize_factor: 0.5
|
||||
spectrogram_transforms:
|
||||
- name: pcen
|
||||
time_constant: 0.1
|
||||
gain: 0.98
|
||||
bias: 2
|
||||
power: 0.5
|
||||
- name: spectral_mean_substraction
|
||||
|
||||
postprocess:
|
||||
nms_kernel_size: 9
|
||||
@ -102,23 +63,57 @@ model:
|
||||
out_channels: 32
|
||||
|
||||
train:
|
||||
learning_rate: 0.001
|
||||
t_max: 100
|
||||
optimizer:
|
||||
learning_rate: 0.001
|
||||
t_max: 100
|
||||
|
||||
labels:
|
||||
sigma: 3
|
||||
|
||||
trainer:
|
||||
max_epochs: 5
|
||||
max_epochs: 10
|
||||
check_val_every_n_epoch: 5
|
||||
|
||||
train_loader:
|
||||
batch_size: 8
|
||||
|
||||
num_workers: 2
|
||||
|
||||
shuffle: True
|
||||
|
||||
clipping_strategy:
|
||||
name: random_subclip
|
||||
duration: 0.256
|
||||
|
||||
augmentations:
|
||||
enabled: true
|
||||
audio:
|
||||
- name: mix_audio
|
||||
probability: 0.2
|
||||
min_weight: 0.3
|
||||
max_weight: 0.7
|
||||
- name: add_echo
|
||||
probability: 0.2
|
||||
max_delay: 0.005
|
||||
min_weight: 0.0
|
||||
max_weight: 1.0
|
||||
spectrogram:
|
||||
- name: scale_volume
|
||||
probability: 0.2
|
||||
min_scaling: 0.0
|
||||
max_scaling: 2.0
|
||||
- name: warp
|
||||
probability: 0.2
|
||||
delta: 0.04
|
||||
- name: mask_time
|
||||
probability: 0.2
|
||||
max_perc: 0.05
|
||||
max_masks: 3
|
||||
- name: mask_freq
|
||||
probability: 0.2
|
||||
max_perc: 0.10
|
||||
max_masks: 3
|
||||
|
||||
val_loader:
|
||||
num_workers: 2
|
||||
clipping_strategy:
|
||||
@ -142,31 +137,28 @@ train:
|
||||
logger:
|
||||
name: csv
|
||||
|
||||
augmentations:
|
||||
enabled: true
|
||||
audio:
|
||||
- name: mix_audio
|
||||
probability: 0.2
|
||||
min_weight: 0.3
|
||||
max_weight: 0.7
|
||||
- name: add_echo
|
||||
probability: 0.2
|
||||
max_delay: 0.005
|
||||
min_weight: 0.0
|
||||
max_weight: 1.0
|
||||
spectrogram:
|
||||
- name: scale_volume
|
||||
probability: 0.2
|
||||
min_scaling: 0.0
|
||||
max_scaling: 2.0
|
||||
- name: warp
|
||||
probability: 0.2
|
||||
delta: 0.04
|
||||
- name: mask_time
|
||||
probability: 0.2
|
||||
max_perc: 0.05
|
||||
max_masks: 3
|
||||
- name: mask_freq
|
||||
probability: 0.2
|
||||
max_perc: 0.10
|
||||
max_masks: 3
|
||||
validation:
|
||||
metrics:
|
||||
- name: detection_ap
|
||||
- name: detection_roc_auc
|
||||
- name: classification_ap
|
||||
- name: classification_roc_auc
|
||||
- name: top_class_ap
|
||||
- name: classification_balanced_accuracy
|
||||
- name: clip_ap
|
||||
- name: clip_roc_auc
|
||||
|
||||
evaluation:
|
||||
match_strategy:
|
||||
name: start_time_match
|
||||
distance_threshold: 0.01
|
||||
metrics:
|
||||
- name: classification_ap
|
||||
- name: detection_ap
|
||||
plots:
|
||||
- name: example_gallery
|
||||
- name: example_clip
|
||||
- name: detection_pr_curve
|
||||
- name: classification_pr_curves
|
||||
- name: detection_roc_curve
|
||||
- name: classification_roc_curves
|
||||
|
||||
36
example_data/targets.yaml
Normal file
36
example_data/targets.yaml
Normal file
@ -0,0 +1,36 @@
|
||||
detection_target:
|
||||
name: bat
|
||||
match_if:
|
||||
name: all_of
|
||||
conditions:
|
||||
- name: has_tag
|
||||
tag: { key: event, value: Echolocation }
|
||||
- name: not
|
||||
condition:
|
||||
name: has_tag
|
||||
tag: { key: class, value: Unknown }
|
||||
assign_tags:
|
||||
- key: class
|
||||
value: Bat
|
||||
|
||||
classification_targets:
|
||||
- name: myomys
|
||||
tags:
|
||||
- key: class
|
||||
value: Myotis mystacinus
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: class
|
||||
value: Pipistrellus pipistrellus
|
||||
- name: eptser
|
||||
tags:
|
||||
- key: class
|
||||
value: Eptesicus serotinus
|
||||
- name: rhifer
|
||||
tags:
|
||||
- key: class
|
||||
value: Rhinolophus ferrumequinum
|
||||
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
179
src/batdetect2/api/base.py
Normal file
179
src/batdetect2/api/base.py
Normal file
@ -0,0 +1,179 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.evaluate import build_evaluator, evaluate
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets.targets import build_targets
|
||||
from batdetect2.train import train
|
||||
from batdetect2.train.lightning import load_model_from_checkpoint
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
EvaluatorProtocol,
|
||||
PostprocessorProtocol,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
|
||||
class BatDetect2API:
|
||||
def __init__(
|
||||
self,
|
||||
config: BatDetect2Config,
|
||||
targets: TargetProtocol,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
evaluator: EvaluatorProtocol,
|
||||
model: Model,
|
||||
):
|
||||
self.config = config
|
||||
self.targets = targets
|
||||
self.audio_loader = audio_loader
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
self.evaluator = evaluator
|
||||
self.model = model
|
||||
|
||||
self.model.eval()
|
||||
|
||||
def train(
|
||||
self,
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
||||
train_workers: Optional[int] = None,
|
||||
val_workers: Optional[int] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
targets=self.targets,
|
||||
config=self.config,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
)
|
||||
return self
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
num_workers: Optional[int] = None,
|
||||
output_dir: data.PathLike = ".",
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
):
|
||||
return evaluate(
|
||||
self.model,
|
||||
test_annotations,
|
||||
targets=self.targets,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
config=self.config,
|
||||
num_workers=num_workers,
|
||||
output_dir=output_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: BatDetect2Config):
|
||||
targets = build_targets(config=config.targets)
|
||||
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
|
||||
preprocessor = build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.postprocess,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
config=config.evaluation,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
# NOTE: Better to have a separate instance of
|
||||
# preprocessor and postprocessor as these may be moved
|
||||
# to another device.
|
||||
model = build_model(
|
||||
config=config.model,
|
||||
targets=targets,
|
||||
preprocessor=build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
),
|
||||
postprocessor=build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.postprocess,
|
||||
),
|
||||
)
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
evaluator=evaluator,
|
||||
model=model,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
path: data.PathLike,
|
||||
config: Optional[BatDetect2Config] = None,
|
||||
):
|
||||
model, stored_config = load_model_from_checkpoint(path)
|
||||
|
||||
config = config or stored_config
|
||||
|
||||
targets = build_targets(config=config.targets)
|
||||
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
|
||||
preprocessor = build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.postprocess,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
config=config.evaluation,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
evaluator=evaluator,
|
||||
model=model,
|
||||
)
|
||||
16
src/batdetect2/audio/__init__.py
Normal file
16
src/batdetect2/audio/__init__.py
Normal file
@ -0,0 +1,16 @@
|
||||
from batdetect2.audio.clips import ClipConfig, build_clipper
|
||||
from batdetect2.audio.loader import (
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
AudioConfig,
|
||||
SoundEventAudioLoader,
|
||||
build_audio_loader,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"AudioConfig",
|
||||
"SoundEventAudioLoader",
|
||||
"build_audio_loader",
|
||||
"ClipConfig",
|
||||
"build_clipper",
|
||||
]
|
||||
@ -6,14 +6,19 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds, intervals_overlap
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.typing import ClipperProtocol
|
||||
|
||||
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||
|
||||
|
||||
__all__ = [
|
||||
"build_clipper",
|
||||
"ClipConfig",
|
||||
]
|
||||
|
||||
|
||||
clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")
|
||||
|
||||
|
||||
295
src/batdetect2/audio/loader.py
Normal file
295
src/batdetect2/audio/loader.py
Normal file
@ -0,0 +1,295 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import Field
|
||||
from scipy.signal import resample, resample_poly
|
||||
from soundevent import audio, data
|
||||
from soundfile import LibsndfileError
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.typing import AudioLoader
|
||||
|
||||
__all__ = [
|
||||
"SoundEventAudioLoader",
|
||||
"build_audio_loader",
|
||||
"load_file_audio",
|
||||
"load_recording_audio",
|
||||
"load_clip_audio",
|
||||
"resample_audio",
|
||||
]
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256_000
|
||||
"""Default target sample rate in Hz used if resampling is enabled."""
|
||||
|
||||
|
||||
class ResampleConfig(BaseConfig):
|
||||
"""Configuration for audio resampling.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
samplerate : int, default=256000
|
||||
The target sample rate in Hz to resample the audio to. Must be > 0.
|
||||
method : str, default="poly"
|
||||
The resampling algorithm to use. Options:
|
||||
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
|
||||
Generally fast.
|
||||
- "fourier": Resampling via Fourier method using
|
||||
`scipy.signal.resample`. May handle non-integer
|
||||
resampling factors differently.
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
method: str = "poly"
|
||||
|
||||
|
||||
class AudioConfig(BaseConfig):
|
||||
"""Configuration for loading and initial audio preprocessing."""
|
||||
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
resample: ResampleConfig = Field(default_factory=ResampleConfig)
|
||||
|
||||
|
||||
def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader:
|
||||
"""Factory function to create an AudioLoader based on configuration."""
|
||||
config = config or AudioConfig()
|
||||
return SoundEventAudioLoader(
|
||||
samplerate=config.samplerate,
|
||||
config=config.resample,
|
||||
)
|
||||
|
||||
|
||||
class SoundEventAudioLoader(AudioLoader):
|
||||
"""Concrete implementation of the `AudioLoader`."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
):
|
||||
self.samplerate = samplerate
|
||||
self.config = config or ResampleConfig()
|
||||
|
||||
def load_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio directly from a file path."""
|
||||
return load_file_audio(
|
||||
path,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio for a Recording object."""
|
||||
return load_recording_audio(
|
||||
recording,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the audio segment defined by a Clip object."""
|
||||
return load_clip_audio(
|
||||
clip,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_file_audio(
|
||||
path: data.PathLike,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio from a file path using specified config."""
|
||||
try:
|
||||
recording = data.Recording.from_file(path)
|
||||
except LibsndfileError as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not load the recording at path: {path}. Error: {e}"
|
||||
) from e
|
||||
|
||||
return load_recording_audio(
|
||||
recording,
|
||||
samplerate=samplerate,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_recording_audio(
|
||||
recording: data.Recording,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio content of a recording using config."""
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
)
|
||||
return load_clip_audio(
|
||||
clip,
|
||||
samplerate=samplerate,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess a specific audio clip segment based on config."""
|
||||
try:
|
||||
wav = (
|
||||
audio.load_clip(clip, audio_dir=audio_dir)
|
||||
.sel(channel=0)
|
||||
.astype(dtype)
|
||||
)
|
||||
except LibsndfileError as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not load the recording at path: {clip.recording.path}. "
|
||||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
if not config or not config.enabled or samplerate is None:
|
||||
return wav.data.astype(dtype)
|
||||
|
||||
sr = int(1 / wav.time.attrs["step"])
|
||||
return resample_audio(
|
||||
wav.data,
|
||||
sr=sr,
|
||||
samplerate=samplerate,
|
||||
method=config.method,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio(
|
||||
wav: np.ndarray,
|
||||
sr: int,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
method: str = "poly",
|
||||
) -> np.ndarray:
|
||||
"""Resample an audio waveform DataArray to a target sample rate."""
|
||||
if sr == samplerate:
|
||||
return wav
|
||||
|
||||
if method == "poly":
|
||||
return resample_audio_poly(
|
||||
wav,
|
||||
sr_orig=sr,
|
||||
sr_new=samplerate,
|
||||
)
|
||||
elif method == "fourier":
|
||||
return resample_audio_fourier(
|
||||
wav,
|
||||
sr_orig=sr,
|
||||
sr_new=samplerate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Resampling method '{method}' not implemented"
|
||||
)
|
||||
|
||||
|
||||
def resample_audio_poly(
|
||||
array: np.ndarray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
"""Resample a numpy array using `scipy.signal.resample_poly`.
|
||||
|
||||
This method is often preferred for signals when the ratio of new
|
||||
to old sample rates can be expressed as a rational number. It uses
|
||||
polyphase filtering.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : np.ndarray
|
||||
The input array to resample.
|
||||
sr_orig : int
|
||||
The original sample rate in Hz.
|
||||
sr_new : int
|
||||
The target sample rate in Hz.
|
||||
axis : int, default=-1
|
||||
The axis of `array` along which to resample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The array resampled to the target sample rate.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If sample rates are not positive.
|
||||
"""
|
||||
gcd = np.gcd(sr_orig, sr_new)
|
||||
return resample_poly(
|
||||
array,
|
||||
sr_new // gcd,
|
||||
sr_orig // gcd,
|
||||
axis=axis,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio_fourier(
|
||||
array: np.ndarray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
"""Resample a numpy array using `scipy.signal.resample`.
|
||||
|
||||
This method uses FFTs to resample the signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : np.ndarray
|
||||
The input array to resample.
|
||||
num : int
|
||||
The desired number of samples in the output array along `axis`.
|
||||
axis : int, default=-1
|
||||
The axis of `array` along which to resample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The array resampled to have `num` samples along `axis`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num` is negative.
|
||||
"""
|
||||
ratio = sr_new / sr_orig
|
||||
return resample( # type: ignore
|
||||
array,
|
||||
int(array.shape[axis] * ratio),
|
||||
axis=axis,
|
||||
)
|
||||
@ -1,10 +1,15 @@
|
||||
import os
|
||||
|
||||
import click
|
||||
|
||||
from batdetect2 import api
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
||||
from batdetect2.types import ProcessingConfiguration
|
||||
from batdetect2.utils.detector_utils import save_results_to_file
|
||||
|
||||
DEFAULT_MODEL_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"models",
|
||||
"checkpoints",
|
||||
"Net2DFast_UK_same.pth.tar",
|
||||
)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -74,6 +79,9 @@ def detect(
|
||||
|
||||
Input files should be short in duration e.g. < 30 seconds.
|
||||
"""
|
||||
from batdetect2 import api
|
||||
from batdetect2.utils.detector_utils import save_results_to_file
|
||||
|
||||
click.echo(f"Loading model: {args['model_path']}")
|
||||
model, params = api.load_model(args["model_path"])
|
||||
|
||||
@ -123,7 +131,7 @@ def detect(
|
||||
click.echo(f" {err}")
|
||||
|
||||
|
||||
def print_config(config: ProcessingConfiguration):
|
||||
def print_config(config):
|
||||
"""Print the processing configuration."""
|
||||
click.echo("\nProcessing Configuration:")
|
||||
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Optional
|
||||
import click
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
__all__ = ["data"]
|
||||
|
||||
@ -33,6 +32,8 @@ def summary(
|
||||
field: Optional[str] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
):
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
base_dir = base_dir or Path.cwd()
|
||||
dataset = load_dataset_from_config(
|
||||
dataset_config,
|
||||
|
||||
@ -6,18 +6,21 @@ import click
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.evaluate.evaluate import evaluate
|
||||
from batdetect2.train.lightning import load_model_from_checkpoint
|
||||
|
||||
__all__ = ["evaluate_command"]
|
||||
|
||||
|
||||
DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||
|
||||
|
||||
@cli.command(name="evaluate")
|
||||
@click.argument("model-path", type=click.Path(exists=True))
|
||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||
@click.option("--output-dir", type=click.Path())
|
||||
@click.option("--workers", type=int)
|
||||
@click.option("--config", "config_path", type=click.Path())
|
||||
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
|
||||
@click.option("--experiment-name", type=str)
|
||||
@click.option("--run-name", type=str)
|
||||
@click.option("--workers", "num_workers", type=int)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
@ -27,10 +30,17 @@ __all__ = ["evaluate_command"]
|
||||
def evaluate_command(
|
||||
model_path: Path,
|
||||
test_dataset: Path,
|
||||
output_dir: Optional[Path] = None,
|
||||
workers: Optional[int] = None,
|
||||
config_path: Optional[Path],
|
||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||
num_workers: Optional[int] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
from batdetect2.api.base import BatDetect2API
|
||||
from batdetect2.config import load_full_config
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
@ -48,16 +58,16 @@ def evaluate_command(
|
||||
num_annotations=len(test_annotations),
|
||||
)
|
||||
|
||||
model, train_config = load_model_from_checkpoint(model_path)
|
||||
config = None
|
||||
if config_path is not None:
|
||||
config = load_full_config(config_path)
|
||||
|
||||
df, results = evaluate(
|
||||
model,
|
||||
api = BatDetect2API.from_checkpoint(model_path, config=config)
|
||||
|
||||
api.evaluate(
|
||||
test_annotations,
|
||||
config=train_config,
|
||||
num_workers=workers,
|
||||
num_workers=num_workers,
|
||||
output_dir=output_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
|
||||
print(results)
|
||||
|
||||
if output_dir:
|
||||
df.to_csv(output_dir / "results.csv")
|
||||
|
||||
@ -6,13 +6,6 @@ import click
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.targets import load_target_config
|
||||
from batdetect2.train import (
|
||||
FullTrainingConfig,
|
||||
load_full_training_config,
|
||||
train,
|
||||
)
|
||||
|
||||
__all__ = ["train_command"]
|
||||
|
||||
@ -20,8 +13,8 @@ __all__ = ["train_command"]
|
||||
@cli.command(name="train")
|
||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||
@click.option("--model-path", type=click.Path(exists=True))
|
||||
@click.option("--targets", type=click.Path(exists=True))
|
||||
@click.option("--model", "model_path", type=click.Path(exists=True))
|
||||
@click.option("--targets", "targets_config", type=click.Path(exists=True))
|
||||
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
||||
@click.option("--log-dir", type=click.Path(exists=True))
|
||||
@click.option("--config", type=click.Path(exists=True))
|
||||
@ -44,7 +37,7 @@ def train_command(
|
||||
ckpt_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
config: Optional[Path] = None,
|
||||
targets: Optional[Path] = None,
|
||||
targets_config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
train_workers: int = 0,
|
||||
@ -53,6 +46,14 @@ def train_command(
|
||||
run_name: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
from batdetect2.api.base import BatDetect2API
|
||||
from batdetect2.config import (
|
||||
BatDetect2Config,
|
||||
load_full_config,
|
||||
)
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.targets import load_target_config
|
||||
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
@ -61,21 +62,20 @@ def train_command(
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
logger.info("Initiating training process...")
|
||||
|
||||
logger.info("Loading training configuration...")
|
||||
|
||||
logger.info("Loading configuration...")
|
||||
conf = (
|
||||
load_full_training_config(config, field=config_field)
|
||||
load_full_config(config, field=config_field)
|
||||
if config is not None
|
||||
else FullTrainingConfig()
|
||||
else BatDetect2Config()
|
||||
)
|
||||
|
||||
if targets is not None:
|
||||
if targets_config is not None:
|
||||
logger.info("Loading targets configuration...")
|
||||
targets_config = load_target_config(targets)
|
||||
conf = conf.model_copy(update=dict(targets=targets_config))
|
||||
conf = conf.model_copy(
|
||||
update=dict(targets=load_target_config(targets_config))
|
||||
)
|
||||
|
||||
logger.info("Loading training dataset...")
|
||||
train_annotations = load_dataset_from_config(train_dataset)
|
||||
@ -95,16 +95,20 @@ def train_command(
|
||||
logger.debug("No validation directory provided.")
|
||||
|
||||
logger.info("Configuration and data loaded. Starting training...")
|
||||
train(
|
||||
|
||||
if model_path is None:
|
||||
api = BatDetect2API.from_config(conf)
|
||||
else:
|
||||
api = BatDetect2API.from_checkpoint(model_path)
|
||||
|
||||
return api.train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
config=conf,
|
||||
model_path=model_path,
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
experiment_name=experiment_name,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=ckpt_dir,
|
||||
seed=seed,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
@ -11,7 +11,6 @@ from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
from soundevent.types import ClassMapper
|
||||
|
||||
from batdetect2.targets.terms import get_term_from_key
|
||||
from batdetect2.types import (
|
||||
Annotation,
|
||||
AudioLoaderAnnotationGroup,
|
||||
@ -173,18 +172,9 @@ def annotation_to_sound_event_annotation(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=get_term_from_key(label_key),
|
||||
value=annotation["class"],
|
||||
),
|
||||
data.Tag(
|
||||
term=get_term_from_key(event_key),
|
||||
value=annotation["event"],
|
||||
),
|
||||
data.Tag(
|
||||
term=get_term_from_key(individual_key),
|
||||
value=str(annotation["individual"]),
|
||||
),
|
||||
data.Tag(key=label_key, value=annotation["class"]),
|
||||
data.Tag(key=event_key, value=annotation["event"]),
|
||||
data.Tag(key=individual_key, value=str(annotation["individual"])),
|
||||
],
|
||||
)
|
||||
|
||||
@ -219,17 +209,11 @@ def annotation_to_sound_event_prediction(
|
||||
tags=[
|
||||
data.PredictedTag(
|
||||
score=annotation["class_prob"],
|
||||
tag=data.Tag(
|
||||
term=get_term_from_key(label_key),
|
||||
value=annotation["class"],
|
||||
),
|
||||
tag=data.Tag(key=label_key, value=annotation["class"]),
|
||||
),
|
||||
data.PredictedTag(
|
||||
score=annotation["det_prob"],
|
||||
tag=data.Tag(
|
||||
term=get_term_from_key(event_key),
|
||||
value=annotation["event"],
|
||||
),
|
||||
tag=data.Tag(key=event_key, value=annotation["event"]),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
40
src/batdetect2/config.py
Normal file
40
src/batdetect2/config.py
Normal file
@ -0,0 +1,40 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.configs import load_config
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.models.config import BackboneConfig
|
||||
from batdetect2.postprocess.config import PostprocessConfig
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.targets.config import TargetConfig
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
|
||||
__all__ = [
|
||||
"BatDetect2Config",
|
||||
"load_full_config",
|
||||
]
|
||||
|
||||
|
||||
class BatDetect2Config(BaseConfig):
|
||||
config_version: Literal["v1"] = "v1"
|
||||
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
|
||||
|
||||
def load_full_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> BatDetect2Config:
|
||||
return load_config(path, schema=BatDetect2Config, field=field)
|
||||
8
src/batdetect2/core/__init__.py
Normal file
8
src/batdetect2/core/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.core.registries import Registry
|
||||
|
||||
__all__ = [
|
||||
"BaseConfig",
|
||||
"load_config",
|
||||
"Registry",
|
||||
]
|
||||
@ -27,7 +27,7 @@ class BaseConfig(BaseModel):
|
||||
and serialization capabilities.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
def to_yaml_string(
|
||||
self,
|
||||
@ -1,7 +1,13 @@
|
||||
import sys
|
||||
from typing import Generic, Protocol, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import ParamSpec
|
||||
from typing_extensions import assert_type
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import ParamSpec
|
||||
else:
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
__all__ = [
|
||||
"Registry",
|
||||
@ -39,7 +45,6 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
config_cls: Type[T_Config],
|
||||
logic_cls: LogicProtocol[T_Config, T_Type, P_Type],
|
||||
) -> None:
|
||||
"""A decorator factory to register a new item."""
|
||||
fields = config_cls.model_fields
|
||||
|
||||
if "name" not in fields:
|
||||
@ -18,7 +18,7 @@ from uuid import uuid5
|
||||
from pydantic import Field
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -33,7 +33,7 @@ from loguru import logger
|
||||
from pydantic import Field, ValidationError
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.data.annotations.legacy import (
|
||||
FileAnnotation,
|
||||
file_annotation_to_clip,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from pathlib import Path
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"AnnotatedDataset",
|
||||
|
||||
@ -5,8 +5,8 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
|
||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.data.annotations import (
|
||||
AnnotatedDataset,
|
||||
AnnotationFormats,
|
||||
|
||||
@ -4,8 +4,8 @@ from typing import Annotated, Dict, List, Literal, Optional, Union
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.data.conditions import (
|
||||
SoundEventCondition,
|
||||
SoundEventConditionConfig,
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||
from batdetect2.evaluate.evaluate import evaluate
|
||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||
|
||||
__all__ = [
|
||||
"EvaluationConfig",
|
||||
"load_evaluation_config",
|
||||
"evaluate",
|
||||
"Evaluator",
|
||||
"build_evaluator",
|
||||
]
|
||||
|
||||
@ -4,8 +4,8 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import compute_affinity
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.typing.evaluate import AffinityFunction
|
||||
|
||||
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
||||
|
||||
@ -3,14 +3,15 @@ from typing import List, Optional
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
|
||||
from batdetect2.evaluate.metrics import (
|
||||
ClassificationAPConfig,
|
||||
DetectionAPConfig,
|
||||
MetricConfig,
|
||||
)
|
||||
from batdetect2.evaluate.plots import ExampleGalleryConfig, PlotConfig
|
||||
from batdetect2.evaluate.plots import PlotConfig
|
||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
|
||||
|
||||
__all__ = [
|
||||
"EvaluationConfig",
|
||||
@ -20,18 +21,15 @@ __all__ = [
|
||||
|
||||
class EvaluationConfig(BaseConfig):
|
||||
ignore_start_end: float = 0.01
|
||||
match: MatchConfig = Field(default_factory=StartTimeMatchConfig)
|
||||
match_strategy: MatchConfig = Field(default_factory=StartTimeMatchConfig)
|
||||
metrics: List[MetricConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
DetectionAPConfig(),
|
||||
ClassificationAPConfig(),
|
||||
]
|
||||
)
|
||||
plots: List[PlotConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ExampleGalleryConfig(),
|
||||
]
|
||||
)
|
||||
plots: List[PlotConfig] = Field(default_factory=list)
|
||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
|
||||
|
||||
def load_evaluation_config(
|
||||
|
||||
144
src/batdetect2/evaluate/dataset.py
Normal file
144
src/batdetect2/evaluate/dataset.py
Normal file
@ -0,0 +1,144 @@
|
||||
from typing import List, NamedTuple, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
||||
from batdetect2.audio.clips import PaddedClipConfig
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipperProtocol,
|
||||
PreprocessorProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TestDataset",
|
||||
"build_test_dataset",
|
||||
"build_test_loader",
|
||||
]
|
||||
|
||||
|
||||
class TestExample(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
idx: torch.Tensor
|
||||
start_time: torch.Tensor
|
||||
end_time: torch.Tensor
|
||||
|
||||
|
||||
class TestDataset(Dataset[TestExample]):
|
||||
clip_annotations: List[data.ClipAnnotation]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
clipper: Optional[ClipperProtocol] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
):
|
||||
self.clip_annotations = list(clip_annotations)
|
||||
self.clipper = clipper
|
||||
self.preprocessor = preprocessor
|
||||
self.audio_loader = audio_loader
|
||||
self.audio_dir = audio_dir
|
||||
|
||||
def __len__(self):
|
||||
return len(self.clip_annotations)
|
||||
|
||||
def __getitem__(self, idx: int) -> TestExample:
|
||||
clip_annotation = self.clip_annotations[idx]
|
||||
|
||||
if self.clipper is not None:
|
||||
clip_annotation = self.clipper(clip_annotation)
|
||||
|
||||
clip = clip_annotation.clip
|
||||
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
||||
spectrogram = self.preprocessor(wav_tensor)
|
||||
return TestExample(
|
||||
spec=spectrogram,
|
||||
idx=torch.tensor(idx),
|
||||
start_time=torch.tensor(clip.start_time),
|
||||
end_time=torch.tensor(clip.end_time),
|
||||
)
|
||||
|
||||
|
||||
class TestLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: PaddedClipConfig()
|
||||
)
|
||||
|
||||
|
||||
def build_test_loader(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[TestLoaderConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> DataLoader[TestExample]:
|
||||
logger.info("Building test data loader...")
|
||||
config = config or TestLoaderConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Test data loader config: \n{config}",
|
||||
config=lambda: config.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
|
||||
test_dataset = build_test_dataset(
|
||||
clip_annotations,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
config=config,
|
||||
)
|
||||
|
||||
num_workers = num_workers or config.num_workers
|
||||
return DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=_collate_fn,
|
||||
)
|
||||
|
||||
|
||||
def build_test_dataset(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[TestLoaderConfig] = None,
|
||||
) -> TestDataset:
|
||||
logger.info("Building training dataset...")
|
||||
config = config or TestLoaderConfig()
|
||||
|
||||
clipper = build_clipper(config=config.clipping_strategy)
|
||||
|
||||
if audio_loader is None:
|
||||
audio_loader = build_audio_loader()
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
return TestDataset(
|
||||
clip_annotations,
|
||||
audio_loader=audio_loader,
|
||||
clipper=clipper,
|
||||
preprocessor=preprocessor,
|
||||
)
|
||||
|
||||
|
||||
def _collate_fn(batch: List[TestExample]) -> TestExample:
|
||||
max_width = max(item.spec.shape[-1] for item in batch)
|
||||
return TestExample(
|
||||
spec=torch.stack(
|
||||
[adjust_width(item.spec, max_width) for item in batch]
|
||||
),
|
||||
idx=torch.stack([item.idx for item in batch]),
|
||||
start_time=torch.stack([item.start_time for item in batch]),
|
||||
end_time=torch.stack([item.end_time for item in batch]),
|
||||
)
|
||||
@ -1,92 +1,68 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, Sequence
|
||||
|
||||
import pandas as pd
|
||||
from lightning import Trainer
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.dataframe import extract_matches_dataframe
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.evaluate.dataset import build_test_loader
|
||||
from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.evaluate.metrics import ClassificationAP, DetectionAP
|
||||
from batdetect2.evaluate.lightning import EvaluationModule
|
||||
from batdetect2.logging import build_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.plotting.clips import build_audio_loader
|
||||
from batdetect2.postprocess import get_raw_predictions
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.config import FullTrainingConfig
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.train import build_val_loader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
DEFAULT_OUTPUT_DIR: Path = Path("outputs") / "evaluations"
|
||||
|
||||
|
||||
def evaluate(
|
||||
model: Model,
|
||||
test_annotations: List[data.ClipAnnotation],
|
||||
config: Optional[FullTrainingConfig] = None,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> Tuple[pd.DataFrame, dict]:
|
||||
config = config or FullTrainingConfig()
|
||||
output_dir: data.PathLike = DEFAULT_OUTPUT_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
):
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
audio_loader = build_audio_loader(config.preprocess.audio)
|
||||
config = config or BatDetect2Config()
|
||||
|
||||
preprocessor = build_preprocessor(config.preprocess)
|
||||
audio_loader = audio_loader or build_audio_loader()
|
||||
|
||||
targets = build_targets(config.targets)
|
||||
|
||||
labeller = build_clip_labeler(
|
||||
targets,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
config=config.train.labels,
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
|
||||
loader = build_val_loader(
|
||||
targets = targets or build_targets()
|
||||
|
||||
loader = build_test_loader(
|
||||
test_annotations,
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
preprocessor=preprocessor,
|
||||
config=config.train.val_loader,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
dataset: ValidationDataset = loader.dataset # type: ignore
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
|
||||
clip_annotations = []
|
||||
predictions = []
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation)
|
||||
|
||||
for batch in loader:
|
||||
outputs = model.detector(batch.spec)
|
||||
|
||||
clip_annotations = [
|
||||
dataset.clip_annotations[int(example_idx)]
|
||||
for example_idx in batch.idx
|
||||
]
|
||||
|
||||
predictions = get_raw_predictions(
|
||||
outputs,
|
||||
start_times=[
|
||||
clip_annotation.clip.start_time
|
||||
for clip_annotation in clip_annotations
|
||||
],
|
||||
targets=targets,
|
||||
postprocessor=model.postprocessor,
|
||||
)
|
||||
|
||||
clip_annotations.extend(clip_annotations)
|
||||
predictions.extend(predictions)
|
||||
|
||||
matches = evaluator.evaluate(clip_annotations, predictions)
|
||||
df = extract_matches_dataframe(matches)
|
||||
|
||||
metrics = [
|
||||
DetectionAP(),
|
||||
ClassificationAP(class_names=targets.class_names),
|
||||
]
|
||||
|
||||
results = {
|
||||
name: value
|
||||
for metric in metrics
|
||||
for name, value in metric(matches).items()
|
||||
}
|
||||
|
||||
return df, results
|
||||
logger = build_logger(
|
||||
config.evaluation.logger,
|
||||
log_dir=Path(output_dir),
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
module = EvaluationModule(model, evaluator)
|
||||
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
||||
return trainer.test(module, loader)
|
||||
|
||||
@ -11,6 +11,7 @@ from batdetect2.evaluate.plots import build_plotter
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing.evaluate import (
|
||||
ClipEvaluation,
|
||||
EvaluatorProtocol,
|
||||
MatcherProtocol,
|
||||
MetricsProtocol,
|
||||
PlotterProtocol,
|
||||
@ -135,10 +136,10 @@ def build_evaluator(
|
||||
matcher: Optional[MatcherProtocol] = None,
|
||||
plots: Optional[List[PlotterProtocol]] = None,
|
||||
metrics: Optional[List[MetricsProtocol]] = None,
|
||||
) -> Evaluator:
|
||||
) -> EvaluatorProtocol:
|
||||
config = config or EvaluationConfig()
|
||||
targets = targets or build_targets()
|
||||
matcher = matcher or build_matcher(config.match)
|
||||
matcher = matcher or build_matcher(config.match_strategy)
|
||||
|
||||
if metrics is None:
|
||||
metrics = [
|
||||
@ -147,7 +148,10 @@ def build_evaluator(
|
||||
]
|
||||
|
||||
if plots is None:
|
||||
plots = [build_plotter(config) for config in config.plots]
|
||||
plots = [
|
||||
build_plotter(config, targets.class_names)
|
||||
for config in config.plots
|
||||
]
|
||||
|
||||
return Evaluator(
|
||||
config=config,
|
||||
|
||||
86
src/batdetect2/evaluate/lightning.py
Normal file
86
src/batdetect2/evaluate/lightning.py
Normal file
@ -0,0 +1,86 @@
|
||||
from typing import Sequence
|
||||
|
||||
from lightning import LightningModule
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||
from batdetect2.evaluate.tables import FullEvaluationTable
|
||||
from batdetect2.logging import get_image_logger, get_table_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing import ClipEvaluation, EvaluatorProtocol
|
||||
|
||||
|
||||
class EvaluationModule(LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
evaluator: EvaluatorProtocol,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
self.evaluator = evaluator
|
||||
|
||||
self.clip_evaluations = []
|
||||
|
||||
def test_step(self, batch: TestExample):
|
||||
dataset = self.get_dataset()
|
||||
clip_annotations = [
|
||||
dataset.clip_annotations[int(example_idx)]
|
||||
for example_idx in batch.idx
|
||||
]
|
||||
|
||||
outputs = self.model.detector(batch.spec)
|
||||
clip_detections = self.model.postprocessor(
|
||||
outputs,
|
||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||
)
|
||||
predictions = [
|
||||
to_raw_predictions(
|
||||
clip_dets.numpy(),
|
||||
targets=self.evaluator.targets,
|
||||
)
|
||||
for clip_dets in clip_detections
|
||||
]
|
||||
|
||||
self.clip_evaluations.extend(
|
||||
self.evaluator.evaluate(clip_annotations, predictions)
|
||||
)
|
||||
|
||||
def on_test_epoch_start(self):
|
||||
self.clip_evaluations = []
|
||||
|
||||
def on_test_epoch_end(self):
|
||||
self.log_metrics(self.clip_evaluations)
|
||||
self.plot_examples(self.clip_evaluations)
|
||||
self.log_table(self.clip_evaluations)
|
||||
|
||||
def log_table(self, evaluated_clips: Sequence[ClipEvaluation]):
|
||||
table_logger = get_table_logger(self.logger) # type: ignore
|
||||
|
||||
if table_logger is None:
|
||||
return
|
||||
|
||||
df = FullEvaluationTable()(evaluated_clips)
|
||||
table_logger("full_evaluation", df, 0)
|
||||
|
||||
def plot_examples(self, evaluated_clips: Sequence[ClipEvaluation]):
|
||||
plotter = get_image_logger(self.logger) # type: ignore
|
||||
|
||||
if plotter is None:
|
||||
return
|
||||
|
||||
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
|
||||
plotter(figure_name, fig, self.global_step)
|
||||
|
||||
def log_metrics(self, evaluated_clips: Sequence[ClipEvaluation]):
|
||||
metrics = self.evaluator.compute_metrics(evaluated_clips)
|
||||
self.log_dict(metrics)
|
||||
|
||||
def get_dataset(self) -> TestDataset:
|
||||
dataloaders = self.trainer.test_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
dataset = dataloaders.dataset
|
||||
assert isinstance(dataset, TestDataset)
|
||||
return dataset
|
||||
@ -8,8 +8,8 @@ from soundevent.evaluation import compute_affinity
|
||||
from soundevent.evaluation import match_geometries as optimal_match
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.evaluate.affinity import (
|
||||
AffinityConfig,
|
||||
GeometricIOUConfig,
|
||||
@ -111,7 +111,7 @@ def match(
|
||||
|
||||
|
||||
class StartTimeMatchConfig(BaseConfig):
|
||||
name: Literal["start_time"] = "start_time"
|
||||
name: Literal["start_time_match"] = "start_time_match"
|
||||
distance_threshold: float = 0.01
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import (
|
||||
Annotated,
|
||||
@ -12,13 +13,10 @@ from typing import (
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from sklearn.preprocessing import label_binarize
|
||||
from sklearn import metrics, preprocessing
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.typing import MetricsProtocol
|
||||
from batdetect2.typing.evaluate import ClipEvaluation
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.typing import ClipEvaluation, MetricsProtocol
|
||||
|
||||
__all__ = ["DetectionAP", "ClassificationAP"]
|
||||
|
||||
@ -26,57 +24,18 @@ __all__ = ["DetectionAP", "ClassificationAP"]
|
||||
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
|
||||
|
||||
|
||||
AveragePrecisionImplementation = Literal["sklearn", "pascal_voc"]
|
||||
APImplementation = Literal["sklearn", "pascal_voc"]
|
||||
|
||||
|
||||
class DetectionAPConfig(BaseConfig):
|
||||
name: Literal["detection_ap"] = "detection_ap"
|
||||
implementation: AveragePrecisionImplementation = "pascal_voc"
|
||||
|
||||
|
||||
def pascal_voc_average_precision(y_true, y_score) -> float:
|
||||
y_true = np.array(y_true)
|
||||
y_score = np.array(y_score)
|
||||
|
||||
sort_ind = np.argsort(y_score)[::-1]
|
||||
y_true_sorted = y_true[sort_ind]
|
||||
|
||||
num_positives = y_true.sum()
|
||||
false_pos_c = np.cumsum(1 - y_true_sorted)
|
||||
true_pos_c = np.cumsum(y_true_sorted)
|
||||
|
||||
recall = true_pos_c / num_positives
|
||||
precision = true_pos_c / np.maximum(
|
||||
true_pos_c + false_pos_c,
|
||||
np.finfo(np.float64).eps,
|
||||
)
|
||||
|
||||
precision[np.isnan(precision)] = 0
|
||||
recall[np.isnan(recall)] = 0
|
||||
|
||||
# pascal 12 way
|
||||
mprec = np.hstack((0, precision, 0))
|
||||
mrec = np.hstack((0, recall, 1))
|
||||
for ii in range(mprec.shape[0] - 2, -1, -1):
|
||||
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
||||
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
||||
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
||||
|
||||
return ave_prec
|
||||
|
||||
|
||||
_ap_impl_mapping: Mapping[
|
||||
AveragePrecisionImplementation, Callable[[Any, Any], float]
|
||||
] = {
|
||||
"sklearn": metrics.average_precision_score,
|
||||
"pascal_voc": pascal_voc_average_precision,
|
||||
}
|
||||
ap_implementation: APImplementation = "pascal_voc"
|
||||
|
||||
|
||||
class DetectionAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
implementation: AveragePrecisionImplementation = "pascal_voc",
|
||||
implementation: APImplementation = "pascal_voc",
|
||||
):
|
||||
self.implementation = implementation
|
||||
self.metric = _ap_impl_mapping[self.implementation]
|
||||
@ -96,14 +55,43 @@ class DetectionAP(MetricsProtocol):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
|
||||
return cls(implementation=config.implementation)
|
||||
return cls(implementation=config.ap_implementation)
|
||||
|
||||
|
||||
metrics_registry.register(DetectionAPConfig, DetectionAP)
|
||||
|
||||
|
||||
class DetectionROCAUCConfig(BaseConfig):
|
||||
name: Literal["detection_roc_auc"] = "detection_roc_auc"
|
||||
|
||||
|
||||
class DetectionROCAUC(MetricsProtocol):
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true, y_score = zip(
|
||||
*[
|
||||
(match.gt_det, match.pred_score)
|
||||
for clip_eval in clip_evaluations
|
||||
for match in clip_eval.matches
|
||||
]
|
||||
)
|
||||
score = float(metrics.roc_auc_score(y_true, y_score))
|
||||
return {"detection_ROC_AUC": score}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: DetectionROCAUCConfig, class_names: List[str]
|
||||
):
|
||||
return cls()
|
||||
|
||||
|
||||
metrics_registry.register(DetectionROCAUCConfig, DetectionROCAUC)
|
||||
|
||||
|
||||
class ClassificationAPConfig(BaseConfig):
|
||||
name: Literal["classification_ap"] = "classification_ap"
|
||||
ap_implementation: APImplementation = "pascal_voc"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
@ -112,7 +100,7 @@ class ClassificationAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
implementation: AveragePrecisionImplementation = "pascal_voc",
|
||||
implementation: APImplementation = "pascal_voc",
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
@ -163,7 +151,7 @@ class ClassificationAP(MetricsProtocol):
|
||||
)
|
||||
)
|
||||
|
||||
y_true = label_binarize(y_true, classes=self.class_names)
|
||||
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
class_scores = {}
|
||||
@ -193,6 +181,7 @@ class ClassificationAP(MetricsProtocol):
|
||||
):
|
||||
return cls(
|
||||
class_names,
|
||||
implementation=config.ap_implementation,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
@ -201,11 +190,523 @@ class ClassificationAP(MetricsProtocol):
|
||||
metrics_registry.register(ClassificationAPConfig, ClassificationAP)
|
||||
|
||||
|
||||
class ClassificationROCAUCConfig(BaseConfig):
|
||||
name: Literal["classification_roc_auc"] = "classification_roc_auc"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClassificationROCAUC(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.class_names = class_names
|
||||
self.selected = class_names
|
||||
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(
|
||||
match.gt_class
|
||||
if match.gt_class is not None
|
||||
else "__NONE__"
|
||||
)
|
||||
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
class_scores = {}
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_pred_class = y_pred[:, class_index]
|
||||
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
|
||||
class_scores[class_name] = float(class_roc_auc)
|
||||
|
||||
mean_roc_auc = np.mean(
|
||||
[value for value in class_scores.values() if value != 0]
|
||||
)
|
||||
|
||||
return {
|
||||
"classification_macro_average_ROC_AUC": float(mean_roc_auc),
|
||||
**{
|
||||
f"classification_ROC_AUC/{class_name}": class_scores[
|
||||
class_name
|
||||
]
|
||||
for class_name in self.selected
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClassificationROCAUCConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
class_names,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
|
||||
|
||||
metrics_registry.register(ClassificationROCAUCConfig, ClassificationROCAUC)
|
||||
|
||||
|
||||
class TopClassAPConfig(BaseConfig):
|
||||
name: Literal["top_class_ap"] = "top_class_ap"
|
||||
ap_implementation: APImplementation = "pascal_voc"
|
||||
|
||||
|
||||
class TopClassAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
implementation: APImplementation = "pascal_voc",
|
||||
):
|
||||
self.implementation = implementation
|
||||
self.metric = _ap_impl_mapping[self.implementation]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
top_class = match.pred_class
|
||||
|
||||
y_true.append(top_class == match.gt_class)
|
||||
y_score.append(match.pred_class_score)
|
||||
|
||||
score = float(self.metric(y_true, y_score))
|
||||
return {"top_class_AP": score}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: TopClassAPConfig, class_names: List[str]):
|
||||
return cls(implementation=config.ap_implementation)
|
||||
|
||||
|
||||
metrics_registry.register(TopClassAPConfig, TopClassAP)
|
||||
|
||||
|
||||
class ClassificationBalancedAccuracyConfig(BaseConfig):
|
||||
name: Literal["classification_balanced_accuracy"] = (
|
||||
"classification_balanced_accuracy"
|
||||
)
|
||||
|
||||
|
||||
class ClassificationBalancedAccuracy(MetricsProtocol):
|
||||
def __init__(self, class_names: List[str]):
|
||||
self.class_names = class_names
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
top_class = match.pred_class
|
||||
|
||||
# Focus on matches
|
||||
if match.gt_class is None or top_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(self.class_names.index(match.gt_class))
|
||||
y_pred.append(self.class_names.index(top_class))
|
||||
|
||||
score = float(metrics.balanced_accuracy_score(y_true, y_pred))
|
||||
return {"classification_balanced_accuracy": score}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClassificationBalancedAccuracyConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(class_names)
|
||||
|
||||
|
||||
metrics_registry.register(
|
||||
ClassificationBalancedAccuracyConfig,
|
||||
ClassificationBalancedAccuracy,
|
||||
)
|
||||
|
||||
|
||||
class ClipDetectionAPConfig(BaseConfig):
|
||||
name: Literal["clip_detection_ap"] = "clip_detection_ap"
|
||||
ap_implementation: APImplementation = "pascal_voc"
|
||||
|
||||
|
||||
class ClipDetectionAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
implementation: APImplementation,
|
||||
):
|
||||
self.implementation = implementation
|
||||
self.metric = _ap_impl_mapping[self.implementation]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
clip_det = []
|
||||
clip_scores = []
|
||||
|
||||
for match in clip_eval.matches:
|
||||
clip_det.append(match.gt_det)
|
||||
clip_scores.append(match.pred_score)
|
||||
|
||||
y_true.append(any(clip_det))
|
||||
y_score.append(max(clip_scores or [0]))
|
||||
|
||||
return {"clip_detection_ap": self.metric(y_true, y_score)}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClipDetectionAPConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(implementation=config.ap_implementation)
|
||||
|
||||
|
||||
metrics_registry.register(ClipDetectionAPConfig, ClipDetectionAP)
|
||||
|
||||
|
||||
class ClipDetectionROCAUCConfig(BaseConfig):
|
||||
name: Literal["clip_detection_roc_auc"] = "clip_detection_roc_auc"
|
||||
|
||||
|
||||
class ClipDetectionROCAUC(MetricsProtocol):
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
clip_det = []
|
||||
clip_scores = []
|
||||
|
||||
for match in clip_eval.matches:
|
||||
clip_det.append(match.gt_det)
|
||||
clip_scores.append(match.pred_score)
|
||||
|
||||
y_true.append(any(clip_det))
|
||||
y_score.append(max(clip_scores or [0]))
|
||||
|
||||
return {
|
||||
"clip_detection_ap": float(metrics.roc_auc_score(y_true, y_score))
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClipDetectionROCAUCConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls()
|
||||
|
||||
|
||||
metrics_registry.register(ClipDetectionROCAUCConfig, ClipDetectionROCAUC)
|
||||
|
||||
|
||||
class ClipMulticlassAPConfig(BaseConfig):
|
||||
name: Literal["clip_multiclass_ap"] = "clip_multiclass_ap"
|
||||
ap_implementation: APImplementation = "pascal_voc"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClipMulticlassAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
implementation: APImplementation,
|
||||
include: Optional[Sequence[str]] = None,
|
||||
exclude: Optional[Sequence[str]] = None,
|
||||
):
|
||||
self.implementation = implementation
|
||||
self.metric = _ap_impl_mapping[self.implementation]
|
||||
self.class_names = class_names
|
||||
|
||||
self.selected = class_names
|
||||
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
clip_classes = set()
|
||||
clip_scores = defaultdict(list)
|
||||
|
||||
for match in clip_eval.matches:
|
||||
if match.gt_class is not None:
|
||||
clip_classes.add(match.gt_class)
|
||||
|
||||
for class_name, score in match.pred_class_scores.items():
|
||||
clip_scores[class_name].append(score)
|
||||
|
||||
y_true.append(clip_classes)
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
# Get max score for each class
|
||||
max(clip_scores.get(class_name, [0]))
|
||||
for class_name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = preprocessing.MultiLabelBinarizer(
|
||||
classes=self.class_names
|
||||
).fit_transform(y_true)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
class_scores = {}
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_pred_class = y_pred[:, class_index]
|
||||
class_ap = self.metric(y_true_class, y_pred_class)
|
||||
class_scores[class_name] = float(class_ap)
|
||||
|
||||
mean_ap = np.mean(
|
||||
[value for value in class_scores.values() if value != 0]
|
||||
)
|
||||
return {
|
||||
"clip_multiclass_mAP": float(mean_ap),
|
||||
**{
|
||||
f"clip_multiclass_AP/{class_name}": class_scores[class_name]
|
||||
for class_name in self.selected
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: ClipMulticlassAPConfig, class_names: List[str]
|
||||
):
|
||||
return cls(
|
||||
implementation=config.ap_implementation,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
class_names=class_names,
|
||||
)
|
||||
|
||||
|
||||
metrics_registry.register(ClipMulticlassAPConfig, ClipMulticlassAP)
|
||||
|
||||
|
||||
class ClipMulticlassROCAUCConfig(BaseConfig):
|
||||
name: Literal["clip_multiclass_roc_auc"] = "clip_multiclass_roc_auc"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClipMulticlassROCAUC(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
include: Optional[Sequence[str]] = None,
|
||||
exclude: Optional[Sequence[str]] = None,
|
||||
):
|
||||
self.class_names = class_names
|
||||
self.selected = class_names
|
||||
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
clip_classes = set()
|
||||
clip_scores = defaultdict(list)
|
||||
|
||||
for match in clip_eval.matches:
|
||||
if match.gt_class is not None:
|
||||
clip_classes.add(match.gt_class)
|
||||
|
||||
for class_name, score in match.pred_class_scores.items():
|
||||
clip_scores[class_name].append(score)
|
||||
|
||||
y_true.append(clip_classes)
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
# Get maximum score for each class
|
||||
max(clip_scores.get(class_name, [0]))
|
||||
for class_name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = preprocessing.MultiLabelBinarizer(
|
||||
classes=self.class_names
|
||||
).fit_transform(y_true)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
class_scores = {}
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_pred_class = y_pred[:, class_index]
|
||||
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
|
||||
class_scores[class_name] = float(class_roc_auc)
|
||||
|
||||
mean_roc_auc = np.mean(
|
||||
[value for value in class_scores.values() if value != 0]
|
||||
)
|
||||
return {
|
||||
"clip_multiclass_macro_ROC_AUC": float(mean_roc_auc),
|
||||
**{
|
||||
f"clip_multiclass_ROC_AUC/{class_name}": class_scores[
|
||||
class_name
|
||||
]
|
||||
for class_name in self.selected
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClipMulticlassROCAUCConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
class_names=class_names,
|
||||
)
|
||||
|
||||
|
||||
metrics_registry.register(ClipMulticlassROCAUCConfig, ClipMulticlassROCAUC)
|
||||
|
||||
MetricConfig = Annotated[
|
||||
Union[ClassificationAPConfig, DetectionAPConfig],
|
||||
Union[
|
||||
DetectionAPConfig,
|
||||
DetectionROCAUCConfig,
|
||||
ClassificationAPConfig,
|
||||
ClassificationROCAUCConfig,
|
||||
TopClassAPConfig,
|
||||
ClassificationBalancedAccuracyConfig,
|
||||
ClipDetectionAPConfig,
|
||||
ClipDetectionROCAUCConfig,
|
||||
ClipMulticlassAPConfig,
|
||||
ClipMulticlassROCAUCConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_metric(config: MetricConfig, class_names: List[str]):
|
||||
return metrics_registry.build(config, class_names)
|
||||
|
||||
|
||||
def pascal_voc_average_precision(y_true, y_score) -> float:
|
||||
y_true = np.array(y_true)
|
||||
y_score = np.array(y_score)
|
||||
|
||||
sort_ind = np.argsort(y_score)[::-1]
|
||||
y_true_sorted = y_true[sort_ind]
|
||||
|
||||
num_positives = y_true.sum()
|
||||
false_pos_c = np.cumsum(1 - y_true_sorted)
|
||||
true_pos_c = np.cumsum(y_true_sorted)
|
||||
|
||||
recall = true_pos_c / num_positives
|
||||
precision = true_pos_c / np.maximum(
|
||||
true_pos_c + false_pos_c,
|
||||
np.finfo(np.float64).eps,
|
||||
)
|
||||
|
||||
precision[np.isnan(precision)] = 0
|
||||
recall[np.isnan(recall)] = 0
|
||||
|
||||
# pascal 12 way
|
||||
mprec = np.hstack((0, precision, 0))
|
||||
mrec = np.hstack((0, recall, 1))
|
||||
for ii in range(mprec.shape[0] - 2, -1, -1):
|
||||
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
||||
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
||||
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
||||
|
||||
return ave_prec
|
||||
|
||||
|
||||
_ap_impl_mapping: Mapping[APImplementation, Callable[[Any, Any], float]] = {
|
||||
"sklearn": metrics.average_precision_score,
|
||||
"pascal_voc": pascal_voc_average_precision,
|
||||
}
|
||||
|
||||
@ -4,20 +4,24 @@ from dataclasses import dataclass, field
|
||||
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from sklearn.preprocessing import label_binarize
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.plotting.gallery import plot_match_gallery
|
||||
from batdetect2.plotting.matches import plot_matches
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing.evaluate import (
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipEvaluation,
|
||||
MatchEvaluation,
|
||||
PlotterProtocol,
|
||||
PreprocessorProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
|
||||
__all__ = [
|
||||
"build_plotter",
|
||||
@ -26,12 +30,13 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
plots_registry: Registry[PlotterProtocol, []] = Registry("plot")
|
||||
plots_registry: Registry[PlotterProtocol, [List[str]]] = Registry("plot")
|
||||
|
||||
|
||||
class ExampleGalleryConfig(BaseConfig):
|
||||
name: Literal["example_gallery"] = "example_gallery"
|
||||
examples_per_class: int = 5
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
@ -87,9 +92,12 @@ class ExampleGallery(PlotterProtocol):
|
||||
plt.close(fig)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ExampleGalleryConfig):
|
||||
preprocessor = build_preprocessor(config.preprocessing)
|
||||
audio_loader = build_audio_loader(config.preprocessing.audio)
|
||||
def from_config(cls, config: ExampleGalleryConfig, class_names: List[str]):
|
||||
audio_loader = build_audio_loader(config.audio)
|
||||
preprocessor = build_preprocessor(
|
||||
config.preprocessing,
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
return cls(
|
||||
examples_per_class=config.examples_per_class,
|
||||
preprocessor=preprocessor,
|
||||
@ -100,13 +108,402 @@ class ExampleGallery(PlotterProtocol):
|
||||
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
|
||||
|
||||
|
||||
class ClipEvaluationPlotConfig(BaseConfig):
|
||||
name: Literal["example_clip"] = "example_clip"
|
||||
num_plots: int = 5
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
|
||||
|
||||
class PlotClipEvaluation(PlotterProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
num_plots: int = 3,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
):
|
||||
self.preprocessor = preprocessor
|
||||
self.audio_loader = audio_loader
|
||||
self.num_plots = num_plots
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
examples = random.sample(
|
||||
clip_evaluations,
|
||||
k=min(self.num_plots, len(clip_evaluations)),
|
||||
)
|
||||
|
||||
for index, clip_evaluation in enumerate(examples):
|
||||
fig, ax = plt.subplots()
|
||||
plot_matches(
|
||||
clip_evaluation.matches,
|
||||
clip=clip_evaluation.clip,
|
||||
audio_loader=self.audio_loader,
|
||||
ax=ax,
|
||||
)
|
||||
yield f"clip_evaluation/example_{index}", fig
|
||||
plt.close(fig)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClipEvaluationPlotConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
audio_loader = build_audio_loader(config.audio)
|
||||
preprocessor = build_preprocessor(
|
||||
config.preprocessing,
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
return cls(
|
||||
num_plots=config.num_plots,
|
||||
preprocessor=preprocessor,
|
||||
audio_loader=audio_loader,
|
||||
)
|
||||
|
||||
|
||||
plots_registry.register(ClipEvaluationPlotConfig, PlotClipEvaluation)
|
||||
|
||||
|
||||
class DetectionPRCurveConfig(BaseConfig):
|
||||
name: Literal["detection_pr_curve"] = "detection_pr_curve"
|
||||
|
||||
|
||||
class DetectionPRCurve(PlotterProtocol):
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
y_true, y_score = zip(
|
||||
*[
|
||||
(match.gt_det, match.pred_score)
|
||||
for clip_eval in clip_evaluations
|
||||
for match in clip_eval.matches
|
||||
]
|
||||
)
|
||||
precision, recall, _ = metrics.precision_recall_curve(y_true, y_score)
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
ax.plot(recall, precision, label="Detector")
|
||||
ax.set_xlabel("Recall")
|
||||
ax.set_ylabel("Precision")
|
||||
ax.legend()
|
||||
|
||||
yield "detection_pr_curve", fig
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: DetectionPRCurveConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls()
|
||||
|
||||
|
||||
plots_registry.register(DetectionPRCurveConfig, DetectionPRCurve)
|
||||
|
||||
|
||||
class ClassificationPRCurvesConfig(BaseConfig):
|
||||
name: Literal["classification_pr_curves"] = "classification_pr_curves"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClassificationPRCurves(PlotterProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.class_names = class_names
|
||||
self.selected = class_names
|
||||
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(
|
||||
match.gt_class
|
||||
if match.gt_class is not None
|
||||
else "__NONE__"
|
||||
)
|
||||
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = label_binarize(y_true, classes=self.class_names)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 10))
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
if class_name not in self.selected:
|
||||
continue
|
||||
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_pred_class = y_pred[:, class_index]
|
||||
precision, recall, _ = metrics.precision_recall_curve(
|
||||
y_true_class,
|
||||
y_pred_class,
|
||||
)
|
||||
ax.plot(recall, precision, label=class_name)
|
||||
|
||||
ax.set_xlabel("Recall")
|
||||
ax.set_ylabel("Precision")
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.05, 1),
|
||||
loc="upper left",
|
||||
borderaxespad=0.0,
|
||||
)
|
||||
|
||||
yield "classification_pr_curve", fig
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClassificationPRCurvesConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
class_names=class_names,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
|
||||
|
||||
plots_registry.register(ClassificationPRCurvesConfig, ClassificationPRCurves)
|
||||
|
||||
|
||||
class DetectionROCCurveConfig(BaseConfig):
|
||||
name: Literal["detection_roc_curve"] = "detection_roc_curve"
|
||||
|
||||
|
||||
class DetectionROCCurve(PlotterProtocol):
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
y_true, y_score = zip(
|
||||
*[
|
||||
(match.gt_det, match.pred_score)
|
||||
for clip_eval in clip_evaluations
|
||||
for match in clip_eval.matches
|
||||
]
|
||||
)
|
||||
fpr, tpr, _ = metrics.roc_curve(y_true, y_score)
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
ax.plot(fpr, tpr, label="Detection")
|
||||
ax.set_xlabel("False Positive Rate")
|
||||
ax.set_ylabel("True Positive Rate")
|
||||
ax.legend()
|
||||
|
||||
yield "detection_roc_curve", fig
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: DetectionROCCurveConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls()
|
||||
|
||||
|
||||
plots_registry.register(DetectionROCCurveConfig, DetectionROCCurve)
|
||||
|
||||
|
||||
class ClassificationROCCurvesConfig(BaseConfig):
|
||||
name: Literal["classification_roc_curves"] = "classification_roc_curves"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClassificationROCCurves(PlotterProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.class_names = class_names
|
||||
self.selected = class_names
|
||||
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(
|
||||
match.gt_class
|
||||
if match.gt_class is not None
|
||||
else "__NONE__"
|
||||
)
|
||||
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = label_binarize(y_true, classes=self.class_names)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 10))
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
if class_name not in self.selected:
|
||||
continue
|
||||
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_roced_class = y_pred[:, class_index]
|
||||
fpr, tpr, _ = metrics.roc_curve(
|
||||
y_true_class,
|
||||
y_roced_class,
|
||||
)
|
||||
ax.plot(fpr, tpr, label=class_name)
|
||||
|
||||
ax.set_xlabel("False Positive Rate")
|
||||
ax.set_ylabel("True Positive Rate")
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.05, 1),
|
||||
loc="upper left",
|
||||
borderaxespad=0.0,
|
||||
)
|
||||
|
||||
yield "classification_roc_curve", fig
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClassificationROCCurvesConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
class_names=class_names,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
|
||||
|
||||
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
|
||||
|
||||
|
||||
class ConfusionMatrixConfig(BaseConfig):
|
||||
name: Literal["confusion_matrix"] = "confusion_matrix"
|
||||
background_class: str = "noise"
|
||||
|
||||
|
||||
class ConfusionMatrix(PlotterProtocol):
|
||||
def __init__(self, background_class: str, class_names: List[str]):
|
||||
self.background_class = background_class
|
||||
self.class_names = class_names
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(
|
||||
match.gt_class
|
||||
if match.gt_class is not None
|
||||
else self.background_class
|
||||
)
|
||||
|
||||
top_class = match.pred_class
|
||||
y_pred.append(
|
||||
top_class
|
||||
if top_class is not None
|
||||
else self.background_class
|
||||
)
|
||||
|
||||
display = metrics.ConfusionMatrixDisplay.from_predictions(
|
||||
y_true,
|
||||
y_pred,
|
||||
labels=[*self.class_names, self.background_class],
|
||||
)
|
||||
|
||||
yield "confusion_matrix", display.figure_
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ConfusionMatrixConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
background_class=config.background_class,
|
||||
class_names=class_names,
|
||||
)
|
||||
|
||||
|
||||
plots_registry.register(ConfusionMatrixConfig, ConfusionMatrix)
|
||||
|
||||
|
||||
PlotConfig = Annotated[
|
||||
Union[ExampleGalleryConfig,], Field(discriminator="name")
|
||||
Union[
|
||||
ExampleGalleryConfig,
|
||||
ClipEvaluationPlotConfig,
|
||||
DetectionPRCurveConfig,
|
||||
ClassificationPRCurvesConfig,
|
||||
DetectionROCCurveConfig,
|
||||
ClassificationROCCurvesConfig,
|
||||
ConfusionMatrixConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_plotter(config: PlotConfig) -> PlotterProtocol:
|
||||
return plots_registry.build(config)
|
||||
def build_plotter(
|
||||
config: PlotConfig, class_names: List[str]
|
||||
) -> PlotterProtocol:
|
||||
return plots_registry.build(config, class_names)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -1,18 +1,49 @@
|
||||
from typing import List
|
||||
from typing import Annotated, Callable, Literal, Sequence, Union
|
||||
|
||||
import pandas as pd
|
||||
from pydantic import Field
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.typing.evaluate import ClipEvaluation
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.typing import ClipEvaluation
|
||||
|
||||
EvaluationTableGenerator = Callable[[Sequence[ClipEvaluation]], pd.DataFrame]
|
||||
|
||||
|
||||
def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.DataFrame:
|
||||
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
|
||||
"evaluation_table"
|
||||
)
|
||||
|
||||
|
||||
class FullEvaluationTableConfig(BaseConfig):
|
||||
name: Literal["full_evaluation"] = "full_evaluation"
|
||||
|
||||
|
||||
class FullEvaluationTable:
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> pd.DataFrame:
|
||||
return extract_matches_dataframe(clip_evaluations)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: FullEvaluationTableConfig):
|
||||
return cls()
|
||||
|
||||
|
||||
tables_registry.register(FullEvaluationTableConfig, FullEvaluationTable)
|
||||
|
||||
|
||||
def extract_matches_dataframe(
|
||||
clip_evaluations: Sequence[ClipEvaluation],
|
||||
) -> pd.DataFrame:
|
||||
data = []
|
||||
|
||||
for clip_evaluation in clip_evaluations:
|
||||
for match in clip_evaluation.matches:
|
||||
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
||||
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
|
||||
pred_start_time = pred_low_freq = pred_end_time = (
|
||||
pred_high_freq
|
||||
) = None
|
||||
|
||||
sound_event_annotation = match.sound_event_annotation
|
||||
|
||||
@ -24,9 +55,12 @@ def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.Data
|
||||
)
|
||||
|
||||
if match.pred_geometry is not None:
|
||||
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
|
||||
compute_bounds(match.pred_geometry)
|
||||
)
|
||||
(
|
||||
pred_start_time,
|
||||
pred_low_freq,
|
||||
pred_end_time,
|
||||
pred_high_freq,
|
||||
) = compute_bounds(match.pred_geometry)
|
||||
|
||||
data.append(
|
||||
{
|
||||
@ -61,3 +95,14 @@ def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.Data
|
||||
df = pd.DataFrame(data)
|
||||
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
|
||||
return df
|
||||
|
||||
|
||||
EvaluationTableConfig = Annotated[
|
||||
Union[FullEvaluationTableConfig,], Field(discriminator="name")
|
||||
]
|
||||
|
||||
|
||||
def build_table_generator(
|
||||
config: EvaluationTableConfig,
|
||||
) -> EvaluationTableGenerator:
|
||||
return tables_registry.build(config)
|
||||
0
src/batdetect2/inference/__init__.py
Normal file
0
src/batdetect2/inference/__init__.py
Normal file
@ -1,4 +1,6 @@
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Annotated,
|
||||
@ -13,12 +15,19 @@ from typing import (
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
|
||||
import pandas as pd
|
||||
from lightning.pytorch.loggers import (
|
||||
CSVLogger,
|
||||
Logger,
|
||||
MLFlowLogger,
|
||||
TensorBoardLogger,
|
||||
)
|
||||
from loguru import logger
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
|
||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||
|
||||
@ -48,7 +57,7 @@ class TensorBoardLoggerConfig(BaseLoggerConfig):
|
||||
|
||||
class MLFlowLoggerConfig(BaseLoggerConfig):
|
||||
name: Literal["mlflow"] = "mlflow"
|
||||
tracking_uri: Optional[str] = None
|
||||
tracking_uri: Optional[str] = "http://localhost:5000"
|
||||
tags: Optional[dict[str, Any]] = None
|
||||
log_model: bool = False
|
||||
|
||||
@ -152,6 +161,9 @@ def create_tensorboard_logger(
|
||||
|
||||
name = run_name
|
||||
|
||||
if name is None:
|
||||
name = experiment_name
|
||||
|
||||
if run_name is not None and experiment_name is not None:
|
||||
name = str(Path(experiment_name) / run_name)
|
||||
|
||||
@ -231,18 +243,18 @@ def build_logger(
|
||||
)
|
||||
|
||||
|
||||
def get_image_plotter(logger: Logger):
|
||||
PlotLogger = Callable[[str, Figure, int], None]
|
||||
|
||||
|
||||
def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
|
||||
if isinstance(logger, TensorBoardLogger):
|
||||
|
||||
def plot_figure(name, figure, step):
|
||||
return logger.experiment.add_figure(name, figure, step)
|
||||
|
||||
return plot_figure
|
||||
return logger.experiment.add_figure
|
||||
|
||||
if isinstance(logger, MLFlowLogger):
|
||||
|
||||
def plot_figure(name, figure, step):
|
||||
image = _convert_figure_to_image(figure)
|
||||
image = _convert_figure_to_array(figure)
|
||||
name = name.replace("/", "_")
|
||||
return logger.experiment.log_image(
|
||||
logger.run_id,
|
||||
image,
|
||||
@ -252,8 +264,51 @@ def get_image_plotter(logger: Logger):
|
||||
|
||||
return plot_figure
|
||||
|
||||
if isinstance(logger, CSVLogger):
|
||||
return partial(save_figure, dir=Path(logger.log_dir))
|
||||
|
||||
def _convert_figure_to_image(figure):
|
||||
|
||||
TableLogger = Callable[[str, pd.DataFrame, int], None]
|
||||
|
||||
|
||||
def get_table_logger(logger: Logger) -> Optional[TableLogger]:
|
||||
if isinstance(logger, TensorBoardLogger):
|
||||
return partial(save_table, dir=Path(logger.log_dir))
|
||||
|
||||
if isinstance(logger, MLFlowLogger):
|
||||
|
||||
def plot_figure(name: str, df: pd.DataFrame, step: int):
|
||||
return logger.experiment.log_table(
|
||||
logger.run_id,
|
||||
data=df,
|
||||
artifact_file=f"{name}_step_{step}.json",
|
||||
)
|
||||
|
||||
return plot_figure
|
||||
|
||||
if isinstance(logger, CSVLogger):
|
||||
return partial(save_table, dir=Path(logger.log_dir))
|
||||
|
||||
|
||||
def save_table(name: str, df: pd.DataFrame, step: int, dir: Path) -> None:
|
||||
path = dir / "tables" / f"{name}_step_{step}.csv"
|
||||
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir(parents=True)
|
||||
|
||||
df.to_csv(path, index=False)
|
||||
|
||||
|
||||
def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
|
||||
path = dir / "plots" / f"{name}_step_{step}.png"
|
||||
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir(parents=True)
|
||||
|
||||
fig.savefig(path, transparent=True, bbox_inches="tight")
|
||||
|
||||
|
||||
def _convert_figure_to_array(figure: Figure) -> np.ndarray:
|
||||
with io.BytesIO() as buff:
|
||||
figure.savefig(buff, format="raw")
|
||||
buff.seek(0)
|
||||
@ -29,15 +29,10 @@ provided here.
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.backbones import (
|
||||
Backbone,
|
||||
BackboneConfig,
|
||||
build_backbone,
|
||||
load_backbone_config,
|
||||
)
|
||||
from batdetect2.models.blocks import (
|
||||
ConvConfig,
|
||||
@ -51,6 +46,10 @@ from batdetect2.models.bottleneck import (
|
||||
BottleneckConfig,
|
||||
build_bottleneck,
|
||||
)
|
||||
from batdetect2.models.config import (
|
||||
BackboneConfig,
|
||||
load_backbone_config,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
DecoderConfig,
|
||||
@ -63,12 +62,12 @@ from batdetect2.models.encoder import (
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.targets import TargetConfig, build_targets
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing.models import DetectionModel
|
||||
from batdetect2.typing.postprocess import (
|
||||
DetectionsTensor,
|
||||
ClipDetectionsTensor,
|
||||
PostprocessorProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
@ -99,20 +98,10 @@ __all__ = [
|
||||
"build_detector",
|
||||
"load_backbone_config",
|
||||
"Model",
|
||||
"ModelConfig",
|
||||
"build_model",
|
||||
]
|
||||
|
||||
|
||||
class ModelConfig(BaseConfig):
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
detector: DetectionModel
|
||||
preprocessor: PreprocessorProtocol
|
||||
@ -125,47 +114,38 @@ class Model(torch.nn.Module):
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
targets: TargetProtocol,
|
||||
config: ModelConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self.detector = detector
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
self.targets = targets
|
||||
self.config = config
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
|
||||
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
|
||||
spec = self.preprocessor(wav)
|
||||
outputs = self.detector(spec)
|
||||
return self.postprocessor(outputs)
|
||||
|
||||
|
||||
def build_model(config: Optional[ModelConfig] = None):
|
||||
config = config or ModelConfig()
|
||||
|
||||
targets = build_targets(config=config.targets)
|
||||
|
||||
preprocessor = build_preprocessor(config=config.preprocess)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
def build_model(
|
||||
config: Optional[BackboneConfig] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
postprocessor: Optional[PostprocessorProtocol] = None,
|
||||
):
|
||||
config = config or BackboneConfig()
|
||||
targets = targets or build_targets()
|
||||
preprocessor = preprocessor or build_preprocessor()
|
||||
postprocessor = postprocessor or build_postprocessor(
|
||||
preprocessor=preprocessor,
|
||||
config=config.postprocess,
|
||||
)
|
||||
|
||||
detector = build_detector(
|
||||
num_classes=len(targets.class_names),
|
||||
config=config.model,
|
||||
config=config,
|
||||
)
|
||||
return Model(
|
||||
config=config,
|
||||
detector=detector,
|
||||
postprocessor=postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
def load_model_config(
|
||||
path: PathLike, field: Optional[str] = None
|
||||
) -> ModelConfig:
|
||||
return load_config(path, schema=ModelConfig, field=field)
|
||||
|
||||
@ -18,37 +18,20 @@ automatic padding to handle input sizes not perfectly divisible by the
|
||||
network's total downsampling factor.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from soundevent import data
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.bottleneck import (
|
||||
DEFAULT_BOTTLENECK_CONFIG,
|
||||
BottleneckConfig,
|
||||
build_bottleneck,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
Decoder,
|
||||
DecoderConfig,
|
||||
build_decoder,
|
||||
)
|
||||
from batdetect2.models.encoder import (
|
||||
DEFAULT_ENCODER_CONFIG,
|
||||
Encoder,
|
||||
EncoderConfig,
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.bottleneck import build_bottleneck
|
||||
from batdetect2.models.config import BackboneConfig
|
||||
from batdetect2.models.decoder import Decoder, build_decoder
|
||||
from batdetect2.models.encoder import Encoder, build_encoder
|
||||
from batdetect2.typing.models import BackboneModel
|
||||
|
||||
__all__ = [
|
||||
"Backbone",
|
||||
"BackboneConfig",
|
||||
"load_backbone_config",
|
||||
"build_backbone",
|
||||
]
|
||||
|
||||
@ -161,82 +144,6 @@ class Backbone(BackboneModel):
|
||||
return x
|
||||
|
||||
|
||||
class BackboneConfig(BaseConfig):
|
||||
"""Configuration for the Encoder-Decoder Backbone network.
|
||||
|
||||
Aggregates configurations for the encoder, bottleneck, and decoder
|
||||
components, along with defining the input and final output dimensions
|
||||
for the complete backbone.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int, default=128
|
||||
Expected height (frequency bins) of the input spectrograms to the
|
||||
backbone. Must be positive.
|
||||
in_channels : int, default=1
|
||||
Expected number of channels in the input spectrograms (e.g., 1 for
|
||||
mono). Must be positive.
|
||||
encoder : EncoderConfig, optional
|
||||
Configuration for the encoder. If None or omitted,
|
||||
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
|
||||
encoder module) will be used.
|
||||
bottleneck : BottleneckConfig, optional
|
||||
Configuration for the bottleneck layer connecting encoder and decoder.
|
||||
If None or omitted, the default bottleneck configuration will be used.
|
||||
decoder : DecoderConfig, optional
|
||||
Configuration for the decoder. If None or omitted,
|
||||
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
|
||||
decoder module) will be used.
|
||||
out_channels : int, default=32
|
||||
Desired number of channels in the final feature map output by the
|
||||
backbone. Must be positive.
|
||||
"""
|
||||
|
||||
input_height: int = 128
|
||||
in_channels: int = 1
|
||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> BackboneConfig:
|
||||
"""Load the backbone configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`BackboneConfig` schema, potentially extracting data from a nested field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
backbone configuration (e.g., "model.backbone"). If None, the entire
|
||||
file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneConfig
|
||||
The loaded and validated backbone configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded config data does not conform to `BackboneConfig`.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path.
|
||||
"""
|
||||
return load_config(path, schema=BackboneConfig, field=field)
|
||||
|
||||
|
||||
def build_backbone(config: BackboneConfig) -> BackboneModel:
|
||||
"""Factory function to build a Backbone from configuration.
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@ import torch.nn.functional as F
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"ConvBlock",
|
||||
|
||||
@ -20,7 +20,7 @@ import torch
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
SelfAttentionConfig,
|
||||
VerticalConv,
|
||||
|
||||
98
src/batdetect2/models/config.py
Normal file
98
src/batdetect2/models/config.py
Normal file
@ -0,0 +1,98 @@
|
||||
from typing import Optional
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.models.bottleneck import (
|
||||
DEFAULT_BOTTLENECK_CONFIG,
|
||||
BottleneckConfig,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
DecoderConfig,
|
||||
)
|
||||
from batdetect2.models.encoder import (
|
||||
DEFAULT_ENCODER_CONFIG,
|
||||
EncoderConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BackboneConfig",
|
||||
"load_backbone_config",
|
||||
]
|
||||
|
||||
|
||||
class BackboneConfig(BaseConfig):
|
||||
"""Configuration for the Encoder-Decoder Backbone network.
|
||||
|
||||
Aggregates configurations for the encoder, bottleneck, and decoder
|
||||
components, along with defining the input and final output dimensions
|
||||
for the complete backbone.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int, default=128
|
||||
Expected height (frequency bins) of the input spectrograms to the
|
||||
backbone. Must be positive.
|
||||
in_channels : int, default=1
|
||||
Expected number of channels in the input spectrograms (e.g., 1 for
|
||||
mono). Must be positive.
|
||||
encoder : EncoderConfig, optional
|
||||
Configuration for the encoder. If None or omitted,
|
||||
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
|
||||
encoder module) will be used.
|
||||
bottleneck : BottleneckConfig, optional
|
||||
Configuration for the bottleneck layer connecting encoder and decoder.
|
||||
If None or omitted, the default bottleneck configuration will be used.
|
||||
decoder : DecoderConfig, optional
|
||||
Configuration for the decoder. If None or omitted,
|
||||
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
|
||||
decoder module) will be used.
|
||||
out_channels : int, default=32
|
||||
Desired number of channels in the final feature map output by the
|
||||
backbone. Must be positive.
|
||||
"""
|
||||
|
||||
input_height: int = 128
|
||||
in_channels: int = 1
|
||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> BackboneConfig:
|
||||
"""Load the backbone configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`BackboneConfig` schema, potentially extracting data from a nested field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
backbone configuration (e.g., "model.backbone"). If None, the entire
|
||||
file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneConfig
|
||||
The loaded and validated backbone configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded config data does not conform to `BackboneConfig`.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path.
|
||||
"""
|
||||
return load_config(path, schema=BackboneConfig, field=field)
|
||||
@ -24,7 +24,7 @@ import torch
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
ConvConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
|
||||
@ -26,7 +26,7 @@ import torch
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
|
||||
@ -5,8 +5,9 @@ import torch
|
||||
from matplotlib.axes import Axes
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.plotting.common import plot_spectrogram
|
||||
from batdetect2.preprocess import build_audio_loader, build_preprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -6,10 +6,8 @@ from soundevent import data, plot
|
||||
from soundevent.geometry import compute_bounds
|
||||
from soundevent.plot.tags import TagColorMapper
|
||||
|
||||
from batdetect2.plotting.clip_predictions import plot_prediction
|
||||
from batdetect2.plotting.clips import AudioLoader, plot_clip
|
||||
from batdetect2.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.evaluate import MatchEvaluation
|
||||
from batdetect2.typing import MatchEvaluation, PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"plot_matches",
|
||||
@ -30,7 +28,7 @@ DEFAULT_PREDICTION_LINE_STYLE = "--"
|
||||
|
||||
|
||||
def plot_matches(
|
||||
matches: List[data.Match],
|
||||
matches: List[MatchEvaluation],
|
||||
clip: data.Clip,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
@ -44,8 +42,7 @@ def plot_matches(
|
||||
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||
cross_trigger_color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||
) -> Axes:
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
@ -61,52 +58,48 @@ def plot_matches(
|
||||
color_mapper = TagColorMapper()
|
||||
|
||||
for match in matches:
|
||||
if match.source is None and match.target is not None:
|
||||
plot.plot_annotation(
|
||||
annotation=match.target,
|
||||
if match.is_cross_trigger():
|
||||
plot_cross_trigger_match(
|
||||
match,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
fill=fill,
|
||||
add_points=add_points,
|
||||
add_spectrogram=False,
|
||||
use_score=True,
|
||||
color=cross_trigger_color,
|
||||
add_text=False,
|
||||
)
|
||||
elif match.is_true_positive():
|
||||
plot_true_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
fill=fill,
|
||||
add_spectrogram=False,
|
||||
use_score=True,
|
||||
add_points=add_points,
|
||||
color=true_positive_color,
|
||||
add_text=False,
|
||||
)
|
||||
elif match.is_false_negative():
|
||||
plot_false_negative_match(
|
||||
match,
|
||||
ax=ax,
|
||||
fill=fill,
|
||||
add_spectrogram=False,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
color=false_negative_color,
|
||||
color_mapper=color_mapper,
|
||||
linestyle=annotation_linestyle,
|
||||
add_text=False,
|
||||
)
|
||||
elif match.target is None and match.source is not None:
|
||||
plot_prediction(
|
||||
prediction=match.source,
|
||||
elif match.is_false_positive:
|
||||
plot_false_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
fill=fill,
|
||||
add_spectrogram=False,
|
||||
use_score=True,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
color=false_positive_color,
|
||||
color_mapper=color_mapper,
|
||||
linestyle=prediction_linestyle,
|
||||
)
|
||||
elif match.target is not None and match.source is not None:
|
||||
plot.plot_annotation(
|
||||
annotation=match.target,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
color=true_positive_color,
|
||||
color_mapper=color_mapper,
|
||||
linestyle=annotation_linestyle,
|
||||
)
|
||||
plot_prediction(
|
||||
prediction=match.source,
|
||||
ax=ax,
|
||||
time_offset=0.004,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
color=true_positive_color,
|
||||
color_mapper=color_mapper,
|
||||
linestyle=prediction_linestyle,
|
||||
add_text=False,
|
||||
)
|
||||
else:
|
||||
continue
|
||||
@ -122,6 +115,9 @@ def plot_false_positive_match(
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
use_score: bool = True,
|
||||
add_spectrogram: bool = True,
|
||||
add_text: bool = True,
|
||||
add_points: bool = False,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
@ -142,34 +138,36 @@ def plot_false_positive_match(
|
||||
recording=match.clip.recording,
|
||||
)
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
if add_spectrogram:
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot.plot_geometry(
|
||||
ax = plot.plot_geometry(
|
||||
match.pred_geometry,
|
||||
ax=ax,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
alpha=match.pred_score if use_score else 1,
|
||||
color=color,
|
||||
)
|
||||
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
if add_text:
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
@ -182,7 +180,9 @@ def plot_false_negative_match(
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
add_spectrogram: bool = True,
|
||||
add_points: bool = False,
|
||||
add_text: bool = True,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||
@ -204,17 +204,18 @@ def plot_false_negative_match(
|
||||
recording=sound_event.recording,
|
||||
)
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
if add_spectrogram:
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot.plot_annotation(
|
||||
ax = plot.plot_annotation(
|
||||
match.sound_event_annotation,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
@ -225,15 +226,16 @@ def plot_false_negative_match(
|
||||
color=color,
|
||||
)
|
||||
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Negative \nClass: {match.gt_class} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
if add_text:
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Negative \nClass: {match.gt_class} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
@ -246,7 +248,10 @@ def plot_true_positive_match(
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
use_score: bool = True,
|
||||
add_spectrogram: bool = True,
|
||||
add_points: bool = False,
|
||||
add_text: bool = True,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||
@ -270,17 +275,18 @@ def plot_true_positive_match(
|
||||
recording=sound_event.recording,
|
||||
)
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
if add_spectrogram:
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot.plot_annotation(
|
||||
ax = plot.plot_annotation(
|
||||
match.sound_event_annotation,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
@ -297,20 +303,21 @@ def plot_true_positive_match(
|
||||
ax=ax,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
alpha=match.pred_score if use_score else 1,
|
||||
color=color,
|
||||
linestyle=prediction_linestyle,
|
||||
)
|
||||
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
if add_text:
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
@ -323,7 +330,10 @@ def plot_cross_trigger_match(
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
duration: float = DEFAULT_DURATION,
|
||||
use_score: bool = True,
|
||||
add_spectrogram: bool = True,
|
||||
add_points: bool = False,
|
||||
add_text: bool = True,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||
@ -347,17 +357,18 @@ def plot_cross_trigger_match(
|
||||
recording=sound_event.recording,
|
||||
)
|
||||
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
if add_spectrogram:
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
plot.plot_annotation(
|
||||
ax = plot.plot_annotation(
|
||||
match.sound_event_annotation,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
@ -369,24 +380,25 @@ def plot_cross_trigger_match(
|
||||
linestyle=annotation_linestyle,
|
||||
)
|
||||
|
||||
plot.plot_geometry(
|
||||
ax = plot.plot_geometry(
|
||||
match.pred_geometry,
|
||||
ax=ax,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
alpha=match.pred_score if use_score else 1,
|
||||
color=color,
|
||||
linestyle=prediction_linestyle,
|
||||
)
|
||||
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
if add_text:
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
@ -1,307 +1,25 @@
|
||||
"""Main entry point for the BatDetect2 Postprocessing pipeline."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.postprocess.config import (
|
||||
PostprocessConfig,
|
||||
load_postprocess_config,
|
||||
)
|
||||
from batdetect2.postprocess.decoding import (
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
convert_raw_prediction_to_sound_event_prediction,
|
||||
convert_raw_predictions_to_clip_prediction,
|
||||
to_raw_predictions,
|
||||
)
|
||||
from batdetect2.postprocess.extraction import extract_prediction_tensor
|
||||
from batdetect2.postprocess.nms import (
|
||||
NMS_KERNEL_SIZE,
|
||||
non_max_suppression,
|
||||
from batdetect2.postprocess.nms import non_max_suppression
|
||||
from batdetect2.postprocess.postprocessor import (
|
||||
Postprocessor,
|
||||
build_postprocessor,
|
||||
)
|
||||
from batdetect2.postprocess.remapping import map_detection_to_clip
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.typing import ModelOutput
|
||||
from batdetect2.typing.postprocess import (
|
||||
BatDetect2Prediction,
|
||||
DetectionsTensor,
|
||||
PostprocessorProtocol,
|
||||
RawPrediction,
|
||||
)
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||
"DEFAULT_DETECTION_THRESHOLD",
|
||||
"MAX_FREQ",
|
||||
"MIN_FREQ",
|
||||
"ModelOutput",
|
||||
"NMS_KERNEL_SIZE",
|
||||
"PostprocessConfig",
|
||||
"Postprocessor",
|
||||
"TOP_K_PER_SEC",
|
||||
"build_postprocessor",
|
||||
"convert_raw_predictions_to_clip_prediction",
|
||||
"to_raw_predictions",
|
||||
"load_postprocess_config",
|
||||
"non_max_suppression",
|
||||
]
|
||||
|
||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||
|
||||
|
||||
TOP_K_PER_SEC = 100
|
||||
|
||||
|
||||
class PostprocessConfig(BaseConfig):
|
||||
"""Configuration settings for the postprocessing pipeline.
|
||||
|
||||
Defines tunable parameters that control how raw model outputs are
|
||||
converted into final detections.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
nms_kernel_size : int, default=NMS_KERNEL_SIZE
|
||||
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
|
||||
Used to suppress weaker detections near stronger peaks. Must be
|
||||
positive.
|
||||
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
|
||||
Minimum confidence score from the detection heatmap required to
|
||||
consider a point as a potential detection. Must be >= 0.
|
||||
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||
Minimum confidence score for a specific class prediction to be included
|
||||
in the decoded tags for a detection. Must be >= 0.
|
||||
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||
Desired maximum number of detections per second of audio. Used by
|
||||
`get_max_detections` to calculate an absolute limit based on clip
|
||||
duration before applying `extract_detections_from_array`. Must be
|
||||
positive.
|
||||
"""
|
||||
|
||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||
detection_threshold: float = Field(
|
||||
default=DEFAULT_DETECTION_THRESHOLD,
|
||||
ge=0,
|
||||
)
|
||||
classification_threshold: float = Field(
|
||||
default=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
ge=0,
|
||||
)
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
def load_postprocess_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PostprocessConfig:
|
||||
"""Load the postprocessing configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`PostprocessConfig` schema, potentially extracting data from a nested
|
||||
field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
postprocessing configuration (e.g., "inference.postprocessing").
|
||||
If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PostprocessConfig
|
||||
The loaded and validated postprocessing configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded configuration data does not conform to the
|
||||
`PostprocessConfig` schema.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path within the loaded data.
|
||||
"""
|
||||
return load_config(path, schema=PostprocessConfig, field=field)
|
||||
|
||||
|
||||
def build_postprocessor(
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[PostprocessConfig] = None,
|
||||
) -> PostprocessorProtocol:
|
||||
"""Factory function to build the standard postprocessor."""
|
||||
config = config or PostprocessConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building postprocessor with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
return Postprocessor(
|
||||
samplerate=preprocessor.output_samplerate,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
top_k_per_sec=config.top_k_per_sec,
|
||||
detection_threshold=config.detection_threshold,
|
||||
)
|
||||
|
||||
|
||||
class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
"""Standard implementation of the postprocessing pipeline."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: float,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
top_k_per_sec: int = 200,
|
||||
detection_threshold: float = 0.01,
|
||||
):
|
||||
"""Initialize the Postprocessor."""
|
||||
super().__init__()
|
||||
self.samplerate = samplerate
|
||||
self.min_freq = min_freq
|
||||
self.max_freq = max_freq
|
||||
self.top_k_per_sec = top_k_per_sec
|
||||
self.detection_threshold = detection_threshold
|
||||
|
||||
def forward(self, output: ModelOutput) -> List[DetectionsTensor]:
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.samplerate
|
||||
max_detections = int(self.top_k_per_sec * duration)
|
||||
detections = extract_prediction_tensor(
|
||||
output,
|
||||
max_detections=max_detections,
|
||||
threshold=self.detection_threshold,
|
||||
)
|
||||
return [
|
||||
map_detection_to_clip(
|
||||
detection,
|
||||
start_time=0,
|
||||
end_time=duration,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for detection in detections
|
||||
]
|
||||
|
||||
def get_detections(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[DetectionsTensor]:
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.samplerate
|
||||
max_detections = int(self.top_k_per_sec * duration)
|
||||
|
||||
detections = extract_prediction_tensor(
|
||||
output,
|
||||
max_detections=max_detections,
|
||||
threshold=self.detection_threshold,
|
||||
)
|
||||
|
||||
if start_times is None:
|
||||
return detections
|
||||
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.samplerate
|
||||
return [
|
||||
map_detection_to_clip(
|
||||
detection,
|
||||
start_time=start_time,
|
||||
end_time=start_time + duration,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for detection, start_time in zip(detections, start_times)
|
||||
]
|
||||
|
||||
|
||||
def get_raw_predictions(
|
||||
output: ModelOutput,
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[List[RawPrediction]]:
|
||||
"""Extract intermediate RawPrediction objects for a batch."""
|
||||
detections = postprocessor.get_detections(output, start_times)
|
||||
return [
|
||||
to_raw_predictions(detection.numpy(), targets=targets)
|
||||
for detection in detections
|
||||
]
|
||||
|
||||
|
||||
def get_sound_event_predictions(
|
||||
output: ModelOutput,
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
clips: List[data.Clip],
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
) -> List[List[BatDetect2Prediction]]:
|
||||
raw_predictions = get_raw_predictions(
|
||||
output,
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
start_times=[clip.start_time for clip in clips],
|
||||
)
|
||||
return [
|
||||
[
|
||||
BatDetect2Prediction(
|
||||
raw=raw,
|
||||
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
||||
raw,
|
||||
recording=clip.recording,
|
||||
targets=targets,
|
||||
classification_threshold=classification_threshold,
|
||||
),
|
||||
)
|
||||
for raw in predictions
|
||||
]
|
||||
for predictions, clip in zip(raw_predictions, clips)
|
||||
]
|
||||
|
||||
|
||||
def get_predictions(
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
) -> List[data.ClipPrediction]:
|
||||
"""Perform the full postprocessing pipeline for a batch.
|
||||
|
||||
Takes raw model output and corresponding clips, applies the entire
|
||||
configured chain (NMS, remapping, extraction, geometry recovery, class
|
||||
decoding), producing final `soundevent.data.ClipPrediction` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.ClipPrediction]
|
||||
List containing one `ClipPrediction` object for each input clip,
|
||||
populated with `SoundEventPrediction` objects.
|
||||
"""
|
||||
raw_predictions = get_raw_predictions(
|
||||
output,
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
start_times=[clip.start_time for clip in clips],
|
||||
)
|
||||
return [
|
||||
convert_raw_predictions_to_clip_prediction(
|
||||
prediction,
|
||||
clip,
|
||||
targets=targets,
|
||||
classification_threshold=classification_threshold,
|
||||
)
|
||||
for prediction, clip in zip(raw_predictions, clips)
|
||||
]
|
||||
|
||||
94
src/batdetect2/postprocess/config.py
Normal file
94
src/batdetect2/postprocess/config.py
Normal file
@ -0,0 +1,94 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.postprocess.decoding import DEFAULT_CLASSIFICATION_THRESHOLD
|
||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE
|
||||
|
||||
__all__ = [
|
||||
"PostprocessConfig",
|
||||
"load_postprocess_config",
|
||||
]
|
||||
|
||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||
|
||||
|
||||
TOP_K_PER_SEC = 100
|
||||
|
||||
|
||||
class PostprocessConfig(BaseConfig):
|
||||
"""Configuration settings for the postprocessing pipeline.
|
||||
|
||||
Defines tunable parameters that control how raw model outputs are
|
||||
converted into final detections.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
nms_kernel_size : int, default=NMS_KERNEL_SIZE
|
||||
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
|
||||
Used to suppress weaker detections near stronger peaks. Must be
|
||||
positive.
|
||||
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
|
||||
Minimum confidence score from the detection heatmap required to
|
||||
consider a point as a potential detection. Must be >= 0.
|
||||
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||
Minimum confidence score for a specific class prediction to be included
|
||||
in the decoded tags for a detection. Must be >= 0.
|
||||
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||
Desired maximum number of detections per second of audio. Used by
|
||||
`get_max_detections` to calculate an absolute limit based on clip
|
||||
duration before applying `extract_detections_from_array`. Must be
|
||||
positive.
|
||||
"""
|
||||
|
||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||
detection_threshold: float = Field(
|
||||
default=DEFAULT_DETECTION_THRESHOLD,
|
||||
ge=0,
|
||||
)
|
||||
classification_threshold: float = Field(
|
||||
default=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
ge=0,
|
||||
)
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
def load_postprocess_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PostprocessConfig:
|
||||
"""Load the postprocessing configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`PostprocessConfig` schema, potentially extracting data from a nested
|
||||
field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
postprocessing configuration (e.g., "inference.postprocessing").
|
||||
If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PostprocessConfig
|
||||
The loaded and validated postprocessing configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded configuration data does not conform to the
|
||||
`PostprocessConfig` schema.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path within the loaded data.
|
||||
"""
|
||||
return load_config(path, schema=PostprocessConfig, field=field)
|
||||
@ -6,7 +6,7 @@ import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.typing.postprocess import (
|
||||
DetectionsArray,
|
||||
ClipDetectionsArray,
|
||||
RawPrediction,
|
||||
)
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
@ -28,7 +28,7 @@ decoding.
|
||||
|
||||
|
||||
def to_raw_predictions(
|
||||
detections: DetectionsArray,
|
||||
detections: ClipDetectionsArray,
|
||||
targets: TargetProtocol,
|
||||
) -> List[RawPrediction]:
|
||||
predictions = []
|
||||
|
||||
@ -15,32 +15,25 @@ precise time-frequency location of each detection. The final output aggregates
|
||||
all extracted information into a structured `xarray.Dataset`.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
||||
from batdetect2.typing.postprocess import (
|
||||
DetectionsTensor,
|
||||
ModelOutput,
|
||||
)
|
||||
from batdetect2.typing.postprocess import ClipDetectionsTensor
|
||||
|
||||
__all__ = [
|
||||
"extract_prediction_tensor",
|
||||
"extract_detection_peaks",
|
||||
]
|
||||
|
||||
|
||||
def extract_prediction_tensor(
|
||||
output: ModelOutput,
|
||||
def extract_detection_peaks(
|
||||
detection_heatmap: torch.Tensor,
|
||||
size_heatmap: torch.Tensor,
|
||||
feature_heatmap: torch.Tensor,
|
||||
classification_heatmap: torch.Tensor,
|
||||
max_detections: int = 200,
|
||||
threshold: Optional[float] = None,
|
||||
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||
) -> List[DetectionsTensor]:
|
||||
detection_heatmap = non_max_suppression(
|
||||
output.detection_probs.detach(),
|
||||
kernel_size=nms_kernel_size,
|
||||
)
|
||||
|
||||
) -> List[ClipDetectionsTensor]:
|
||||
height = detection_heatmap.shape[-2]
|
||||
width = detection_heatmap.shape[-1]
|
||||
|
||||
@ -53,9 +46,9 @@ def extract_prediction_tensor(
|
||||
freqs = freqs.flatten().to(detection_heatmap.device)
|
||||
times = times.flatten().to(detection_heatmap.device)
|
||||
|
||||
output_size_preds = output.size_preds.detach()
|
||||
output_features = output.features.detach()
|
||||
output_class_probs = output.class_probs.detach()
|
||||
output_size_preds = size_heatmap.detach()
|
||||
output_features = feature_heatmap.detach()
|
||||
output_class_probs = classification_heatmap.detach()
|
||||
|
||||
predictions = []
|
||||
for idx, item in enumerate(detection_heatmap):
|
||||
@ -65,23 +58,25 @@ def extract_prediction_tensor(
|
||||
detection_scores = item.take(indices)
|
||||
detection_freqs = freqs.take(indices)
|
||||
detection_times = times.take(indices)
|
||||
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
|
||||
features = output_features[idx, :, detection_freqs, detection_times].T
|
||||
class_scores = output_class_probs[
|
||||
idx, :, detection_freqs, detection_times
|
||||
].T
|
||||
|
||||
if threshold is not None:
|
||||
mask = detection_scores >= threshold
|
||||
|
||||
detection_scores = detection_scores[mask]
|
||||
sizes = sizes[mask]
|
||||
detection_times = detection_times[mask]
|
||||
detection_freqs = detection_freqs[mask]
|
||||
features = features[mask]
|
||||
class_scores = class_scores[mask]
|
||||
|
||||
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
|
||||
features = output_features[idx, :, detection_freqs, detection_times].T
|
||||
class_scores = output_class_probs[
|
||||
idx,
|
||||
:,
|
||||
detection_freqs,
|
||||
detection_times,
|
||||
].T
|
||||
|
||||
predictions.append(
|
||||
DetectionsTensor(
|
||||
ClipDetectionsTensor(
|
||||
scores=detection_scores,
|
||||
sizes=sizes,
|
||||
features=features,
|
||||
|
||||
100
src/batdetect2/postprocess/postprocessor.py
Normal file
100
src/batdetect2/postprocess/postprocessor.py
Normal file
@ -0,0 +1,100 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.postprocess.config import (
|
||||
PostprocessConfig,
|
||||
)
|
||||
from batdetect2.postprocess.extraction import extract_detection_peaks
|
||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
||||
from batdetect2.postprocess.remapping import map_detection_to_clip
|
||||
from batdetect2.typing import ModelOutput
|
||||
from batdetect2.typing.postprocess import (
|
||||
ClipDetectionsTensor,
|
||||
PostprocessorProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"build_postprocessor",
|
||||
"Postprocessor",
|
||||
]
|
||||
|
||||
|
||||
def build_postprocessor(
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[PostprocessConfig] = None,
|
||||
) -> PostprocessorProtocol:
|
||||
"""Factory function to build the standard postprocessor."""
|
||||
config = config or PostprocessConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building postprocessor with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
return Postprocessor(
|
||||
samplerate=preprocessor.output_samplerate,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
top_k_per_sec=config.top_k_per_sec,
|
||||
detection_threshold=config.detection_threshold,
|
||||
)
|
||||
|
||||
|
||||
class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
"""Standard implementation of the postprocessing pipeline."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: float,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
top_k_per_sec: int = 200,
|
||||
detection_threshold: float = 0.01,
|
||||
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||
):
|
||||
"""Initialize the Postprocessor."""
|
||||
super().__init__()
|
||||
|
||||
self.output_samplerate = samplerate
|
||||
self.min_freq = min_freq
|
||||
self.max_freq = max_freq
|
||||
self.top_k_per_sec = top_k_per_sec
|
||||
self.detection_threshold = detection_threshold
|
||||
self.nms_kernel_size = nms_kernel_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[ClipDetectionsTensor]:
|
||||
detection_heatmap = non_max_suppression(
|
||||
output.detection_probs.detach(),
|
||||
kernel_size=self.nms_kernel_size,
|
||||
)
|
||||
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.output_samplerate
|
||||
max_detections = int(self.top_k_per_sec * duration)
|
||||
detections = extract_detection_peaks(
|
||||
detection_heatmap,
|
||||
size_heatmap=output.size_preds,
|
||||
feature_heatmap=output.features,
|
||||
classification_heatmap=output.class_probs,
|
||||
max_detections=max_detections,
|
||||
threshold=self.detection_threshold,
|
||||
)
|
||||
|
||||
if start_times is None:
|
||||
start_times = [0 for _ in range(len(detections))]
|
||||
|
||||
return [
|
||||
map_detection_to_clip(
|
||||
detection,
|
||||
start_time=0,
|
||||
end_time=duration,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for detection in detections
|
||||
]
|
||||
@ -20,7 +20,7 @@ import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.typing.postprocess import DetectionsTensor
|
||||
from batdetect2.typing.postprocess import ClipDetectionsTensor
|
||||
|
||||
__all__ = [
|
||||
"features_to_xarray",
|
||||
@ -31,15 +31,15 @@ __all__ = [
|
||||
|
||||
|
||||
def map_detection_to_clip(
|
||||
detections: DetectionsTensor,
|
||||
detections: ClipDetectionsTensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
) -> DetectionsTensor:
|
||||
) -> ClipDetectionsTensor:
|
||||
duration = end_time - start_time
|
||||
bandwidth = max_freq - min_freq
|
||||
return DetectionsTensor(
|
||||
return ClipDetectionsTensor(
|
||||
scores=detections.scores,
|
||||
sizes=detections.sizes,
|
||||
features=detections.features,
|
||||
|
||||
@ -1,176 +1,19 @@
|
||||
"""Main entry point for the BatDetect2 Preprocessing subsystem.
|
||||
"""Main entry point for the BatDetect2 preprocessing subsystem."""
|
||||
|
||||
This package (`batdetect2.preprocessing`) defines and orchestrates the pipeline
|
||||
for converting raw audio input (from files or data objects) into processed
|
||||
spectrograms suitable for input to BatDetect2 models. This ensures consistent
|
||||
data handling between model training and inference.
|
||||
|
||||
The preprocessing pipeline consists of two main stages, configured via nested
|
||||
data structures:
|
||||
1. **Audio Processing (`.audio`)**: Loads audio waveforms and applies initial
|
||||
processing like resampling, duration adjustment, centering, and scaling.
|
||||
Configured via `AudioConfig`.
|
||||
2. **Spectrogram Generation (`.spectrogram`)**: Computes the spectrogram from
|
||||
the processed waveform using STFT, followed by frequency cropping, optional
|
||||
PCEN, amplitude scaling (dB, power, linear), optional denoising, optional
|
||||
resizing, and optional peak normalization. Configured via
|
||||
`SpectrogramConfig`.
|
||||
|
||||
This module provides the primary interface:
|
||||
|
||||
- `PreprocessingConfig`: A unified configuration object holding `AudioConfig`
|
||||
and `SpectrogramConfig`.
|
||||
- `load_preprocessing_config`: Function to load the unified configuration.
|
||||
- `Preprocessor`: A protocol defining the interface for the end-to-end pipeline.
|
||||
- `StandardPreprocessor`: The default implementation of the `Preprocessor`.
|
||||
- `build_preprocessor`: A factory function to create a `StandardPreprocessor`
|
||||
instance from a `PreprocessingConfig`.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess.audio import (
|
||||
DEFAULT_DURATION,
|
||||
SCALE_RAW_AUDIO,
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
AudioConfig,
|
||||
ResampleConfig,
|
||||
build_audio_loader,
|
||||
build_audio_pipeline,
|
||||
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.preprocess.config import (
|
||||
PreprocessingConfig,
|
||||
load_preprocessing_config,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
MAX_FREQ,
|
||||
MIN_FREQ,
|
||||
FrequencyConfig,
|
||||
PcenConfig,
|
||||
SpectrogramConfig,
|
||||
SpectrogramPipeline,
|
||||
STFTConfig,
|
||||
_spec_params_from_config,
|
||||
build_spectrogram_builder,
|
||||
build_spectrogram_pipeline,
|
||||
)
|
||||
from batdetect2.typing import PreprocessorProtocol
|
||||
from batdetect2.preprocess.preprocessor import Preprocessor, build_preprocessor
|
||||
from batdetect2.preprocess.spectrogram import MAX_FREQ, MIN_FREQ
|
||||
|
||||
__all__ = [
|
||||
"AudioConfig",
|
||||
"DEFAULT_DURATION",
|
||||
"FrequencyConfig",
|
||||
"MAX_FREQ",
|
||||
"MIN_FREQ",
|
||||
"PcenConfig",
|
||||
"PreprocessingConfig",
|
||||
"ResampleConfig",
|
||||
"SCALE_RAW_AUDIO",
|
||||
"STFTConfig",
|
||||
"SpectrogramConfig",
|
||||
"Preprocessor",
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"build_audio_loader",
|
||||
"build_spectrogram_builder",
|
||||
"build_preprocessor",
|
||||
"load_preprocessing_config",
|
||||
]
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseConfig):
|
||||
"""Unified configuration for the audio preprocessing pipeline.
|
||||
|
||||
Aggregates the configuration for both the initial audio processing stage
|
||||
and the subsequent spectrogram generation stage.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
audio : AudioConfig
|
||||
Configuration settings for the audio loading and initial waveform
|
||||
processing steps (e.g., resampling, duration adjustment, scaling).
|
||||
Defaults to default `AudioConfig` settings if omitted.
|
||||
spectrogram : SpectrogramConfig
|
||||
Configuration settings for the spectrogram generation process
|
||||
(e.g., STFT parameters, frequency cropping, scaling, denoising,
|
||||
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
|
||||
"""
|
||||
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||
|
||||
|
||||
def load_preprocessing_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PreprocessingConfig:
|
||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
||||
|
||||
|
||||
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||
"""Standard implementation of the `Preprocessor` protocol."""
|
||||
|
||||
input_samplerate: int
|
||||
output_samplerate: float
|
||||
|
||||
max_freq: float
|
||||
min_freq: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_pipeline: torch.nn.Module,
|
||||
spectrogram_pipeline: SpectrogramPipeline,
|
||||
input_samplerate: int,
|
||||
output_samplerate: float,
|
||||
max_freq: float,
|
||||
min_freq: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.audio_pipeline = audio_pipeline
|
||||
self.spectrogram_pipeline = spectrogram_pipeline
|
||||
|
||||
self.max_freq = max_freq
|
||||
self.min_freq = min_freq
|
||||
|
||||
self.input_samplerate = input_samplerate
|
||||
self.output_samplerate = output_samplerate
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
wav = self.audio_pipeline(wav)
|
||||
return self.spectrogram_pipeline(wav)
|
||||
|
||||
|
||||
def compute_output_samplerate(config: PreprocessingConfig) -> float:
|
||||
samplerate = config.audio.samplerate
|
||||
_, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft)
|
||||
factor = config.spectrogram.size.resize_factor
|
||||
return samplerate * factor / hop_size
|
||||
|
||||
|
||||
def build_preprocessor(
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
) -> PreprocessorProtocol:
|
||||
"""Factory function to build the standard preprocessor from configuration."""
|
||||
config = config or PreprocessingConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building preprocessor with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
samplerate = config.audio.samplerate
|
||||
|
||||
min_freq = config.spectrogram.frequencies.min_freq
|
||||
max_freq = config.spectrogram.frequencies.max_freq
|
||||
|
||||
output_samplerate = compute_output_samplerate(config)
|
||||
|
||||
return StandardPreprocessor(
|
||||
audio_pipeline=build_audio_pipeline(config.audio),
|
||||
spectrogram_pipeline=build_spectrogram_pipeline(
|
||||
samplerate, config.spectrogram
|
||||
),
|
||||
input_samplerate=samplerate,
|
||||
output_samplerate=output_samplerate,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
|
||||
@ -1,307 +1,57 @@
|
||||
"""Handles loading and initial preprocessing of audio waveforms."""
|
||||
from typing import Annotated, Literal, Union
|
||||
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import Field
|
||||
from scipy.signal import resample, resample_poly
|
||||
from soundevent import audio, data
|
||||
from soundfile import LibsndfileError
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.common import CenterTensor, PeakNormalize
|
||||
from batdetect2.typing import AudioLoader
|
||||
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.preprocess.common import center_tensor, peak_normalize
|
||||
|
||||
__all__ = [
|
||||
"ResampleConfig",
|
||||
"AudioConfig",
|
||||
"SoundEventAudioLoader",
|
||||
"build_audio_loader",
|
||||
"load_file_audio",
|
||||
"load_recording_audio",
|
||||
"load_clip_audio",
|
||||
"resample_audio",
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"SCALE_RAW_AUDIO",
|
||||
"DEFAULT_DURATION",
|
||||
"CenterAudioConfig",
|
||||
"ScaleAudioConfig",
|
||||
"FixDurationConfig",
|
||||
"build_audio_transform",
|
||||
]
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256_000
|
||||
"""Default target sample rate in Hz used if resampling is enabled."""
|
||||
|
||||
SCALE_RAW_AUDIO = False
|
||||
"""Default setting for whether to perform peak normalization."""
|
||||
|
||||
DEFAULT_DURATION = None
|
||||
"""Default setting for target audio duration in seconds."""
|
||||
|
||||
|
||||
class ResampleConfig(BaseConfig):
|
||||
"""Configuration for audio resampling.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
samplerate : int, default=256000
|
||||
The target sample rate in Hz to resample the audio to. Must be > 0.
|
||||
method : str, default="poly"
|
||||
The resampling algorithm to use. Options:
|
||||
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
|
||||
Generally fast.
|
||||
- "fourier": Resampling via Fourier method using
|
||||
`scipy.signal.resample`. May handle non-integer
|
||||
resampling factors differently.
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
method: str = "poly"
|
||||
|
||||
|
||||
class SoundEventAudioLoader:
|
||||
"""Concrete implementation of the `AudioLoader`."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
):
|
||||
self.samplerate = samplerate
|
||||
self.config = config or ResampleConfig()
|
||||
|
||||
def load_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio directly from a file path."""
|
||||
return load_file_audio(
|
||||
path,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio for a Recording object."""
|
||||
return load_recording_audio(
|
||||
recording,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the audio segment defined by a Clip object."""
|
||||
return load_clip_audio(
|
||||
clip,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_file_audio(
|
||||
path: data.PathLike,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio from a file path using specified config."""
|
||||
try:
|
||||
recording = data.Recording.from_file(path)
|
||||
except LibsndfileError as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not load the recording at path: {path}. Error: {e}"
|
||||
) from e
|
||||
|
||||
return load_recording_audio(
|
||||
recording,
|
||||
samplerate=samplerate,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_recording_audio(
|
||||
recording: data.Recording,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio content of a recording using config."""
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
)
|
||||
return load_clip_audio(
|
||||
clip,
|
||||
samplerate=samplerate,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess a specific audio clip segment based on config."""
|
||||
try:
|
||||
wav = (
|
||||
audio.load_clip(clip, audio_dir=audio_dir)
|
||||
.sel(channel=0)
|
||||
.astype(dtype)
|
||||
)
|
||||
except LibsndfileError as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not load the recording at path: {clip.recording.path}. "
|
||||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
if not config or not config.enabled or samplerate is None:
|
||||
return wav.data.astype(dtype)
|
||||
|
||||
sr = int(1 / wav.time.attrs["step"])
|
||||
return resample_audio(
|
||||
wav.data,
|
||||
sr=sr,
|
||||
samplerate=samplerate,
|
||||
method=config.method,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio(
|
||||
wav: np.ndarray,
|
||||
sr: int,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
method: str = "poly",
|
||||
) -> np.ndarray:
|
||||
"""Resample an audio waveform DataArray to a target sample rate."""
|
||||
if sr == samplerate:
|
||||
return wav
|
||||
|
||||
if method == "poly":
|
||||
return resample_audio_poly(
|
||||
wav,
|
||||
sr_orig=sr,
|
||||
sr_new=samplerate,
|
||||
)
|
||||
elif method == "fourier":
|
||||
return resample_audio_fourier(
|
||||
wav,
|
||||
sr_orig=sr,
|
||||
sr_new=samplerate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Resampling method '{method}' not implemented"
|
||||
)
|
||||
|
||||
|
||||
def resample_audio_poly(
|
||||
array: np.ndarray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
"""Resample a numpy array using `scipy.signal.resample_poly`.
|
||||
|
||||
This method is often preferred for signals when the ratio of new
|
||||
to old sample rates can be expressed as a rational number. It uses
|
||||
polyphase filtering.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : np.ndarray
|
||||
The input array to resample.
|
||||
sr_orig : int
|
||||
The original sample rate in Hz.
|
||||
sr_new : int
|
||||
The target sample rate in Hz.
|
||||
axis : int, default=-1
|
||||
The axis of `array` along which to resample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The array resampled to the target sample rate.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If sample rates are not positive.
|
||||
"""
|
||||
gcd = np.gcd(sr_orig, sr_new)
|
||||
return resample_poly(
|
||||
array,
|
||||
sr_new // gcd,
|
||||
sr_orig // gcd,
|
||||
axis=axis,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio_fourier(
|
||||
array: np.ndarray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
"""Resample a numpy array using `scipy.signal.resample`.
|
||||
|
||||
This method uses FFTs to resample the signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : np.ndarray
|
||||
The input array to resample.
|
||||
num : int
|
||||
The desired number of samples in the output array along `axis`.
|
||||
axis : int, default=-1
|
||||
The axis of `array` along which to resample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The array resampled to have `num` samples along `axis`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num` is negative.
|
||||
"""
|
||||
ratio = sr_new / sr_orig
|
||||
return resample( # type: ignore
|
||||
array,
|
||||
int(array.shape[axis] * ratio),
|
||||
axis=axis,
|
||||
)
|
||||
audio_transforms: Registry[torch.nn.Module, [int]] = Registry(
|
||||
"audio_transform"
|
||||
)
|
||||
|
||||
|
||||
class CenterAudioConfig(BaseConfig):
|
||||
name: Literal["center_audio"] = "center_audio"
|
||||
|
||||
|
||||
class CenterAudio(torch.nn.Module):
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return center_tensor(wav)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: CenterAudioConfig, samplerate: int):
|
||||
return cls()
|
||||
|
||||
|
||||
audio_transforms.register(CenterAudioConfig, CenterAudio)
|
||||
|
||||
|
||||
class ScaleAudioConfig(BaseConfig):
|
||||
name: Literal["scale_audio"] = "scale_audio"
|
||||
|
||||
|
||||
class ScaleAudio(torch.nn.Module):
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return peak_normalize(wav)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ScaleAudioConfig, samplerate: int):
|
||||
return cls()
|
||||
|
||||
|
||||
audio_transforms.register(ScaleAudioConfig, ScaleAudio)
|
||||
|
||||
|
||||
class FixDurationConfig(BaseConfig):
|
||||
name: Literal["fix_duration"] = "fix_duration"
|
||||
duration: float = 0.5
|
||||
@ -325,6 +75,12 @@ class FixDuration(torch.nn.Module):
|
||||
|
||||
return torch.nn.functional.pad(wav, (0, self.length - length))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: FixDurationConfig, samplerate: int):
|
||||
return cls(samplerate=samplerate, duration=config.duration)
|
||||
|
||||
|
||||
audio_transforms.register(FixDurationConfig, FixDuration)
|
||||
|
||||
AudioTransform = Annotated[
|
||||
Union[
|
||||
@ -336,47 +92,8 @@ AudioTransform = Annotated[
|
||||
]
|
||||
|
||||
|
||||
class AudioConfig(BaseConfig):
|
||||
"""Configuration for loading and initial audio preprocessing."""
|
||||
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
||||
transforms: List[AudioTransform] = Field(default_factory=list)
|
||||
|
||||
|
||||
def build_audio_loader(
|
||||
config: Optional[AudioConfig] = None,
|
||||
) -> AudioLoader:
|
||||
"""Factory function to create an AudioLoader based on configuration."""
|
||||
config = config or AudioConfig()
|
||||
return SoundEventAudioLoader(
|
||||
samplerate=config.samplerate,
|
||||
config=config.resample,
|
||||
)
|
||||
|
||||
|
||||
def build_audio_transform_step(
|
||||
def build_audio_transform(
|
||||
config: AudioTransform,
|
||||
samplerate: int,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> torch.nn.Module:
|
||||
if config.name == "fix_duration":
|
||||
return FixDuration(samplerate=samplerate, duration=config.duration)
|
||||
|
||||
if config.name == "scale_audio":
|
||||
return PeakNormalize()
|
||||
|
||||
if config.name == "center_audio":
|
||||
return CenterTensor()
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Audio preprocessing step {config.name} not implemented"
|
||||
)
|
||||
|
||||
|
||||
def build_audio_pipeline(config: AudioConfig) -> torch.nn.Module:
|
||||
return torch.nn.Sequential(
|
||||
*[
|
||||
build_audio_transform_step(step, samplerate=config.samplerate)
|
||||
for step in config.transforms
|
||||
]
|
||||
)
|
||||
return audio_transforms.build(config, samplerate)
|
||||
|
||||
@ -1,24 +1,22 @@
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"CenterTensor",
|
||||
"PeakNormalize",
|
||||
"center_tensor",
|
||||
"peak_normalize",
|
||||
]
|
||||
|
||||
|
||||
class CenterTensor(torch.nn.Module):
|
||||
def forward(self, wav: torch.Tensor):
|
||||
return wav - wav.mean()
|
||||
def center_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor - tensor.mean()
|
||||
|
||||
|
||||
class PeakNormalize(torch.nn.Module):
|
||||
def forward(self, wav: torch.Tensor):
|
||||
max_value = wav.abs().min()
|
||||
def peak_normalize(tensor: torch.Tensor) -> torch.Tensor:
|
||||
max_value = tensor.abs().min()
|
||||
|
||||
denominator = torch.where(
|
||||
max_value == 0,
|
||||
torch.tensor(1.0, device=wav.device, dtype=wav.dtype),
|
||||
max_value,
|
||||
)
|
||||
denominator = torch.where(
|
||||
max_value == 0,
|
||||
torch.tensor(1.0, device=tensor.device, dtype=tensor.dtype),
|
||||
max_value,
|
||||
)
|
||||
|
||||
return wav / denominator
|
||||
return tensor / denominator
|
||||
|
||||
62
src/batdetect2/preprocess/config.py
Normal file
62
src/batdetect2/preprocess/config.py
Normal file
@ -0,0 +1,62 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess.audio import AudioTransform
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
FrequencyConfig,
|
||||
PcenConfig,
|
||||
ResizeConfig,
|
||||
SpectralMeanSubstractionConfig,
|
||||
SpectrogramTransform,
|
||||
STFTConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"load_preprocessing_config",
|
||||
"AudioTransform",
|
||||
"PreprocessingConfig",
|
||||
]
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseConfig):
|
||||
"""Unified configuration for the audio preprocessing pipeline.
|
||||
|
||||
Aggregates the configuration for both the initial audio processing stage
|
||||
and the subsequent spectrogram generation stage.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
audio : AudioConfig
|
||||
Configuration settings for the audio loading and initial waveform
|
||||
processing steps (e.g., resampling, duration adjustment, scaling).
|
||||
Defaults to default `AudioConfig` settings if omitted.
|
||||
spectrogram : SpectrogramConfig
|
||||
Configuration settings for the spectrogram generation process
|
||||
(e.g., STFT parameters, frequency cropping, scaling, denoising,
|
||||
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
|
||||
"""
|
||||
|
||||
audio_transforms: List[AudioTransform] = Field(default_factory=list)
|
||||
|
||||
spectrogram_transforms: List[SpectrogramTransform] = Field(
|
||||
default_factory=lambda: [
|
||||
PcenConfig(),
|
||||
SpectralMeanSubstractionConfig(),
|
||||
]
|
||||
)
|
||||
|
||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||
|
||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||
|
||||
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
||||
|
||||
|
||||
def load_preprocessing_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PreprocessingConfig:
|
||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
||||
114
src/batdetect2/preprocess/preprocessor.py
Normal file
114
src/batdetect2/preprocess/preprocessor.py
Normal file
@ -0,0 +1,114 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.preprocess.audio import build_audio_transform
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
_spec_params_from_config,
|
||||
build_spectrogram_builder,
|
||||
build_spectrogram_crop,
|
||||
build_spectrogram_resizer,
|
||||
build_spectrogram_transform,
|
||||
)
|
||||
from batdetect2.typing import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"Preprocessor",
|
||||
"build_preprocessor",
|
||||
]
|
||||
|
||||
|
||||
class Preprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||
"""Standard implementation of the `Preprocessor` protocol."""
|
||||
|
||||
input_samplerate: int
|
||||
output_samplerate: float
|
||||
|
||||
max_freq: float
|
||||
min_freq: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PreprocessingConfig,
|
||||
input_samplerate: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.audio_transforms = torch.nn.Sequential(
|
||||
*[
|
||||
build_audio_transform(step, samplerate=input_samplerate)
|
||||
for step in config.audio_transforms
|
||||
]
|
||||
)
|
||||
|
||||
self.spectrogram_transforms = torch.nn.Sequential(
|
||||
*[
|
||||
build_spectrogram_transform(step, samplerate=input_samplerate)
|
||||
for step in config.spectrogram_transforms
|
||||
]
|
||||
)
|
||||
|
||||
self.spectrogram_builder = build_spectrogram_builder(
|
||||
config.stft,
|
||||
samplerate=input_samplerate,
|
||||
)
|
||||
|
||||
self.spectrogram_crop = build_spectrogram_crop(
|
||||
config.frequencies,
|
||||
stft=config.stft,
|
||||
samplerate=input_samplerate,
|
||||
)
|
||||
|
||||
self.spectrogram_resizer = build_spectrogram_resizer(config.size)
|
||||
|
||||
self.min_freq = config.frequencies.min_freq
|
||||
self.max_freq = config.frequencies.max_freq
|
||||
|
||||
self.input_samplerate = input_samplerate
|
||||
self.output_samplerate = compute_output_samplerate(
|
||||
config,
|
||||
input_samplerate=input_samplerate,
|
||||
)
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
wav = self.audio_transforms(wav)
|
||||
spec = self.spectrogram_builder(wav)
|
||||
return self.process_spectrogram(spec)
|
||||
|
||||
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return self.spectrogram_builder(wav)
|
||||
|
||||
def process_audio(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return self(wav)
|
||||
|
||||
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
spec = self.spectrogram_crop(spec)
|
||||
spec = self.spectrogram_transforms(spec)
|
||||
return self.spectrogram_resizer(spec)
|
||||
|
||||
|
||||
def compute_output_samplerate(
|
||||
config: PreprocessingConfig,
|
||||
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> float:
|
||||
_, hop_size = _spec_params_from_config(
|
||||
config.stft, samplerate=input_samplerate
|
||||
)
|
||||
factor = config.size.resize_factor
|
||||
return input_samplerate * factor / hop_size
|
||||
|
||||
|
||||
def build_preprocessor(
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> PreprocessorProtocol:
|
||||
"""Factory function to build the standard preprocessor from configuration."""
|
||||
config = config or PreprocessingConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building preprocessor with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
return Preprocessor(config=config, input_samplerate=input_samplerate)
|
||||
@ -1,31 +1,21 @@
|
||||
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
||||
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
from typing import Annotated, Callable, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.common import PeakNormalize
|
||||
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.preprocess.common import peak_normalize
|
||||
|
||||
__all__ = [
|
||||
"STFTConfig",
|
||||
"FrequencyConfig",
|
||||
"PcenConfig",
|
||||
"SpectrogramConfig",
|
||||
"build_spectrogram_transform",
|
||||
"build_spectrogram_builder",
|
||||
"MIN_FREQ",
|
||||
"MAX_FREQ",
|
||||
]
|
||||
|
||||
|
||||
@ -60,6 +50,20 @@ class STFTConfig(BaseConfig):
|
||||
window_fn: str = "hann"
|
||||
|
||||
|
||||
def build_spectrogram_builder(
|
||||
config: STFTConfig,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> torch.nn.Module:
|
||||
n_fft, hop_length = _spec_params_from_config(config, samplerate=samplerate)
|
||||
return torchaudio.transforms.Spectrogram(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
window_fn=get_spectrogram_window(config.window_fn),
|
||||
center=True,
|
||||
power=1,
|
||||
)
|
||||
|
||||
|
||||
def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
|
||||
if name == "hann":
|
||||
return torch.hann_window
|
||||
@ -81,24 +85,31 @@ def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
|
||||
)
|
||||
|
||||
|
||||
def _spec_params_from_config(samplerate: int, conf: STFTConfig):
|
||||
n_fft = int(samplerate * conf.window_duration)
|
||||
hop_length = int(n_fft * (1 - conf.window_overlap))
|
||||
def _spec_params_from_config(
|
||||
config: STFTConfig,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
):
|
||||
n_fft = int(samplerate * config.window_duration)
|
||||
hop_length = int(n_fft * (1 - config.window_overlap))
|
||||
return n_fft, hop_length
|
||||
|
||||
|
||||
def build_spectrogram_builder(
|
||||
samplerate: int,
|
||||
conf: STFTConfig,
|
||||
) -> torch.nn.Module:
|
||||
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
|
||||
return torchaudio.transforms.Spectrogram(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
window_fn=get_spectrogram_window(conf.window_fn),
|
||||
center=True,
|
||||
power=1,
|
||||
)
|
||||
def _frequency_to_index(
|
||||
freq: float,
|
||||
n_fft: int,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> Optional[int]:
|
||||
alpha = freq * 2 / samplerate
|
||||
height = np.floor(n_fft / 2) + 1
|
||||
index = int(np.floor(alpha * height))
|
||||
|
||||
if index <= 0:
|
||||
return None
|
||||
|
||||
if index >= height:
|
||||
return None
|
||||
|
||||
return index
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
@ -114,36 +125,36 @@ class FrequencyConfig(BaseConfig):
|
||||
Frequencies below this value will be cropped. Must be >= 0.
|
||||
"""
|
||||
|
||||
max_freq: int = Field(default=120_000, ge=0)
|
||||
min_freq: int = Field(default=10_000, ge=0)
|
||||
max_freq: int = Field(default=MAX_FREQ, ge=0)
|
||||
min_freq: int = Field(default=MIN_FREQ, ge=0)
|
||||
|
||||
|
||||
def _frequency_to_index(
|
||||
freq: float,
|
||||
samplerate: int,
|
||||
n_fft: int,
|
||||
) -> Optional[int]:
|
||||
alpha = freq * 2 / samplerate
|
||||
height = np.floor(n_fft / 2) + 1
|
||||
index = int(np.floor(alpha * height))
|
||||
|
||||
if index <= 0:
|
||||
return None
|
||||
|
||||
if index >= height:
|
||||
return None
|
||||
|
||||
return index
|
||||
|
||||
|
||||
class FrequencyClip(torch.nn.Module):
|
||||
class FrequencyCrop(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
low_index: Optional[int] = None,
|
||||
high_index: Optional[int] = None,
|
||||
samplerate: int,
|
||||
n_fft: int,
|
||||
min_freq: Optional[int] = None,
|
||||
max_freq: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.samplerate = samplerate
|
||||
self.min_freq = min_freq
|
||||
self.max_freq = max_freq
|
||||
|
||||
low_index = None
|
||||
if min_freq is not None:
|
||||
low_index = _frequency_to_index(
|
||||
min_freq, self.samplerate, self.n_fft
|
||||
)
|
||||
self.low_index = low_index
|
||||
|
||||
high_index = None
|
||||
if max_freq is not None:
|
||||
high_index = _frequency_to_index(
|
||||
max_freq, self.samplerate, self.n_fft
|
||||
)
|
||||
self.high_index = high_index
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
@ -164,6 +175,62 @@ class FrequencyClip(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def build_spectrogram_crop(
|
||||
config: FrequencyConfig,
|
||||
stft: Optional[STFTConfig] = None,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> torch.nn.Module:
|
||||
stft = stft or STFTConfig()
|
||||
n_fft, _ = _spec_params_from_config(stft, samplerate=samplerate)
|
||||
return FrequencyCrop(
|
||||
samplerate=samplerate,
|
||||
n_fft=n_fft,
|
||||
min_freq=config.min_freq,
|
||||
max_freq=config.max_freq,
|
||||
)
|
||||
|
||||
|
||||
class ResizeConfig(BaseConfig):
|
||||
name: Literal["resize_spec"] = "resize_spec"
|
||||
height: int = 128
|
||||
resize_factor: float = 0.5
|
||||
|
||||
|
||||
class ResizeSpec(torch.nn.Module):
|
||||
def __init__(self, height: int, time_factor: float):
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.time_factor = time_factor
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
current_length = spec.shape[-1]
|
||||
target_length = int(self.time_factor * current_length)
|
||||
|
||||
original_ndim = spec.ndim
|
||||
while spec.ndim < 4:
|
||||
spec = spec.unsqueeze(0)
|
||||
|
||||
resized = torch.nn.functional.interpolate(
|
||||
spec,
|
||||
size=(self.height, target_length),
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
while resized.ndim != original_ndim:
|
||||
resized = resized.squeeze(0)
|
||||
|
||||
return resized
|
||||
|
||||
|
||||
def build_spectrogram_resizer(config: ResizeConfig) -> torch.nn.Module:
|
||||
return ResizeSpec(height=config.height, time_factor=config.resize_factor)
|
||||
|
||||
|
||||
spectrogram_transforms: Registry[torch.nn.Module, [int]] = Registry(
|
||||
"spectrogram_transform"
|
||||
)
|
||||
|
||||
|
||||
class PcenConfig(BaseConfig):
|
||||
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
|
||||
|
||||
@ -182,7 +249,7 @@ class PCEN(torch.nn.Module):
|
||||
bias: float = 2.0,
|
||||
power: float = 0.5,
|
||||
eps: float = 1e-6,
|
||||
dtype=torch.float64,
|
||||
dtype=torch.float32,
|
||||
):
|
||||
super().__init__()
|
||||
self.smoothing_constant = smoothing_constant
|
||||
@ -218,6 +285,19 @@ class PCEN(torch.nn.Module):
|
||||
* torch.expm1(self.power * torch.log1p(S * smooth / self.bias))
|
||||
).to(spec.dtype)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PcenConfig, samplerate: int):
|
||||
smooth = _compute_smoothing_constant(samplerate, config.time_constant)
|
||||
return cls(
|
||||
smoothing_constant=smooth,
|
||||
gain=config.gain,
|
||||
bias=config.bias,
|
||||
power=config.power,
|
||||
)
|
||||
|
||||
|
||||
spectrogram_transforms.register(PcenConfig, PCEN)
|
||||
|
||||
|
||||
def _compute_smoothing_constant(
|
||||
samplerate: int,
|
||||
@ -241,16 +321,26 @@ class ToPower(torch.nn.Module):
|
||||
return spec**2
|
||||
|
||||
|
||||
def _build_amplitude_scaler(conf: ScaleAmplitudeConfig) -> torch.nn.Module:
|
||||
if conf.scale == "db":
|
||||
return torchaudio.transforms.AmplitudeToDB()
|
||||
_scalers = {
|
||||
"db": torchaudio.transforms.AmplitudeToDB,
|
||||
"power": ToPower,
|
||||
}
|
||||
|
||||
if conf.scale == "power":
|
||||
return ToPower()
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Amplitude scaling {conf.scale} not implemented"
|
||||
)
|
||||
class ScaleAmplitude(torch.nn.Module):
|
||||
def __init__(self, scale: Literal["power", "db"]):
|
||||
self.scale = scale
|
||||
self.scaler = _scalers[scale]()
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return self.scaler(spec)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ScaleAmplitudeConfig, samplerate: int):
|
||||
return cls(scale=config.scale)
|
||||
|
||||
|
||||
spectrogram_transforms.register(ScaleAmplitudeConfig, ScaleAmplitude)
|
||||
|
||||
|
||||
class SpectralMeanSubstractionConfig(BaseConfig):
|
||||
@ -262,43 +352,36 @@ class SpectralMeanSubstraction(torch.nn.Module):
|
||||
mean = spec.mean(-1, keepdim=True)
|
||||
return (spec - mean).clamp(min=0)
|
||||
|
||||
|
||||
class ResizeConfig(BaseConfig):
|
||||
name: Literal["resize_spec"] = "resize_spec"
|
||||
height: int = 128
|
||||
resize_factor: float = 0.5
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: SpectralMeanSubstractionConfig,
|
||||
samplerate: int,
|
||||
):
|
||||
return cls()
|
||||
|
||||
|
||||
class ResizeSpec(torch.nn.Module):
|
||||
def __init__(self, height: int, time_factor: float):
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.time_factor = time_factor
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
current_length = spec.shape[-1]
|
||||
target_length = int(self.time_factor * current_length)
|
||||
|
||||
original_ndim = spec.ndim
|
||||
while spec.ndim < 4:
|
||||
spec = spec.unsqueeze(0)
|
||||
|
||||
resized = torch.nn.functional.interpolate(
|
||||
spec,
|
||||
size=(self.height, target_length),
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
while resized.ndim != original_ndim:
|
||||
resized = resized.squeeze(0)
|
||||
|
||||
return resized
|
||||
spectrogram_transforms.register(
|
||||
SpectralMeanSubstractionConfig,
|
||||
SpectralMeanSubstraction,
|
||||
)
|
||||
|
||||
|
||||
class PeakNormalizeConfig(BaseConfig):
|
||||
name: Literal["peak_normalize"] = "peak_normalize"
|
||||
|
||||
|
||||
class PeakNormalize(torch.nn.Module):
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return peak_normalize(spec)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PeakNormalizeConfig, samplerate: int):
|
||||
return cls()
|
||||
|
||||
|
||||
spectrogram_transforms.register(PeakNormalizeConfig, PeakNormalize)
|
||||
|
||||
SpectrogramTransform = Annotated[
|
||||
Union[
|
||||
PcenConfig,
|
||||
@ -310,114 +393,8 @@ SpectrogramTransform = Annotated[
|
||||
]
|
||||
|
||||
|
||||
class SpectrogramConfig(BaseConfig):
|
||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
||||
transforms: Sequence[SpectrogramTransform] = Field(
|
||||
default_factory=lambda: [
|
||||
PcenConfig(),
|
||||
SpectralMeanSubstractionConfig(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _build_spectrogram_transform_step(
|
||||
step: SpectrogramTransform,
|
||||
samplerate: int,
|
||||
) -> torch.nn.Module:
|
||||
if step.name == "pcen":
|
||||
return PCEN(
|
||||
smoothing_constant=_compute_smoothing_constant(
|
||||
samplerate=samplerate,
|
||||
time_constant=step.time_constant,
|
||||
),
|
||||
gain=step.gain,
|
||||
bias=step.bias,
|
||||
power=step.power,
|
||||
)
|
||||
|
||||
if step.name == "scale_amplitude":
|
||||
return _build_amplitude_scaler(step)
|
||||
|
||||
if step.name == "spectral_mean_substraction":
|
||||
return SpectralMeanSubstraction()
|
||||
|
||||
if step.name == "peak_normalize":
|
||||
return PeakNormalize()
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Spectrogram preprocessing step {step.name} not implemented"
|
||||
)
|
||||
|
||||
|
||||
def build_spectrogram_transform(
|
||||
config: SpectrogramTransform,
|
||||
samplerate: int,
|
||||
conf: SpectrogramConfig,
|
||||
) -> torch.nn.Module:
|
||||
return torch.nn.Sequential(
|
||||
*[
|
||||
_build_spectrogram_transform_step(step, samplerate=samplerate)
|
||||
for step in conf.transforms
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class SpectrogramPipeline(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
spec_builder: torch.nn.Module,
|
||||
freq_cutter: torch.nn.Module,
|
||||
transforms: torch.nn.Module,
|
||||
resizer: torch.nn.Module,
|
||||
):
|
||||
super().__init__()
|
||||
self.spec_builder = spec_builder
|
||||
self.freq_cutter = freq_cutter
|
||||
self.transforms = transforms
|
||||
self.resizer = resizer
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
spec = self.spec_builder(wav)
|
||||
spec = self.freq_cutter(spec)
|
||||
spec = self.transforms(spec)
|
||||
return self.resizer(spec)
|
||||
|
||||
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return self.spec_builder(wav)
|
||||
|
||||
def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return self.freq_cutter(spec)
|
||||
|
||||
def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return self.transforms(spec)
|
||||
|
||||
def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return self.resizer(spec)
|
||||
|
||||
|
||||
def build_spectrogram_pipeline(
|
||||
samplerate: int,
|
||||
conf: SpectrogramConfig,
|
||||
) -> SpectrogramPipeline:
|
||||
spec_builder = build_spectrogram_builder(samplerate, conf.stft)
|
||||
n_fft, _ = _spec_params_from_config(samplerate, conf.stft)
|
||||
cutter = FrequencyClip(
|
||||
low_index=_frequency_to_index(
|
||||
conf.frequencies.min_freq, samplerate, n_fft
|
||||
),
|
||||
high_index=_frequency_to_index(
|
||||
conf.frequencies.max_freq, samplerate, n_fft
|
||||
),
|
||||
)
|
||||
transforms = build_spectrogram_transform(samplerate, conf)
|
||||
resizer = ResizeSpec(
|
||||
height=conf.size.height,
|
||||
time_factor=conf.size.resize_factor,
|
||||
)
|
||||
return SpectrogramPipeline(
|
||||
spec_builder=spec_builder,
|
||||
freq_cutter=cutter,
|
||||
transforms=transforms,
|
||||
resizer=resizer,
|
||||
)
|
||||
return spectrogram_transforms.build(config, samplerate)
|
||||
|
||||
@ -1,17 +1,6 @@
|
||||
"""BatDetect2 Target Definition system."""
|
||||
|
||||
from collections import Counter
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field, field_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.data.conditions import build_sound_event_condition
|
||||
from batdetect2.targets.classes import (
|
||||
DEFAULT_CLASSES,
|
||||
DEFAULT_DETECTION_CLASS,
|
||||
SoundEventDecoder,
|
||||
SoundEventEncoder,
|
||||
TargetClassConfig,
|
||||
@ -19,23 +8,29 @@ from batdetect2.targets.classes import (
|
||||
build_sound_event_encoder,
|
||||
get_class_names_from_config,
|
||||
)
|
||||
from batdetect2.targets.config import TargetConfig, load_target_config
|
||||
from batdetect2.targets.rois import (
|
||||
AnchorBBoxMapperConfig,
|
||||
ROIMapperConfig,
|
||||
ROITargetMapper,
|
||||
build_roi_mapper,
|
||||
)
|
||||
from batdetect2.targets.targets import (
|
||||
Targets,
|
||||
build_targets,
|
||||
iterate_encoded_sound_events,
|
||||
load_targets,
|
||||
)
|
||||
from batdetect2.targets.terms import (
|
||||
call_type,
|
||||
data_source,
|
||||
generic_class,
|
||||
individual,
|
||||
)
|
||||
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"AnchorBBoxMapperConfig",
|
||||
"DEFAULT_TARGET_CONFIG",
|
||||
"ROIMapperConfig",
|
||||
"ROITargetMapper",
|
||||
"SoundEventDecoder",
|
||||
"SoundEventEncoder",
|
||||
@ -45,365 +40,13 @@ __all__ = [
|
||||
"build_roi_mapper",
|
||||
"build_sound_event_decoder",
|
||||
"build_sound_event_encoder",
|
||||
"build_targets",
|
||||
"call_type",
|
||||
"data_source",
|
||||
"generic_class",
|
||||
"get_class_names_from_config",
|
||||
"individual",
|
||||
"iterate_encoded_sound_events",
|
||||
"load_target_config",
|
||||
"load_targets",
|
||||
]
|
||||
|
||||
|
||||
class TargetConfig(BaseConfig):
|
||||
detection_target: TargetClassConfig = Field(
|
||||
default=DEFAULT_DETECTION_CLASS
|
||||
)
|
||||
|
||||
classification_targets: List[TargetClassConfig] = Field(
|
||||
default_factory=lambda: DEFAULT_CLASSES
|
||||
)
|
||||
|
||||
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
||||
|
||||
@field_validator("classification_targets")
|
||||
def check_unique_class_names(cls, v: List[TargetClassConfig]):
|
||||
"""Ensure all defined class names are unique."""
|
||||
names = [c.name for c in v]
|
||||
|
||||
if len(names) != len(set(names)):
|
||||
name_counts = Counter(names)
|
||||
duplicates = [
|
||||
name for name, count in name_counts.items() if count > 1
|
||||
]
|
||||
raise ValueError(
|
||||
"Class names must be unique. Found duplicates: "
|
||||
f"{', '.join(duplicates)}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
def load_target_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> TargetConfig:
|
||||
"""Load the unified target configuration from a file.
|
||||
|
||||
Reads a configuration file (typically YAML) and validates it against the
|
||||
`TargetConfig` schema, potentially extracting data from a nested field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
target configuration. If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TargetConfig
|
||||
The loaded and validated unified target configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded configuration data does not conform to the
|
||||
`TargetConfig` schema (including validation within nested configs
|
||||
like `ClassesConfig`).
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path within the loaded data.
|
||||
"""
|
||||
return load_config(path=path, schema=TargetConfig, field=field)
|
||||
|
||||
|
||||
class Targets(TargetProtocol):
|
||||
"""Encapsulates the complete configured target definition pipeline.
|
||||
|
||||
This class implements the `TargetProtocol`, holding the configured
|
||||
functions for filtering, transforming, encoding (tags to class name),
|
||||
decoding (class name to tags), and mapping ROIs (geometry to position/size
|
||||
and back). It provides a high-level interface to apply these steps and
|
||||
access relevant metadata like class names and dimension names.
|
||||
|
||||
Instances are typically created using the `build_targets` factory function
|
||||
or the `load_targets` convenience loader.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
class_names : List[str]
|
||||
An ordered list of the unique names of the specific target classes
|
||||
defined in the configuration.
|
||||
generic_class_tags : List[data.Tag]
|
||||
A list of `soundevent.data.Tag` objects representing the configured
|
||||
generic class category (used when no specific class matches).
|
||||
dimension_names : List[str]
|
||||
The names of the size dimensions handled by the ROI mapper
|
||||
(e.g., ['width', 'height']).
|
||||
"""
|
||||
|
||||
class_names: List[str]
|
||||
detection_class_tags: List[data.Tag]
|
||||
dimension_names: List[str]
|
||||
detection_class_name: str
|
||||
|
||||
def __init__(self, config: TargetConfig):
|
||||
"""Initialize the Targets object."""
|
||||
self.config = config
|
||||
|
||||
self._filter_fn = build_sound_event_condition(
|
||||
config.detection_target.match_if
|
||||
)
|
||||
self._encode_fn = build_sound_event_encoder(
|
||||
config.classification_targets
|
||||
)
|
||||
self._decode_fn = build_sound_event_decoder(
|
||||
config.classification_targets
|
||||
)
|
||||
|
||||
self._roi_mapper = build_roi_mapper(config.roi)
|
||||
|
||||
self.dimension_names = self._roi_mapper.dimension_names
|
||||
|
||||
self.class_names = get_class_names_from_config(
|
||||
config.classification_targets
|
||||
)
|
||||
|
||||
self.detection_class_name = config.detection_target.name
|
||||
self.detection_class_tags = config.detection_target.assign_tags
|
||||
|
||||
self._roi_mapper_overrides = {
|
||||
class_config.name: build_roi_mapper(class_config.roi)
|
||||
for class_config in config.classification_targets
|
||||
if class_config.roi is not None
|
||||
}
|
||||
|
||||
for class_name in self._roi_mapper_overrides:
|
||||
if class_name not in self.class_names:
|
||||
# TODO: improve this warning
|
||||
logger.warning(
|
||||
"The ROI mapper overrides contains a class ({class_name}) "
|
||||
"not present in the class names.",
|
||||
class_name=class_name,
|
||||
)
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||
"""Apply the configured filter to a sound event annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation to filter.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation should be kept (passes the filter),
|
||||
False otherwise. If no filter was configured, always returns True.
|
||||
"""
|
||||
return self._filter_fn(sound_event)
|
||||
|
||||
def encode_class(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> Optional[str]:
|
||||
"""Encode a sound event annotation to its target class name.
|
||||
|
||||
Applies the configured class definition rules (including priority)
|
||||
to determine the specific class name for the annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation to encode. Note: This should typically be called
|
||||
*after* applying any transformations via the `transform` method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
The name of the matched target class, or None if the annotation
|
||||
does not match any specific class rule (i.e., it belongs to the
|
||||
generic category).
|
||||
"""
|
||||
return self._encode_fn(sound_event)
|
||||
|
||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||
"""Decode a predicted class name back into representative tags.
|
||||
|
||||
Uses the configured mapping (based on `TargetClass.output_tags` or
|
||||
`TargetClass.tags`) to convert a class name string into a list of
|
||||
`soundevent.data.Tag` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
class_label : str
|
||||
The class name to decode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.Tag]
|
||||
The list of tags corresponding to the input class name.
|
||||
"""
|
||||
return self._decode_fn(class_label)
|
||||
|
||||
def encode_roi(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> tuple[Position, Size]:
|
||||
"""Extract the target reference position from the annotation's roi.
|
||||
|
||||
Delegates to the internal ROI mapper's `get_roi_position` method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry.
|
||||
"""
|
||||
class_name = self.encode_class(sound_event)
|
||||
|
||||
if class_name in self._roi_mapper_overrides:
|
||||
return self._roi_mapper_overrides[class_name].encode(
|
||||
sound_event.sound_event
|
||||
)
|
||||
|
||||
return self._roi_mapper.encode(sound_event.sound_event)
|
||||
|
||||
def decode_roi(
|
||||
self,
|
||||
position: Position,
|
||||
size: Size,
|
||||
class_name: Optional[str] = None,
|
||||
) -> data.Geometry:
|
||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||
|
||||
Delegates to the internal ROI mapper's `recover_roi` method, which
|
||||
un-scales the dimensions and reconstructs the geometry (typically a
|
||||
`BoundingBox`).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
dims : np.ndarray
|
||||
NumPy array with size dimensions (e.g., from model prediction),
|
||||
matching the order in `self.dimension_names`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.Geometry
|
||||
The reconstructed geometry (typically `BoundingBox`).
|
||||
"""
|
||||
if class_name in self._roi_mapper_overrides:
|
||||
return self._roi_mapper_overrides[class_name].decode(
|
||||
position,
|
||||
size,
|
||||
)
|
||||
|
||||
return self._roi_mapper.decode(position, size)
|
||||
|
||||
|
||||
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
classification_targets=DEFAULT_CLASSES,
|
||||
detection_target=DEFAULT_DETECTION_CLASS,
|
||||
roi=AnchorBBoxMapperConfig(),
|
||||
)
|
||||
|
||||
|
||||
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
||||
"""Build a Targets object from a loaded TargetConfig.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : TargetConfig
|
||||
The loaded and validated unified target configuration object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Targets
|
||||
An initialized `Targets` object ready for use.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If term keys or derivation function keys specified in the `config`
|
||||
are not found in their respective registries.
|
||||
ImportError, AttributeError, TypeError
|
||||
If dynamic import of a derivation function fails (when configured).
|
||||
"""
|
||||
config = config or DEFAULT_TARGET_CONFIG
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building targets with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
return Targets(config=config)
|
||||
|
||||
|
||||
def load_targets(
|
||||
config_path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> Targets:
|
||||
"""Load a Targets object directly from a configuration file.
|
||||
|
||||
This convenience factory method loads the `TargetConfig` from the
|
||||
specified file path and then calls `Targets.from_config` to build
|
||||
the fully initialized `Targets` object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : data.PathLike
|
||||
Path to the configuration file (e.g., YAML).
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing
|
||||
the target configuration. If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Targets
|
||||
An initialized `Targets` object ready for use.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
|
||||
TypeError
|
||||
Errors raised during file loading, validation, or extraction via
|
||||
`load_target_config`.
|
||||
KeyError, ImportError, AttributeError, TypeError
|
||||
Errors raised during the build process by `Targets.from_config`
|
||||
(e.g., missing keys in registries, failed imports).
|
||||
"""
|
||||
config = load_target_config(
|
||||
config_path,
|
||||
field=field,
|
||||
)
|
||||
return build_targets(config)
|
||||
|
||||
|
||||
def iterate_encoded_sound_events(
|
||||
sound_events: Iterable[data.SoundEventAnnotation],
|
||||
targets: TargetProtocol,
|
||||
) -> Iterable[Tuple[Optional[str], Position, Size]]:
|
||||
for sound_event in sound_events:
|
||||
if not targets.filter(sound_event):
|
||||
continue
|
||||
|
||||
geometry = sound_event.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
continue
|
||||
|
||||
class_name = targets.encode_class(sound_event)
|
||||
position, size = targets.encode_roi(sound_event)
|
||||
|
||||
yield class_name, position, size
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Dict, List, Optional
|
||||
from pydantic import Field, PrivateAttr, computed_field, model_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.data.conditions import (
|
||||
AllOfConfig,
|
||||
HasAllTagsConfig,
|
||||
|
||||
84
src/batdetect2/targets/config.py
Normal file
84
src/batdetect2/targets/config.py
Normal file
@ -0,0 +1,84 @@
|
||||
from collections import Counter
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.classes import (
|
||||
DEFAULT_CLASSES,
|
||||
DEFAULT_DETECTION_CLASS,
|
||||
TargetClassConfig,
|
||||
)
|
||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig
|
||||
|
||||
__all__ = [
|
||||
"TargetConfig",
|
||||
"load_target_config",
|
||||
]
|
||||
|
||||
|
||||
class TargetConfig(BaseConfig):
|
||||
detection_target: TargetClassConfig = Field(
|
||||
default=DEFAULT_DETECTION_CLASS
|
||||
)
|
||||
|
||||
classification_targets: List[TargetClassConfig] = Field(
|
||||
default_factory=lambda: DEFAULT_CLASSES
|
||||
)
|
||||
|
||||
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
||||
|
||||
@field_validator("classification_targets")
|
||||
def check_unique_class_names(cls, v: List[TargetClassConfig]):
|
||||
"""Ensure all defined class names are unique."""
|
||||
names = [c.name for c in v]
|
||||
|
||||
if len(names) != len(set(names)):
|
||||
name_counts = Counter(names)
|
||||
duplicates = [
|
||||
name for name, count in name_counts.items() if count > 1
|
||||
]
|
||||
raise ValueError(
|
||||
"Class names must be unique. Found duplicates: "
|
||||
f"{', '.join(duplicates)}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
def load_target_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> TargetConfig:
|
||||
"""Load the unified target configuration from a file.
|
||||
|
||||
Reads a configuration file (typically YAML) and validates it against the
|
||||
`TargetConfig` schema, potentially extracting data from a nested field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
target configuration. If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TargetConfig
|
||||
The loaded and validated unified target configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded configuration data does not conform to the
|
||||
`TargetConfig` schema (including validation within nested configs
|
||||
like `ClassesConfig`).
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path within the loaded data.
|
||||
"""
|
||||
return load_config(path=path, schema=TargetConfig, field=field)
|
||||
@ -26,12 +26,17 @@ import numpy as np
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core.arrays import spec_to_xarray
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||
from batdetect2.typing.targets import Position, ROITargetMapper, Size
|
||||
from batdetect2.utils.arrays import spec_to_xarray
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
Position,
|
||||
PreprocessorProtocol,
|
||||
ROITargetMapper,
|
||||
Size,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Anchor",
|
||||
@ -260,6 +265,7 @@ class PeakEnergyBBoxMapperConfig(BaseConfig):
|
||||
"""
|
||||
|
||||
name: Literal["peak_energy_bbox"] = "peak_energy_bbox"
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
@ -451,8 +457,11 @@ def build_roi_mapper(
|
||||
)
|
||||
|
||||
if config.name == "peak_energy_bbox":
|
||||
preprocessor = build_preprocessor(config.preprocessing)
|
||||
audio_loader = build_audio_loader(config.preprocessing.audio)
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
preprocessor = build_preprocessor(
|
||||
config.preprocessing,
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
return PeakEnergyBBoxMapper(
|
||||
preprocessor=preprocessor,
|
||||
audio_loader=audio_loader,
|
||||
|
||||
308
src/batdetect2/targets/targets.py
Normal file
308
src/batdetect2/targets/targets.py
Normal file
@ -0,0 +1,308 @@
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.conditions import build_sound_event_condition
|
||||
from batdetect2.targets.classes import (
|
||||
DEFAULT_CLASSES,
|
||||
DEFAULT_DETECTION_CLASS,
|
||||
build_sound_event_decoder,
|
||||
build_sound_event_encoder,
|
||||
get_class_names_from_config,
|
||||
)
|
||||
from batdetect2.targets.config import TargetConfig, load_target_config
|
||||
from batdetect2.targets.rois import (
|
||||
AnchorBBoxMapperConfig,
|
||||
build_roi_mapper,
|
||||
)
|
||||
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
||||
|
||||
|
||||
class Targets(TargetProtocol):
|
||||
"""Encapsulates the complete configured target definition pipeline.
|
||||
|
||||
This class implements the `TargetProtocol`, holding the configured
|
||||
functions for filtering, transforming, encoding (tags to class name),
|
||||
decoding (class name to tags), and mapping ROIs (geometry to position/size
|
||||
and back). It provides a high-level interface to apply these steps and
|
||||
access relevant metadata like class names and dimension names.
|
||||
|
||||
Instances are typically created using the `build_targets` factory function
|
||||
or the `load_targets` convenience loader.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
class_names : List[str]
|
||||
An ordered list of the unique names of the specific target classes
|
||||
defined in the configuration.
|
||||
generic_class_tags : List[data.Tag]
|
||||
A list of `soundevent.data.Tag` objects representing the configured
|
||||
generic class category (used when no specific class matches).
|
||||
dimension_names : List[str]
|
||||
The names of the size dimensions handled by the ROI mapper
|
||||
(e.g., ['width', 'height']).
|
||||
"""
|
||||
|
||||
class_names: List[str]
|
||||
detection_class_tags: List[data.Tag]
|
||||
dimension_names: List[str]
|
||||
detection_class_name: str
|
||||
|
||||
def __init__(self, config: TargetConfig):
|
||||
"""Initialize the Targets object."""
|
||||
self.config = config
|
||||
|
||||
self._filter_fn = build_sound_event_condition(
|
||||
config.detection_target.match_if
|
||||
)
|
||||
self._encode_fn = build_sound_event_encoder(
|
||||
config.classification_targets
|
||||
)
|
||||
self._decode_fn = build_sound_event_decoder(
|
||||
config.classification_targets
|
||||
)
|
||||
|
||||
self._roi_mapper = build_roi_mapper(config.roi)
|
||||
|
||||
self.dimension_names = self._roi_mapper.dimension_names
|
||||
|
||||
self.class_names = get_class_names_from_config(
|
||||
config.classification_targets
|
||||
)
|
||||
|
||||
self.detection_class_name = config.detection_target.name
|
||||
self.detection_class_tags = config.detection_target.assign_tags
|
||||
|
||||
self._roi_mapper_overrides = {
|
||||
class_config.name: build_roi_mapper(class_config.roi)
|
||||
for class_config in config.classification_targets
|
||||
if class_config.roi is not None
|
||||
}
|
||||
|
||||
for class_name in self._roi_mapper_overrides:
|
||||
if class_name not in self.class_names:
|
||||
# TODO: improve this warning
|
||||
logger.warning(
|
||||
"The ROI mapper overrides contains a class ({class_name}) "
|
||||
"not present in the class names.",
|
||||
class_name=class_name,
|
||||
)
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||
"""Apply the configured filter to a sound event annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation to filter.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation should be kept (passes the filter),
|
||||
False otherwise. If no filter was configured, always returns True.
|
||||
"""
|
||||
return self._filter_fn(sound_event)
|
||||
|
||||
def encode_class(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> Optional[str]:
|
||||
"""Encode a sound event annotation to its target class name.
|
||||
|
||||
Applies the configured class definition rules (including priority)
|
||||
to determine the specific class name for the annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation to encode. Note: This should typically be called
|
||||
*after* applying any transformations via the `transform` method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
The name of the matched target class, or None if the annotation
|
||||
does not match any specific class rule (i.e., it belongs to the
|
||||
generic category).
|
||||
"""
|
||||
return self._encode_fn(sound_event)
|
||||
|
||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||
"""Decode a predicted class name back into representative tags.
|
||||
|
||||
Uses the configured mapping (based on `TargetClass.output_tags` or
|
||||
`TargetClass.tags`) to convert a class name string into a list of
|
||||
`soundevent.data.Tag` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
class_label : str
|
||||
The class name to decode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.Tag]
|
||||
The list of tags corresponding to the input class name.
|
||||
"""
|
||||
return self._decode_fn(class_label)
|
||||
|
||||
def encode_roi(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> tuple[Position, Size]:
|
||||
"""Extract the target reference position from the annotation's roi.
|
||||
|
||||
Delegates to the internal ROI mapper's `get_roi_position` method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry.
|
||||
"""
|
||||
class_name = self.encode_class(sound_event)
|
||||
|
||||
if class_name in self._roi_mapper_overrides:
|
||||
return self._roi_mapper_overrides[class_name].encode(
|
||||
sound_event.sound_event
|
||||
)
|
||||
|
||||
return self._roi_mapper.encode(sound_event.sound_event)
|
||||
|
||||
def decode_roi(
|
||||
self,
|
||||
position: Position,
|
||||
size: Size,
|
||||
class_name: Optional[str] = None,
|
||||
) -> data.Geometry:
|
||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||
|
||||
Delegates to the internal ROI mapper's `recover_roi` method, which
|
||||
un-scales the dimensions and reconstructs the geometry (typically a
|
||||
`BoundingBox`).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
dims : np.ndarray
|
||||
NumPy array with size dimensions (e.g., from model prediction),
|
||||
matching the order in `self.dimension_names`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.Geometry
|
||||
The reconstructed geometry (typically `BoundingBox`).
|
||||
"""
|
||||
if class_name in self._roi_mapper_overrides:
|
||||
return self._roi_mapper_overrides[class_name].decode(
|
||||
position,
|
||||
size,
|
||||
)
|
||||
|
||||
return self._roi_mapper.decode(position, size)
|
||||
|
||||
|
||||
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
classification_targets=DEFAULT_CLASSES,
|
||||
detection_target=DEFAULT_DETECTION_CLASS,
|
||||
roi=AnchorBBoxMapperConfig(),
|
||||
)
|
||||
|
||||
|
||||
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
||||
"""Build a Targets object from a loaded TargetConfig.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : TargetConfig
|
||||
The loaded and validated unified target configuration object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Targets
|
||||
An initialized `Targets` object ready for use.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If term keys or derivation function keys specified in the `config`
|
||||
are not found in their respective registries.
|
||||
ImportError, AttributeError, TypeError
|
||||
If dynamic import of a derivation function fails (when configured).
|
||||
"""
|
||||
config = config or DEFAULT_TARGET_CONFIG
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building targets with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
return Targets(config=config)
|
||||
|
||||
|
||||
def load_targets(
|
||||
config_path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> Targets:
|
||||
"""Load a Targets object directly from a configuration file.
|
||||
|
||||
This convenience factory method loads the `TargetConfig` from the
|
||||
specified file path and then calls `Targets.from_config` to build
|
||||
the fully initialized `Targets` object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : data.PathLike
|
||||
Path to the configuration file (e.g., YAML).
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing
|
||||
the target configuration. If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Targets
|
||||
An initialized `Targets` object ready for use.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
|
||||
TypeError
|
||||
Errors raised during file loading, validation, or extraction via
|
||||
`load_target_config`.
|
||||
KeyError, ImportError, AttributeError, TypeError
|
||||
Errors raised during the build process by `Targets.from_config`
|
||||
(e.g., missing keys in registries, failed imports).
|
||||
"""
|
||||
config = load_target_config(
|
||||
config_path,
|
||||
field=field,
|
||||
)
|
||||
return build_targets(config)
|
||||
|
||||
|
||||
def iterate_encoded_sound_events(
|
||||
sound_events: Iterable[data.SoundEventAnnotation],
|
||||
targets: TargetProtocol,
|
||||
) -> Iterable[Tuple[Optional[str], Position, Size]]:
|
||||
for sound_event in sound_events:
|
||||
if not targets.filter(sound_event):
|
||||
continue
|
||||
|
||||
geometry = sound_event.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
continue
|
||||
|
||||
class_name = targets.encode_class(sound_event)
|
||||
position, size = targets.encode_roi(sound_event)
|
||||
|
||||
yield class_name, position, size
|
||||
@ -14,12 +14,9 @@ from batdetect2.train.augmentations import (
|
||||
scale_volume,
|
||||
warp_spectrogram,
|
||||
)
|
||||
from batdetect2.train.clips import build_clipper, select_subclip
|
||||
from batdetect2.train.config import (
|
||||
FullTrainingConfig,
|
||||
PLTrainerConfig,
|
||||
TrainingConfig,
|
||||
load_full_training_config,
|
||||
load_train_config,
|
||||
)
|
||||
from batdetect2.train.dataset import (
|
||||
@ -48,7 +45,6 @@ __all__ = [
|
||||
"DetectionLossConfig",
|
||||
"EchoAugmentationConfig",
|
||||
"FrequencyMaskAugmentationConfig",
|
||||
"FullTrainingConfig",
|
||||
"LossConfig",
|
||||
"LossFunction",
|
||||
"PLTrainerConfig",
|
||||
@ -64,21 +60,18 @@ __all__ = [
|
||||
"add_echo",
|
||||
"build_augmentations",
|
||||
"build_clip_labeler",
|
||||
"build_clipper",
|
||||
"build_loss",
|
||||
"build_train_dataset",
|
||||
"build_train_loader",
|
||||
"build_trainer",
|
||||
"build_val_dataset",
|
||||
"build_val_loader",
|
||||
"load_full_training_config",
|
||||
"load_label_config",
|
||||
"load_train_config",
|
||||
"mask_frequency",
|
||||
"mask_time",
|
||||
"mix_audio",
|
||||
"scale_volume",
|
||||
"select_subclip",
|
||||
"train",
|
||||
"warp_spectrogram",
|
||||
]
|
||||
|
||||
@ -11,11 +11,10 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import scale_geometry, shift_geometry
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.train.clips import get_subclip_annotation
|
||||
from batdetect2.typing import Augmentation
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
from batdetect2.utils.arrays import adjust_width
|
||||
from batdetect2.audio.clips import get_subclip_annotation
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.typing import AudioLoader, Augmentation
|
||||
|
||||
__all__ = [
|
||||
"AugmentationConfig",
|
||||
|
||||
@ -5,19 +5,21 @@ from lightning.pytorch.callbacks import Callback
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate import Evaluator
|
||||
from batdetect2.postprocess import get_raw_predictions
|
||||
from batdetect2.logging import get_image_logger
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.logging import get_image_plotter
|
||||
from batdetect2.typing.evaluate import ClipEvaluation
|
||||
from batdetect2.typing.models import ModelOutput
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing.train import TrainExample
|
||||
from batdetect2.typing import (
|
||||
ClipEvaluation,
|
||||
EvaluatorProtocol,
|
||||
ModelOutput,
|
||||
RawPrediction,
|
||||
TrainExample,
|
||||
)
|
||||
|
||||
|
||||
class ValidationMetrics(Callback):
|
||||
def __init__(self, evaluator: Evaluator):
|
||||
def __init__(self, evaluator: EvaluatorProtocol):
|
||||
super().__init__()
|
||||
|
||||
self.evaluator = evaluator
|
||||
@ -32,12 +34,12 @@ class ValidationMetrics(Callback):
|
||||
assert isinstance(dataset, ValidationDataset)
|
||||
return dataset
|
||||
|
||||
def plot_examples(
|
||||
def generate_plots(
|
||||
self,
|
||||
pl_module: LightningModule,
|
||||
evaluated_clips: List[ClipEvaluation],
|
||||
):
|
||||
plotter = get_image_plotter(pl_module.logger) # type: ignore
|
||||
plotter = get_image_logger(pl_module.logger) # type: ignore
|
||||
|
||||
if plotter is None:
|
||||
return
|
||||
@ -64,7 +66,7 @@ class ValidationMetrics(Callback):
|
||||
)
|
||||
|
||||
self.log_metrics(pl_module, clip_evaluations)
|
||||
self.plot_examples(pl_module, clip_evaluations)
|
||||
self.generate_plots(pl_module, clip_evaluations)
|
||||
|
||||
return super().on_validation_epoch_end(trainer, pl_module)
|
||||
|
||||
@ -86,8 +88,7 @@ class ValidationMetrics(Callback):
|
||||
batch_idx: int,
|
||||
dataloader_idx: int = 0,
|
||||
) -> None:
|
||||
postprocessor = pl_module.model.postprocessor
|
||||
targets = pl_module.model.targets
|
||||
model = pl_module.model
|
||||
dataset = self.get_dataset(trainer)
|
||||
|
||||
clip_annotations = [
|
||||
@ -95,15 +96,14 @@ class ValidationMetrics(Callback):
|
||||
for example_idx in batch.idx
|
||||
]
|
||||
|
||||
predictions = get_raw_predictions(
|
||||
clip_detections = model.postprocessor(
|
||||
outputs,
|
||||
start_times=[
|
||||
clip_annotation.clip.start_time
|
||||
for clip_annotation in clip_annotations
|
||||
],
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||
)
|
||||
predictions = [
|
||||
to_raw_predictions(clip_dets.numpy(), targets=model.targets)
|
||||
for clip_dets in clip_detections
|
||||
]
|
||||
|
||||
self._clip_annotations.extend(clip_annotations)
|
||||
self._predictions.extend(predictions)
|
||||
|
||||
@ -3,27 +3,16 @@ from typing import Optional, Union
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate import EvaluationConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.train.augmentations import (
|
||||
DEFAULT_AUGMENTATION_CONFIG,
|
||||
AugmentationsConfig,
|
||||
)
|
||||
from batdetect2.train.clips import (
|
||||
ClipConfig,
|
||||
PaddedClipConfig,
|
||||
RandomClipConfig,
|
||||
)
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
|
||||
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
|
||||
from batdetect2.train.labels import LabelConfig
|
||||
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
|
||||
from batdetect2.train.losses import LossConfig
|
||||
|
||||
__all__ = [
|
||||
"TrainingConfig",
|
||||
"load_train_config",
|
||||
"FullTrainingConfig",
|
||||
"load_full_training_config",
|
||||
]
|
||||
|
||||
|
||||
@ -48,30 +37,6 @@ class PLTrainerConfig(BaseConfig):
|
||||
val_check_interval: Optional[Union[int, float]] = None
|
||||
|
||||
|
||||
class ValLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: PaddedClipConfig()
|
||||
)
|
||||
|
||||
|
||||
class TrainLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
|
||||
batch_size: int = 8
|
||||
|
||||
shuffle: bool = False
|
||||
|
||||
augmentations: AugmentationsConfig = Field(
|
||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
|
||||
)
|
||||
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: PaddedClipConfig()
|
||||
)
|
||||
|
||||
|
||||
class OptimizerConfig(BaseConfig):
|
||||
learning_rate: float = 1e-3
|
||||
t_max: int = 100
|
||||
@ -80,13 +45,12 @@ class OptimizerConfig(BaseConfig):
|
||||
class TrainingConfig(BaseConfig):
|
||||
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
|
||||
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
|
||||
|
||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
|
||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||
|
||||
|
||||
def load_train_config(
|
||||
@ -94,18 +58,3 @@ def load_train_config(
|
||||
field: Optional[str] = None,
|
||||
) -> TrainingConfig:
|
||||
return load_config(path, schema=TrainingConfig, field=field)
|
||||
|
||||
|
||||
class FullTrainingConfig(ModelConfig):
|
||||
"""Full training configuration."""
|
||||
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||
|
||||
|
||||
def load_full_training_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> FullTrainingConfig:
|
||||
"""Load the full training configuration."""
|
||||
return load_config(path, schema=FullTrainingConfig, field=field)
|
||||
|
||||
@ -2,22 +2,30 @@ from typing import List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from batdetect2.plotting.clips import build_audio_loader
|
||||
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
||||
from batdetect2.audio.clips import PaddedClipConfig
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.train.augmentations import (
|
||||
DEFAULT_AUGMENTATION_CONFIG,
|
||||
AugmentationsConfig,
|
||||
RandomAudioSource,
|
||||
build_augmentations,
|
||||
)
|
||||
from batdetect2.train.clips import build_clipper
|
||||
from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.typing import ClipperProtocol, TrainExample
|
||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||
from batdetect2.typing.train import Augmentation, ClipLabeller
|
||||
from batdetect2.utils.arrays import adjust_width
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
Augmentation,
|
||||
ClipLabeller,
|
||||
ClipperProtocol,
|
||||
PreprocessorProtocol,
|
||||
TrainExample,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TrainingDataset",
|
||||
@ -139,6 +147,22 @@ class ValidationDataset(Dataset):
|
||||
)
|
||||
|
||||
|
||||
class TrainLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
|
||||
batch_size: int = 8
|
||||
|
||||
shuffle: bool = False
|
||||
|
||||
augmentations: AugmentationsConfig = Field(
|
||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
|
||||
)
|
||||
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: PaddedClipConfig()
|
||||
)
|
||||
|
||||
|
||||
def build_train_loader(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
@ -173,6 +197,14 @@ def build_train_loader(
|
||||
)
|
||||
|
||||
|
||||
class ValLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: PaddedClipConfig()
|
||||
)
|
||||
|
||||
|
||||
def build_val_loader(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
|
||||
@ -13,14 +13,10 @@ import torch
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.targets import build_targets, iterate_encoded_sound_events
|
||||
from batdetect2.typing import (
|
||||
ClipLabeller,
|
||||
Heatmaps,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.typing import ClipLabeller, Heatmaps, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"LabelConfig",
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
@ -6,11 +6,17 @@ from soundevent.data import PathLike
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.train.config import FullTrainingConfig
|
||||
from batdetect2.plotting.clips import build_preprocessor
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.targets.targets import build_targets
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.typing import ModelOutput, TrainExample
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
__all__ = [
|
||||
"TrainingModule",
|
||||
]
|
||||
@ -21,7 +27,8 @@ class TrainingModule(L.LightningModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FullTrainingConfig,
|
||||
config: "BatDetect2Config",
|
||||
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
learning_rate: float = 0.001,
|
||||
t_max: int = 100,
|
||||
model: Optional[Model] = None,
|
||||
@ -31,6 +38,7 @@ class TrainingModule(L.LightningModule):
|
||||
|
||||
self.save_hyperparameters(logger=False)
|
||||
|
||||
self.input_samplerate = input_samplerate
|
||||
self.config = config
|
||||
self.learning_rate = learning_rate
|
||||
self.t_max = t_max
|
||||
@ -39,7 +47,23 @@ class TrainingModule(L.LightningModule):
|
||||
loss = build_loss(self.config.train.loss)
|
||||
|
||||
if model is None:
|
||||
model = build_model(self.config)
|
||||
targets = build_targets(self.config.targets)
|
||||
|
||||
preprocessor = build_preprocessor(
|
||||
config=self.config.preprocess,
|
||||
input_samplerate=self.input_samplerate,
|
||||
)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor, config=self.config.postprocess
|
||||
)
|
||||
|
||||
model = build_model(
|
||||
config=self.config.model,
|
||||
targets=targets,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
self.loss = loss
|
||||
self.model = model
|
||||
@ -74,16 +98,18 @@ class TrainingModule(L.LightningModule):
|
||||
|
||||
def load_model_from_checkpoint(
|
||||
path: PathLike,
|
||||
) -> Tuple[Model, FullTrainingConfig]:
|
||||
) -> Tuple[Model, "BatDetect2Config"]:
|
||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||
return module.model, module.config
|
||||
|
||||
|
||||
def build_training_module(
|
||||
config: Optional[FullTrainingConfig] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
t_max: int = 200,
|
||||
) -> TrainingModule:
|
||||
config = config or FullTrainingConfig()
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
return TrainingModule(
|
||||
config=config,
|
||||
learning_rate=config.train.optimizer.learning_rate,
|
||||
|
||||
@ -27,7 +27,7 @@ from loguru import logger
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -1,29 +1,31 @@
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from lightning import Trainer, seed_everything
|
||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
|
||||
from batdetect2.logging import build_logger
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.config import (
|
||||
FullTrainingConfig,
|
||||
)
|
||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.lightning import TrainingModule, build_training_module
|
||||
from batdetect2.train.logging import build_logger
|
||||
from batdetect2.typing import (
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
from batdetect2.typing.train import ClipLabeller
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipLabeller,
|
||||
EvaluatorProtocol,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"build_trainer",
|
||||
@ -36,13 +38,12 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||
def train(
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
labeller: Optional["ClipLabeller"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
trainer: Optional[Trainer] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
labeller: Optional[ClipLabeller] = None,
|
||||
config: Optional[FullTrainingConfig] = None,
|
||||
model_path: Optional[data.PathLike] = None,
|
||||
train_workers: Optional[int] = None,
|
||||
val_workers: Optional[int] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
@ -51,17 +52,20 @@ def train(
|
||||
run_name: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
if seed is not None:
|
||||
seed_everything(seed)
|
||||
|
||||
config = config or FullTrainingConfig()
|
||||
config = config or BatDetect2Config()
|
||||
|
||||
targets = targets or build_targets(config.targets)
|
||||
targets = targets or build_targets(config=config.targets)
|
||||
|
||||
preprocessor = preprocessor or build_preprocessor(config.preprocess)
|
||||
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader(
|
||||
config=config.preprocess.audio
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
)
|
||||
|
||||
labeller = labeller or build_clip_labeler(
|
||||
@ -93,18 +97,15 @@ def train(
|
||||
else None
|
||||
)
|
||||
|
||||
if model_path is not None:
|
||||
logger.debug("Loading model from: {path}", path=model_path)
|
||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||
else:
|
||||
module = build_training_module(
|
||||
config,
|
||||
t_max=config.train.optimizer.t_max * len(train_dataloader),
|
||||
)
|
||||
module = build_training_module(
|
||||
config,
|
||||
t_max=config.train.optimizer.t_max * len(train_dataloader),
|
||||
)
|
||||
|
||||
trainer = trainer or build_trainer(
|
||||
config,
|
||||
targets=targets,
|
||||
evaluator=build_evaluator(config.train.validation, targets=targets),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
@ -121,8 +122,8 @@ def train(
|
||||
|
||||
|
||||
def build_trainer_callbacks(
|
||||
targets: TargetProtocol,
|
||||
config: FullTrainingConfig,
|
||||
targets: "TargetProtocol",
|
||||
evaluator: Optional["EvaluatorProtocol"] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
@ -136,13 +137,12 @@ def build_trainer_callbacks(
|
||||
if run_name is not None:
|
||||
checkpoint_dir = checkpoint_dir / run_name
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
evaluator = evaluator or build_evaluator(targets=targets)
|
||||
|
||||
return [
|
||||
ModelCheckpoint(
|
||||
dirpath=str(checkpoint_dir),
|
||||
save_top_k=1,
|
||||
filename="best-{epoch:02d}-{val_loss:.0f}",
|
||||
monitor="total_loss/val",
|
||||
),
|
||||
ValidationMetrics(evaluator),
|
||||
@ -150,8 +150,9 @@ def build_trainer_callbacks(
|
||||
|
||||
|
||||
def build_trainer(
|
||||
conf: FullTrainingConfig,
|
||||
targets: TargetProtocol,
|
||||
conf: "BatDetect2Config",
|
||||
targets: "TargetProtocol",
|
||||
evaluator: Optional["EvaluatorProtocol"] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
@ -181,7 +182,7 @@ def build_trainer(
|
||||
logger=train_logger,
|
||||
callbacks=build_trainer_callbacks(
|
||||
targets,
|
||||
config=conf,
|
||||
evaluator=evaluator,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
|
||||
@ -1,4 +1,10 @@
|
||||
from batdetect2.typing.evaluate import MatchEvaluation, MetricsProtocol
|
||||
from batdetect2.typing.evaluate import (
|
||||
ClipEvaluation,
|
||||
EvaluatorProtocol,
|
||||
MatchEvaluation,
|
||||
MetricsProtocol,
|
||||
PlotterProtocol,
|
||||
)
|
||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||
from batdetect2.typing.postprocess import (
|
||||
BatDetect2Prediction,
|
||||
@ -9,10 +15,10 @@ from batdetect2.typing.postprocess import (
|
||||
from batdetect2.typing.preprocess import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
SpectrogramBuilder,
|
||||
)
|
||||
from batdetect2.typing.targets import (
|
||||
Position,
|
||||
ROITargetMapper,
|
||||
Size,
|
||||
SoundEventDecoder,
|
||||
SoundEventEncoder,
|
||||
@ -34,6 +40,7 @@ __all__ = [
|
||||
"Augmentation",
|
||||
"BackboneModel",
|
||||
"BatDetect2Prediction",
|
||||
"ClipEvaluation",
|
||||
"ClipLabeller",
|
||||
"ClipperProtocol",
|
||||
"DetectionModel",
|
||||
@ -44,15 +51,17 @@ __all__ = [
|
||||
"MatchEvaluation",
|
||||
"MetricsProtocol",
|
||||
"ModelOutput",
|
||||
"PlotterProtocol",
|
||||
"Position",
|
||||
"PostprocessorProtocol",
|
||||
"PreprocessorProtocol",
|
||||
"ROITargetMapper",
|
||||
"RawPrediction",
|
||||
"Size",
|
||||
"SoundEventDecoder",
|
||||
"SoundEventEncoder",
|
||||
"SoundEventFilter",
|
||||
"SpectrogramBuilder",
|
||||
"TargetProtocol",
|
||||
"TrainExample",
|
||||
"EvaluatorProtocol",
|
||||
]
|
||||
|
||||
@ -14,7 +14,11 @@ from typing import (
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"EvaluatorProtocol",
|
||||
"MetricsProtocol",
|
||||
"MatchEvaluation",
|
||||
]
|
||||
@ -50,6 +54,26 @@ class MatchEvaluation:
|
||||
|
||||
return self.pred_class_scores[pred_class]
|
||||
|
||||
def is_true_positive(self, threshold: float = 0) -> bool:
|
||||
return (
|
||||
self.gt_det
|
||||
and self.pred_score > threshold
|
||||
and self.gt_class == self.pred_class
|
||||
)
|
||||
|
||||
def is_false_positive(self, threshold: float = 0) -> bool:
|
||||
return self.gt_det is None and self.pred_score > threshold
|
||||
|
||||
def is_false_negative(self, threshold: float = 0) -> bool:
|
||||
return self.gt_det and self.pred_score <= threshold
|
||||
|
||||
def is_cross_trigger(self, threshold: float = 0) -> bool:
|
||||
return (
|
||||
self.gt_det
|
||||
and self.pred_score > threshold
|
||||
and self.gt_class != self.pred_class
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipEvaluation:
|
||||
@ -87,3 +111,21 @@ class PlotterProtocol(Protocol):
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Iterable[Tuple[str, Figure]]: ...
|
||||
|
||||
|
||||
class EvaluatorProtocol(Protocol):
|
||||
targets: TargetProtocol
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> List[ClipEvaluation]: ...
|
||||
|
||||
def compute_metrics(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]: ...
|
||||
|
||||
def generate_plots(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Iterable[Tuple[str, Figure]]: ...
|
||||
|
||||
@ -12,7 +12,7 @@ system that deal with model predictions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional, Protocol
|
||||
from typing import List, NamedTuple, Optional, Protocol, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -47,15 +47,13 @@ class GeometryDecoder(Protocol):
|
||||
|
||||
|
||||
class RawPrediction(NamedTuple):
|
||||
"""Intermediate representation of a single detected sound event."""
|
||||
|
||||
geometry: data.Geometry
|
||||
detection_score: float
|
||||
class_scores: np.ndarray
|
||||
features: np.ndarray
|
||||
|
||||
|
||||
class DetectionsArray(NamedTuple):
|
||||
class ClipDetectionsArray(NamedTuple):
|
||||
scores: np.ndarray
|
||||
sizes: np.ndarray
|
||||
class_scores: np.ndarray
|
||||
@ -64,7 +62,7 @@ class DetectionsArray(NamedTuple):
|
||||
features: np.ndarray
|
||||
|
||||
|
||||
class DetectionsTensor(NamedTuple):
|
||||
class ClipDetectionsTensor(NamedTuple):
|
||||
scores: torch.Tensor
|
||||
sizes: torch.Tensor
|
||||
class_scores: torch.Tensor
|
||||
@ -72,8 +70,8 @@ class DetectionsTensor(NamedTuple):
|
||||
frequencies: torch.Tensor
|
||||
features: torch.Tensor
|
||||
|
||||
def numpy(self) -> DetectionsArray:
|
||||
return DetectionsArray(
|
||||
def numpy(self) -> ClipDetectionsArray:
|
||||
return ClipDetectionsArray(
|
||||
scores=self.scores.detach().cpu().numpy(),
|
||||
sizes=self.sizes.detach().cpu().numpy(),
|
||||
class_scores=self.class_scores.detach().cpu().numpy(),
|
||||
@ -92,10 +90,8 @@ class BatDetect2Prediction:
|
||||
class PostprocessorProtocol(Protocol):
|
||||
"""Protocol defining the interface for the full postprocessing pipeline."""
|
||||
|
||||
def __call__(self, output: ModelOutput) -> List[DetectionsTensor]: ...
|
||||
|
||||
def get_detections(
|
||||
def __call__(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[DetectionsTensor]: ...
|
||||
start_times: Optional[Sequence[float]] = None,
|
||||
) -> List[ClipDetectionsTensor]: ...
|
||||
|
||||
@ -32,6 +32,8 @@ class AudioLoader(Protocol):
|
||||
allows for different loading strategies or implementations.
|
||||
"""
|
||||
|
||||
samplerate: int
|
||||
|
||||
def load_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
@ -125,22 +127,6 @@ class SpectrogramBuilder(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class AudioPipeline(Protocol):
|
||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class SpectrogramPipeline(Protocol):
|
||||
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class PreprocessorProtocol(Protocol):
|
||||
"""Defines a high-level interface for the complete preprocessing pipeline."""
|
||||
|
||||
@ -152,11 +138,13 @@ class PreprocessorProtocol(Protocol):
|
||||
|
||||
output_samplerate: float
|
||||
|
||||
audio_pipeline: AudioPipeline
|
||||
|
||||
spectrogram_pipeline: SpectrogramPipeline
|
||||
|
||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def process_audio(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
|
||||
return self(torch.tensor(wav)).numpy()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user