diff --git a/src/batdetect2/data/__init__.py b/src/batdetect2/data/__init__.py index 647a135..e405a0c 100644 --- a/src/batdetect2/data/__init__.py +++ b/src/batdetect2/data/__init__.py @@ -19,6 +19,7 @@ from batdetect2.data.predictions import ( SoundEventOutputConfig, build_output_formatter, get_output_formatter, + load_predictions, ) from batdetect2.data.summary import ( compute_class_summary, @@ -46,4 +47,5 @@ __all__ = [ "load_dataset", "load_dataset_config", "load_dataset_from_config", + "load_predictions", ] diff --git a/src/batdetect2/data/annotations/legacy.py b/src/batdetect2/data/annotations/legacy.py index a433264..343cfe1 100644 --- a/src/batdetect2/data/annotations/legacy.py +++ b/src/batdetect2/data/annotations/legacy.py @@ -18,6 +18,14 @@ UNKNOWN_CLASS = "__UNKNOWN__" NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242") +CLIP_NAMESPACE = uuid.uuid5(NAMESPACE, "clip") +CLIP_ANNOTATION_NAMESPACE = uuid.uuid5(NAMESPACE, "clip_annotation") +RECORDING_NAMESPACE = uuid.uuid5(NAMESPACE, "recording") +SOUND_EVENT_NAMESPACE = uuid.uuid5(NAMESPACE, "sound_event") +SOUND_EVENT_ANNOTATION_NAMESPACE = uuid.uuid5( + NAMESPACE, "sound_event_annotation" +) + EventFn = Callable[[data.SoundEventAnnotation], Optional[str]] @@ -71,8 +79,8 @@ def annotation_to_sound_event( """Convert annotation to sound event annotation.""" sound_event = data.SoundEvent( uuid=uuid.uuid5( - NAMESPACE, - f"{recording.hash}_{annotation.start_time}_{annotation.end_time}", + SOUND_EVENT_NAMESPACE, + f"{recording.uuid}_{annotation.start_time}_{annotation.end_time}", ), recording=recording, geometry=data.BoundingBox( @@ -86,7 +94,10 @@ def annotation_to_sound_event( ) return data.SoundEventAnnotation( - uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"), + uuid=uuid.uuid5( + SOUND_EVENT_ANNOTATION_NAMESPACE, + f"{sound_event.uuid}", + ), sound_event=sound_event, tags=get_sound_event_tags( annotation, label_key, event_key, individual_key @@ -139,12 +150,18 @@ def file_annotation_to_clip( time_expansion=file_annotation.time_exp, tags=tags, ) + recording.uuid = uuid.uuid5(RECORDING_NAMESPACE, f"{recording.hash}") + start_time = 0 + end_time = recording.duration return data.Clip( - uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip"), + uuid=uuid.uuid5( + CLIP_NAMESPACE, + f"{recording.uuid}_{start_time}_{end_time}", + ), recording=recording, - start_time=0, - end_time=recording.duration, + start_time=start_time, + end_time=end_time, ) @@ -165,7 +182,7 @@ def file_annotation_to_clip_annotation( tags.append(data.Tag(key=label_key, value=file_annotation.label)) return data.ClipAnnotation( - uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"), + uuid=uuid.uuid5(CLIP_ANNOTATION_NAMESPACE, f"{clip.uuid}"), clip=clip, notes=notes, tags=tags, diff --git a/src/batdetect2/data/predictions/__init__.py b/src/batdetect2/data/predictions/__init__.py index 8839636..8c3cc1a 100644 --- a/src/batdetect2/data/predictions/__init__.py +++ b/src/batdetect2/data/predictions/__init__.py @@ -1,6 +1,7 @@ from typing import Annotated, Optional, Union from pydantic import Field +from soundevent.data import PathLike from batdetect2.data.predictions.base import ( OutputFormatterProtocol, @@ -21,7 +22,11 @@ __all__ = [ OutputFormatConfig = Annotated[ - Union[BatDetect2OutputConfig, SoundEventOutputConfig, RawOutputConfig], + Union[ + BatDetect2OutputConfig, + SoundEventOutputConfig, + RawOutputConfig, + ], Field(discriminator="name"), ] @@ -40,13 +45,16 @@ def build_output_formatter( def get_output_formatter( - name: str, + name: Optional[str] = None, targets: Optional[TargetProtocol] = None, config: Optional[OutputFormatConfig] = None, ) -> OutputFormatterProtocol: """Get the output formatter by name.""" if config is None: + if name is None: + raise ValueError("Either config or name must be provided.") + config_class = prediction_formatters.get_config_type(name) config = config_class() # type: ignore @@ -56,3 +64,17 @@ def get_output_formatter( ) return build_output_formatter(targets, config) + + +def load_predictions( + path: PathLike, + format: Optional[str] = "raw", + config: Optional[OutputFormatConfig] = None, + targets: Optional[TargetProtocol] = None, +): + """Load predictions from a file.""" + from batdetect2.targets import build_targets + + targets = targets or build_targets() + formatter = get_output_formatter(format, targets, config) + return formatter.load(path) diff --git a/src/batdetect2/data/predictions/raw.py b/src/batdetect2/data/predictions/raw.py index 715a595..416c74b 100644 --- a/src/batdetect2/data/predictions/raw.py +++ b/src/batdetect2/data/predictions/raw.py @@ -5,6 +5,7 @@ from uuid import UUID, uuid4 import numpy as np import xarray as xr +from loguru import logger from soundevent import data from soundevent.geometry import compute_bounds @@ -36,11 +37,13 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): include_class_scores: bool = True, include_features: bool = True, include_geometry: bool = True, + parse_full_geometry: bool = False, ): self.targets = targets self.include_class_scores = include_class_scores self.include_features = include_features self.include_geometry = include_geometry + self.parse_full_geometry = parse_full_geometry def format( self, @@ -169,6 +172,7 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): predictions: List[BatDetect2Prediction] = [] for _, clip_data in root.items(): + logger.debug(f"Loading clip {clip_data.clip_id.item()}") recording = data.Recording.model_validate_json( clip_data.attrs["recording"] ) @@ -183,37 +187,36 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): sound_events = [] - for detection in clip_data.detection: - score = clip_data.score.sel(detection=detection).item() + for detection in clip_data.coords["detection"]: + detection_data = clip_data.sel(detection=detection) + score = detection_data.score.item() - if "geometry" in clip_data: + if "geometry" in clip_data and self.parse_full_geometry: geometry = data.geometry_validate( - clip_data.geometry.sel(detection=detection).item() + detection_data.geometry.item() ) else: - start_time = clip_data.start_time.sel(detection=detection) - end_time = clip_data.end_time.sel(detection=detection) - low_freq = clip_data.low_freq.sel(detection=detection) - high_freq = clip_data.high_freq.sel(detection=detection) - geometry = data.BoundingBox( + start_time = detection_data.start_time + end_time = detection_data.end_time + low_freq = detection_data.low_freq + high_freq = detection_data.high_freq + geometry = data.BoundingBox.model_construct( coordinates=[start_time, low_freq, end_time, high_freq] ) - if "class_scores" in clip_data: - class_scores = clip_data.class_scores.sel( - detection=detection - ).data + if "class_scores" in detection_data: + class_scores = detection_data.class_scores.data else: class_scores = np.zeros(len(self.targets.class_names)) class_index = self.targets.class_names.index( - clip_data.top_class.sel(detection=detection).item() + detection_data.top_class.item() + ) + class_scores[class_index] = ( + detection_data.top_class_score.item() ) - class_scores[class_index] = clip_data.top_class_score.sel( - detection=detection - ).item() - if "features" in clip_data: - features = clip_data.features.sel(detection=detection).data + if "features" in detection_data: + features = detection_data.features.data else: features = np.zeros(0)