Create data annotation loader registry

This commit is contained in:
mbsantiago 2026-03-15 20:53:59 +00:00
parent 3c337a06cb
commit 4b7d23abde
6 changed files with 124 additions and 44 deletions

View File

@ -44,11 +44,12 @@ class SimpleRegistry(Generic[T]):
class Registry(Generic[T_Type, P_Type]):
"""A generic class to create and manage a registry of items."""
def __init__(self, name: str):
def __init__(self, name: str, discriminator: str = "name"):
self._name = name
self._registry: dict[
str, Callable[Concatenate[..., P_Type], T_Type]
] = {}
self._discriminator = discriminator
self._config_types: dict[str, Type[BaseModel]] = {}
def register(
@ -57,15 +58,20 @@ class Registry(Generic[T_Type, P_Type]):
):
fields = config_cls.model_fields
if "name" not in fields:
raise ValueError("Configuration object must have a 'name' field.")
if self._discriminator not in fields:
raise ValueError(
"Configuration object must have "
f"a '{self._discriminator}' field."
)
name = fields["name"].default
name = fields[self._discriminator].default
self._config_types[name] = config_cls
if not isinstance(name, str):
raise ValueError("'name' field must be a string literal.")
raise ValueError(
f"'{self._discriminator}' field must be a string literal."
)
def decorator(
func: Callable[Concatenate[T_Config, P_Type], T_Type],
@ -95,10 +101,12 @@ class Registry(Generic[T_Type, P_Type]):
) -> T_Type:
"""Builds a logic instance from a config object."""
name = getattr(config, "name") # noqa: B009
name = getattr(config, self._discriminator) # noqa: B009
if name is None:
raise ValueError("Config does not have a name field")
raise ValueError(
f"Config does not have a '{self._discriminator}' field"
)
if name not in self._registry:
raise NotImplementedError(

View File

@ -18,17 +18,13 @@ from typing import Annotated
from pydantic import Field
from soundevent import data
from batdetect2.data.annotations.aoef import (
AOEFAnnotations,
load_aoef_annotated_dataset,
)
from batdetect2.data.annotations.aoef import AOEFAnnotations
from batdetect2.data.annotations.batdetect2 import (
AnnotationFilter,
BatDetect2FilesAnnotations,
BatDetect2MergedAnnotations,
load_batdetect2_files_annotated_dataset,
load_batdetect2_merged_annotated_dataset,
)
from batdetect2.data.annotations.registry import annotation_format_registry
from batdetect2.data.annotations.types import AnnotatedDataset
__all__ = [
@ -63,20 +59,20 @@ def load_annotated_dataset(
) -> data.AnnotationSet:
"""Load annotations for a single data source based on its configuration.
This function acts as a dispatcher. It inspects the type of the input
`source_config` object (which corresponds to a specific annotation format)
and calls the appropriate loading function (e.g.,
`load_aoef_annotated_dataset` for `AOEFAnnotations`).
This function acts as a dispatcher. It inspects the format of the input
`dataset` object and delegates to the appropriate format-specific loader
registered in the `annotation_format_registry` (e.g.,
`AOEFLoader` for `AOEFAnnotations`).
Parameters
----------
source_config : AnnotationFormats
dataset : AnnotatedDataset
The configuration object for the data source, specifying its format
and necessary details (like paths). Must be an instance of one of the
types included in the `AnnotationFormats` union.
base_dir : Path, optional
An optional base directory path. If provided, relative paths within
the `source_config` might be resolved relative to this directory by
the `dataset` will be resolved relative to this directory by
the underlying loading functions. Defaults to None.
Returns
@ -88,23 +84,8 @@ def load_annotated_dataset(
Raises
------
NotImplementedError
If the type of the `source_config` object does not match any of the
known format-specific loading functions implemented in the dispatch
logic.
If the `format` field of `dataset` does not match any registered
annotation format loader.
"""
if isinstance(dataset, AOEFAnnotations):
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)
if isinstance(dataset, BatDetect2MergedAnnotations):
return load_batdetect2_merged_annotated_dataset(
dataset, base_dir=base_dir
)
if isinstance(dataset, BatDetect2FilesAnnotations):
return load_batdetect2_files_annotated_dataset(
dataset,
base_dir=base_dir,
)
raise NotImplementedError(f"Unknown annotation format: {dataset.name}")
loader = annotation_format_registry.build(dataset)
return loader.load(base_dir=base_dir)

View File

@ -19,10 +19,15 @@ from pydantic import Field
from soundevent import data, io
from batdetect2.core.configs import BaseConfig
from batdetect2.data.annotations.types import AnnotatedDataset
from batdetect2.data.annotations.registry import annotation_format_registry
from batdetect2.data.annotations.types import (
AnnotatedDataset,
AnnotationLoader,
)
__all__ = [
"AOEFAnnotations",
"AOEFLoader",
"load_aoef_annotated_dataset",
"AnnotationTaskFilter",
]
@ -82,6 +87,22 @@ class AOEFAnnotations(AnnotatedDataset):
)
class AOEFLoader(AnnotationLoader):
def __init__(self, config: AOEFAnnotations):
self.config = config
def load(
self,
base_dir: Optional[data.PathLike] = None,
) -> data.AnnotationSet:
return load_aoef_annotated_dataset(self.config, base_dir=base_dir)
@annotation_format_registry.register(AOEFAnnotations)
@staticmethod
def from_config(config: AOEFAnnotations):
return AOEFLoader(config)
def load_aoef_annotated_dataset(
dataset: AOEFAnnotations,
base_dir: data.PathLike | None = None,

View File

@ -41,7 +41,11 @@ from batdetect2.data.annotations.legacy import (
list_file_annotations,
load_file_annotation,
)
from batdetect2.data.annotations.types import AnnotatedDataset
from batdetect2.data.annotations.registry import annotation_format_registry
from batdetect2.data.annotations.types import (
AnnotatedDataset,
AnnotationLoader,
)
PathLike = Path | str | os.PathLike
@ -302,7 +306,7 @@ def load_batdetect2_merged_annotated_dataset(
try:
ann = FileAnnotation.model_validate(ann)
except ValueError as err:
logger.warning(f"Invalid annotation file: {err}")
logger.warning("Invalid annotation file: {err}", err=err)
continue
if (
@ -310,17 +314,23 @@ def load_batdetect2_merged_annotated_dataset(
and dataset.filter.only_annotated
and not ann.annotated
):
logger.debug(f"Skipping incomplete annotation {ann.id}")
logger.debug(
"Skipping incomplete annotation {ann_id}",
ann_id=ann.id,
)
continue
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
logger.debug(f"Skipping annotation with issues {ann.id}")
logger.debug(
"Skipping annotation with issues {ann_id}",
ann_id=ann.id,
)
continue
try:
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
except FileNotFoundError as err:
logger.warning(f"Error loading annotations: {err}")
logger.warning("Error loading annotations: {err}", err=err)
continue
annotations.append(file_annotation_to_clip_annotation(ann, clip))
@ -330,3 +340,41 @@ def load_batdetect2_merged_annotated_dataset(
description=dataset.description,
clip_annotations=annotations,
)
class BatDetect2MergedLoader(AnnotationLoader):
def __init__(self, config: BatDetect2MergedAnnotations):
self.config = config
def load(
self,
base_dir: Optional[PathLike] = None,
) -> data.AnnotationSet:
return load_batdetect2_merged_annotated_dataset(
self.config,
base_dir=base_dir,
)
@annotation_format_registry.register(BatDetect2MergedAnnotations)
@staticmethod
def from_config(config: BatDetect2MergedAnnotations):
return BatDetect2MergedLoader(config)
class BatDetect2FilesLoader(AnnotationLoader):
def __init__(self, config: BatDetect2FilesAnnotations):
self.config = config
def load(
self,
base_dir: Optional[PathLike] = None,
) -> data.AnnotationSet:
return load_batdetect2_files_annotated_dataset(
self.config,
base_dir=base_dir,
)
@annotation_format_registry.register(BatDetect2FilesAnnotations)
@staticmethod
def from_config(config: BatDetect2FilesAnnotations):
return BatDetect2FilesLoader(config)

View File

@ -0,0 +1,11 @@
from batdetect2.core import Registry
from batdetect2.data.annotations.types import AnnotationLoader
__all__ = [
"annotation_format_registry",
]
annotation_format_registry: Registry[AnnotationLoader, []] = Registry(
"annotation_format",
discriminator="format",
)

View File

@ -1,9 +1,13 @@
from pathlib import Path
from typing import Optional, Protocol
from soundevent import data
from batdetect2.core.configs import BaseConfig
__all__ = [
"AnnotatedDataset",
"AnnotationLoader",
]
@ -34,3 +38,10 @@ class AnnotatedDataset(BaseConfig):
name: str
audio_dir: Path
description: str = ""
class AnnotationLoader(Protocol):
def load(
self,
base_dir: Optional[data.PathLike] = None,
) -> data.AnnotationSet: ...