mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Create data annotation loader registry
This commit is contained in:
parent
3c337a06cb
commit
4b7d23abde
@ -44,11 +44,12 @@ class SimpleRegistry(Generic[T]):
|
||||
class Registry(Generic[T_Type, P_Type]):
|
||||
"""A generic class to create and manage a registry of items."""
|
||||
|
||||
def __init__(self, name: str):
|
||||
def __init__(self, name: str, discriminator: str = "name"):
|
||||
self._name = name
|
||||
self._registry: dict[
|
||||
str, Callable[Concatenate[..., P_Type], T_Type]
|
||||
] = {}
|
||||
self._discriminator = discriminator
|
||||
self._config_types: dict[str, Type[BaseModel]] = {}
|
||||
|
||||
def register(
|
||||
@ -57,15 +58,20 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
):
|
||||
fields = config_cls.model_fields
|
||||
|
||||
if "name" not in fields:
|
||||
raise ValueError("Configuration object must have a 'name' field.")
|
||||
if self._discriminator not in fields:
|
||||
raise ValueError(
|
||||
"Configuration object must have "
|
||||
f"a '{self._discriminator}' field."
|
||||
)
|
||||
|
||||
name = fields["name"].default
|
||||
name = fields[self._discriminator].default
|
||||
|
||||
self._config_types[name] = config_cls
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise ValueError("'name' field must be a string literal.")
|
||||
raise ValueError(
|
||||
f"'{self._discriminator}' field must be a string literal."
|
||||
)
|
||||
|
||||
def decorator(
|
||||
func: Callable[Concatenate[T_Config, P_Type], T_Type],
|
||||
@ -95,10 +101,12 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
) -> T_Type:
|
||||
"""Builds a logic instance from a config object."""
|
||||
|
||||
name = getattr(config, "name") # noqa: B009
|
||||
name = getattr(config, self._discriminator) # noqa: B009
|
||||
|
||||
if name is None:
|
||||
raise ValueError("Config does not have a name field")
|
||||
raise ValueError(
|
||||
f"Config does not have a '{self._discriminator}' field"
|
||||
)
|
||||
|
||||
if name not in self._registry:
|
||||
raise NotImplementedError(
|
||||
|
||||
@ -18,17 +18,13 @@ from typing import Annotated
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.annotations.aoef import (
|
||||
AOEFAnnotations,
|
||||
load_aoef_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.annotations.aoef import AOEFAnnotations
|
||||
from batdetect2.data.annotations.batdetect2 import (
|
||||
AnnotationFilter,
|
||||
BatDetect2FilesAnnotations,
|
||||
BatDetect2MergedAnnotations,
|
||||
load_batdetect2_files_annotated_dataset,
|
||||
load_batdetect2_merged_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.annotations.registry import annotation_format_registry
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
__all__ = [
|
||||
@ -63,20 +59,20 @@ def load_annotated_dataset(
|
||||
) -> data.AnnotationSet:
|
||||
"""Load annotations for a single data source based on its configuration.
|
||||
|
||||
This function acts as a dispatcher. It inspects the type of the input
|
||||
`source_config` object (which corresponds to a specific annotation format)
|
||||
and calls the appropriate loading function (e.g.,
|
||||
`load_aoef_annotated_dataset` for `AOEFAnnotations`).
|
||||
This function acts as a dispatcher. It inspects the format of the input
|
||||
`dataset` object and delegates to the appropriate format-specific loader
|
||||
registered in the `annotation_format_registry` (e.g.,
|
||||
`AOEFLoader` for `AOEFAnnotations`).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_config : AnnotationFormats
|
||||
dataset : AnnotatedDataset
|
||||
The configuration object for the data source, specifying its format
|
||||
and necessary details (like paths). Must be an instance of one of the
|
||||
types included in the `AnnotationFormats` union.
|
||||
base_dir : Path, optional
|
||||
An optional base directory path. If provided, relative paths within
|
||||
the `source_config` might be resolved relative to this directory by
|
||||
the `dataset` will be resolved relative to this directory by
|
||||
the underlying loading functions. Defaults to None.
|
||||
|
||||
Returns
|
||||
@ -88,23 +84,8 @@ def load_annotated_dataset(
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
If the type of the `source_config` object does not match any of the
|
||||
known format-specific loading functions implemented in the dispatch
|
||||
logic.
|
||||
If the `format` field of `dataset` does not match any registered
|
||||
annotation format loader.
|
||||
"""
|
||||
|
||||
if isinstance(dataset, AOEFAnnotations):
|
||||
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)
|
||||
|
||||
if isinstance(dataset, BatDetect2MergedAnnotations):
|
||||
return load_batdetect2_merged_annotated_dataset(
|
||||
dataset, base_dir=base_dir
|
||||
)
|
||||
|
||||
if isinstance(dataset, BatDetect2FilesAnnotations):
|
||||
return load_batdetect2_files_annotated_dataset(
|
||||
dataset,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
raise NotImplementedError(f"Unknown annotation format: {dataset.name}")
|
||||
loader = annotation_format_registry.build(dataset)
|
||||
return loader.load(base_dir=base_dir)
|
||||
|
||||
@ -19,10 +19,15 @@ from pydantic import Field
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
from batdetect2.data.annotations.registry import annotation_format_registry
|
||||
from batdetect2.data.annotations.types import (
|
||||
AnnotatedDataset,
|
||||
AnnotationLoader,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AOEFAnnotations",
|
||||
"AOEFLoader",
|
||||
"load_aoef_annotated_dataset",
|
||||
"AnnotationTaskFilter",
|
||||
]
|
||||
@ -82,6 +87,22 @@ class AOEFAnnotations(AnnotatedDataset):
|
||||
)
|
||||
|
||||
|
||||
class AOEFLoader(AnnotationLoader):
|
||||
def __init__(self, config: AOEFAnnotations):
|
||||
self.config = config
|
||||
|
||||
def load(
|
||||
self,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
) -> data.AnnotationSet:
|
||||
return load_aoef_annotated_dataset(self.config, base_dir=base_dir)
|
||||
|
||||
@annotation_format_registry.register(AOEFAnnotations)
|
||||
@staticmethod
|
||||
def from_config(config: AOEFAnnotations):
|
||||
return AOEFLoader(config)
|
||||
|
||||
|
||||
def load_aoef_annotated_dataset(
|
||||
dataset: AOEFAnnotations,
|
||||
base_dir: data.PathLike | None = None,
|
||||
|
||||
@ -41,7 +41,11 @@ from batdetect2.data.annotations.legacy import (
|
||||
list_file_annotations,
|
||||
load_file_annotation,
|
||||
)
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
from batdetect2.data.annotations.registry import annotation_format_registry
|
||||
from batdetect2.data.annotations.types import (
|
||||
AnnotatedDataset,
|
||||
AnnotationLoader,
|
||||
)
|
||||
|
||||
PathLike = Path | str | os.PathLike
|
||||
|
||||
@ -302,7 +306,7 @@ def load_batdetect2_merged_annotated_dataset(
|
||||
try:
|
||||
ann = FileAnnotation.model_validate(ann)
|
||||
except ValueError as err:
|
||||
logger.warning(f"Invalid annotation file: {err}")
|
||||
logger.warning("Invalid annotation file: {err}", err=err)
|
||||
continue
|
||||
|
||||
if (
|
||||
@ -310,17 +314,23 @@ def load_batdetect2_merged_annotated_dataset(
|
||||
and dataset.filter.only_annotated
|
||||
and not ann.annotated
|
||||
):
|
||||
logger.debug(f"Skipping incomplete annotation {ann.id}")
|
||||
logger.debug(
|
||||
"Skipping incomplete annotation {ann_id}",
|
||||
ann_id=ann.id,
|
||||
)
|
||||
continue
|
||||
|
||||
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
|
||||
logger.debug(f"Skipping annotation with issues {ann.id}")
|
||||
logger.debug(
|
||||
"Skipping annotation with issues {ann_id}",
|
||||
ann_id=ann.id,
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
||||
except FileNotFoundError as err:
|
||||
logger.warning(f"Error loading annotations: {err}")
|
||||
logger.warning("Error loading annotations: {err}", err=err)
|
||||
continue
|
||||
|
||||
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
||||
@ -330,3 +340,41 @@ def load_batdetect2_merged_annotated_dataset(
|
||||
description=dataset.description,
|
||||
clip_annotations=annotations,
|
||||
)
|
||||
|
||||
|
||||
class BatDetect2MergedLoader(AnnotationLoader):
|
||||
def __init__(self, config: BatDetect2MergedAnnotations):
|
||||
self.config = config
|
||||
|
||||
def load(
|
||||
self,
|
||||
base_dir: Optional[PathLike] = None,
|
||||
) -> data.AnnotationSet:
|
||||
return load_batdetect2_merged_annotated_dataset(
|
||||
self.config,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
@annotation_format_registry.register(BatDetect2MergedAnnotations)
|
||||
@staticmethod
|
||||
def from_config(config: BatDetect2MergedAnnotations):
|
||||
return BatDetect2MergedLoader(config)
|
||||
|
||||
|
||||
class BatDetect2FilesLoader(AnnotationLoader):
|
||||
def __init__(self, config: BatDetect2FilesAnnotations):
|
||||
self.config = config
|
||||
|
||||
def load(
|
||||
self,
|
||||
base_dir: Optional[PathLike] = None,
|
||||
) -> data.AnnotationSet:
|
||||
return load_batdetect2_files_annotated_dataset(
|
||||
self.config,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
@annotation_format_registry.register(BatDetect2FilesAnnotations)
|
||||
@staticmethod
|
||||
def from_config(config: BatDetect2FilesAnnotations):
|
||||
return BatDetect2FilesLoader(config)
|
||||
|
||||
11
src/batdetect2/data/annotations/registry.py
Normal file
11
src/batdetect2/data/annotations/registry.py
Normal file
@ -0,0 +1,11 @@
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.data.annotations.types import AnnotationLoader
|
||||
|
||||
__all__ = [
|
||||
"annotation_format_registry",
|
||||
]
|
||||
|
||||
annotation_format_registry: Registry[AnnotationLoader, []] = Registry(
|
||||
"annotation_format",
|
||||
discriminator="format",
|
||||
)
|
||||
@ -1,9 +1,13 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Protocol
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"AnnotatedDataset",
|
||||
"AnnotationLoader",
|
||||
]
|
||||
|
||||
|
||||
@ -34,3 +38,10 @@ class AnnotatedDataset(BaseConfig):
|
||||
name: str
|
||||
audio_dir: Path
|
||||
description: str = ""
|
||||
|
||||
|
||||
class AnnotationLoader(Protocol):
|
||||
def load(
|
||||
self,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
) -> data.AnnotationSet: ...
|
||||
|
||||
Loading…
Reference in New Issue
Block a user