diff --git a/src/batdetect2/data/conditions/__init__.py b/src/batdetect2/data/conditions/__init__.py index 1115c9c..e6470c9 100644 --- a/src/batdetect2/data/conditions/__init__.py +++ b/src/batdetect2/data/conditions/__init__.py @@ -9,10 +9,14 @@ from batdetect2.data.conditions.clips import ( build_clip_annotation_condition, ) from batdetect2.data.conditions.common import ( + CsvList, HasAllTagsConfig, HasAnyTagConfig, HasTagConfig, IdInListConfig, + JsonList, + ListFormatConfig, + TxtList, ) from batdetect2.data.conditions.recordings import ( RecordingAllOfConfig, @@ -46,12 +50,15 @@ __all__ = [ "ClipAnnotationConditionImportConfig", "ClipAnyOfConfig", "ClipNotConfig", + "CsvList", "DurationConfig", "FrequencyConfig", "HasAllTagsConfig", "HasAnyTagConfig", "HasTagConfig", "IdInListConfig", + "JsonList", + "ListFormatConfig", "NotConfig", "Operator", "RecordingCondition", @@ -64,6 +71,7 @@ __all__ = [ "SoundEventCondition", "SoundEventConditionConfig", "SoundEventConditionImportConfig", + "TxtList", "build_clip_annotation_condition", "build_recording_condition", "build_sound_event_condition", diff --git a/src/batdetect2/data/conditions/clips.py b/src/batdetect2/data/conditions/clips.py index 4ee187b..93dcbbc 100644 --- a/src/batdetect2/data/conditions/clips.py +++ b/src/batdetect2/data/conditions/clips.py @@ -87,10 +87,16 @@ class RecordingSatisfies: return RecordingSatisfies(condition) -register_has_tag_condition(clip_annotation_conditions)(HasTagConfig) -register_has_all_tags_condition(clip_annotation_conditions)(HasAllTagsConfig) -register_has_any_tag_condition(clip_annotation_conditions)(HasAnyTagConfig) -register_id_in_list_condition(clip_annotation_conditions)(IdInListConfig) +register_has_tag_condition(clip_annotation_conditions, HasTagConfig) +register_has_all_tags_condition( + clip_annotation_conditions, + HasAllTagsConfig, +) +register_has_any_tag_condition( + clip_annotation_conditions, + HasAnyTagConfig, +) +register_id_in_list_condition(clip_annotation_conditions, IdInListConfig) @register_all_of_condition(clip_annotation_conditions) diff --git a/src/batdetect2/data/conditions/common.py b/src/batdetect2/data/conditions/common.py index 2f18eb0..f340ba2 100644 --- a/src/batdetect2/data/conditions/common.py +++ b/src/batdetect2/data/conditions/common.py @@ -1,10 +1,11 @@ +import csv import json from collections.abc import Callable, Sequence from pathlib import Path -from typing import Generic, Literal, ParamSpec, Protocol, TypeVar +from typing import Annotated, Generic, Literal, ParamSpec, Protocol, TypeVar from uuid import UUID -from pydantic import BaseModel +from pydantic import BaseModel, Field, model_validator from soundevent import data from batdetect2.core.configs import BaseConfig @@ -14,6 +15,7 @@ __all__ = [ "AllOf", "AnyOf", "Condition", + "CsvList", "HasAllTags", "HasAllTagsConfig", "HasAnyTag", @@ -22,11 +24,16 @@ __all__ = [ "HasTagConfig", "IdInList", "IdInListConfig", + "JsonList", + "ListLoader", + "ListFormatConfig", "MultiConditionConfigBase", "Not", "NotConditionConfigBase", "ObjectWithTags", "ObjectWithUUID", + "TxtList", + "build_list_loader", "register_all_of_condition", "register_any_of_condition", "register_has_all_tags_condition", @@ -146,112 +153,208 @@ class HasAnyTagConfig(BaseConfig): tags: list[data.Tag] +class JsonList(BaseConfig): + name: Literal["json"] = "json" + field: str | None = None + + +class TxtList(BaseConfig): + name: Literal["txt"] = "txt" + + +class CsvList(BaseConfig): + name: Literal["csv"] = "csv" + column: str + + +ListFormatConfig = Annotated[ + JsonList | TxtList | CsvList, + Field(discriminator="name"), +] + + +ListLoader = Callable[[Path], list[str]] + +list_loaders: Registry[ListLoader, []] = Registry("list_loader") + + class IdInListConfig(BaseConfig): name: Literal["id_in_list"] = "id_in_list" path: Path - list_format: Literal["json", "txt"] = "json" + format: ListFormatConfig = JsonList() + + @model_validator(mode="before") + @classmethod + def _normalize_format(cls, values): + if not isinstance(values, dict): + return values + + format_config = values.get("format") + + if isinstance(format_config, str): + values = values.copy() + config_class = list_loaders.get_config_type(format_config) + values["format"] = config_class().model_dump() + + return values -def _load_ids( - path: Path, - list_format: Literal["json", "txt"], -) -> list[str]: - if list_format == "json": +class JsonListLoader: + def __init__(self, field: str | None): + self.field = field + + def __call__(self, path: Path) -> list[str]: content = json.loads(path.read_text()) + if self.field is not None: + if not isinstance(content, dict): + raise TypeError( + "Expected JSON object with field for 'id_in_list'." + ) + + if self.field not in content: + raise KeyError(f"Field '{self.field}' not found in '{path}'.") + + content = content[self.field] + if not isinstance(content, list): raise TypeError("Expected JSON list with IDs for 'id_in_list'.") return [str(value) for value in content] - return [ - line.strip() for line in path.read_text().splitlines() if line.strip() - ] + @list_loaders.register(JsonList) + @staticmethod + def from_config(config: JsonList) -> ListLoader: + return JsonListLoader(field=config.field) + + +class TxtListLoader: + def __call__(self, path: Path) -> list[str]: + return [ + line.strip() + for line in path.read_text().splitlines() + if line.strip() + ] + + @list_loaders.register(TxtList) + @staticmethod + def from_config(config: TxtList) -> ListLoader: + return TxtListLoader() + + +class CsvListLoader: + def __init__(self, column: str): + self.column = column + + def __call__(self, path: Path) -> list[str]: + with path.open("r", newline="") as csv_file: + reader = csv.DictReader(csv_file) + + if reader.fieldnames is None: + raise ValueError( + f"Expected CSV header row for 'id_in_list' in '{path}'." + ) + + if self.column not in reader.fieldnames: + raise ValueError( + f"Column '{self.column}' not found in '{path}'." + ) + + values = [] + for row in reader: + value = row.get(self.column) + + if value is None: + continue + + value = value.strip() + + if not value: + continue + + values.append(value) + + return values + + @list_loaders.register(CsvList) + @staticmethod + def from_config(config: CsvList) -> ListLoader: + return CsvListLoader(column=config.column) + + +def build_list_loader(config: ListFormatConfig) -> ListLoader: + return list_loaders.build(config) def register_id_in_list_condition( registry: Registry[Condition[UUIDObject], [data.PathLike | None]], -) -> Callable[[type[IdInListConfig]], type[IdInListConfig]]: - def decorator(config_cls: type[IdInListConfig]) -> type[IdInListConfig]: - @registry.register(config_cls) - def builder( - config: IdInListConfig, - base_dir: data.PathLike | None = None, - ) -> Condition[UUIDObject]: - path = config.path + config_cls: type[IdInListConfig], +) -> None: + def builder( + config: IdInListConfig, + base_dir: data.PathLike | None = None, + ) -> Condition[UUIDObject]: + path = config.path - if base_dir is not None and not path.is_absolute(): - path = Path(base_dir) / path + if base_dir is not None and not path.is_absolute(): + path = Path(base_dir) / path - ids = set() - for index, value in enumerate(_load_ids(path, config.list_format)): - try: - ids.add(UUID(value)) - except ValueError as err: - raise ValueError( - f"Invalid ID at index {index} in '{path}': {value!r}." - ) from err + ids = set() + loader = build_list_loader(config.format) + values = loader(path) + for index, value in enumerate(values): + try: + ids.add(UUID(value)) + except ValueError as err: + raise ValueError( + f"Invalid ID at index {index} in '{path}': {value!r}." + ) from err - return IdInList(ids) + return IdInList(ids) - return config_cls - - return decorator + registry.register(config_cls)(builder) def register_has_tag_condition( registry: Registry[Condition[TaggedObject], P], -) -> Callable[[type[HasTagConfig]], type[HasTagConfig]]: - def decorator(config_cls: type[HasTagConfig]) -> type[HasTagConfig]: - @registry.register(config_cls) - def builder( - config: HasTagConfig, - *args: P.args, - **kwargs: P.kwargs, - ) -> Condition[TaggedObject]: - return HasTag(config.tag) + config_cls: type[HasTagConfig], +) -> None: + def builder( + config: HasTagConfig, + *args: P.args, + **kwargs: P.kwargs, + ) -> Condition[TaggedObject]: + return HasTag(config.tag) - return config_cls - - return decorator + registry.register(config_cls)(builder) def register_has_all_tags_condition( registry: Registry[Condition[TaggedObject], P], -) -> Callable[[type[HasAllTagsConfig]], type[HasAllTagsConfig]]: - def decorator( - config_cls: type[HasAllTagsConfig], - ) -> type[HasAllTagsConfig]: - @registry.register(config_cls) - def builder( - config: HasAllTagsConfig, - *args: P.args, - **kwargs: P.kwargs, - ) -> Condition[TaggedObject]: - return HasAllTags(config.tags) + config_cls: type[HasAllTagsConfig], +) -> None: + def builder( + config: HasAllTagsConfig, + *args: P.args, + **kwargs: P.kwargs, + ) -> Condition[TaggedObject]: + return HasAllTags(config.tags) - return config_cls - - return decorator + registry.register(config_cls)(builder) def register_has_any_tag_condition( registry: Registry[Condition[TaggedObject], P], -) -> Callable[[type[HasAnyTagConfig]], type[HasAnyTagConfig]]: - def decorator( - config_cls: type[HasAnyTagConfig], - ) -> type[HasAnyTagConfig]: - @registry.register(config_cls) - def builder( - config: HasAnyTagConfig, - *args: P.args, - **kwargs: P.kwargs, - ) -> Condition[TaggedObject]: - return HasAnyTag(config.tags) + config_cls: type[HasAnyTagConfig], +) -> None: + def builder( + config: HasAnyTagConfig, + *args: P.args, + **kwargs: P.kwargs, + ) -> Condition[TaggedObject]: + return HasAnyTag(config.tags) - return config_cls - - return decorator + registry.register(config_cls)(builder) def register_not_condition( diff --git a/src/batdetect2/data/conditions/recordings.py b/src/batdetect2/data/conditions/recordings.py index dcb4762..d6757f1 100644 --- a/src/batdetect2/data/conditions/recordings.py +++ b/src/batdetect2/data/conditions/recordings.py @@ -54,10 +54,10 @@ class RecordingConditionImportConfig(ImportConfig): name: Literal["import"] = "import" -register_id_in_list_condition(recording_conditions)(IdInListConfig) -register_has_tag_condition(recording_conditions)(HasTagConfig) -register_has_all_tags_condition(recording_conditions)(HasAllTagsConfig) -register_has_any_tag_condition(recording_conditions)(HasAnyTagConfig) +register_id_in_list_condition(recording_conditions, IdInListConfig) +register_has_tag_condition(recording_conditions, HasTagConfig) +register_has_all_tags_condition(recording_conditions, HasAllTagsConfig) +register_has_any_tag_condition(recording_conditions, HasAnyTagConfig) @register_all_of_condition(recording_conditions) diff --git a/src/batdetect2/data/conditions/sound_events.py b/src/batdetect2/data/conditions/sound_events.py index 1fb38c0..5786d16 100644 --- a/src/batdetect2/data/conditions/sound_events.py +++ b/src/batdetect2/data/conditions/sound_events.py @@ -63,10 +63,10 @@ class SoundEventConditionImportConfig(ImportConfig): name: Literal["import"] = "import" -register_has_tag_condition(sound_event_conditions)(HasTagConfig) -register_has_all_tags_condition(sound_event_conditions)(HasAllTagsConfig) -register_has_any_tag_condition(sound_event_conditions)(HasAnyTagConfig) -register_id_in_list_condition(sound_event_conditions)(IdInListConfig) +register_has_tag_condition(sound_event_conditions, HasTagConfig) +register_has_all_tags_condition(sound_event_conditions, HasAllTagsConfig) +register_has_any_tag_condition(sound_event_conditions, HasAnyTagConfig) +register_id_in_list_condition(sound_event_conditions, IdInListConfig) Operator = Literal["gt", "gte", "lt", "lte", "eq"] diff --git a/tests/test_data/test_conditions/test_recording.py b/tests/test_data/test_conditions/test_recording.py index 2290bdb..35f8e84 100644 --- a/tests/test_data/test_conditions/test_recording.py +++ b/tests/test_data/test_conditions/test_recording.py @@ -121,7 +121,62 @@ def test_id_in_list_condition_supports_txt_format( f""" name: id_in_list path: {ids_path} - list_format: txt + format: txt + """, + ) + + assert condition(recording_a) + assert not condition(recording_b) + + +def test_id_in_list_condition_supports_json_field( + tmp_path: Path, + create_recording, +) -> None: + recording_a = create_recording(path=tmp_path / "a.wav") + recording_b = create_recording(path=tmp_path / "b.wav") + ids_path = tmp_path / "recording_ids.json" + ids_path.write_text( + json.dumps( + { + "train": [str(recording_a.uuid)], + "val": [str(recording_b.uuid)], + } + ) + ) + + condition = build_recording_condition_from_yaml( + tmp_path, + f""" + name: id_in_list + path: {ids_path} + format: + name: json + field: train + """, + ) + + assert condition(recording_a) + assert not condition(recording_b) + + +def test_id_in_list_condition_supports_csv_column( + tmp_path: Path, + create_recording, +) -> None: + recording_a = create_recording(path=tmp_path / "a.wav") + recording_b = create_recording(path=tmp_path / "b.wav") + ids_path = tmp_path / "recording_ids.csv" + ids_path.write_text(f"recording_uuid\n{recording_a.uuid}\n") + + condition = build_recording_condition_from_yaml( + tmp_path, + f""" + name: id_in_list + path: {ids_path} + format: + name: csv + column: recording_uuid """, )