Improve performance of postprocessing code

This commit is contained in:
mbsantiago 2025-08-12 17:47:17 +01:00
parent b997a122f1
commit 51d0a49da9
4 changed files with 87 additions and 64 deletions

View File

@ -543,7 +543,9 @@ class Postprocessor(PostprocessorProtocol):
] ]
def get_sound_event_predictions( def get_sound_event_predictions(
self, output: ModelOutput, clips: List[data.Clip] self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[BatDetect2Prediction]]: ) -> List[List[BatDetect2Prediction]]:
raw_predictions = self.get_raw_predictions(output, clips) raw_predictions = self.get_raw_predictions(output, clips)
return [ return [
@ -553,8 +555,7 @@ class Postprocessor(PostprocessorProtocol):
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction( sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw, raw,
recording=clip.recording, recording=clip.recording,
sound_event_decoder=self.targets.decode_class, targets=self.targets,
generic_class_tags=self.targets.generic_class_tags,
classification_threshold=self.config.classification_threshold, classification_threshold=self.config.classification_threshold,
), ),
) )
@ -590,8 +591,7 @@ class Postprocessor(PostprocessorProtocol):
convert_raw_predictions_to_clip_prediction( convert_raw_predictions_to_clip_prediction(
prediction, prediction,
clip, clip,
sound_event_decoder=self.targets.decode_class, targets=self.targets,
generic_class_tags=self.targets.generic_class_tags,
classification_threshold=self.config.classification_threshold, classification_threshold=self.config.classification_threshold,
) )
for prediction, clip in zip(raw_predictions, clips) for prediction, clip in zip(raw_predictions, clips)

View File

