Add csv list format

This commit is contained in:
mbsantiago 2026-04-04 10:23:14 +01:00
parent c67d9cbba0
commit 1579bbc6c5
6 changed files with 258 additions and 86 deletions

View File

@ -9,10 +9,14 @@ from batdetect2.data.conditions.clips import (
build_clip_annotation_condition, build_clip_annotation_condition,
) )
from batdetect2.data.conditions.common import ( from batdetect2.data.conditions.common import (
CsvList,
HasAllTagsConfig, HasAllTagsConfig,
HasAnyTagConfig, HasAnyTagConfig,
HasTagConfig, HasTagConfig,
IdInListConfig, IdInListConfig,
JsonList,
ListFormatConfig,
TxtList,
) )
from batdetect2.data.conditions.recordings import ( from batdetect2.data.conditions.recordings import (
RecordingAllOfConfig, RecordingAllOfConfig,
@ -46,12 +50,15 @@ __all__ = [
"ClipAnnotationConditionImportConfig", "ClipAnnotationConditionImportConfig",
"ClipAnyOfConfig", "ClipAnyOfConfig",
"ClipNotConfig", "ClipNotConfig",
"CsvList",
"DurationConfig", "DurationConfig",
"FrequencyConfig", "FrequencyConfig",
"HasAllTagsConfig", "HasAllTagsConfig",
"HasAnyTagConfig", "HasAnyTagConfig",
"HasTagConfig", "HasTagConfig",
"IdInListConfig", "IdInListConfig",
"JsonList",
"ListFormatConfig",
"NotConfig", "NotConfig",
"Operator", "Operator",
"RecordingCondition", "RecordingCondition",
@ -64,6 +71,7 @@ __all__ = [
"SoundEventCondition", "SoundEventCondition",
"SoundEventConditionConfig", "SoundEventConditionConfig",
"SoundEventConditionImportConfig", "SoundEventConditionImportConfig",
"TxtList",
"build_clip_annotation_condition", "build_clip_annotation_condition",
"build_recording_condition", "build_recording_condition",
"build_sound_event_condition", "build_sound_event_condition",

View File

@ -87,10 +87,16 @@ class RecordingSatisfies:
return RecordingSatisfies(condition) return RecordingSatisfies(condition)
register_has_tag_condition(clip_annotation_conditions)(HasTagConfig) register_has_tag_condition(clip_annotation_conditions, HasTagConfig)
register_has_all_tags_condition(clip_annotation_conditions)(HasAllTagsConfig) register_has_all_tags_condition(
register_has_any_tag_condition(clip_annotation_conditions)(HasAnyTagConfig) clip_annotation_conditions,
register_id_in_list_condition(clip_annotation_conditions)(IdInListConfig) HasAllTagsConfig,
)
register_has_any_tag_condition(
clip_annotation_conditions,
HasAnyTagConfig,
)
register_id_in_list_condition(clip_annotation_conditions, IdInListConfig)
@register_all_of_condition(clip_annotation_conditions) @register_all_of_condition(clip_annotation_conditions)

View File

@ -1,10 +1,11 @@
import csv
import json import json
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from pathlib import Path from pathlib import Path
from typing import Generic, Literal, ParamSpec, Protocol, TypeVar from typing import Annotated, Generic, Literal, ParamSpec, Protocol, TypeVar
from uuid import UUID from uuid import UUID
from pydantic import BaseModel from pydantic import BaseModel, Field, model_validator
from soundevent import data from soundevent import data
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
@ -14,6 +15,7 @@ __all__ = [
"AllOf", "AllOf",
"AnyOf", "AnyOf",
"Condition", "Condition",
"CsvList",
"HasAllTags", "HasAllTags",
"HasAllTagsConfig", "HasAllTagsConfig",
"HasAnyTag", "HasAnyTag",
@ -22,11 +24,16 @@ __all__ = [
"HasTagConfig", "HasTagConfig",
"IdInList", "IdInList",
"IdInListConfig", "IdInListConfig",
"JsonList",
"ListLoader",
"ListFormatConfig",
"MultiConditionConfigBase", "MultiConditionConfigBase",
"Not", "Not",
"NotConditionConfigBase", "NotConditionConfigBase",
"ObjectWithTags", "ObjectWithTags",
"ObjectWithUUID", "ObjectWithUUID",
"TxtList",
"build_list_loader",
"register_all_of_condition", "register_all_of_condition",
"register_any_of_condition", "register_any_of_condition",
"register_has_all_tags_condition", "register_has_all_tags_condition",
@ -146,34 +153,143 @@ class HasAnyTagConfig(BaseConfig):
tags: list[data.Tag] tags: list[data.Tag]
class JsonList(BaseConfig):
name: Literal["json"] = "json"
field: str | None = None
class TxtList(BaseConfig):
name: Literal["txt"] = "txt"
class CsvList(BaseConfig):
name: Literal["csv"] = "csv"
column: str
ListFormatConfig = Annotated[
JsonList | TxtList | CsvList,
Field(discriminator="name"),
]
ListLoader = Callable[[Path], list[str]]
list_loaders: Registry[ListLoader, []] = Registry("list_loader")
class IdInListConfig(BaseConfig): class IdInListConfig(BaseConfig):
name: Literal["id_in_list"] = "id_in_list" name: Literal["id_in_list"] = "id_in_list"
path: Path path: Path
list_format: Literal["json", "txt"] = "json" format: ListFormatConfig = JsonList()
@model_validator(mode="before")
@classmethod
def _normalize_format(cls, values):
if not isinstance(values, dict):
return values
format_config = values.get("format")
if isinstance(format_config, str):
values = values.copy()
config_class = list_loaders.get_config_type(format_config)
values["format"] = config_class().model_dump()
return values
def _load_ids( class JsonListLoader:
path: Path, def __init__(self, field: str | None):
list_format: Literal["json", "txt"], self.field = field
) -> list[str]:
if list_format == "json": def __call__(self, path: Path) -> list[str]:
content = json.loads(path.read_text()) content = json.loads(path.read_text())
if self.field is not None:
if not isinstance(content, dict):
raise TypeError(
"Expected JSON object with field for 'id_in_list'."
)
if self.field not in content:
raise KeyError(f"Field '{self.field}' not found in '{path}'.")
content = content[self.field]
if not isinstance(content, list): if not isinstance(content, list):
raise TypeError("Expected JSON list with IDs for 'id_in_list'.") raise TypeError("Expected JSON list with IDs for 'id_in_list'.")
return [str(value) for value in content] return [str(value) for value in content]
@list_loaders.register(JsonList)
@staticmethod
def from_config(config: JsonList) -> ListLoader:
return JsonListLoader(field=config.field)
class TxtListLoader:
def __call__(self, path: Path) -> list[str]:
return [ return [
line.strip() for line in path.read_text().splitlines() if line.strip() line.strip()
for line in path.read_text().splitlines()
if line.strip()
] ]
@list_loaders.register(TxtList)
@staticmethod
def from_config(config: TxtList) -> ListLoader:
return TxtListLoader()
class CsvListLoader:
def __init__(self, column: str):
self.column = column
def __call__(self, path: Path) -> list[str]:
with path.open("r", newline="") as csv_file:
reader = csv.DictReader(csv_file)
if reader.fieldnames is None:
raise ValueError(
f"Expected CSV header row for 'id_in_list' in '{path}'."
)
if self.column not in reader.fieldnames:
raise ValueError(
f"Column '{self.column}' not found in '{path}'."
)
values = []
for row in reader:
value = row.get(self.column)
if value is None:
continue
value = value.strip()
if not value:
continue
values.append(value)
return values
@list_loaders.register(CsvList)
@staticmethod
def from_config(config: CsvList) -> ListLoader:
return CsvListLoader(column=config.column)
def build_list_loader(config: ListFormatConfig) -> ListLoader:
return list_loaders.build(config)
def register_id_in_list_condition( def register_id_in_list_condition(
registry: Registry[Condition[UUIDObject], [data.PathLike | None]], registry: Registry[Condition[UUIDObject], [data.PathLike | None]],
) -> Callable[[type[IdInListConfig]], type[IdInListConfig]]: config_cls: type[IdInListConfig],
def decorator(config_cls: type[IdInListConfig]) -> type[IdInListConfig]: ) -> None:
@registry.register(config_cls)
def builder( def builder(
config: IdInListConfig, config: IdInListConfig,
base_dir: data.PathLike | None = None, base_dir: data.PathLike | None = None,
@ -184,7 +300,9 @@ def register_id_in_list_condition(
path = Path(base_dir) / path path = Path(base_dir) / path
ids = set() ids = set()
for index, value in enumerate(_load_ids(path, config.list_format)): loader = build_list_loader(config.format)
values = loader(path)
for index, value in enumerate(values):
try: try:
ids.add(UUID(value)) ids.add(UUID(value))
except ValueError as err: except ValueError as err:
@ -194,16 +312,13 @@ def register_id_in_list_condition(
return IdInList(ids) return IdInList(ids)
return config_cls registry.register(config_cls)(builder)
return decorator
def register_has_tag_condition( def register_has_tag_condition(
registry: Registry[Condition[TaggedObject], P], registry: Registry[Condition[TaggedObject], P],
) -> Callable[[type[HasTagConfig]], type[HasTagConfig]]: config_cls: type[HasTagConfig],
def decorator(config_cls: type[HasTagConfig]) -> type[HasTagConfig]: ) -> None:
@registry.register(config_cls)
def builder( def builder(
config: HasTagConfig, config: HasTagConfig,
*args: P.args, *args: P.args,
@ -211,18 +326,13 @@ def register_has_tag_condition(
) -> Condition[TaggedObject]: ) -> Condition[TaggedObject]:
return HasTag(config.tag) return HasTag(config.tag)
return config_cls registry.register(config_cls)(builder)
return decorator
def register_has_all_tags_condition( def register_has_all_tags_condition(
registry: Registry[Condition[TaggedObject], P], registry: Registry[Condition[TaggedObject], P],
) -> Callable[[type[HasAllTagsConfig]], type[HasAllTagsConfig]]:
def decorator(
config_cls: type[HasAllTagsConfig], config_cls: type[HasAllTagsConfig],
) -> type[HasAllTagsConfig]: ) -> None:
@registry.register(config_cls)
def builder( def builder(
config: HasAllTagsConfig, config: HasAllTagsConfig,
*args: P.args, *args: P.args,
@ -230,18 +340,13 @@ def register_has_all_tags_condition(
) -> Condition[TaggedObject]: ) -> Condition[TaggedObject]:
return HasAllTags(config.tags) return HasAllTags(config.tags)
return config_cls registry.register(config_cls)(builder)
return decorator
def register_has_any_tag_condition( def register_has_any_tag_condition(
registry: Registry[Condition[TaggedObject], P], registry: Registry[Condition[TaggedObject], P],
) -> Callable[[type[HasAnyTagConfig]], type[HasAnyTagConfig]]:
def decorator(
config_cls: type[HasAnyTagConfig], config_cls: type[HasAnyTagConfig],
) -> type[HasAnyTagConfig]: ) -> None:
@registry.register(config_cls)
def builder( def builder(
config: HasAnyTagConfig, config: HasAnyTagConfig,
*args: P.args, *args: P.args,
@ -249,9 +354,7 @@ def register_has_any_tag_condition(
) -> Condition[TaggedObject]: ) -> Condition[TaggedObject]:
return HasAnyTag(config.tags) return HasAnyTag(config.tags)
return config_cls registry.register(config_cls)(builder)
return decorator
def register_not_condition( def register_not_condition(

View File

@ -54,10 +54,10 @@ class RecordingConditionImportConfig(ImportConfig):
name: Literal["import"] = "import" name: Literal["import"] = "import"
register_id_in_list_condition(recording_conditions)(IdInListConfig) register_id_in_list_condition(recording_conditions, IdInListConfig)
register_has_tag_condition(recording_conditions)(HasTagConfig) register_has_tag_condition(recording_conditions, HasTagConfig)
register_has_all_tags_condition(recording_conditions)(HasAllTagsConfig) register_has_all_tags_condition(recording_conditions, HasAllTagsConfig)
register_has_any_tag_condition(recording_conditions)(HasAnyTagConfig) register_has_any_tag_condition(recording_conditions, HasAnyTagConfig)
@register_all_of_condition(recording_conditions) @register_all_of_condition(recording_conditions)

View File

@ -63,10 +63,10 @@ class SoundEventConditionImportConfig(ImportConfig):
name: Literal["import"] = "import" name: Literal["import"] = "import"
register_has_tag_condition(sound_event_conditions)(HasTagConfig) register_has_tag_condition(sound_event_conditions, HasTagConfig)
register_has_all_tags_condition(sound_event_conditions)(HasAllTagsConfig) register_has_all_tags_condition(sound_event_conditions, HasAllTagsConfig)
register_has_any_tag_condition(sound_event_conditions)(HasAnyTagConfig) register_has_any_tag_condition(sound_event_conditions, HasAnyTagConfig)
register_id_in_list_condition(sound_event_conditions)(IdInListConfig) register_id_in_list_condition(sound_event_conditions, IdInListConfig)
Operator = Literal["gt", "gte", "lt", "lte", "eq"] Operator = Literal["gt", "gte", "lt", "lte", "eq"]

View File

@ -121,7 +121,62 @@ def test_id_in_list_condition_supports_txt_format(
f""" f"""
name: id_in_list name: id_in_list
path: {ids_path} path: {ids_path}
list_format: txt format: txt
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_id_in_list_condition_supports_json_field(
tmp_path: Path,
create_recording,
) -> None:
recording_a = create_recording(path=tmp_path / "a.wav")
recording_b = create_recording(path=tmp_path / "b.wav")
ids_path = tmp_path / "recording_ids.json"
ids_path.write_text(
json.dumps(
{
"train": [str(recording_a.uuid)],
"val": [str(recording_b.uuid)],
}
)
)
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: id_in_list
path: {ids_path}
format:
name: json
field: train
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_id_in_list_condition_supports_csv_column(
tmp_path: Path,
create_recording,
) -> None:
recording_a = create_recording(path=tmp_path / "a.wav")
recording_b = create_recording(path=tmp_path / "b.wav")
ids_path = tmp_path / "recording_ids.csv"
ids_path.write_text(f"recording_uuid\n{recording_a.uuid}\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: id_in_list
path: {ids_path}
format:
name: csv
column: recording_uuid
""", """,
) )