From 51d0a49da9d8ef9381cff46911fe6c01cf16766a Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 12 Aug 2025 17:47:17 +0100 Subject: [PATCH] Improve performance of postprocessing code --- src/batdetect2/postprocess/__init__.py | 10 ++-- src/batdetect2/postprocess/decoding.py | 71 +++++++++++++------------- src/batdetect2/postprocess/types.py | 5 +- src/batdetect2/train/callbacks.py | 65 +++++++++++++++-------- 4 files changed, 87 insertions(+), 64 deletions(-) diff --git a/src/batdetect2/postprocess/__init__.py b/src/batdetect2/postprocess/__init__.py index eb7b19d..8dcc53b 100644 --- a/src/batdetect2/postprocess/__init__.py +++ b/src/batdetect2/postprocess/__init__.py @@ -543,7 +543,9 @@ class Postprocessor(PostprocessorProtocol): ] def get_sound_event_predictions( - self, output: ModelOutput, clips: List[data.Clip] + self, + output: ModelOutput, + clips: List[data.Clip], ) -> List[List[BatDetect2Prediction]]: raw_predictions = self.get_raw_predictions(output, clips) return [ @@ -553,8 +555,7 @@ class Postprocessor(PostprocessorProtocol): sound_event_prediction=convert_raw_prediction_to_sound_event_prediction( raw, recording=clip.recording, - sound_event_decoder=self.targets.decode_class, - generic_class_tags=self.targets.generic_class_tags, + targets=self.targets, classification_threshold=self.config.classification_threshold, ), ) @@ -590,8 +591,7 @@ class Postprocessor(PostprocessorProtocol): convert_raw_predictions_to_clip_prediction( prediction, clip, - sound_event_decoder=self.targets.decode_class, - generic_class_tags=self.targets.generic_class_tags, + targets=self.targets, classification_threshold=self.config.classification_threshold, ) for prediction, clip in zip(raw_predictions, clips) diff --git a/src/batdetect2/postprocess/decoding.py b/src/batdetect2/postprocess/decoding.py index 5237d6c..599f34c 100644 --- a/src/batdetect2/postprocess/decoding.py +++ b/src/batdetect2/postprocess/decoding.py @@ -33,8 +33,7 @@ import xarray as xr from soundevent import data from batdetect2.postprocess.types import GeometryDecoder, RawPrediction -from batdetect2.targets.classes import SoundEventDecoder -from batdetect2.utils.arrays import iterate_over_array +from batdetect2.targets.types import TargetProtocol __all__ = [ "convert_xr_dataset_to_raw_prediction", @@ -92,25 +91,30 @@ def convert_xr_dataset_to_raw_prediction( """ detections = [] - for det_num in range(detection_dataset.sizes["detection"]): - det_info = detection_dataset.sel(detection=det_num) + categories = detection_dataset.category.values - highest_scoring_class = det_info.coords["category"][ - det_info["classes"].argmax() - ].item() + for score, class_scores, time, freq, dims, feats in zip( + detection_dataset["scores"].values, + detection_dataset["classes"].values, + detection_dataset["time"].values, + detection_dataset["frequency"].values, + detection_dataset["dimensions"].values, + detection_dataset["features"].values, + ): + highest_scoring_class = categories[class_scores.argmax()] geom = geometry_decoder( - (det_info.time, det_info.frequency), - det_info.dimensions, + (time, freq), + dims, class_name=highest_scoring_class, ) detections.append( RawPrediction( - detection_score=det_info.scores, + detection_score=score, geometry=geom, - class_scores=det_info.classes, - features=det_info.features, + class_scores=class_scores, + features=feats, ) ) @@ -120,8 +124,7 @@ def convert_xr_dataset_to_raw_prediction( def convert_raw_predictions_to_clip_prediction( raw_predictions: List[RawPrediction], clip: data.Clip, - sound_event_decoder: SoundEventDecoder, - generic_class_tags: List[data.Tag], + targets: TargetProtocol, classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, top_class_only: bool = False, ) -> data.ClipPrediction: @@ -160,8 +163,7 @@ def convert_raw_predictions_to_clip_prediction( convert_raw_prediction_to_sound_event_prediction( prediction, recording=clip.recording, - sound_event_decoder=sound_event_decoder, - generic_class_tags=generic_class_tags, + targets=targets, classification_threshold=classification_threshold, top_class_only=top_class_only, ) @@ -173,8 +175,7 @@ def convert_raw_predictions_to_clip_prediction( def convert_raw_prediction_to_sound_event_prediction( raw_prediction: RawPrediction, recording: data.Recording, - sound_event_decoder: SoundEventDecoder, - generic_class_tags: List[data.Tag], + targets: TargetProtocol, classification_threshold: Optional[ float ] = DEFAULT_CLASSIFICATION_THRESHOLD, @@ -251,11 +252,11 @@ def convert_raw_prediction_to_sound_event_prediction( tags = [ *get_generic_tags( raw_prediction.detection_score, - generic_class_tags=generic_class_tags, + generic_class_tags=targets.generic_class_tags, ), *get_class_tags( raw_prediction.class_scores, - sound_event_decoder, + targets=targets, top_class_only=top_class_only, threshold=classification_threshold, ), @@ -297,7 +298,7 @@ def get_generic_tags( ] -def get_prediction_features(features: xr.DataArray) -> List[data.Feature]: +def get_prediction_features(features: np.ndarray) -> List[data.Feature]: """Convert an extracted feature vector DataArray into soundevent Features. Parameters @@ -320,19 +321,19 @@ def get_prediction_features(features: xr.DataArray) -> List[data.Feature]: return [ data.Feature( term=data.Term( - name=f"batdetect2:{feat_name}", - label=feat_name, + name=f"batdetect2:f{index}", + label=f"BatDetect Feature {index}", definition="Automatically extracted features by BatDetect2", ), value=value, ) - for feat_name, value in iterate_over_array(features) + for index, value in enumerate(features) ] def get_class_tags( - class_scores: xr.DataArray, - sound_event_decoder: SoundEventDecoder, + class_scores: np.ndarray, + targets: TargetProtocol, top_class_only: bool = False, threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD, ) -> List[data.PredictedTag]: @@ -367,11 +368,13 @@ def get_class_tags( """ tags = [] - if threshold is not None: - class_scores = class_scores.where(class_scores > threshold, drop=True) + for class_name, score in _iterate_sorted( + class_scores, targets.class_names + ): + if threshold is not None and score < threshold: + continue - for class_name, score in _iterate_sorted(class_scores): - class_tags = sound_event_decoder(class_name) + class_tags = targets.decode_class(class_name) for tag in class_tags: tags.append( @@ -387,9 +390,7 @@ def get_class_tags( return tags -def _iterate_sorted(array: xr.DataArray): - dim_name = array.dims[0] - coords = array.coords[dim_name].values - indices = np.argsort(-array.values) +def _iterate_sorted(array: np.ndarray, class_names: List[str]): + indices = np.argsort(-array) for index in indices: - yield str(coords[index]), float(array.values[index]) + yield str(class_names[index]), float(array[index]) diff --git a/src/batdetect2/postprocess/types.py b/src/batdetect2/postprocess/types.py index a15e4d8..533b3e1 100644 --- a/src/batdetect2/postprocess/types.py +++ b/src/batdetect2/postprocess/types.py @@ -14,6 +14,7 @@ system that deal with model predictions. from dataclasses import dataclass from typing import List, NamedTuple, Optional, Protocol +import numpy as np import xarray as xr from soundevent import data @@ -72,8 +73,8 @@ class RawPrediction(NamedTuple): geometry: data.Geometry detection_score: float - class_scores: xr.DataArray - features: xr.DataArray + class_scores: np.ndarray + features: np.ndarray @dataclass diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index e9f3730..aeeb819 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -15,7 +15,10 @@ from batdetect2.evaluate.match import ( ) from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol from batdetect2.plotting.evaluation import plot_example_gallery -from batdetect2.postprocess.types import BatDetect2Prediction +from batdetect2.postprocess.types import ( + BatDetect2Prediction, + PostprocessorProtocol, +) from batdetect2.targets.types import TargetProtocol from batdetect2.train.dataset import LabeledDataset, TrainExample from batdetect2.train.lightning import TrainingModule @@ -114,33 +117,51 @@ class ValidationMetrics(Callback): batch_idx: int, dataloader_idx: int = 0, ) -> None: - dataset = self.get_dataset(trainer) - - clip_annotations = [ - _get_subclip( - dataset.get_clip_annotation(example_id), - start_time=start_time.item(), - end_time=end_time.item(), + self._matches.extend( + _get_batch_clips_and_predictions( + batch, + outputs, + dataset=self.get_dataset(trainer), + postprocessor=pl_module.postprocessor, targets=pl_module.targets, ) - for example_id, start_time, end_time in zip( - batch.idx, - batch.start_time, - batch.end_time, - ) - ] - - clips = [clip_annotation.clip for clip_annotation in clip_annotations] - - raw_predictions = pl_module.postprocessor.get_sound_event_predictions( - outputs, - clips, ) + +def _get_batch_clips_and_predictions( + batch: TrainExample, + outputs: ModelOutput, + dataset: LabeledDataset, + postprocessor: PostprocessorProtocol, + targets: TargetProtocol, +) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]: + clip_annotations = [ + _get_subclip( + dataset.get_clip_annotation(example_id), + start_time=start_time.item(), + end_time=end_time.item(), + targets=targets, + ) + for example_id, start_time, end_time in zip( + batch.idx, + batch.start_time, + batch.end_time, + ) + ] + + clips = [clip_annotation.clip for clip_annotation in clip_annotations] + + raw_predictions = postprocessor.get_sound_event_predictions( + outputs, + clips, + ) + + return [ + (clip_annotation, clip_predictions) for clip_annotation, clip_predictions in zip( clip_annotations, raw_predictions - ): - self._matches.append((clip_annotation, clip_predictions)) + ) + ] def _match_all_collected_examples(