From 5b9a5a968f31e618ed2c8947cdad040ee46abc90 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sun, 31 Aug 2025 22:57:02 +0100 Subject: [PATCH] Refactor eval code --- src/batdetect2/evaluate/__init__.py | 6 +- src/batdetect2/evaluate/match.py | 53 +++++-- src/batdetect2/models/__init__.py | 4 +- src/batdetect2/plotting/matches.py | 56 +++---- src/batdetect2/postprocess/__init__.py | 18 +-- src/batdetect2/postprocess/decoding.py | 20 +-- src/batdetect2/postprocess/extraction.py | 9 +- src/batdetect2/postprocess/remapping.py | 8 +- src/batdetect2/targets/__init__.py | 23 ++- src/batdetect2/train/callbacks.py | 180 ++++------------------- src/batdetect2/train/dataset.py | 49 +++++- src/batdetect2/train/labels.py | 122 ++------------- src/batdetect2/train/logging.py | 36 ++++- src/batdetect2/train/train.py | 8 +- src/batdetect2/typing/evaluate.py | 6 +- src/batdetect2/typing/postprocess.py | 25 +++- src/batdetect2/typing/targets.py | 4 +- 17 files changed, 273 insertions(+), 354 deletions(-) diff --git a/src/batdetect2/evaluate/__init__.py b/src/batdetect2/evaluate/__init__.py index 9bcaebf..ecd2812 100644 --- a/src/batdetect2/evaluate/__init__.py +++ b/src/batdetect2/evaluate/__init__.py @@ -2,14 +2,10 @@ from batdetect2.evaluate.config import ( EvaluationConfig, load_evaluation_config, ) -from batdetect2.evaluate.match import ( - match_predictions_and_annotations, - match_sound_events_and_raw_predictions, -) +from batdetect2.evaluate.match import match_predictions_and_annotations __all__ = [ "EvaluationConfig", "load_evaluation_config", "match_predictions_and_annotations", - "match_sound_events_and_raw_predictions", ] diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index 44e372e..15f3a34 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -3,6 +3,7 @@ from dataclasses import dataclass, field from typing import List, Literal, Optional, Protocol, Tuple import numpy as np +from loguru import logger from soundevent import data from soundevent.evaluation import compute_affinity from soundevent.evaluation import match_geometries as optimal_match @@ -10,10 +11,10 @@ from soundevent.geometry import compute_bounds from batdetect2.configs import BaseConfig from batdetect2.typing import ( - BatDetect2Prediction, MatchEvaluation, TargetProtocol, ) +from batdetect2.typing.postprocess import RawPrediction MatchingStrategy = Literal["greedy", "optimal"] """The type of matching algorithm to use: 'greedy' or 'optimal'.""" @@ -274,7 +275,7 @@ def greedy_match( def match_sound_events_and_raw_predictions( clip_annotation: data.ClipAnnotation, - raw_predictions: List[BatDetect2Prediction], + raw_predictions: List[RawPrediction], targets: TargetProtocol, config: Optional[MatchConfig] = None, ) -> List[MatchEvaluation]: @@ -294,12 +295,11 @@ def match_sound_events_and_raw_predictions( ] predicted_geometries = [ - raw_prediction.raw.geometry for raw_prediction in raw_predictions + raw_prediction.geometry for raw_prediction in raw_predictions ] scores = [ - raw_prediction.raw.detection_score - for raw_prediction in raw_predictions + raw_prediction.detection_score for raw_prediction in raw_predictions ] matches = [] @@ -320,14 +320,20 @@ def match_sound_events_and_raw_predictions( gt_det = target is not None gt_class = targets.encode_class(target) if target is not None else None - pred_score = float(prediction.raw.detection_score) if prediction else 0 + pred_score = float(prediction.detection_score) if prediction else 0 + + pred_geometry = ( + predicted_geometries[source_idx] + if source_idx is not None + else None + ) class_scores = ( { str(class_name): float(score) for class_name, score in zip( targets.class_names, - prediction.raw.class_scores, + prediction.class_scores, ) } if prediction is not None @@ -336,17 +342,14 @@ def match_sound_events_and_raw_predictions( matches.append( MatchEvaluation( - match=data.Match( - source=None - if prediction is None - else prediction.sound_event_prediction, - target=target, - affinity=affinity, - ), + clip=clip_annotation.clip, + sound_event_annotation=target, gt_det=gt_det, gt_class=gt_class, pred_score=pred_score, pred_class_scores=class_scores, + pred_geometry=pred_geometry, + affinity=affinity, ) ) @@ -418,6 +421,28 @@ def match_predictions_and_annotations( return matches +def match_all_predictions( + clip_annotations: List[data.ClipAnnotation], + predictions: List[List[RawPrediction]], + targets: TargetProtocol, + config: Optional[MatchConfig] = None, +) -> List[MatchEvaluation]: + logger.info("Matching all annotations and predictions...") + return [ + match + for clip_annotation, raw_predictions in zip( + clip_annotations, + predictions, + ) + for match in match_sound_events_and_raw_predictions( + clip_annotation, + raw_predictions, + targets=targets, + config=config, + ) + ] + + @dataclass class ClassExamples: false_positives: List[MatchEvaluation] = field(default_factory=list) diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index 5d35512..1e2fe18 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -68,7 +68,7 @@ from batdetect2.postprocess import PostprocessConfig, build_postprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.targets import TargetConfig, build_targets from batdetect2.typing.models import DetectionModel -from batdetect2.typing.postprocess import Detections, PostprocessorProtocol +from batdetect2.typing.postprocess import DetectionsArray, PostprocessorProtocol from batdetect2.typing.preprocess import PreprocessorProtocol from batdetect2.typing.targets import TargetProtocol @@ -122,7 +122,7 @@ class Model(LightningModule): self.targets = targets self.save_hyperparameters() - def forward(self, wav: torch.Tensor) -> List[Detections]: + def forward(self, wav: torch.Tensor) -> List[DetectionsArray]: spec = self.preprocessor(wav) outputs = self.detector(spec) return self.postprocessor(outputs) diff --git a/src/batdetect2/plotting/matches.py b/src/batdetect2/plotting/matches.py index 58216fd..c584bea 100644 --- a/src/batdetect2/plotting/matches.py +++ b/src/batdetect2/plotting/matches.py @@ -124,25 +124,21 @@ def plot_false_positive_match( add_points: bool = False, fill: bool = False, spec_cmap: str = "gray", - time_offset: float = 0, color: str = DEFAULT_FALSE_POSITIVE_COLOR, fontsize: Union[float, str] = "small", ) -> Axes: - assert match.match.source is not None - assert match.match.target is None - sound_event = match.match.source.sound_event - geometry = sound_event.geometry - assert geometry is not None + assert match.pred_geometry is not None + assert match.sound_event_annotation is None - start_time, _, _, high_freq = compute_bounds(geometry) + start_time, _, _, high_freq = compute_bounds(match.pred_geometry) clip = data.Clip( start_time=max(start_time - duration / 2, 0), end_time=min( start_time + duration / 2, - sound_event.recording.duration, + match.clip.end_time, ), - recording=sound_event.recording, + recording=match.clip.recording, ) ax = plot_clip( @@ -154,11 +150,9 @@ def plot_false_positive_match( spec_cmap=spec_cmap, ) - plot_prediction( - match.match.source, + plot.plot_geometry( + match.pred_geometry, ax=ax, - time_offset=time_offset, - freq_offset=2_000, add_points=add_points, facecolor="none" if not fill else None, alpha=1, @@ -191,9 +185,9 @@ def plot_false_negative_match( color: str = DEFAULT_FALSE_NEGATIVE_COLOR, fontsize: Union[float, str] = "small", ) -> Axes: - assert match.match.source is None - assert match.match.target is not None - sound_event = match.match.target.sound_event + assert match.pred_geometry is None + assert match.sound_event_annotation is not None + sound_event = match.sound_event_annotation.sound_event geometry = sound_event.geometry assert geometry is not None @@ -217,7 +211,7 @@ def plot_false_negative_match( ) plot.plot_annotation( - match.match.target, + match.sound_event_annotation, ax=ax, time_offset=0.001, freq_offset=2_000, @@ -255,9 +249,9 @@ def plot_true_positive_match( annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, ) -> Axes: - assert match.match.source is not None - assert match.match.target is not None - sound_event = match.match.target.sound_event + assert match.sound_event_annotation is not None + assert match.pred_geometry is not None + sound_event = match.sound_event_annotation.sound_event geometry = sound_event.geometry assert geometry is not None @@ -281,7 +275,7 @@ def plot_true_positive_match( ) plot.plot_annotation( - match.match.target, + match.sound_event_annotation, ax=ax, time_offset=0.001, freq_offset=2_000, @@ -292,11 +286,9 @@ def plot_true_positive_match( linestyle=annotation_linestyle, ) - plot_prediction( - match.match.source, + plot.plot_geometry( + match.pred_geometry, ax=ax, - time_offset=0.001, - freq_offset=2_000, add_points=add_points, facecolor="none" if not fill else None, alpha=1, @@ -332,9 +324,9 @@ def plot_cross_trigger_match( annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, ) -> Axes: - assert match.match.source is not None - assert match.match.target is not None - sound_event = match.match.source.sound_event + assert match.sound_event_annotation is not None + assert match.pred_geometry is not None + sound_event = match.sound_event_annotation.sound_event geometry = sound_event.geometry assert geometry is not None @@ -358,7 +350,7 @@ def plot_cross_trigger_match( ) plot.plot_annotation( - match.match.target, + match.sound_event_annotation, ax=ax, time_offset=0.001, freq_offset=2_000, @@ -369,11 +361,9 @@ def plot_cross_trigger_match( linestyle=annotation_linestyle, ) - plot_prediction( - match.match.source, + plot.plot_geometry( + match.pred_geometry, ax=ax, - time_offset=0.001, - freq_offset=2_000, add_points=add_points, facecolor="none" if not fill else None, alpha=1, diff --git a/src/batdetect2/postprocess/__init__.py b/src/batdetect2/postprocess/__init__.py index ba77c1a..27e45df 100644 --- a/src/batdetect2/postprocess/__init__.py +++ b/src/batdetect2/postprocess/__init__.py @@ -10,9 +10,9 @@ from soundevent import data from batdetect2.configs import BaseConfig, load_config from batdetect2.postprocess.decoding import ( DEFAULT_CLASSIFICATION_THRESHOLD, - convert_detections_to_raw_predictions, convert_raw_prediction_to_sound_event_prediction, convert_raw_predictions_to_clip_prediction, + to_raw_predictions, ) from batdetect2.postprocess.extraction import extract_prediction_tensor from batdetect2.postprocess.nms import ( @@ -24,7 +24,8 @@ from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.typing import ModelOutput from batdetect2.typing.postprocess import ( BatDetect2Prediction, - Detections, + DetectionsArray, + DetectionsTensor, PostprocessorProtocol, RawPrediction, ) @@ -43,7 +44,7 @@ __all__ = [ "TOP_K_PER_SEC", "build_postprocessor", "convert_raw_predictions_to_clip_prediction", - "convert_detections_to_raw_predictions", + "to_raw_predictions", "load_postprocess_config", "non_max_suppression", ] @@ -168,7 +169,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): self.top_k_per_sec = top_k_per_sec self.detection_threshold = detection_threshold - def forward(self, output: ModelOutput) -> List[Detections]: + def forward(self, output: ModelOutput) -> List[DetectionsTensor]: width = output.detection_probs.shape[-1] duration = width / self.samplerate max_detections = int(self.top_k_per_sec * duration) @@ -192,7 +193,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): self, output: ModelOutput, clips: Optional[List[data.Clip]] = None, - ) -> List[Detections]: + ) -> List[DetectionsTensor]: width = output.detection_probs.shape[-1] duration = width / self.samplerate max_detections = int(self.top_k_per_sec * duration) @@ -245,11 +246,8 @@ def get_raw_predictions( """ detections = postprocessor.get_detections(output, clips) return [ - convert_detections_to_raw_predictions( - dataset, - targets=targets, - ) - for dataset in detections + to_raw_predictions(detection.numpy(), targets=targets) + for detection in detections ] diff --git a/src/batdetect2/postprocess/decoding.py b/src/batdetect2/postprocess/decoding.py index c1a283c..9180d34 100644 --- a/src/batdetect2/postprocess/decoding.py +++ b/src/batdetect2/postprocess/decoding.py @@ -6,13 +6,13 @@ import numpy as np from soundevent import data from batdetect2.typing.postprocess import ( - Detections, + DetectionsArray, RawPrediction, ) from batdetect2.typing.targets import TargetProtocol __all__ = [ - "convert_detections_to_raw_predictions", + "to_raw_predictions", "convert_raw_predictions_to_clip_prediction", "convert_raw_prediction_to_sound_event_prediction", "DEFAULT_CLASSIFICATION_THRESHOLD", @@ -27,19 +27,19 @@ decoding. """ -def convert_detections_to_raw_predictions( - detections: Detections, +def to_raw_predictions( + detections: DetectionsArray, targets: TargetProtocol, ) -> List[RawPrediction]: predictions = [] for score, class_scores, time, freq, dims, feats in zip( - detections.scores.cpu().numpy(), - detections.class_scores.cpu().numpy(), - detections.times.cpu().numpy(), - detections.frequencies.cpu().numpy(), - detections.sizes.cpu().numpy(), - detections.features.cpu().numpy(), + detections.scores, + detections.class_scores, + detections.times, + detections.frequencies, + detections.sizes, + detections.features, ): highest_scoring_class = targets.class_names[class_scores.argmax()] diff --git a/src/batdetect2/postprocess/extraction.py b/src/batdetect2/postprocess/extraction.py index f0635e3..29e981b 100644 --- a/src/batdetect2/postprocess/extraction.py +++ b/src/batdetect2/postprocess/extraction.py @@ -20,7 +20,10 @@ from typing import List, Optional, Tuple, Union import torch from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression -from batdetect2.typing.postprocess import Detections, ModelOutput +from batdetect2.typing.postprocess import ( + DetectionsTensor, + ModelOutput, +) __all__ = [ "extract_prediction_tensor", @@ -32,7 +35,7 @@ def extract_prediction_tensor( max_detections: int = 200, threshold: Optional[float] = None, nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, -) -> List[Detections]: +) -> List[DetectionsTensor]: detection_heatmap = non_max_suppression( output.detection_probs.detach(), kernel_size=nms_kernel_size, @@ -78,7 +81,7 @@ def extract_prediction_tensor( class_scores = class_scores[mask] predictions.append( - Detections( + DetectionsTensor( scores=detection_scores, sizes=sizes, features=features, diff --git a/src/batdetect2/postprocess/remapping.py b/src/batdetect2/postprocess/remapping.py index 7eea516..0fc5db7 100644 --- a/src/batdetect2/postprocess/remapping.py +++ b/src/batdetect2/postprocess/remapping.py @@ -20,7 +20,7 @@ import xarray as xr from soundevent.arrays import Dimensions from batdetect2.preprocess import MAX_FREQ, MIN_FREQ -from batdetect2.typing.postprocess import Detections +from batdetect2.typing.postprocess import DetectionsTensor __all__ = [ "features_to_xarray", @@ -31,15 +31,15 @@ __all__ = [ def map_detection_to_clip( - detections: Detections, + detections: DetectionsTensor, start_time: float, end_time: float, min_freq: float, max_freq: float, -) -> Detections: +) -> DetectionsTensor: duration = end_time - start_time bandwidth = max_freq - min_freq - return Detections( + return DetectionsTensor( scores=detections.scores, sizes=detections.sizes, features=detections.features, diff --git a/src/batdetect2/targets/__init__.py b/src/batdetect2/targets/__init__.py index c976f71..2114bcf 100644 --- a/src/batdetect2/targets/__init__.py +++ b/src/batdetect2/targets/__init__.py @@ -21,7 +21,7 @@ configured processing steps. The main way to create a functional `Targets` object is via the `build_targets` or `load_targets` functions. """ -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple from loguru import logger from pydantic import Field @@ -675,3 +675,24 @@ def load_targets( term_registry=term_registry, derivation_registry=derivation_registry, ) + + +def iterate_encoded_sound_events( + sound_events: Iterable[data.SoundEventAnnotation], + targets: TargetProtocol, +) -> Iterable[Tuple[Optional[str], Position, Size]]: + for sound_event in sound_events: + if not targets.filter(sound_event): + continue + + geometry = sound_event.sound_event.geometry + + if geometry is None: + continue + + sound_event = targets.transform(sound_event) + + class_name = targets.encode_class(sound_event) + position, size = targets.encode_roi(sound_event) + + yield class_name, position, size diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index 2195615..ae6636b 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -1,32 +1,26 @@ -import io -from typing import List, Optional, Tuple +from typing import List, Optional -import numpy as np from lightning import LightningModule, Trainer from lightning.pytorch.callbacks import Callback -from lightning.pytorch.loggers import Logger, TensorBoardLogger -from lightning.pytorch.loggers.mlflow import MLFlowLogger -from loguru import logger from soundevent import data from torch.utils.data import DataLoader from batdetect2.evaluate.match import ( MatchConfig, - match_sound_events_and_raw_predictions, + match_all_predictions, ) -from batdetect2.models import Model from batdetect2.plotting.evaluation import plot_example_gallery -from batdetect2.postprocess import get_sound_event_predictions -from batdetect2.train.dataset import TrainingDataset +from batdetect2.postprocess import get_raw_predictions +from batdetect2.train.dataset import ValidationDataset from batdetect2.train.lightning import TrainingModule +from batdetect2.train.logging import get_image_plotter from batdetect2.typing import ( - BatDetect2Prediction, MatchEvaluation, MetricsProtocol, - ModelOutput, - TargetProtocol, - TrainExample, ) +from batdetect2.typing.models import ModelOutput +from batdetect2.typing.postprocess import RawPrediction +from batdetect2.typing.train import TrainExample class ValidationMetrics(Callback): @@ -45,15 +39,14 @@ class ValidationMetrics(Callback): self.metrics = metrics self.plot = plot - self._matches: List[ - Tuple[data.ClipAnnotation, List[BatDetect2Prediction]] - ] = [] + self._clip_annotations: List[data.ClipAnnotation] = [] + self._predictions: List[List[RawPrediction]] = [] - def get_dataset(self, trainer: Trainer) -> TrainingDataset: + def get_dataset(self, trainer: Trainer) -> ValidationDataset: dataloaders = trainer.val_dataloaders assert isinstance(dataloaders, DataLoader) dataset = dataloaders.dataset - assert isinstance(dataset, TrainingDataset) + assert isinstance(dataset, ValidationDataset) return dataset def plot_examples( @@ -61,7 +54,7 @@ class ValidationMetrics(Callback): pl_module: LightningModule, matches: List[MatchEvaluation], ): - plotter = _get_image_plotter(pl_module.logger) # type: ignore + plotter = get_image_plotter(pl_module.logger) # type: ignore if plotter is None: return @@ -93,9 +86,10 @@ class ValidationMetrics(Callback): trainer: Trainer, pl_module: LightningModule, ) -> None: - matches = _match_all_collected_examples( - self._matches, - pl_module.model.targets, + matches = match_all_predictions( + self._clip_annotations, + self._predictions, + targets=pl_module.model.targets, config=self.match_config, ) @@ -123,133 +117,23 @@ class ValidationMetrics(Callback): batch_idx: int, dataloader_idx: int = 0, ) -> None: - self._matches.extend( - _get_batch_clips_and_predictions( - batch, - outputs, - dataset=self.get_dataset(trainer), - model=pl_module.model, - ) - ) + postprocessor = pl_module.model.postprocessor + targets = pl_module.model.targets + dataset = self.get_dataset(trainer) + clip_annotations = [ + dataset.clip_annotations[int(example_idx)] + for example_idx in batch.idx + ] -def _get_batch_clips_and_predictions( - batch: TrainExample, - outputs: ModelOutput, - dataset: TrainingDataset, - model: Model, -) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]: - clip_annotations = [ - _get_subclip( - dataset.clip_annotations[int(example_id)], - start_time=start_time.item(), - end_time=end_time.item(), - targets=model.targets, - ) - for example_id, start_time, end_time in zip( - batch.idx, - batch.start_time, - batch.end_time, - ) - ] - - clips = [clip_annotation.clip for clip_annotation in clip_annotations] - - raw_predictions = get_sound_event_predictions( - outputs, - clips, - targets=model.targets, - postprocessor=model.postprocessor - ) - - return [ - (clip_annotation, clip_predictions) - for clip_annotation, clip_predictions in zip( - clip_annotations, raw_predictions - ) - ] - - -def _match_all_collected_examples( - pre_matches: List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]], - targets: TargetProtocol, - config: Optional[MatchConfig] = None, -) -> List[MatchEvaluation]: - logger.info("Matching all annotations and predictions...") - return [ - match - for clip_annotation, raw_predictions in pre_matches - for match in match_sound_events_and_raw_predictions( - clip_annotation, - raw_predictions, + predictions = get_raw_predictions( + outputs, + clips=[ + clip_annotation.clip for clip_annotation in clip_annotations + ], targets=targets, - config=config, + postprocessor=postprocessor, ) - ] - -def _is_in_subclip( - sound_event_annotation: data.SoundEventAnnotation, - targets: TargetProtocol, - start_time: float, - end_time: float, -) -> bool: - (time, _), _ = targets.encode_roi(sound_event_annotation) - return start_time <= time <= end_time - - -def _get_subclip( - clip_annotation: data.ClipAnnotation, - start_time: float, - end_time: float, - targets: TargetProtocol, -) -> data.ClipAnnotation: - return data.ClipAnnotation( - clip=data.Clip( - recording=clip_annotation.clip.recording, - start_time=start_time, - end_time=end_time, - ), - sound_events=[ - sound_event_annotation - for sound_event_annotation in clip_annotation.sound_events - if _is_in_subclip( - sound_event_annotation, - targets, - start_time=start_time, - end_time=end_time, - ) - ], - ) - - -def _get_image_plotter(logger: Logger): - if isinstance(logger, TensorBoardLogger): - - def plot_figure(name, figure, step): - return logger.experiment.add_figure(name, figure, step) - - return plot_figure - - if isinstance(logger, MLFlowLogger): - - def plot_figure(name, figure, step): - image = _convert_figure_to_image(figure) - return logger.experiment.log_image( - run_id=logger.run_id, - image=image, - key=name, - step=step, - ) - - return plot_figure - - -def _convert_figure_to_image(figure): - with io.BytesIO() as buff: - figure.savefig(buff, format="raw") - buff.seek(0) - data = np.frombuffer(buff.getvalue(), dtype=np.uint8) - w, h = figure.canvas.get_width_height() - im = data.reshape((int(h), int(w), -1)) - return im + self._clip_annotations.extend(clip_annotations) + self._predictions.extend(predictions) diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index f7f62f5..4add71f 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -6,7 +6,10 @@ from torch.utils.data import Dataset from batdetect2.typing import ClipperProtocol, TrainExample from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol -from batdetect2.typing.train import Augmentation, ClipLabeller +from batdetect2.typing.train import ( + Augmentation, + ClipLabeller, +) __all__ = [ "TrainingDataset", @@ -75,3 +78,47 @@ class TrainingDataset(Dataset): start_time=torch.tensor(clip.start_time), end_time=torch.tensor(clip.end_time), ) + + +class ValidationDataset(Dataset): + def __init__( + self, + clip_annotations: Sequence[data.ClipAnnotation], + audio_loader: AudioLoader, + preprocessor: PreprocessorProtocol, + labeller: ClipLabeller, + audio_dir: Optional[data.PathLike] = None, + ): + self.clip_annotations = clip_annotations + self.labeller = labeller + self.preprocessor = preprocessor + self.audio_loader = audio_loader + self.audio_dir = audio_dir + + def __len__(self): + return len(self.clip_annotations) + + def __getitem__(self, idx) -> TrainExample: + clip_annotation = self.clip_annotations[idx] + clip = clip_annotation.clip + + wav = self.audio_loader.load_clip( + clip_annotation.clip, + audio_dir=self.audio_dir, + ) + + wav_tensor = torch.tensor(wav).unsqueeze(0) + + spectrogram = self.preprocessor(wav_tensor) + + heatmaps = self.labeller(clip_annotation, spectrogram) + + return TrainExample( + spec=spectrogram, + detection_heatmap=heatmaps.detection, + class_heatmap=heatmaps.classes, + size_heatmap=heatmaps.size, + idx=torch.tensor(idx), + start_time=torch.tensor(clip.start_time), + end_time=torch.tensor(clip.end_time), + ) diff --git a/src/batdetect2/train/labels.py b/src/batdetect2/train/labels.py index dd42110..163e787 100644 --- a/src/batdetect2/train/labels.py +++ b/src/batdetect2/train/labels.py @@ -3,24 +3,6 @@ This module is responsible for creating the target labels used for training BatDetect2 models. It converts sound event annotations for an audio clip into the specific multi-channel heatmap formats required by the neural network. - -It uses a pre-configured object adhering to the `TargetProtocol` (from -`batdetect2.targets`) which encapsulates all the logic for filtering -annotations, transforming tags, encoding class names, and mapping annotation -geometry (ROIs) to target positions and sizes. This module then focuses on -rendering this information onto the heatmap grids. - -The pipeline generates three core outputs for a given spectrogram: -1. **Detection Heatmap**: Indicates presence/location of relevant sound events. -2. **Class Heatmap**: Indicates location and class identity for specifically - classified events. -3. **Size Heatmap**: Encodes the target dimensions (width, height) of events. - -The primary function generated by this module is a `ClipLabeller` (defined in -`.types`), which takes a `ClipAnnotation` object and its corresponding -spectrogram and returns the calculated `Heatmaps` tuple. The main configurable -parameter specific to this module is the Gaussian smoothing sigma (`sigma`) -defined in `LabelConfig`. """ from functools import partial @@ -32,6 +14,7 @@ from loguru import logger from soundevent import data from batdetect2.configs import BaseConfig, load_config +from batdetect2.targets import iterate_encoded_sound_events from batdetect2.typing import ( ClipLabeller, Heatmaps, @@ -56,9 +39,6 @@ class LabelConfig(BaseConfig): Attributes ---------- sigma : float, default=3.0 - The standard deviation (in pixels/bins) of the Gaussian kernel applied - to smooth the detection and class heatmaps. Larger values create more - diffuse targets. """ sigma: float = 2.0 @@ -70,28 +50,7 @@ def build_clip_labeler( max_freq: float, config: Optional[LabelConfig] = None, ) -> ClipLabeller: - """Construct the final clip labelling function. - - This factory function prepares the callable that will perform the - end-to-end heatmap generation for a given clip and spectrogram during - training data loading. It takes the fully configured `targets` object and - the `LabelConfig` and binds them to the `generate_clip_label` function. - - Parameters - ---------- - targets : TargetProtocol - An initialized object conforming to the `TargetProtocol`, providing all - necessary methods for filtering, transforming, encoding, and ROI - mapping. - config : LabelConfig - Configuration object containing heatmap generation parameters. - - Returns - ------- - ClipLabeller - A function that accepts a `data.ClipAnnotation` and `xr.DataArray` - (spectrogram) and returns the generated `Heatmaps`. - """ + """Construct the final clip labelling function.""" config = config or LabelConfig() logger.opt(lazy=True).debug( "Building clip labeler with config: \n{}", @@ -119,37 +78,10 @@ def generate_heatmaps( target_sigma: float = 3.0, dtype=torch.float32, ) -> Heatmaps: - """Generate training heatmaps for a single annotated clip. - - This function orchestrates the target generation process for one clip: - 1. Filters and transforms sound events using `targets.filter` and - `targets.transform`. - 2. Passes the resulting processed annotations, along with the spectrogram, - the `targets` object, and the Gaussian `sigma` from `config`, to the - core `generate_heatmaps` function. - - Parameters - ---------- - clip_annotation : data.ClipAnnotation - The complete annotation data for the audio clip, including the list - of `sound_events` to process. - spec : xr.DataArray - The spectrogram corresponding to the `clip_annotation`. Must have - 'time' and 'frequency' dimensions/coordinates. - targets : TargetProtocol - The fully configured target definition object, providing methods for - filtering, transformation, encoding, and ROI mapping. - config : LabelConfig - Configuration object providing heatmap parameters (primarily `sigma`). - - Returns - ------- - Heatmaps - A NamedTuple containing the generated 'detection', 'classes', and 'size' - heatmaps for this clip. - """ + """Generate training heatmaps for a single annotated clip.""" logger.debug( - "Will generate heatmaps for clip annotation {uuid} with {num} annotated sound events", + "Will generate heatmaps for clip annotation " + "{uuid} with {num} annotated sound events", uuid=clip_annotation.uuid, num=len(clip_annotation.sound_events), ) @@ -174,28 +106,10 @@ def generate_heatmaps( freqs = freqs.to(spec.device) times = times.to(spec.device) - for sound_event_annotation in clip_annotation.sound_events: - if not targets.filter(sound_event_annotation): - logger.debug( - "Sound event {sound_event} did not pass the filter. Tags: {tags}", - sound_event=sound_event_annotation, - tags=sound_event_annotation.tags, - ) - continue - - sound_event_annotation = targets.transform(sound_event_annotation) - - geom = sound_event_annotation.sound_event.geometry - if geom is None: - logger.debug( - "Skipping annotation %s: missing geometry.", - sound_event_annotation.uuid, - ) - continue - - # Get the position of the sound event - (time, frequency), size = targets.encode_roi(sound_event_annotation) - + for class_name, (time, frequency), size in iterate_encoded_sound_events( + clip_annotation.sound_events, + targets, + ): time_index = map_to_pixels(time, width, clip.start_time, clip.end_time) freq_index = map_to_pixels(frequency, height, min_freq, max_freq) @@ -206,9 +120,7 @@ def generate_heatmaps( or freq_index >= height ): logger.debug( - "Skipping annotation %s: position outside spectrogram. " - "Pos: %s", - sound_event_annotation.uuid, + "Skipping annotation: position outside spectrogram. Pos: %s", (time, frequency), ) continue @@ -222,20 +134,8 @@ def generate_heatmaps( ) size_heatmap[:, freq_index, time_index] = torch.tensor(size[:]) - # Get the class name of the sound event - try: - class_name = targets.encode_class(sound_event_annotation) - except ValueError as e: - logger.warning( - "Skipping annotation %s: Unexpected error while encoding " - "class name %s", - sound_event_annotation.uuid, - e, - ) - continue - + # If the label is None skip the sound event if class_name is None: - # If the label is None skip the sound event continue class_index = targets.class_names.index(class_name) diff --git a/src/batdetect2/train/logging.py b/src/batdetect2/train/logging.py index c4093c1..3684e01 100644 --- a/src/batdetect2/train/logging.py +++ b/src/batdetect2/train/logging.py @@ -1,6 +1,8 @@ +import io from typing import Annotated, Any, Literal, Optional, Union -from lightning.pytorch.loggers import Logger +import numpy as np +from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger from loguru import logger from pydantic import Field @@ -140,3 +142,35 @@ def build_logger(config: LoggerConfig) -> Logger: creation_func = LOGGER_FACTORY[logger_type] return creation_func(config) + + +def get_image_plotter(logger: Logger): + if isinstance(logger, TensorBoardLogger): + + def plot_figure(name, figure, step): + return logger.experiment.add_figure(name, figure, step) + + return plot_figure + + if isinstance(logger, MLFlowLogger): + + def plot_figure(name, figure, step): + image = _convert_figure_to_image(figure) + return logger.experiment.log_image( + run_id=logger.run_id, + image=image, + key=name, + step=step, + ) + + return plot_figure + + +def _convert_figure_to_image(figure): + with io.BytesIO() as buff: + figure.savefig(buff, format="raw") + buff.seek(0) + data = np.frombuffer(buff.getvalue(), dtype=np.uint8) + w, h = figure.canvas.get_width_height() + im = data.reshape((int(h), int(w), -1)) + return im diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 0178333..b1b05cc 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -24,9 +24,7 @@ from batdetect2.train.augmentations import ( from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.clips import build_clipper from batdetect2.train.config import FullTrainingConfig, TrainingConfig -from batdetect2.train.dataset import ( - TrainingDataset, -) +from batdetect2.train.dataset import TrainingDataset, ValidationDataset from batdetect2.train.labels import build_clip_labeler from batdetect2.train.lightning import TrainingModule from batdetect2.train.logging import build_logger @@ -304,11 +302,11 @@ def build_val_dataset( labeller: ClipLabeller, preprocessor: PreprocessorProtocol, config: Optional[TrainingConfig] = None, -) -> TrainingDataset: +) -> ValidationDataset: logger.info("Building validation dataset...") config = config or TrainingConfig() - return TrainingDataset( + return ValidationDataset( clip_annotations, audio_loader=audio_loader, labeller=labeller, diff --git a/src/batdetect2/typing/evaluate.py b/src/batdetect2/typing/evaluate.py index 2ef9206..06c53e2 100644 --- a/src/batdetect2/typing/evaluate.py +++ b/src/batdetect2/typing/evaluate.py @@ -11,13 +11,17 @@ __all__ = [ @dataclass class MatchEvaluation: - match: data.Match + clip: data.Clip + sound_event_annotation: Optional[data.SoundEventAnnotation] gt_det: bool gt_class: Optional[str] pred_score: float pred_class_scores: Dict[str, float] + pred_geometry: Optional[data.Geometry] + + affinity: float @property def pred_class(self) -> Optional[str]: diff --git a/src/batdetect2/typing/postprocess.py b/src/batdetect2/typing/postprocess.py index e876c3d..42afb4a 100644 --- a/src/batdetect2/typing/postprocess.py +++ b/src/batdetect2/typing/postprocess.py @@ -77,7 +77,16 @@ class RawPrediction(NamedTuple): features: np.ndarray -class Detections(NamedTuple): +class DetectionsArray(NamedTuple): + scores: np.ndarray + sizes: np.ndarray + class_scores: np.ndarray + times: np.ndarray + frequencies: np.ndarray + features: np.ndarray + + +class DetectionsTensor(NamedTuple): scores: torch.Tensor sizes: torch.Tensor class_scores: torch.Tensor @@ -85,6 +94,16 @@ class Detections(NamedTuple): frequencies: torch.Tensor features: torch.Tensor + def numpy(self) -> DetectionsArray: + return DetectionsArray( + scores=self.scores.detach().numpy(), + sizes=self.sizes.detach().numpy(), + class_scores=self.class_scores.detach().numpy(), + times=self.times.detach().numpy(), + frequencies=self.frequencies.detach().numpy(), + features=self.features.detach().numpy(), + ) + @dataclass class BatDetect2Prediction: @@ -95,10 +114,10 @@ class BatDetect2Prediction: class PostprocessorProtocol(Protocol): """Protocol defining the interface for the full postprocessing pipeline.""" - def __call__(self, output: ModelOutput) -> List[Detections]: ... + def __call__(self, output: ModelOutput) -> List[DetectionsTensor]: ... def get_detections( self, output: ModelOutput, clips: Optional[List[data.Clip]] = None, - ) -> List[Detections]: ... + ) -> List[DetectionsTensor]: ... diff --git a/src/batdetect2/typing/targets.py b/src/batdetect2/typing/targets.py index db74baf..2846a0e 100644 --- a/src/batdetect2/typing/targets.py +++ b/src/batdetect2/typing/targets.py @@ -12,8 +12,8 @@ that components responsible for these tasks can be interacted with consistently throughout BatDetect2. """ -from collections.abc import Callable, Iterable -from typing import List, Optional, Protocol, Tuple +from collections.abc import Callable +from typing import List, Optional, Protocol import numpy as np from soundevent import data