mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add parquet format for outputs
This commit is contained in:
parent
72278d75ec
commit
9c72537ddd
@ -8,6 +8,7 @@ from batdetect2.data.predictions.base import (
|
|||||||
prediction_formatters,
|
prediction_formatters,
|
||||||
)
|
)
|
||||||
from batdetect2.data.predictions.batdetect2 import BatDetect2OutputConfig
|
from batdetect2.data.predictions.batdetect2 import BatDetect2OutputConfig
|
||||||
|
from batdetect2.data.predictions.parquet import ParquetOutputConfig
|
||||||
from batdetect2.data.predictions.raw import RawOutputConfig
|
from batdetect2.data.predictions.raw import RawOutputConfig
|
||||||
from batdetect2.data.predictions.soundevent import SoundEventOutputConfig
|
from batdetect2.data.predictions.soundevent import SoundEventOutputConfig
|
||||||
from batdetect2.typing import TargetProtocol
|
from batdetect2.typing import TargetProtocol
|
||||||
@ -16,6 +17,7 @@ __all__ = [
|
|||||||
"build_output_formatter",
|
"build_output_formatter",
|
||||||
"get_output_formatter",
|
"get_output_formatter",
|
||||||
"BatDetect2OutputConfig",
|
"BatDetect2OutputConfig",
|
||||||
|
"ParquetOutputConfig",
|
||||||
"RawOutputConfig",
|
"RawOutputConfig",
|
||||||
"SoundEventOutputConfig",
|
"SoundEventOutputConfig",
|
||||||
]
|
]
|
||||||
@ -24,6 +26,7 @@ __all__ = [
|
|||||||
OutputFormatConfig = Annotated[
|
OutputFormatConfig = Annotated[
|
||||||
Union[
|
Union[
|
||||||
BatDetect2OutputConfig,
|
BatDetect2OutputConfig,
|
||||||
|
ParquetOutputConfig,
|
||||||
SoundEventOutputConfig,
|
SoundEventOutputConfig,
|
||||||
RawOutputConfig,
|
RawOutputConfig,
|
||||||
],
|
],
|
||||||
|
|||||||
194
src/batdetect2/data/predictions/parquet.py
Normal file
194
src/batdetect2/data/predictions/parquet.py
Normal file
@ -0,0 +1,194 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Literal, Optional, Sequence
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from loguru import logger
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig
|
||||||
|
from batdetect2.data.predictions.base import (
|
||||||
|
make_path_relative,
|
||||||
|
prediction_formatters,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
OutputFormatterProtocol,
|
||||||
|
RawPrediction,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ParquetOutputConfig(BaseConfig):
|
||||||
|
name: Literal["parquet"] = "parquet"
|
||||||
|
|
||||||
|
include_class_scores: bool = True
|
||||||
|
include_features: bool = True
|
||||||
|
include_geometry: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ParquetFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
include_class_scores: bool = True,
|
||||||
|
include_features: bool = True,
|
||||||
|
include_geometry: bool = True,
|
||||||
|
):
|
||||||
|
self.targets = targets
|
||||||
|
self.include_class_scores = include_class_scores
|
||||||
|
self.include_features = include_features
|
||||||
|
self.include_geometry = include_geometry
|
||||||
|
|
||||||
|
def format(
|
||||||
|
self,
|
||||||
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
|
) -> List[BatDetect2Prediction]:
|
||||||
|
return list(predictions)
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
|
path: data.PathLike,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> None:
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
if not path.parent.exists():
|
||||||
|
path.parent.mkdir(parents=True)
|
||||||
|
|
||||||
|
# Ensure the file has .parquet extension if it's a file path
|
||||||
|
if path.suffix != ".parquet":
|
||||||
|
# If it's a directory, we might want to save as a partitioned dataset or a single file inside
|
||||||
|
# For now, let's assume the user provides a full file path or a directory where we save 'predictions.parquet'
|
||||||
|
if path.is_dir() or not path.suffix:
|
||||||
|
path = path / "predictions.parquet"
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for prediction in predictions:
|
||||||
|
clip = prediction.clip
|
||||||
|
recording = clip.recording
|
||||||
|
|
||||||
|
if audio_dir is not None:
|
||||||
|
recording = recording.model_copy(
|
||||||
|
update=dict(path=make_path_relative(recording.path, audio_dir))
|
||||||
|
)
|
||||||
|
|
||||||
|
recording_json = recording.model_dump_json(exclude_none=True)
|
||||||
|
|
||||||
|
for pred in prediction.predictions:
|
||||||
|
row = {
|
||||||
|
"clip_uuid": str(clip.uuid),
|
||||||
|
"clip_start_time": clip.start_time,
|
||||||
|
"clip_end_time": clip.end_time,
|
||||||
|
"recording_info": recording_json,
|
||||||
|
"detection_score": pred.detection_score,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.include_geometry:
|
||||||
|
# Store geometry as [start_time, low_freq, end_time, high_freq]
|
||||||
|
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||||
|
pred.geometry
|
||||||
|
)
|
||||||
|
row["start_time"] = start_time
|
||||||
|
row["low_freq"] = low_freq
|
||||||
|
row["end_time"] = end_time
|
||||||
|
row["high_freq"] = high_freq
|
||||||
|
|
||||||
|
# Store full geometry as JSON
|
||||||
|
row["geometry"] = pred.geometry.model_dump_json()
|
||||||
|
|
||||||
|
if self.include_class_scores:
|
||||||
|
row["class_scores"] = pred.class_scores.tolist()
|
||||||
|
|
||||||
|
if self.include_features:
|
||||||
|
row["features"] = pred.features.tolist()
|
||||||
|
|
||||||
|
rows.append(row)
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
logger.warning("No predictions to save.")
|
||||||
|
return
|
||||||
|
|
||||||
|
df = pd.DataFrame(rows)
|
||||||
|
logger.info(f"Saving {len(df)} predictions to {path}")
|
||||||
|
df.to_parquet(path, index=False)
|
||||||
|
|
||||||
|
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
|
||||||
|
path = Path(path)
|
||||||
|
if path.is_dir():
|
||||||
|
# Try to find parquet files
|
||||||
|
files = list(path.glob("*.parquet"))
|
||||||
|
if not files:
|
||||||
|
return []
|
||||||
|
# Read all and concatenate
|
||||||
|
dfs = [pd.read_parquet(f) for f in files]
|
||||||
|
df = pd.concat(dfs, ignore_index=True)
|
||||||
|
else:
|
||||||
|
df = pd.read_parquet(path)
|
||||||
|
|
||||||
|
predictions_by_clip = {}
|
||||||
|
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
clip_uuid = row["clip_uuid"]
|
||||||
|
|
||||||
|
if clip_uuid not in predictions_by_clip:
|
||||||
|
recording = data.Recording.model_validate_json(row["recording_info"])
|
||||||
|
clip = data.Clip(
|
||||||
|
uuid=UUID(clip_uuid),
|
||||||
|
recording=recording,
|
||||||
|
start_time=row["clip_start_time"],
|
||||||
|
end_time=row["clip_end_time"],
|
||||||
|
)
|
||||||
|
predictions_by_clip[clip_uuid] = {
|
||||||
|
"clip": clip,
|
||||||
|
"preds": []
|
||||||
|
}
|
||||||
|
|
||||||
|
# Reconstruct geometry
|
||||||
|
if "geometry" in row and row["geometry"]:
|
||||||
|
geometry = data.geometry_validate(row["geometry"])
|
||||||
|
else:
|
||||||
|
geometry = data.BoundingBox.model_construct(
|
||||||
|
coordinates=[
|
||||||
|
row["start_time"],
|
||||||
|
row["low_freq"],
|
||||||
|
row["end_time"],
|
||||||
|
row["high_freq"]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
class_scores = np.array(row["class_scores"]) if "class_scores" in row else np.zeros(len(self.targets.class_names))
|
||||||
|
features = np.array(row["features"]) if "features" in row else np.zeros(0)
|
||||||
|
|
||||||
|
pred = RawPrediction(
|
||||||
|
geometry=geometry,
|
||||||
|
detection_score=row["detection_score"],
|
||||||
|
class_scores=class_scores,
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
predictions_by_clip[clip_uuid]["preds"].append(pred)
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for clip_data in predictions_by_clip.values():
|
||||||
|
results.append(
|
||||||
|
BatDetect2Prediction(
|
||||||
|
clip=clip_data["clip"],
|
||||||
|
predictions=clip_data["preds"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
@prediction_formatters.register(ParquetOutputConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ParquetOutputConfig, targets: TargetProtocol):
|
||||||
|
return ParquetFormatter(
|
||||||
|
targets,
|
||||||
|
include_class_scores=config.include_class_scores,
|
||||||
|
include_features=config.include_features,
|
||||||
|
include_geometry=config.include_geometry,
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue
Block a user