Improve performance of postprocessing code

This commit is contained in:
mbsantiago 2025-08-12 17:47:17 +01:00
parent b997a122f1
commit 51d0a49da9
4 changed files with 87 additions and 64 deletions

View File

@ -543,7 +543,9 @@ class Postprocessor(PostprocessorProtocol):
]
def get_sound_event_predictions(
self, output: ModelOutput, clips: List[data.Clip]
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[BatDetect2Prediction]]:
raw_predictions = self.get_raw_predictions(output, clips)
return [
@ -553,8 +555,7 @@ class Postprocessor(PostprocessorProtocol):
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
sound_event_decoder=self.targets.decode_class,
generic_class_tags=self.targets.generic_class_tags,
targets=self.targets,
classification_threshold=self.config.classification_threshold,
),
)
@ -590,8 +591,7 @@ class Postprocessor(PostprocessorProtocol):
convert_raw_predictions_to_clip_prediction(
prediction,
clip,
sound_event_decoder=self.targets.decode_class,
generic_class_tags=self.targets.generic_class_tags,
targets=self.targets,
classification_threshold=self.config.classification_threshold,
)
for prediction, clip in zip(raw_predictions, clips)

View File

@ -33,8 +33,7 @@ import xarray as xr
from soundevent import data
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
from batdetect2.targets.classes import SoundEventDecoder
from batdetect2.utils.arrays import iterate_over_array
from batdetect2.targets.types import TargetProtocol
__all__ = [
"convert_xr_dataset_to_raw_prediction",
@ -92,25 +91,30 @@ def convert_xr_dataset_to_raw_prediction(
"""
detections = []
for det_num in range(detection_dataset.sizes["detection"]):
det_info = detection_dataset.sel(detection=det_num)
categories = detection_dataset.category.values
highest_scoring_class = det_info.coords["category"][
det_info["classes"].argmax()
].item()
for score, class_scores, time, freq, dims, feats in zip(
detection_dataset["scores"].values,
detection_dataset["classes"].values,
detection_dataset["time"].values,
detection_dataset["frequency"].values,
detection_dataset["dimensions"].values,
detection_dataset["features"].values,
):
highest_scoring_class = categories[class_scores.argmax()]
geom = geometry_decoder(
(det_info.time, det_info.frequency),
det_info.dimensions,
(time, freq),
dims,
class_name=highest_scoring_class,
)
detections.append(
RawPrediction(
detection_score=det_info.scores,
detection_score=score,
geometry=geom,
class_scores=det_info.classes,
features=det_info.features,
class_scores=class_scores,
features=feats,
)
)
@ -120,8 +124,7 @@ def convert_xr_dataset_to_raw_prediction(
def convert_raw_predictions_to_clip_prediction(
raw_predictions: List[RawPrediction],
clip: data.Clip,
sound_event_decoder: SoundEventDecoder,
generic_class_tags: List[data.Tag],
targets: TargetProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False,
) -> data.ClipPrediction:
@ -160,8 +163,7 @@ def convert_raw_predictions_to_clip_prediction(
convert_raw_prediction_to_sound_event_prediction(
prediction,
recording=clip.recording,
sound_event_decoder=sound_event_decoder,
generic_class_tags=generic_class_tags,
targets=targets,
classification_threshold=classification_threshold,
top_class_only=top_class_only,
)
@ -173,8 +175,7 @@ def convert_raw_predictions_to_clip_prediction(
def convert_raw_prediction_to_sound_event_prediction(
raw_prediction: RawPrediction,
recording: data.Recording,
sound_event_decoder: SoundEventDecoder,
generic_class_tags: List[data.Tag],
targets: TargetProtocol,
classification_threshold: Optional[
float
] = DEFAULT_CLASSIFICATION_THRESHOLD,
@ -251,11 +252,11 @@ def convert_raw_prediction_to_sound_event_prediction(
tags = [
*get_generic_tags(
raw_prediction.detection_score,
generic_class_tags=generic_class_tags,
generic_class_tags=targets.generic_class_tags,
),
*get_class_tags(
raw_prediction.class_scores,
sound_event_decoder,
targets=targets,
top_class_only=top_class_only,
threshold=classification_threshold,
),
@ -297,7 +298,7 @@ def get_generic_tags(
]
def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
"""Convert an extracted feature vector DataArray into soundevent Features.
Parameters
@ -320,19 +321,19 @@ def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
return [
data.Feature(
term=data.Term(
name=f"batdetect2:{feat_name}",
label=feat_name,
name=f"batdetect2:f{index}",
label=f"BatDetect Feature {index}",
definition="Automatically extracted features by BatDetect2",
),
value=value,
)
for feat_name, value in iterate_over_array(features)
for index, value in enumerate(features)
]
def get_class_tags(
class_scores: xr.DataArray,
sound_event_decoder: SoundEventDecoder,
class_scores: np.ndarray,
targets: TargetProtocol,
top_class_only: bool = False,
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.PredictedTag]:
@ -367,11 +368,13 @@ def get_class_tags(
"""
tags = []
if threshold is not None:
class_scores = class_scores.where(class_scores > threshold, drop=True)
for class_name, score in _iterate_sorted(
class_scores, targets.class_names
):
if threshold is not None and score < threshold:
continue
for class_name, score in _iterate_sorted(class_scores):
class_tags = sound_event_decoder(class_name)
class_tags = targets.decode_class(class_name)
for tag in class_tags:
tags.append(
@ -387,9 +390,7 @@ def get_class_tags(
return tags
def _iterate_sorted(array: xr.DataArray):
dim_name = array.dims[0]
coords = array.coords[dim_name].values
indices = np.argsort(-array.values)
def _iterate_sorted(array: np.ndarray, class_names: List[str]):
indices = np.argsort(-array)
for index in indices:
yield str(coords[index]), float(array.values[index])
yield str(class_names[index]), float(array[index])

View File

@ -14,6 +14,7 @@ system that deal with model predictions.
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Protocol
import numpy as np
import xarray as xr
from soundevent import data
@ -72,8 +73,8 @@ class RawPrediction(NamedTuple):
geometry: data.Geometry
detection_score: float
class_scores: xr.DataArray
features: xr.DataArray
class_scores: np.ndarray
features: np.ndarray
@dataclass

View File

@ -15,7 +15,10 @@ from batdetect2.evaluate.match import (
)
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.postprocess.types import (
BatDetect2Prediction,
PostprocessorProtocol,
)
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.lightning import TrainingModule
@ -114,33 +117,51 @@ class ValidationMetrics(Callback):
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
dataset = self.get_dataset(trainer)
clip_annotations = [
_get_subclip(
dataset.get_clip_annotation(example_id),
start_time=start_time.item(),
end_time=end_time.item(),
self._matches.extend(
_get_batch_clips_and_predictions(
batch,
outputs,
dataset=self.get_dataset(trainer),
postprocessor=pl_module.postprocessor,
targets=pl_module.targets,
)
for example_id, start_time, end_time in zip(
batch.idx,
batch.start_time,
batch.end_time,
)
]
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = pl_module.postprocessor.get_sound_event_predictions(
outputs,
clips,
)
def _get_batch_clips_and_predictions(
batch: TrainExample,
outputs: ModelOutput,
dataset: LabeledDataset,
postprocessor: PostprocessorProtocol,
targets: TargetProtocol,
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
clip_annotations = [
_get_subclip(
dataset.get_clip_annotation(example_id),
start_time=start_time.item(),
end_time=end_time.item(),
targets=targets,
)
for example_id, start_time, end_time in zip(
batch.idx,
batch.start_time,
batch.end_time,
)
]
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = postprocessor.get_sound_event_predictions(
outputs,
clips,
)
return [
(clip_annotation, clip_predictions)
for clip_annotation, clip_predictions in zip(
clip_annotations, raw_predictions
):
self._matches.append((clip_annotation, clip_predictions))
)
]
def _match_all_collected_examples(