mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Create data annotation loader registry
This commit is contained in:
parent
3c337a06cb
commit
4b7d23abde
@ -44,11 +44,12 @@ class SimpleRegistry(Generic[T]):
|
|||||||
class Registry(Generic[T_Type, P_Type]):
|
class Registry(Generic[T_Type, P_Type]):
|
||||||
"""A generic class to create and manage a registry of items."""
|
"""A generic class to create and manage a registry of items."""
|
||||||
|
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str, discriminator: str = "name"):
|
||||||
self._name = name
|
self._name = name
|
||||||
self._registry: dict[
|
self._registry: dict[
|
||||||
str, Callable[Concatenate[..., P_Type], T_Type]
|
str, Callable[Concatenate[..., P_Type], T_Type]
|
||||||
] = {}
|
] = {}
|
||||||
|
self._discriminator = discriminator
|
||||||
self._config_types: dict[str, Type[BaseModel]] = {}
|
self._config_types: dict[str, Type[BaseModel]] = {}
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
@ -57,15 +58,20 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
):
|
):
|
||||||
fields = config_cls.model_fields
|
fields = config_cls.model_fields
|
||||||
|
|
||||||
if "name" not in fields:
|
if self._discriminator not in fields:
|
||||||
raise ValueError("Configuration object must have a 'name' field.")
|
raise ValueError(
|
||||||
|
"Configuration object must have "
|
||||||
|
f"a '{self._discriminator}' field."
|
||||||
|
)
|
||||||
|
|
||||||
name = fields["name"].default
|
name = fields[self._discriminator].default
|
||||||
|
|
||||||
self._config_types[name] = config_cls
|
self._config_types[name] = config_cls
|
||||||
|
|
||||||
if not isinstance(name, str):
|
if not isinstance(name, str):
|
||||||
raise ValueError("'name' field must be a string literal.")
|
raise ValueError(
|
||||||
|
f"'{self._discriminator}' field must be a string literal."
|
||||||
|
)
|
||||||
|
|
||||||
def decorator(
|
def decorator(
|
||||||
func: Callable[Concatenate[T_Config, P_Type], T_Type],
|
func: Callable[Concatenate[T_Config, P_Type], T_Type],
|
||||||
@ -95,10 +101,12 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
) -> T_Type:
|
) -> T_Type:
|
||||||
"""Builds a logic instance from a config object."""
|
"""Builds a logic instance from a config object."""
|
||||||
|
|
||||||
name = getattr(config, "name") # noqa: B009
|
name = getattr(config, self._discriminator) # noqa: B009
|
||||||
|
|
||||||
if name is None:
|
if name is None:
|
||||||
raise ValueError("Config does not have a name field")
|
raise ValueError(
|
||||||
|
f"Config does not have a '{self._discriminator}' field"
|
||||||
|
)
|
||||||
|
|
||||||
if name not in self._registry:
|
if name not in self._registry:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|||||||
@ -18,17 +18,13 @@ from typing import Annotated
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.data.annotations.aoef import (
|
from batdetect2.data.annotations.aoef import AOEFAnnotations
|
||||||
AOEFAnnotations,
|
|
||||||
load_aoef_annotated_dataset,
|
|
||||||
)
|
|
||||||
from batdetect2.data.annotations.batdetect2 import (
|
from batdetect2.data.annotations.batdetect2 import (
|
||||||
AnnotationFilter,
|
AnnotationFilter,
|
||||||
BatDetect2FilesAnnotations,
|
BatDetect2FilesAnnotations,
|
||||||
BatDetect2MergedAnnotations,
|
BatDetect2MergedAnnotations,
|
||||||
load_batdetect2_files_annotated_dataset,
|
|
||||||
load_batdetect2_merged_annotated_dataset,
|
|
||||||
)
|
)
|
||||||
|
from batdetect2.data.annotations.registry import annotation_format_registry
|
||||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -63,20 +59,20 @@ def load_annotated_dataset(
|
|||||||
) -> data.AnnotationSet:
|
) -> data.AnnotationSet:
|
||||||
"""Load annotations for a single data source based on its configuration.
|
"""Load annotations for a single data source based on its configuration.
|
||||||
|
|
||||||
This function acts as a dispatcher. It inspects the type of the input
|
This function acts as a dispatcher. It inspects the format of the input
|
||||||
`source_config` object (which corresponds to a specific annotation format)
|
`dataset` object and delegates to the appropriate format-specific loader
|
||||||
and calls the appropriate loading function (e.g.,
|
registered in the `annotation_format_registry` (e.g.,
|
||||||
`load_aoef_annotated_dataset` for `AOEFAnnotations`).
|
`AOEFLoader` for `AOEFAnnotations`).
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
source_config : AnnotationFormats
|
dataset : AnnotatedDataset
|
||||||
The configuration object for the data source, specifying its format
|
The configuration object for the data source, specifying its format
|
||||||
and necessary details (like paths). Must be an instance of one of the
|
and necessary details (like paths). Must be an instance of one of the
|
||||||
types included in the `AnnotationFormats` union.
|
types included in the `AnnotationFormats` union.
|
||||||
base_dir : Path, optional
|
base_dir : Path, optional
|
||||||
An optional base directory path. If provided, relative paths within
|
An optional base directory path. If provided, relative paths within
|
||||||
the `source_config` might be resolved relative to this directory by
|
the `dataset` will be resolved relative to this directory by
|
||||||
the underlying loading functions. Defaults to None.
|
the underlying loading functions. Defaults to None.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@ -88,23 +84,8 @@ def load_annotated_dataset(
|
|||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
NotImplementedError
|
NotImplementedError
|
||||||
If the type of the `source_config` object does not match any of the
|
If the `format` field of `dataset` does not match any registered
|
||||||
known format-specific loading functions implemented in the dispatch
|
annotation format loader.
|
||||||
logic.
|
|
||||||
"""
|
"""
|
||||||
|
loader = annotation_format_registry.build(dataset)
|
||||||
if isinstance(dataset, AOEFAnnotations):
|
return loader.load(base_dir=base_dir)
|
||||||
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)
|
|
||||||
|
|
||||||
if isinstance(dataset, BatDetect2MergedAnnotations):
|
|
||||||
return load_batdetect2_merged_annotated_dataset(
|
|
||||||
dataset, base_dir=base_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(dataset, BatDetect2FilesAnnotations):
|
|
||||||
return load_batdetect2_files_annotated_dataset(
|
|
||||||
dataset,
|
|
||||||
base_dir=base_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise NotImplementedError(f"Unknown annotation format: {dataset.name}")
|
|
||||||
|
|||||||
@ -19,10 +19,15 @@ from pydantic import Field
|
|||||||
from soundevent import data, io
|
from soundevent import data, io
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
from batdetect2.data.annotations.registry import annotation_format_registry
|
||||||
|
from batdetect2.data.annotations.types import (
|
||||||
|
AnnotatedDataset,
|
||||||
|
AnnotationLoader,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AOEFAnnotations",
|
"AOEFAnnotations",
|
||||||
|
"AOEFLoader",
|
||||||
"load_aoef_annotated_dataset",
|
"load_aoef_annotated_dataset",
|
||||||
"AnnotationTaskFilter",
|
"AnnotationTaskFilter",
|
||||||
]
|
]
|
||||||
@ -82,6 +87,22 @@ class AOEFAnnotations(AnnotatedDataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AOEFLoader(AnnotationLoader):
|
||||||
|
def __init__(self, config: AOEFAnnotations):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
base_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> data.AnnotationSet:
|
||||||
|
return load_aoef_annotated_dataset(self.config, base_dir=base_dir)
|
||||||
|
|
||||||
|
@annotation_format_registry.register(AOEFAnnotations)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: AOEFAnnotations):
|
||||||
|
return AOEFLoader(config)
|
||||||
|
|
||||||
|
|
||||||
def load_aoef_annotated_dataset(
|
def load_aoef_annotated_dataset(
|
||||||
dataset: AOEFAnnotations,
|
dataset: AOEFAnnotations,
|
||||||
base_dir: data.PathLike | None = None,
|
base_dir: data.PathLike | None = None,
|
||||||
|
|||||||
@ -41,7 +41,11 @@ from batdetect2.data.annotations.legacy import (
|
|||||||
list_file_annotations,
|
list_file_annotations,
|
||||||
load_file_annotation,
|
load_file_annotation,
|
||||||
)
|
)
|
||||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
from batdetect2.data.annotations.registry import annotation_format_registry
|
||||||
|
from batdetect2.data.annotations.types import (
|
||||||
|
AnnotatedDataset,
|
||||||
|
AnnotationLoader,
|
||||||
|
)
|
||||||
|
|
||||||
PathLike = Path | str | os.PathLike
|
PathLike = Path | str | os.PathLike
|
||||||
|
|
||||||
@ -302,7 +306,7 @@ def load_batdetect2_merged_annotated_dataset(
|
|||||||
try:
|
try:
|
||||||
ann = FileAnnotation.model_validate(ann)
|
ann = FileAnnotation.model_validate(ann)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
logger.warning(f"Invalid annotation file: {err}")
|
logger.warning("Invalid annotation file: {err}", err=err)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -310,17 +314,23 @@ def load_batdetect2_merged_annotated_dataset(
|
|||||||
and dataset.filter.only_annotated
|
and dataset.filter.only_annotated
|
||||||
and not ann.annotated
|
and not ann.annotated
|
||||||
):
|
):
|
||||||
logger.debug(f"Skipping incomplete annotation {ann.id}")
|
logger.debug(
|
||||||
|
"Skipping incomplete annotation {ann_id}",
|
||||||
|
ann_id=ann.id,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
|
if dataset.filter and dataset.filter.exclude_issues and ann.issues:
|
||||||
logger.debug(f"Skipping annotation with issues {ann.id}")
|
logger.debug(
|
||||||
|
"Skipping annotation with issues {ann_id}",
|
||||||
|
ann_id=ann.id,
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
|
||||||
except FileNotFoundError as err:
|
except FileNotFoundError as err:
|
||||||
logger.warning(f"Error loading annotations: {err}")
|
logger.warning("Error loading annotations: {err}", err=err)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
annotations.append(file_annotation_to_clip_annotation(ann, clip))
|
||||||
@ -330,3 +340,41 @@ def load_batdetect2_merged_annotated_dataset(
|
|||||||
description=dataset.description,
|
description=dataset.description,
|
||||||
clip_annotations=annotations,
|
clip_annotations=annotations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BatDetect2MergedLoader(AnnotationLoader):
|
||||||
|
def __init__(self, config: BatDetect2MergedAnnotations):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
base_dir: Optional[PathLike] = None,
|
||||||
|
) -> data.AnnotationSet:
|
||||||
|
return load_batdetect2_merged_annotated_dataset(
|
||||||
|
self.config,
|
||||||
|
base_dir=base_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
@annotation_format_registry.register(BatDetect2MergedAnnotations)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: BatDetect2MergedAnnotations):
|
||||||
|
return BatDetect2MergedLoader(config)
|
||||||
|
|
||||||
|
|
||||||
|
class BatDetect2FilesLoader(AnnotationLoader):
|
||||||
|
def __init__(self, config: BatDetect2FilesAnnotations):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
base_dir: Optional[PathLike] = None,
|
||||||
|
) -> data.AnnotationSet:
|
||||||
|
return load_batdetect2_files_annotated_dataset(
|
||||||
|
self.config,
|
||||||
|
base_dir=base_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
@annotation_format_registry.register(BatDetect2FilesAnnotations)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: BatDetect2FilesAnnotations):
|
||||||
|
return BatDetect2FilesLoader(config)
|
||||||
|
|||||||
11
src/batdetect2/data/annotations/registry.py
Normal file
11
src/batdetect2/data/annotations/registry.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from batdetect2.core import Registry
|
||||||
|
from batdetect2.data.annotations.types import AnnotationLoader
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"annotation_format_registry",
|
||||||
|
]
|
||||||
|
|
||||||
|
annotation_format_registry: Registry[AnnotationLoader, []] = Registry(
|
||||||
|
"annotation_format",
|
||||||
|
discriminator="format",
|
||||||
|
)
|
||||||
@ -1,9 +1,13 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional, Protocol
|
||||||
|
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnnotatedDataset",
|
"AnnotatedDataset",
|
||||||
|
"AnnotationLoader",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -34,3 +38,10 @@ class AnnotatedDataset(BaseConfig):
|
|||||||
name: str
|
name: str
|
||||||
audio_dir: Path
|
audio_dir: Path
|
||||||
description: str = ""
|
description: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationLoader(Protocol):
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
base_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> data.AnnotationSet: ...
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user