Add csv list format

This commit is contained in:
mbsantiago 2026-04-04 10:23:14 +01:00
parent c67d9cbba0
commit 1579bbc6c5
6 changed files with 258 additions and 86 deletions

View File

@ -9,10 +9,14 @@ from batdetect2.data.conditions.clips import (
build_clip_annotation_condition,
)
from batdetect2.data.conditions.common import (
CsvList,
HasAllTagsConfig,
HasAnyTagConfig,
HasTagConfig,
IdInListConfig,
JsonList,
ListFormatConfig,
TxtList,
)
from batdetect2.data.conditions.recordings import (
RecordingAllOfConfig,
@ -46,12 +50,15 @@ __all__ = [
"ClipAnnotationConditionImportConfig",
"ClipAnyOfConfig",
"ClipNotConfig",
"CsvList",
"DurationConfig",
"FrequencyConfig",
"HasAllTagsConfig",
"HasAnyTagConfig",
"HasTagConfig",
"IdInListConfig",
"JsonList",
"ListFormatConfig",
"NotConfig",
"Operator",
"RecordingCondition",
@ -64,6 +71,7 @@ __all__ = [
"SoundEventCondition",
"SoundEventConditionConfig",
"SoundEventConditionImportConfig",
"TxtList",
"build_clip_annotation_condition",
"build_recording_condition",
"build_sound_event_condition",

View File

@ -87,10 +87,16 @@ class RecordingSatisfies:
return RecordingSatisfies(condition)
register_has_tag_condition(clip_annotation_conditions)(HasTagConfig)
register_has_all_tags_condition(clip_annotation_conditions)(HasAllTagsConfig)
register_has_any_tag_condition(clip_annotation_conditions)(HasAnyTagConfig)
register_id_in_list_condition(clip_annotation_conditions)(IdInListConfig)
register_has_tag_condition(clip_annotation_conditions, HasTagConfig)
register_has_all_tags_condition(
clip_annotation_conditions,
HasAllTagsConfig,
)
register_has_any_tag_condition(
clip_annotation_conditions,
HasAnyTagConfig,
)
register_id_in_list_condition(clip_annotation_conditions, IdInListConfig)
@register_all_of_condition(clip_annotation_conditions)

View File

@ -1,10 +1,11 @@
import csv
import json
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Generic, Literal, ParamSpec, Protocol, TypeVar
from typing import Annotated, Generic, Literal, ParamSpec, Protocol, TypeVar
from uuid import UUID
from pydantic import BaseModel
from pydantic import BaseModel, Field, model_validator
from soundevent import data
from batdetect2.core.configs import BaseConfig
@ -14,6 +15,7 @@ __all__ = [
"AllOf",
"AnyOf",
"Condition",
"CsvList",
"HasAllTags",
"HasAllTagsConfig",
"HasAnyTag",
@ -22,11 +24,16 @@ __all__ = [
"HasTagConfig",
"IdInList",
"IdInListConfig",
"JsonList",
"ListLoader",
"ListFormatConfig",
"MultiConditionConfigBase",
"Not",
"NotConditionConfigBase",
"ObjectWithTags",
"ObjectWithUUID",
"TxtList",
"build_list_loader",
"register_all_of_condition",
"register_any_of_condition",
"register_has_all_tags_condition",
@ -146,112 +153,208 @@ class HasAnyTagConfig(BaseConfig):
tags: list[data.Tag]
class JsonList(BaseConfig):
name: Literal["json"] = "json"
field: str | None = None
class TxtList(BaseConfig):
name: Literal["txt"] = "txt"
class CsvList(BaseConfig):
name: Literal["csv"] = "csv"
column: str
ListFormatConfig = Annotated[
JsonList | TxtList | CsvList,
Field(discriminator="name"),
]
ListLoader = Callable[[Path], list[str]]
list_loaders: Registry[ListLoader, []] = Registry("list_loader")
class IdInListConfig(BaseConfig):
name: Literal["id_in_list"] = "id_in_list"
path: Path
list_format: Literal["json", "txt"] = "json"
format: ListFormatConfig = JsonList()
@model_validator(mode="before")
@classmethod
def _normalize_format(cls, values):
if not isinstance(values, dict):
return values
format_config = values.get("format")
if isinstance(format_config, str):
values = values.copy()
config_class = list_loaders.get_config_type(format_config)
values["format"] = config_class().model_dump()
return values
def _load_ids(
path: Path,
list_format: Literal["json", "txt"],
) -> list[str]:
if list_format == "json":
class JsonListLoader:
def __init__(self, field: str | None):
self.field = field
def __call__(self, path: Path) -> list[str]:
content = json.loads(path.read_text())
if self.field is not None:
if not isinstance(content, dict):
raise TypeError(
"Expected JSON object with field for 'id_in_list'."
)
if self.field not in content:
raise KeyError(f"Field '{self.field}' not found in '{path}'.")
content = content[self.field]
if not isinstance(content, list):
raise TypeError("Expected JSON list with IDs for 'id_in_list'.")
return [str(value) for value in content]
return [
line.strip() for line in path.read_text().splitlines() if line.strip()
]
@list_loaders.register(JsonList)
@staticmethod
def from_config(config: JsonList) -> ListLoader:
return JsonListLoader(field=config.field)
class TxtListLoader:
def __call__(self, path: Path) -> list[str]:
return [
line.strip()
for line in path.read_text().splitlines()
if line.strip()
]
@list_loaders.register(TxtList)
@staticmethod
def from_config(config: TxtList) -> ListLoader:
return TxtListLoader()
class CsvListLoader:
def __init__(self, column: str):
self.column = column
def __call__(self, path: Path) -> list[str]:
with path.open("r", newline="") as csv_file:
reader = csv.DictReader(csv_file)
if reader.fieldnames is None:
raise ValueError(
f"Expected CSV header row for 'id_in_list' in '{path}'."
)
if self.column not in reader.fieldnames:
raise ValueError(
f"Column '{self.column}' not found in '{path}'."
)
values = []
for row in reader:
value = row.get(self.column)
if value is None:
continue
value = value.strip()
if not value:
continue
values.append(value)
return values
@list_loaders.register(CsvList)
@staticmethod
def from_config(config: CsvList) -> ListLoader:
return CsvListLoader(column=config.column)
def build_list_loader(config: ListFormatConfig) -> ListLoader:
return list_loaders.build(config)
def register_id_in_list_condition(
registry: Registry[Condition[UUIDObject], [data.PathLike | None]],
) -> Callable[[type[IdInListConfig]], type[IdInListConfig]]:
def decorator(config_cls: type[IdInListConfig]) -> type[IdInListConfig]:
@registry.register(config_cls)
def builder(
config: IdInListConfig,
base_dir: data.PathLike | None = None,
) -> Condition[UUIDObject]:
path = config.path
config_cls: type[IdInListConfig],
) -> None:
def builder(
config: IdInListConfig,
base_dir: data.PathLike | None = None,
) -> Condition[UUIDObject]:
path = config.path
if base_dir is not None and not path.is_absolute():
path = Path(base_dir) / path
if base_dir is not None and not path.is_absolute():
path = Path(base_dir) / path
ids = set()
for index, value in enumerate(_load_ids(path, config.list_format)):
try:
ids.add(UUID(value))
except ValueError as err:
raise ValueError(
f"Invalid ID at index {index} in '{path}': {value!r}."
) from err
ids = set()
loader = build_list_loader(config.format)
values = loader(path)
for index, value in enumerate(values):
try:
ids.add(UUID(value))
except ValueError as err:
raise ValueError(
f"Invalid ID at index {index} in '{path}': {value!r}."
) from err
return IdInList(ids)
return IdInList(ids)
return config_cls
return decorator
registry.register(config_cls)(builder)
def register_has_tag_condition(
registry: Registry[Condition[TaggedObject], P],
) -> Callable[[type[HasTagConfig]], type[HasTagConfig]]:
def decorator(config_cls: type[HasTagConfig]) -> type[HasTagConfig]:
@registry.register(config_cls)
def builder(
config: HasTagConfig,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[TaggedObject]:
return HasTag(config.tag)
config_cls: type[HasTagConfig],
) -> None:
def builder(
config: HasTagConfig,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[TaggedObject]:
return HasTag(config.tag)
return config_cls
return decorator
registry.register(config_cls)(builder)
def register_has_all_tags_condition(
registry: Registry[Condition[TaggedObject], P],
) -> Callable[[type[HasAllTagsConfig]], type[HasAllTagsConfig]]:
def decorator(
config_cls: type[HasAllTagsConfig],
) -> type[HasAllTagsConfig]:
@registry.register(config_cls)
def builder(
config: HasAllTagsConfig,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[TaggedObject]:
return HasAllTags(config.tags)
config_cls: type[HasAllTagsConfig],
) -> None:
def builder(
config: HasAllTagsConfig,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[TaggedObject]:
return HasAllTags(config.tags)
return config_cls
return decorator
registry.register(config_cls)(builder)
def register_has_any_tag_condition(
registry: Registry[Condition[TaggedObject], P],
) -> Callable[[type[HasAnyTagConfig]], type[HasAnyTagConfig]]:
def decorator(
config_cls: type[HasAnyTagConfig],
) -> type[HasAnyTagConfig]:
@registry.register(config_cls)
def builder(
config: HasAnyTagConfig,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[TaggedObject]:
return HasAnyTag(config.tags)
config_cls: type[HasAnyTagConfig],
) -> None:
def builder(
config: HasAnyTagConfig,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[TaggedObject]:
return HasAnyTag(config.tags)
return config_cls
return decorator
registry.register(config_cls)(builder)
def register_not_condition(

View File

@ -54,10 +54,10 @@ class RecordingConditionImportConfig(ImportConfig):
name: Literal["import"] = "import"
register_id_in_list_condition(recording_conditions)(IdInListConfig)
register_has_tag_condition(recording_conditions)(HasTagConfig)
register_has_all_tags_condition(recording_conditions)(HasAllTagsConfig)
register_has_any_tag_condition(recording_conditions)(HasAnyTagConfig)
register_id_in_list_condition(recording_conditions, IdInListConfig)
register_has_tag_condition(recording_conditions, HasTagConfig)
register_has_all_tags_condition(recording_conditions, HasAllTagsConfig)
register_has_any_tag_condition(recording_conditions, HasAnyTagConfig)
@register_all_of_condition(recording_conditions)

View File

@ -63,10 +63,10 @@ class SoundEventConditionImportConfig(ImportConfig):
name: Literal["import"] = "import"
register_has_tag_condition(sound_event_conditions)(HasTagConfig)
register_has_all_tags_condition(sound_event_conditions)(HasAllTagsConfig)
register_has_any_tag_condition(sound_event_conditions)(HasAnyTagConfig)
register_id_in_list_condition(sound_event_conditions)(IdInListConfig)
register_has_tag_condition(sound_event_conditions, HasTagConfig)
register_has_all_tags_condition(sound_event_conditions, HasAllTagsConfig)
register_has_any_tag_condition(sound_event_conditions, HasAnyTagConfig)
register_id_in_list_condition(sound_event_conditions, IdInListConfig)
Operator = Literal["gt", "gte", "lt", "lte", "eq"]

View File

@ -121,7 +121,62 @@ def test_id_in_list_condition_supports_txt_format(
f"""
name: id_in_list
path: {ids_path}
list_format: txt
format: txt
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_id_in_list_condition_supports_json_field(
tmp_path: Path,
create_recording,
) -> None:
recording_a = create_recording(path=tmp_path / "a.wav")
recording_b = create_recording(path=tmp_path / "b.wav")
ids_path = tmp_path / "recording_ids.json"
ids_path.write_text(
json.dumps(
{
"train": [str(recording_a.uuid)],
"val": [str(recording_b.uuid)],
}
)
)
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: id_in_list
path: {ids_path}
format:
name: json
field: train
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_id_in_list_condition_supports_csv_column(
tmp_path: Path,
create_recording,
) -> None:
recording_a = create_recording(path=tmp_path / "a.wav")
recording_b = create_recording(path=tmp_path / "b.wav")
ids_path = tmp_path / "recording_ids.csv"
ids_path.write_text(f"recording_uuid\n{recording_a.uuid}\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: id_in_list
path: {ids_path}
format:
name: csv
column: recording_uuid
""",
)