Fix legacy import to use reproducible UUIDs

This commit is contained in:
mbsantiago 2025-11-16 21:37:33 +00:00
parent 76503fbd12
commit 960b9a92e4
4 changed files with 72 additions and 28 deletions

View File

@ -19,6 +19,7 @@ from batdetect2.data.predictions import (
SoundEventOutputConfig, SoundEventOutputConfig,
build_output_formatter, build_output_formatter,
get_output_formatter, get_output_formatter,
load_predictions,
) )
from batdetect2.data.summary import ( from batdetect2.data.summary import (
compute_class_summary, compute_class_summary,
@ -46,4 +47,5 @@ __all__ = [
"load_dataset", "load_dataset",
"load_dataset_config", "load_dataset_config",
"load_dataset_from_config", "load_dataset_from_config",
"load_predictions",
] ]

View File

@ -18,6 +18,14 @@ UNKNOWN_CLASS = "__UNKNOWN__"
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242") NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
CLIP_NAMESPACE = uuid.uuid5(NAMESPACE, "clip")
CLIP_ANNOTATION_NAMESPACE = uuid.uuid5(NAMESPACE, "clip_annotation")
RECORDING_NAMESPACE = uuid.uuid5(NAMESPACE, "recording")
SOUND_EVENT_NAMESPACE = uuid.uuid5(NAMESPACE, "sound_event")
SOUND_EVENT_ANNOTATION_NAMESPACE = uuid.uuid5(
NAMESPACE, "sound_event_annotation"
)
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]] EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
@ -71,8 +79,8 @@ def annotation_to_sound_event(
"""Convert annotation to sound event annotation.""" """Convert annotation to sound event annotation."""
sound_event = data.SoundEvent( sound_event = data.SoundEvent(
uuid=uuid.uuid5( uuid=uuid.uuid5(
NAMESPACE, SOUND_EVENT_NAMESPACE,
f"{recording.hash}_{annotation.start_time}_{annotation.end_time}", f"{recording.uuid}_{annotation.start_time}_{annotation.end_time}",
), ),
recording=recording, recording=recording,
geometry=data.BoundingBox( geometry=data.BoundingBox(
@ -86,7 +94,10 @@ def annotation_to_sound_event(
) )
return data.SoundEventAnnotation( return data.SoundEventAnnotation(
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"), uuid=uuid.uuid5(
SOUND_EVENT_ANNOTATION_NAMESPACE,
f"{sound_event.uuid}",
),
sound_event=sound_event, sound_event=sound_event,
tags=get_sound_event_tags( tags=get_sound_event_tags(
annotation, label_key, event_key, individual_key annotation, label_key, event_key, individual_key
@ -139,12 +150,18 @@ def file_annotation_to_clip(
time_expansion=file_annotation.time_exp, time_expansion=file_annotation.time_exp,
tags=tags, tags=tags,
) )
recording.uuid = uuid.uuid5(RECORDING_NAMESPACE, f"{recording.hash}")
start_time = 0
end_time = recording.duration
return data.Clip( return data.Clip(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip"), uuid=uuid.uuid5(
CLIP_NAMESPACE,
f"{recording.uuid}_{start_time}_{end_time}",
),
recording=recording, recording=recording,
start_time=0, start_time=start_time,
end_time=recording.duration, end_time=end_time,
) )
@ -165,7 +182,7 @@ def file_annotation_to_clip_annotation(
tags.append(data.Tag(key=label_key, value=file_annotation.label)) tags.append(data.Tag(key=label_key, value=file_annotation.label))
return data.ClipAnnotation( return data.ClipAnnotation(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"), uuid=uuid.uuid5(CLIP_ANNOTATION_NAMESPACE, f"{clip.uuid}"),
clip=clip, clip=clip,
notes=notes, notes=notes,
tags=tags, tags=tags,

View File

@ -1,6 +1,7 @@
from typing import Annotated, Optional, Union from typing import Annotated, Optional, Union
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike
from batdetect2.data.predictions.base import ( from batdetect2.data.predictions.base import (
OutputFormatterProtocol, OutputFormatterProtocol,
@ -21,7 +22,11 @@ __all__ = [
OutputFormatConfig = Annotated[ OutputFormatConfig = Annotated[
Union[BatDetect2OutputConfig, SoundEventOutputConfig, RawOutputConfig], Union[
BatDetect2OutputConfig,
SoundEventOutputConfig,
RawOutputConfig,
],
Field(discriminator="name"), Field(discriminator="name"),
] ]
@ -40,13 +45,16 @@ def build_output_formatter(
def get_output_formatter( def get_output_formatter(
name: str, name: Optional[str] = None,
targets: Optional[TargetProtocol] = None, targets: Optional[TargetProtocol] = None,
config: Optional[OutputFormatConfig] = None, config: Optional[OutputFormatConfig] = None,
) -> OutputFormatterProtocol: ) -> OutputFormatterProtocol:
"""Get the output formatter by name.""" """Get the output formatter by name."""
if config is None: if config is None:
if name is None:
raise ValueError("Either config or name must be provided.")
config_class = prediction_formatters.get_config_type(name) config_class = prediction_formatters.get_config_type(name)
config = config_class() # type: ignore config = config_class() # type: ignore
@ -56,3 +64,17 @@ def get_output_formatter(
) )
return build_output_formatter(targets, config) return build_output_formatter(targets, config)
def load_predictions(
path: PathLike,
format: Optional[str] = "raw",
config: Optional[OutputFormatConfig] = None,
targets: Optional[TargetProtocol] = None,
):
"""Load predictions from a file."""
from batdetect2.targets import build_targets
targets = targets or build_targets()
formatter = get_output_formatter(format, targets, config)
return formatter.load(path)

View File

@ -5,6 +5,7 @@ from uuid import UUID, uuid4
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from loguru import logger
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
@ -36,11 +37,13 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
include_class_scores: bool = True, include_class_scores: bool = True,
include_features: bool = True, include_features: bool = True,
include_geometry: bool = True, include_geometry: bool = True,
parse_full_geometry: bool = False,
): ):
self.targets = targets self.targets = targets
self.include_class_scores = include_class_scores self.include_class_scores = include_class_scores
self.include_features = include_features self.include_features = include_features
self.include_geometry = include_geometry self.include_geometry = include_geometry
self.parse_full_geometry = parse_full_geometry
def format( def format(
self, self,
@ -169,6 +172,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
predictions: List[BatDetect2Prediction] = [] predictions: List[BatDetect2Prediction] = []
for _, clip_data in root.items(): for _, clip_data in root.items():
logger.debug(f"Loading clip {clip_data.clip_id.item()}")
recording = data.Recording.model_validate_json( recording = data.Recording.model_validate_json(
clip_data.attrs["recording"] clip_data.attrs["recording"]
) )
@ -183,37 +187,36 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
sound_events = [] sound_events = []
for detection in clip_data.detection: for detection in clip_data.coords["detection"]:
score = clip_data.score.sel(detection=detection).item() detection_data = clip_data.sel(detection=detection)
score = detection_data.score.item()
if "geometry" in clip_data: if "geometry" in clip_data and self.parse_full_geometry:
geometry = data.geometry_validate( geometry = data.geometry_validate(
clip_data.geometry.sel(detection=detection).item() detection_data.geometry.item()
) )
else: else:
start_time = clip_data.start_time.sel(detection=detection) start_time = detection_data.start_time
end_time = clip_data.end_time.sel(detection=detection) end_time = detection_data.end_time
low_freq = clip_data.low_freq.sel(detection=detection) low_freq = detection_data.low_freq
high_freq = clip_data.high_freq.sel(detection=detection) high_freq = detection_data.high_freq
geometry = data.BoundingBox( geometry = data.BoundingBox.model_construct(
coordinates=[start_time, low_freq, end_time, high_freq] coordinates=[start_time, low_freq, end_time, high_freq]
) )
if "class_scores" in clip_data: if "class_scores" in detection_data:
class_scores = clip_data.class_scores.sel( class_scores = detection_data.class_scores.data
detection=detection
).data
else: else:
class_scores = np.zeros(len(self.targets.class_names)) class_scores = np.zeros(len(self.targets.class_names))
class_index = self.targets.class_names.index( class_index = self.targets.class_names.index(
clip_data.top_class.sel(detection=detection).item() detection_data.top_class.item()
)
class_scores[class_index] = (
detection_data.top_class_score.item()
) )
class_scores[class_index] = clip_data.top_class_score.sel(
detection=detection
).item()
if "features" in clip_data: if "features" in detection_data:
features = clip_data.features.sel(detection=detection).data features = detection_data.features.data
else: else:
features = np.zeros(0) features = np.zeros(0)