From 4b7d23abde4a495d9cc3aebf6b98447113055c27 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sun, 15 Mar 2026 20:53:59 +0000 Subject: [PATCH] Create data annotation loader registry --- src/batdetect2/core/registries.py | 22 ++++--- src/batdetect2/data/annotations/__init__.py | 43 ++++---------- src/batdetect2/data/annotations/aoef.py | 23 +++++++- src/batdetect2/data/annotations/batdetect2.py | 58 +++++++++++++++++-- src/batdetect2/data/annotations/registry.py | 11 ++++ src/batdetect2/data/annotations/types.py | 11 ++++ 6 files changed, 124 insertions(+), 44 deletions(-) create mode 100644 src/batdetect2/data/annotations/registry.py diff --git a/src/batdetect2/core/registries.py b/src/batdetect2/core/registries.py index 3cfc817..33a5d41 100644 --- a/src/batdetect2/core/registries.py +++ b/src/batdetect2/core/registries.py @@ -44,11 +44,12 @@ class SimpleRegistry(Generic[T]): class Registry(Generic[T_Type, P_Type]): """A generic class to create and manage a registry of items.""" - def __init__(self, name: str): + def __init__(self, name: str, discriminator: str = "name"): self._name = name self._registry: dict[ str, Callable[Concatenate[..., P_Type], T_Type] ] = {} + self._discriminator = discriminator self._config_types: dict[str, Type[BaseModel]] = {} def register( @@ -57,15 +58,20 @@ class Registry(Generic[T_Type, P_Type]): ): fields = config_cls.model_fields - if "name" not in fields: - raise ValueError("Configuration object must have a 'name' field.") + if self._discriminator not in fields: + raise ValueError( + "Configuration object must have " + f"a '{self._discriminator}' field." + ) - name = fields["name"].default + name = fields[self._discriminator].default self._config_types[name] = config_cls if not isinstance(name, str): - raise ValueError("'name' field must be a string literal.") + raise ValueError( + f"'{self._discriminator}' field must be a string literal." + ) def decorator( func: Callable[Concatenate[T_Config, P_Type], T_Type], @@ -95,10 +101,12 @@ class Registry(Generic[T_Type, P_Type]): ) -> T_Type: """Builds a logic instance from a config object.""" - name = getattr(config, "name") # noqa: B009 + name = getattr(config, self._discriminator) # noqa: B009 if name is None: - raise ValueError("Config does not have a name field") + raise ValueError( + f"Config does not have a '{self._discriminator}' field" + ) if name not in self._registry: raise NotImplementedError( diff --git a/src/batdetect2/data/annotations/__init__.py b/src/batdetect2/data/annotations/__init__.py index a23a2ee..5735a53 100644 --- a/src/batdetect2/data/annotations/__init__.py +++ b/src/batdetect2/data/annotations/__init__.py @@ -18,17 +18,13 @@ from typing import Annotated from pydantic import Field from soundevent import data -from batdetect2.data.annotations.aoef import ( - AOEFAnnotations, - load_aoef_annotated_dataset, -) +from batdetect2.data.annotations.aoef import AOEFAnnotations from batdetect2.data.annotations.batdetect2 import ( AnnotationFilter, BatDetect2FilesAnnotations, BatDetect2MergedAnnotations, - load_batdetect2_files_annotated_dataset, - load_batdetect2_merged_annotated_dataset, ) +from batdetect2.data.annotations.registry import annotation_format_registry from batdetect2.data.annotations.types import AnnotatedDataset __all__ = [ @@ -63,20 +59,20 @@ def load_annotated_dataset( ) -> data.AnnotationSet: """Load annotations for a single data source based on its configuration. - This function acts as a dispatcher. It inspects the type of the input - `source_config` object (which corresponds to a specific annotation format) - and calls the appropriate loading function (e.g., - `load_aoef_annotated_dataset` for `AOEFAnnotations`). + This function acts as a dispatcher. It inspects the format of the input + `dataset` object and delegates to the appropriate format-specific loader + registered in the `annotation_format_registry` (e.g., + `AOEFLoader` for `AOEFAnnotations`). Parameters ---------- - source_config : AnnotationFormats + dataset : AnnotatedDataset The configuration object for the data source, specifying its format and necessary details (like paths). Must be an instance of one of the types included in the `AnnotationFormats` union. base_dir : Path, optional An optional base directory path. If provided, relative paths within - the `source_config` might be resolved relative to this directory by + the `dataset` will be resolved relative to this directory by the underlying loading functions. Defaults to None. Returns @@ -88,23 +84,8 @@ def load_annotated_dataset( Raises ------ NotImplementedError - If the type of the `source_config` object does not match any of the - known format-specific loading functions implemented in the dispatch - logic. + If the `format` field of `dataset` does not match any registered + annotation format loader. """ - - if isinstance(dataset, AOEFAnnotations): - return load_aoef_annotated_dataset(dataset, base_dir=base_dir) - - if isinstance(dataset, BatDetect2MergedAnnotations): - return load_batdetect2_merged_annotated_dataset( - dataset, base_dir=base_dir - ) - - if isinstance(dataset, BatDetect2FilesAnnotations): - return load_batdetect2_files_annotated_dataset( - dataset, - base_dir=base_dir, - ) - - raise NotImplementedError(f"Unknown annotation format: {dataset.name}") + loader = annotation_format_registry.build(dataset) + return loader.load(base_dir=base_dir) diff --git a/src/batdetect2/data/annotations/aoef.py b/src/batdetect2/data/annotations/aoef.py index e9924b8..88fa001 100644 --- a/src/batdetect2/data/annotations/aoef.py +++ b/src/batdetect2/data/annotations/aoef.py @@ -19,10 +19,15 @@ from pydantic import Field from soundevent import data, io from batdetect2.core.configs import BaseConfig -from batdetect2.data.annotations.types import AnnotatedDataset +from batdetect2.data.annotations.registry import annotation_format_registry +from batdetect2.data.annotations.types import ( + AnnotatedDataset, + AnnotationLoader, +) __all__ = [ "AOEFAnnotations", + "AOEFLoader", "load_aoef_annotated_dataset", "AnnotationTaskFilter", ] @@ -82,6 +87,22 @@ class AOEFAnnotations(AnnotatedDataset): ) +class AOEFLoader(AnnotationLoader): + def __init__(self, config: AOEFAnnotations): + self.config = config + + def load( + self, + base_dir: Optional[data.PathLike] = None, + ) -> data.AnnotationSet: + return load_aoef_annotated_dataset(self.config, base_dir=base_dir) + + @annotation_format_registry.register(AOEFAnnotations) + @staticmethod + def from_config(config: AOEFAnnotations): + return AOEFLoader(config) + + def load_aoef_annotated_dataset( dataset: AOEFAnnotations, base_dir: data.PathLike | None = None, diff --git a/src/batdetect2/data/annotations/batdetect2.py b/src/batdetect2/data/annotations/batdetect2.py index 1dd3727..1659e55 100644 --- a/src/batdetect2/data/annotations/batdetect2.py +++ b/src/batdetect2/data/annotations/batdetect2.py @@ -41,7 +41,11 @@ from batdetect2.data.annotations.legacy import ( list_file_annotations, load_file_annotation, ) -from batdetect2.data.annotations.types import AnnotatedDataset +from batdetect2.data.annotations.registry import annotation_format_registry +from batdetect2.data.annotations.types import ( + AnnotatedDataset, + AnnotationLoader, +) PathLike = Path | str | os.PathLike @@ -302,7 +306,7 @@ def load_batdetect2_merged_annotated_dataset( try: ann = FileAnnotation.model_validate(ann) except ValueError as err: - logger.warning(f"Invalid annotation file: {err}") + logger.warning("Invalid annotation file: {err}", err=err) continue if ( @@ -310,17 +314,23 @@ def load_batdetect2_merged_annotated_dataset( and dataset.filter.only_annotated and not ann.annotated ): - logger.debug(f"Skipping incomplete annotation {ann.id}") + logger.debug( + "Skipping incomplete annotation {ann_id}", + ann_id=ann.id, + ) continue if dataset.filter and dataset.filter.exclude_issues and ann.issues: - logger.debug(f"Skipping annotation with issues {ann.id}") + logger.debug( + "Skipping annotation with issues {ann_id}", + ann_id=ann.id, + ) continue try: clip = file_annotation_to_clip(ann, audio_dir=audio_dir) except FileNotFoundError as err: - logger.warning(f"Error loading annotations: {err}") + logger.warning("Error loading annotations: {err}", err=err) continue annotations.append(file_annotation_to_clip_annotation(ann, clip)) @@ -330,3 +340,41 @@ def load_batdetect2_merged_annotated_dataset( description=dataset.description, clip_annotations=annotations, ) + + +class BatDetect2MergedLoader(AnnotationLoader): + def __init__(self, config: BatDetect2MergedAnnotations): + self.config = config + + def load( + self, + base_dir: Optional[PathLike] = None, + ) -> data.AnnotationSet: + return load_batdetect2_merged_annotated_dataset( + self.config, + base_dir=base_dir, + ) + + @annotation_format_registry.register(BatDetect2MergedAnnotations) + @staticmethod + def from_config(config: BatDetect2MergedAnnotations): + return BatDetect2MergedLoader(config) + + +class BatDetect2FilesLoader(AnnotationLoader): + def __init__(self, config: BatDetect2FilesAnnotations): + self.config = config + + def load( + self, + base_dir: Optional[PathLike] = None, + ) -> data.AnnotationSet: + return load_batdetect2_files_annotated_dataset( + self.config, + base_dir=base_dir, + ) + + @annotation_format_registry.register(BatDetect2FilesAnnotations) + @staticmethod + def from_config(config: BatDetect2FilesAnnotations): + return BatDetect2FilesLoader(config) diff --git a/src/batdetect2/data/annotations/registry.py b/src/batdetect2/data/annotations/registry.py new file mode 100644 index 0000000..6200cdc --- /dev/null +++ b/src/batdetect2/data/annotations/registry.py @@ -0,0 +1,11 @@ +from batdetect2.core import Registry +from batdetect2.data.annotations.types import AnnotationLoader + +__all__ = [ + "annotation_format_registry", +] + +annotation_format_registry: Registry[AnnotationLoader, []] = Registry( + "annotation_format", + discriminator="format", +) diff --git a/src/batdetect2/data/annotations/types.py b/src/batdetect2/data/annotations/types.py index 74e769b..a496d59 100644 --- a/src/batdetect2/data/annotations/types.py +++ b/src/batdetect2/data/annotations/types.py @@ -1,9 +1,13 @@ from pathlib import Path +from typing import Optional, Protocol + +from soundevent import data from batdetect2.core.configs import BaseConfig __all__ = [ "AnnotatedDataset", + "AnnotationLoader", ] @@ -34,3 +38,10 @@ class AnnotatedDataset(BaseConfig): name: str audio_dir: Path description: str = "" + + +class AnnotationLoader(Protocol): + def load( + self, + base_dir: Optional[data.PathLike] = None, + ) -> data.AnnotationSet: ...