From 9c72537dddb9dd26c38c873e6ad0c724d7e5c9de Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 8 Dec 2025 17:11:35 +0000 Subject: [PATCH] Add parquet format for outputs --- src/batdetect2/data/predictions/__init__.py | 3 + src/batdetect2/data/predictions/parquet.py | 194 ++++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 src/batdetect2/data/predictions/parquet.py diff --git a/src/batdetect2/data/predictions/__init__.py b/src/batdetect2/data/predictions/__init__.py index 8c3cc1a..1006476 100644 --- a/src/batdetect2/data/predictions/__init__.py +++ b/src/batdetect2/data/predictions/__init__.py @@ -8,6 +8,7 @@ from batdetect2.data.predictions.base import ( prediction_formatters, ) from batdetect2.data.predictions.batdetect2 import BatDetect2OutputConfig +from batdetect2.data.predictions.parquet import ParquetOutputConfig from batdetect2.data.predictions.raw import RawOutputConfig from batdetect2.data.predictions.soundevent import SoundEventOutputConfig from batdetect2.typing import TargetProtocol @@ -16,6 +17,7 @@ __all__ = [ "build_output_formatter", "get_output_formatter", "BatDetect2OutputConfig", + "ParquetOutputConfig", "RawOutputConfig", "SoundEventOutputConfig", ] @@ -24,6 +26,7 @@ __all__ = [ OutputFormatConfig = Annotated[ Union[ BatDetect2OutputConfig, + ParquetOutputConfig, SoundEventOutputConfig, RawOutputConfig, ], diff --git a/src/batdetect2/data/predictions/parquet.py b/src/batdetect2/data/predictions/parquet.py new file mode 100644 index 0000000..c7c1030 --- /dev/null +++ b/src/batdetect2/data/predictions/parquet.py @@ -0,0 +1,194 @@ +import json +from pathlib import Path +from typing import List, Literal, Optional, Sequence +from uuid import UUID + +import numpy as np +import pandas as pd +from loguru import logger +from soundevent import data +from soundevent.geometry import compute_bounds + +from batdetect2.core import BaseConfig +from batdetect2.data.predictions.base import ( + make_path_relative, + prediction_formatters, +) +from batdetect2.typing import ( + BatDetect2Prediction, + OutputFormatterProtocol, + RawPrediction, + TargetProtocol, +) + + +class ParquetOutputConfig(BaseConfig): + name: Literal["parquet"] = "parquet" + + include_class_scores: bool = True + include_features: bool = True + include_geometry: bool = True + + +class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]): + def __init__( + self, + targets: TargetProtocol, + include_class_scores: bool = True, + include_features: bool = True, + include_geometry: bool = True, + ): + self.targets = targets + self.include_class_scores = include_class_scores + self.include_features = include_features + self.include_geometry = include_geometry + + def format( + self, + predictions: Sequence[BatDetect2Prediction], + ) -> List[BatDetect2Prediction]: + return list(predictions) + + def save( + self, + predictions: Sequence[BatDetect2Prediction], + path: data.PathLike, + audio_dir: Optional[data.PathLike] = None, + ) -> None: + path = Path(path) + + if not path.parent.exists(): + path.parent.mkdir(parents=True) + + # Ensure the file has .parquet extension if it's a file path + if path.suffix != ".parquet": + # If it's a directory, we might want to save as a partitioned dataset or a single file inside + # For now, let's assume the user provides a full file path or a directory where we save 'predictions.parquet' + if path.is_dir() or not path.suffix: + path = path / "predictions.parquet" + + rows = [] + for prediction in predictions: + clip = prediction.clip + recording = clip.recording + + if audio_dir is not None: + recording = recording.model_copy( + update=dict(path=make_path_relative(recording.path, audio_dir)) + ) + + recording_json = recording.model_dump_json(exclude_none=True) + + for pred in prediction.predictions: + row = { + "clip_uuid": str(clip.uuid), + "clip_start_time": clip.start_time, + "clip_end_time": clip.end_time, + "recording_info": recording_json, + "detection_score": pred.detection_score, + } + + if self.include_geometry: + # Store geometry as [start_time, low_freq, end_time, high_freq] + start_time, low_freq, end_time, high_freq = compute_bounds( + pred.geometry + ) + row["start_time"] = start_time + row["low_freq"] = low_freq + row["end_time"] = end_time + row["high_freq"] = high_freq + + # Store full geometry as JSON + row["geometry"] = pred.geometry.model_dump_json() + + if self.include_class_scores: + row["class_scores"] = pred.class_scores.tolist() + + if self.include_features: + row["features"] = pred.features.tolist() + + rows.append(row) + + if not rows: + logger.warning("No predictions to save.") + return + + df = pd.DataFrame(rows) + logger.info(f"Saving {len(df)} predictions to {path}") + df.to_parquet(path, index=False) + + def load(self, path: data.PathLike) -> List[BatDetect2Prediction]: + path = Path(path) + if path.is_dir(): + # Try to find parquet files + files = list(path.glob("*.parquet")) + if not files: + return [] + # Read all and concatenate + dfs = [pd.read_parquet(f) for f in files] + df = pd.concat(dfs, ignore_index=True) + else: + df = pd.read_parquet(path) + + predictions_by_clip = {} + + for _, row in df.iterrows(): + clip_uuid = row["clip_uuid"] + + if clip_uuid not in predictions_by_clip: + recording = data.Recording.model_validate_json(row["recording_info"]) + clip = data.Clip( + uuid=UUID(clip_uuid), + recording=recording, + start_time=row["clip_start_time"], + end_time=row["clip_end_time"], + ) + predictions_by_clip[clip_uuid] = { + "clip": clip, + "preds": [] + } + + # Reconstruct geometry + if "geometry" in row and row["geometry"]: + geometry = data.geometry_validate(row["geometry"]) + else: + geometry = data.BoundingBox.model_construct( + coordinates=[ + row["start_time"], + row["low_freq"], + row["end_time"], + row["high_freq"] + ] + ) + + class_scores = np.array(row["class_scores"]) if "class_scores" in row else np.zeros(len(self.targets.class_names)) + features = np.array(row["features"]) if "features" in row else np.zeros(0) + + pred = RawPrediction( + geometry=geometry, + detection_score=row["detection_score"], + class_scores=class_scores, + features=features, + ) + predictions_by_clip[clip_uuid]["preds"].append(pred) + + results = [] + for clip_data in predictions_by_clip.values(): + results.append( + BatDetect2Prediction( + clip=clip_data["clip"], + predictions=clip_data["preds"] + ) + ) + + return results + + @prediction_formatters.register(ParquetOutputConfig) + @staticmethod + def from_config(config: ParquetOutputConfig, targets: TargetProtocol): + return ParquetFormatter( + targets, + include_class_scores=config.include_class_scores, + include_features=config.include_features, + include_geometry=config.include_geometry, + )