mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 17:29:34 +01:00
Improve performance of postprocessing code
This commit is contained in:
parent
b997a122f1
commit
51d0a49da9
@ -543,7 +543,9 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def get_sound_event_predictions(
|
def get_sound_event_predictions(
|
||||||
self, output: ModelOutput, clips: List[data.Clip]
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
) -> List[List[BatDetect2Prediction]]:
|
) -> List[List[BatDetect2Prediction]]:
|
||||||
raw_predictions = self.get_raw_predictions(output, clips)
|
raw_predictions = self.get_raw_predictions(output, clips)
|
||||||
return [
|
return [
|
||||||
@ -553,8 +555,7 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
||||||
raw,
|
raw,
|
||||||
recording=clip.recording,
|
recording=clip.recording,
|
||||||
sound_event_decoder=self.targets.decode_class,
|
targets=self.targets,
|
||||||
generic_class_tags=self.targets.generic_class_tags,
|
|
||||||
classification_threshold=self.config.classification_threshold,
|
classification_threshold=self.config.classification_threshold,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -590,8 +591,7 @@ class Postprocessor(PostprocessorProtocol):
|
|||||||
convert_raw_predictions_to_clip_prediction(
|
convert_raw_predictions_to_clip_prediction(
|
||||||
prediction,
|
prediction,
|
||||||
clip,
|
clip,
|
||||||
sound_event_decoder=self.targets.decode_class,
|
targets=self.targets,
|
||||||
generic_class_tags=self.targets.generic_class_tags,
|
|
||||||
classification_threshold=self.config.classification_threshold,
|
classification_threshold=self.config.classification_threshold,
|
||||||
)
|
)
|
||||||
for prediction, clip in zip(raw_predictions, clips)
|
for prediction, clip in zip(raw_predictions, clips)
|
||||||
|
|||||||
@ -33,8 +33,7 @@ import xarray as xr
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
|
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
|
||||||
from batdetect2.targets.classes import SoundEventDecoder
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.utils.arrays import iterate_over_array
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"convert_xr_dataset_to_raw_prediction",
|
"convert_xr_dataset_to_raw_prediction",
|
||||||
@ -92,25 +91,30 @@ def convert_xr_dataset_to_raw_prediction(
|
|||||||
"""
|
"""
|
||||||
detections = []
|
detections = []
|
||||||
|
|
||||||
for det_num in range(detection_dataset.sizes["detection"]):
|
categories = detection_dataset.category.values
|
||||||
det_info = detection_dataset.sel(detection=det_num)
|
|
||||||
|
|
||||||
highest_scoring_class = det_info.coords["category"][
|
for score, class_scores, time, freq, dims, feats in zip(
|
||||||
det_info["classes"].argmax()
|
detection_dataset["scores"].values,
|
||||||
].item()
|
detection_dataset["classes"].values,
|
||||||
|
detection_dataset["time"].values,
|
||||||
|
detection_dataset["frequency"].values,
|
||||||
|
detection_dataset["dimensions"].values,
|
||||||
|
detection_dataset["features"].values,
|
||||||
|
):
|
||||||
|
highest_scoring_class = categories[class_scores.argmax()]
|
||||||
|
|
||||||
geom = geometry_decoder(
|
geom = geometry_decoder(
|
||||||
(det_info.time, det_info.frequency),
|
(time, freq),
|
||||||
det_info.dimensions,
|
dims,
|
||||||
class_name=highest_scoring_class,
|
class_name=highest_scoring_class,
|
||||||
)
|
)
|
||||||
|
|
||||||
detections.append(
|
detections.append(
|
||||||
RawPrediction(
|
RawPrediction(
|
||||||
detection_score=det_info.scores,
|
detection_score=score,
|
||||||
geometry=geom,
|
geometry=geom,
|
||||||
class_scores=det_info.classes,
|
class_scores=class_scores,
|
||||||
features=det_info.features,
|
features=feats,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -120,8 +124,7 @@ def convert_xr_dataset_to_raw_prediction(
|
|||||||
def convert_raw_predictions_to_clip_prediction(
|
def convert_raw_predictions_to_clip_prediction(
|
||||||
raw_predictions: List[RawPrediction],
|
raw_predictions: List[RawPrediction],
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
sound_event_decoder: SoundEventDecoder,
|
targets: TargetProtocol,
|
||||||
generic_class_tags: List[data.Tag],
|
|
||||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
top_class_only: bool = False,
|
top_class_only: bool = False,
|
||||||
) -> data.ClipPrediction:
|
) -> data.ClipPrediction:
|
||||||
@ -160,8 +163,7 @@ def convert_raw_predictions_to_clip_prediction(
|
|||||||
convert_raw_prediction_to_sound_event_prediction(
|
convert_raw_prediction_to_sound_event_prediction(
|
||||||
prediction,
|
prediction,
|
||||||
recording=clip.recording,
|
recording=clip.recording,
|
||||||
sound_event_decoder=sound_event_decoder,
|
targets=targets,
|
||||||
generic_class_tags=generic_class_tags,
|
|
||||||
classification_threshold=classification_threshold,
|
classification_threshold=classification_threshold,
|
||||||
top_class_only=top_class_only,
|
top_class_only=top_class_only,
|
||||||
)
|
)
|
||||||
@ -173,8 +175,7 @@ def convert_raw_predictions_to_clip_prediction(
|
|||||||
def convert_raw_prediction_to_sound_event_prediction(
|
def convert_raw_prediction_to_sound_event_prediction(
|
||||||
raw_prediction: RawPrediction,
|
raw_prediction: RawPrediction,
|
||||||
recording: data.Recording,
|
recording: data.Recording,
|
||||||
sound_event_decoder: SoundEventDecoder,
|
targets: TargetProtocol,
|
||||||
generic_class_tags: List[data.Tag],
|
|
||||||
classification_threshold: Optional[
|
classification_threshold: Optional[
|
||||||
float
|
float
|
||||||
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
@ -251,11 +252,11 @@ def convert_raw_prediction_to_sound_event_prediction(
|
|||||||
tags = [
|
tags = [
|
||||||
*get_generic_tags(
|
*get_generic_tags(
|
||||||
raw_prediction.detection_score,
|
raw_prediction.detection_score,
|
||||||
generic_class_tags=generic_class_tags,
|
generic_class_tags=targets.generic_class_tags,
|
||||||
),
|
),
|
||||||
*get_class_tags(
|
*get_class_tags(
|
||||||
raw_prediction.class_scores,
|
raw_prediction.class_scores,
|
||||||
sound_event_decoder,
|
targets=targets,
|
||||||
top_class_only=top_class_only,
|
top_class_only=top_class_only,
|
||||||
threshold=classification_threshold,
|
threshold=classification_threshold,
|
||||||
),
|
),
|
||||||
@ -297,7 +298,7 @@ def get_generic_tags(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
|
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
|
||||||
"""Convert an extracted feature vector DataArray into soundevent Features.
|
"""Convert an extracted feature vector DataArray into soundevent Features.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -320,19 +321,19 @@ def get_prediction_features(features: xr.DataArray) -> List[data.Feature]:
|
|||||||
return [
|
return [
|
||||||
data.Feature(
|
data.Feature(
|
||||||
term=data.Term(
|
term=data.Term(
|
||||||
name=f"batdetect2:{feat_name}",
|
name=f"batdetect2:f{index}",
|
||||||
label=feat_name,
|
label=f"BatDetect Feature {index}",
|
||||||
definition="Automatically extracted features by BatDetect2",
|
definition="Automatically extracted features by BatDetect2",
|
||||||
),
|
),
|
||||||
value=value,
|
value=value,
|
||||||
)
|
)
|
||||||
for feat_name, value in iterate_over_array(features)
|
for index, value in enumerate(features)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def get_class_tags(
|
def get_class_tags(
|
||||||
class_scores: xr.DataArray,
|
class_scores: np.ndarray,
|
||||||
sound_event_decoder: SoundEventDecoder,
|
targets: TargetProtocol,
|
||||||
top_class_only: bool = False,
|
top_class_only: bool = False,
|
||||||
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
threshold: Optional[float] = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
) -> List[data.PredictedTag]:
|
) -> List[data.PredictedTag]:
|
||||||
@ -367,11 +368,13 @@ def get_class_tags(
|
|||||||
"""
|
"""
|
||||||
tags = []
|
tags = []
|
||||||
|
|
||||||
if threshold is not None:
|
for class_name, score in _iterate_sorted(
|
||||||
class_scores = class_scores.where(class_scores > threshold, drop=True)
|
class_scores, targets.class_names
|
||||||
|
):
|
||||||
|
if threshold is not None and score < threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
for class_name, score in _iterate_sorted(class_scores):
|
class_tags = targets.decode_class(class_name)
|
||||||
class_tags = sound_event_decoder(class_name)
|
|
||||||
|
|
||||||
for tag in class_tags:
|
for tag in class_tags:
|
||||||
tags.append(
|
tags.append(
|
||||||
@ -387,9 +390,7 @@ def get_class_tags(
|
|||||||
return tags
|
return tags
|
||||||
|
|
||||||
|
|
||||||
def _iterate_sorted(array: xr.DataArray):
|
def _iterate_sorted(array: np.ndarray, class_names: List[str]):
|
||||||
dim_name = array.dims[0]
|
indices = np.argsort(-array)
|
||||||
coords = array.coords[dim_name].values
|
|
||||||
indices = np.argsort(-array.values)
|
|
||||||
for index in indices:
|
for index in indices:
|
||||||
yield str(coords[index]), float(array.values[index])
|
yield str(class_names[index]), float(array[index])
|
||||||
|
|||||||
@ -14,6 +14,7 @@ system that deal with model predictions.
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, NamedTuple, Optional, Protocol
|
from typing import List, NamedTuple, Optional, Protocol
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -72,8 +73,8 @@ class RawPrediction(NamedTuple):
|
|||||||
|
|
||||||
geometry: data.Geometry
|
geometry: data.Geometry
|
||||||
detection_score: float
|
detection_score: float
|
||||||
class_scores: xr.DataArray
|
class_scores: np.ndarray
|
||||||
features: xr.DataArray
|
features: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@ -15,7 +15,10 @@ from batdetect2.evaluate.match import (
|
|||||||
)
|
)
|
||||||
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||||
from batdetect2.plotting.evaluation import plot_example_gallery
|
from batdetect2.plotting.evaluation import plot_example_gallery
|
||||||
from batdetect2.postprocess.types import BatDetect2Prediction
|
from batdetect2.postprocess.types import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
PostprocessorProtocol,
|
||||||
|
)
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
@ -114,14 +117,30 @@ class ValidationMetrics(Callback):
|
|||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
dataloader_idx: int = 0,
|
dataloader_idx: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
dataset = self.get_dataset(trainer)
|
self._matches.extend(
|
||||||
|
_get_batch_clips_and_predictions(
|
||||||
|
batch,
|
||||||
|
outputs,
|
||||||
|
dataset=self.get_dataset(trainer),
|
||||||
|
postprocessor=pl_module.postprocessor,
|
||||||
|
targets=pl_module.targets,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_batch_clips_and_predictions(
|
||||||
|
batch: TrainExample,
|
||||||
|
outputs: ModelOutput,
|
||||||
|
dataset: LabeledDataset,
|
||||||
|
postprocessor: PostprocessorProtocol,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
|
||||||
clip_annotations = [
|
clip_annotations = [
|
||||||
_get_subclip(
|
_get_subclip(
|
||||||
dataset.get_clip_annotation(example_id),
|
dataset.get_clip_annotation(example_id),
|
||||||
start_time=start_time.item(),
|
start_time=start_time.item(),
|
||||||
end_time=end_time.item(),
|
end_time=end_time.item(),
|
||||||
targets=pl_module.targets,
|
targets=targets,
|
||||||
)
|
)
|
||||||
for example_id, start_time, end_time in zip(
|
for example_id, start_time, end_time in zip(
|
||||||
batch.idx,
|
batch.idx,
|
||||||
@ -132,15 +151,17 @@ class ValidationMetrics(Callback):
|
|||||||
|
|
||||||
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
|
||||||
|
|
||||||
raw_predictions = pl_module.postprocessor.get_sound_event_predictions(
|
raw_predictions = postprocessor.get_sound_event_predictions(
|
||||||
outputs,
|
outputs,
|
||||||
clips,
|
clips,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
(clip_annotation, clip_predictions)
|
||||||
for clip_annotation, clip_predictions in zip(
|
for clip_annotation, clip_predictions in zip(
|
||||||
clip_annotations, raw_predictions
|
clip_annotations, raw_predictions
|
||||||
):
|
)
|
||||||
self._matches.append((clip_annotation, clip_predictions))
|
]
|
||||||
|
|
||||||
|
|
||||||
def _match_all_collected_examples(
|
def _match_all_collected_examples(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user