save predictions as individual nc files for speed

This commit is contained in:
mbsantiago 2025-11-19 13:54:57 +00:00
parent bdb9e18964
commit dbd2d30ead
4 changed files with 231 additions and 173 deletions

View File

@ -57,90 +57,109 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
path: data.PathLike, path: data.PathLike,
audio_dir: Optional[data.PathLike] = None, audio_dir: Optional[data.PathLike] = None,
) -> None: ) -> None:
num_features = 0 path = Path(path)
tree = xr.DataTree() if not path.exists():
path.mkdir(parents=True)
for prediction in predictions: for prediction in predictions:
logger.debug(f"Saving clip predictions {prediction.clip.uuid}")
clip = prediction.clip
dataset = self.pred_to_xr(prediction, audio_dir)
dataset.to_netcdf(path / f"{clip.uuid}.nc")
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
path = Path(path)
files = list(path.glob("*.nc"))
predictions: List[BatDetect2Prediction] = []
for filepath in files:
logger.debug(f"Loading clip predictions {filepath}")
clip_data = xr.load_dataset(filepath)
prediction = self.pred_from_xr(clip_data)
predictions.append(prediction)
return predictions
def pred_to_xr(
self,
prediction: BatDetect2Prediction,
audio_dir: Optional[data.PathLike] = None,
) -> xr.Dataset:
clip = prediction.clip clip = prediction.clip
recording = clip.recording recording = clip.recording
num_features = 0
if audio_dir is not None: if audio_dir is not None:
recording = recording.model_copy( recording = recording.model_copy(
update=dict( update=dict(path=make_path_relative(recording.path, audio_dir))
path=make_path_relative(recording.path, audio_dir)
)
) )
clip_data = defaultdict(list) data = defaultdict(list)
for pred in prediction.predictions: for pred in prediction.predictions:
detection_id = str(uuid4()) detection_id = str(uuid4())
clip_data["detection_id"].append(detection_id) data["detection_id"].append(detection_id)
clip_data["detection_score"].append(pred.detection_score) data["detection_score"].append(pred.detection_score)
start_time, low_freq, end_time, high_freq = compute_bounds( start_time, low_freq, end_time, high_freq = compute_bounds(
pred.geometry pred.geometry
) )
clip_data["start_time"].append(start_time) data["start_time"].append(start_time)
clip_data["end_time"].append(end_time) data["end_time"].append(end_time)
clip_data["low_freq"].append(low_freq) data["low_freq"].append(low_freq)
clip_data["high_freq"].append(high_freq) data["high_freq"].append(high_freq)
clip_data["geometry"].append(pred.geometry.model_dump_json()) data["geometry"].append(pred.geometry.model_dump_json())
top_class_index = int(np.argmax(pred.class_scores)) top_class_index = int(np.argmax(pred.class_scores))
top_class_score = float(pred.class_scores[top_class_index]) top_class_score = float(pred.class_scores[top_class_index])
top_class = self.targets.class_names[top_class_index] top_class = self.targets.class_names[top_class_index]
clip_data["top_class"].append(top_class) data["top_class"].append(top_class)
clip_data["top_class_score"].append(top_class_score) data["top_class_score"].append(top_class_score)
clip_data["class_scores"].append(pred.class_scores) data["class_scores"].append(pred.class_scores)
clip_data["features"].append(pred.features) data["features"].append(pred.features)
num_features = len(pred.features) num_features = len(pred.features)
data_vars = { data_vars = {
"score": (["detection"], clip_data["detection_score"]), "score": (["detection"], data["detection_score"]),
"start_time": (["detection"], clip_data["start_time"]), "start_time": (["detection"], data["start_time"]),
"end_time": (["detection"], clip_data["end_time"]), "end_time": (["detection"], data["end_time"]),
"low_freq": (["detection"], clip_data["low_freq"]), "low_freq": (["detection"], data["low_freq"]),
"high_freq": (["detection"], clip_data["high_freq"]), "high_freq": (["detection"], data["high_freq"]),
"top_class": (["detection"], clip_data["top_class"]), "top_class": (["detection"], data["top_class"]),
"top_class_score": ( "top_class_score": (["detection"], data["top_class_score"]),
["detection"],
clip_data["top_class_score"],
),
} }
coords = { coords = {
"detection": ("detection", clip_data["detection_id"]), "detection": ("detection", data["detection_id"]),
"clip_start": clip.start_time, "clip_start": clip.start_time,
"clip_end": clip.end_time, "clip_end": clip.end_time,
"clip_id": str(clip.uuid), "clip_id": str(clip.uuid),
} }
if self.include_class_scores: if self.include_class_scores:
class_scores = np.stack(data["class_scores"], axis=0)
data_vars["class_scores"] = ( data_vars["class_scores"] = (
["detection", "classes"], ["detection", "classes"],
clip_data["class_scores"], class_scores,
) )
coords["classes"] = ("classes", self.targets.class_names) coords["classes"] = ("classes", self.targets.class_names)
if self.include_features: if self.include_features:
data_vars["features"] = ( features = np.stack(data["features"], axis=0)
["detection", "feature"], data_vars["features"] = (["detection", "feature"], features)
clip_data["features"],
)
coords["feature"] = ("feature", np.arange(num_features)) coords["feature"] = ("feature", np.arange(num_features))
if self.include_geometry: if self.include_geometry:
data_vars["geometry"] = (["detection"], clip_data["geometry"]) data_vars["geometry"] = (["detection"], data["geometry"])
dataset = xr.Dataset( return xr.Dataset(
data_vars=data_vars, data_vars=data_vars,
coords=coords, coords=coords,
attrs={ attrs={
@ -148,31 +167,10 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
}, },
) )
tree = tree.assign( def pred_from_xr(self, dataset: xr.Dataset) -> BatDetect2Prediction:
{ clip_data = dataset
str(clip.uuid): xr.DataTree( clip_id = clip_data.clip_id.item()
dataset=dataset,
name=str(clip.uuid),
)
}
)
path = Path(path)
if not path.suffix == ".nc":
path = Path(path).with_suffix(".nc")
tree.to_netcdf(path)
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
path = Path(path)
root = xr.load_datatree(path)
predictions: List[BatDetect2Prediction] = []
for _, clip_data in root.items():
logger.debug(f"Loading clip {clip_data.clip_id.item()}")
recording = data.Recording.model_validate_json( recording = data.Recording.model_validate_json(
clip_data.attrs["recording"] clip_data.attrs["recording"]
) )
@ -229,14 +227,10 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
) )
) )
predictions.append( return BatDetect2Prediction(
BatDetect2Prediction(
clip=clip, clip=clip,
predictions=sound_events, predictions=sound_events,
) )
)
return predictions
@prediction_formatters.register(RawOutputConfig) @prediction_formatters.register(RawOutputConfig)
@staticmethod @staticmethod

