save predictions as individual nc files for speed

This commit is contained in:
mbsantiago 2025-11-19 13:54:57 +00:00
parent bdb9e18964
commit dbd2d30ead
4 changed files with 231 additions and 173 deletions

View File

@ -57,187 +57,181 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> None:
num_features = 0
tree = xr.DataTree()
for prediction in predictions:
clip = prediction.clip
recording = clip.recording
if audio_dir is not None:
recording = recording.model_copy(
update=dict(
path=make_path_relative(recording.path, audio_dir)
)
)
clip_data = defaultdict(list)
for pred in prediction.predictions:
detection_id = str(uuid4())
clip_data["detection_id"].append(detection_id)
clip_data["detection_score"].append(pred.detection_score)
start_time, low_freq, end_time, high_freq = compute_bounds(
pred.geometry
)
clip_data["start_time"].append(start_time)
clip_data["end_time"].append(end_time)
clip_data["low_freq"].append(low_freq)
clip_data["high_freq"].append(high_freq)
clip_data["geometry"].append(pred.geometry.model_dump_json())
top_class_index = int(np.argmax(pred.class_scores))
top_class_score = float(pred.class_scores[top_class_index])
top_class = self.targets.class_names[top_class_index]
clip_data["top_class"].append(top_class)
clip_data["top_class_score"].append(top_class_score)
clip_data["class_scores"].append(pred.class_scores)
clip_data["features"].append(pred.features)
num_features = len(pred.features)
data_vars = {
"score": (["detection"], clip_data["detection_score"]),
"start_time": (["detection"], clip_data["start_time"]),
"end_time": (["detection"], clip_data["end_time"]),
"low_freq": (["detection"], clip_data["low_freq"]),
"high_freq": (["detection"], clip_data["high_freq"]),
"top_class": (["detection"], clip_data["top_class"]),
"top_class_score": (
["detection"],
clip_data["top_class_score"],
),
}
coords = {
"detection": ("detection", clip_data["detection_id"]),
"clip_start": clip.start_time,
"clip_end": clip.end_time,
"clip_id": str(clip.uuid),
}
if self.include_class_scores:
data_vars["class_scores"] = (
["detection", "classes"],
clip_data["class_scores"],
)
coords["classes"] = ("classes", self.targets.class_names)
if self.include_features:
data_vars["features"] = (
["detection", "feature"],
clip_data["features"],
)
coords["feature"] = ("feature", np.arange(num_features))
if self.include_geometry:
data_vars["geometry"] = (["detection"], clip_data["geometry"])
dataset = xr.Dataset(
data_vars=data_vars,
coords=coords,
attrs={
"recording": recording.model_dump_json(exclude_none=True),
},
)
tree = tree.assign(
{
str(clip.uuid): xr.DataTree(
dataset=dataset,
name=str(clip.uuid),
)
}
)
path = Path(path)
if not path.suffix == ".nc":
path = Path(path).with_suffix(".nc")
if not path.exists():
path.mkdir(parents=True)
tree.to_netcdf(path)
for prediction in predictions:
logger.debug(f"Saving clip predictions {prediction.clip.uuid}")
clip = prediction.clip
dataset = self.pred_to_xr(prediction, audio_dir)
dataset.to_netcdf(path / f"{clip.uuid}.nc")
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
path = Path(path)
root = xr.load_datatree(path)
files = list(path.glob("*.nc"))
predictions: List[BatDetect2Prediction] = []
for _, clip_data in root.items():
logger.debug(f"Loading clip {clip_data.clip_id.item()}")
recording = data.Recording.model_validate_json(
clip_data.attrs["recording"]
)
clip_id = clip_data.clip_id.item()
clip = data.Clip(
recording=recording,
uuid=UUID(clip_id),
start_time=clip_data.clip_start,
end_time=clip_data.clip_end,
)
sound_events = []
for detection in clip_data.coords["detection"]:
detection_data = clip_data.sel(detection=detection)
score = detection_data.score.item()
if "geometry" in clip_data and self.parse_full_geometry:
geometry = data.geometry_validate(
detection_data.geometry.item()
)
else:
start_time = detection_data.start_time
end_time = detection_data.end_time
low_freq = detection_data.low_freq
high_freq = detection_data.high_freq
geometry = data.BoundingBox.model_construct(
coordinates=[start_time, low_freq, end_time, high_freq]
)
if "class_scores" in detection_data:
class_scores = detection_data.class_scores.data
else:
class_scores = np.zeros(len(self.targets.class_names))
class_index = self.targets.class_names.index(
detection_data.top_class.item()
)
class_scores[class_index] = (
detection_data.top_class_score.item()
)
if "features" in detection_data:
features = detection_data.features.data
else:
features = np.zeros(0)
sound_events.append(
RawPrediction(
geometry=geometry,
detection_score=score,
class_scores=class_scores,
features=features,
)
)
predictions.append(
BatDetect2Prediction(
clip=clip,
predictions=sound_events,
)
)
for filepath in files:
logger.debug(f"Loading clip predictions {filepath}")
clip_data = xr.load_dataset(filepath)
prediction = self.pred_from_xr(clip_data)
predictions.append(prediction)
return predictions
def pred_to_xr(
self,
prediction: BatDetect2Prediction,
audio_dir: Optional[data.PathLike] = None,
) -> xr.Dataset:
clip = prediction.clip
recording = clip.recording
num_features = 0
if audio_dir is not None:
recording = recording.model_copy(
update=dict(path=make_path_relative(recording.path, audio_dir))
)
data = defaultdict(list)
for pred in prediction.predictions:
detection_id = str(uuid4())
data["detection_id"].append(detection_id)
data["detection_score"].append(pred.detection_score)
start_time, low_freq, end_time, high_freq = compute_bounds(
pred.geometry
)
data["start_time"].append(start_time)
data["end_time"].append(end_time)
data["low_freq"].append(low_freq)
data["high_freq"].append(high_freq)
data["geometry"].append(pred.geometry.model_dump_json())
top_class_index = int(np.argmax(pred.class_scores))
top_class_score = float(pred.class_scores[top_class_index])
top_class = self.targets.class_names[top_class_index]
data["top_class"].append(top_class)
data["top_class_score"].append(top_class_score)
data["class_scores"].append(pred.class_scores)
data["features"].append(pred.features)
num_features = len(pred.features)
data_vars = {
"score": (["detection"], data["detection_score"]),
"start_time": (["detection"], data["start_time"]),
"end_time": (["detection"], data["end_time"]),
"low_freq": (["detection"], data["low_freq"]),
"high_freq": (["detection"], data["high_freq"]),
"top_class": (["detection"], data["top_class"]),
"top_class_score": (["detection"], data["top_class_score"]),
}
coords = {
"detection": ("detection", data["detection_id"]),
"clip_start": clip.start_time,
"clip_end": clip.end_time,
"clip_id": str(clip.uuid),
}
if self.include_class_scores:
class_scores = np.stack(data["class_scores"], axis=0)
data_vars["class_scores"] = (
["detection", "classes"],
class_scores,
)
coords["classes"] = ("classes", self.targets.class_names)
if self.include_features:
features = np.stack(data["features"], axis=0)
data_vars["features"] = (["detection", "feature"], features)
coords["feature"] = ("feature", np.arange(num_features))
if self.include_geometry:
data_vars["geometry"] = (["detection"], data["geometry"])
return xr.Dataset(
data_vars=data_vars,
coords=coords,
attrs={
"recording": recording.model_dump_json(exclude_none=True),
},
)
def pred_from_xr(self, dataset: xr.Dataset) -> BatDetect2Prediction:
clip_data = dataset
clip_id = clip_data.clip_id.item()
recording = data.Recording.model_validate_json(
clip_data.attrs["recording"]
)
clip_id = clip_data.clip_id.item()
clip = data.Clip(
recording=recording,
uuid=UUID(clip_id),
start_time=clip_data.clip_start,
end_time=clip_data.clip_end,
)
sound_events = []
for detection in clip_data.coords["detection"]:
detection_data = clip_data.sel(detection=detection)
score = detection_data.score.item()
if "geometry" in clip_data and self.parse_full_geometry:
geometry = data.geometry_validate(
detection_data.geometry.item()
)
else:
start_time = detection_data.start_time
end_time = detection_data.end_time
low_freq = detection_data.low_freq
high_freq = detection_data.high_freq
geometry = data.BoundingBox.model_construct(
coordinates=[start_time, low_freq, end_time, high_freq]
)
if "class_scores" in detection_data:
class_scores = detection_data.class_scores.data
else:
class_scores = np.zeros(len(self.targets.class_names))
class_index = self.targets.class_names.index(
detection_data.top_class.item()
)
class_scores[class_index] = (
detection_data.top_class_score.item()
)
if "features" in detection_data:
features = detection_data.features.data
else:
features = np.zeros(0)
sound_events.append(
RawPrediction(
geometry=geometry,
detection_score=score,
class_scores=class_scores,
features=features,
)
)
return BatDetect2Prediction(
clip=clip,
predictions=sound_events,
)
@prediction_formatters.register(RawOutputConfig)
@staticmethod
def from_config(config: RawOutputConfig, targets: TargetProtocol):