@ -33,8 +33,7 @@ import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
from batdetect2.targets.classes import SoundEventDecoder from batdetect2.targets.types import TargetProtocol
from batdetect2.utils.arrays import iterate_over_array
__all__ = [ __all__ = [
"convert_xr_dataset_to_raw_prediction", "convert_xr_dataset_to_raw_prediction",
@ -92,25 +91,30 @@ def convert_xr_dataset_to_raw_prediction(
""" """
detections = [] detections = []
for det_num in range(detection_dataset.sizes["detection"]): categories = detection_dataset.category.values
det_info = detection_dataset.sel(detection=det_num)
highest_scoring_class = det_info.coords["category"][ for score, class_scores, time, freq, dims, feats in zip(
det_info["classes"].argmax() detection_dataset["scores"].values,
].item() detection_dataset["classes"].values,
detection_dataset["time"].values,
detection_dataset["frequency"].values,
detection_dataset["dimensions"].values,
detection_dataset["features"].values,
):
highest_scoring_class = categories[class_scores.argmax()]
geom = geometry_decoder( geom = geometry_decoder(
(det_info.time, det_info.frequency), (time, freq),
det_info.dimensions, dims,
class_name=highest_scoring_class, class_name=highest_scoring_class,
) )
detections.append( detections.append(
RawPrediction( RawPrediction(
detection_score=det_info.scores, detection_score=score,
geometry=geom, geometry=geom,
class_scores=det_info.classes, class_scores=class_scores,
features=det_info.features, features=feats,
) )
) )
@ -120,8 +124,7 @@ def convert_xr_dataset_to_raw_prediction(
def convert_raw_predictions_to_clip_prediction( def convert_raw_predictions_to_clip_prediction(
raw_predictions: List[RawPrediction], raw_predictions: List[RawPrediction],
clip: data.Clip, clip: data.Clip,
sound_event_decoder: SoundEventDecoder, targets: TargetProtocol,
generic_class_tags: List[data.Tag],
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False, top_class_only: bool = False,
) -> data.ClipPrediction: ) -> data.ClipPrediction:
@ -160,8 +163,7 @@ def convert_raw_predictions_to_clip_prediction(
convert_raw_prediction_to_sound_event_prediction( convert_raw_prediction_to_sound_event_prediction(
prediction, prediction,
recording=clip.recording, recording=clip.recording,
sound_event_decoder=sound_event_decoder, targets=targets,
generic_class_tags=generic_class_tags,
classification_threshold=classification_threshold, classification_threshold=classification_threshold,
top_class_only=top_class_only, top_class_only=top_class_only,
) )
@ -173,8 +175,7 @@ def convert_raw_predictions_to_clip_prediction(
def convert_raw_prediction_to_sound_event_prediction( def convert_raw_prediction_to_sound_event_prediction(
raw_prediction: RawPrediction, raw_prediction: RawPrediction,
recording: data.Recording, recording: data.Recording,
sound_event_decoder: SoundEventDecoder, targets: TargetProtocol,
generic_class_tags: List[data.Tag],
classification_threshold: Optional[ classification_threshold: Optional[
float float
] = DEFAULT_CLASSIFICATION_THRESHOLD, ] = DEFAULT_CLASSIFICATION_THRESHOLD,
@ -251,11 +252,11 @@ def convert_raw_prediction_to_sound_event_prediction(
tags = [ tags = [
*get_generic_tags( *get_generic_tags(
raw_prediction.detection_score, raw_prediction.detection_score,
generic_class_tags=generic_class_tags, generic_class_tags=targets.generic_class_tags,
), ),
*get_class_tags( *get_class_tags(
raw_prediction.class_scores, raw_prediction.class_scores,
sound_event_decoder, targets=targets,
top_class_only=top_class_only, top_class_only=top_class_only,
threshold=classification_threshold, threshold=classification_threshold,
), ),
@ -297,7 +298,7 @@ def get_generic_tags(
] ]
def get_prediction_features(features: xr.DataArray) -> List[data.Feature]: def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
"""Convert an extracted feature vector DataArray into soundevent Features. """Convert an extracted feature vector DataArray into soundevent Features.
Parameters Parameters
@ -320,19 +321,19 @@ def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
return [ return [
data.Feature( data.Feature(
term=data.Term( term=data.Term(
name=f"batdetect2:{feat_name}", name=f"batdetect2:f{index}",
label=feat_name, label=f"BatDetect Feature {index}",
definition="Automatically extracted features by BatDetect2", definition="Automatically extracted features by BatDetect2",
), ),
value=value, value=value,
) )
for feat_name, value in iterate_over_array(features) for index, value in enumerate(features)
] ]
def get_class_tags( def get_class_tags(
class_scores: xr.DataArray, class_scores: np.ndarray,
sound_event_decoder: SoundEventDecoder, targets: TargetProtocol,
top_class_only: bool = False, top_class_only: bool = False,
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD, threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.PredictedTag]: ) -> List[data.PredictedTag]:
@ -367,11 +368,13 @@ def get_class_tags(
""" """
tags = [] tags = []
if threshold is not None: for class_name, score in _iterate_sorted(
class_scores = class_scores.where(class_scores > threshold, drop=True) class_scores, targets.class_names
):
if threshold is not None and score < threshold:
continue
for class_name, score in _iterate_sorted(class_scores): class_tags = targets.decode_class(class_name)
class_tags = sound_event_decoder(class_name)
for tag in class_tags: for tag in class_tags:
tags.append( tags.append(
@ -387,9 +390,7 @@ def get_class_tags(
return tags return tags
def _iterate_sorted(array: xr.DataArray): def _iterate_sorted(array: np.ndarray, class_names: List[str]):
dim_name = array.dims[0] indices = np.argsort(-array)
coords = array.coords[dim_name].values
indices = np.argsort(-array.values)
for index in indices: for index in indices:
yield str(coords[index]), float(array.values[index]) yield str(class_names[index]), float(array[index])

View File

@ -14,6 +14,7 @@ system that deal with model predictions.
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Protocol from typing import List, NamedTuple, Optional, Protocol
import numpy as np
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
@ -72,8 +73,8 @@ class RawPrediction(NamedTuple):
geometry: data.Geometry geometry: data.Geometry
detection_score: float detection_score: float
class_scores: xr.DataArray class_scores: np.ndarray
features: xr.DataArray features: np.ndarray
@dataclass @dataclass

View File

@ -15,7 +15,10 @@ from batdetect2.evaluate.match import (
) )
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
from batdetect2.plotting.evaluation import plot_example_gallery from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess.types import BatDetect2Prediction from batdetect2.postprocess.types import (
BatDetect2Prediction,
PostprocessorProtocol,
)
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
@ -114,33 +117,51 @@ class ValidationMetrics(Callback):
batch_idx: int, batch_idx: int,
dataloader_idx: int = 0, dataloader_idx: int = 0,
) -> None: ) -> None:
dataset = self.get_dataset(trainer) self._matches.extend(
_get_batch_clips_and_predictions(
clip_annotations = [ batch,
_get_subclip( outputs,
dataset.get_clip_annotation(example_id), dataset=self.get_dataset(trainer),
start_time=start_time.item(), postprocessor=pl_module.postprocessor,
end_time=end_time.item(),
targets=pl_module.targets, targets=pl_module.targets,
) )
for example_id, start_time, end_time in zip(
batch.idx,
batch.start_time,
batch.end_time,
)
]
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = pl_module.postprocessor.get_sound_event_predictions(
outputs,
clips,
) )
def _get_batch_clips_and_predictions(
batch: TrainExample,
outputs: ModelOutput,
dataset: LabeledDataset,
postprocessor: PostprocessorProtocol,
targets: TargetProtocol,
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
clip_annotations = [
_get_subclip(
dataset.get_clip_annotation(example_id),
start_time=start_time.item(),
end_time=end_time.item(),
targets=targets,
)
for example_id, start_time, end_time in zip(
batch.idx,
batch.start_time,
batch.end_time,
)
]
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = postprocessor.get_sound_event_predictions(
outputs,
clips,
)
return [
(clip_annotation, clip_predictions)
for clip_annotation, clip_predictions in zip( for clip_annotation, clip_predictions in zip(
clip_annotations, raw_predictions clip_annotations, raw_predictions
): )
self._matches.append((clip_annotation, clip_predictions)) ]
def _match_all_collected_examples( def _match_all_collected_examples(