View File

@ -9,17 +9,17 @@ import soundfile as sf
from scipy import signal from scipy import signal
from soundevent import data, terms from soundevent import data, terms
from batdetect2.audio import build_audio_loader
from batdetect2.audio.clips import build_clipper
from batdetect2.data import DatasetConfig, load_dataset from batdetect2.data import DatasetConfig, load_dataset
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.targets import ( from batdetect2.targets import (
TargetConfig, TargetConfig,
build_targets, build_targets,
call_type, call_type,
) )
from batdetect2.targets.classes import TargetClassConfig from batdetect2.targets.classes import TargetClassConfig
from batdetect2.train.clips import build_clipper
from batdetect2.train.labels import build_clip_labeler from batdetect2.train.labels import build_clip_labeler
from batdetect2.typing import ( from batdetect2.typing import (
ClipLabeller, ClipLabeller,
@ -442,7 +442,7 @@ def example_annotations(
) -> List[data.ClipAnnotation]: ) -> List[data.ClipAnnotation]:
annotations = load_dataset(example_dataset) annotations = load_dataset(example_dataset)
assert len(annotations) == 3 assert len(annotations) == 3
return annotations return list(annotations)
@pytest.fixture @pytest.fixture

View File

@ -0,0 +1,64 @@
from pathlib import Path
import numpy as np
import pytest
from soundevent import data
from batdetect2.data.predictions import RawOutputConfig, build_output_formatter
from batdetect2.typing import (
BatDetect2Prediction,
RawPrediction,
TargetProtocol,
)
@pytest.fixture
def sample_formatter(sample_targets: TargetProtocol):
return build_output_formatter(
config=RawOutputConfig(),
targets=sample_targets,
)
def test_roundtrip(
sample_formatter,
clip: data.Clip,
sample_targets: TargetProtocol,
tmp_path: Path,
):
detections = [
RawPrediction(
geometry=data.BoundingBox(
coordinates=list(np.random.uniform(size=[4]))
),
detection_score=0.5,
class_scores=np.random.uniform(
size=len(sample_targets.class_names)
),
features=np.random.uniform(size=32),
)
for _ in range(10)
]
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
path = tmp_path / "predictions"
sample_formatter.save(predictions=[prediction], path=path)
recovered = sample_formatter.load(path=path)
assert len(recovered) == 1
assert recovered[0].clip == prediction.clip
for recovered_prediction, detection in zip(
recovered[0].predictions, detections
):
assert (
recovered_prediction.detection_score == detection.detection_score
)
assert (
recovered_prediction.class_scores == detection.class_scores
).all()
assert (recovered_prediction.features == detection.features).all()
assert recovered_prediction.geometry == detection.geometry