Update compat module to use new term module

This commit is contained in:
mbsantiago 2025-08-08 12:25:16 +01:00
parent 62923a201b
commit e1908c35ca
2 changed files with 100 additions and 61 deletions

View File

@ -0,0 +1,15 @@
from batdetect2.compat.data import (
annotation_to_sound_event_annotation,
annotation_to_sound_event_prediction,
convert_to_annotation_group,
file_annotation_to_clip_annotation,
load_file_annotation,
)
__all__ = [
"annotation_to_sound_event_annotation",
"annotation_to_sound_event_prediction",
"convert_to_annotation_group",
"file_annotation_to_clip_annotation",
"load_file_annotation",
]

View File

@ -1,24 +1,30 @@
"""Compatibility functions between old and new data structures."""
import json
import os
import uuid
from pathlib import Path
from typing import Callable, List, Optional, Union
import numpy as np
from pydantic import BaseModel, Field
from soundevent import data
from soundevent.geometry import compute_bounds
from soundevent.types import ClassMapper
from batdetect2 import types
from batdetect2.targets.terms import get_term_from_key
from batdetect2.types import (
Annotation,
AudioLoaderAnnotationGroup,
FileAnnotation,
)
PathLike = Union[Path, str, os.PathLike]
__all__ = [
"convert_to_annotation_group",
"load_file_annotation",
"annotation_to_sound_event",
"annotation_to_sound_event_annotation",
"annotation_to_sound_event_prediction",
]
SPECIES_TAG_KEY = "species"
@ -37,7 +43,7 @@ IndividualFn = Callable[[data.SoundEventAnnotation], int]
def get_recording_class_name(recording: data.Recording) -> str:
"""Get the class name for a recording."""
tag = data.find_tag(recording.tags, SPECIES_TAG_KEY)
tag = data.find_tag(recording.tags, label=SPECIES_TAG_KEY)
if tag is None:
return UNKNOWN_CLASS
return tag.value
@ -59,7 +65,7 @@ def convert_to_annotation_group(
event_fn: EventFn = lambda _: ECHOLOCATION_EVENT,
class_fn: ClassFn = lambda _: 0,
individual_fn: IndividualFn = lambda _: 0,
) -> types.AudioLoaderAnnotationGroup:
) -> AudioLoaderAnnotationGroup:
"""Convert a ClipAnnotation to an AudioLoaderAnnotationGroup."""
recording = annotation.clip.recording
@ -71,7 +77,7 @@ def convert_to_annotation_group(
x_inds = []
y_inds = []
individual_ids = []
annotations: List[types.Annotation] = []
annotations: List[Annotation] = []
class_id_file = class_fn(recording)
for sound_event in annotation.sound_events:
@ -133,42 +139,13 @@ def convert_to_annotation_group(
}
class Annotation(BaseModel):
"""Annotation class to hold batdetect annotations."""
label: str = Field(alias="class")
event: str
individual: int = 0
start_time: float
end_time: float
low_freq: float
high_freq: float
class FileAnnotation(BaseModel):
"""FileAnnotation class to hold batdetect annotations for a file."""
id: str
duration: float
time_exp: float = 1
label: str = Field(alias="class_name")
annotation: List[Annotation]
annotated: bool = False
issues: bool = False
notes: str = ""
def load_file_annotation(path: PathLike) -> FileAnnotation:
"""Load annotation from batdetect format."""
path = Path(path)
return FileAnnotation.model_validate_json(path.read_text())
return json.loads(path.read_text())
def annotation_to_sound_event(
def annotation_to_sound_event_annotation(
annotation: Annotation,
recording: data.Recording,
label_key: str = "class",
@ -179,15 +156,15 @@ def annotation_to_sound_event(
sound_event = data.SoundEvent(
uuid=uuid.uuid5(
NAMESPACE,
f"{recording.hash}_{annotation.start_time}_{annotation.end_time}",
f"{recording.hash}_{annotation['start_time']}_{annotation['end_time']}",
),
recording=recording,
geometry=data.BoundingBox(
coordinates=[
annotation.start_time,
annotation.low_freq,
annotation.end_time,
annotation.high_freq,
annotation["start_time"],
annotation["low_freq"],
annotation["end_time"],
annotation["high_freq"],
],
),
)
@ -197,16 +174,62 @@ def annotation_to_sound_event(
sound_event=sound_event,
tags=[
data.Tag(
term=data.term_from_key(label_key),
value=annotation.label,
term=get_term_from_key(label_key),
value=annotation["class"],
),
data.Tag(
term=data.term_from_key(event_key),
value=annotation.event,
term=get_term_from_key(event_key),
value=annotation["event"],
),
data.Tag(
term=data.term_from_key(individual_key),
value=str(annotation.individual),
term=get_term_from_key(individual_key),
value=str(annotation["individual"]),
),
],
)
def annotation_to_sound_event_prediction(
annotation: Annotation,
recording: data.Recording,
label_key: str = "class",
event_key: str = "event",
) -> data.SoundEventPrediction:
"""Convert annotation to sound event annotation."""
sound_event = data.SoundEvent(
uuid=uuid.uuid5(
NAMESPACE,
f"{recording.hash}_{annotation['start_time']}_{annotation['end_time']}",
),
recording=recording,
geometry=data.BoundingBox(
coordinates=[
annotation["start_time"],
annotation["low_freq"],
annotation["end_time"],
annotation["high_freq"],
],
),
)
return data.SoundEventPrediction(
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
sound_event=sound_event,
score=annotation["det_prob"],
tags=[
data.PredictedTag(
score=annotation["class_prob"],
tag=data.Tag(
term=get_term_from_key(label_key),
value=annotation["class"],
),
),
data.PredictedTag(
score=annotation["det_prob"],
tag=data.Tag(
term=get_term_from_key(event_key),
value=annotation["event"],
),
),
],
)
@ -220,24 +243,24 @@ def file_annotation_to_clip(
"""Convert file annotation to recording."""
audio_dir = audio_dir or Path.cwd()
full_path = Path(audio_dir) / file_annotation.id
full_path = Path(audio_dir) / file_annotation["id"]
if not full_path.exists():
raise FileNotFoundError(f"File {full_path} not found.")
recording = data.Recording.from_file(
full_path,
time_expansion=file_annotation.time_exp,
time_expansion=file_annotation["time_exp"],
tags=[
data.Tag(
term=data.term_from_key(label_key),
value=file_annotation.label,
value=file_annotation["class_name"],
)
],
)
return data.Clip(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip"),
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation['id']}_clip"),
recording=recording,
start_time=0,
end_time=recording.duration,
@ -253,27 +276,28 @@ def file_annotation_to_clip_annotation(
) -> data.ClipAnnotation:
"""Convert file annotation to clip annotation."""
notes = []
if file_annotation.notes:
notes.append(data.Note(message=file_annotation.notes))
if file_annotation["notes"]:
notes.append(data.Note(message=file_annotation["notes"]))
return data.ClipAnnotation(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation['id']}_clip_annotation"),
clip=clip,
notes=notes,
tags=[
data.Tag(
term=data.term_from_key(label_key), value=file_annotation.label
term=data.term_from_key(label_key),
value=file_annotation["class_name"],
)
],
sound_events=[
annotation_to_sound_event(
annotation_to_sound_event_annotation(
annotation,
clip.recording,
label_key=label_key,
event_key=event_key,
individual_key=individual_key,
)
for annotation in file_annotation.annotation
for annotation in file_annotation["annotation"]
],
)
@ -284,17 +308,17 @@ def file_annotation_to_annotation_task(
) -> data.AnnotationTask:
status_badges = []
if file_annotation.issues:
if file_annotation["issues"]:
status_badges.append(
data.StatusBadge(state=data.AnnotationState.rejected)
)
elif file_annotation.annotated:
elif file_annotation["annotated"]:
status_badges.append(
data.StatusBadge(state=data.AnnotationState.completed)
)
return data.AnnotationTask(
uuid=uuid.uuid5(uuid.NAMESPACE_URL, f"{file_annotation.id}_task"),
uuid=uuid.uuid5(uuid.NAMESPACE_URL, f"{file_annotation['id']}_task"),
clip=clip,
status_badges=status_badges,
)