From dbd2d30ead9c9ba4a2b10f7fd976f29884f729e7 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 19 Nov 2025 13:54:57 +0000 Subject: [PATCH] save predictions as individual nc files for speed --- src/batdetect2/data/predictions/raw.py | 334 +++++++++---------- tests/conftest.py | 6 +- tests/test_data/test_predictions/__init__.py | 0 tests/test_data/test_predictions/test_raw.py | 64 ++++ 4 files changed, 231 insertions(+), 173 deletions(-) create mode 100644 tests/test_data/test_predictions/__init__.py create mode 100644 tests/test_data/test_predictions/test_raw.py diff --git a/src/batdetect2/data/predictions/raw.py b/src/batdetect2/data/predictions/raw.py index 416c74b..08cd3d3 100644 --- a/src/batdetect2/data/predictions/raw.py +++ b/src/batdetect2/data/predictions/raw.py @@ -57,187 +57,181 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): path: data.PathLike, audio_dir: Optional[data.PathLike] = None, ) -> None: - num_features = 0 - - tree = xr.DataTree() - - for prediction in predictions: - clip = prediction.clip - recording = clip.recording - - if audio_dir is not None: - recording = recording.model_copy( - update=dict( - path=make_path_relative(recording.path, audio_dir) - ) - ) - - clip_data = defaultdict(list) - - for pred in prediction.predictions: - detection_id = str(uuid4()) - - clip_data["detection_id"].append(detection_id) - clip_data["detection_score"].append(pred.detection_score) - - start_time, low_freq, end_time, high_freq = compute_bounds( - pred.geometry - ) - - clip_data["start_time"].append(start_time) - clip_data["end_time"].append(end_time) - clip_data["low_freq"].append(low_freq) - clip_data["high_freq"].append(high_freq) - - clip_data["geometry"].append(pred.geometry.model_dump_json()) - - top_class_index = int(np.argmax(pred.class_scores)) - top_class_score = float(pred.class_scores[top_class_index]) - top_class = self.targets.class_names[top_class_index] - - clip_data["top_class"].append(top_class) - clip_data["top_class_score"].append(top_class_score) - - clip_data["class_scores"].append(pred.class_scores) - clip_data["features"].append(pred.features) - - num_features = len(pred.features) - - data_vars = { - "score": (["detection"], clip_data["detection_score"]), - "start_time": (["detection"], clip_data["start_time"]), - "end_time": (["detection"], clip_data["end_time"]), - "low_freq": (["detection"], clip_data["low_freq"]), - "high_freq": (["detection"], clip_data["high_freq"]), - "top_class": (["detection"], clip_data["top_class"]), - "top_class_score": ( - ["detection"], - clip_data["top_class_score"], - ), - } - - coords = { - "detection": ("detection", clip_data["detection_id"]), - "clip_start": clip.start_time, - "clip_end": clip.end_time, - "clip_id": str(clip.uuid), - } - - if self.include_class_scores: - data_vars["class_scores"] = ( - ["detection", "classes"], - clip_data["class_scores"], - ) - coords["classes"] = ("classes", self.targets.class_names) - - if self.include_features: - data_vars["features"] = ( - ["detection", "feature"], - clip_data["features"], - ) - coords["feature"] = ("feature", np.arange(num_features)) - - if self.include_geometry: - data_vars["geometry"] = (["detection"], clip_data["geometry"]) - - dataset = xr.Dataset( - data_vars=data_vars, - coords=coords, - attrs={ - "recording": recording.model_dump_json(exclude_none=True), - }, - ) - - tree = tree.assign( - { - str(clip.uuid): xr.DataTree( - dataset=dataset, - name=str(clip.uuid), - ) - } - ) - path = Path(path) - if not path.suffix == ".nc": - path = Path(path).with_suffix(".nc") + if not path.exists(): + path.mkdir(parents=True) - tree.to_netcdf(path) + for prediction in predictions: + logger.debug(f"Saving clip predictions {prediction.clip.uuid}") + clip = prediction.clip + dataset = self.pred_to_xr(prediction, audio_dir) + dataset.to_netcdf(path / f"{clip.uuid}.nc") def load(self, path: data.PathLike) -> List[BatDetect2Prediction]: path = Path(path) - - root = xr.load_datatree(path) - + files = list(path.glob("*.nc")) predictions: List[BatDetect2Prediction] = [] - for _, clip_data in root.items(): - logger.debug(f"Loading clip {clip_data.clip_id.item()}") - recording = data.Recording.model_validate_json( - clip_data.attrs["recording"] - ) - - clip_id = clip_data.clip_id.item() - clip = data.Clip( - recording=recording, - uuid=UUID(clip_id), - start_time=clip_data.clip_start, - end_time=clip_data.clip_end, - ) - - sound_events = [] - - for detection in clip_data.coords["detection"]: - detection_data = clip_data.sel(detection=detection) - score = detection_data.score.item() - - if "geometry" in clip_data and self.parse_full_geometry: - geometry = data.geometry_validate( - detection_data.geometry.item() - ) - else: - start_time = detection_data.start_time - end_time = detection_data.end_time - low_freq = detection_data.low_freq - high_freq = detection_data.high_freq - geometry = data.BoundingBox.model_construct( - coordinates=[start_time, low_freq, end_time, high_freq] - ) - - if "class_scores" in detection_data: - class_scores = detection_data.class_scores.data - else: - class_scores = np.zeros(len(self.targets.class_names)) - class_index = self.targets.class_names.index( - detection_data.top_class.item() - ) - class_scores[class_index] = ( - detection_data.top_class_score.item() - ) - - if "features" in detection_data: - features = detection_data.features.data - else: - features = np.zeros(0) - - sound_events.append( - RawPrediction( - geometry=geometry, - detection_score=score, - class_scores=class_scores, - features=features, - ) - ) - - predictions.append( - BatDetect2Prediction( - clip=clip, - predictions=sound_events, - ) - ) + for filepath in files: + logger.debug(f"Loading clip predictions {filepath}") + clip_data = xr.load_dataset(filepath) + prediction = self.pred_from_xr(clip_data) + predictions.append(prediction) return predictions + def pred_to_xr( + self, + prediction: BatDetect2Prediction, + audio_dir: Optional[data.PathLike] = None, + ) -> xr.Dataset: + clip = prediction.clip + recording = clip.recording + num_features = 0 + + if audio_dir is not None: + recording = recording.model_copy( + update=dict(path=make_path_relative(recording.path, audio_dir)) + ) + + data = defaultdict(list) + + for pred in prediction.predictions: + detection_id = str(uuid4()) + + data["detection_id"].append(detection_id) + data["detection_score"].append(pred.detection_score) + + start_time, low_freq, end_time, high_freq = compute_bounds( + pred.geometry + ) + + data["start_time"].append(start_time) + data["end_time"].append(end_time) + data["low_freq"].append(low_freq) + data["high_freq"].append(high_freq) + + data["geometry"].append(pred.geometry.model_dump_json()) + + top_class_index = int(np.argmax(pred.class_scores)) + top_class_score = float(pred.class_scores[top_class_index]) + top_class = self.targets.class_names[top_class_index] + + data["top_class"].append(top_class) + data["top_class_score"].append(top_class_score) + + data["class_scores"].append(pred.class_scores) + data["features"].append(pred.features) + + num_features = len(pred.features) + + data_vars = { + "score": (["detection"], data["detection_score"]), + "start_time": (["detection"], data["start_time"]), + "end_time": (["detection"], data["end_time"]), + "low_freq": (["detection"], data["low_freq"]), + "high_freq": (["detection"], data["high_freq"]), + "top_class": (["detection"], data["top_class"]), + "top_class_score": (["detection"], data["top_class_score"]), + } + + coords = { + "detection": ("detection", data["detection_id"]), + "clip_start": clip.start_time, + "clip_end": clip.end_time, + "clip_id": str(clip.uuid), + } + + if self.include_class_scores: + class_scores = np.stack(data["class_scores"], axis=0) + data_vars["class_scores"] = ( + ["detection", "classes"], + class_scores, + ) + coords["classes"] = ("classes", self.targets.class_names) + + if self.include_features: + features = np.stack(data["features"], axis=0) + data_vars["features"] = (["detection", "feature"], features) + coords["feature"] = ("feature", np.arange(num_features)) + + if self.include_geometry: + data_vars["geometry"] = (["detection"], data["geometry"]) + + return xr.Dataset( + data_vars=data_vars, + coords=coords, + attrs={ + "recording": recording.model_dump_json(exclude_none=True), + }, + ) + + def pred_from_xr(self, dataset: xr.Dataset) -> BatDetect2Prediction: + clip_data = dataset + clip_id = clip_data.clip_id.item() + + recording = data.Recording.model_validate_json( + clip_data.attrs["recording"] + ) + + clip_id = clip_data.clip_id.item() + clip = data.Clip( + recording=recording, + uuid=UUID(clip_id), + start_time=clip_data.clip_start, + end_time=clip_data.clip_end, + ) + + sound_events = [] + + for detection in clip_data.coords["detection"]: + detection_data = clip_data.sel(detection=detection) + score = detection_data.score.item() + + if "geometry" in clip_data and self.parse_full_geometry: + geometry = data.geometry_validate( + detection_data.geometry.item() + ) + else: + start_time = detection_data.start_time + end_time = detection_data.end_time + low_freq = detection_data.low_freq + high_freq = detection_data.high_freq + geometry = data.BoundingBox.model_construct( + coordinates=[start_time, low_freq, end_time, high_freq] + ) + + if "class_scores" in detection_data: + class_scores = detection_data.class_scores.data + else: + class_scores = np.zeros(len(self.targets.class_names)) + class_index = self.targets.class_names.index( + detection_data.top_class.item() + ) + class_scores[class_index] = ( + detection_data.top_class_score.item() + ) + + if "features" in detection_data: + features = detection_data.features.data + else: + features = np.zeros(0) + + sound_events.append( + RawPrediction( + geometry=geometry, + detection_score=score, + class_scores=class_scores, + features=features, + ) + ) + + return BatDetect2Prediction( + clip=clip, + predictions=sound_events, + ) + @prediction_formatters.register(RawOutputConfig) @staticmethod def from_config(config: RawOutputConfig, targets: TargetProtocol): diff --git a/tests/conftest.py b/tests/conftest.py index 5f5e32b..43922a4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,17 +9,17 @@ import soundfile as sf from scipy import signal from soundevent import data, terms +from batdetect2.audio import build_audio_loader +from batdetect2.audio.clips import build_clipper from batdetect2.data import DatasetConfig, load_dataset from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations from batdetect2.preprocess import build_preprocessor -from batdetect2.preprocess.audio import build_audio_loader from batdetect2.targets import ( TargetConfig, build_targets, call_type, ) from batdetect2.targets.classes import TargetClassConfig -from batdetect2.train.clips import build_clipper from batdetect2.train.labels import build_clip_labeler from batdetect2.typing import ( ClipLabeller, @@ -442,7 +442,7 @@ def example_annotations( ) -> List[data.ClipAnnotation]: annotations = load_dataset(example_dataset) assert len(annotations) == 3 - return annotations + return list(annotations) @pytest.fixture diff --git a/tests/test_data/test_predictions/__init__.py b/tests/test_data/test_predictions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_data/test_predictions/test_raw.py b/tests/test_data/test_predictions/test_raw.py new file mode 100644 index 0000000..cdfcafe --- /dev/null +++ b/tests/test_data/test_predictions/test_raw.py @@ -0,0 +1,64 @@ +from pathlib import Path + +import numpy as np +import pytest +from soundevent import data + +from batdetect2.data.predictions import RawOutputConfig, build_output_formatter +from batdetect2.typing import ( + BatDetect2Prediction, + RawPrediction, + TargetProtocol, +) + + +@pytest.fixture +def sample_formatter(sample_targets: TargetProtocol): + return build_output_formatter( + config=RawOutputConfig(), + targets=sample_targets, + ) + + +def test_roundtrip( + sample_formatter, + clip: data.Clip, + sample_targets: TargetProtocol, + tmp_path: Path, +): + detections = [ + RawPrediction( + geometry=data.BoundingBox( + coordinates=list(np.random.uniform(size=[4])) + ), + detection_score=0.5, + class_scores=np.random.uniform( + size=len(sample_targets.class_names) + ), + features=np.random.uniform(size=32), + ) + for _ in range(10) + ] + + prediction = BatDetect2Prediction(clip=clip, predictions=detections) + + path = tmp_path / "predictions" + + sample_formatter.save(predictions=[prediction], path=path) + + recovered = sample_formatter.load(path=path) + + assert len(recovered) == 1 + assert recovered[0].clip == prediction.clip + + for recovered_prediction, detection in zip( + recovered[0].predictions, detections + ): + assert ( + recovered_prediction.detection_score == detection.detection_score + ) + assert ( + recovered_prediction.class_scores == detection.class_scores + ).all() + assert (recovered_prediction.features == detection.features).all() + assert recovered_prediction.geometry == detection.geometry