mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Fix legacy import to use reproducible UUIDs
This commit is contained in:
parent
76503fbd12
commit
960b9a92e4
@ -19,6 +19,7 @@ from batdetect2.data.predictions import (
|
||||
SoundEventOutputConfig,
|
||||
build_output_formatter,
|
||||
get_output_formatter,
|
||||
load_predictions,
|
||||
)
|
||||
from batdetect2.data.summary import (
|
||||
compute_class_summary,
|
||||
@ -46,4 +47,5 @@ __all__ = [
|
||||
"load_dataset",
|
||||
"load_dataset_config",
|
||||
"load_dataset_from_config",
|
||||
"load_predictions",
|
||||
]
|
||||
|
||||
@ -18,6 +18,14 @@ UNKNOWN_CLASS = "__UNKNOWN__"
|
||||
|
||||
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
|
||||
|
||||
CLIP_NAMESPACE = uuid.uuid5(NAMESPACE, "clip")
|
||||
CLIP_ANNOTATION_NAMESPACE = uuid.uuid5(NAMESPACE, "clip_annotation")
|
||||
RECORDING_NAMESPACE = uuid.uuid5(NAMESPACE, "recording")
|
||||
SOUND_EVENT_NAMESPACE = uuid.uuid5(NAMESPACE, "sound_event")
|
||||
SOUND_EVENT_ANNOTATION_NAMESPACE = uuid.uuid5(
|
||||
NAMESPACE, "sound_event_annotation"
|
||||
)
|
||||
|
||||
|
||||
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||
|
||||
@ -71,8 +79,8 @@ def annotation_to_sound_event(
|
||||
"""Convert annotation to sound event annotation."""
|
||||
sound_event = data.SoundEvent(
|
||||
uuid=uuid.uuid5(
|
||||
NAMESPACE,
|
||||
f"{recording.hash}_{annotation.start_time}_{annotation.end_time}",
|
||||
SOUND_EVENT_NAMESPACE,
|
||||
f"{recording.uuid}_{annotation.start_time}_{annotation.end_time}",
|
||||
),
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(
|
||||
@ -86,7 +94,10 @@ def annotation_to_sound_event(
|
||||
)
|
||||
|
||||
return data.SoundEventAnnotation(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||
uuid=uuid.uuid5(
|
||||
SOUND_EVENT_ANNOTATION_NAMESPACE,
|
||||
f"{sound_event.uuid}",
|
||||
),
|
||||
sound_event=sound_event,
|
||||
tags=get_sound_event_tags(
|
||||
annotation, label_key, event_key, individual_key
|
||||
@ -139,12 +150,18 @@ def file_annotation_to_clip(
|
||||
time_expansion=file_annotation.time_exp,
|
||||
tags=tags,
|
||||
)
|
||||
recording.uuid = uuid.uuid5(RECORDING_NAMESPACE, f"{recording.hash}")
|
||||
|
||||
start_time = 0
|
||||
end_time = recording.duration
|
||||
return data.Clip(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip"),
|
||||
uuid=uuid.uuid5(
|
||||
CLIP_NAMESPACE,
|
||||
f"{recording.uuid}_{start_time}_{end_time}",
|
||||
),
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
|
||||
@ -165,7 +182,7 @@ def file_annotation_to_clip_annotation(
|
||||
tags.append(data.Tag(key=label_key, value=file_annotation.label))
|
||||
|
||||
return data.ClipAnnotation(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
||||
uuid=uuid.uuid5(CLIP_ANNOTATION_NAMESPACE, f"{clip.uuid}"),
|
||||
clip=clip,
|
||||
notes=notes,
|
||||
tags=tags,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.data.predictions.base import (
|
||||
OutputFormatterProtocol,
|
||||
@ -21,7 +22,11 @@ __all__ = [
|
||||
|
||||
|
||||
OutputFormatConfig = Annotated[
|
||||
Union[BatDetect2OutputConfig, SoundEventOutputConfig, RawOutputConfig],
|
||||
Union[
|
||||
BatDetect2OutputConfig,
|
||||
SoundEventOutputConfig,
|
||||
RawOutputConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
@ -40,13 +45,16 @@ def build_output_formatter(
|
||||
|
||||
|
||||
def get_output_formatter(
|
||||
name: str,
|
||||
name: Optional[str] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
config: Optional[OutputFormatConfig] = None,
|
||||
) -> OutputFormatterProtocol:
|
||||
"""Get the output formatter by name."""
|
||||
|
||||
if config is None:
|
||||
if name is None:
|
||||
raise ValueError("Either config or name must be provided.")
|
||||
|
||||
config_class = prediction_formatters.get_config_type(name)
|
||||
config = config_class() # type: ignore
|
||||
|
||||
@ -56,3 +64,17 @@ def get_output_formatter(
|
||||
)
|
||||
|
||||
return build_output_formatter(targets, config)
|
||||
|
||||
|
||||
def load_predictions(
|
||||
path: PathLike,
|
||||
format: Optional[str] = "raw",
|
||||
config: Optional[OutputFormatConfig] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
):
|
||||
"""Load predictions from a file."""
|
||||
from batdetect2.targets import build_targets
|
||||
|
||||
targets = targets or build_targets()
|
||||
formatter = get_output_formatter(format, targets, config)
|
||||
return formatter.load(path)
|
||||
|
||||
@ -5,6 +5,7 @@ from uuid import UUID, uuid4
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
@ -36,11 +37,13 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
include_class_scores: bool = True,
|
||||
include_features: bool = True,
|
||||
include_geometry: bool = True,
|
||||
parse_full_geometry: bool = False,
|
||||
):
|
||||
self.targets = targets
|
||||
self.include_class_scores = include_class_scores
|
||||
self.include_features = include_features
|
||||
self.include_geometry = include_geometry
|
||||
self.parse_full_geometry = parse_full_geometry
|
||||
|
||||
def format(
|
||||
self,
|
||||
@ -169,6 +172,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
predictions: List[BatDetect2Prediction] = []
|
||||
|
||||
for _, clip_data in root.items():
|
||||
logger.debug(f"Loading clip {clip_data.clip_id.item()}")
|
||||
recording = data.Recording.model_validate_json(
|
||||
clip_data.attrs["recording"]
|
||||
)
|
||||
@ -183,37 +187,36 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
|
||||
sound_events = []
|
||||
|
||||
for detection in clip_data.detection:
|
||||
score = clip_data.score.sel(detection=detection).item()
|
||||
for detection in clip_data.coords["detection"]:
|
||||
detection_data = clip_data.sel(detection=detection)
|
||||
score = detection_data.score.item()
|
||||
|
||||
if "geometry" in clip_data:
|
||||
if "geometry" in clip_data and self.parse_full_geometry:
|
||||
geometry = data.geometry_validate(
|
||||
clip_data.geometry.sel(detection=detection).item()
|
||||
detection_data.geometry.item()
|
||||
)
|
||||
else:
|
||||
start_time = clip_data.start_time.sel(detection=detection)
|
||||
end_time = clip_data.end_time.sel(detection=detection)
|
||||
low_freq = clip_data.low_freq.sel(detection=detection)
|
||||
high_freq = clip_data.high_freq.sel(detection=detection)
|
||||
geometry = data.BoundingBox(
|
||||
start_time = detection_data.start_time
|
||||
end_time = detection_data.end_time
|
||||
low_freq = detection_data.low_freq
|
||||
high_freq = detection_data.high_freq
|
||||
geometry = data.BoundingBox.model_construct(
|
||||
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||
)
|
||||
|
||||
if "class_scores" in clip_data:
|
||||
class_scores = clip_data.class_scores.sel(
|
||||
detection=detection
|
||||
).data
|
||||
if "class_scores" in detection_data:
|
||||
class_scores = detection_data.class_scores.data
|
||||
else:
|
||||
class_scores = np.zeros(len(self.targets.class_names))
|
||||
class_index = self.targets.class_names.index(
|
||||
clip_data.top_class.sel(detection=detection).item()
|
||||
detection_data.top_class.item()
|
||||
)
|
||||
class_scores[class_index] = (
|
||||
detection_data.top_class_score.item()
|
||||
)
|
||||
class_scores[class_index] = clip_data.top_class_score.sel(
|
||||
detection=detection
|
||||
).item()
|
||||
|
||||
if "features" in clip_data:
|
||||
features = clip_data.features.sel(detection=detection).data
|
||||
if "features" in detection_data:
|
||||
features = detection_data.features.data
|
||||
else:
|
||||
features = np.zeros(0)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user