mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Add csv list format
This commit is contained in:
parent
c67d9cbba0
commit
1579bbc6c5
@ -9,10 +9,14 @@ from batdetect2.data.conditions.clips import (
|
||||
build_clip_annotation_condition,
|
||||
)
|
||||
from batdetect2.data.conditions.common import (
|
||||
CsvList,
|
||||
HasAllTagsConfig,
|
||||
HasAnyTagConfig,
|
||||
HasTagConfig,
|
||||
IdInListConfig,
|
||||
JsonList,
|
||||
ListFormatConfig,
|
||||
TxtList,
|
||||
)
|
||||
from batdetect2.data.conditions.recordings import (
|
||||
RecordingAllOfConfig,
|
||||
@ -46,12 +50,15 @@ __all__ = [
|
||||
"ClipAnnotationConditionImportConfig",
|
||||
"ClipAnyOfConfig",
|
||||
"ClipNotConfig",
|
||||
"CsvList",
|
||||
"DurationConfig",
|
||||
"FrequencyConfig",
|
||||
"HasAllTagsConfig",
|
||||
"HasAnyTagConfig",
|
||||
"HasTagConfig",
|
||||
"IdInListConfig",
|
||||
"JsonList",
|
||||
"ListFormatConfig",
|
||||
"NotConfig",
|
||||
"Operator",
|
||||
"RecordingCondition",
|
||||
@ -64,6 +71,7 @@ __all__ = [
|
||||
"SoundEventCondition",
|
||||
"SoundEventConditionConfig",
|
||||
"SoundEventConditionImportConfig",
|
||||
"TxtList",
|
||||
"build_clip_annotation_condition",
|
||||
"build_recording_condition",
|
||||
"build_sound_event_condition",
|
||||
|
||||
@ -87,10 +87,16 @@ class RecordingSatisfies:
|
||||
return RecordingSatisfies(condition)
|
||||
|
||||
|
||||
register_has_tag_condition(clip_annotation_conditions)(HasTagConfig)
|
||||
register_has_all_tags_condition(clip_annotation_conditions)(HasAllTagsConfig)
|
||||
register_has_any_tag_condition(clip_annotation_conditions)(HasAnyTagConfig)
|
||||
register_id_in_list_condition(clip_annotation_conditions)(IdInListConfig)
|
||||
register_has_tag_condition(clip_annotation_conditions, HasTagConfig)
|
||||
register_has_all_tags_condition(
|
||||
clip_annotation_conditions,
|
||||
HasAllTagsConfig,
|
||||
)
|
||||
register_has_any_tag_condition(
|
||||
clip_annotation_conditions,
|
||||
HasAnyTagConfig,
|
||||
)
|
||||
register_id_in_list_condition(clip_annotation_conditions, IdInListConfig)
|
||||
|
||||
|
||||
@register_all_of_condition(clip_annotation_conditions)
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
import csv
|
||||
import json
|
||||
from collections.abc import Callable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Generic, Literal, ParamSpec, Protocol, TypeVar
|
||||
from typing import Annotated, Generic, Literal, ParamSpec, Protocol, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
@ -14,6 +15,7 @@ __all__ = [
|
||||
"AllOf",
|
||||
"AnyOf",
|
||||
"Condition",
|
||||
"CsvList",
|
||||
"HasAllTags",
|
||||
"HasAllTagsConfig",
|
||||
"HasAnyTag",
|
||||
@ -22,11 +24,16 @@ __all__ = [
|
||||
"HasTagConfig",
|
||||
"IdInList",
|
||||
"IdInListConfig",
|
||||
"JsonList",
|
||||
"ListLoader",
|
||||
"ListFormatConfig",
|
||||
"MultiConditionConfigBase",
|
||||
"Not",
|
||||
"NotConditionConfigBase",
|
||||
"ObjectWithTags",
|
||||
"ObjectWithUUID",
|
||||
"TxtList",
|
||||
"build_list_loader",
|
||||
"register_all_of_condition",
|
||||
"register_any_of_condition",
|
||||
"register_has_all_tags_condition",
|
||||
@ -146,112 +153,208 @@ class HasAnyTagConfig(BaseConfig):
|
||||
tags: list[data.Tag]
|
||||
|
||||
|
||||
class JsonList(BaseConfig):
|
||||
name: Literal["json"] = "json"
|
||||
field: str | None = None
|
||||
|
||||
|
||||
class TxtList(BaseConfig):
|
||||
name: Literal["txt"] = "txt"
|
||||
|
||||
|
||||
class CsvList(BaseConfig):
|
||||
name: Literal["csv"] = "csv"
|
||||
column: str
|
||||
|
||||
|
||||
ListFormatConfig = Annotated[
|
||||
JsonList | TxtList | CsvList,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
ListLoader = Callable[[Path], list[str]]
|
||||
|
||||
list_loaders: Registry[ListLoader, []] = Registry("list_loader")
|
||||
|
||||
|
||||
class IdInListConfig(BaseConfig):
|
||||
name: Literal["id_in_list"] = "id_in_list"
|
||||
path: Path
|
||||
list_format: Literal["json", "txt"] = "json"
|
||||
format: ListFormatConfig = JsonList()
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_format(cls, values):
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
|
||||
format_config = values.get("format")
|
||||
|
||||
if isinstance(format_config, str):
|
||||
values = values.copy()
|
||||
config_class = list_loaders.get_config_type(format_config)
|
||||
values["format"] = config_class().model_dump()
|
||||
|
||||
return values
|
||||
|
||||
|
||||
def _load_ids(
|
||||
path: Path,
|
||||
list_format: Literal["json", "txt"],
|
||||
) -> list[str]:
|
||||
if list_format == "json":
|
||||
class JsonListLoader:
|
||||
def __init__(self, field: str | None):
|
||||
self.field = field
|
||||
|
||||
def __call__(self, path: Path) -> list[str]:
|
||||
content = json.loads(path.read_text())
|
||||
|
||||
if self.field is not None:
|
||||
if not isinstance(content, dict):
|
||||
raise TypeError(
|
||||
"Expected JSON object with field for 'id_in_list'."
|
||||
)
|
||||
|
||||
if self.field not in content:
|
||||
raise KeyError(f"Field '{self.field}' not found in '{path}'.")
|
||||
|
||||
content = content[self.field]
|
||||
|
||||
if not isinstance(content, list):
|
||||
raise TypeError("Expected JSON list with IDs for 'id_in_list'.")
|
||||
|
||||
return [str(value) for value in content]
|
||||
|
||||
return [
|
||||
line.strip() for line in path.read_text().splitlines() if line.strip()
|
||||
]
|
||||
@list_loaders.register(JsonList)
|
||||
@staticmethod
|
||||
def from_config(config: JsonList) -> ListLoader:
|
||||
return JsonListLoader(field=config.field)
|
||||
|
||||
|
||||
class TxtListLoader:
|
||||
def __call__(self, path: Path) -> list[str]:
|
||||
return [
|
||||
line.strip()
|
||||
for line in path.read_text().splitlines()
|
||||
if line.strip()
|
||||
]
|
||||
|
||||
@list_loaders.register(TxtList)
|
||||
@staticmethod
|
||||
def from_config(config: TxtList) -> ListLoader:
|
||||
return TxtListLoader()
|
||||
|
||||
|
||||
class CsvListLoader:
|
||||
def __init__(self, column: str):
|
||||
self.column = column
|
||||
|
||||
def __call__(self, path: Path) -> list[str]:
|
||||
with path.open("r", newline="") as csv_file:
|
||||
reader = csv.DictReader(csv_file)
|
||||
|
||||
if reader.fieldnames is None:
|
||||
raise ValueError(
|
||||
f"Expected CSV header row for 'id_in_list' in '{path}'."
|
||||
)
|
||||
|
||||
if self.column not in reader.fieldnames:
|
||||
raise ValueError(
|
||||
f"Column '{self.column}' not found in '{path}'."
|
||||
)
|
||||
|
||||
values = []
|
||||
for row in reader:
|
||||
value = row.get(self.column)
|
||||
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
value = value.strip()
|
||||
|
||||
if not value:
|
||||
continue
|
||||
|
||||
values.append(value)
|
||||
|
||||
return values
|
||||
|
||||
@list_loaders.register(CsvList)
|
||||
@staticmethod
|
||||
def from_config(config: CsvList) -> ListLoader:
|
||||
return CsvListLoader(column=config.column)
|
||||
|
||||
|
||||
def build_list_loader(config: ListFormatConfig) -> ListLoader:
|
||||
return list_loaders.build(config)
|
||||
|
||||
|
||||
def register_id_in_list_condition(
|
||||
registry: Registry[Condition[UUIDObject], [data.PathLike | None]],
|
||||
) -> Callable[[type[IdInListConfig]], type[IdInListConfig]]:
|
||||
def decorator(config_cls: type[IdInListConfig]) -> type[IdInListConfig]:
|
||||
@registry.register(config_cls)
|
||||
def builder(
|
||||
config: IdInListConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> Condition[UUIDObject]:
|
||||
path = config.path
|
||||
config_cls: type[IdInListConfig],
|
||||
) -> None:
|
||||
def builder(
|
||||
config: IdInListConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> Condition[UUIDObject]:
|
||||
path = config.path
|
||||
|
||||
if base_dir is not None and not path.is_absolute():
|
||||
path = Path(base_dir) / path
|
||||
if base_dir is not None and not path.is_absolute():
|
||||
path = Path(base_dir) / path
|
||||
|
||||
ids = set()
|
||||
for index, value in enumerate(_load_ids(path, config.list_format)):
|
||||
try:
|
||||
ids.add(UUID(value))
|
||||
except ValueError as err:
|
||||
raise ValueError(
|
||||
f"Invalid ID at index {index} in '{path}': {value!r}."
|
||||
) from err
|
||||
ids = set()
|
||||
loader = build_list_loader(config.format)
|
||||
values = loader(path)
|
||||
for index, value in enumerate(values):
|
||||
try:
|
||||
ids.add(UUID(value))
|
||||
except ValueError as err:
|
||||
raise ValueError(
|
||||
f"Invalid ID at index {index} in '{path}': {value!r}."
|
||||
) from err
|
||||
|
||||
return IdInList(ids)
|
||||
return IdInList(ids)
|
||||
|
||||
return config_cls
|
||||
|
||||
return decorator
|
||||
registry.register(config_cls)(builder)
|
||||
|
||||
|
||||
def register_has_tag_condition(
|
||||
registry: Registry[Condition[TaggedObject], P],
|
||||
) -> Callable[[type[HasTagConfig]], type[HasTagConfig]]:
|
||||
def decorator(config_cls: type[HasTagConfig]) -> type[HasTagConfig]:
|
||||
@registry.register(config_cls)
|
||||
def builder(
|
||||
config: HasTagConfig,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[TaggedObject]:
|
||||
return HasTag(config.tag)
|
||||
config_cls: type[HasTagConfig],
|
||||
) -> None:
|
||||
def builder(
|
||||
config: HasTagConfig,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[TaggedObject]:
|
||||
return HasTag(config.tag)
|
||||
|
||||
return config_cls
|
||||
|
||||
return decorator
|
||||
registry.register(config_cls)(builder)
|
||||
|
||||
|
||||
def register_has_all_tags_condition(
|
||||
registry: Registry[Condition[TaggedObject], P],
|
||||
) -> Callable[[type[HasAllTagsConfig]], type[HasAllTagsConfig]]:
|
||||
def decorator(
|
||||
config_cls: type[HasAllTagsConfig],
|
||||
) -> type[HasAllTagsConfig]:
|
||||
@registry.register(config_cls)
|
||||
def builder(
|
||||
config: HasAllTagsConfig,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[TaggedObject]:
|
||||
return HasAllTags(config.tags)
|
||||
config_cls: type[HasAllTagsConfig],
|
||||
) -> None:
|
||||
def builder(
|
||||
config: HasAllTagsConfig,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[TaggedObject]:
|
||||
return HasAllTags(config.tags)
|
||||
|
||||
return config_cls
|
||||
|
||||
return decorator
|
||||
registry.register(config_cls)(builder)
|
||||
|
||||
|
||||
def register_has_any_tag_condition(
|
||||
registry: Registry[Condition[TaggedObject], P],
|
||||
) -> Callable[[type[HasAnyTagConfig]], type[HasAnyTagConfig]]:
|
||||
def decorator(
|
||||
config_cls: type[HasAnyTagConfig],
|
||||
) -> type[HasAnyTagConfig]:
|
||||
@registry.register(config_cls)
|
||||
def builder(
|
||||
config: HasAnyTagConfig,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[TaggedObject]:
|
||||
return HasAnyTag(config.tags)
|
||||
config_cls: type[HasAnyTagConfig],
|
||||
) -> None:
|
||||
def builder(
|
||||
config: HasAnyTagConfig,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[TaggedObject]:
|
||||
return HasAnyTag(config.tags)
|
||||
|
||||
return config_cls
|
||||
|
||||
return decorator
|
||||
registry.register(config_cls)(builder)
|
||||
|
||||
|
||||
def register_not_condition(
|
||||
|
||||
@ -54,10 +54,10 @@ class RecordingConditionImportConfig(ImportConfig):
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
register_id_in_list_condition(recording_conditions)(IdInListConfig)
|
||||
register_has_tag_condition(recording_conditions)(HasTagConfig)
|
||||
register_has_all_tags_condition(recording_conditions)(HasAllTagsConfig)
|
||||
register_has_any_tag_condition(recording_conditions)(HasAnyTagConfig)
|
||||
register_id_in_list_condition(recording_conditions, IdInListConfig)
|
||||
register_has_tag_condition(recording_conditions, HasTagConfig)
|
||||
register_has_all_tags_condition(recording_conditions, HasAllTagsConfig)
|
||||
register_has_any_tag_condition(recording_conditions, HasAnyTagConfig)
|
||||
|
||||
|
||||
@register_all_of_condition(recording_conditions)
|
||||
|
||||
@ -63,10 +63,10 @@ class SoundEventConditionImportConfig(ImportConfig):
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
register_has_tag_condition(sound_event_conditions)(HasTagConfig)
|
||||
register_has_all_tags_condition(sound_event_conditions)(HasAllTagsConfig)
|
||||
register_has_any_tag_condition(sound_event_conditions)(HasAnyTagConfig)
|
||||
register_id_in_list_condition(sound_event_conditions)(IdInListConfig)
|
||||
register_has_tag_condition(sound_event_conditions, HasTagConfig)
|
||||
register_has_all_tags_condition(sound_event_conditions, HasAllTagsConfig)
|
||||
register_has_any_tag_condition(sound_event_conditions, HasAnyTagConfig)
|
||||
register_id_in_list_condition(sound_event_conditions, IdInListConfig)
|
||||
|
||||
|
||||
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
||||
|
||||
@ -121,7 +121,62 @@ def test_id_in_list_condition_supports_txt_format(
|
||||
f"""
|
||||
name: id_in_list
|
||||
path: {ids_path}
|
||||
list_format: txt
|
||||
format: txt
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_id_in_list_condition_supports_json_field(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||
ids_path = tmp_path / "recording_ids.json"
|
||||
ids_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"train": [str(recording_a.uuid)],
|
||||
"val": [str(recording_b.uuid)],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: id_in_list
|
||||
path: {ids_path}
|
||||
format:
|
||||
name: json
|
||||
field: train
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_id_in_list_condition_supports_csv_column(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||
ids_path = tmp_path / "recording_ids.csv"
|
||||
ids_path.write_text(f"recording_uuid\n{recording_a.uuid}\n")
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: id_in_list
|
||||
path: {ids_path}
|
||||
format:
|
||||
name: csv
|
||||
column: recording_uuid
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user