Refactor eval code

This commit is contained in:
mbsantiago 2025-08-31 22:57:02 +01:00
parent 356be57f62
commit 5b9a5a968f
17 changed files with 273 additions and 354 deletions

View File

@ -2,14 +2,10 @@ from batdetect2.evaluate.config import (
EvaluationConfig, EvaluationConfig,
load_evaluation_config, load_evaluation_config,
) )
from batdetect2.evaluate.match import ( from batdetect2.evaluate.match import match_predictions_and_annotations
match_predictions_and_annotations,
match_sound_events_and_raw_predictions,
)
__all__ = [ __all__ = [
"EvaluationConfig", "EvaluationConfig",
"load_evaluation_config", "load_evaluation_config",
"match_predictions_and_annotations", "match_predictions_and_annotations",
"match_sound_events_and_raw_predictions",
] ]

View File

@ -3,6 +3,7 @@ from dataclasses import dataclass, field
from typing import List, Literal, Optional, Protocol, Tuple from typing import List, Literal, Optional, Protocol, Tuple
import numpy as np import numpy as np
from loguru import logger
from soundevent import data from soundevent import data
from soundevent.evaluation import compute_affinity from soundevent.evaluation import compute_affinity
from soundevent.evaluation import match_geometries as optimal_match from soundevent.evaluation import match_geometries as optimal_match
@ -10,10 +11,10 @@ from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.typing import ( from batdetect2.typing import (
BatDetect2Prediction,
MatchEvaluation, MatchEvaluation,
TargetProtocol, TargetProtocol,
) )
from batdetect2.typing.postprocess import RawPrediction
MatchingStrategy = Literal["greedy", "optimal"] MatchingStrategy = Literal["greedy", "optimal"]
"""The type of matching algorithm to use: 'greedy' or 'optimal'.""" """The type of matching algorithm to use: 'greedy' or 'optimal'."""
@ -274,7 +275,7 @@ def greedy_match(
def match_sound_events_and_raw_predictions( def match_sound_events_and_raw_predictions(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
raw_predictions: List[BatDetect2Prediction], raw_predictions: List[RawPrediction],
targets: TargetProtocol, targets: TargetProtocol,
config: Optional[MatchConfig] = None, config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]: ) -> List[MatchEvaluation]:
@ -294,12 +295,11 @@ def match_sound_events_and_raw_predictions(
] ]
predicted_geometries = [ predicted_geometries = [
raw_prediction.raw.geometry for raw_prediction in raw_predictions raw_prediction.geometry for raw_prediction in raw_predictions
] ]
scores = [ scores = [
raw_prediction.raw.detection_score raw_prediction.detection_score for raw_prediction in raw_predictions
for raw_prediction in raw_predictions
] ]
matches = [] matches = []
@ -320,14 +320,20 @@ def match_sound_events_and_raw_predictions(
gt_det = target is not None gt_det = target is not None
gt_class = targets.encode_class(target) if target is not None else None gt_class = targets.encode_class(target) if target is not None else None
pred_score = float(prediction.raw.detection_score) if prediction else 0 pred_score = float(prediction.detection_score) if prediction else 0
pred_geometry = (
predicted_geometries[source_idx]
if source_idx is not None
else None
)
class_scores = ( class_scores = (
{ {
str(class_name): float(score) str(class_name): float(score)
for class_name, score in zip( for class_name, score in zip(
targets.class_names, targets.class_names,
prediction.raw.class_scores, prediction.class_scores,
) )
} }
if prediction is not None if prediction is not None
@ -336,17 +342,14 @@ def match_sound_events_and_raw_predictions(
matches.append( matches.append(
MatchEvaluation( MatchEvaluation(
match=data.Match( clip=clip_annotation.clip,
source=None sound_event_annotation=target,
if prediction is None
else prediction.sound_event_prediction,
target=target,
affinity=affinity,
),
gt_det=gt_det, gt_det=gt_det,
gt_class=gt_class, gt_class=gt_class,
pred_score=pred_score, pred_score=pred_score,
pred_class_scores=class_scores, pred_class_scores=class_scores,
pred_geometry=pred_geometry,
affinity=affinity,
) )
) )
@ -418,6 +421,28 @@ def match_predictions_and_annotations(
return matches return matches
def match_all_predictions(
clip_annotations: List[data.ClipAnnotation],
predictions: List[List[RawPrediction]],
targets: TargetProtocol,
config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions...")
return [
match
for clip_annotation, raw_predictions in zip(
clip_annotations,
predictions,
)
for match in match_sound_events_and_raw_predictions(
clip_annotation,
raw_predictions,
targets=targets,
config=config,
)
]
@dataclass @dataclass
class ClassExamples: class ClassExamples:
false_positives: List[MatchEvaluation] = field(default_factory=list) false_positives: List[MatchEvaluation] = field(default_factory=list)

View File

@ -68,7 +68,7 @@ from batdetect2.postprocess import PostprocessConfig, build_postprocessor
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.targets import TargetConfig, build_targets from batdetect2.targets import TargetConfig, build_targets
from batdetect2.typing.models import DetectionModel from batdetect2.typing.models import DetectionModel
from batdetect2.typing.postprocess import Detections, PostprocessorProtocol from batdetect2.typing.postprocess import DetectionsArray, PostprocessorProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol from batdetect2.typing.targets import TargetProtocol
@ -122,7 +122,7 @@ class Model(LightningModule):
self.targets = targets self.targets = targets
self.save_hyperparameters() self.save_hyperparameters()
def forward(self, wav: torch.Tensor) -> List[Detections]: def forward(self, wav: torch.Tensor) -> List[DetectionsArray]:
spec = self.preprocessor(wav) spec = self.preprocessor(wav)
outputs = self.detector(spec) outputs = self.detector(spec)
return self.postprocessor(outputs) return self.postprocessor(outputs)

View File

@ -124,25 +124,21 @@ def plot_false_positive_match(
add_points: bool = False, add_points: bool = False,
fill: bool = False, fill: bool = False,
spec_cmap: str = "gray", spec_cmap: str = "gray",
time_offset: float = 0,
color: str = DEFAULT_FALSE_POSITIVE_COLOR, color: str = DEFAULT_FALSE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small", fontsize: Union[float, str] = "small",
) -> Axes: ) -> Axes:
assert match.match.source is not None assert match.pred_geometry is not None
assert match.match.target is None assert match.sound_event_annotation is None
sound_event = match.match.source.sound_event
geometry = sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry) start_time, _, _, high_freq = compute_bounds(match.pred_geometry)
clip = data.Clip( clip = data.Clip(
start_time=max(start_time - duration / 2, 0), start_time=max(start_time - duration / 2, 0),
end_time=min( end_time=min(
start_time + duration / 2, start_time + duration / 2,
sound_event.recording.duration, match.clip.end_time,
), ),
recording=sound_event.recording, recording=match.clip.recording,
) )
ax = plot_clip( ax = plot_clip(
@ -154,11 +150,9 @@ def plot_false_positive_match(
spec_cmap=spec_cmap, spec_cmap=spec_cmap,
) )
plot_prediction( plot.plot_geometry(
match.match.source, match.pred_geometry,
ax=ax, ax=ax,
time_offset=time_offset,
freq_offset=2_000,
add_points=add_points, add_points=add_points,
facecolor="none" if not fill else None, facecolor="none" if not fill else None,
alpha=1, alpha=1,
@ -191,9 +185,9 @@ def plot_false_negative_match(
color: str = DEFAULT_FALSE_NEGATIVE_COLOR, color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
fontsize: Union[float, str] = "small", fontsize: Union[float, str] = "small",
) -> Axes: ) -> Axes:
assert match.match.source is None assert match.pred_geometry is None
assert match.match.target is not None assert match.sound_event_annotation is not None
sound_event = match.match.target.sound_event sound_event = match.sound_event_annotation.sound_event
geometry = sound_event.geometry geometry = sound_event.geometry
assert geometry is not None assert geometry is not None
@ -217,7 +211,7 @@ def plot_false_negative_match(
) )
plot.plot_annotation( plot.plot_annotation(
match.match.target, match.sound_event_annotation,
ax=ax, ax=ax,
time_offset=0.001, time_offset=0.001,
freq_offset=2_000, freq_offset=2_000,
@ -255,9 +249,9 @@ def plot_true_positive_match(
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes: ) -> Axes:
assert match.match.source is not None assert match.sound_event_annotation is not None
assert match.match.target is not None assert match.pred_geometry is not None
sound_event = match.match.target.sound_event sound_event = match.sound_event_annotation.sound_event
geometry = sound_event.geometry geometry = sound_event.geometry
assert geometry is not None assert geometry is not None
@ -281,7 +275,7 @@ def plot_true_positive_match(
) )
plot.plot_annotation( plot.plot_annotation(
match.match.target, match.sound_event_annotation,
ax=ax, ax=ax,
time_offset=0.001, time_offset=0.001,
freq_offset=2_000, freq_offset=2_000,
@ -292,11 +286,9 @@ def plot_true_positive_match(
linestyle=annotation_linestyle, linestyle=annotation_linestyle,
) )
plot_prediction( plot.plot_geometry(
match.match.source, match.pred_geometry,
ax=ax, ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points, add_points=add_points,
facecolor="none" if not fill else None, facecolor="none" if not fill else None,
alpha=1, alpha=1,
@ -332,9 +324,9 @@ def plot_cross_trigger_match(
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes: ) -> Axes:
assert match.match.source is not None assert match.sound_event_annotation is not None
assert match.match.target is not None assert match.pred_geometry is not None
sound_event = match.match.source.sound_event sound_event = match.sound_event_annotation.sound_event
geometry = sound_event.geometry geometry = sound_event.geometry
assert geometry is not None assert geometry is not None
@ -358,7 +350,7 @@ def plot_cross_trigger_match(
) )
plot.plot_annotation( plot.plot_annotation(
match.match.target, match.sound_event_annotation,
ax=ax, ax=ax,
time_offset=0.001, time_offset=0.001,
freq_offset=2_000, freq_offset=2_000,
@ -369,11 +361,9 @@ def plot_cross_trigger_match(
linestyle=annotation_linestyle, linestyle=annotation_linestyle,
) )
plot_prediction( plot.plot_geometry(
match.match.source, match.pred_geometry,
ax=ax, ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points, add_points=add_points,
facecolor="none" if not fill else None, facecolor="none" if not fill else None,
alpha=1, alpha=1,

View File

@ -10,9 +10,9 @@ from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.postprocess.decoding import ( from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD, DEFAULT_CLASSIFICATION_THRESHOLD,
convert_detections_to_raw_predictions,
convert_raw_prediction_to_sound_event_prediction, convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction, convert_raw_predictions_to_clip_prediction,
to_raw_predictions,
) )
from batdetect2.postprocess.extraction import extract_prediction_tensor from batdetect2.postprocess.extraction import extract_prediction_tensor
from batdetect2.postprocess.nms import ( from batdetect2.postprocess.nms import (
@ -24,7 +24,8 @@ from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing import ModelOutput from batdetect2.typing import ModelOutput
from batdetect2.typing.postprocess import ( from batdetect2.typing.postprocess import (
BatDetect2Prediction, BatDetect2Prediction,
Detections, DetectionsArray,
DetectionsTensor,
PostprocessorProtocol, PostprocessorProtocol,
RawPrediction, RawPrediction,
) )
@ -43,7 +44,7 @@ __all__ = [
"TOP_K_PER_SEC", "TOP_K_PER_SEC",
"build_postprocessor", "build_postprocessor",
"convert_raw_predictions_to_clip_prediction", "convert_raw_predictions_to_clip_prediction",
"convert_detections_to_raw_predictions", "to_raw_predictions",
"load_postprocess_config", "load_postprocess_config",
"non_max_suppression", "non_max_suppression",
] ]
@ -168,7 +169,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
self.top_k_per_sec = top_k_per_sec self.top_k_per_sec = top_k_per_sec
self.detection_threshold = detection_threshold self.detection_threshold = detection_threshold
def forward(self, output: ModelOutput) -> List[Detections]: def forward(self, output: ModelOutput) -> List[DetectionsTensor]:
width = output.detection_probs.shape[-1] width = output.detection_probs.shape[-1]
duration = width / self.samplerate duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration) max_detections = int(self.top_k_per_sec * duration)
@ -192,7 +193,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
self, self,
output: ModelOutput, output: ModelOutput,
clips: Optional[List[data.Clip]] = None, clips: Optional[List[data.Clip]] = None,
) -> List[Detections]: ) -> List[DetectionsTensor]:
width = output.detection_probs.shape[-1] width = output.detection_probs.shape[-1]
duration = width / self.samplerate duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration) max_detections = int(self.top_k_per_sec * duration)
@ -245,11 +246,8 @@ def get_raw_predictions(
""" """
detections = postprocessor.get_detections(output, clips) detections = postprocessor.get_detections(output, clips)
return [ return [
convert_detections_to_raw_predictions( to_raw_predictions(detection.numpy(), targets=targets)
dataset, for detection in detections
targets=targets,
)
for dataset in detections
] ]

View File

@ -6,13 +6,13 @@ import numpy as np
from soundevent import data from soundevent import data
from batdetect2.typing.postprocess import ( from batdetect2.typing.postprocess import (
Detections, DetectionsArray,
RawPrediction, RawPrediction,
) )
from batdetect2.typing.targets import TargetProtocol from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
"convert_detections_to_raw_predictions", "to_raw_predictions",
"convert_raw_predictions_to_clip_prediction", "convert_raw_predictions_to_clip_prediction",
"convert_raw_prediction_to_sound_event_prediction", "convert_raw_prediction_to_sound_event_prediction",
"DEFAULT_CLASSIFICATION_THRESHOLD", "DEFAULT_CLASSIFICATION_THRESHOLD",
@ -27,19 +27,19 @@ decoding.
""" """
def convert_detections_to_raw_predictions( def to_raw_predictions(
detections: Detections, detections: DetectionsArray,
targets: TargetProtocol, targets: TargetProtocol,
) -> List[RawPrediction]: ) -> List[RawPrediction]:
predictions = [] predictions = []
for score, class_scores, time, freq, dims, feats in zip( for score, class_scores, time, freq, dims, feats in zip(
detections.scores.cpu().numpy(), detections.scores,
detections.class_scores.cpu().numpy(), detections.class_scores,
detections.times.cpu().numpy(), detections.times,
detections.frequencies.cpu().numpy(), detections.frequencies,
detections.sizes.cpu().numpy(), detections.sizes,
detections.features.cpu().numpy(), detections.features,
): ):
highest_scoring_class = targets.class_names[class_scores.argmax()] highest_scoring_class = targets.class_names[class_scores.argmax()]

View File

@ -20,7 +20,10 @@ from typing import List, Optional, Tuple, Union
import torch import torch
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
from batdetect2.typing.postprocess import Detections, ModelOutput from batdetect2.typing.postprocess import (
DetectionsTensor,
ModelOutput,
)
__all__ = [ __all__ = [
"extract_prediction_tensor", "extract_prediction_tensor",
@ -32,7 +35,7 @@ def extract_prediction_tensor(
max_detections: int = 200, max_detections: int = 200,
threshold: Optional[float] = None, threshold: Optional[float] = None,
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
) -> List[Detections]: ) -> List[DetectionsTensor]:
detection_heatmap = non_max_suppression( detection_heatmap = non_max_suppression(
output.detection_probs.detach(), output.detection_probs.detach(),
kernel_size=nms_kernel_size, kernel_size=nms_kernel_size,
@ -78,7 +81,7 @@ def extract_prediction_tensor(
class_scores = class_scores[mask] class_scores = class_scores[mask]
predictions.append( predictions.append(
Detections( DetectionsTensor(
scores=detection_scores, scores=detection_scores,
sizes=sizes, sizes=sizes,
features=features, features=features,

View File

@ -20,7 +20,7 @@ import xarray as xr
from soundevent.arrays import Dimensions from soundevent.arrays import Dimensions
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing.postprocess import Detections from batdetect2.typing.postprocess import DetectionsTensor
__all__ = [ __all__ = [
"features_to_xarray", "features_to_xarray",
@ -31,15 +31,15 @@ __all__ = [
def map_detection_to_clip( def map_detection_to_clip(
detections: Detections, detections: DetectionsTensor,
start_time: float, start_time: float,
end_time: float, end_time: float,
min_freq: float, min_freq: float,
max_freq: float, max_freq: float,
) -> Detections: ) -> DetectionsTensor:
duration = end_time - start_time duration = end_time - start_time
bandwidth = max_freq - min_freq bandwidth = max_freq - min_freq
return Detections( return DetectionsTensor(
scores=detections.scores, scores=detections.scores,
sizes=detections.sizes, sizes=detections.sizes,
features=detections.features, features=detections.features,

View File

@ -21,7 +21,7 @@ configured processing steps. The main way to create a functional `Targets`
object is via the `build_targets` or `load_targets` functions. object is via the `build_targets` or `load_targets` functions.
""" """
from typing import List, Optional from typing import Iterable, List, Optional, Tuple
from loguru import logger from loguru import logger
from pydantic import Field from pydantic import Field
@ -675,3 +675,24 @@ def load_targets(
term_registry=term_registry, term_registry=term_registry,
derivation_registry=derivation_registry, derivation_registry=derivation_registry,
) )
def iterate_encoded_sound_events(
sound_events: Iterable[data.SoundEventAnnotation],
targets: TargetProtocol,
) -> Iterable[Tuple[Optional[str], Position, Size]]:
for sound_event in sound_events:
if not targets.filter(sound_event):
continue
geometry = sound_event.sound_event.geometry
if geometry is None:
continue
sound_event = targets.transform(sound_event)
class_name = targets.encode_class(sound_event)
position, size = targets.encode_roi(sound_event)
yield class_name, position, size

View File

@ -1,32 +1,26 @@
import io from typing import List, Optional
from typing import List, Optional, Tuple
import numpy as np
from lightning import LightningModule, Trainer from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
from lightning.pytorch.loggers import Logger, TensorBoardLogger
from lightning.pytorch.loggers.mlflow import MLFlowLogger
from loguru import logger
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.evaluate.match import ( from batdetect2.evaluate.match import (
MatchConfig, MatchConfig,
match_sound_events_and_raw_predictions, match_all_predictions,
) )
from batdetect2.models import Model
from batdetect2.plotting.evaluation import plot_example_gallery from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess import get_sound_event_predictions from batdetect2.postprocess import get_raw_predictions
from batdetect2.train.dataset import TrainingDataset from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import get_image_plotter
from batdetect2.typing import ( from batdetect2.typing import (
BatDetect2Prediction,
MatchEvaluation, MatchEvaluation,
MetricsProtocol, MetricsProtocol,
ModelOutput,
TargetProtocol,
TrainExample,
) )
from batdetect2.typing.models import ModelOutput
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.train import TrainExample
class ValidationMetrics(Callback): class ValidationMetrics(Callback):
@ -45,15 +39,14 @@ class ValidationMetrics(Callback):
self.metrics = metrics self.metrics = metrics
self.plot = plot self.plot = plot
self._matches: List[ self._clip_annotations: List[data.ClipAnnotation] = []
Tuple[data.ClipAnnotation, List[BatDetect2Prediction]] self._predictions: List[List[RawPrediction]] = []
] = []
def get_dataset(self, trainer: Trainer) -> TrainingDataset: def get_dataset(self, trainer: Trainer) -> ValidationDataset:
dataloaders = trainer.val_dataloaders dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader) assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset dataset = dataloaders.dataset
assert isinstance(dataset, TrainingDataset) assert isinstance(dataset, ValidationDataset)
return dataset return dataset
def plot_examples( def plot_examples(
@ -61,7 +54,7 @@ class ValidationMetrics(Callback):
pl_module: LightningModule, pl_module: LightningModule,
matches: List[MatchEvaluation], matches: List[MatchEvaluation],
): ):
plotter = _get_image_plotter(pl_module.logger) # type: ignore plotter = get_image_plotter(pl_module.logger) # type: ignore
if plotter is None: if plotter is None:
return return
@ -93,9 +86,10 @@ class ValidationMetrics(Callback):
trainer: Trainer, trainer: Trainer,
pl_module: LightningModule, pl_module: LightningModule,
) -> None: ) -> None:
matches = _match_all_collected_examples( matches = match_all_predictions(
self._matches, self._clip_annotations,
pl_module.model.targets, self._predictions,
targets=pl_module.model.targets,
config=self.match_config, config=self.match_config,
) )
@ -123,133 +117,23 @@ class ValidationMetrics(Callback):
batch_idx: int, batch_idx: int,
dataloader_idx: int = 0, dataloader_idx: int = 0,
) -> None: ) -> None:
self._matches.extend( postprocessor = pl_module.model.postprocessor
_get_batch_clips_and_predictions( targets = pl_module.model.targets
batch, dataset = self.get_dataset(trainer)
outputs,
dataset=self.get_dataset(trainer),
model=pl_module.model,
)
)
clip_annotations = [
dataset.clip_annotations[int(example_idx)]
for example_idx in batch.idx
]
def _get_batch_clips_and_predictions( predictions = get_raw_predictions(
batch: TrainExample, outputs,
outputs: ModelOutput, clips=[
dataset: TrainingDataset, clip_annotation.clip for clip_annotation in clip_annotations
model: Model, ],
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
clip_annotations = [
_get_subclip(
dataset.clip_annotations[int(example_id)],
start_time=start_time.item(),
end_time=end_time.item(),
targets=model.targets,
)
for example_id, start_time, end_time in zip(
batch.idx,
batch.start_time,
batch.end_time,
)
]
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = get_sound_event_predictions(
outputs,
clips,
targets=model.targets,
postprocessor=model.postprocessor
)
return [
(clip_annotation, clip_predictions)
for clip_annotation, clip_predictions in zip(
clip_annotations, raw_predictions
)
]
def _match_all_collected_examples(
pre_matches: List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]],
targets: TargetProtocol,
config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions...")
return [
match
for clip_annotation, raw_predictions in pre_matches
for match in match_sound_events_and_raw_predictions(
clip_annotation,
raw_predictions,
targets=targets, targets=targets,
config=config, postprocessor=postprocessor,
) )
]
self._clip_annotations.extend(clip_annotations)
def _is_in_subclip( self._predictions.extend(predictions)
sound_event_annotation: data.SoundEventAnnotation,
targets: TargetProtocol,
start_time: float,
end_time: float,
) -> bool:
(time, _), _ = targets.encode_roi(sound_event_annotation)
return start_time <= time <= end_time
def _get_subclip(
clip_annotation: data.ClipAnnotation,
start_time: float,
end_time: float,
targets: TargetProtocol,
) -> data.ClipAnnotation:
return data.ClipAnnotation(
clip=data.Clip(
recording=clip_annotation.clip.recording,
start_time=start_time,
end_time=end_time,
),
sound_events=[
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
if _is_in_subclip(
sound_event_annotation,
targets,
start_time=start_time,
end_time=end_time,
)
],
)
def _get_image_plotter(logger: Logger):
if isinstance(logger, TensorBoardLogger):
def plot_figure(name, figure, step):
return logger.experiment.add_figure(name, figure, step)
return plot_figure
if isinstance(logger, MLFlowLogger):
def plot_figure(name, figure, step):
image = _convert_figure_to_image(figure)
return logger.experiment.log_image(
run_id=logger.run_id,
image=image,
key=name,
step=step,
)
return plot_figure
def _convert_figure_to_image(figure):
with io.BytesIO() as buff:
figure.savefig(buff, format="raw")
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = figure.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
return im

View File

@ -6,7 +6,10 @@ from torch.utils.data import Dataset
from batdetect2.typing import ClipperProtocol, TrainExample from batdetect2.typing import ClipperProtocol, TrainExample
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
from batdetect2.typing.train import Augmentation, ClipLabeller from batdetect2.typing.train import (
Augmentation,
ClipLabeller,
)
__all__ = [ __all__ = [
"TrainingDataset", "TrainingDataset",
@ -75,3 +78,47 @@ class TrainingDataset(Dataset):
start_time=torch.tensor(clip.start_time), start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time), end_time=torch.tensor(clip.end_time),
) )
class ValidationDataset(Dataset):
def __init__(
self,
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
audio_dir: Optional[data.PathLike] = None,
):
self.clip_annotations = clip_annotations
self.labeller = labeller
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.audio_dir = audio_dir
def __len__(self):
return len(self.clip_annotations)
def __getitem__(self, idx) -> TrainExample:
clip_annotation = self.clip_annotations[idx]
clip = clip_annotation.clip
wav = self.audio_loader.load_clip(
clip_annotation.clip,
audio_dir=self.audio_dir,
)
wav_tensor = torch.tensor(wav).unsqueeze(0)
spectrogram = self.preprocessor(wav_tensor)
heatmaps = self.labeller(clip_annotation, spectrogram)
return TrainExample(
spec=spectrogram,
detection_heatmap=heatmaps.detection,
class_heatmap=heatmaps.classes,
size_heatmap=heatmaps.size,
idx=torch.tensor(idx),
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)

View File

@ -3,24 +3,6 @@
This module is responsible for creating the target labels used for training This module is responsible for creating the target labels used for training
BatDetect2 models. It converts sound event annotations for an audio clip into BatDetect2 models. It converts sound event annotations for an audio clip into
the specific multi-channel heatmap formats required by the neural network. the specific multi-channel heatmap formats required by the neural network.
It uses a pre-configured object adhering to the `TargetProtocol` (from
`batdetect2.targets`) which encapsulates all the logic for filtering
annotations, transforming tags, encoding class names, and mapping annotation
geometry (ROIs) to target positions and sizes. This module then focuses on
rendering this information onto the heatmap grids.
The pipeline generates three core outputs for a given spectrogram:
1. **Detection Heatmap**: Indicates presence/location of relevant sound events.
2. **Class Heatmap**: Indicates location and class identity for specifically
classified events.
3. **Size Heatmap**: Encodes the target dimensions (width, height) of events.
The primary function generated by this module is a `ClipLabeller` (defined in
`.types`), which takes a `ClipAnnotation` object and its corresponding
spectrogram and returns the calculated `Heatmaps` tuple. The main configurable
parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
defined in `LabelConfig`.
""" """
from functools import partial from functools import partial
@ -32,6 +14,7 @@ from loguru import logger
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets import iterate_encoded_sound_events
from batdetect2.typing import ( from batdetect2.typing import (
ClipLabeller, ClipLabeller,
Heatmaps, Heatmaps,
@ -56,9 +39,6 @@ class LabelConfig(BaseConfig):
Attributes Attributes
---------- ----------
sigma : float, default=3.0 sigma : float, default=3.0
The standard deviation (in pixels/bins) of the Gaussian kernel applied
to smooth the detection and class heatmaps. Larger values create more
diffuse targets.
""" """
sigma: float = 2.0 sigma: float = 2.0
@ -70,28 +50,7 @@ def build_clip_labeler(
max_freq: float, max_freq: float,
config: Optional[LabelConfig] = None, config: Optional[LabelConfig] = None,
) -> ClipLabeller: ) -> ClipLabeller:
"""Construct the final clip labelling function. """Construct the final clip labelling function."""
This factory function prepares the callable that will perform the
end-to-end heatmap generation for a given clip and spectrogram during
training data loading. It takes the fully configured `targets` object and
the `LabelConfig` and binds them to the `generate_clip_label` function.
Parameters
----------
targets : TargetProtocol
An initialized object conforming to the `TargetProtocol`, providing all
necessary methods for filtering, transforming, encoding, and ROI
mapping.
config : LabelConfig
Configuration object containing heatmap generation parameters.
Returns
-------
ClipLabeller
A function that accepts a `data.ClipAnnotation` and `xr.DataArray`
(spectrogram) and returns the generated `Heatmaps`.
"""
config = config or LabelConfig() config = config or LabelConfig()
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building clip labeler with config: \n{}", "Building clip labeler with config: \n{}",
@ -119,37 +78,10 @@ def generate_heatmaps(
target_sigma: float = 3.0, target_sigma: float = 3.0,
dtype=torch.float32, dtype=torch.float32,
) -> Heatmaps: ) -> Heatmaps:
"""Generate training heatmaps for a single annotated clip. """Generate training heatmaps for a single annotated clip."""
This function orchestrates the target generation process for one clip:
1. Filters and transforms sound events using `targets.filter` and
`targets.transform`.
2. Passes the resulting processed annotations, along with the spectrogram,
the `targets` object, and the Gaussian `sigma` from `config`, to the
core `generate_heatmaps` function.
Parameters
----------
clip_annotation : data.ClipAnnotation
The complete annotation data for the audio clip, including the list
of `sound_events` to process.
spec : xr.DataArray
The spectrogram corresponding to the `clip_annotation`. Must have
'time' and 'frequency' dimensions/coordinates.
targets : TargetProtocol
The fully configured target definition object, providing methods for
filtering, transformation, encoding, and ROI mapping.
config : LabelConfig
Configuration object providing heatmap parameters (primarily `sigma`).
Returns
-------
Heatmaps
A NamedTuple containing the generated 'detection', 'classes', and 'size'
heatmaps for this clip.
"""
logger.debug( logger.debug(
"Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events", "Will generate heatmaps for clip annotation "
"{uuid} with {num} annotated sound events",
uuid=clip_annotation.uuid, uuid=clip_annotation.uuid,
num=len(clip_annotation.sound_events), num=len(clip_annotation.sound_events),
) )
@ -174,28 +106,10 @@ def generate_heatmaps(
freqs = freqs.to(spec.device) freqs = freqs.to(spec.device)
times = times.to(spec.device) times = times.to(spec.device)
for sound_event_annotation in clip_annotation.sound_events: for class_name, (time, frequency), size in iterate_encoded_sound_events(
if not targets.filter(sound_event_annotation): clip_annotation.sound_events,
logger.debug( targets,
"Sound event {sound_event} did not pass the filter. Tags: {tags}", ):
sound_event=sound_event_annotation,
tags=sound_event_annotation.tags,
)
continue
sound_event_annotation = targets.transform(sound_event_annotation)
geom = sound_event_annotation.sound_event.geometry
if geom is None:
logger.debug(
"Skipping annotation %s: missing geometry.",
sound_event_annotation.uuid,
)
continue
# Get the position of the sound event
(time, frequency), size = targets.encode_roi(sound_event_annotation)
time_index = map_to_pixels(time, width, clip.start_time, clip.end_time) time_index = map_to_pixels(time, width, clip.start_time, clip.end_time)
freq_index = map_to_pixels(frequency, height, min_freq, max_freq) freq_index = map_to_pixels(frequency, height, min_freq, max_freq)
@ -206,9 +120,7 @@ def generate_heatmaps(
or freq_index >= height or freq_index >= height
): ):
logger.debug( logger.debug(
"Skipping annotation %s: position outside spectrogram. " "Skipping annotation: position outside spectrogram. Pos: %s",
"Pos: %s",
sound_event_annotation.uuid,
(time, frequency), (time, frequency),
) )
continue continue
@ -222,20 +134,8 @@ def generate_heatmaps(
) )
size_heatmap[:, freq_index, time_index] = torch.tensor(size[:]) size_heatmap[:, freq_index, time_index] = torch.tensor(size[:])
# Get the class name of the sound event # If the label is None skip the sound event
try:
class_name = targets.encode_class(sound_event_annotation)
except ValueError as e:
logger.warning(
"Skipping annotation %s: Unexpected error while encoding "
"class name %s",
sound_event_annotation.uuid,
e,
)
continue
if class_name is None: if class_name is None:
# If the label is None skip the sound event
continue continue
class_index = targets.class_names.index(class_name) class_index = targets.class_names.index(class_name)

View File

@ -1,6 +1,8 @@
import io
from typing import Annotated, Any, Literal, Optional, Union from typing import Annotated, Any, Literal, Optional, Union
from lightning.pytorch.loggers import Logger import numpy as np
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
from loguru import logger from loguru import logger
from pydantic import Field from pydantic import Field
@ -140,3 +142,35 @@ def build_logger(config: LoggerConfig) -> Logger:
creation_func = LOGGER_FACTORY[logger_type] creation_func = LOGGER_FACTORY[logger_type]
return creation_func(config) return creation_func(config)
def get_image_plotter(logger: Logger):
if isinstance(logger, TensorBoardLogger):
def plot_figure(name, figure, step):
return logger.experiment.add_figure(name, figure, step)
return plot_figure
if isinstance(logger, MLFlowLogger):
def plot_figure(name, figure, step):
image = _convert_figure_to_image(figure)
return logger.experiment.log_image(
run_id=logger.run_id,
image=image,
key=name,
step=step,
)
return plot_figure
def _convert_figure_to_image(figure):
with io.BytesIO() as buff:
figure.savefig(buff, format="raw")
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = figure.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
return im

View File

@ -24,9 +24,7 @@ from batdetect2.train.augmentations import (
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
from batdetect2.train.config import FullTrainingConfig, TrainingConfig from batdetect2.train.config import FullTrainingConfig, TrainingConfig
from batdetect2.train.dataset import ( from batdetect2.train.dataset import TrainingDataset, ValidationDataset
TrainingDataset,
)
from batdetect2.train.labels import build_clip_labeler from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger from batdetect2.train.logging import build_logger
@ -304,11 +302,11 @@ def build_val_dataset(
labeller: ClipLabeller, labeller: ClipLabeller,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None, config: Optional[TrainingConfig] = None,
) -> TrainingDataset: ) -> ValidationDataset:
logger.info("Building validation dataset...") logger.info("Building validation dataset...")
config = config or TrainingConfig() config = config or TrainingConfig()
return TrainingDataset( return ValidationDataset(
clip_annotations, clip_annotations,
audio_loader=audio_loader, audio_loader=audio_loader,
labeller=labeller, labeller=labeller,

View File

@ -11,13 +11,17 @@ __all__ = [
@dataclass @dataclass
class MatchEvaluation: class MatchEvaluation:
match: data.Match clip: data.Clip
sound_event_annotation: Optional[data.SoundEventAnnotation]
gt_det: bool gt_det: bool
gt_class: Optional[str] gt_class: Optional[str]
pred_score: float pred_score: float
pred_class_scores: Dict[str, float] pred_class_scores: Dict[str, float]
pred_geometry: Optional[data.Geometry]
affinity: float
@property @property
def pred_class(self) -> Optional[str]: def pred_class(self) -> Optional[str]:

View File

@ -77,7 +77,16 @@ class RawPrediction(NamedTuple):
features: np.ndarray features: np.ndarray
class Detections(NamedTuple): class DetectionsArray(NamedTuple):
scores: np.ndarray
sizes: np.ndarray
class_scores: np.ndarray
times: np.ndarray
frequencies: np.ndarray
features: np.ndarray
class DetectionsTensor(NamedTuple):
scores: torch.Tensor scores: torch.Tensor
sizes: torch.Tensor sizes: torch.Tensor
class_scores: torch.Tensor class_scores: torch.Tensor
@ -85,6 +94,16 @@ class Detections(NamedTuple):
frequencies: torch.Tensor frequencies: torch.Tensor
features: torch.Tensor features: torch.Tensor
def numpy(self) -> DetectionsArray:
return DetectionsArray(
scores=self.scores.detach().numpy(),
sizes=self.sizes.detach().numpy(),
class_scores=self.class_scores.detach().numpy(),
times=self.times.detach().numpy(),
frequencies=self.frequencies.detach().numpy(),
features=self.features.detach().numpy(),
)
@dataclass @dataclass
class BatDetect2Prediction: class BatDetect2Prediction:
@ -95,10 +114,10 @@ class BatDetect2Prediction:
class PostprocessorProtocol(Protocol): class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline.""" """Protocol defining the interface for the full postprocessing pipeline."""
def __call__(self, output: ModelOutput) -> List[Detections]: ... def __call__(self, output: ModelOutput) -> List[DetectionsTensor]: ...
def get_detections( def get_detections(
self, self,
output: ModelOutput, output: ModelOutput,
clips: Optional[List[data.Clip]] = None, clips: Optional[List[data.Clip]] = None,
) -> List[Detections]: ... ) -> List[DetectionsTensor]: ...

View File

@ -12,8 +12,8 @@ that components responsible for these tasks can be interacted with consistently
throughout BatDetect2. throughout BatDetect2.
""" """
from collections.abc import Callable, Iterable from collections.abc import Callable
from typing import List, Optional, Protocol, Tuple from typing import List, Optional, Protocol
import numpy as np import numpy as np
from soundevent import data from soundevent import data