Create data annotation loader registry

This commit is contained in:
mbsantiago 2026-03-15 20:53:59 +00:00
parent 3c337a06cb
commit 4b7d23abde
6 changed files with 124 additions and 44 deletions

View File

@ -44,11 +44,12 @@ class SimpleRegistry(Generic[T]):
class Registry(Generic[T_Type, P_Type]): class Registry(Generic[T_Type, P_Type]):
"""A generic class to create and manage a registry of items.""" """A generic class to create and manage a registry of items."""
def __init__(self, name: str): def __init__(self, name: str, discriminator: str = "name"):
self._name = name self._name = name
self._registry: dict[ self._registry: dict[
str, Callable[Concatenate[..., P_Type], T_Type] str, Callable[Concatenate[..., P_Type], T_Type]
] = {} ] = {}
self._discriminator = discriminator
self._config_types: dict[str, Type[BaseModel]] = {} self._config_types: dict[str, Type[BaseModel]] = {}
def register( def register(
@ -57,15 +58,20 @@ class Registry(Generic[T_Type, P_Type]):
): ):
fields = config_cls.model_fields fields = config_cls.model_fields
if "name" not in fields: if self._discriminator not in fields:
raise ValueError("Configuration object must have a 'name' field.") raise ValueError(
"Configuration object must have "
f"a '{self._discriminator}' field."
)
name = fields["name"].default name = fields[self._discriminator].default
self._config_types[name] = config_cls self._config_types[name] = config_cls
if not isinstance(name, str): if not isinstance(name, str):
raise ValueError("'name' field must be a string literal.") raise ValueError(
f"'{self._discriminator}' field must be a string literal."
)
def decorator( def decorator(
func: Callable[Concatenate[T_Config, P_Type], T_Type], func: Callable[Concatenate[T_Config, P_Type], T_Type],
@ -95,10 +101,12 @@ class Registry(Generic[T_Type, P_Type]):
) -> T_Type: ) -> T_Type:
"""Builds a logic instance from a config object.""" """Builds a logic instance from a config object."""
name = getattr(config, "name") # noqa: B009 name = getattr(config, self._discriminator) # noqa: B009
if name is None: if name is None:
raise ValueError("Config does not have a name field") raise ValueError(
f"Config does not have a '{self._discriminator}' field"
)
if name not in self._registry: if name not in self._registry:
raise NotImplementedError( raise NotImplementedError(

View File

@ -18,17 +18,13 @@ from typing import Annotated
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.data.annotations.aoef import ( from batdetect2.data.annotations.aoef import AOEFAnnotations
AOEFAnnotations,
load_aoef_annotated_dataset,
)
from batdetect2.data.annotations.batdetect2 import ( from batdetect2.data.annotations.batdetect2 import (
AnnotationFilter, AnnotationFilter,
BatDetect2FilesAnnotations, BatDetect2FilesAnnotations,
BatDetect2MergedAnnotations, BatDetect2MergedAnnotations,
load_batdetect2_files_annotated_dataset,
load_batdetect2_merged_annotated_dataset,
) )
from batdetect2.data.annotations.registry import annotation_format_registry
from batdetect2.data.annotations.types import AnnotatedDataset from batdetect2.data.annotations.types import AnnotatedDataset
__all__ = [ __all__ = [
@ -63,20 +59,20 @@ def load_annotated_dataset(
) -> data.AnnotationSet: ) -> data.AnnotationSet:
"""Load annotations for a single data source based on its configuration. """Load annotations for a single data source based on its configuration.
This function acts as a dispatcher. It inspects the type of the input This function acts as a dispatcher. It inspects the format of the input
`source_config` object (which corresponds to a specific annotation format) `dataset` object and delegates to the appropriate format-specific loader
and calls the appropriate loading function (e.g., registered in the `annotation_format_registry` (e.g.,
`load_aoef_annotated_dataset` for `AOEFAnnotations`). `AOEFLoader` for `AOEFAnnotations`).
Parameters Parameters
---------- ----------
source_config : AnnotationFormats dataset : AnnotatedDataset
The configuration object for the data source, specifying its format The configuration object for the data source, specifying its format
and necessary details (like paths). Must be an instance of one of the and necessary details (like paths). Must be an instance of one of the
types included in the `AnnotationFormats` union. types included in the `AnnotationFormats` union.
base_dir : Path, optional base_dir : Path, optional
An optional base directory path. If provided, relative paths within An optional base directory path. If provided, relative paths within
the `source_config` might be resolved relative to this directory by the `dataset` will be resolved relative to this directory by
the underlying loading functions. Defaults to None. the underlying loading functions. Defaults to None.
Returns Returns
@ -88,23 +84,8 @@ def load_annotated_dataset(
Raises Raises
------ ------
NotImplementedError NotImplementedError
If the type of the `source_config` object does not match any of the If the `format` field of `dataset` does not match any registered
known format-specific loading functions implemented in the dispatch annotation format loader.
logic.
""" """
loader = annotation_format_registry.build(dataset)
if isinstance(dataset, AOEFAnnotations): return loader.load(base_dir=base_dir)
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)
if isinstance(dataset, BatDetect2MergedAnnotations):
return load_batdetect2_merged_annotated_dataset(
dataset, base_dir=base_dir
)
if isinstance(dataset, BatDetect2FilesAnnotations):
return load_batdetect2_files_annotated_dataset(
dataset,
base_dir=base_dir,
)
raise NotImplementedError(f"Unknown annotation format: {dataset.name}")

View File

@ -19,10 +19,15 @@ from pydantic import Field
from soundevent import data, io from soundevent import data, io
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.data.annotations.types import AnnotatedDataset from batdetect2.data.annotations.registry import annotation_format_registry
from batdetect2.data.annotations.types import (
AnnotatedDataset,
AnnotationLoader,
)
__all__ = [ __all__ = [
"AOEFAnnotations", "AOEFAnnotations",
"AOEFLoader",
"load_aoef_annotated_dataset", "load_aoef_annotated_dataset",
"AnnotationTaskFilter", "AnnotationTaskFilter",
] ]
@ -82,6 +87,22 @@ class AOEFAnnotations(AnnotatedDataset):
) )
class AOEFLoader(AnnotationLoader):
def __init__(self, config: AOEFAnnotations):
self.config = config
def load(
self,
base_dir: Optional[data.PathLike] = None,
) -> data.AnnotationSet:
return load_aoef_annotated_dataset(self.config, base_dir=base_dir)
@annotation_format_registry.register(AOEFAnnotations)
@staticmethod
def from_config(config: AOEFAnnotations):
return AOEFLoader(config)
def load_aoef_annotated_dataset( def load_aoef_annotated_dataset(
dataset: AOEFAnnotations, dataset: AOEFAnnotations,
base_dir: data.PathLike | None = None, base_dir: data.PathLike | None = None,

View File

@ -41,7 +41,11 @@ from batdetect2.data.annotations.legacy import (
list_file_annotations, list_file_annotations,
load_file_annotation, load_file_annotation,
) )
from batdetect2.data.annotations.types import AnnotatedDataset from batdetect2.data.annotations.registry import annotation_format_registry
from batdetect2.data.annotations.types import (
AnnotatedDataset,
AnnotationLoader,
)
PathLike = Path | str | os.PathLike PathLike = Path | str | os.PathLike
@ -302,7 +306,7 @@ def load_batdetect2_merged_annotated_dataset(
try: try:
ann = FileAnnotation.model_validate(ann) ann = FileAnnotation.model_validate(ann)
except ValueError as err: except ValueError as err:
logger.warning(f"Invalid annotation file: {err}") logger.warning("Invalid annotation file: {err}", err=err)
continue continue
if ( if (
@ -310,17 +314,23 @@ def load_batdetect2_merged_annotated_dataset(
and dataset.filter.only_annotated and dataset.filter.only_annotated
and not ann.annotated and not ann.annotated
): ):
logger.debug(f"Skipping incomplete annotation {ann.id}") logger.debug(
"Skipping incomplete annotation {ann_id}",
ann_id=ann.id,
)
continue continue
if dataset.filter and dataset.filter.exclude_issues and ann.issues: if dataset.filter and dataset.filter.exclude_issues and ann.issues:
logger.debug(f"Skipping annotation with issues {ann.id}") logger.debug(
"Skipping annotation with issues {ann_id}",
ann_id=ann.id,
)
continue continue
try: try:
clip = file_annotation_to_clip(ann, audio_dir=audio_dir) clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
except FileNotFoundError as err: except FileNotFoundError as err:
logger.warning(f"Error loading annotations: {err}") logger.warning("Error loading annotations: {err}", err=err)
continue continue
annotations.append(file_annotation_to_clip_annotation(ann, clip)) annotations.append(file_annotation_to_clip_annotation(ann, clip))
@ -330,3 +340,41 @@ def load_batdetect2_merged_annotated_dataset(
description=dataset.description, description=dataset.description,
clip_annotations=annotations, clip_annotations=annotations,
) )
class BatDetect2MergedLoader(AnnotationLoader):
def __init__(self, config: BatDetect2MergedAnnotations):
self.config = config
def load(
self,
base_dir: Optional[PathLike] = None,
) -> data.AnnotationSet:
return load_batdetect2_merged_annotated_dataset(
self.config,
base_dir=base_dir,
)
@annotation_format_registry.register(BatDetect2MergedAnnotations)
@staticmethod
def from_config(config: BatDetect2MergedAnnotations):
return BatDetect2MergedLoader(config)
class BatDetect2FilesLoader(AnnotationLoader):
def __init__(self, config: BatDetect2FilesAnnotations):
self.config = config
def load(
self,
base_dir: Optional[PathLike] = None,
) -> data.AnnotationSet:
return load_batdetect2_files_annotated_dataset(
self.config,
base_dir=base_dir,
)
@annotation_format_registry.register(BatDetect2FilesAnnotations)
@staticmethod
def from_config(config: BatDetect2FilesAnnotations):
return BatDetect2FilesLoader(config)

View File

@ -0,0 +1,11 @@
from batdetect2.core import Registry
from batdetect2.data.annotations.types import AnnotationLoader
__all__ = [
"annotation_format_registry",
]
annotation_format_registry: Registry[AnnotationLoader, []] = Registry(
"annotation_format",
discriminator="format",
)

View File

@ -1,9 +1,13 @@
from pathlib import Path from pathlib import Path
from typing import Optional, Protocol
from soundevent import data
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
__all__ = [ __all__ = [
"AnnotatedDataset", "AnnotatedDataset",
"AnnotationLoader",
] ]
@ -34,3 +38,10 @@ class AnnotatedDataset(BaseConfig):
name: str name: str
audio_dir: Path audio_dir: Path
description: str = "" description: str = ""
class AnnotationLoader(Protocol):
def load(
self,
base_dir: Optional[data.PathLike] = None,
) -> data.AnnotationSet: ...