Compare commits

...

9 Commits

Author SHA1 Message Date
mbsantiago
4cd983a2c2 Better train cli arg names 2025-09-18 09:44:27 +01:00
mbsantiago
e65df81db2 Evaluate using Lightning too to handle device changes 2025-09-18 09:28:21 +01:00
mbsantiago
6c25787123 Logging is not just for training 2025-09-18 09:27:40 +01:00
mbsantiago
8c80402f08 Move clips and audio to dedicated module 2025-09-18 09:27:24 +01:00
mbsantiago
b81a882b58 Add metrics and plots 2025-09-17 10:30:24 +01:00
mbsantiago
6e217380f2 Moved example target config to independent file 2025-09-17 10:29:30 +01:00
mbsantiago
957c0735d2 Starting new API 2025-09-16 19:39:30 +01:00
mbsantiago
bbb96b33a2 Config restructuring 2025-09-16 18:57:56 +01:00
mbsantiago
7d6cba5465 Restructuring 2025-09-16 13:38:38 +01:00
75 changed files with 3679 additions and 2197 deletions

View File

@ -1,66 +1,27 @@
targets:
detection_target:
name: bat
match_if:
name: all_of
conditions:
- name: has_tag
tag: { key: event, value: Echolocation }
- name: not
condition:
name: has_tag
tag: { key: class, value: Unknown }
assign_tags:
- key: class
value: Bat
classification_targets:
- name: myomys
tags:
- key: class
value: Myotis mystacinus
- name: pippip
tags:
- key: class
value: Pipistrellus pipistrellus
- name: eptser
tags:
- key: class
value: Eptesicus serotinus
- name: rhifer
tags:
- key: class
value: Rhinolophus ferrumequinum
roi:
name: anchor_bbox
anchor: top-left
audio:
samplerate: 256000
resample:
enabled: True
method: "poly"
preprocess:
audio:
samplerate: 256000
resample:
enabled: True
method: "poly"
spectrogram:
stft:
window_duration: 0.002
window_overlap: 0.75
window_fn: hann
frequencies:
max_freq: 120000
min_freq: 10000
size:
height: 128
resize_factor: 0.5
transforms:
- name: pcen
time_constant: 0.1
gain: 0.98
bias: 2
power: 0.5
- name: spectral_mean_substraction
stft:
window_duration: 0.002
window_overlap: 0.75
window_fn: hann
frequencies:
max_freq: 120000
min_freq: 10000
size:
height: 128
resize_factor: 0.5
spectrogram_transforms:
- name: pcen
time_constant: 0.1
gain: 0.98
bias: 2
power: 0.5
- name: spectral_mean_substraction
postprocess:
nms_kernel_size: 9
@ -102,23 +63,57 @@ model:
out_channels: 32
train:
learning_rate: 0.001
t_max: 100
optimizer:
learning_rate: 0.001
t_max: 100
labels:
sigma: 3
trainer:
max_epochs: 5
max_epochs: 10
check_val_every_n_epoch: 5
train_loader:
batch_size: 8
num_workers: 2
shuffle: True
clipping_strategy:
name: random_subclip
duration: 0.256
augmentations:
enabled: true
audio:
- name: mix_audio
probability: 0.2
min_weight: 0.3
max_weight: 0.7
- name: add_echo
probability: 0.2
max_delay: 0.005
min_weight: 0.0
max_weight: 1.0
spectrogram:
- name: scale_volume
probability: 0.2
min_scaling: 0.0
max_scaling: 2.0
- name: warp
probability: 0.2
delta: 0.04
- name: mask_time
probability: 0.2
max_perc: 0.05
max_masks: 3
- name: mask_freq
probability: 0.2
max_perc: 0.10
max_masks: 3
val_loader:
num_workers: 2
clipping_strategy:
@ -142,31 +137,28 @@ train:
logger:
name: csv
augmentations:
enabled: true
audio:
- name: mix_audio
probability: 0.2
min_weight: 0.3
max_weight: 0.7
- name: add_echo
probability: 0.2
max_delay: 0.005
min_weight: 0.0
max_weight: 1.0
spectrogram:
- name: scale_volume
probability: 0.2
min_scaling: 0.0
max_scaling: 2.0
- name: warp
probability: 0.2
delta: 0.04
- name: mask_time
probability: 0.2
max_perc: 0.05
max_masks: 3
- name: mask_freq
probability: 0.2
max_perc: 0.10
max_masks: 3
validation:
metrics:
- name: detection_ap
- name: detection_roc_auc
- name: classification_ap
- name: classification_roc_auc
- name: top_class_ap
- name: classification_balanced_accuracy
- name: clip_ap
- name: clip_roc_auc
evaluation:
match_strategy:
name: start_time_match
distance_threshold: 0.01
metrics:
- name: classification_ap
- name: detection_ap
plots:
- name: example_gallery
- name: example_clip
- name: detection_pr_curve
- name: classification_pr_curves
- name: detection_roc_curve
- name: classification_roc_curves

36
example_data/targets.yaml Normal file
View File

@ -0,0 +1,36 @@
detection_target:
name: bat
match_if:
name: all_of
conditions:
- name: has_tag
tag: { key: event, value: Echolocation }
- name: not
condition:
name: has_tag
tag: { key: class, value: Unknown }
assign_tags:
- key: class
value: Bat
classification_targets:
- name: myomys
tags:
- key: class
value: Myotis mystacinus
- name: pippip
tags:
- key: class
value: Pipistrellus pipistrellus
- name: eptser
tags:
- key: class
value: Eptesicus serotinus
- name: rhifer
tags:
- key: class
value: Rhinolophus ferrumequinum
roi:
name: anchor_bbox
anchor: top-left

179
src/batdetect2/api/base.py Normal file
View File

@ -0,0 +1,179 @@
from pathlib import Path
from typing import Optional, Sequence
from soundevent import data
from batdetect2.audio import build_audio_loader
from batdetect2.config import BatDetect2Config
from batdetect2.evaluate import build_evaluator, evaluate
from batdetect2.models import Model, build_model
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets.targets import build_targets
from batdetect2.train import train
from batdetect2.train.lightning import load_model_from_checkpoint
from batdetect2.typing import (
AudioLoader,
EvaluatorProtocol,
PostprocessorProtocol,
PreprocessorProtocol,
TargetProtocol,
)
class BatDetect2API:
def __init__(
self,
config: BatDetect2Config,
targets: TargetProtocol,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
evaluator: EvaluatorProtocol,
model: Model,
):
self.config = config
self.targets = targets
self.audio_loader = audio_loader
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.evaluator = evaluator
self.model = model
self.model.eval()
def train(
self,
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
train_workers: Optional[int] = None,
val_workers: Optional[int] = None,
checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
seed: Optional[int] = None,
):
train(
train_annotations=train_annotations,
val_annotations=val_annotations,
targets=self.targets,
config=self.config,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
train_workers=train_workers,
val_workers=val_workers,
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
seed=seed,
)
return self
def evaluate(
self,
test_annotations: Sequence[data.ClipAnnotation],
num_workers: Optional[int] = None,
output_dir: data.PathLike = ".",
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
):
return evaluate(
self.model,
test_annotations,
targets=self.targets,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
config=self.config,
num_workers=num_workers,
output_dir=output_dir,
experiment_name=experiment_name,
run_name=run_name,
)
@classmethod
def from_config(cls, config: BatDetect2Config):
targets = build_targets(config=config.targets)
audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
)
postprocessor = build_postprocessor(
preprocessor,
config=config.postprocess,
)
evaluator = build_evaluator(
config=config.evaluation,
targets=targets,
)
# NOTE: Better to have a separate instance of
# preprocessor and postprocessor as these may be moved
# to another device.
model = build_model(
config=config.model,
targets=targets,
preprocessor=build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
),
postprocessor=build_postprocessor(
preprocessor,
config=config.postprocess,
),
)
return cls(
config=config,
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,
postprocessor=postprocessor,
evaluator=evaluator,
model=model,
)
@classmethod
def from_checkpoint(
cls,
path: data.PathLike,
config: Optional[BatDetect2Config] = None,
):
model, stored_config = load_model_from_checkpoint(path)
config = config or stored_config
targets = build_targets(config=config.targets)
audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
)
postprocessor = build_postprocessor(
preprocessor,
config=config.postprocess,
)
evaluator = build_evaluator(
config=config.evaluation,
targets=targets,
)
return cls(
config=config,
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,
postprocessor=postprocessor,
evaluator=evaluator,
model=model,
)

View File

@ -0,0 +1,16 @@
from batdetect2.audio.clips import ClipConfig, build_clipper
from batdetect2.audio.loader import (
TARGET_SAMPLERATE_HZ,
AudioConfig,
SoundEventAudioLoader,
build_audio_loader,
)
__all__ = [
"TARGET_SAMPLERATE_HZ",
"AudioConfig",
"SoundEventAudioLoader",
"build_audio_loader",
"ClipConfig",
"build_clipper",
]

View File

@ -6,14 +6,19 @@ from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds, intervals_overlap
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipperProtocol
DEFAULT_TRAIN_CLIP_DURATION = 0.256
DEFAULT_MAX_EMPTY_CLIP = 0.1
__all__ = [
"build_clipper",
"ClipConfig",
]
clipper_registry: Registry[ClipperProtocol, []] = Registry("clipper")

View File

