mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Improve performance of postprocessing code
This commit is contained in:
parent
b997a122f1
commit
51d0a49da9
@ -543,7 +543,9 @@ class Postprocessor(PostprocessorProtocol):
|
||||
]
|
||||
|
||||
def get_sound_event_predictions(
|
||||
self, output: ModelOutput, clips: List[data.Clip]
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
) -> List[List[BatDetect2Prediction]]:
|
||||
raw_predictions = self.get_raw_predictions(output, clips)
|
||||
return [
|
||||
@ -553,8 +555,7 @@ class Postprocessor(PostprocessorProtocol):
|
||||
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
||||
raw,
|
||||
recording=clip.recording,
|
||||
sound_event_decoder=self.targets.decode_class,
|
||||
generic_class_tags=self.targets.generic_class_tags,
|
||||
targets=self.targets,
|
||||
classification_threshold=self.config.classification_threshold,
|
||||
),
|
||||
)
|
||||
@ -590,8 +591,7 @@ class Postprocessor(PostprocessorProtocol):
|
||||
convert_raw_predictions_to_clip_prediction(
|
||||
prediction,
|
||||
clip,
|
||||
sound_event_decoder=self.targets.decode_class,
|
||||
generic_class_tags=self.targets.generic_class_tags,
|
||||
targets=self.targets,
|
||||
classification_threshold=self.config.classification_threshold,
|
||||
)
|
||||
for prediction, clip in zip(raw_predictions, clips)
|
||||
|
||||
@ -33,8 +33,7 @@ import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
|
||||
from batdetect2.targets.classes import SoundEventDecoder
|
||||
from batdetect2.utils.arrays import iterate_over_array
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"convert_xr_dataset_to_raw_prediction",
|
||||
@ -92,25 +91,30 @@ def convert_xr_dataset_to_raw_prediction(
|
||||
"""
|
||||
detections = []
|
||||
|
||||
for det_num in range(detection_dataset.sizes["detection"]):
|
||||
det_info = detection_dataset.sel(detection=det_num)
|
||||
categories = detection_dataset.category.values
|
||||
|
||||
highest_scoring_class = det_info.coords["category"][
|
||||
det_info["classes"].argmax()
|
||||
].item()
|
||||
for score, class_scores, time, freq, dims, feats in zip(
|
||||
detection_dataset["scores"].values,
|
||||
detection_dataset["classes"].values,
|
||||
detection_dataset["time"].values,
|
||||
detection_dataset["frequency"].values,
|
||||
detection_dataset["dimensions"].values,
|
||||
detection_dataset["features"].values,
|
||||
):
|
||||
highest_scoring_class = categories[class_scores.argmax()]
|
||||
|
||||
geom = geometry_decoder(
|
||||
(det_info.time, det_info.frequency),
|
||||
det_info.dimensions,
|
||||
(time, freq),
|
||||
dims,
|
||||
class_name=highest_scoring_class,
|
||||
)
|
||||
|
||||
detections.append(
|
||||
RawPrediction(
|
||||
detection_score=det_info.scores,
|
||||
detection_score=score,
|
||||
geometry=geom,
|
||||
class_scores=det_info.classes,
|
||||
features=det_info.features,
|
||||
class_scores=class_scores,
|
||||
features=feats,
|
||||
)
|
||||
)
|
||||
|
||||
@ -120,8 +124,7 @@ def convert_xr_dataset_to_raw_prediction(
|
||||
def convert_raw_predictions_to_clip_prediction(
|
||||
raw_predictions: List[RawPrediction],
|
||||
clip: data.Clip,
|
||||
sound_event_decoder: SoundEventDecoder,
|
||||
generic_class_tags: List[data.Tag],
|
||||
targets: TargetProtocol,
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only: bool = False,
|
||||
) -> data.ClipPrediction:
|
||||
@ -160,8 +163,7 @@ def convert_raw_predictions_to_clip_prediction(
|
||||
convert_raw_prediction_to_sound_event_prediction(
|
||||
prediction,
|
||||
recording=clip.recording,
|
||||
sound_event_decoder=sound_event_decoder,
|
||||
generic_class_tags=generic_class_tags,
|
||||
targets=targets,
|
||||
classification_threshold=classification_threshold,
|
||||
top_class_only=top_class_only,
|
||||
)
|
||||
@ -173,8 +175,7 @@ def convert_raw_predictions_to_clip_prediction(
|
||||
def convert_raw_prediction_to_sound_event_prediction(
|
||||
raw_prediction: RawPrediction,
|
||||
recording: data.Recording,
|
||||
sound_event_decoder: SoundEventDecoder,
|
||||
generic_class_tags: List[data.Tag],
|
||||
targets: TargetProtocol,
|
||||
classification_threshold: Optional[
|
||||
float
|
||||
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
@ -251,11 +252,11 @@ def convert_raw_prediction_to_sound_event_prediction(
|
||||
tags = [
|
||||
*get_generic_tags(
|
||||
raw_prediction.detection_score,
|
||||
generic_class_tags=generic_class_tags,
|
||||
generic_class_tags=targets.generic_class_tags,
|
||||
),
|
||||
*get_class_tags(
|
||||
raw_prediction.class_scores,
|
||||
sound_event_decoder,
|
||||
targets=targets,
|
||||
top_class_only=top_class_only,
|
||||
threshold=classification_threshold,
|
||||
),
|
||||
@ -297,7 +298,7 @@ def get_generic_tags(
|
||||
]
|
||||
|
||||
|
||||
def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
|
||||
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
|
||||
"""Convert an extracted feature vector DataArray into soundevent Features.
|
||||
|
||||
Parameters
|
||||
@ -320,19 +321,19 @@ def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
|
||||
return [
|
||||
data.Feature(
|
||||
term=data.Term(
|
||||
name=f"batdetect2:{feat_name}",
|
||||
label=feat_name,
|
||||
name=f"batdetect2:f{index}",
|
||||
label=f"BatDetect Feature {index}",
|
||||
definition="Automatically extracted features by BatDetect2",
|
||||
),
|
||||
value=value,
|
||||
)
|
||||
for feat_name, value in iterate_over_array(features)
|
||||
for index, value in enumerate(features)
|
||||
]
|
||||
|
||||
|
||||
def get_class_tags(
|
||||
class_scores: xr.DataArray,
|
||||
sound_event_decoder: SoundEventDecoder,
|
||||
class_scores: np.ndarray,
|
||||
targets: TargetProtocol,
|
||||
top_class_only: bool = False,
|
||||
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
) -> List[data.PredictedTag]:
|
||||
@ -367,11 +368,13 @@ def get_class_tags(
|
||||
"""
|
||||
tags = []
|
||||
|
||||
if threshold is not None:
|
||||
class_scores = class_scores.where(class_scores > threshold, drop=True)
|
||||
for class_name, score in _iterate_sorted(
|
||||
class_scores, targets.class_names
|
||||
):
|
||||
if threshold is not None and score < threshold:
|
||||
continue
|
||||
|
||||
for class_name, score in _iterate_sorted(class_scores):
|
||||
class_tags = sound_event_decoder(class_name)
|
||||
class_tags = targets.decode_class(class_name)
|
||||
|
||||
for tag in class_tags:
|
||||
tags.append(
|
||||
@ -387,9 +390,7 @@ def get_class_tags(
|
||||
return tags
|
||||
|
||||
|
||||
def _iterate_sorted(array: xr.DataArray):
|
||||
dim_name = array.dims[0]
|
||||
coords = array.coords[dim_name].values
|
||||
indices = np.argsort(-array.values)
|
||||
def _iterate_sorted(array: np.ndarray, class_names: List[str]):
|
||||
indices = np.argsort(-array)
|
||||
for index in indices:
|
||||
yield str(coords[index]), float(array.values[index])
|
||||
yield str(class_names[index]), float(array[index])
|
||||
|
||||
@ -14,6 +14,7 @@ system that deal with model predictions.
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional, Protocol
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
@ -72,8 +73,8 @@ class RawPrediction(NamedTuple):
|
||||
|
||||
geometry: data.Geometry
|
||||
detection_score: float
|
||||
class_scores: xr.DataArray
|
||||
features: xr.DataArray
|
||||
class_scores: np.ndarray
|
||||
features: np.ndarray
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -15,7 +15,10 @@ from batdetect2.evaluate.match import (
|
||||
)
|
||||
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||
from batdetect2.plotting.evaluation import plot_example_gallery
|
||||
from batdetect2.postprocess.types import BatDetect2Prediction
|
||||
from batdetect2.postprocess.types import (
|
||||
BatDetect2Prediction,
|
||||
PostprocessorProtocol,
|
||||
)
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
@ -114,33 +117,51 @@ class ValidationMetrics(Callback):
|
||||
batch_idx: int,
|
||||
dataloader_idx: int = 0,
|
||||
) -> None:
|
||||
dataset = self.get_dataset(trainer)
|
||||
|
||||
clip_annotations = [
|
||||
_get_subclip(
|
||||
dataset.get_clip_annotation(example_id),
|
||||
start_time=start_time.item(),
|
||||
end_time=end_time.item(),
|
||||
self._matches.extend(
|
||||
_get_batch_clips_and_predictions(
|
||||
batch,
|
||||
outputs,
|
||||
dataset=self.get_dataset(trainer),
|
||||
postprocessor=pl_module.postprocessor,
|
||||
targets=pl_module.targets,
|
||||
)
|
||||
for example_id, start_time, end_time in zip(
|
||||
batch.idx,
|
||||
batch.start_time,
|
||||
batch.end_time,
|
||||
)
|
||||
]
|
||||
|
||||
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
||||
|
||||
raw_predictions = pl_module.postprocessor.get_sound_event_predictions(
|
||||
outputs,
|
||||
clips,
|
||||
)
|
||||
|
||||
|
||||
def _get_batch_clips_and_predictions(
|
||||
batch: TrainExample,
|
||||
outputs: ModelOutput,
|
||||
dataset: LabeledDataset,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
targets: TargetProtocol,
|
||||
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
|
||||
clip_annotations = [
|
||||
_get_subclip(
|
||||
dataset.get_clip_annotation(example_id),
|
||||
start_time=start_time.item(),
|
||||
end_time=end_time.item(),
|
||||
targets=targets,
|
||||
)
|
||||
for example_id, start_time, end_time in zip(
|
||||
batch.idx,
|
||||
batch.start_time,
|
||||
batch.end_time,
|
||||
)
|
||||
]
|
||||
|
||||
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
||||
|
||||
raw_predictions = postprocessor.get_sound_event_predictions(
|
||||
outputs,
|
||||
clips,
|
||||
)
|
||||
|
||||
return [
|
||||
(clip_annotation, clip_predictions)
|
||||
for clip_annotation, clip_predictions in zip(
|
||||
clip_annotations, raw_predictions
|
||||
):
|
||||
self._matches.append((clip_annotation, clip_predictions))
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def _match_all_collected_examples(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user