View File

@ -9,17 +9,17 @@ import soundfile as sf
from scipy import signal
from soundevent import data, terms
from batdetect2.audio import build_audio_loader
from batdetect2.audio.clips import build_clipper
from batdetect2.data import DatasetConfig, load_dataset
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.targets import (
TargetConfig,
build_targets,
call_type,
)
from batdetect2.targets.classes import TargetClassConfig
from batdetect2.train.clips import build_clipper
from batdetect2.train.labels import build_clip_labeler
from batdetect2.typing import (
ClipLabeller,
@ -442,7 +442,7 @@ def example_annotations(
) -> List[data.ClipAnnotation]:
annotations = load_dataset(example_dataset)
assert len(annotations) == 3
return annotations
return list(annotations)
@pytest.fixture

View File

@ -0,0 +1,64 @@
from pathlib import Path
import numpy as np
import pytest
from soundevent import data
from batdetect2.data.predictions import RawOutputConfig, build_output_formatter
from batdetect2.typing import (
BatDetect2Prediction,
RawPrediction,
TargetProtocol,
)
@pytest.fixture
def sample_formatter(sample_targets: TargetProtocol):
return build_output_formatter(
config=RawOutputConfig(),
targets=sample_targets,
)
def test_roundtrip(
sample_formatter,
clip: data.Clip,
sample_targets: TargetProtocol,
tmp_path: Path,
):
detections = [
RawPrediction(
geometry=data.BoundingBox(
coordinates=list(np.random.uniform(size=[4]))
),
detection_score=0.5,
class_scores=np.random.uniform(
size=len(sample_targets.class_names)
),
features=np.random.uniform(size=32),
)
for _ in range(10)
]
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
path = tmp_path / "predictions"
sample_formatter.save(predictions=[prediction], path=path)
recovered = sample_formatter.load(path=path)
assert len(recovered) == 1
assert recovered[0].clip == prediction.clip
for recovered_prediction, detection in zip(
recovered[0].predictions, detections
):
assert (
recovered_prediction.detection_score == detection.detection_score
)
assert (
recovered_prediction.class_scores == detection.class_scores
).all()
assert (recovered_prediction.features == detection.features).all()
assert recovered_prediction.geometry == detection.geometry