@ -0,0 +1,295 @@
from typing import Optional
import numpy as np
from numpy.typing import DTypeLike
from pydantic import Field
from scipy.signal import resample, resample_poly
from soundevent import audio, data
from soundfile import LibsndfileError
from batdetect2.core import BaseConfig
from batdetect2.typing import AudioLoader
__all__ = [
"SoundEventAudioLoader",
"build_audio_loader",
"load_file_audio",
"load_recording_audio",
"load_clip_audio",
"resample_audio",
]
TARGET_SAMPLERATE_HZ = 256_000
"""Default target sample rate in Hz used if resampling is enabled."""
class ResampleConfig(BaseConfig):
"""Configuration for audio resampling.
Attributes
----------
samplerate : int, default=256000
The target sample rate in Hz to resample the audio to. Must be > 0.
method : str, default="poly"
The resampling algorithm to use. Options:
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
Generally fast.
- "fourier": Resampling via Fourier method using
`scipy.signal.resample`. May handle non-integer
resampling factors differently.
"""
enabled: bool = True
method: str = "poly"
class AudioConfig(BaseConfig):
"""Configuration for loading and initial audio preprocessing."""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
resample: ResampleConfig = Field(default_factory=ResampleConfig)
def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader:
"""Factory function to create an AudioLoader based on configuration."""
config = config or AudioConfig()
return SoundEventAudioLoader(
samplerate=config.samplerate,
config=config.resample,
)
class SoundEventAudioLoader(AudioLoader):
"""Concrete implementation of the `AudioLoader`."""
def __init__(
self,
samplerate: int = TARGET_SAMPLERATE_HZ,
config: Optional[ResampleConfig] = None,
):
self.samplerate = samplerate
self.config = config or ResampleConfig()
def load_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> np.ndarray:
"""Load and preprocess audio directly from a file path."""
return load_file_audio(
path,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> np.ndarray:
"""Load and preprocess the entire audio for a Recording object."""
return load_recording_audio(
recording,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> np.ndarray:
"""Load and preprocess the audio segment defined by a Clip object."""
return load_clip_audio(
clip,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_file_audio(
path: data.PathLike,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess audio from a file path using specified config."""
try:
recording = data.Recording.from_file(path)
except LibsndfileError as e:
raise FileNotFoundError(
f"Could not load the recording at path: {path}. Error: {e}"
) from e
return load_recording_audio(
recording,
samplerate=samplerate,
config=config,
dtype=dtype,
audio_dir=audio_dir,
)
def load_recording_audio(
recording: data.Recording,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess the entire audio content of a recording using config."""
clip = data.Clip(
recording=recording,
start_time=0,
end_time=recording.duration,
)
return load_clip_audio(
clip,
samplerate=samplerate,
config=config,
dtype=dtype,
audio_dir=audio_dir,
)
def load_clip_audio(
clip: data.Clip,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess a specific audio clip segment based on config."""
try:
wav = (
audio.load_clip(clip, audio_dir=audio_dir)
.sel(channel=0)
.astype(dtype)
)
except LibsndfileError as e:
raise FileNotFoundError(
f"Could not load the recording at path: {clip.recording.path}. "
f"Error: {e}"
) from e
if not config or not config.enabled or samplerate is None:
return wav.data.astype(dtype)
sr = int(1 / wav.time.attrs["step"])
return resample_audio(
wav.data,
sr=sr,
samplerate=samplerate,
method=config.method,
)
def resample_audio(
wav: np.ndarray,
sr: int,
samplerate: int = TARGET_SAMPLERATE_HZ,
method: str = "poly",
) -> np.ndarray:
"""Resample an audio waveform DataArray to a target sample rate."""
if sr == samplerate:
return wav
if method == "poly":
return resample_audio_poly(
wav,
sr_orig=sr,
sr_new=samplerate,
)
elif method == "fourier":
return resample_audio_fourier(
wav,
sr_orig=sr,
sr_new=samplerate,
)
else:
raise NotImplementedError(
f"Resampling method '{method}' not implemented"
)
def resample_audio_poly(
array: np.ndarray,
sr_orig: int,
sr_new: int,
axis: int = -1,
) -> np.ndarray:
"""Resample a numpy array using `scipy.signal.resample_poly`.
This method is often preferred for signals when the ratio of new
to old sample rates can be expressed as a rational number. It uses
polyphase filtering.
Parameters
----------
array : np.ndarray
The input array to resample.
sr_orig : int
The original sample rate in Hz.
sr_new : int
The target sample rate in Hz.
axis : int, default=-1
The axis of `array` along which to resample.
Returns
-------
np.ndarray
The array resampled to the target sample rate.
Raises
------
ValueError
If sample rates are not positive.
"""
gcd = np.gcd(sr_orig, sr_new)
return resample_poly(
array,
sr_new // gcd,
sr_orig // gcd,
axis=axis,
)
def resample_audio_fourier(
array: np.ndarray,
sr_orig: int,
sr_new: int,
axis: int = -1,
) -> np.ndarray:
"""Resample a numpy array using `scipy.signal.resample`.
This method uses FFTs to resample the signal.
Parameters
----------
array : np.ndarray
The input array to resample.
num : int
The desired number of samples in the output array along `axis`.
axis : int, default=-1
The axis of `array` along which to resample.
Returns
-------
np.ndarray
The array resampled to have `num` samples along `axis`.
Raises
------
ValueError
If `num` is negative.
"""
ratio = sr_new / sr_orig
return resample( # type: ignore
array,
int(array.shape[axis] * ratio),
axis=axis,
)

View File

@ -1,10 +1,15 @@
import os
import click
from batdetect2 import api
from batdetect2.cli.base import cli
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
from batdetect2.types import ProcessingConfiguration
from batdetect2.utils.detector_utils import save_results_to_file
DEFAULT_MODEL_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"models",
"checkpoints",
"Net2DFast_UK_same.pth.tar",
)
@cli.command()
@ -74,6 +79,9 @@ def detect(
Input files should be short in duration e.g. < 30 seconds.
"""
from batdetect2 import api
from batdetect2.utils.detector_utils import save_results_to_file
click.echo(f"Loading model: {args['model_path']}")
model, params = api.load_model(args["model_path"])
@ -123,7 +131,7 @@ def detect(
click.echo(f" {err}")
def print_config(config: ProcessingConfiguration):
def print_config(config):
"""Print the processing configuration."""
click.echo("\nProcessing Configuration:")
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")

View File

@ -4,7 +4,6 @@ from typing import Optional
import click
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
__all__ = ["data"]
@ -33,6 +32,8 @@ def summary(
field: Optional[str] = None,
base_dir: Optional[Path] = None,
):
from batdetect2.data import load_dataset_from_config
base_dir = base_dir or Path.cwd()
dataset = load_dataset_from_config(
dataset_config,

View File

@ -6,18 +6,21 @@ import click
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate.evaluate import evaluate
from batdetect2.train.lightning import load_model_from_checkpoint
__all__ = ["evaluate_command"]
DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@cli.command(name="evaluate")
@click.argument("model-path", type=click.Path(exists=True))
@click.argument("test_dataset", type=click.Path(exists=True))
@click.option("--output-dir", type=click.Path())
@click.option("--workers", type=int)
@click.option("--config", "config_path", type=click.Path())
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
@click.option("--experiment-name", type=str)
@click.option("--run-name", type=str)
@click.option("--workers", "num_workers", type=int)
@click.option(
"-v",
"--verbose",
@ -27,10 +30,17 @@ __all__ = ["evaluate_command"]
def evaluate_command(
model_path: Path,
test_dataset: Path,
output_dir: Optional[Path] = None,
workers: Optional[int] = None,
config_path: Optional[Path],
output_dir: Path = DEFAULT_OUTPUT_DIR,
num_workers: Optional[int] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
verbose: int = 0,
):
from batdetect2.api.base import BatDetect2API
from batdetect2.config import load_full_config
from batdetect2.data import load_dataset_from_config
logger.remove()
if verbose == 0:
log_level = "WARNING"
@ -48,16 +58,16 @@ def evaluate_command(
num_annotations=len(test_annotations),
)
model, train_config = load_model_from_checkpoint(model_path)
config = None
if config_path is not None:
config = load_full_config(config_path)
df, results = evaluate(
model,
api = BatDetect2API.from_checkpoint(model_path, config=config)
api.evaluate(
test_annotations,
config=train_config,
num_workers=workers,
num_workers=num_workers,
output_dir=output_dir,
experiment_name=experiment_name,
run_name=run_name,
)
print(results)
if output_dir:
df.to_csv(output_dir / "results.csv")

View File

@ -6,13 +6,6 @@ import click
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config
from batdetect2.train import (
FullTrainingConfig,
load_full_training_config,
train,
)
__all__ = ["train_command"]
@ -20,8 +13,8 @@ __all__ = ["train_command"]
@cli.command(name="train")
@click.argument("train_dataset", type=click.Path(exists=True))
@click.option("--val-dataset", type=click.Path(exists=True))
@click.option("--model-path", type=click.Path(exists=True))
@click.option("--targets", type=click.Path(exists=True))
@click.option("--model", "model_path", type=click.Path(exists=True))
@click.option("--targets", "targets_config", type=click.Path(exists=True))
@click.option("--ckpt-dir", type=click.Path(exists=True))
@click.option("--log-dir", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True))
@ -44,7 +37,7 @@ def train_command(
ckpt_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
config: Optional[Path] = None,
targets: Optional[Path] = None,
targets_config: Optional[Path] = None,
config_field: Optional[str] = None,
seed: Optional[int] = None,
train_workers: int = 0,
@ -53,6 +46,14 @@ def train_command(
run_name: Optional[str] = None,
verbose: int = 0,
):
from batdetect2.api.base import BatDetect2API
from batdetect2.config import (
BatDetect2Config,
load_full_config,
)
from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config
logger.remove()
if verbose == 0:
log_level = "WARNING"
@ -61,21 +62,20 @@ def train_command(
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Initiating training process...")
logger.info("Loading training configuration...")
logger.info("Loading configuration...")
conf = (
load_full_training_config(config, field=config_field)
load_full_config(config, field=config_field)
if config is not None
else FullTrainingConfig()
else BatDetect2Config()
)
if targets is not None:
if targets_config is not None:
logger.info("Loading targets configuration...")
targets_config = load_target_config(targets)
conf = conf.model_copy(update=dict(targets=targets_config))
conf = conf.model_copy(
update=dict(targets=load_target_config(targets_config))
)
logger.info("Loading training dataset...")
train_annotations = load_dataset_from_config(train_dataset)
@ -95,16 +95,20 @@ def train_command(
logger.debug("No validation directory provided.")
logger.info("Configuration and data loaded. Starting training...")
train(
if model_path is None:
api = BatDetect2API.from_config(conf)
else:
api = BatDetect2API.from_checkpoint(model_path)
return api.train(
train_annotations=train_annotations,
val_annotations=val_annotations,
config=conf,
model_path=model_path,
train_workers=train_workers,
val_workers=val_workers,
experiment_name=experiment_name,
log_dir=log_dir,
checkpoint_dir=ckpt_dir,
seed=seed,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
seed=seed,
)

View File

@ -11,7 +11,6 @@ from soundevent import data
from soundevent.geometry import compute_bounds
from soundevent.types import ClassMapper
from batdetect2.targets.terms import get_term_from_key
from batdetect2.types import (
Annotation,
AudioLoaderAnnotationGroup,
@ -173,18 +172,9 @@ def annotation_to_sound_event_annotation(
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
sound_event=sound_event,
tags=[
data.Tag(
term=get_term_from_key(label_key),
value=annotation["class"],
),
data.Tag(
term=get_term_from_key(event_key),
value=annotation["event"],
),
data.Tag(
term=get_term_from_key(individual_key),
value=str(annotation["individual"]),
),
data.Tag(key=label_key, value=annotation["class"]),
data.Tag(key=event_key, value=annotation["event"]),
data.Tag(key=individual_key, value=str(annotation["individual"])),
],
)
@ -219,17 +209,11 @@ def annotation_to_sound_event_prediction(
tags=[
data.PredictedTag(
score=annotation["class_prob"],
tag=data.Tag(
term=get_term_from_key(label_key),
value=annotation["class"],
),
tag=data.Tag(key=label_key, value=annotation["class"]),
),
data.PredictedTag(
score=annotation["det_prob"],
tag=data.Tag(
term=get_term_from_key(event_key),
value=annotation["event"],
),
tag=data.Tag(key=event_key, value=annotation["event"]),
),
],
)

40
src/batdetect2/config.py Normal file
View File

@ -0,0 +1,40 @@
from typing import Literal, Optional
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.audio import AudioConfig
from batdetect2.core import BaseConfig
from batdetect2.core.configs import load_config
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.models.config import BackboneConfig
from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.targets.config import TargetConfig
from batdetect2.train.config import TrainingConfig
__all__ = [
"BatDetect2Config",
"load_full_config",
]
class BatDetect2Config(BaseConfig):
config_version: Literal["v1"] = "v1"
train: TrainingConfig = Field(default_factory=TrainingConfig)
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
audio: AudioConfig = Field(default_factory=AudioConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
def load_full_config(
path: PathLike,
field: Optional[str] = None,
) -> BatDetect2Config:
return load_config(path, schema=BatDetect2Config, field=field)

View File

@ -0,0 +1,8 @@
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.registries import Registry
__all__ = [
"BaseConfig",
"load_config",
"Registry",
]

View File

@ -27,7 +27,7 @@ class BaseConfig(BaseModel):
and serialization capabilities.
"""
model_config = ConfigDict(extra="ignore")
model_config = ConfigDict(extra="forbid")
def to_yaml_string(
self,

View File

@ -1,7 +1,13 @@
import sys
from typing import Generic, Protocol, Type, TypeVar
from pydantic import BaseModel
from typing_extensions import ParamSpec
from typing_extensions import assert_type
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
__all__ = [
"Registry",
@ -39,7 +45,6 @@ class Registry(Generic[T_Type, P_Type]):
config_cls: Type[T_Config],
logic_cls: LogicProtocol[T_Config, T_Type, P_Type],
) -> None:
"""A decorator factory to register a new item."""
fields = config_cls.model_fields
if "name" not in fields:

View File

@ -18,7 +18,7 @@ from uuid import uuid5
from pydantic import Field
from soundevent import data, io
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.data.annotations.types import AnnotatedDataset
__all__ = [

View File

@ -33,7 +33,7 @@ from loguru import logger
from pydantic import Field, ValidationError
from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.data.annotations.legacy import (
FileAnnotation,
file_annotation_to_clip,

View File

@ -1,6 +1,6 @@
from pathlib import Path
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
__all__ = [
"AnnotatedDataset",

View File

@ -5,8 +5,8 @@ from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]

View File

@ -25,7 +25,7 @@ from loguru import logger
from pydantic import Field
from soundevent import data, io
from batdetect2.configs import BaseConfig, load_config
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.data.annotations import (
AnnotatedDataset,
AnnotationFormats,

View File

@ -4,8 +4,8 @@ from typing import Annotated, Dict, List, Literal, Optional, Union
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.data.conditions import (
SoundEventCondition,
SoundEventConditionConfig,

View File

@ -1,9 +1,11 @@
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluate import evaluate
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
__all__ = [
"EvaluationConfig",
"load_evaluation_config",
"evaluate",
"Evaluator",
"build_evaluator",
]

View File

@ -4,8 +4,8 @@ from pydantic import Field
from soundevent import data
from soundevent.evaluation import compute_affinity
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.typing.evaluate import AffinityFunction
affinity_functions: Registry[AffinityFunction, []] = Registry(

View File

@ -3,14 +3,15 @@ from typing import List, Optional
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
from batdetect2.evaluate.metrics import (
ClassificationAPConfig,
DetectionAPConfig,
MetricConfig,
)
from batdetect2.evaluate.plots import ExampleGalleryConfig, PlotConfig
from batdetect2.evaluate.plots import PlotConfig
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
__all__ = [
"EvaluationConfig",
@ -20,18 +21,15 @@ __all__ = [
class EvaluationConfig(BaseConfig):
ignore_start_end: float = 0.01
match: MatchConfig = Field(default_factory=StartTimeMatchConfig)
match_strategy: MatchConfig = Field(default_factory=StartTimeMatchConfig)
metrics: List[MetricConfig] = Field(
default_factory=lambda: [
DetectionAPConfig(),
ClassificationAPConfig(),
]
)
plots: List[PlotConfig] = Field(
default_factory=lambda: [
ExampleGalleryConfig(),
]
)
plots: List[PlotConfig] = Field(default_factory=list)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
def load_evaluation_config(

View File

@ -0,0 +1,144 @@
from typing import List, NamedTuple, Optional, Sequence
import torch
from loguru import logger
from pydantic import Field
from soundevent import data
from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
from batdetect2.audio.clips import PaddedClipConfig
from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import (
AudioLoader,
ClipperProtocol,
PreprocessorProtocol,
)
__all__ = [
"TestDataset",
"build_test_dataset",
"build_test_loader",
]
class TestExample(NamedTuple):
spec: torch.Tensor
idx: torch.Tensor
start_time: torch.Tensor
end_time: torch.Tensor
class TestDataset(Dataset[TestExample]):
clip_annotations: List[data.ClipAnnotation]
def __init__(
self,
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
clipper: Optional[ClipperProtocol] = None,
audio_dir: Optional[data.PathLike] = None,
):
self.clip_annotations = list(clip_annotations)
self.clipper = clipper
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.audio_dir = audio_dir
def __len__(self):
return len(self.clip_annotations)
def __getitem__(self, idx: int) -> TestExample:
clip_annotation = self.clip_annotations[idx]
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
clip = clip_annotation.clip
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
wav_tensor = torch.tensor(wav).unsqueeze(0)
spectrogram = self.preprocessor(wav_tensor)
return TestExample(
spec=spectrogram,
idx=torch.tensor(idx),
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)
class TestLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
def build_test_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TestLoaderConfig] = None,
num_workers: Optional[int] = None,
) -> DataLoader[TestExample]:
logger.info("Building test data loader...")
config = config or TestLoaderConfig()
logger.opt(lazy=True).debug(
"Test data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
test_dataset = build_test_dataset(
clip_annotations,
audio_loader=audio_loader,
preprocessor=preprocessor,
config=config,
)
num_workers = num_workers or config.num_workers
return DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
collate_fn=_collate_fn,
)
def build_test_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TestLoaderConfig] = None,
) -> TestDataset:
logger.info("Building training dataset...")
config = config or TestLoaderConfig()
clipper = build_clipper(config=config.clipping_strategy)
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
return TestDataset(
clip_annotations,
audio_loader=audio_loader,
clipper=clipper,
preprocessor=preprocessor,
)
def _collate_fn(batch: List[TestExample]) -> TestExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TestExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)

View File

@ -1,92 +1,68 @@
from typing import List, Optional, Tuple
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Sequence
import pandas as pd
from lightning import Trainer
from soundevent import data
from batdetect2.evaluate.dataframe import extract_matches_dataframe
from batdetect2.audio import build_audio_loader
from batdetect2.evaluate.dataset import build_test_loader
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.metrics import ClassificationAP, DetectionAP
from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import build_logger
from batdetect2.models import Model
from batdetect2.plotting.clips import build_audio_loader
from batdetect2.postprocess import get_raw_predictions
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.config import FullTrainingConfig
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.train import build_val_loader
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
from batdetect2.typing import (
AudioLoader,
PreprocessorProtocol,
TargetProtocol,
)
DEFAULT_OUTPUT_DIR: Path = Path("outputs") / "evaluations"
def evaluate(
model: Model,
test_annotations: List[data.ClipAnnotation],
config: Optional[FullTrainingConfig] = None,
test_annotations: Sequence[data.ClipAnnotation],
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
num_workers: Optional[int] = None,
) -> Tuple[pd.DataFrame, dict]:
config = config or FullTrainingConfig()
output_dir: data.PathLike = DEFAULT_OUTPUT_DIR,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
):
from batdetect2.config import BatDetect2Config
audio_loader = build_audio_loader(config.preprocess.audio)
config = config or BatDetect2Config()
preprocessor = build_preprocessor(config.preprocess)
audio_loader = audio_loader or build_audio_loader()
targets = build_targets(config.targets)
labeller = build_clip_labeler(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
config=config.train.labels,
preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate,
)
loader = build_val_loader(
targets = targets or build_targets()
loader = build_test_loader(
test_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config.train.val_loader,
num_workers=num_workers,
)
dataset: ValidationDataset = loader.dataset # type: ignore
evaluator = build_evaluator(config=config.evaluation, targets=targets)
clip_annotations = []
predictions = []
evaluator = build_evaluator(config=config.evaluation)
for batch in loader:
outputs = model.detector(batch.spec)
clip_annotations = [
dataset.clip_annotations[int(example_idx)]
for example_idx in batch.idx
]
predictions = get_raw_predictions(
outputs,
start_times=[
clip_annotation.clip.start_time
for clip_annotation in clip_annotations
],
targets=targets,
postprocessor=model.postprocessor,
)
clip_annotations.extend(clip_annotations)
predictions.extend(predictions)
matches = evaluator.evaluate(clip_annotations, predictions)
df = extract_matches_dataframe(matches)
metrics = [
DetectionAP(),
ClassificationAP(class_names=targets.class_names),
]
results = {
name: value
for metric in metrics
for name, value in metric(matches).items()
}
return df, results
logger = build_logger(
config.evaluation.logger,
log_dir=Path(output_dir),
experiment_name=experiment_name,
run_name=run_name,
)
module = EvaluationModule(model, evaluator)
trainer = Trainer(logger=logger, enable_checkpointing=False)
return trainer.test(module, loader)

View File

@ -11,6 +11,7 @@ from batdetect2.evaluate.plots import build_plotter
from batdetect2.targets import build_targets
from batdetect2.typing.evaluate import (
ClipEvaluation,
EvaluatorProtocol,
MatcherProtocol,
MetricsProtocol,
PlotterProtocol,
@ -135,10 +136,10 @@ def build_evaluator(
matcher: Optional[MatcherProtocol] = None,
plots: Optional[List[PlotterProtocol]] = None,
metrics: Optional[List[MetricsProtocol]] = None,
) -> Evaluator:
) -> EvaluatorProtocol:
config = config or EvaluationConfig()
targets = targets or build_targets()
matcher = matcher or build_matcher(config.match)
matcher = matcher or build_matcher(config.match_strategy)
if metrics is None:
metrics = [
@ -147,7 +148,10 @@ def build_evaluator(
]
if plots is None:
plots = [build_plotter(config) for config in config.plots]
plots = [
build_plotter(config, targets.class_names)
for config in config.plots
]
return Evaluator(
config=config,

View File

@ -0,0 +1,86 @@
from typing import Sequence
from lightning import LightningModule
from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.evaluate.tables import FullEvaluationTable
from batdetect2.logging import get_image_logger, get_table_logger
from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import ClipEvaluation, EvaluatorProtocol
class EvaluationModule(LightningModule):
def __init__(
self,
model: Model,
evaluator: EvaluatorProtocol,
):
super().__init__()
self.model = model
self.evaluator = evaluator
self.clip_evaluations = []
def test_step(self, batch: TestExample):
dataset = self.get_dataset()
clip_annotations = [
dataset.clip_annotations[int(example_idx)]
for example_idx in batch.idx
]
outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor(
outputs,
start_times=[ca.clip.start_time for ca in clip_annotations],
)
predictions = [
to_raw_predictions(
clip_dets.numpy(),
targets=self.evaluator.targets,
)
for clip_dets in clip_detections
]
self.clip_evaluations.extend(
self.evaluator.evaluate(clip_annotations, predictions)
)
def on_test_epoch_start(self):
self.clip_evaluations = []
def on_test_epoch_end(self):
self.log_metrics(self.clip_evaluations)
self.plot_examples(self.clip_evaluations)
self.log_table(self.clip_evaluations)
def log_table(self, evaluated_clips: Sequence[ClipEvaluation]):
table_logger = get_table_logger(self.logger) # type: ignore
if table_logger is None:
return
df = FullEvaluationTable()(evaluated_clips)
table_logger("full_evaluation", df, 0)
def plot_examples(self, evaluated_clips: Sequence[ClipEvaluation]):
plotter = get_image_logger(self.logger) # type: ignore
if plotter is None:
return
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
plotter(figure_name, fig, self.global_step)
def log_metrics(self, evaluated_clips: Sequence[ClipEvaluation]):
metrics = self.evaluator.compute_metrics(evaluated_clips)
self.log_dict(metrics)
def get_dataset(self) -> TestDataset:
dataloaders = self.trainer.test_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, TestDataset)
return dataset

View File

@ -8,8 +8,8 @@ from soundevent.evaluation import compute_affinity
from soundevent.evaluation import match_geometries as optimal_match
from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.evaluate.affinity import (
AffinityConfig,
GeometricIOUConfig,
@ -111,7 +111,7 @@ def match(
class StartTimeMatchConfig(BaseConfig):
name: Literal["start_time"] = "start_time"
name: Literal["start_time_match"] = "start_time_match"
distance_threshold: float = 0.01

View File

@ -1,3 +1,4 @@
from collections import defaultdict
from collections.abc import Callable, Mapping
from typing import (
Annotated,
@ -12,13 +13,10 @@ from typing import (
import numpy as np
from pydantic import Field
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from sklearn import metrics, preprocessing
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.typing import MetricsProtocol
from batdetect2.typing.evaluate import ClipEvaluation
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipEvaluation, MetricsProtocol
__all__ = ["DetectionAP", "ClassificationAP"]
@ -26,57 +24,18 @@ __all__ = ["DetectionAP", "ClassificationAP"]
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
AveragePrecisionImplementation = Literal["sklearn", "pascal_voc"]
APImplementation = Literal["sklearn", "pascal_voc"]
class DetectionAPConfig(BaseConfig):
name: Literal["detection_ap"] = "detection_ap"
implementation: AveragePrecisionImplementation = "pascal_voc"
def pascal_voc_average_precision(y_true, y_score) -> float:
y_true = np.array(y_true)
y_score = np.array(y_score)
sort_ind = np.argsort(y_score)[::-1]
y_true_sorted = y_true[sort_ind]
num_positives = y_true.sum()
false_pos_c = np.cumsum(1 - y_true_sorted)
true_pos_c = np.cumsum(y_true_sorted)
recall = true_pos_c / num_positives
precision = true_pos_c / np.maximum(
true_pos_c + false_pos_c,
np.finfo(np.float64).eps,
)
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
# pascal 12 way
mprec = np.hstack((0, precision, 0))
mrec = np.hstack((0, recall, 1))
for ii in range(mprec.shape[0] - 2, -1, -1):
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
return ave_prec
_ap_impl_mapping: Mapping[
AveragePrecisionImplementation, Callable[[Any, Any], float]
] = {
"sklearn": metrics.average_precision_score,
"pascal_voc": pascal_voc_average_precision,
}
ap_implementation: APImplementation = "pascal_voc"
class DetectionAP(MetricsProtocol):
def __init__(
self,
implementation: AveragePrecisionImplementation = "pascal_voc",
implementation: APImplementation = "pascal_voc",
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
@ -96,14 +55,43 @@ class DetectionAP(MetricsProtocol):
@classmethod
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
return cls(implementation=config.implementation)
return cls(implementation=config.ap_implementation)
metrics_registry.register(DetectionAPConfig, DetectionAP)
class DetectionROCAUCConfig(BaseConfig):
name: Literal["detection_roc_auc"] = "detection_roc_auc"
class DetectionROCAUC(MetricsProtocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
score = float(metrics.roc_auc_score(y_true, y_score))
return {"detection_ROC_AUC": score}
@classmethod
def from_config(
cls, config: DetectionROCAUCConfig, class_names: List[str]
):
return cls()
metrics_registry.register(DetectionROCAUCConfig, DetectionROCAUC)
class ClassificationAPConfig(BaseConfig):
name: Literal["classification_ap"] = "classification_ap"
ap_implementation: APImplementation = "pascal_voc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
@ -112,7 +100,7 @@ class ClassificationAP(MetricsProtocol):
def __init__(
self,
class_names: List[str],
implementation: AveragePrecisionImplementation = "pascal_voc",
implementation: APImplementation = "pascal_voc",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
@ -163,7 +151,7 @@ class ClassificationAP(MetricsProtocol):
)
)
y_true = label_binarize(y_true, classes=self.class_names)
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
class_scores = {}
@ -193,6 +181,7 @@ class ClassificationAP(MetricsProtocol):
):
return cls(
class_names,
implementation=config.ap_implementation,
include=config.include,
exclude=config.exclude,
)
@ -201,11 +190,523 @@ class ClassificationAP(MetricsProtocol):
metrics_registry.register(ClassificationAPConfig, ClassificationAP)
class ClassificationROCAUCConfig(BaseConfig):
name: Literal["classification_roc_auc"] = "classification_roc_auc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationROCAUC(MetricsProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = preprocessing.label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
class_scores[class_name] = float(class_roc_auc)
mean_roc_auc = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"classification_macro_average_ROC_AUC": float(mean_roc_auc),
**{
f"classification_ROC_AUC/{class_name}": class_scores[
class_name
]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls,
config: ClassificationROCAUCConfig,
class_names: List[str],
):
return cls(
class_names,
include=config.include,
exclude=config.exclude,
)
metrics_registry.register(ClassificationROCAUCConfig, ClassificationROCAUC)
class TopClassAPConfig(BaseConfig):
name: Literal["top_class_ap"] = "top_class_ap"
ap_implementation: APImplementation = "pascal_voc"
class TopClassAP(MetricsProtocol):
def __init__(
self,
implementation: APImplementation = "pascal_voc",
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
top_class = match.pred_class
y_true.append(top_class == match.gt_class)
y_score.append(match.pred_class_score)
score = float(self.metric(y_true, y_score))
return {"top_class_AP": score}
@classmethod
def from_config(cls, config: TopClassAPConfig, class_names: List[str]):
return cls(implementation=config.ap_implementation)
metrics_registry.register(TopClassAPConfig, TopClassAP)
class ClassificationBalancedAccuracyConfig(BaseConfig):
name: Literal["classification_balanced_accuracy"] = (
"classification_balanced_accuracy"
)
class ClassificationBalancedAccuracy(MetricsProtocol):
def __init__(self, class_names: List[str]):
self.class_names = class_names
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
top_class = match.pred_class
# Focus on matches
if match.gt_class is None or top_class is None:
continue
y_true.append(self.class_names.index(match.gt_class))
y_pred.append(self.class_names.index(top_class))
score = float(metrics.balanced_accuracy_score(y_true, y_pred))
return {"classification_balanced_accuracy": score}
@classmethod
def from_config(
cls,
config: ClassificationBalancedAccuracyConfig,
class_names: List[str],
):
return cls(class_names)
metrics_registry.register(
ClassificationBalancedAccuracyConfig,
ClassificationBalancedAccuracy,
)
class ClipDetectionAPConfig(BaseConfig):
name: Literal["clip_detection_ap"] = "clip_detection_ap"
ap_implementation: APImplementation = "pascal_voc"
class ClipDetectionAP(MetricsProtocol):
def __init__(
self,
implementation: APImplementation,
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
clip_det = []
clip_scores = []
for match in clip_eval.matches:
clip_det.append(match.gt_det)
clip_scores.append(match.pred_score)
y_true.append(any(clip_det))
y_score.append(max(clip_scores or [0]))
return {"clip_detection_ap": self.metric(y_true, y_score)}
@classmethod
def from_config(
cls,
config: ClipDetectionAPConfig,
class_names: List[str],
):
return cls(implementation=config.ap_implementation)
metrics_registry.register(ClipDetectionAPConfig, ClipDetectionAP)
class ClipDetectionROCAUCConfig(BaseConfig):
name: Literal["clip_detection_roc_auc"] = "clip_detection_roc_auc"
class ClipDetectionROCAUC(MetricsProtocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
clip_det = []
clip_scores = []
for match in clip_eval.matches:
clip_det.append(match.gt_det)
clip_scores.append(match.pred_score)
y_true.append(any(clip_det))
y_score.append(max(clip_scores or [0]))
return {
"clip_detection_ap": float(metrics.roc_auc_score(y_true, y_score))
}
@classmethod
def from_config(
cls,
config: ClipDetectionROCAUCConfig,
class_names: List[str],
):
return cls()
metrics_registry.register(ClipDetectionROCAUCConfig, ClipDetectionROCAUC)
class ClipMulticlassAPConfig(BaseConfig):
name: Literal["clip_multiclass_ap"] = "clip_multiclass_ap"
ap_implementation: APImplementation = "pascal_voc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClipMulticlassAP(MetricsProtocol):
def __init__(
self,
class_names: List[str],
implementation: APImplementation,
include: Optional[Sequence[str]] = None,
exclude: Optional[Sequence[str]] = None,
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
clip_classes = set()
clip_scores = defaultdict(list)
for match in clip_eval.matches:
if match.gt_class is not None:
clip_classes.add(match.gt_class)
for class_name, score in match.pred_class_scores.items():
clip_scores[class_name].append(score)
y_true.append(clip_classes)
y_pred.append(
np.array(
[
# Get max score for each class
max(clip_scores.get(class_name, [0]))
for class_name in self.class_names
]
)
)
y_true = preprocessing.MultiLabelBinarizer(
classes=self.class_names
).fit_transform(y_true)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_ap = self.metric(y_true_class, y_pred_class)
class_scores[class_name] = float(class_ap)
mean_ap = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"clip_multiclass_mAP": float(mean_ap),
**{
f"clip_multiclass_AP/{class_name}": class_scores[class_name]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls, config: ClipMulticlassAPConfig, class_names: List[str]
):
return cls(
implementation=config.ap_implementation,
include=config.include,
exclude=config.exclude,
class_names=class_names,
)
metrics_registry.register(ClipMulticlassAPConfig, ClipMulticlassAP)
class ClipMulticlassROCAUCConfig(BaseConfig):
name: Literal["clip_multiclass_roc_auc"] = "clip_multiclass_roc_auc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClipMulticlassROCAUC(MetricsProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[Sequence[str]] = None,
exclude: Optional[Sequence[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
clip_classes = set()
clip_scores = defaultdict(list)
for match in clip_eval.matches:
if match.gt_class is not None:
clip_classes.add(match.gt_class)
for class_name, score in match.pred_class_scores.items():
clip_scores[class_name].append(score)
y_true.append(clip_classes)
y_pred.append(
np.array(
[
# Get maximum score for each class
max(clip_scores.get(class_name, [0]))
for class_name in self.class_names
]
)
)
y_true = preprocessing.MultiLabelBinarizer(
classes=self.class_names
).fit_transform(y_true)
y_pred = np.stack(y_pred)
class_scores = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_roc_auc = metrics.roc_auc_score(y_true_class, y_pred_class)
class_scores[class_name] = float(class_roc_auc)
mean_roc_auc = np.mean(
[value for value in class_scores.values() if value != 0]
)
return {
"clip_multiclass_macro_ROC_AUC": float(mean_roc_auc),
**{
f"clip_multiclass_ROC_AUC/{class_name}": class_scores[
class_name
]
for class_name in self.selected
},
}
@classmethod
def from_config(
cls,
config: ClipMulticlassROCAUCConfig,
class_names: List[str],
):
return cls(
include=config.include,
exclude=config.exclude,
class_names=class_names,
)
metrics_registry.register(ClipMulticlassROCAUCConfig, ClipMulticlassROCAUC)
MetricConfig = Annotated[
Union[ClassificationAPConfig, DetectionAPConfig],
Union[
DetectionAPConfig,
DetectionROCAUCConfig,
ClassificationAPConfig,
ClassificationROCAUCConfig,
TopClassAPConfig,
ClassificationBalancedAccuracyConfig,
ClipDetectionAPConfig,
ClipDetectionROCAUCConfig,
ClipMulticlassAPConfig,
ClipMulticlassROCAUCConfig,
],
Field(discriminator="name"),
]
def build_metric(config: MetricConfig, class_names: List[str]):
return metrics_registry.build(config, class_names)
def pascal_voc_average_precision(y_true, y_score) -> float:
y_true = np.array(y_true)
y_score = np.array(y_score)
sort_ind = np.argsort(y_score)[::-1]
y_true_sorted = y_true[sort_ind]
num_positives = y_true.sum()
false_pos_c = np.cumsum(1 - y_true_sorted)
true_pos_c = np.cumsum(y_true_sorted)
recall = true_pos_c / num_positives
precision = true_pos_c / np.maximum(
true_pos_c + false_pos_c,
np.finfo(np.float64).eps,
)
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
# pascal 12 way
mprec = np.hstack((0, precision, 0))
mrec = np.hstack((0, recall, 1))
for ii in range(mprec.shape[0] - 2, -1, -1):
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
return ave_prec
_ap_impl_mapping: Mapping[APImplementation, Callable[[Any, Any], float]] = {
"sklearn": metrics.average_precision_score,
"pascal_voc": pascal_voc_average_precision,
}

View File

@ -4,20 +4,24 @@ from dataclasses import dataclass, field
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pydantic import Field
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import BaseConfig, Registry
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.matches import plot_matches
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing.evaluate import (
from batdetect2.typing import (
AudioLoader,
ClipEvaluation,
MatchEvaluation,
PlotterProtocol,
PreprocessorProtocol,
)
from batdetect2.typing.preprocess import AudioLoader
__all__ = [
"build_plotter",
@ -26,12 +30,13 @@ __all__ = [
]
plots_registry: Registry[PlotterProtocol, []] = Registry("plot")
plots_registry: Registry[PlotterProtocol, [List[str]]] = Registry("plot")
class ExampleGalleryConfig(BaseConfig):
name: Literal["example_gallery"] = "example_gallery"
examples_per_class: int = 5
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
@ -87,9 +92,12 @@ class ExampleGallery(PlotterProtocol):
plt.close(fig)
@classmethod
def from_config(cls, config: ExampleGalleryConfig):
preprocessor = build_preprocessor(config.preprocessing)
audio_loader = build_audio_loader(config.preprocessing.audio)
def from_config(cls, config: ExampleGalleryConfig, class_names: List[str]):
audio_loader = build_audio_loader(config.audio)
preprocessor = build_preprocessor(
config.preprocessing,
input_samplerate=audio_loader.samplerate,
)
return cls(
examples_per_class=config.examples_per_class,
preprocessor=preprocessor,
@ -100,13 +108,402 @@ class ExampleGallery(PlotterProtocol):
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
class ClipEvaluationPlotConfig(BaseConfig):
name: Literal["example_clip"] = "example_clip"
num_plots: int = 5
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class PlotClipEvaluation(PlotterProtocol):
def __init__(
self,
num_plots: int = 3,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
):
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.num_plots = num_plots
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
examples = random.sample(
clip_evaluations,
k=min(self.num_plots, len(clip_evaluations)),
)
for index, clip_evaluation in enumerate(examples):
fig, ax = plt.subplots()
plot_matches(
clip_evaluation.matches,
clip=clip_evaluation.clip,
audio_loader=self.audio_loader,
ax=ax,
)
yield f"clip_evaluation/example_{index}", fig
plt.close(fig)
@classmethod
def from_config(
cls,
config: ClipEvaluationPlotConfig,
class_names: List[str],
):
audio_loader = build_audio_loader(config.audio)
preprocessor = build_preprocessor(
config.preprocessing,
input_samplerate=audio_loader.samplerate,
)
return cls(
num_plots=config.num_plots,
preprocessor=preprocessor,
audio_loader=audio_loader,
)
plots_registry.register(ClipEvaluationPlotConfig, PlotClipEvaluation)
class DetectionPRCurveConfig(BaseConfig):
name: Literal["detection_pr_curve"] = "detection_pr_curve"
class DetectionPRCurve(PlotterProtocol):
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
precision, recall, _ = metrics.precision_recall_curve(y_true, y_score)
fig, ax = plt.subplots()
ax.plot(recall, precision, label="Detector")
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.legend()
yield "detection_pr_curve", fig
@classmethod
def from_config(
cls,
config: DetectionPRCurveConfig,
class_names: List[str],
):
return cls()
plots_registry.register(DetectionPRCurveConfig, DetectionPRCurve)
class ClassificationPRCurvesConfig(BaseConfig):
name: Literal["classification_pr_curves"] = "classification_pr_curves"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationPRCurves(PlotterProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
fig, ax = plt.subplots(figsize=(10, 10))
for class_index, class_name in enumerate(self.class_names):
if class_name not in self.selected:
continue
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
precision, recall, _ = metrics.precision_recall_curve(
y_true_class,
y_pred_class,
)
ax.plot(recall, precision, label=class_name)
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
yield "classification_pr_curve", fig
@classmethod
def from_config(
cls,
config: ClassificationPRCurvesConfig,
class_names: List[str],
):
return cls(
class_names=class_names,
include=config.include,
exclude=config.exclude,
)
plots_registry.register(ClassificationPRCurvesConfig, ClassificationPRCurves)
class DetectionROCCurveConfig(BaseConfig):
name: Literal["detection_roc_curve"] = "detection_roc_curve"
class DetectionROCCurve(PlotterProtocol):
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
fpr, tpr, _ = metrics.roc_curve(y_true, y_score)
fig, ax = plt.subplots()
ax.plot(fpr, tpr, label="Detection")
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend()
yield "detection_roc_curve", fig
@classmethod
def from_config(
cls,
config: DetectionROCCurveConfig,
class_names: List[str],
):
return cls()
plots_registry.register(DetectionROCCurveConfig, DetectionROCCurve)
class ClassificationROCCurvesConfig(BaseConfig):
name: Literal["classification_roc_curves"] = "classification_roc_curves"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationROCCurves(PlotterProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
fig, ax = plt.subplots(figsize=(10, 10))
for class_index, class_name in enumerate(self.class_names):
if class_name not in self.selected:
continue
y_true_class = y_true[:, class_index]
y_roced_class = y_pred[:, class_index]
fpr, tpr, _ = metrics.roc_curve(
y_true_class,
y_roced_class,
)
ax.plot(fpr, tpr, label=class_name)
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
yield "classification_roc_curve", fig
@classmethod
def from_config(
cls,
config: ClassificationROCCurvesConfig,
class_names: List[str],
):
return cls(
class_names=class_names,
include=config.include,
exclude=config.exclude,
)
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
class ConfusionMatrixConfig(BaseConfig):
name: Literal["confusion_matrix"] = "confusion_matrix"
background_class: str = "noise"
class ConfusionMatrix(PlotterProtocol):
def __init__(self, background_class: str, class_names: List[str]):
self.background_class = background_class
self.class_names = class_names
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else self.background_class
)
top_class = match.pred_class
y_pred.append(
top_class
if top_class is not None
else self.background_class
)
display = metrics.ConfusionMatrixDisplay.from_predictions(
y_true,
y_pred,
labels=[*self.class_names, self.background_class],
)
yield "confusion_matrix", display.figure_
@classmethod
def from_config(
cls,
config: ConfusionMatrixConfig,
class_names: List[str],
):
return cls(
background_class=config.background_class,
class_names=class_names,
)
plots_registry.register(ConfusionMatrixConfig, ConfusionMatrix)
PlotConfig = Annotated[
Union[ExampleGalleryConfig,], Field(discriminator="name")
Union[
ExampleGalleryConfig,
ClipEvaluationPlotConfig,
DetectionPRCurveConfig,
ClassificationPRCurvesConfig,
DetectionROCCurveConfig,
ClassificationROCCurvesConfig,
ConfusionMatrixConfig,
],
Field(discriminator="name"),
]
def build_plotter(config: PlotConfig) -> PlotterProtocol:
return plots_registry.build(config)
def build_plotter(
config: PlotConfig, class_names: List[str]
) -> PlotterProtocol:
return plots_registry.build(config, class_names)
@dataclass

View File

@ -1,18 +1,49 @@
from typing import List
from typing import Annotated, Callable, Literal, Sequence, Union
import pandas as pd
from pydantic import Field
from soundevent.geometry import compute_bounds
from batdetect2.typing.evaluate import ClipEvaluation
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipEvaluation
EvaluationTableGenerator = Callable[[Sequence[ClipEvaluation]], pd.DataFrame]
def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.DataFrame:
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
"evaluation_table"
)
class FullEvaluationTableConfig(BaseConfig):
name: Literal["full_evaluation"] = "full_evaluation"
class FullEvaluationTable:
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> pd.DataFrame:
return extract_matches_dataframe(clip_evaluations)
@classmethod
def from_config(cls, config: FullEvaluationTableConfig):
return cls()
tables_registry.register(FullEvaluationTableConfig, FullEvaluationTable)
def extract_matches_dataframe(
clip_evaluations: Sequence[ClipEvaluation],
) -> pd.DataFrame:
data = []
for clip_evaluation in clip_evaluations:
for match in clip_evaluation.matches:
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
pred_start_time = pred_low_freq = pred_end_time = (
pred_high_freq
) = None
sound_event_annotation = match.sound_event_annotation
@ -24,9 +55,12 @@ def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.Data
)
if match.pred_geometry is not None:
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
compute_bounds(match.pred_geometry)
)
(
pred_start_time,
pred_low_freq,
pred_end_time,
pred_high_freq,
) = compute_bounds(match.pred_geometry)
data.append(
{
@ -61,3 +95,14 @@ def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.Data
df = pd.DataFrame(data)
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
return df
EvaluationTableConfig = Annotated[
Union[FullEvaluationTableConfig,], Field(discriminator="name")
]
def build_table_generator(
config: EvaluationTableConfig,
) -> EvaluationTableGenerator:
return tables_registry.build(config)

View File

View File

@ -1,4 +1,6 @@
import io
from collections.abc import Callable
from functools import partial
from pathlib import Path
from typing import (
Annotated,
@ -13,12 +15,19 @@ from typing import (
)
import numpy as np
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
import pandas as pd
from lightning.pytorch.loggers import (
CSVLogger,
Logger,
MLFlowLogger,
TensorBoardLogger,
)
from loguru import logger
from matplotlib.figure import Figure
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
@ -48,7 +57,7 @@ class TensorBoardLoggerConfig(BaseLoggerConfig):
class MLFlowLoggerConfig(BaseLoggerConfig):
name: Literal["mlflow"] = "mlflow"
tracking_uri: Optional[str] = None
tracking_uri: Optional[str] = "http://localhost:5000"
tags: Optional[dict[str, Any]] = None
log_model: bool = False
@ -152,6 +161,9 @@ def create_tensorboard_logger(
name = run_name
if name is None:
name = experiment_name
if run_name is not None and experiment_name is not None:
name = str(Path(experiment_name) / run_name)
@ -231,18 +243,18 @@ def build_logger(
)
def get_image_plotter(logger: Logger):
PlotLogger = Callable[[str, Figure, int], None]
def get_image_logger(logger: Logger) -> Optional[PlotLogger]:
if isinstance(logger, TensorBoardLogger):
def plot_figure(name, figure, step):
return logger.experiment.add_figure(name, figure, step)
return plot_figure
return logger.experiment.add_figure
if isinstance(logger, MLFlowLogger):
def plot_figure(name, figure, step):
image = _convert_figure_to_image(figure)
image = _convert_figure_to_array(figure)
name = name.replace("/", "_")
return logger.experiment.log_image(
logger.run_id,
image,
@ -252,8 +264,51 @@ def get_image_plotter(logger: Logger):
return plot_figure
if isinstance(logger, CSVLogger):
return partial(save_figure, dir=Path(logger.log_dir))
def _convert_figure_to_image(figure):
TableLogger = Callable[[str, pd.DataFrame, int], None]
def get_table_logger(logger: Logger) -> Optional[TableLogger]:
if isinstance(logger, TensorBoardLogger):
return partial(save_table, dir=Path(logger.log_dir))
if isinstance(logger, MLFlowLogger):
def plot_figure(name: str, df: pd.DataFrame, step: int):
return logger.experiment.log_table(
logger.run_id,
data=df,
artifact_file=f"{name}_step_{step}.json",
)
return plot_figure
if isinstance(logger, CSVLogger):
return partial(save_table, dir=Path(logger.log_dir))
def save_table(name: str, df: pd.DataFrame, step: int, dir: Path) -> None:
path = dir / "tables" / f"{name}_step_{step}.csv"
if not path.parent.exists():
path.parent.mkdir(parents=True)
df.to_csv(path, index=False)
def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
path = dir / "plots" / f"{name}_step_{step}.png"
if not path.parent.exists():
path.parent.mkdir(parents=True)
fig.savefig(path, transparent=True, bbox_inches="tight")
def _convert_figure_to_array(figure: Figure) -> np.ndarray:
with io.BytesIO() as buff:
figure.savefig(buff, format="raw")
buff.seek(0)

View File

@ -29,15 +29,10 @@ provided here.
from typing import List, Optional
import torch
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.backbones import (
Backbone,
BackboneConfig,
build_backbone,
load_backbone_config,
)
from batdetect2.models.blocks import (
ConvConfig,
@ -51,6 +46,10 @@ from batdetect2.models.bottleneck import (
BottleneckConfig,
build_bottleneck,
)
from batdetect2.models.config import (
BackboneConfig,
load_backbone_config,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
@ -63,12 +62,12 @@ from batdetect2.models.encoder import (
build_encoder,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.typing.models import DetectionModel
from batdetect2.typing.postprocess import (
DetectionsTensor,
ClipDetectionsTensor,
PostprocessorProtocol,
)
from batdetect2.typing.preprocess import PreprocessorProtocol
@ -99,20 +98,10 @@ __all__ = [
"build_detector",
"load_backbone_config",
"Model",
"ModelConfig",
"build_model",
]
class ModelConfig(BaseConfig):
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
class Model(torch.nn.Module):
detector: DetectionModel
preprocessor: PreprocessorProtocol
@ -125,47 +114,38 @@ class Model(torch.nn.Module):
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
targets: TargetProtocol,
config: ModelConfig,
):
super().__init__()
self.detector = detector
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.targets = targets
self.config = config
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
spec = self.preprocessor(wav)
outputs = self.detector(spec)
return self.postprocessor(outputs)
def build_model(config: Optional[ModelConfig] = None):
config = config or ModelConfig()
targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess)
postprocessor = build_postprocessor(
def build_model(
config: Optional[BackboneConfig] = None,
targets: Optional[TargetProtocol] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
postprocessor: Optional[PostprocessorProtocol] = None,
):
config = config or BackboneConfig()
targets = targets or build_targets()
preprocessor = preprocessor or build_preprocessor()
postprocessor = postprocessor or build_postprocessor(
preprocessor=preprocessor,
config=config.postprocess,
)
detector = build_detector(
num_classes=len(targets.class_names),
config=config.model,
config=config,
)
return Model(
config=config,
detector=detector,
postprocessor=postprocessor,
preprocessor=preprocessor,
targets=targets,
)
def load_model_config(
path: PathLike, field: Optional[str] = None
) -> ModelConfig:
return load_config(path, schema=ModelConfig, field=field)

View File

@ -18,37 +18,20 @@ automatic padding to handle input sizes not perfectly divisible by the
network's total downsampling factor.
"""
from typing import Optional, Tuple
from typing import Tuple
import torch
import torch.nn.functional as F
from soundevent import data
from torch import nn
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig,
build_bottleneck,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
Decoder,
DecoderConfig,
build_decoder,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
Encoder,
EncoderConfig,
build_encoder,
)
from batdetect2.models.bottleneck import build_bottleneck
from batdetect2.models.config import BackboneConfig
from batdetect2.models.decoder import Decoder, build_decoder
from batdetect2.models.encoder import Encoder, build_encoder
from batdetect2.typing.models import BackboneModel
__all__ = [
"Backbone",
"BackboneConfig",
"load_backbone_config",
"build_backbone",
]
@ -161,82 +144,6 @@ class Backbone(BackboneModel):
return x
class BackboneConfig(BaseConfig):
"""Configuration for the Encoder-Decoder Backbone network.
Aggregates configurations for the encoder, bottleneck, and decoder
components, along with defining the input and final output dimensions
for the complete backbone.
Attributes
----------
input_height : int, default=128
Expected height (frequency bins) of the input spectrograms to the
backbone. Must be positive.
in_channels : int, default=1
Expected number of channels in the input spectrograms (e.g., 1 for
mono). Must be positive.
encoder : EncoderConfig, optional
Configuration for the encoder. If None or omitted,
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
encoder module) will be used.
bottleneck : BottleneckConfig, optional
Configuration for the bottleneck layer connecting encoder and decoder.
If None or omitted, the default bottleneck configuration will be used.
decoder : DecoderConfig, optional
Configuration for the decoder. If None or omitted,
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
decoder module) will be used.
out_channels : int, default=32
Desired number of channels in the final feature map output by the
backbone. Must be positive.
"""
input_height: int = 128
in_channels: int = 1
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
out_channels: int = 32
def load_backbone_config(
path: data.PathLike,
field: Optional[str] = None,
) -> BackboneConfig:
"""Load the backbone configuration from a file.
Reads a configuration file (YAML) and validates it against the
`BackboneConfig` schema, potentially extracting data from a nested field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
backbone configuration (e.g., "model.backbone"). If None, the entire
file content is used.
Returns
-------
BackboneConfig
The loaded and validated backbone configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded config data does not conform to `BackboneConfig`.
KeyError, TypeError
If `field` specifies an invalid path.
"""
return load_config(path, schema=BackboneConfig, field=field)
def build_backbone(config: BackboneConfig) -> BackboneModel:
"""Factory function to build a Backbone from configuration.

View File

@ -34,7 +34,7 @@ import torch.nn.functional as F
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
__all__ = [
"ConvBlock",

View File

@ -20,7 +20,7 @@ import torch
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.models.blocks import (
SelfAttentionConfig,
VerticalConv,

View File

@ -0,0 +1,98 @@
from typing import Optional
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
EncoderConfig,
)
__all__ = [
"BackboneConfig",
"load_backbone_config",
]
class BackboneConfig(BaseConfig):
"""Configuration for the Encoder-Decoder Backbone network.
Aggregates configurations for the encoder, bottleneck, and decoder
components, along with defining the input and final output dimensions
for the complete backbone.
Attributes
----------
input_height : int, default=128
Expected height (frequency bins) of the input spectrograms to the
backbone. Must be positive.
in_channels : int, default=1
Expected number of channels in the input spectrograms (e.g., 1 for
mono). Must be positive.
encoder : EncoderConfig, optional
Configuration for the encoder. If None or omitted,
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
encoder module) will be used.
bottleneck : BottleneckConfig, optional
Configuration for the bottleneck layer connecting encoder and decoder.
If None or omitted, the default bottleneck configuration will be used.
decoder : DecoderConfig, optional
Configuration for the decoder. If None or omitted,
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
decoder module) will be used.
out_channels : int, default=32
Desired number of channels in the final feature map output by the
backbone. Must be positive.
"""
input_height: int = 128
in_channels: int = 1
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
out_channels: int = 32
def load_backbone_config(
path: data.PathLike,
field: Optional[str] = None,
) -> BackboneConfig:
"""Load the backbone configuration from a file.
Reads a configuration file (YAML) and validates it against the
`BackboneConfig` schema, potentially extracting data from a nested field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
backbone configuration (e.g., "model.backbone"). If None, the entire
file content is used.
Returns
-------
BackboneConfig
The loaded and validated backbone configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded config data does not conform to `BackboneConfig`.
KeyError, TypeError
If `field` specifies an invalid path.
"""
return load_config(path, schema=BackboneConfig, field=field)

View File

@ -24,7 +24,7 @@ import torch
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.models.blocks import (
ConvConfig,
FreqCoordConvUpConfig,

View File

@ -26,7 +26,7 @@ import torch
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.models.blocks import (
ConvConfig,
FreqCoordConvDownConfig,

View File

@ -5,8 +5,9 @@ import torch
from matplotlib.axes import Axes
from soundevent import data
from batdetect2.audio import build_audio_loader
from batdetect2.plotting.common import plot_spectrogram
from batdetect2.preprocess import build_audio_loader, build_preprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol
__all__ = [

View File

@ -6,10 +6,8 @@ from soundevent import data, plot
from soundevent.geometry import compute_bounds
from soundevent.plot.tags import TagColorMapper
from batdetect2.plotting.clip_predictions import plot_prediction
from batdetect2.plotting.clips import AudioLoader, plot_clip
from batdetect2.preprocess import PreprocessorProtocol
from batdetect2.typing.evaluate import MatchEvaluation
from batdetect2.typing import MatchEvaluation, PreprocessorProtocol
__all__ = [
"plot_matches",
@ -30,7 +28,7 @@ DEFAULT_PREDICTION_LINE_STYLE = "--"
def plot_matches(
matches: List[data.Match],
matches: List[MatchEvaluation],
clip: data.Clip,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
@ -44,8 +42,7 @@ def plot_matches(
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
cross_trigger_color: str = DEFAULT_CROSS_TRIGGER_COLOR,
) -> Axes:
ax = plot_clip(
clip,
@ -61,52 +58,48 @@ def plot_matches(
color_mapper = TagColorMapper()
for match in matches:
if match.source is None and match.target is not None:
plot.plot_annotation(
annotation=match.target,
if match.is_cross_trigger():
plot_cross_trigger_match(
match,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
fill=fill,
add_points=add_points,
add_spectrogram=False,
use_score=True,
color=cross_trigger_color,
add_text=False,
)
elif match.is_true_positive():
plot_true_positive_match(
match,
ax=ax,
fill=fill,
add_spectrogram=False,
use_score=True,
add_points=add_points,
color=true_positive_color,
add_text=False,
)
elif match.is_false_negative():
plot_false_negative_match(
match,
ax=ax,
fill=fill,
add_spectrogram=False,
add_points=add_points,
facecolor="none" if not fill else None,
color=false_negative_color,
color_mapper=color_mapper,
linestyle=annotation_linestyle,
add_text=False,
)
elif match.target is None and match.source is not None:
plot_prediction(
prediction=match.source,
elif match.is_false_positive:
plot_false_positive_match(
match,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
fill=fill,
add_spectrogram=False,
use_score=True,
add_points=add_points,
facecolor="none" if not fill else None,
color=false_positive_color,
color_mapper=color_mapper,
linestyle=prediction_linestyle,
)
elif match.target is not None and match.source is not None:
plot.plot_annotation(
annotation=match.target,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=true_positive_color,
color_mapper=color_mapper,
linestyle=annotation_linestyle,
)
plot_prediction(
prediction=match.source,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=true_positive_color,
color_mapper=color_mapper,
linestyle=prediction_linestyle,
add_text=False,
)
else:
continue
@ -122,6 +115,9 @@ def plot_false_positive_match(
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
use_score: bool = True,
add_spectrogram: bool = True,
add_text: bool = True,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
@ -142,34 +138,36 @@ def plot_false_positive_match(
recording=match.clip.recording,
)
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
if add_spectrogram:
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
plot.plot_geometry(
ax = plot.plot_geometry(
match.pred_geometry,
ax=ax,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
alpha=match.pred_score if use_score else 1,
color=color,
)
plt.text(
start_time,
high_freq,
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_text:
plt.text(
start_time,
high_freq,
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax
@ -182,7 +180,9 @@ def plot_false_negative_match(
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_spectrogram: bool = True,
add_points: bool = False,
add_text: bool = True,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
@ -204,17 +204,18 @@ def plot_false_negative_match(
recording=sound_event.recording,
)
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
if add_spectrogram:
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
ax = plot.plot_annotation(
match.sound_event_annotation,
ax=ax,
time_offset=0.001,
@ -225,15 +226,16 @@ def plot_false_negative_match(
color=color,
)
plt.text(
start_time,
high_freq,
f"False Negative \nClass: {match.gt_class} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_text:
plt.text(
start_time,
high_freq,
f"False Negative \nClass: {match.gt_class} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax
@ -246,7 +248,10 @@ def plot_true_positive_match(
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
use_score: bool = True,
add_spectrogram: bool = True,
add_points: bool = False,
add_text: bool = True,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
@ -270,17 +275,18 @@ def plot_true_positive_match(
recording=sound_event.recording,
)
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
if add_spectrogram:
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
ax = plot.plot_annotation(
match.sound_event_annotation,
ax=ax,
time_offset=0.001,
@ -297,20 +303,21 @@ def plot_true_positive_match(
ax=ax,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
alpha=match.pred_score if use_score else 1,
color=color,
linestyle=prediction_linestyle,
)
plt.text(
start_time,
high_freq,
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_text:
plt.text(
start_time,
high_freq,
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax
@ -323,7 +330,10 @@ def plot_cross_trigger_match(
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
use_score: bool = True,
add_spectrogram: bool = True,
add_points: bool = False,
add_text: bool = True,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
@ -347,17 +357,18 @@ def plot_cross_trigger_match(
recording=sound_event.recording,
)
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
if add_spectrogram:
ax = plot_clip(
clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
ax = plot.plot_annotation(
match.sound_event_annotation,
ax=ax,
time_offset=0.001,
@ -369,24 +380,25 @@ def plot_cross_trigger_match(
linestyle=annotation_linestyle,
)
plot.plot_geometry(
ax = plot.plot_geometry(
match.pred_geometry,
ax=ax,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
alpha=match.pred_score if use_score else 1,
color=color,
linestyle=prediction_linestyle,
)
plt.text(
start_time,
high_freq,
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_text:
plt.text(
start_time,
high_freq,
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax

View File

@ -1,307 +1,25 @@
"""Main entry point for the BatDetect2 Postprocessing pipeline."""
from typing import List, Optional
import torch
from loguru import logger
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.postprocess.config import (
PostprocessConfig,
load_postprocess_config,
)
from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction,
to_raw_predictions,
)
from batdetect2.postprocess.extraction import extract_prediction_tensor
from batdetect2.postprocess.nms import (
NMS_KERNEL_SIZE,
non_max_suppression,
from batdetect2.postprocess.nms import non_max_suppression
from batdetect2.postprocess.postprocessor import (
Postprocessor,
build_postprocessor,
)
from batdetect2.postprocess.remapping import map_detection_to_clip
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing import ModelOutput
from batdetect2.typing.postprocess import (
BatDetect2Prediction,
DetectionsTensor,
PostprocessorProtocol,
RawPrediction,
)
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"DEFAULT_CLASSIFICATION_THRESHOLD",
"DEFAULT_DETECTION_THRESHOLD",
"MAX_FREQ",
"MIN_FREQ",
"ModelOutput",
"NMS_KERNEL_SIZE",
"PostprocessConfig",
"Postprocessor",
"TOP_K_PER_SEC",
"build_postprocessor",
"convert_raw_predictions_to_clip_prediction",
"to_raw_predictions",
"load_postprocess_config",
"non_max_suppression",
]
DEFAULT_DETECTION_THRESHOLD = 0.01
TOP_K_PER_SEC = 100
class PostprocessConfig(BaseConfig):
"""Configuration settings for the postprocessing pipeline.
Defines tunable parameters that control how raw model outputs are
converted into final detections.
Attributes
----------
nms_kernel_size : int, default=NMS_KERNEL_SIZE
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
Used to suppress weaker detections near stronger peaks. Must be
positive.
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
Minimum confidence score from the detection heatmap required to
consider a point as a potential detection. Must be >= 0.
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
Minimum confidence score for a specific class prediction to be included
in the decoded tags for a detection. Must be >= 0.
top_k_per_sec : int, default=TOP_K_PER_SEC
Desired maximum number of detections per second of audio. Used by
`get_max_detections` to calculate an absolute limit based on clip
duration before applying `extract_detections_from_array`. Must be
positive.
"""
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
detection_threshold: float = Field(
default=DEFAULT_DETECTION_THRESHOLD,
ge=0,
)
classification_threshold: float = Field(
default=DEFAULT_CLASSIFICATION_THRESHOLD,
ge=0,
)
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
def load_postprocess_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PostprocessConfig:
"""Load the postprocessing configuration from a file.
Reads a configuration file (YAML) and validates it against the
`PostprocessConfig` schema, potentially extracting data from a nested
field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
postprocessing configuration (e.g., "inference.postprocessing").
If None, the entire file content is used.
Returns
-------
PostprocessConfig
The loaded and validated postprocessing configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded configuration data does not conform to the
`PostprocessConfig` schema.
KeyError, TypeError
If `field` specifies an invalid path within the loaded data.
"""
return load_config(path, schema=PostprocessConfig, field=field)
def build_postprocessor(
preprocessor: PreprocessorProtocol,
config: Optional[PostprocessConfig] = None,
) -> PostprocessorProtocol:
"""Factory function to build the standard postprocessor."""
config = config or PostprocessConfig()
logger.opt(lazy=True).debug(
"Building postprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
return Postprocessor(
samplerate=preprocessor.output_samplerate,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
top_k_per_sec=config.top_k_per_sec,
detection_threshold=config.detection_threshold,
)
class Postprocessor(torch.nn.Module, PostprocessorProtocol):
"""Standard implementation of the postprocessing pipeline."""
def __init__(
self,
samplerate: float,
min_freq: float,
max_freq: float,
top_k_per_sec: int = 200,
detection_threshold: float = 0.01,
):
"""Initialize the Postprocessor."""
super().__init__()
self.samplerate = samplerate
self.min_freq = min_freq
self.max_freq = max_freq
self.top_k_per_sec = top_k_per_sec
self.detection_threshold = detection_threshold
def forward(self, output: ModelOutput) -> List[DetectionsTensor]:
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor(
output,
max_detections=max_detections,
threshold=self.detection_threshold,
)
return [
map_detection_to_clip(
detection,
start_time=0,
end_time=duration,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection in detections
]
def get_detections(
self,
output: ModelOutput,
start_times: Optional[List[float]] = None,
) -> List[DetectionsTensor]:
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor(
output,
max_detections=max_detections,
threshold=self.detection_threshold,
)
if start_times is None:
return detections
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
return [
map_detection_to_clip(
detection,
start_time=start_time,
end_time=start_time + duration,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection, start_time in zip(detections, start_times)
]
def get_raw_predictions(
output: ModelOutput,
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
start_times: Optional[List[float]] = None,
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch."""
detections = postprocessor.get_detections(output, start_times)
return [
to_raw_predictions(detection.numpy(), targets=targets)
for detection in detections
]
def get_sound_event_predictions(
output: ModelOutput,
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
clips: List[data.Clip],
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[List[BatDetect2Prediction]]:
raw_predictions = get_raw_predictions(
output,
targets=targets,
postprocessor=postprocessor,
start_times=[clip.start_time for clip in clips],
)
return [
[
BatDetect2Prediction(
raw=raw,
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
targets=targets,
classification_threshold=classification_threshold,
),
)
for raw in predictions
]
for predictions, clip in zip(raw_predictions, clips)
]
def get_predictions(
output: ModelOutput,
clips: List[data.Clip],
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output and corresponding clips, applies the entire
configured chain (NMS, remapping, extraction, geometry recovery, class
decoding), producing final `soundevent.data.ClipPrediction` objects.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[data.ClipPrediction]
List containing one `ClipPrediction` object for each input clip,
populated with `SoundEventPrediction` objects.
"""
raw_predictions = get_raw_predictions(
output,
targets=targets,
postprocessor=postprocessor,
start_times=[clip.start_time for clip in clips],
)
return [
convert_raw_predictions_to_clip_prediction(
prediction,
clip,
targets=targets,
classification_threshold=classification_threshold,
)
for prediction, clip in zip(raw_predictions, clips)
]

View File

@ -0,0 +1,94 @@
from typing import Optional
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.postprocess.decoding import DEFAULT_CLASSIFICATION_THRESHOLD
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE
__all__ = [
"PostprocessConfig",
"load_postprocess_config",
]
DEFAULT_DETECTION_THRESHOLD = 0.01
TOP_K_PER_SEC = 100
class PostprocessConfig(BaseConfig):
"""Configuration settings for the postprocessing pipeline.
Defines tunable parameters that control how raw model outputs are
converted into final detections.
Attributes
----------
nms_kernel_size : int, default=NMS_KERNEL_SIZE
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
Used to suppress weaker detections near stronger peaks. Must be
positive.
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
Minimum confidence score from the detection heatmap required to
consider a point as a potential detection. Must be >= 0.
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
Minimum confidence score for a specific class prediction to be included
in the decoded tags for a detection. Must be >= 0.
top_k_per_sec : int, default=TOP_K_PER_SEC
Desired maximum number of detections per second of audio. Used by
`get_max_detections` to calculate an absolute limit based on clip
duration before applying `extract_detections_from_array`. Must be
positive.
"""
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
detection_threshold: float = Field(
default=DEFAULT_DETECTION_THRESHOLD,
ge=0,
)
classification_threshold: float = Field(
default=DEFAULT_CLASSIFICATION_THRESHOLD,
ge=0,
)
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
def load_postprocess_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PostprocessConfig:
"""Load the postprocessing configuration from a file.
Reads a configuration file (YAML) and validates it against the
`PostprocessConfig` schema, potentially extracting data from a nested
field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
postprocessing configuration (e.g., "inference.postprocessing").
If None, the entire file content is used.
Returns
-------
PostprocessConfig
The loaded and validated postprocessing configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded configuration data does not conform to the
`PostprocessConfig` schema.
KeyError, TypeError
If `field` specifies an invalid path within the loaded data.
"""
return load_config(path, schema=PostprocessConfig, field=field)

View File

@ -6,7 +6,7 @@ import numpy as np
from soundevent import data
from batdetect2.typing.postprocess import (
DetectionsArray,
ClipDetectionsArray,
RawPrediction,
)
from batdetect2.typing.targets import TargetProtocol
@ -28,7 +28,7 @@ decoding.
def to_raw_predictions(
detections: DetectionsArray,
detections: ClipDetectionsArray,
targets: TargetProtocol,
) -> List[RawPrediction]:
predictions = []

View File

@ -15,32 +15,25 @@ precise time-frequency location of each detection. The final output aggregates
all extracted information into a structured `xarray.Dataset`.
"""
from typing import List, Optional, Tuple, Union
from typing import List, Optional
import torch
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
from batdetect2.typing.postprocess import (
DetectionsTensor,
ModelOutput,
)
from batdetect2.typing.postprocess import ClipDetectionsTensor
__all__ = [
"extract_prediction_tensor",
"extract_detection_peaks",
]
def extract_prediction_tensor(
output: ModelOutput,
def extract_detection_peaks(
detection_heatmap: torch.Tensor,
size_heatmap: torch.Tensor,
feature_heatmap: torch.Tensor,
classification_heatmap: torch.Tensor,
max_detections: int = 200,
threshold: Optional[float] = None,
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
) -> List[DetectionsTensor]:
detection_heatmap = non_max_suppression(
output.detection_probs.detach(),
kernel_size=nms_kernel_size,
)
) -> List[ClipDetectionsTensor]:
height = detection_heatmap.shape[-2]
width = detection_heatmap.shape[-1]
@ -53,9 +46,9 @@ def extract_prediction_tensor(
freqs = freqs.flatten().to(detection_heatmap.device)
times = times.flatten().to(detection_heatmap.device)
output_size_preds = output.size_preds.detach()
output_features = output.features.detach()
output_class_probs = output.class_probs.detach()
output_size_preds = size_heatmap.detach()
output_features = feature_heatmap.detach()
output_class_probs = classification_heatmap.detach()
predictions = []
for idx, item in enumerate(detection_heatmap):
@ -65,23 +58,25 @@ def extract_prediction_tensor(
detection_scores = item.take(indices)
detection_freqs = freqs.take(indices)
detection_times = times.take(indices)
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
features = output_features[idx, :, detection_freqs, detection_times].T
class_scores = output_class_probs[
idx, :, detection_freqs, detection_times
].T
if threshold is not None:
mask = detection_scores >= threshold
detection_scores = detection_scores[mask]
sizes = sizes[mask]
detection_times = detection_times[mask]
detection_freqs = detection_freqs[mask]
features = features[mask]
class_scores = class_scores[mask]
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
features = output_features[idx, :, detection_freqs, detection_times].T
class_scores = output_class_probs[
idx,
:,
detection_freqs,
detection_times,
].T
predictions.append(
DetectionsTensor(
ClipDetectionsTensor(
scores=detection_scores,
sizes=sizes,
features=features,

View File

@ -0,0 +1,100 @@
from typing import List, Optional, Tuple, Union
import torch
from loguru import logger
from batdetect2.postprocess.config import (
PostprocessConfig,
)
from batdetect2.postprocess.extraction import extract_detection_peaks
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
from batdetect2.postprocess.remapping import map_detection_to_clip
from batdetect2.typing import ModelOutput
from batdetect2.typing.postprocess import (
ClipDetectionsTensor,
PostprocessorProtocol,
)
from batdetect2.typing.preprocess import PreprocessorProtocol
__all__ = [
"build_postprocessor",
"Postprocessor",
]
def build_postprocessor(
preprocessor: PreprocessorProtocol,
config: Optional[PostprocessConfig] = None,
) -> PostprocessorProtocol:
"""Factory function to build the standard postprocessor."""
config = config or PostprocessConfig()
logger.opt(lazy=True).debug(
"Building postprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
return Postprocessor(
samplerate=preprocessor.output_samplerate,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
top_k_per_sec=config.top_k_per_sec,
detection_threshold=config.detection_threshold,
)
class Postprocessor(torch.nn.Module, PostprocessorProtocol):
"""Standard implementation of the postprocessing pipeline."""
def __init__(
self,
samplerate: float,
min_freq: float,
max_freq: float,
top_k_per_sec: int = 200,
detection_threshold: float = 0.01,
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
):
"""Initialize the Postprocessor."""
super().__init__()
self.output_samplerate = samplerate
self.min_freq = min_freq
self.max_freq = max_freq
self.top_k_per_sec = top_k_per_sec
self.detection_threshold = detection_threshold
self.nms_kernel_size = nms_kernel_size
def forward(
self,
output: ModelOutput,
start_times: Optional[List[float]] = None,
) -> List[ClipDetectionsTensor]:
detection_heatmap = non_max_suppression(
output.detection_probs.detach(),
kernel_size=self.nms_kernel_size,
)
width = output.detection_probs.shape[-1]
duration = width / self.output_samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_detection_peaks(
detection_heatmap,
size_heatmap=output.size_preds,
feature_heatmap=output.features,
classification_heatmap=output.class_probs,
max_detections=max_detections,
threshold=self.detection_threshold,
)
if start_times is None:
start_times = [0 for _ in range(len(detections))]
return [
map_detection_to_clip(
detection,
start_time=0,
end_time=duration,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection in detections
]

View File

@ -20,7 +20,7 @@ import xarray as xr
from soundevent.arrays import Dimensions
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing.postprocess import DetectionsTensor
from batdetect2.typing.postprocess import ClipDetectionsTensor
__all__ = [
"features_to_xarray",
@ -31,15 +31,15 @@ __all__ = [
def map_detection_to_clip(
detections: DetectionsTensor,
detections: ClipDetectionsTensor,
start_time: float,
end_time: float,
min_freq: float,
max_freq: float,
) -> DetectionsTensor:
) -> ClipDetectionsTensor:
duration = end_time - start_time
bandwidth = max_freq - min_freq
return DetectionsTensor(
return ClipDetectionsTensor(
scores=detections.scores,
sizes=detections.sizes,
features=detections.features,

View File

@ -1,176 +1,19 @@
"""Main entry point for the BatDetect2 Preprocessing subsystem.
"""Main entry point for the BatDetect2 preprocessing subsystem."""
This package (`batdetect2.preprocessing`) defines and orchestrates the pipeline
for converting raw audio input (from files or data objects) into processed
spectrograms suitable for input to BatDetect2 models. This ensures consistent
data handling between model training and inference.
The preprocessing pipeline consists of two main stages, configured via nested
data structures:
1. **Audio Processing (`.audio`)**: Loads audio waveforms and applies initial
processing like resampling, duration adjustment, centering, and scaling.
Configured via `AudioConfig`.
2. **Spectrogram Generation (`.spectrogram`)**: Computes the spectrogram from
the processed waveform using STFT, followed by frequency cropping, optional
PCEN, amplitude scaling (dB, power, linear), optional denoising, optional
resizing, and optional peak normalization. Configured via
`SpectrogramConfig`.
This module provides the primary interface:
- `PreprocessingConfig`: A unified configuration object holding `AudioConfig`
and `SpectrogramConfig`.
- `load_preprocessing_config`: Function to load the unified configuration.
- `Preprocessor`: A protocol defining the interface for the end-to-end pipeline.
- `StandardPreprocessor`: The default implementation of the `Preprocessor`.
- `build_preprocessor`: A factory function to create a `StandardPreprocessor`
instance from a `PreprocessingConfig`.
"""
from typing import Optional
import torch
from loguru import logger
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess.audio import (
DEFAULT_DURATION,
SCALE_RAW_AUDIO,
TARGET_SAMPLERATE_HZ,
AudioConfig,
ResampleConfig,
build_audio_loader,
build_audio_pipeline,
from batdetect2.audio import TARGET_SAMPLERATE_HZ
from batdetect2.preprocess.config import (
PreprocessingConfig,
load_preprocessing_config,
)
from batdetect2.preprocess.spectrogram import (
MAX_FREQ,
MIN_FREQ,
FrequencyConfig,
PcenConfig,
SpectrogramConfig,
SpectrogramPipeline,
STFTConfig,
_spec_params_from_config,
build_spectrogram_builder,
build_spectrogram_pipeline,
)
from batdetect2.typing import PreprocessorProtocol
from batdetect2.preprocess.preprocessor import Preprocessor, build_preprocessor
from batdetect2.preprocess.spectrogram import MAX_FREQ, MIN_FREQ
__all__ = [
"AudioConfig",
"DEFAULT_DURATION",
"FrequencyConfig",
"MAX_FREQ",
"MIN_FREQ",
"PcenConfig",
"PreprocessingConfig",
"ResampleConfig",
"SCALE_RAW_AUDIO",
"STFTConfig",
"SpectrogramConfig",
"Preprocessor",
"TARGET_SAMPLERATE_HZ",
"build_audio_loader",
"build_spectrogram_builder",
"build_preprocessor",
"load_preprocessing_config",
]
class PreprocessingConfig(BaseConfig):
"""Unified configuration for the audio preprocessing pipeline.
Aggregates the configuration for both the initial audio processing stage
and the subsequent spectrogram generation stage.
Attributes
----------
audio : AudioConfig
Configuration settings for the audio loading and initial waveform
processing steps (e.g., resampling, duration adjustment, scaling).
Defaults to default `AudioConfig` settings if omitted.
spectrogram : SpectrogramConfig
Configuration settings for the spectrogram generation process
(e.g., STFT parameters, frequency cropping, scaling, denoising,
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
"""
audio: AudioConfig = Field(default_factory=AudioConfig)
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
def load_preprocessing_config(
path: PathLike,
field: Optional[str] = None,
) -> PreprocessingConfig:
return load_config(path, schema=PreprocessingConfig, field=field)
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
"""Standard implementation of the `Preprocessor` protocol."""
input_samplerate: int
output_samplerate: float
max_freq: float
min_freq: float
def __init__(
self,
audio_pipeline: torch.nn.Module,
spectrogram_pipeline: SpectrogramPipeline,
input_samplerate: int,
output_samplerate: float,
max_freq: float,
min_freq: float,
) -> None:
super().__init__()
self.audio_pipeline = audio_pipeline
self.spectrogram_pipeline = spectrogram_pipeline
self.max_freq = max_freq
self.min_freq = min_freq
self.input_samplerate = input_samplerate
self.output_samplerate = output_samplerate
def forward(self, wav: torch.Tensor) -> torch.Tensor:
wav = self.audio_pipeline(wav)
return self.spectrogram_pipeline(wav)
def compute_output_samplerate(config: PreprocessingConfig) -> float:
samplerate = config.audio.samplerate
_, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft)
factor = config.spectrogram.size.resize_factor
return samplerate * factor / hop_size
def build_preprocessor(
config: Optional[PreprocessingConfig] = None,
) -> PreprocessorProtocol:
"""Factory function to build the standard preprocessor from configuration."""
config = config or PreprocessingConfig()
logger.opt(lazy=True).debug(
"Building preprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
samplerate = config.audio.samplerate
min_freq = config.spectrogram.frequencies.min_freq
max_freq = config.spectrogram.frequencies.max_freq
output_samplerate = compute_output_samplerate(config)
return StandardPreprocessor(
audio_pipeline=build_audio_pipeline(config.audio),
spectrogram_pipeline=build_spectrogram_pipeline(
samplerate, config.spectrogram
),
input_samplerate=samplerate,
output_samplerate=output_samplerate,
min_freq=min_freq,
max_freq=max_freq,
)

View File

@ -1,307 +1,57 @@
"""Handles loading and initial preprocessing of audio waveforms."""
from typing import Annotated, Literal, Union
from typing import Annotated, List, Literal, Optional, Union
import numpy as np
import torch
from numpy.typing import DTypeLike
from pydantic import Field
from scipy.signal import resample, resample_poly
from soundevent import audio, data
from soundfile import LibsndfileError
from batdetect2.configs import BaseConfig
from batdetect2.preprocess.common import CenterTensor, PeakNormalize
from batdetect2.typing import AudioLoader
from batdetect2.audio import TARGET_SAMPLERATE_HZ
from batdetect2.core import BaseConfig, Registry
from batdetect2.preprocess.common import center_tensor, peak_normalize
__all__ = [
"ResampleConfig",
"AudioConfig",
"SoundEventAudioLoader",
"build_audio_loader",
"load_file_audio",
"load_recording_audio",
"load_clip_audio",
"resample_audio",
"TARGET_SAMPLERATE_HZ",
"SCALE_RAW_AUDIO",
"DEFAULT_DURATION",
"CenterAudioConfig",
"ScaleAudioConfig",
"FixDurationConfig",
"build_audio_transform",
]
TARGET_SAMPLERATE_HZ = 256_000
"""Default target sample rate in Hz used if resampling is enabled."""
SCALE_RAW_AUDIO = False
"""Default setting for whether to perform peak normalization."""
DEFAULT_DURATION = None
"""Default setting for target audio duration in seconds."""
class ResampleConfig(BaseConfig):
"""Configuration for audio resampling.
Attributes
----------
samplerate : int, default=256000
The target sample rate in Hz to resample the audio to. Must be > 0.
method : str, default="poly"
The resampling algorithm to use. Options:
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
Generally fast.
- "fourier": Resampling via Fourier method using
`scipy.signal.resample`. May handle non-integer
resampling factors differently.
"""
enabled: bool = True
method: str = "poly"
class SoundEventAudioLoader:
"""Concrete implementation of the `AudioLoader`."""
def __init__(
self,
samplerate: int = TARGET_SAMPLERATE_HZ,
config: Optional[ResampleConfig] = None,
):
self.samplerate = samplerate
self.config = config or ResampleConfig()
def load_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> np.ndarray:
"""Load and preprocess audio directly from a file path."""
return load_file_audio(
path,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> np.ndarray:
"""Load and preprocess the entire audio for a Recording object."""
return load_recording_audio(
recording,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> np.ndarray:
"""Load and preprocess the audio segment defined by a Clip object."""
return load_clip_audio(
clip,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_file_audio(
path: data.PathLike,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess audio from a file path using specified config."""
try:
recording = data.Recording.from_file(path)
except LibsndfileError as e:
raise FileNotFoundError(
f"Could not load the recording at path: {path}. Error: {e}"
) from e
return load_recording_audio(
recording,
samplerate=samplerate,
config=config,
dtype=dtype,
audio_dir=audio_dir,
)
def load_recording_audio(
recording: data.Recording,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess the entire audio content of a recording using config."""
clip = data.Clip(
recording=recording,
start_time=0,
end_time=recording.duration,
)
return load_clip_audio(
clip,
samplerate=samplerate,
config=config,
dtype=dtype,
audio_dir=audio_dir,
)
def load_clip_audio(
clip: data.Clip,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess a specific audio clip segment based on config."""
try:
wav = (
audio.load_clip(clip, audio_dir=audio_dir)
.sel(channel=0)
.astype(dtype)
)
except LibsndfileError as e:
raise FileNotFoundError(
f"Could not load the recording at path: {clip.recording.path}. "
f"Error: {e}"
) from e
if not config or not config.enabled or samplerate is None:
return wav.data.astype(dtype)
sr = int(1 / wav.time.attrs["step"])
return resample_audio(
wav.data,
sr=sr,
samplerate=samplerate,
method=config.method,
)
def resample_audio(
wav: np.ndarray,
sr: int,
samplerate: int = TARGET_SAMPLERATE_HZ,
method: str = "poly",
) -> np.ndarray:
"""Resample an audio waveform DataArray to a target sample rate."""
if sr == samplerate:
return wav
if method == "poly":
return resample_audio_poly(
wav,
sr_orig=sr,
sr_new=samplerate,
)
elif method == "fourier":
return resample_audio_fourier(
wav,
sr_orig=sr,
sr_new=samplerate,
)
else:
raise NotImplementedError(
f"Resampling method '{method}' not implemented"
)
def resample_audio_poly(
array: np.ndarray,
sr_orig: int,
sr_new: int,
axis: int = -1,
) -> np.ndarray:
"""Resample a numpy array using `scipy.signal.resample_poly`.
This method is often preferred for signals when the ratio of new
to old sample rates can be expressed as a rational number. It uses
polyphase filtering.
Parameters
----------
array : np.ndarray
The input array to resample.
sr_orig : int
The original sample rate in Hz.
sr_new : int
The target sample rate in Hz.
axis : int, default=-1
The axis of `array` along which to resample.
Returns
-------
np.ndarray
The array resampled to the target sample rate.
Raises
------
ValueError
If sample rates are not positive.
"""
gcd = np.gcd(sr_orig, sr_new)
return resample_poly(
array,
sr_new // gcd,
sr_orig // gcd,
axis=axis,
)
def resample_audio_fourier(
array: np.ndarray,
sr_orig: int,
sr_new: int,
axis: int = -1,
) -> np.ndarray:
"""Resample a numpy array using `scipy.signal.resample`.
This method uses FFTs to resample the signal.
Parameters
----------
array : np.ndarray
The input array to resample.
num : int
The desired number of samples in the output array along `axis`.
axis : int, default=-1
The axis of `array` along which to resample.
Returns
-------
np.ndarray
The array resampled to have `num` samples along `axis`.
Raises
------
ValueError
If `num` is negative.
"""
ratio = sr_new / sr_orig
return resample( # type: ignore
array,
int(array.shape[axis] * ratio),
axis=axis,
)
audio_transforms: Registry[torch.nn.Module, [int]] = Registry(
"audio_transform"
)
class CenterAudioConfig(BaseConfig):
name: Literal["center_audio"] = "center_audio"
class CenterAudio(torch.nn.Module):
def forward(self, wav: torch.Tensor) -> torch.Tensor:
return center_tensor(wav)
@classmethod
def from_config(cls, config: CenterAudioConfig, samplerate: int):
return cls()
audio_transforms.register(CenterAudioConfig, CenterAudio)
class ScaleAudioConfig(BaseConfig):
name: Literal["scale_audio"] = "scale_audio"
class ScaleAudio(torch.nn.Module):
def forward(self, wav: torch.Tensor) -> torch.Tensor:
return peak_normalize(wav)
@classmethod
def from_config(cls, config: ScaleAudioConfig, samplerate: int):
return cls()
audio_transforms.register(ScaleAudioConfig, ScaleAudio)
class FixDurationConfig(BaseConfig):
name: Literal["fix_duration"] = "fix_duration"
duration: float = 0.5
@ -325,6 +75,12 @@ class FixDuration(torch.nn.Module):
return torch.nn.functional.pad(wav, (0, self.length - length))
@classmethod
def from_config(cls, config: FixDurationConfig, samplerate: int):
return cls(samplerate=samplerate, duration=config.duration)
audio_transforms.register(FixDurationConfig, FixDuration)
AudioTransform = Annotated[
Union[
@ -336,47 +92,8 @@ AudioTransform = Annotated[
]
class AudioConfig(BaseConfig):
"""Configuration for loading and initial audio preprocessing."""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
transforms: List[AudioTransform] = Field(default_factory=list)
def build_audio_loader(
config: Optional[AudioConfig] = None,
) -> AudioLoader:
"""Factory function to create an AudioLoader based on configuration."""
config = config or AudioConfig()
return SoundEventAudioLoader(
samplerate=config.samplerate,
config=config.resample,
)
def build_audio_transform_step(
def build_audio_transform(
config: AudioTransform,
samplerate: int,
samplerate: int = TARGET_SAMPLERATE_HZ,
) -> torch.nn.Module:
if config.name == "fix_duration":
return FixDuration(samplerate=samplerate, duration=config.duration)
if config.name == "scale_audio":
return PeakNormalize()
if config.name == "center_audio":
return CenterTensor()
raise NotImplementedError(
f"Audio preprocessing step {config.name} not implemented"
)
def build_audio_pipeline(config: AudioConfig) -> torch.nn.Module:
return torch.nn.Sequential(
*[
build_audio_transform_step(step, samplerate=config.samplerate)
for step in config.transforms
]
)
return audio_transforms.build(config, samplerate)

View File

@ -1,24 +1,22 @@
import torch
__all__ = [
"CenterTensor",
"PeakNormalize",
"center_tensor",
"peak_normalize",
]
class CenterTensor(torch.nn.Module):
def forward(self, wav: torch.Tensor):
return wav - wav.mean()
def center_tensor(tensor: torch.Tensor) -> torch.Tensor:
return tensor - tensor.mean()
class PeakNormalize(torch.nn.Module):
def forward(self, wav: torch.Tensor):
max_value = wav.abs().min()
def peak_normalize(tensor: torch.Tensor) -> torch.Tensor:
max_value = tensor.abs().min()
denominator = torch.where(
max_value == 0,
torch.tensor(1.0, device=wav.device, dtype=wav.dtype),
max_value,
)
denominator = torch.where(
max_value == 0,
torch.tensor(1.0, device=tensor.device, dtype=tensor.dtype),
max_value,
)
return wav / denominator
return tensor / denominator

View File

@ -0,0 +1,62 @@
from typing import List, Optional
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.preprocess.audio import AudioTransform
from batdetect2.preprocess.spectrogram import (
FrequencyConfig,
PcenConfig,
ResizeConfig,
SpectralMeanSubstractionConfig,
SpectrogramTransform,
STFTConfig,
)
__all__ = [
"load_preprocessing_config",
"AudioTransform",
"PreprocessingConfig",
]
class PreprocessingConfig(BaseConfig):
"""Unified configuration for the audio preprocessing pipeline.
Aggregates the configuration for both the initial audio processing stage
and the subsequent spectrogram generation stage.
Attributes
----------
audio : AudioConfig
Configuration settings for the audio loading and initial waveform
processing steps (e.g., resampling, duration adjustment, scaling).
Defaults to default `AudioConfig` settings if omitted.
spectrogram : SpectrogramConfig
Configuration settings for the spectrogram generation process
(e.g., STFT parameters, frequency cropping, scaling, denoising,
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
"""
audio_transforms: List[AudioTransform] = Field(default_factory=list)
spectrogram_transforms: List[SpectrogramTransform] = Field(
default_factory=lambda: [
PcenConfig(),
SpectralMeanSubstractionConfig(),
]
)
stft: STFTConfig = Field(default_factory=STFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
size: ResizeConfig = Field(default_factory=ResizeConfig)
def load_preprocessing_config(
path: PathLike,
field: Optional[str] = None,
) -> PreprocessingConfig:
return load_config(path, schema=PreprocessingConfig, field=field)

View File

@ -0,0 +1,114 @@
from typing import Optional
import torch
from loguru import logger
from batdetect2.audio import TARGET_SAMPLERATE_HZ
from batdetect2.preprocess.audio import build_audio_transform
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.preprocess.spectrogram import (
_spec_params_from_config,
build_spectrogram_builder,
build_spectrogram_crop,
build_spectrogram_resizer,
build_spectrogram_transform,
)
from batdetect2.typing import PreprocessorProtocol
__all__ = [
"Preprocessor",
"build_preprocessor",
]
class Preprocessor(torch.nn.Module, PreprocessorProtocol):
"""Standard implementation of the `Preprocessor` protocol."""
input_samplerate: int
output_samplerate: float
max_freq: float
min_freq: float
def __init__(
self,
config: PreprocessingConfig,
input_samplerate: int,
) -> None:
super().__init__()
self.audio_transforms = torch.nn.Sequential(
*[
build_audio_transform(step, samplerate=input_samplerate)
for step in config.audio_transforms
]
)
self.spectrogram_transforms = torch.nn.Sequential(
*[
build_spectrogram_transform(step, samplerate=input_samplerate)
for step in config.spectrogram_transforms
]
)
self.spectrogram_builder = build_spectrogram_builder(
config.stft,
samplerate=input_samplerate,
)
self.spectrogram_crop = build_spectrogram_crop(
config.frequencies,
stft=config.stft,
samplerate=input_samplerate,
)
self.spectrogram_resizer = build_spectrogram_resizer(config.size)
self.min_freq = config.frequencies.min_freq
self.max_freq = config.frequencies.max_freq
self.input_samplerate = input_samplerate
self.output_samplerate = compute_output_samplerate(
config,
input_samplerate=input_samplerate,
)
def forward(self, wav: torch.Tensor) -> torch.Tensor:
wav = self.audio_transforms(wav)
spec = self.spectrogram_builder(wav)
return self.process_spectrogram(spec)
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
return self.spectrogram_builder(wav)
def process_audio(self, wav: torch.Tensor) -> torch.Tensor:
return self(wav)
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
spec = self.spectrogram_crop(spec)
spec = self.spectrogram_transforms(spec)
return self.spectrogram_resizer(spec)
def compute_output_samplerate(
config: PreprocessingConfig,
input_samplerate: int = TARGET_SAMPLERATE_HZ,
) -> float:
_, hop_size = _spec_params_from_config(
config.stft, samplerate=input_samplerate
)
factor = config.size.resize_factor
return input_samplerate * factor / hop_size
def build_preprocessor(
config: Optional[PreprocessingConfig] = None,
input_samplerate: int = TARGET_SAMPLERATE_HZ,
) -> PreprocessorProtocol:
"""Factory function to build the standard preprocessor from configuration."""
config = config or PreprocessingConfig()
logger.opt(lazy=True).debug(
"Building preprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
return Preprocessor(config=config, input_samplerate=input_samplerate)

View File

@ -1,31 +1,21 @@
"""Computes spectrograms from audio waveforms with configurable parameters."""
from typing import (
Annotated,
Callable,
List,
Literal,
Optional,
Sequence,
Union,
)
from typing import Annotated, Callable, Literal, Optional, Union
import numpy as np
import torch
import torchaudio
from pydantic import Field
from batdetect2.configs import BaseConfig
from batdetect2.preprocess.common import PeakNormalize
from batdetect2.audio import TARGET_SAMPLERATE_HZ
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.preprocess.common import peak_normalize
__all__ = [
"STFTConfig",
"FrequencyConfig",
"PcenConfig",
"SpectrogramConfig",
"build_spectrogram_transform",
"build_spectrogram_builder",
"MIN_FREQ",
"MAX_FREQ",
]
@ -60,6 +50,20 @@ class STFTConfig(BaseConfig):
window_fn: str = "hann"
def build_spectrogram_builder(
config: STFTConfig,
samplerate: int = TARGET_SAMPLERATE_HZ,
) -> torch.nn.Module:
n_fft, hop_length = _spec_params_from_config(config, samplerate=samplerate)
return torchaudio.transforms.Spectrogram(
n_fft=n_fft,
hop_length=hop_length,
window_fn=get_spectrogram_window(config.window_fn),
center=True,
power=1,
)
def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
if name == "hann":
return torch.hann_window
@ -81,24 +85,31 @@ def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
)
def _spec_params_from_config(samplerate: int, conf: STFTConfig):
n_fft = int(samplerate * conf.window_duration)
hop_length = int(n_fft * (1 - conf.window_overlap))
def _spec_params_from_config(
config: STFTConfig,
samplerate: int = TARGET_SAMPLERATE_HZ,
):
n_fft = int(samplerate * config.window_duration)
hop_length = int(n_fft * (1 - config.window_overlap))
return n_fft, hop_length
def build_spectrogram_builder(
samplerate: int,
conf: STFTConfig,
) -> torch.nn.Module:
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
return torchaudio.transforms.Spectrogram(
n_fft=n_fft,
hop_length=hop_length,
window_fn=get_spectrogram_window(conf.window_fn),
center=True,
power=1,
)
def _frequency_to_index(
freq: float,
n_fft: int,
samplerate: int = TARGET_SAMPLERATE_HZ,
) -> Optional[int]:
alpha = freq * 2 / samplerate
height = np.floor(n_fft / 2) + 1
index = int(np.floor(alpha * height))
if index <= 0:
return None
if index >= height:
return None
return index
class FrequencyConfig(BaseConfig):
@ -114,36 +125,36 @@ class FrequencyConfig(BaseConfig):
Frequencies below this value will be cropped. Must be >= 0.
"""
max_freq: int = Field(default=120_000, ge=0)
min_freq: int = Field(default=10_000, ge=0)
max_freq: int = Field(default=MAX_FREQ, ge=0)
min_freq: int = Field(default=MIN_FREQ, ge=0)
def _frequency_to_index(
freq: float,
samplerate: int,
n_fft: int,
) -> Optional[int]:
alpha = freq * 2 / samplerate
height = np.floor(n_fft / 2) + 1
index = int(np.floor(alpha * height))
if index <= 0:
return None
if index >= height:
return None
return index
class FrequencyClip(torch.nn.Module):
class FrequencyCrop(torch.nn.Module):
def __init__(
self,
low_index: Optional[int] = None,
high_index: Optional[int] = None,
samplerate: int,
n_fft: int,
min_freq: Optional[int] = None,
max_freq: Optional[int] = None,
):
super().__init__()
self.n_fft = n_fft
self.samplerate = samplerate
self.min_freq = min_freq
self.max_freq = max_freq
low_index = None
if min_freq is not None:
low_index = _frequency_to_index(
min_freq, self.samplerate, self.n_fft
)
self.low_index = low_index
high_index = None
if max_freq is not None:
high_index = _frequency_to_index(
max_freq, self.samplerate, self.n_fft
)
self.high_index = high_index
def forward(self, spec: torch.Tensor) -> torch.Tensor:
@ -164,6 +175,62 @@ class FrequencyClip(torch.nn.Module):
)
def build_spectrogram_crop(
config: FrequencyConfig,
stft: Optional[STFTConfig] = None,
samplerate: int = TARGET_SAMPLERATE_HZ,
) -> torch.nn.Module:
stft = stft or STFTConfig()
n_fft, _ = _spec_params_from_config(stft, samplerate=samplerate)
return FrequencyCrop(
samplerate=samplerate,
n_fft=n_fft,
min_freq=config.min_freq,
max_freq=config.max_freq,
)
class ResizeConfig(BaseConfig):
name: Literal["resize_spec"] = "resize_spec"
height: int = 128
resize_factor: float = 0.5
class ResizeSpec(torch.nn.Module):
def __init__(self, height: int, time_factor: float):
super().__init__()
self.height = height
self.time_factor = time_factor
def forward(self, spec: torch.Tensor) -> torch.Tensor:
current_length = spec.shape[-1]
target_length = int(self.time_factor * current_length)
original_ndim = spec.ndim
while spec.ndim < 4:
spec = spec.unsqueeze(0)
resized = torch.nn.functional.interpolate(
spec,
size=(self.height, target_length),
mode="bilinear",
)
while resized.ndim != original_ndim:
resized = resized.squeeze(0)
return resized
def build_spectrogram_resizer(config: ResizeConfig) -> torch.nn.Module:
return ResizeSpec(height=config.height, time_factor=config.resize_factor)
spectrogram_transforms: Registry[torch.nn.Module, [int]] = Registry(
"spectrogram_transform"
)
class PcenConfig(BaseConfig):
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
@ -182,7 +249,7 @@ class PCEN(torch.nn.Module):
bias: float = 2.0,
power: float = 0.5,
eps: float = 1e-6,
dtype=torch.float64,
dtype=torch.float32,
):
super().__init__()
self.smoothing_constant = smoothing_constant
@ -218,6 +285,19 @@ class PCEN(torch.nn.Module):
* torch.expm1(self.power * torch.log1p(S * smooth / self.bias))
).to(spec.dtype)
@classmethod
def from_config(cls, config: PcenConfig, samplerate: int):
smooth = _compute_smoothing_constant(samplerate, config.time_constant)
return cls(
smoothing_constant=smooth,
gain=config.gain,
bias=config.bias,
power=config.power,
)
spectrogram_transforms.register(PcenConfig, PCEN)
def _compute_smoothing_constant(
samplerate: int,
@ -241,16 +321,26 @@ class ToPower(torch.nn.Module):
return spec**2
def _build_amplitude_scaler(conf: ScaleAmplitudeConfig) -> torch.nn.Module:
if conf.scale == "db":
return torchaudio.transforms.AmplitudeToDB()
_scalers = {
"db": torchaudio.transforms.AmplitudeToDB,
"power": ToPower,
}
if conf.scale == "power":
return ToPower()
raise NotImplementedError(
f"Amplitude scaling {conf.scale} not implemented"
)
class ScaleAmplitude(torch.nn.Module):
def __init__(self, scale: Literal["power", "db"]):
self.scale = scale
self.scaler = _scalers[scale]()
def forward(self, spec: torch.Tensor) -> torch.Tensor:
return self.scaler(spec)
@classmethod
def from_config(cls, config: ScaleAmplitudeConfig, samplerate: int):
return cls(scale=config.scale)
spectrogram_transforms.register(ScaleAmplitudeConfig, ScaleAmplitude)
class SpectralMeanSubstractionConfig(BaseConfig):
@ -262,43 +352,36 @@ class SpectralMeanSubstraction(torch.nn.Module):
mean = spec.mean(-1, keepdim=True)
return (spec - mean).clamp(min=0)
class ResizeConfig(BaseConfig):
name: Literal["resize_spec"] = "resize_spec"
height: int = 128
resize_factor: float = 0.5
@classmethod
def from_config(
cls,
config: SpectralMeanSubstractionConfig,
samplerate: int,
):
return cls()
class ResizeSpec(torch.nn.Module):
def __init__(self, height: int, time_factor: float):
super().__init__()
self.height = height
self.time_factor = time_factor
def forward(self, spec: torch.Tensor) -> torch.Tensor:
current_length = spec.shape[-1]
target_length = int(self.time_factor * current_length)
original_ndim = spec.ndim
while spec.ndim < 4:
spec = spec.unsqueeze(0)
resized = torch.nn.functional.interpolate(
spec,
size=(self.height, target_length),
mode="bilinear",
)
while resized.ndim != original_ndim:
resized = resized.squeeze(0)
return resized
spectrogram_transforms.register(
SpectralMeanSubstractionConfig,
SpectralMeanSubstraction,
)
class PeakNormalizeConfig(BaseConfig):
name: Literal["peak_normalize"] = "peak_normalize"
class PeakNormalize(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor:
return peak_normalize(spec)
@classmethod
def from_config(cls, config: PeakNormalizeConfig, samplerate: int):
return cls()
spectrogram_transforms.register(PeakNormalizeConfig, PeakNormalize)
SpectrogramTransform = Annotated[
Union[
PcenConfig,
@ -310,114 +393,8 @@ SpectrogramTransform = Annotated[
]
class SpectrogramConfig(BaseConfig):
stft: STFTConfig = Field(default_factory=STFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
size: ResizeConfig = Field(default_factory=ResizeConfig)
transforms: Sequence[SpectrogramTransform] = Field(
default_factory=lambda: [
PcenConfig(),
SpectralMeanSubstractionConfig(),
]
)
def _build_spectrogram_transform_step(
step: SpectrogramTransform,
samplerate: int,
) -> torch.nn.Module:
if step.name == "pcen":
return PCEN(
smoothing_constant=_compute_smoothing_constant(
samplerate=samplerate,
time_constant=step.time_constant,
),
gain=step.gain,
bias=step.bias,
power=step.power,
)
if step.name == "scale_amplitude":
return _build_amplitude_scaler(step)
if step.name == "spectral_mean_substraction":
return SpectralMeanSubstraction()
if step.name == "peak_normalize":
return PeakNormalize()
raise NotImplementedError(
f"Spectrogram preprocessing step {step.name} not implemented"
)
def build_spectrogram_transform(
config: SpectrogramTransform,
samplerate: int,
conf: SpectrogramConfig,
) -> torch.nn.Module:
return torch.nn.Sequential(
*[
_build_spectrogram_transform_step(step, samplerate=samplerate)
for step in conf.transforms
]
)
class SpectrogramPipeline(torch.nn.Module):
def __init__(
self,
spec_builder: torch.nn.Module,
freq_cutter: torch.nn.Module,
transforms: torch.nn.Module,
resizer: torch.nn.Module,
):
super().__init__()
self.spec_builder = spec_builder
self.freq_cutter = freq_cutter
self.transforms = transforms
self.resizer = resizer
def forward(self, wav: torch.Tensor) -> torch.Tensor:
spec = self.spec_builder(wav)
spec = self.freq_cutter(spec)
spec = self.transforms(spec)
return self.resizer(spec)
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
return self.spec_builder(wav)
def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor:
return self.freq_cutter(spec)
def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
return self.transforms(spec)
def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
return self.resizer(spec)
def build_spectrogram_pipeline(
samplerate: int,
conf: SpectrogramConfig,
) -> SpectrogramPipeline:
spec_builder = build_spectrogram_builder(samplerate, conf.stft)
n_fft, _ = _spec_params_from_config(samplerate, conf.stft)
cutter = FrequencyClip(
low_index=_frequency_to_index(
conf.frequencies.min_freq, samplerate, n_fft
),
high_index=_frequency_to_index(
conf.frequencies.max_freq, samplerate, n_fft
),
)
transforms = build_spectrogram_transform(samplerate, conf)
resizer = ResizeSpec(
height=conf.size.height,
time_factor=conf.size.resize_factor,
)
return SpectrogramPipeline(
spec_builder=spec_builder,
freq_cutter=cutter,
transforms=transforms,
resizer=resizer,
)
return spectrogram_transforms.build(config, samplerate)

View File

@ -1,17 +1,6 @@
"""BatDetect2 Target Definition system."""
from collections import Counter
from typing import Iterable, List, Optional, Tuple
from loguru import logger
from pydantic import Field, field_validator
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.conditions import build_sound_event_condition
from batdetect2.targets.classes import (
DEFAULT_CLASSES,
DEFAULT_DETECTION_CLASS,
SoundEventDecoder,
SoundEventEncoder,
TargetClassConfig,
@ -19,23 +8,29 @@ from batdetect2.targets.classes import (
build_sound_event_encoder,
get_class_names_from_config,
)
from batdetect2.targets.config import TargetConfig, load_target_config
from batdetect2.targets.rois import (
AnchorBBoxMapperConfig,
ROIMapperConfig,
ROITargetMapper,
build_roi_mapper,
)
from batdetect2.targets.targets import (
Targets,
build_targets,
iterate_encoded_sound_events,
load_targets,
)
from batdetect2.targets.terms import (
call_type,
data_source,
generic_class,
individual,
)
from batdetect2.typing.targets import Position, Size, TargetProtocol
__all__ = [
"AnchorBBoxMapperConfig",
"DEFAULT_TARGET_CONFIG",
"ROIMapperConfig",
"ROITargetMapper",
"SoundEventDecoder",
"SoundEventEncoder",
@ -45,365 +40,13 @@ __all__ = [
"build_roi_mapper",
"build_sound_event_decoder",
"build_sound_event_encoder",
"build_targets",
"call_type",
"data_source",
"generic_class",
"get_class_names_from_config",
"individual",
"iterate_encoded_sound_events",
"load_target_config",
"load_targets",
]
class TargetConfig(BaseConfig):
detection_target: TargetClassConfig = Field(
default=DEFAULT_DETECTION_CLASS
)
classification_targets: List[TargetClassConfig] = Field(
default_factory=lambda: DEFAULT_CLASSES
)
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
@field_validator("classification_targets")
def check_unique_class_names(cls, v: List[TargetClassConfig]):
"""Ensure all defined class names are unique."""
names = [c.name for c in v]
if len(names) != len(set(names)):
name_counts = Counter(names)
duplicates = [
name for name, count in name_counts.items() if count > 1
]
raise ValueError(
"Class names must be unique. Found duplicates: "
f"{', '.join(duplicates)}"
)
return v
def load_target_config(
path: data.PathLike,
field: Optional[str] = None,
) -> TargetConfig:
"""Load the unified target configuration from a file.
Reads a configuration file (typically YAML) and validates it against the
`TargetConfig` schema, potentially extracting data from a nested field.
Parameters
----------
path : data.PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
target configuration. If None, the entire file content is used.
Returns
-------
TargetConfig
The loaded and validated unified target configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded configuration data does not conform to the
`TargetConfig` schema (including validation within nested configs
like `ClassesConfig`).
KeyError, TypeError
If `field` specifies an invalid path within the loaded data.
"""
return load_config(path=path, schema=TargetConfig, field=field)
class Targets(TargetProtocol):
"""Encapsulates the complete configured target definition pipeline.
This class implements the `TargetProtocol`, holding the configured
functions for filtering, transforming, encoding (tags to class name),
decoding (class name to tags), and mapping ROIs (geometry to position/size
and back). It provides a high-level interface to apply these steps and
access relevant metadata like class names and dimension names.
Instances are typically created using the `build_targets` factory function
or the `load_targets` convenience loader.
Attributes
----------
class_names : List[str]
An ordered list of the unique names of the specific target classes
defined in the configuration.
generic_class_tags : List[data.Tag]
A list of `soundevent.data.Tag` objects representing the configured
generic class category (used when no specific class matches).
dimension_names : List[str]
The names of the size dimensions handled by the ROI mapper
(e.g., ['width', 'height']).
"""
class_names: List[str]
detection_class_tags: List[data.Tag]
dimension_names: List[str]
detection_class_name: str
def __init__(self, config: TargetConfig):
"""Initialize the Targets object."""
self.config = config
self._filter_fn = build_sound_event_condition(
config.detection_target.match_if
)
self._encode_fn = build_sound_event_encoder(
config.classification_targets
)
self._decode_fn = build_sound_event_decoder(
config.classification_targets
)
self._roi_mapper = build_roi_mapper(config.roi)
self.dimension_names = self._roi_mapper.dimension_names
self.class_names = get_class_names_from_config(
config.classification_targets
)
self.detection_class_name = config.detection_target.name
self.detection_class_tags = config.detection_target.assign_tags
self._roi_mapper_overrides = {
class_config.name: build_roi_mapper(class_config.roi)
for class_config in config.classification_targets
if class_config.roi is not None
}
for class_name in self._roi_mapper_overrides:
if class_name not in self.class_names:
# TODO: improve this warning
logger.warning(
"The ROI mapper overrides contains a class ({class_name}) "
"not present in the class names.",
class_name=class_name,
)
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the configured filter to a sound event annotation.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation to filter.
Returns
-------
bool
True if the annotation should be kept (passes the filter),
False otherwise. If no filter was configured, always returns True.
"""
return self._filter_fn(sound_event)
def encode_class(
self, sound_event: data.SoundEventAnnotation
) -> Optional[str]:
"""Encode a sound event annotation to its target class name.
Applies the configured class definition rules (including priority)
to determine the specific class name for the annotation.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation to encode. Note: This should typically be called
*after* applying any transformations via the `transform` method.
Returns
-------
str or None
The name of the matched target class, or None if the annotation
does not match any specific class rule (i.e., it belongs to the
generic category).
"""
return self._encode_fn(sound_event)
def decode_class(self, class_label: str) -> List[data.Tag]:
"""Decode a predicted class name back into representative tags.
Uses the configured mapping (based on `TargetClass.output_tags` or
`TargetClass.tags`) to convert a class name string into a list of
`soundevent.data.Tag` objects.
Parameters
----------
class_label : str
The class name to decode.
Returns
-------
List[data.Tag]
The list of tags corresponding to the input class name.
"""
return self._decode_fn(class_label)
def encode_roi(
self, sound_event: data.SoundEventAnnotation
) -> tuple[Position, Size]:
"""Extract the target reference position from the annotation's roi.
Delegates to the internal ROI mapper's `get_roi_position` method.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI).
Returns
-------
Tuple[float, float]
The reference position `(time, frequency)`.
Raises
------
ValueError
If the annotation lacks geometry.
"""
class_name = self.encode_class(sound_event)
if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].encode(
sound_event.sound_event
)
return self._roi_mapper.encode(sound_event.sound_event)
def decode_roi(
self,
position: Position,
size: Size,
class_name: Optional[str] = None,
) -> data.Geometry:
"""Recover an approximate geometric ROI from a position and dimensions.
Delegates to the internal ROI mapper's `recover_roi` method, which
un-scales the dimensions and reconstructs the geometry (typically a
`BoundingBox`).
Parameters
----------
pos : Tuple[float, float]
The reference position `(time, frequency)`.
dims : np.ndarray
NumPy array with size dimensions (e.g., from model prediction),
matching the order in `self.dimension_names`.
Returns
-------
data.Geometry
The reconstructed geometry (typically `BoundingBox`).
"""
if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].decode(
position,
size,
)
return self._roi_mapper.decode(position, size)
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
classification_targets=DEFAULT_CLASSES,
detection_target=DEFAULT_DETECTION_CLASS,
roi=AnchorBBoxMapperConfig(),
)
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
"""Build a Targets object from a loaded TargetConfig.
Parameters
----------
config : TargetConfig
The loaded and validated unified target configuration object.
Returns
-------
Targets
An initialized `Targets` object ready for use.
Raises
------
KeyError
If term keys or derivation function keys specified in the `config`
are not found in their respective registries.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function fails (when configured).
"""
config = config or DEFAULT_TARGET_CONFIG
logger.opt(lazy=True).debug(
"Building targets with config: \n{}",
lambda: config.to_yaml_string(),
)
return Targets(config=config)
def load_targets(
config_path: data.PathLike,
field: Optional[str] = None,
) -> Targets:
"""Load a Targets object directly from a configuration file.
This convenience factory method loads the `TargetConfig` from the
specified file path and then calls `Targets.from_config` to build
the fully initialized `Targets` object.
Parameters
----------
config_path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
Dot-separated path to a nested section within the file containing
the target configuration. If None, the entire file content is used.
Returns
-------
Targets
An initialized `Targets` object ready for use.
Raises
------
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
TypeError
Errors raised during file loading, validation, or extraction via
`load_target_config`.
KeyError, ImportError, AttributeError, TypeError
Errors raised during the build process by `Targets.from_config`
(e.g., missing keys in registries, failed imports).
"""
config = load_target_config(
config_path,
field=field,
)
return build_targets(config)
def iterate_encoded_sound_events(
sound_events: Iterable[data.SoundEventAnnotation],
targets: TargetProtocol,
) -> Iterable[Tuple[Optional[str], Position, Size]]:
for sound_event in sound_events:
if not targets.filter(sound_event):
continue
geometry = sound_event.sound_event.geometry
if geometry is None:
continue
class_name = targets.encode_class(sound_event)
position, size = targets.encode_roi(sound_event)
yield class_name, position, size

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Optional
from pydantic import Field, PrivateAttr, computed_field, model_validator
from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.data.conditions import (
AllOfConfig,
HasAllTagsConfig,

View File

@ -0,0 +1,84 @@
from collections import Counter
from typing import List, Optional
from pydantic import Field, field_validator
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.targets.classes import (
DEFAULT_CLASSES,
DEFAULT_DETECTION_CLASS,
TargetClassConfig,
)
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig
__all__ = [
"TargetConfig",
"load_target_config",
]
class TargetConfig(BaseConfig):
detection_target: TargetClassConfig = Field(
default=DEFAULT_DETECTION_CLASS
)
classification_targets: List[TargetClassConfig] = Field(
default_factory=lambda: DEFAULT_CLASSES
)
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
@field_validator("classification_targets")
def check_unique_class_names(cls, v: List[TargetClassConfig]):
"""Ensure all defined class names are unique."""
names = [c.name for c in v]
if len(names) != len(set(names)):
name_counts = Counter(names)
duplicates = [
name for name, count in name_counts.items() if count > 1
]
raise ValueError(
"Class names must be unique. Found duplicates: "
f"{', '.join(duplicates)}"
)
return v
def load_target_config(
path: data.PathLike,
field: Optional[str] = None,
) -> TargetConfig:
"""Load the unified target configuration from a file.
Reads a configuration file (typically YAML) and validates it against the
`TargetConfig` schema, potentially extracting data from a nested field.
Parameters
----------
path : data.PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
target configuration. If None, the entire file content is used.
Returns
-------
TargetConfig
The loaded and validated unified target configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded configuration data does not conform to the
`TargetConfig` schema (including validation within nested configs
like `ClassesConfig`).
KeyError, TypeError
If `field` specifies an invalid path within the loaded data.
"""
return load_config(path=path, schema=TargetConfig, field=field)

View File

@ -26,12 +26,17 @@ import numpy as np
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core.arrays import spec_to_xarray
from batdetect2.core.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
from batdetect2.typing.targets import Position, ROITargetMapper, Size
from batdetect2.utils.arrays import spec_to_xarray
from batdetect2.typing import (
AudioLoader,
Position,
PreprocessorProtocol,
ROITargetMapper,
Size,
)
__all__ = [
"Anchor",
@ -260,6 +265,7 @@ class PeakEnergyBBoxMapperConfig(BaseConfig):
"""
name: Literal["peak_energy_bbox"] = "peak_energy_bbox"
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
@ -451,8 +457,11 @@ def build_roi_mapper(
)
if config.name == "peak_energy_bbox":
preprocessor = build_preprocessor(config.preprocessing)
audio_loader = build_audio_loader(config.preprocessing.audio)
audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor(
config.preprocessing,
input_samplerate=audio_loader.samplerate,
)
return PeakEnergyBBoxMapper(
preprocessor=preprocessor,
audio_loader=audio_loader,

View File

@ -0,0 +1,308 @@
from typing import Iterable, List, Optional, Tuple
from loguru import logger
from soundevent import data
from batdetect2.data.conditions import build_sound_event_condition
from batdetect2.targets.classes import (
DEFAULT_CLASSES,
DEFAULT_DETECTION_CLASS,
build_sound_event_decoder,
build_sound_event_encoder,
get_class_names_from_config,
)
from batdetect2.targets.config import TargetConfig, load_target_config
from batdetect2.targets.rois import (
AnchorBBoxMapperConfig,
build_roi_mapper,
)
from batdetect2.typing.targets import Position, Size, TargetProtocol
class Targets(TargetProtocol):
"""Encapsulates the complete configured target definition pipeline.
This class implements the `TargetProtocol`, holding the configured
functions for filtering, transforming, encoding (tags to class name),
decoding (class name to tags), and mapping ROIs (geometry to position/size
and back). It provides a high-level interface to apply these steps and
access relevant metadata like class names and dimension names.
Instances are typically created using the `build_targets` factory function
or the `load_targets` convenience loader.
Attributes
----------
class_names : List[str]
An ordered list of the unique names of the specific target classes
defined in the configuration.
generic_class_tags : List[data.Tag]
A list of `soundevent.data.Tag` objects representing the configured
generic class category (used when no specific class matches).
dimension_names : List[str]
The names of the size dimensions handled by the ROI mapper
(e.g., ['width', 'height']).
"""
class_names: List[str]
detection_class_tags: List[data.Tag]
dimension_names: List[str]
detection_class_name: str
def __init__(self, config: TargetConfig):
"""Initialize the Targets object."""
self.config = config
self._filter_fn = build_sound_event_condition(
config.detection_target.match_if
)
self._encode_fn = build_sound_event_encoder(
config.classification_targets
)
self._decode_fn = build_sound_event_decoder(
config.classification_targets
)
self._roi_mapper = build_roi_mapper(config.roi)
self.dimension_names = self._roi_mapper.dimension_names
self.class_names = get_class_names_from_config(
config.classification_targets
)
self.detection_class_name = config.detection_target.name
self.detection_class_tags = config.detection_target.assign_tags
self._roi_mapper_overrides = {
class_config.name: build_roi_mapper(class_config.roi)
for class_config in config.classification_targets
if class_config.roi is not None
}
for class_name in self._roi_mapper_overrides:
if class_name not in self.class_names:
# TODO: improve this warning
logger.warning(
"The ROI mapper overrides contains a class ({class_name}) "
"not present in the class names.",
class_name=class_name,
)
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the configured filter to a sound event annotation.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation to filter.
Returns
-------
bool
True if the annotation should be kept (passes the filter),
False otherwise. If no filter was configured, always returns True.
"""
return self._filter_fn(sound_event)
def encode_class(
self, sound_event: data.SoundEventAnnotation
) -> Optional[str]:
"""Encode a sound event annotation to its target class name.
Applies the configured class definition rules (including priority)
to determine the specific class name for the annotation.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation to encode. Note: This should typically be called
*after* applying any transformations via the `transform` method.
Returns
-------
str or None
The name of the matched target class, or None if the annotation
does not match any specific class rule (i.e., it belongs to the
generic category).
"""
return self._encode_fn(sound_event)
def decode_class(self, class_label: str) -> List[data.Tag]:
"""Decode a predicted class name back into representative tags.
Uses the configured mapping (based on `TargetClass.output_tags` or
`TargetClass.tags`) to convert a class name string into a list of
`soundevent.data.Tag` objects.
Parameters
----------
class_label : str
The class name to decode.
Returns
-------
List[data.Tag]
The list of tags corresponding to the input class name.
"""
return self._decode_fn(class_label)
def encode_roi(
self, sound_event: data.SoundEventAnnotation
) -> tuple[Position, Size]:
"""Extract the target reference position from the annotation's roi.
Delegates to the internal ROI mapper's `get_roi_position` method.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI).
Returns
-------
Tuple[float, float]
The reference position `(time, frequency)`.
Raises
------
ValueError
If the annotation lacks geometry.
"""
class_name = self.encode_class(sound_event)
if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].encode(
sound_event.sound_event
)
return self._roi_mapper.encode(sound_event.sound_event)
def decode_roi(
self,
position: Position,
size: Size,
class_name: Optional[str] = None,
) -> data.Geometry:
"""Recover an approximate geometric ROI from a position and dimensions.
Delegates to the internal ROI mapper's `recover_roi` method, which
un-scales the dimensions and reconstructs the geometry (typically a
`BoundingBox`).
Parameters
----------
pos : Tuple[float, float]
The reference position `(time, frequency)`.
dims : np.ndarray
NumPy array with size dimensions (e.g., from model prediction),
matching the order in `self.dimension_names`.
Returns
-------
data.Geometry
The reconstructed geometry (typically `BoundingBox`).
"""
if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].decode(
position,
size,
)
return self._roi_mapper.decode(position, size)
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
classification_targets=DEFAULT_CLASSES,
detection_target=DEFAULT_DETECTION_CLASS,
roi=AnchorBBoxMapperConfig(),
)
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
"""Build a Targets object from a loaded TargetConfig.
Parameters
----------
config : TargetConfig
The loaded and validated unified target configuration object.
Returns
-------
Targets
An initialized `Targets` object ready for use.
Raises
------
KeyError
If term keys or derivation function keys specified in the `config`
are not found in their respective registries.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function fails (when configured).
"""
config = config or DEFAULT_TARGET_CONFIG
logger.opt(lazy=True).debug(
"Building targets with config: \n{}",
lambda: config.to_yaml_string(),
)
return Targets(config=config)
def load_targets(
config_path: data.PathLike,
field: Optional[str] = None,
) -> Targets:
"""Load a Targets object directly from a configuration file.
This convenience factory method loads the `TargetConfig` from the
specified file path and then calls `Targets.from_config` to build
the fully initialized `Targets` object.
Parameters
----------
config_path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
Dot-separated path to a nested section within the file containing
the target configuration. If None, the entire file content is used.
Returns
-------
Targets
An initialized `Targets` object ready for use.
Raises
------
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
TypeError
Errors raised during file loading, validation, or extraction via
`load_target_config`.
KeyError, ImportError, AttributeError, TypeError
Errors raised during the build process by `Targets.from_config`
(e.g., missing keys in registries, failed imports).
"""
config = load_target_config(
config_path,
field=field,
)
return build_targets(config)
def iterate_encoded_sound_events(
sound_events: Iterable[data.SoundEventAnnotation],
targets: TargetProtocol,
) -> Iterable[Tuple[Optional[str], Position, Size]]:
for sound_event in sound_events:
if not targets.filter(sound_event):
continue
geometry = sound_event.sound_event.geometry
if geometry is None:
continue
class_name = targets.encode_class(sound_event)
position, size = targets.encode_roi(sound_event)
yield class_name, position, size

View File

@ -14,12 +14,9 @@ from batdetect2.train.augmentations import (
scale_volume,
warp_spectrogram,
)
from batdetect2.train.clips import build_clipper, select_subclip
from batdetect2.train.config import (
FullTrainingConfig,
PLTrainerConfig,
TrainingConfig,
load_full_training_config,
load_train_config,
)
from batdetect2.train.dataset import (
@ -48,7 +45,6 @@ __all__ = [
"DetectionLossConfig",
"EchoAugmentationConfig",
"FrequencyMaskAugmentationConfig",
"FullTrainingConfig",
"LossConfig",
"LossFunction",
"PLTrainerConfig",
@ -64,21 +60,18 @@ __all__ = [
"add_echo",
"build_augmentations",
"build_clip_labeler",
"build_clipper",
"build_loss",
"build_train_dataset",
"build_train_loader",
"build_trainer",
"build_val_dataset",
"build_val_loader",
"load_full_training_config",
"load_label_config",
"load_train_config",
"mask_frequency",
"mask_time",
"mix_audio",
"scale_volume",
"select_subclip",
"train",
"warp_spectrogram",
]

View File

@ -11,11 +11,10 @@ from pydantic import Field
from soundevent import data
from soundevent.geometry import scale_geometry, shift_geometry
from batdetect2.configs import BaseConfig, load_config
from batdetect2.train.clips import get_subclip_annotation
from batdetect2.typing import Augmentation
from batdetect2.typing.preprocess import AudioLoader
from batdetect2.utils.arrays import adjust_width
from batdetect2.audio.clips import get_subclip_annotation
from batdetect2.core.arrays import adjust_width
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.typing import AudioLoader, Augmentation
__all__ = [
"AugmentationConfig",

View File

@ -5,19 +5,21 @@ from lightning.pytorch.callbacks import Callback
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate import Evaluator
from batdetect2.postprocess import get_raw_predictions
from batdetect2.logging import get_image_logger
from batdetect2.postprocess import to_raw_predictions
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import get_image_plotter
from batdetect2.typing.evaluate import ClipEvaluation
from batdetect2.typing.models import ModelOutput
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.train import TrainExample
from batdetect2.typing import (
ClipEvaluation,
EvaluatorProtocol,
ModelOutput,
RawPrediction,
TrainExample,
)
class ValidationMetrics(Callback):
def __init__(self, evaluator: Evaluator):
def __init__(self, evaluator: EvaluatorProtocol):
super().__init__()
self.evaluator = evaluator
@ -32,12 +34,12 @@ class ValidationMetrics(Callback):
assert isinstance(dataset, ValidationDataset)
return dataset
def plot_examples(
def generate_plots(
self,
pl_module: LightningModule,
evaluated_clips: List[ClipEvaluation],
):
plotter = get_image_plotter(pl_module.logger) # type: ignore
plotter = get_image_logger(pl_module.logger) # type: ignore
if plotter is None:
return
@ -64,7 +66,7 @@ class ValidationMetrics(Callback):
)
self.log_metrics(pl_module, clip_evaluations)
self.plot_examples(pl_module, clip_evaluations)
self.generate_plots(pl_module, clip_evaluations)
return super().on_validation_epoch_end(trainer, pl_module)
@ -86,8 +88,7 @@ class ValidationMetrics(Callback):
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
postprocessor = pl_module.model.postprocessor
targets = pl_module.model.targets
model = pl_module.model
dataset = self.get_dataset(trainer)
clip_annotations = [
@ -95,15 +96,14 @@ class ValidationMetrics(Callback):
for example_idx in batch.idx
]
predictions = get_raw_predictions(
clip_detections = model.postprocessor(
outputs,
start_times=[
clip_annotation.clip.start_time
for clip_annotation in clip_annotations
],
targets=targets,
postprocessor=postprocessor,
start_times=[ca.clip.start_time for ca in clip_annotations],
)
predictions = [
to_raw_predictions(clip_dets.numpy(), targets=model.targets)
for clip_dets in clip_detections
]
self._clip_annotations.extend(clip_annotations)
self._predictions.extend(predictions)

View File

@ -3,27 +3,16 @@ from typing import Optional, Union
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate import EvaluationConfig
from batdetect2.models import ModelConfig
from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
)
from batdetect2.train.clips import (
ClipConfig,
PaddedClipConfig,
RandomClipConfig,
)
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
from batdetect2.train.labels import LabelConfig
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
from batdetect2.train.losses import LossConfig
__all__ = [
"TrainingConfig",
"load_train_config",
"FullTrainingConfig",
"load_full_training_config",
]
@ -48,30 +37,6 @@ class PLTrainerConfig(BaseConfig):
val_check_interval: Optional[Union[int, float]] = None
class ValLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
class TrainLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8
shuffle: bool = False
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
)
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
class OptimizerConfig(BaseConfig):
learning_rate: float = 1e-3
t_max: int = 100
@ -80,13 +45,12 @@ class OptimizerConfig(BaseConfig):
class TrainingConfig(BaseConfig):
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
loss: LossConfig = Field(default_factory=LossConfig)
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
def load_train_config(
@ -94,18 +58,3 @@ def load_train_config(
field: Optional[str] = None,
) -> TrainingConfig:
return load_config(path, schema=TrainingConfig, field=field)
class FullTrainingConfig(ModelConfig):
"""Full training configuration."""
train: TrainingConfig = Field(default_factory=TrainingConfig)
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
def load_full_training_config(
path: data.PathLike,
field: Optional[str] = None,
) -> FullTrainingConfig:
"""Load the full training configuration."""
return load_config(path, schema=FullTrainingConfig, field=field)

View File

@ -2,22 +2,30 @@ from typing import List, Optional, Sequence
import torch
from loguru import logger
from pydantic import Field
from soundevent import data
from torch.utils.data import DataLoader, Dataset
from batdetect2.plotting.clips import build_audio_loader
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
from batdetect2.audio.clips import PaddedClipConfig
from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor
from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
RandomAudioSource,
build_augmentations,
)
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig
from batdetect2.train.labels import build_clip_labeler
from batdetect2.typing import ClipperProtocol, TrainExample
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
from batdetect2.typing.train import Augmentation, ClipLabeller
from batdetect2.utils.arrays import adjust_width
from batdetect2.typing import (
AudioLoader,
Augmentation,
ClipLabeller,
ClipperProtocol,
PreprocessorProtocol,
TrainExample,
)
__all__ = [
"TrainingDataset",
@ -139,6 +147,22 @@ class ValidationDataset(Dataset):
)
class TrainLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8
shuffle: bool = False
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
)
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
def build_train_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
@ -173,6 +197,14 @@ def build_train_loader(
)
class ValLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
def build_val_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,

View File

@ -13,14 +13,10 @@ import torch
from loguru import logger
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets import build_targets, iterate_encoded_sound_events
from batdetect2.typing import (
ClipLabeller,
Heatmaps,
TargetProtocol,
)
from batdetect2.typing import ClipLabeller, Heatmaps, TargetProtocol
__all__ = [
"LabelConfig",

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple
import lightning as L
import torch
@ -6,11 +6,17 @@ from soundevent.data import PathLike
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.audio import TARGET_SAMPLERATE_HZ
from batdetect2.models import Model, build_model
from batdetect2.train.config import FullTrainingConfig
from batdetect2.plotting.clips import build_preprocessor
from batdetect2.postprocess import build_postprocessor
from batdetect2.targets.targets import build_targets
from batdetect2.train.losses import build_loss
from batdetect2.typing import ModelOutput, TrainExample
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
__all__ = [
"TrainingModule",
]
@ -21,7 +27,8 @@ class TrainingModule(L.LightningModule):
def __init__(
self,
config: FullTrainingConfig,
config: "BatDetect2Config",
input_samplerate: int = TARGET_SAMPLERATE_HZ,
learning_rate: float = 0.001,
t_max: int = 100,
model: Optional[Model] = None,
@ -31,6 +38,7 @@ class TrainingModule(L.LightningModule):
self.save_hyperparameters(logger=False)
self.input_samplerate = input_samplerate
self.config = config
self.learning_rate = learning_rate
self.t_max = t_max
@ -39,7 +47,23 @@ class TrainingModule(L.LightningModule):
loss = build_loss(self.config.train.loss)
if model is None:
model = build_model(self.config)
targets = build_targets(self.config.targets)
preprocessor = build_preprocessor(
config=self.config.preprocess,
input_samplerate=self.input_samplerate,
)
postprocessor = build_postprocessor(
preprocessor, config=self.config.postprocess
)
model = build_model(
config=self.config.model,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
self.loss = loss
self.model = model
@ -74,16 +98,18 @@ class TrainingModule(L.LightningModule):
def load_model_from_checkpoint(
path: PathLike,
) -> Tuple[Model, FullTrainingConfig]:
) -> Tuple[Model, "BatDetect2Config"]:
module = TrainingModule.load_from_checkpoint(path) # type: ignore
return module.model, module.config
def build_training_module(
config: Optional[FullTrainingConfig] = None,
config: Optional["BatDetect2Config"] = None,
t_max: int = 200,
) -> TrainingModule:
config = config or FullTrainingConfig()
from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config()
return TrainingModule(
config=config,
learning_rate=config.train.optimizer.learning_rate,

View File

@ -27,7 +27,7 @@ from loguru import logger
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample
__all__ = [

View File

@ -1,29 +1,31 @@
from collections.abc import Sequence
from pathlib import Path
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional
from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from loguru import logger
from soundevent import data
from batdetect2.audio import build_audio_loader
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
from batdetect2.logging import build_logger
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.config import (
FullTrainingConfig,
)
from batdetect2.train.dataset import build_train_loader, build_val_loader
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import TrainingModule, build_training_module
from batdetect2.train.logging import build_logger
from batdetect2.typing import (
TargetProtocol,
)
from batdetect2.typing.preprocess import AudioLoader
from batdetect2.typing.train import ClipLabeller
from batdetect2.train.lightning import build_training_module
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
from batdetect2.typing import (
AudioLoader,
ClipLabeller,
EvaluatorProtocol,
PreprocessorProtocol,
TargetProtocol,
)
__all__ = [
"build_trainer",
@ -36,13 +38,12 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
def train(
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
targets: Optional["TargetProtocol"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
labeller: Optional["ClipLabeller"] = None,
config: Optional["BatDetect2Config"] = None,
trainer: Optional[Trainer] = None,
targets: Optional[TargetProtocol] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
config: Optional[FullTrainingConfig] = None,
model_path: Optional[data.PathLike] = None,
train_workers: Optional[int] = None,
val_workers: Optional[int] = None,
checkpoint_dir: Optional[Path] = None,
@ -51,17 +52,20 @@ def train(
run_name: Optional[str] = None,
seed: Optional[int] = None,
):
from batdetect2.config import BatDetect2Config
if seed is not None:
seed_everything(seed)
config = config or FullTrainingConfig()
config = config or BatDetect2Config()
targets = targets or build_targets(config.targets)
targets = targets or build_targets(config=config.targets)
preprocessor = preprocessor or build_preprocessor(config.preprocess)
audio_loader = audio_loader or build_audio_loader(config=config.audio)
audio_loader = audio_loader or build_audio_loader(
config=config.preprocess.audio
preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
)
labeller = labeller or build_clip_labeler(
@ -93,18 +97,15 @@ def train(
else None
)
if model_path is not None:
logger.debug("Loading model from: {path}", path=model_path)
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = build_training_module(
config,
t_max=config.train.optimizer.t_max * len(train_dataloader),
)
module = build_training_module(
config,
t_max=config.train.optimizer.t_max * len(train_dataloader),
)
trainer = trainer or build_trainer(
config,
targets=targets,
evaluator=build_evaluator(config.train.validation, targets=targets),
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
experiment_name=experiment_name,
@ -121,8 +122,8 @@ def train(
def build_trainer_callbacks(
targets: TargetProtocol,
config: FullTrainingConfig,
targets: "TargetProtocol",
evaluator: Optional["EvaluatorProtocol"] = None,
checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
@ -136,13 +137,12 @@ def build_trainer_callbacks(
if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name
evaluator = build_evaluator(config=config.evaluation, targets=targets)
evaluator = evaluator or build_evaluator(targets=targets)
return [
ModelCheckpoint(
dirpath=str(checkpoint_dir),
save_top_k=1,
filename="best-{epoch:02d}-{val_loss:.0f}",
monitor="total_loss/val",
),
ValidationMetrics(evaluator),
@ -150,8 +150,9 @@ def build_trainer_callbacks(
def build_trainer(
conf: FullTrainingConfig,
targets: TargetProtocol,
conf: "BatDetect2Config",
targets: "TargetProtocol",
evaluator: Optional["EvaluatorProtocol"] = None,
checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
@ -181,7 +182,7 @@ def build_trainer(
logger=train_logger,
callbacks=build_trainer_callbacks(
targets,
config=conf,
evaluator=evaluator,
checkpoint_dir=checkpoint_dir,
experiment_name=experiment_name,
run_name=run_name,

View File

@ -1,4 +1,10 @@
from batdetect2.typing.evaluate import MatchEvaluation, MetricsProtocol
from batdetect2.typing.evaluate import (
ClipEvaluation,
EvaluatorProtocol,
MatchEvaluation,
MetricsProtocol,
PlotterProtocol,
)
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
from batdetect2.typing.postprocess import (
BatDetect2Prediction,
@ -9,10 +15,10 @@ from batdetect2.typing.postprocess import (
from batdetect2.typing.preprocess import (
AudioLoader,
PreprocessorProtocol,
SpectrogramBuilder,
)
from batdetect2.typing.targets import (
Position,
ROITargetMapper,
Size,
SoundEventDecoder,
SoundEventEncoder,
@ -34,6 +40,7 @@ __all__ = [
"Augmentation",
"BackboneModel",
"BatDetect2Prediction",
"ClipEvaluation",
"ClipLabeller",
"ClipperProtocol",
"DetectionModel",
@ -44,15 +51,17 @@ __all__ = [
"MatchEvaluation",
"MetricsProtocol",
"ModelOutput",
"PlotterProtocol",
"Position",
"PostprocessorProtocol",
"PreprocessorProtocol",
"ROITargetMapper",
"RawPrediction",
"Size",
"SoundEventDecoder",
"SoundEventEncoder",
"SoundEventFilter",
"SpectrogramBuilder",
"TargetProtocol",
"TrainExample",
"EvaluatorProtocol",
]

View File

@ -14,7 +14,11 @@ from typing import (
from matplotlib.figure import Figure
from soundevent import data
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"EvaluatorProtocol",
"MetricsProtocol",
"MatchEvaluation",
]
@ -50,6 +54,26 @@ class MatchEvaluation:
return self.pred_class_scores[pred_class]
def is_true_positive(self, threshold: float = 0) -> bool:
return (
self.gt_det
and self.pred_score > threshold
and self.gt_class == self.pred_class
)
def is_false_positive(self, threshold: float = 0) -> bool:
return self.gt_det is None and self.pred_score > threshold
def is_false_negative(self, threshold: float = 0) -> bool:
return self.gt_det and self.pred_score <= threshold
def is_cross_trigger(self, threshold: float = 0) -> bool:
return (
self.gt_det
and self.pred_score > threshold
and self.gt_class != self.pred_class
)
@dataclass
class ClipEvaluation:
@ -87,3 +111,21 @@ class PlotterProtocol(Protocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Iterable[Tuple[str, Figure]]: ...
class EvaluatorProtocol(Protocol):
targets: TargetProtocol
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> List[ClipEvaluation]: ...
def compute_metrics(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]: ...
def generate_plots(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Iterable[Tuple[str, Figure]]: ...

View File

@ -12,7 +12,7 @@ system that deal with model predictions.
"""
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Protocol
from typing import List, NamedTuple, Optional, Protocol, Sequence
import numpy as np
import torch
@ -47,15 +47,13 @@ class GeometryDecoder(Protocol):
class RawPrediction(NamedTuple):
"""Intermediate representation of a single detected sound event."""
geometry: data.Geometry
detection_score: float
class_scores: np.ndarray
features: np.ndarray
class DetectionsArray(NamedTuple):
class ClipDetectionsArray(NamedTuple):
scores: np.ndarray
sizes: np.ndarray
class_scores: np.ndarray
@ -64,7 +62,7 @@ class DetectionsArray(NamedTuple):
features: np.ndarray
class DetectionsTensor(NamedTuple):
class ClipDetectionsTensor(NamedTuple):
scores: torch.Tensor
sizes: torch.Tensor
class_scores: torch.Tensor
@ -72,8 +70,8 @@ class DetectionsTensor(NamedTuple):
frequencies: torch.Tensor
features: torch.Tensor
def numpy(self) -> DetectionsArray:
return DetectionsArray(
def numpy(self) -> ClipDetectionsArray:
return ClipDetectionsArray(
scores=self.scores.detach().cpu().numpy(),
sizes=self.sizes.detach().cpu().numpy(),
class_scores=self.class_scores.detach().cpu().numpy(),
@ -92,10 +90,8 @@ class BatDetect2Prediction:
class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline."""
def __call__(self, output: ModelOutput) -> List[DetectionsTensor]: ...
def get_detections(
def __call__(
self,
output: ModelOutput,
start_times: Optional[List[float]] = None,
) -> List[DetectionsTensor]: ...
start_times: Optional[Sequence[float]] = None,
) -> List[ClipDetectionsTensor]: ...

View File

@ -32,6 +32,8 @@ class AudioLoader(Protocol):
allows for different loading strategies or implementations.
"""
samplerate: int
def load_file(
self,
path: data.PathLike,
@ -125,22 +127,6 @@ class SpectrogramBuilder(Protocol):
...
class AudioPipeline(Protocol):
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
class SpectrogramPipeline(Protocol):
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor: ...
def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
class PreprocessorProtocol(Protocol):
"""Defines a high-level interface for the complete preprocessing pipeline."""
@ -152,11 +138,13 @@ class PreprocessorProtocol(Protocol):
output_samplerate: float
audio_pipeline: AudioPipeline
spectrogram_pipeline: SpectrogramPipeline
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
def process_audio(self, wav: torch.Tensor) -> torch.Tensor: ...
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
return self(torch.tensor(wav)).numpy()