mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
save predictions as individual nc files for speed
This commit is contained in:
parent
bdb9e18964
commit
dbd2d30ead
@ -57,187 +57,181 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> None:
|
||||
num_features = 0
|
||||
|
||||
tree = xr.DataTree()
|
||||
|
||||
for prediction in predictions:
|
||||
clip = prediction.clip
|
||||
recording = clip.recording
|
||||
|
||||
if audio_dir is not None:
|
||||
recording = recording.model_copy(
|
||||
update=dict(
|
||||
path=make_path_relative(recording.path, audio_dir)
|
||||
)
|
||||
)
|
||||
|
||||
clip_data = defaultdict(list)
|
||||
|
||||
for pred in prediction.predictions:
|
||||
detection_id = str(uuid4())
|
||||
|
||||
clip_data["detection_id"].append(detection_id)
|
||||
clip_data["detection_score"].append(pred.detection_score)
|
||||
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||
pred.geometry
|
||||
)
|
||||
|
||||
clip_data["start_time"].append(start_time)
|
||||
clip_data["end_time"].append(end_time)
|
||||
clip_data["low_freq"].append(low_freq)
|
||||
clip_data["high_freq"].append(high_freq)
|
||||
|
||||
clip_data["geometry"].append(pred.geometry.model_dump_json())
|
||||
|
||||
top_class_index = int(np.argmax(pred.class_scores))
|
||||
top_class_score = float(pred.class_scores[top_class_index])
|
||||
top_class = self.targets.class_names[top_class_index]
|
||||
|
||||
clip_data["top_class"].append(top_class)
|
||||
clip_data["top_class_score"].append(top_class_score)
|
||||
|
||||
clip_data["class_scores"].append(pred.class_scores)
|
||||
clip_data["features"].append(pred.features)
|
||||
|
||||
num_features = len(pred.features)
|
||||
|
||||
data_vars = {
|
||||
"score": (["detection"], clip_data["detection_score"]),
|
||||
"start_time": (["detection"], clip_data["start_time"]),
|
||||
"end_time": (["detection"], clip_data["end_time"]),
|
||||
"low_freq": (["detection"], clip_data["low_freq"]),
|
||||
"high_freq": (["detection"], clip_data["high_freq"]),
|
||||
"top_class": (["detection"], clip_data["top_class"]),
|
||||
"top_class_score": (
|
||||
["detection"],
|
||||
clip_data["top_class_score"],
|
||||
),
|
||||
}
|
||||
|
||||
coords = {
|
||||
"detection": ("detection", clip_data["detection_id"]),
|
||||
"clip_start": clip.start_time,
|
||||
"clip_end": clip.end_time,
|
||||
"clip_id": str(clip.uuid),
|
||||
}
|
||||
|
||||
if self.include_class_scores:
|
||||
data_vars["class_scores"] = (
|
||||
["detection", "classes"],
|
||||
clip_data["class_scores"],
|
||||
)
|
||||
coords["classes"] = ("classes", self.targets.class_names)
|
||||
|
||||
if self.include_features:
|
||||
data_vars["features"] = (
|
||||
["detection", "feature"],
|
||||
clip_data["features"],
|
||||
)
|
||||
coords["feature"] = ("feature", np.arange(num_features))
|
||||
|
||||
if self.include_geometry:
|
||||
data_vars["geometry"] = (["detection"], clip_data["geometry"])
|
||||
|
||||
dataset = xr.Dataset(
|
||||
data_vars=data_vars,
|
||||
coords=coords,
|
||||
attrs={
|
||||
"recording": recording.model_dump_json(exclude_none=True),
|
||||
},
|
||||
)
|
||||
|
||||
tree = tree.assign(
|
||||
{
|
||||
str(clip.uuid): xr.DataTree(
|
||||
dataset=dataset,
|
||||
name=str(clip.uuid),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
path = Path(path)
|
||||
|
||||
if not path.suffix == ".nc":
|
||||
path = Path(path).with_suffix(".nc")
|
||||
if not path.exists():
|
||||
path.mkdir(parents=True)
|
||||
|
||||
tree.to_netcdf(path)
|
||||
for prediction in predictions:
|
||||
logger.debug(f"Saving clip predictions {prediction.clip.uuid}")
|
||||
clip = prediction.clip
|
||||
dataset = self.pred_to_xr(prediction, audio_dir)
|
||||
dataset.to_netcdf(path / f"{clip.uuid}.nc")
|
||||
|
||||
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
|
||||
path = Path(path)
|
||||
|
||||
root = xr.load_datatree(path)
|
||||
|
||||
files = list(path.glob("*.nc"))
|
||||
predictions: List[BatDetect2Prediction] = []
|
||||
|
||||
for _, clip_data in root.items():
|
||||
logger.debug(f"Loading clip {clip_data.clip_id.item()}")
|
||||
recording = data.Recording.model_validate_json(
|
||||
clip_data.attrs["recording"]
|
||||
)
|
||||
|
||||
clip_id = clip_data.clip_id.item()
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
uuid=UUID(clip_id),
|
||||
start_time=clip_data.clip_start,
|
||||
end_time=clip_data.clip_end,
|
||||
)
|
||||
|
||||
sound_events = []
|
||||
|
||||
for detection in clip_data.coords["detection"]:
|
||||
detection_data = clip_data.sel(detection=detection)
|
||||
score = detection_data.score.item()
|
||||
|
||||
if "geometry" in clip_data and self.parse_full_geometry:
|
||||
geometry = data.geometry_validate(
|
||||
detection_data.geometry.item()
|
||||
)
|
||||
else:
|
||||
start_time = detection_data.start_time
|
||||
end_time = detection_data.end_time
|
||||
low_freq = detection_data.low_freq
|
||||
high_freq = detection_data.high_freq
|
||||
geometry = data.BoundingBox.model_construct(
|
||||
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||
)
|
||||
|
||||
if "class_scores" in detection_data:
|
||||
class_scores = detection_data.class_scores.data
|
||||
else:
|
||||
class_scores = np.zeros(len(self.targets.class_names))
|
||||
class_index = self.targets.class_names.index(
|
||||
detection_data.top_class.item()
|
||||
)
|
||||
class_scores[class_index] = (
|
||||
detection_data.top_class_score.item()
|
||||
)
|
||||
|
||||
if "features" in detection_data:
|
||||
features = detection_data.features.data
|
||||
else:
|
||||
features = np.zeros(0)
|
||||
|
||||
sound_events.append(
|
||||
RawPrediction(
|
||||
geometry=geometry,
|
||||
detection_score=score,
|
||||
class_scores=class_scores,
|
||||
features=features,
|
||||
)
|
||||
)
|
||||
|
||||
predictions.append(
|
||||
BatDetect2Prediction(
|
||||
clip=clip,
|
||||
predictions=sound_events,
|
||||
)
|
||||
)
|
||||
for filepath in files:
|
||||
logger.debug(f"Loading clip predictions {filepath}")
|
||||
clip_data = xr.load_dataset(filepath)
|
||||
prediction = self.pred_from_xr(clip_data)
|
||||
predictions.append(prediction)
|
||||
|
||||
return predictions
|
||||
|
||||
def pred_to_xr(
|
||||
self,
|
||||
prediction: BatDetect2Prediction,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> xr.Dataset:
|
||||
clip = prediction.clip
|
||||
recording = clip.recording
|
||||
num_features = 0
|
||||
|
||||
if audio_dir is not None:
|
||||
recording = recording.model_copy(
|
||||
update=dict(path=make_path_relative(recording.path, audio_dir))
|
||||
)
|
||||
|
||||
data = defaultdict(list)
|
||||
|
||||
for pred in prediction.predictions:
|
||||
detection_id = str(uuid4())
|
||||
|
||||
data["detection_id"].append(detection_id)
|
||||
data["detection_score"].append(pred.detection_score)
|
||||
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||
pred.geometry
|
||||
)
|
||||
|
||||
data["start_time"].append(start_time)
|
||||
data["end_time"].append(end_time)
|
||||
data["low_freq"].append(low_freq)
|
||||
data["high_freq"].append(high_freq)
|
||||
|
||||
data["geometry"].append(pred.geometry.model_dump_json())
|
||||
|
||||
top_class_index = int(np.argmax(pred.class_scores))
|
||||
top_class_score = float(pred.class_scores[top_class_index])
|
||||
top_class = self.targets.class_names[top_class_index]
|
||||
|
||||
data["top_class"].append(top_class)
|
||||
data["top_class_score"].append(top_class_score)
|
||||
|
||||
data["class_scores"].append(pred.class_scores)
|
||||
data["features"].append(pred.features)
|
||||
|
||||
num_features = len(pred.features)
|
||||
|
||||
data_vars = {
|
||||
"score": (["detection"], data["detection_score"]),
|
||||
"start_time": (["detection"], data["start_time"]),
|
||||
"end_time": (["detection"], data["end_time"]),
|
||||
"low_freq": (["detection"], data["low_freq"]),
|
||||
"high_freq": (["detection"], data["high_freq"]),
|
||||
"top_class": (["detection"], data["top_class"]),
|
||||
"top_class_score": (["detection"], data["top_class_score"]),
|
||||
}
|
||||
|
||||
coords = {
|
||||
"detection": ("detection", data["detection_id"]),
|
||||
"clip_start": clip.start_time,
|
||||
"clip_end": clip.end_time,
|
||||
"clip_id": str(clip.uuid),
|
||||
}
|
||||
|
||||
if self.include_class_scores:
|
||||
class_scores = np.stack(data["class_scores"], axis=0)
|
||||
data_vars["class_scores"] = (
|
||||
["detection", "classes"],
|
||||
class_scores,
|
||||
)
|
||||
coords["classes"] = ("classes", self.targets.class_names)
|
||||
|
||||
if self.include_features:
|
||||
features = np.stack(data["features"], axis=0)
|
||||
data_vars["features"] = (["detection", "feature"], features)
|
||||
coords["feature"] = ("feature", np.arange(num_features))
|
||||
|
||||
if self.include_geometry:
|
||||
data_vars["geometry"] = (["detection"], data["geometry"])
|
||||
|
||||
return xr.Dataset(
|
||||
data_vars=data_vars,
|
||||
coords=coords,
|
||||
attrs={
|
||||
"recording": recording.model_dump_json(exclude_none=True),
|
||||
},
|
||||
)
|
||||
|
||||
def pred_from_xr(self, dataset: xr.Dataset) -> BatDetect2Prediction:
|
||||
clip_data = dataset
|
||||
clip_id = clip_data.clip_id.item()
|
||||
|
||||
recording = data.Recording.model_validate_json(
|
||||
clip_data.attrs["recording"]
|
||||
)
|
||||
|
||||
clip_id = clip_data.clip_id.item()
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
uuid=UUID(clip_id),
|
||||
start_time=clip_data.clip_start,
|
||||
end_time=clip_data.clip_end,
|
||||
)
|
||||
|
||||
sound_events = []
|
||||
|
||||
for detection in clip_data.coords["detection"]:
|
||||
detection_data = clip_data.sel(detection=detection)
|
||||
score = detection_data.score.item()
|
||||
|
||||
if "geometry" in clip_data and self.parse_full_geometry:
|
||||
geometry = data.geometry_validate(
|
||||
detection_data.geometry.item()
|
||||
)
|
||||
else:
|
||||
start_time = detection_data.start_time
|
||||
end_time = detection_data.end_time
|
||||
low_freq = detection_data.low_freq
|
||||
high_freq = detection_data.high_freq
|
||||
geometry = data.BoundingBox.model_construct(
|
||||
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||
)
|
||||
|
||||
if "class_scores" in detection_data:
|
||||
class_scores = detection_data.class_scores.data
|
||||
else:
|
||||
class_scores = np.zeros(len(self.targets.class_names))
|
||||
class_index = self.targets.class_names.index(
|
||||
detection_data.top_class.item()
|
||||
)
|
||||
class_scores[class_index] = (
|
||||
detection_data.top_class_score.item()
|
||||
)
|
||||
|
||||
if "features" in detection_data:
|
||||
features = detection_data.features.data
|
||||
else:
|
||||
features = np.zeros(0)
|
||||
|
||||
sound_events.append(
|
||||
RawPrediction(
|
||||
geometry=geometry,
|
||||
detection_score=score,
|
||||
class_scores=class_scores,
|
||||
features=features,
|
||||
)
|
||||
)
|
||||
|
||||
return BatDetect2Prediction(
|
||||
clip=clip,
|
||||
predictions=sound_events,
|
||||
)
|
||||
|
||||
@prediction_formatters.register(RawOutputConfig)
|
||||
@staticmethod
|
||||
def from_config(config: RawOutputConfig, targets: TargetProtocol):
|
||||
|
||||
@ -9,17 +9,17 @@ import soundfile as sf
|
||||
from scipy import signal
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.audio.clips import build_clipper
|
||||
from batdetect2.data import DatasetConfig, load_dataset
|
||||
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.targets import (
|
||||
TargetConfig,
|
||||
build_targets,
|
||||
call_type,
|
||||
)
|
||||
from batdetect2.targets.classes import TargetClassConfig
|
||||
from batdetect2.train.clips import build_clipper
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.typing import (
|
||||
ClipLabeller,
|
||||
@ -442,7 +442,7 @@ def example_annotations(
|
||||
) -> List[data.ClipAnnotation]:
|
||||
annotations = load_dataset(example_dataset)
|
||||
assert len(annotations) == 3
|
||||
return annotations
|
||||
return list(annotations)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
0
tests/test_data/test_predictions/__init__.py
Normal file
0
tests/test_data/test_predictions/__init__.py
Normal file
64
tests/test_data/test_predictions/test_raw.py
Normal file
64
tests/test_data/test_predictions/test_raw.py
Normal file
@ -0,0 +1,64 @@
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.predictions import RawOutputConfig, build_output_formatter
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_formatter(sample_targets: TargetProtocol):
|
||||
return build_output_formatter(
|
||||
config=RawOutputConfig(),
|
||||
targets=sample_targets,
|
||||
)
|
||||
|
||||
|
||||
def test_roundtrip(
|
||||
sample_formatter,
|
||||
clip: data.Clip,
|
||||
sample_targets: TargetProtocol,
|
||||
tmp_path: Path,
|
||||
):
|
||||
detections = [
|
||||
RawPrediction(
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=list(np.random.uniform(size=[4]))
|
||||
),
|
||||
detection_score=0.5,
|
||||
class_scores=np.random.uniform(
|
||||
size=len(sample_targets.class_names)
|
||||
),
|
||||
features=np.random.uniform(size=32),
|
||||
)
|
||||
for _ in range(10)
|
||||
]
|
||||
|
||||
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
|
||||
|
||||
path = tmp_path / "predictions"
|
||||
|
||||
sample_formatter.save(predictions=[prediction], path=path)
|
||||
|
||||
recovered = sample_formatter.load(path=path)
|
||||
|
||||
assert len(recovered) == 1
|
||||
assert recovered[0].clip == prediction.clip
|
||||
|
||||
for recovered_prediction, detection in zip(
|
||||
recovered[0].predictions, detections
|
||||
):
|
||||
assert (
|
||||
recovered_prediction.detection_score == detection.detection_score
|
||||
)
|
||||
assert (
|
||||
recovered_prediction.class_scores == detection.class_scores
|
||||
).all()
|
||||
assert (recovered_prediction.features == detection.features).all()
|
||||
assert recovered_prediction.geometry == detection.geometry
|
||||
Loading…
Reference in New Issue
Block a user