mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
save predictions as individual nc files for speed
This commit is contained in:
parent
bdb9e18964
commit
dbd2d30ead
@ -57,187 +57,181 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
|||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
num_features = 0
|
|
||||||
|
|
||||||
tree = xr.DataTree()
|
|
||||||
|
|
||||||
for prediction in predictions:
|
|
||||||
clip = prediction.clip
|
|
||||||
recording = clip.recording
|
|
||||||
|
|
||||||
if audio_dir is not None:
|
|
||||||
recording = recording.model_copy(
|
|
||||||
update=dict(
|
|
||||||
path=make_path_relative(recording.path, audio_dir)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
clip_data = defaultdict(list)
|
|
||||||
|
|
||||||
for pred in prediction.predictions:
|
|
||||||
detection_id = str(uuid4())
|
|
||||||
|
|
||||||
clip_data["detection_id"].append(detection_id)
|
|
||||||
clip_data["detection_score"].append(pred.detection_score)
|
|
||||||
|
|
||||||
start_time, low_freq, end_time, high_freq = compute_bounds(
|
|
||||||
pred.geometry
|
|
||||||
)
|
|
||||||
|
|
||||||
clip_data["start_time"].append(start_time)
|
|
||||||
clip_data["end_time"].append(end_time)
|
|
||||||
clip_data["low_freq"].append(low_freq)
|
|
||||||
clip_data["high_freq"].append(high_freq)
|
|
||||||
|
|
||||||
clip_data["geometry"].append(pred.geometry.model_dump_json())
|
|
||||||
|
|
||||||
top_class_index = int(np.argmax(pred.class_scores))
|
|
||||||
top_class_score = float(pred.class_scores[top_class_index])
|
|
||||||
top_class = self.targets.class_names[top_class_index]
|
|
||||||
|
|
||||||
clip_data["top_class"].append(top_class)
|
|
||||||
clip_data["top_class_score"].append(top_class_score)
|
|
||||||
|
|
||||||
clip_data["class_scores"].append(pred.class_scores)
|
|
||||||
clip_data["features"].append(pred.features)
|
|
||||||
|
|
||||||
num_features = len(pred.features)
|
|
||||||
|
|
||||||
data_vars = {
|
|
||||||
"score": (["detection"], clip_data["detection_score"]),
|
|
||||||
"start_time": (["detection"], clip_data["start_time"]),
|
|
||||||
"end_time": (["detection"], clip_data["end_time"]),
|
|
||||||
"low_freq": (["detection"], clip_data["low_freq"]),
|
|
||||||
"high_freq": (["detection"], clip_data["high_freq"]),
|
|
||||||
"top_class": (["detection"], clip_data["top_class"]),
|
|
||||||
"top_class_score": (
|
|
||||||
["detection"],
|
|
||||||
clip_data["top_class_score"],
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
coords = {
|
|
||||||
"detection": ("detection", clip_data["detection_id"]),
|
|
||||||
"clip_start": clip.start_time,
|
|
||||||
"clip_end": clip.end_time,
|
|
||||||
"clip_id": str(clip.uuid),
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.include_class_scores:
|
|
||||||
data_vars["class_scores"] = (
|
|
||||||
["detection", "classes"],
|
|
||||||
clip_data["class_scores"],
|
|
||||||
)
|
|
||||||
coords["classes"] = ("classes", self.targets.class_names)
|
|
||||||
|
|
||||||
if self.include_features:
|
|
||||||
data_vars["features"] = (
|
|
||||||
["detection", "feature"],
|
|
||||||
clip_data["features"],
|
|
||||||
)
|
|
||||||
coords["feature"] = ("feature", np.arange(num_features))
|
|
||||||
|
|
||||||
if self.include_geometry:
|
|
||||||
data_vars["geometry"] = (["detection"], clip_data["geometry"])
|
|
||||||
|
|
||||||
dataset = xr.Dataset(
|
|
||||||
data_vars=data_vars,
|
|
||||||
coords=coords,
|
|
||||||
attrs={
|
|
||||||
"recording": recording.model_dump_json(exclude_none=True),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
tree = tree.assign(
|
|
||||||
{
|
|
||||||
str(clip.uuid): xr.DataTree(
|
|
||||||
dataset=dataset,
|
|
||||||
name=str(clip.uuid),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
|
||||||
if not path.suffix == ".nc":
|
if not path.exists():
|
||||||
path = Path(path).with_suffix(".nc")
|
path.mkdir(parents=True)
|
||||||
|
|
||||||
tree.to_netcdf(path)
|
for prediction in predictions:
|
||||||
|
logger.debug(f"Saving clip predictions {prediction.clip.uuid}")
|
||||||
|
clip = prediction.clip
|
||||||
|
dataset = self.pred_to_xr(prediction, audio_dir)
|
||||||
|
dataset.to_netcdf(path / f"{clip.uuid}.nc")
|
||||||
|
|
||||||
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
|
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
files = list(path.glob("*.nc"))
|
||||||
root = xr.load_datatree(path)
|
|
||||||
|
|
||||||
predictions: List[BatDetect2Prediction] = []
|
predictions: List[BatDetect2Prediction] = []
|
||||||
|
|
||||||
for _, clip_data in root.items():
|
for filepath in files:
|
||||||
logger.debug(f"Loading clip {clip_data.clip_id.item()}")
|
logger.debug(f"Loading clip predictions {filepath}")
|
||||||
recording = data.Recording.model_validate_json(
|
clip_data = xr.load_dataset(filepath)
|
||||||
clip_data.attrs["recording"]
|
prediction = self.pred_from_xr(clip_data)
|
||||||
)
|
predictions.append(prediction)
|
||||||
|
|
||||||
clip_id = clip_data.clip_id.item()
|
|
||||||
clip = data.Clip(
|
|
||||||
recording=recording,
|
|
||||||
uuid=UUID(clip_id),
|
|
||||||
start_time=clip_data.clip_start,
|
|
||||||
end_time=clip_data.clip_end,
|
|
||||||
)
|
|
||||||
|
|
||||||
sound_events = []
|
|
||||||
|
|
||||||
for detection in clip_data.coords["detection"]:
|
|
||||||
detection_data = clip_data.sel(detection=detection)
|
|
||||||
score = detection_data.score.item()
|
|
||||||
|
|
||||||
if "geometry" in clip_data and self.parse_full_geometry:
|
|
||||||
geometry = data.geometry_validate(
|
|
||||||
detection_data.geometry.item()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
start_time = detection_data.start_time
|
|
||||||
end_time = detection_data.end_time
|
|
||||||
low_freq = detection_data.low_freq
|
|
||||||
high_freq = detection_data.high_freq
|
|
||||||
geometry = data.BoundingBox.model_construct(
|
|
||||||
coordinates=[start_time, low_freq, end_time, high_freq]
|
|
||||||
)
|
|
||||||
|
|
||||||
if "class_scores" in detection_data:
|
|
||||||
class_scores = detection_data.class_scores.data
|
|
||||||
else:
|
|
||||||
class_scores = np.zeros(len(self.targets.class_names))
|
|
||||||
class_index = self.targets.class_names.index(
|
|
||||||
detection_data.top_class.item()
|
|
||||||
)
|
|
||||||
class_scores[class_index] = (
|
|
||||||
detection_data.top_class_score.item()
|
|
||||||
)
|
|
||||||
|
|
||||||
if "features" in detection_data:
|
|
||||||
features = detection_data.features.data
|
|
||||||
else:
|
|
||||||
features = np.zeros(0)
|
|
||||||
|
|
||||||
sound_events.append(
|
|
||||||
RawPrediction(
|
|
||||||
geometry=geometry,
|
|
||||||
detection_score=score,
|
|
||||||
class_scores=class_scores,
|
|
||||||
features=features,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
predictions.append(
|
|
||||||
BatDetect2Prediction(
|
|
||||||
clip=clip,
|
|
||||||
predictions=sound_events,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
|
def pred_to_xr(
|
||||||
|
self,
|
||||||
|
prediction: BatDetect2Prediction,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> xr.Dataset:
|
||||||
|
clip = prediction.clip
|
||||||
|
recording = clip.recording
|
||||||
|
num_features = 0
|
||||||
|
|
||||||
|
if audio_dir is not None:
|
||||||
|
recording = recording.model_copy(
|
||||||
|
update=dict(path=make_path_relative(recording.path, audio_dir))
|
||||||
|
)
|
||||||
|
|
||||||
|
data = defaultdict(list)
|
||||||
|
|
||||||
|
for pred in prediction.predictions:
|
||||||
|
detection_id = str(uuid4())
|
||||||
|
|
||||||
|
data["detection_id"].append(detection_id)
|
||||||
|
data["detection_score"].append(pred.detection_score)
|
||||||
|
|
||||||
|
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||||
|
pred.geometry
|
||||||
|
)
|
||||||
|
|
||||||
|
data["start_time"].append(start_time)
|
||||||
|
data["end_time"].append(end_time)
|
||||||
|
data["low_freq"].append(low_freq)
|
||||||
|
data["high_freq"].append(high_freq)
|
||||||
|
|
||||||
|
data["geometry"].append(pred.geometry.model_dump_json())
|
||||||
|
|
||||||
|
top_class_index = int(np.argmax(pred.class_scores))
|
||||||
|
top_class_score = float(pred.class_scores[top_class_index])
|
||||||
|
top_class = self.targets.class_names[top_class_index]
|
||||||
|
|
||||||
|
data["top_class"].append(top_class)
|
||||||
|
data["top_class_score"].append(top_class_score)
|
||||||
|
|
||||||
|
data["class_scores"].append(pred.class_scores)
|
||||||
|
data["features"].append(pred.features)
|
||||||
|
|
||||||
|
num_features = len(pred.features)
|
||||||
|
|
||||||
|
data_vars = {
|
||||||
|
"score": (["detection"], data["detection_score"]),
|
||||||
|
"start_time": (["detection"], data["start_time"]),
|
||||||
|
"end_time": (["detection"], data["end_time"]),
|
||||||
|
"low_freq": (["detection"], data["low_freq"]),
|
||||||
|
"high_freq": (["detection"], data["high_freq"]),
|
||||||
|
"top_class": (["detection"], data["top_class"]),
|
||||||
|
"top_class_score": (["detection"], data["top_class_score"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
coords = {
|
||||||
|
"detection": ("detection", data["detection_id"]),
|
||||||
|
"clip_start": clip.start_time,
|
||||||
|
"clip_end": clip.end_time,
|
||||||
|
"clip_id": str(clip.uuid),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.include_class_scores:
|
||||||
|
class_scores = np.stack(data["class_scores"], axis=0)
|
||||||
|
data_vars["class_scores"] = (
|
||||||
|
["detection", "classes"],
|
||||||
|
class_scores,
|
||||||
|
)
|
||||||
|
coords["classes"] = ("classes", self.targets.class_names)
|
||||||
|
|
||||||
|
if self.include_features:
|
||||||
|
features = np.stack(data["features"], axis=0)
|
||||||
|
data_vars["features"] = (["detection", "feature"], features)
|
||||||
|
coords["feature"] = ("feature", np.arange(num_features))
|
||||||
|
|
||||||
|
if self.include_geometry:
|
||||||
|
data_vars["geometry"] = (["detection"], data["geometry"])
|
||||||
|
|
||||||
|
return xr.Dataset(
|
||||||
|
data_vars=data_vars,
|
||||||
|
coords=coords,
|
||||||
|
attrs={
|
||||||
|
"recording": recording.model_dump_json(exclude_none=True),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def pred_from_xr(self, dataset: xr.Dataset) -> BatDetect2Prediction:
|
||||||
|
clip_data = dataset
|
||||||
|
clip_id = clip_data.clip_id.item()
|
||||||
|
|
||||||
|
recording = data.Recording.model_validate_json(
|
||||||
|
clip_data.attrs["recording"]
|
||||||
|
)
|
||||||
|
|
||||||
|
clip_id = clip_data.clip_id.item()
|
||||||
|
clip = data.Clip(
|
||||||
|
recording=recording,
|
||||||
|
uuid=UUID(clip_id),
|
||||||
|
start_time=clip_data.clip_start,
|
||||||
|
end_time=clip_data.clip_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
sound_events = []
|
||||||
|
|
||||||
|
for detection in clip_data.coords["detection"]:
|
||||||
|
detection_data = clip_data.sel(detection=detection)
|
||||||
|
score = detection_data.score.item()
|
||||||
|
|
||||||
|
if "geometry" in clip_data and self.parse_full_geometry:
|
||||||
|
geometry = data.geometry_validate(
|
||||||
|
detection_data.geometry.item()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
start_time = detection_data.start_time
|
||||||
|
end_time = detection_data.end_time
|
||||||
|
low_freq = detection_data.low_freq
|
||||||
|
high_freq = detection_data.high_freq
|
||||||
|
geometry = data.BoundingBox.model_construct(
|
||||||
|
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||||
|
)
|
||||||
|
|
||||||
|
if "class_scores" in detection_data:
|
||||||
|
class_scores = detection_data.class_scores.data
|
||||||
|
else:
|
||||||
|
class_scores = np.zeros(len(self.targets.class_names))
|
||||||
|
class_index = self.targets.class_names.index(
|
||||||
|
detection_data.top_class.item()
|
||||||
|
)
|
||||||
|
class_scores[class_index] = (
|
||||||
|
detection_data.top_class_score.item()
|
||||||
|
)
|
||||||
|
|
||||||
|
if "features" in detection_data:
|
||||||
|
features = detection_data.features.data
|
||||||
|
else:
|
||||||
|
features = np.zeros(0)
|
||||||
|
|
||||||
|
sound_events.append(
|
||||||
|
RawPrediction(
|
||||||
|
geometry=geometry,
|
||||||
|
detection_score=score,
|
||||||
|
class_scores=class_scores,
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return BatDetect2Prediction(
|
||||||
|
clip=clip,
|
||||||
|
predictions=sound_events,
|
||||||
|
)
|
||||||
|
|
||||||
@prediction_formatters.register(RawOutputConfig)
|
@prediction_formatters.register(RawOutputConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(config: RawOutputConfig, targets: TargetProtocol):
|
def from_config(config: RawOutputConfig, targets: TargetProtocol):
|
||||||
|
|||||||
@ -9,17 +9,17 @@ import soundfile as sf
|
|||||||
from scipy import signal
|
from scipy import signal
|
||||||
from soundevent import data, terms
|
from soundevent import data, terms
|
||||||
|
|
||||||
|
from batdetect2.audio import build_audio_loader
|
||||||
|
from batdetect2.audio.clips import build_clipper
|
||||||
from batdetect2.data import DatasetConfig, load_dataset
|
from batdetect2.data import DatasetConfig, load_dataset
|
||||||
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.preprocess.audio import build_audio_loader
|
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
build_targets,
|
build_targets,
|
||||||
call_type,
|
call_type,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.classes import TargetClassConfig
|
from batdetect2.targets.classes import TargetClassConfig
|
||||||
from batdetect2.train.clips import build_clipper
|
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
ClipLabeller,
|
ClipLabeller,
|
||||||
@ -442,7 +442,7 @@ def example_annotations(
|
|||||||
) -> List[data.ClipAnnotation]:
|
) -> List[data.ClipAnnotation]:
|
||||||
annotations = load_dataset(example_dataset)
|
annotations = load_dataset(example_dataset)
|
||||||
assert len(annotations) == 3
|
assert len(annotations) == 3
|
||||||
return annotations
|
return list(annotations)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
0
tests/test_data/test_predictions/__init__.py
Normal file
0
tests/test_data/test_predictions/__init__.py
Normal file
64
tests/test_data/test_predictions/test_raw.py
Normal file
64
tests/test_data/test_predictions/test_raw.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.data.predictions import RawOutputConfig, build_output_formatter
|
||||||
|
from batdetect2.typing import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
RawPrediction,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_formatter(sample_targets: TargetProtocol):
|
||||||
|
return build_output_formatter(
|
||||||
|
config=RawOutputConfig(),
|
||||||
|
targets=sample_targets,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_roundtrip(
|
||||||
|
sample_formatter,
|
||||||
|
clip: data.Clip,
|
||||||
|
sample_targets: TargetProtocol,
|
||||||
|
tmp_path: Path,
|
||||||
|
):
|
||||||
|
detections = [
|
||||||
|
RawPrediction(
|
||||||
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=list(np.random.uniform(size=[4]))
|
||||||
|
),
|
||||||
|
detection_score=0.5,
|
||||||
|
class_scores=np.random.uniform(
|
||||||
|
size=len(sample_targets.class_names)
|
||||||
|
),
|
||||||
|
features=np.random.uniform(size=32),
|
||||||
|
)
|
||||||
|
for _ in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
prediction = BatDetect2Prediction(clip=clip, predictions=detections)
|
||||||
|
|
||||||
|
path = tmp_path / "predictions"
|
||||||
|
|
||||||
|
sample_formatter.save(predictions=[prediction], path=path)
|
||||||
|
|
||||||
|
recovered = sample_formatter.load(path=path)
|
||||||
|
|
||||||
|
assert len(recovered) == 1
|
||||||
|
assert recovered[0].clip == prediction.clip
|
||||||
|
|
||||||
|
for recovered_prediction, detection in zip(
|
||||||
|
recovered[0].predictions, detections
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
recovered_prediction.detection_score == detection.detection_score
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
recovered_prediction.class_scores == detection.class_scores
|
||||||
|
).all()
|
||||||
|
assert (recovered_prediction.features == detection.features).all()
|
||||||
|
assert recovered_prediction.geometry == detection.geometry
|
||||||
Loading…
Reference in New Issue
Block a user