mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Add csv list format
This commit is contained in:
parent
c67d9cbba0
commit
1579bbc6c5
@ -9,10 +9,14 @@ from batdetect2.data.conditions.clips import (
|
|||||||
build_clip_annotation_condition,
|
build_clip_annotation_condition,
|
||||||
)
|
)
|
||||||
from batdetect2.data.conditions.common import (
|
from batdetect2.data.conditions.common import (
|
||||||
|
CsvList,
|
||||||
HasAllTagsConfig,
|
HasAllTagsConfig,
|
||||||
HasAnyTagConfig,
|
HasAnyTagConfig,
|
||||||
HasTagConfig,
|
HasTagConfig,
|
||||||
IdInListConfig,
|
IdInListConfig,
|
||||||
|
JsonList,
|
||||||
|
ListFormatConfig,
|
||||||
|
TxtList,
|
||||||
)
|
)
|
||||||
from batdetect2.data.conditions.recordings import (
|
from batdetect2.data.conditions.recordings import (
|
||||||
RecordingAllOfConfig,
|
RecordingAllOfConfig,
|
||||||
@ -46,12 +50,15 @@ __all__ = [
|
|||||||
"ClipAnnotationConditionImportConfig",
|
"ClipAnnotationConditionImportConfig",
|
||||||
"ClipAnyOfConfig",
|
"ClipAnyOfConfig",
|
||||||
"ClipNotConfig",
|
"ClipNotConfig",
|
||||||
|
"CsvList",
|
||||||
"DurationConfig",
|
"DurationConfig",
|
||||||
"FrequencyConfig",
|
"FrequencyConfig",
|
||||||
"HasAllTagsConfig",
|
"HasAllTagsConfig",
|
||||||
"HasAnyTagConfig",
|
"HasAnyTagConfig",
|
||||||
"HasTagConfig",
|
"HasTagConfig",
|
||||||
"IdInListConfig",
|
"IdInListConfig",
|
||||||
|
"JsonList",
|
||||||
|
"ListFormatConfig",
|
||||||
"NotConfig",
|
"NotConfig",
|
||||||
"Operator",
|
"Operator",
|
||||||
"RecordingCondition",
|
"RecordingCondition",
|
||||||
@ -64,6 +71,7 @@ __all__ = [
|
|||||||
"SoundEventCondition",
|
"SoundEventCondition",
|
||||||
"SoundEventConditionConfig",
|
"SoundEventConditionConfig",
|
||||||
"SoundEventConditionImportConfig",
|
"SoundEventConditionImportConfig",
|
||||||
|
"TxtList",
|
||||||
"build_clip_annotation_condition",
|
"build_clip_annotation_condition",
|
||||||
"build_recording_condition",
|
"build_recording_condition",
|
||||||
"build_sound_event_condition",
|
"build_sound_event_condition",
|
||||||
|
|||||||
@ -87,10 +87,16 @@ class RecordingSatisfies:
|
|||||||
return RecordingSatisfies(condition)
|
return RecordingSatisfies(condition)
|
||||||
|
|
||||||
|
|
||||||
register_has_tag_condition(clip_annotation_conditions)(HasTagConfig)
|
register_has_tag_condition(clip_annotation_conditions, HasTagConfig)
|
||||||
register_has_all_tags_condition(clip_annotation_conditions)(HasAllTagsConfig)
|
register_has_all_tags_condition(
|
||||||
register_has_any_tag_condition(clip_annotation_conditions)(HasAnyTagConfig)
|
clip_annotation_conditions,
|
||||||
register_id_in_list_condition(clip_annotation_conditions)(IdInListConfig)
|
HasAllTagsConfig,
|
||||||
|
)
|
||||||
|
register_has_any_tag_condition(
|
||||||
|
clip_annotation_conditions,
|
||||||
|
HasAnyTagConfig,
|
||||||
|
)
|
||||||
|
register_id_in_list_condition(clip_annotation_conditions, IdInListConfig)
|
||||||
|
|
||||||
|
|
||||||
@register_all_of_condition(clip_annotation_conditions)
|
@register_all_of_condition(clip_annotation_conditions)
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
|
import csv
|
||||||
import json
|
import json
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generic, Literal, ParamSpec, Protocol, TypeVar
|
from typing import Annotated, Generic, Literal, ParamSpec, Protocol, TypeVar
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field, model_validator
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
@ -14,6 +15,7 @@ __all__ = [
|
|||||||
"AllOf",
|
"AllOf",
|
||||||
"AnyOf",
|
"AnyOf",
|
||||||
"Condition",
|
"Condition",
|
||||||
|
"CsvList",
|
||||||
"HasAllTags",
|
"HasAllTags",
|
||||||
"HasAllTagsConfig",
|
"HasAllTagsConfig",
|
||||||
"HasAnyTag",
|
"HasAnyTag",
|
||||||
@ -22,11 +24,16 @@ __all__ = [
|
|||||||
"HasTagConfig",
|
"HasTagConfig",
|
||||||
"IdInList",
|
"IdInList",
|
||||||
"IdInListConfig",
|
"IdInListConfig",
|
||||||
|
"JsonList",
|
||||||
|
"ListLoader",
|
||||||
|
"ListFormatConfig",
|
||||||
"MultiConditionConfigBase",
|
"MultiConditionConfigBase",
|
||||||
"Not",
|
"Not",
|
||||||
"NotConditionConfigBase",
|
"NotConditionConfigBase",
|
||||||
"ObjectWithTags",
|
"ObjectWithTags",
|
||||||
"ObjectWithUUID",
|
"ObjectWithUUID",
|
||||||
|
"TxtList",
|
||||||
|
"build_list_loader",
|
||||||
"register_all_of_condition",
|
"register_all_of_condition",
|
||||||
"register_any_of_condition",
|
"register_any_of_condition",
|
||||||
"register_has_all_tags_condition",
|
"register_has_all_tags_condition",
|
||||||
@ -146,34 +153,143 @@ class HasAnyTagConfig(BaseConfig):
|
|||||||
tags: list[data.Tag]
|
tags: list[data.Tag]
|
||||||
|
|
||||||
|
|
||||||
|
class JsonList(BaseConfig):
|
||||||
|
name: Literal["json"] = "json"
|
||||||
|
field: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class TxtList(BaseConfig):
|
||||||
|
name: Literal["txt"] = "txt"
|
||||||
|
|
||||||
|
|
||||||
|
class CsvList(BaseConfig):
|
||||||
|
name: Literal["csv"] = "csv"
|
||||||
|
column: str
|
||||||
|
|
||||||
|
|
||||||
|
ListFormatConfig = Annotated[
|
||||||
|
JsonList | TxtList | CsvList,
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
ListLoader = Callable[[Path], list[str]]
|
||||||
|
|
||||||
|
list_loaders: Registry[ListLoader, []] = Registry("list_loader")
|
||||||
|
|
||||||
|
|
||||||
class IdInListConfig(BaseConfig):
|
class IdInListConfig(BaseConfig):
|
||||||
name: Literal["id_in_list"] = "id_in_list"
|
name: Literal["id_in_list"] = "id_in_list"
|
||||||
path: Path
|
path: Path
|
||||||
list_format: Literal["json", "txt"] = "json"
|
format: ListFormatConfig = JsonList()
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_format(cls, values):
|
||||||
|
if not isinstance(values, dict):
|
||||||
|
return values
|
||||||
|
|
||||||
|
format_config = values.get("format")
|
||||||
|
|
||||||
|
if isinstance(format_config, str):
|
||||||
|
values = values.copy()
|
||||||
|
config_class = list_loaders.get_config_type(format_config)
|
||||||
|
values["format"] = config_class().model_dump()
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
def _load_ids(
|
class JsonListLoader:
|
||||||
path: Path,
|
def __init__(self, field: str | None):
|
||||||
list_format: Literal["json", "txt"],
|
self.field = field
|
||||||
) -> list[str]:
|
|
||||||
if list_format == "json":
|
def __call__(self, path: Path) -> list[str]:
|
||||||
content = json.loads(path.read_text())
|
content = json.loads(path.read_text())
|
||||||
|
|
||||||
|
if self.field is not None:
|
||||||
|
if not isinstance(content, dict):
|
||||||
|
raise TypeError(
|
||||||
|
"Expected JSON object with field for 'id_in_list'."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.field not in content:
|
||||||
|
raise KeyError(f"Field '{self.field}' not found in '{path}'.")
|
||||||
|
|
||||||
|
content = content[self.field]
|
||||||
|
|
||||||
if not isinstance(content, list):
|
if not isinstance(content, list):
|
||||||
raise TypeError("Expected JSON list with IDs for 'id_in_list'.")
|
raise TypeError("Expected JSON list with IDs for 'id_in_list'.")
|
||||||
|
|
||||||
return [str(value) for value in content]
|
return [str(value) for value in content]
|
||||||
|
|
||||||
|
@list_loaders.register(JsonList)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: JsonList) -> ListLoader:
|
||||||
|
return JsonListLoader(field=config.field)
|
||||||
|
|
||||||
|
|
||||||
|
class TxtListLoader:
|
||||||
|
def __call__(self, path: Path) -> list[str]:
|
||||||
return [
|
return [
|
||||||
line.strip() for line in path.read_text().splitlines() if line.strip()
|
line.strip()
|
||||||
|
for line in path.read_text().splitlines()
|
||||||
|
if line.strip()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@list_loaders.register(TxtList)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: TxtList) -> ListLoader:
|
||||||
|
return TxtListLoader()
|
||||||
|
|
||||||
|
|
||||||
|
class CsvListLoader:
|
||||||
|
def __init__(self, column: str):
|
||||||
|
self.column = column
|
||||||
|
|
||||||
|
def __call__(self, path: Path) -> list[str]:
|
||||||
|
with path.open("r", newline="") as csv_file:
|
||||||
|
reader = csv.DictReader(csv_file)
|
||||||
|
|
||||||
|
if reader.fieldnames is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected CSV header row for 'id_in_list' in '{path}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.column not in reader.fieldnames:
|
||||||
|
raise ValueError(
|
||||||
|
f"Column '{self.column}' not found in '{path}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
values = []
|
||||||
|
for row in reader:
|
||||||
|
value = row.get(self.column)
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
value = value.strip()
|
||||||
|
|
||||||
|
if not value:
|
||||||
|
continue
|
||||||
|
|
||||||
|
values.append(value)
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
@list_loaders.register(CsvList)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: CsvList) -> ListLoader:
|
||||||
|
return CsvListLoader(column=config.column)
|
||||||
|
|
||||||
|
|
||||||
|
def build_list_loader(config: ListFormatConfig) -> ListLoader:
|
||||||
|
return list_loaders.build(config)
|
||||||
|
|
||||||
|
|
||||||
def register_id_in_list_condition(
|
def register_id_in_list_condition(
|
||||||
registry: Registry[Condition[UUIDObject], [data.PathLike | None]],
|
registry: Registry[Condition[UUIDObject], [data.PathLike | None]],
|
||||||
) -> Callable[[type[IdInListConfig]], type[IdInListConfig]]:
|
config_cls: type[IdInListConfig],
|
||||||
def decorator(config_cls: type[IdInListConfig]) -> type[IdInListConfig]:
|
) -> None:
|
||||||
@registry.register(config_cls)
|
|
||||||
def builder(
|
def builder(
|
||||||
config: IdInListConfig,
|
config: IdInListConfig,
|
||||||
base_dir: data.PathLike | None = None,
|
base_dir: data.PathLike | None = None,
|
||||||
@ -184,7 +300,9 @@ def register_id_in_list_condition(
|
|||||||
path = Path(base_dir) / path
|
path = Path(base_dir) / path
|
||||||
|
|
||||||
ids = set()
|
ids = set()
|
||||||
for index, value in enumerate(_load_ids(path, config.list_format)):
|
loader = build_list_loader(config.format)
|
||||||
|
values = loader(path)
|
||||||
|
for index, value in enumerate(values):
|
||||||
try:
|
try:
|
||||||
ids.add(UUID(value))
|
ids.add(UUID(value))
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
@ -194,16 +312,13 @@ def register_id_in_list_condition(
|
|||||||
|
|
||||||
return IdInList(ids)
|
return IdInList(ids)
|
||||||
|
|
||||||
return config_cls
|
registry.register(config_cls)(builder)
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def register_has_tag_condition(
|
def register_has_tag_condition(
|
||||||
registry: Registry[Condition[TaggedObject], P],
|
registry: Registry[Condition[TaggedObject], P],
|
||||||
) -> Callable[[type[HasTagConfig]], type[HasTagConfig]]:
|
config_cls: type[HasTagConfig],
|
||||||
def decorator(config_cls: type[HasTagConfig]) -> type[HasTagConfig]:
|
) -> None:
|
||||||
@registry.register(config_cls)
|
|
||||||
def builder(
|
def builder(
|
||||||
config: HasTagConfig,
|
config: HasTagConfig,
|
||||||
*args: P.args,
|
*args: P.args,
|
||||||
@ -211,18 +326,13 @@ def register_has_tag_condition(
|
|||||||
) -> Condition[TaggedObject]:
|
) -> Condition[TaggedObject]:
|
||||||
return HasTag(config.tag)
|
return HasTag(config.tag)
|
||||||
|
|
||||||
return config_cls
|
registry.register(config_cls)(builder)
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def register_has_all_tags_condition(
|
def register_has_all_tags_condition(
|
||||||
registry: Registry[Condition[TaggedObject], P],
|
registry: Registry[Condition[TaggedObject], P],
|
||||||
) -> Callable[[type[HasAllTagsConfig]], type[HasAllTagsConfig]]:
|
|
||||||
def decorator(
|
|
||||||
config_cls: type[HasAllTagsConfig],
|
config_cls: type[HasAllTagsConfig],
|
||||||
) -> type[HasAllTagsConfig]:
|
) -> None:
|
||||||
@registry.register(config_cls)
|
|
||||||
def builder(
|
def builder(
|
||||||
config: HasAllTagsConfig,
|
config: HasAllTagsConfig,
|
||||||
*args: P.args,
|
*args: P.args,
|
||||||
@ -230,18 +340,13 @@ def register_has_all_tags_condition(
|
|||||||
) -> Condition[TaggedObject]:
|
) -> Condition[TaggedObject]:
|
||||||
return HasAllTags(config.tags)
|
return HasAllTags(config.tags)
|
||||||
|
|
||||||
return config_cls
|
registry.register(config_cls)(builder)
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def register_has_any_tag_condition(
|
def register_has_any_tag_condition(
|
||||||
registry: Registry[Condition[TaggedObject], P],
|
registry: Registry[Condition[TaggedObject], P],
|
||||||
) -> Callable[[type[HasAnyTagConfig]], type[HasAnyTagConfig]]:
|
|
||||||
def decorator(
|
|
||||||
config_cls: type[HasAnyTagConfig],
|
config_cls: type[HasAnyTagConfig],
|
||||||
) -> type[HasAnyTagConfig]:
|
) -> None:
|
||||||
@registry.register(config_cls)
|
|
||||||
def builder(
|
def builder(
|
||||||
config: HasAnyTagConfig,
|
config: HasAnyTagConfig,
|
||||||
*args: P.args,
|
*args: P.args,
|
||||||
@ -249,9 +354,7 @@ def register_has_any_tag_condition(
|
|||||||
) -> Condition[TaggedObject]:
|
) -> Condition[TaggedObject]:
|
||||||
return HasAnyTag(config.tags)
|
return HasAnyTag(config.tags)
|
||||||
|
|
||||||
return config_cls
|
registry.register(config_cls)(builder)
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def register_not_condition(
|
def register_not_condition(
|
||||||
|
|||||||
@ -54,10 +54,10 @@ class RecordingConditionImportConfig(ImportConfig):
|
|||||||
name: Literal["import"] = "import"
|
name: Literal["import"] = "import"
|
||||||
|
|
||||||
|
|
||||||
register_id_in_list_condition(recording_conditions)(IdInListConfig)
|
register_id_in_list_condition(recording_conditions, IdInListConfig)
|
||||||
register_has_tag_condition(recording_conditions)(HasTagConfig)
|
register_has_tag_condition(recording_conditions, HasTagConfig)
|
||||||
register_has_all_tags_condition(recording_conditions)(HasAllTagsConfig)
|
register_has_all_tags_condition(recording_conditions, HasAllTagsConfig)
|
||||||
register_has_any_tag_condition(recording_conditions)(HasAnyTagConfig)
|
register_has_any_tag_condition(recording_conditions, HasAnyTagConfig)
|
||||||
|
|
||||||
|
|
||||||
@register_all_of_condition(recording_conditions)
|
@register_all_of_condition(recording_conditions)
|
||||||
|
|||||||
@ -63,10 +63,10 @@ class SoundEventConditionImportConfig(ImportConfig):
|
|||||||
name: Literal["import"] = "import"
|
name: Literal["import"] = "import"
|
||||||
|
|
||||||
|
|
||||||
register_has_tag_condition(sound_event_conditions)(HasTagConfig)
|
register_has_tag_condition(sound_event_conditions, HasTagConfig)
|
||||||
register_has_all_tags_condition(sound_event_conditions)(HasAllTagsConfig)
|
register_has_all_tags_condition(sound_event_conditions, HasAllTagsConfig)
|
||||||
register_has_any_tag_condition(sound_event_conditions)(HasAnyTagConfig)
|
register_has_any_tag_condition(sound_event_conditions, HasAnyTagConfig)
|
||||||
register_id_in_list_condition(sound_event_conditions)(IdInListConfig)
|
register_id_in_list_condition(sound_event_conditions, IdInListConfig)
|
||||||
|
|
||||||
|
|
||||||
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
||||||
|
|||||||
@ -121,7 +121,62 @@ def test_id_in_list_condition_supports_txt_format(
|
|||||||
f"""
|
f"""
|
||||||
name: id_in_list
|
name: id_in_list
|
||||||
path: {ids_path}
|
path: {ids_path}
|
||||||
list_format: txt
|
format: txt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert condition(recording_a)
|
||||||
|
assert not condition(recording_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_id_in_list_condition_supports_json_field(
|
||||||
|
tmp_path: Path,
|
||||||
|
create_recording,
|
||||||
|
) -> None:
|
||||||
|
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||||
|
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||||
|
ids_path = tmp_path / "recording_ids.json"
|
||||||
|
ids_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"train": [str(recording_a.uuid)],
|
||||||
|
"val": [str(recording_b.uuid)],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
condition = build_recording_condition_from_yaml(
|
||||||
|
tmp_path,
|
||||||
|
f"""
|
||||||
|
name: id_in_list
|
||||||
|
path: {ids_path}
|
||||||
|
format:
|
||||||
|
name: json
|
||||||
|
field: train
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert condition(recording_a)
|
||||||
|
assert not condition(recording_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_id_in_list_condition_supports_csv_column(
|
||||||
|
tmp_path: Path,
|
||||||
|
create_recording,
|
||||||
|
) -> None:
|
||||||
|
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||||
|
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||||
|
ids_path = tmp_path / "recording_ids.csv"
|
||||||
|
ids_path.write_text(f"recording_uuid\n{recording_a.uuid}\n")
|
||||||
|
|
||||||
|
condition = build_recording_condition_from_yaml(
|
||||||
|
tmp_path,
|
||||||
|
f"""
|
||||||
|
name: id_in_list
|
||||||
|
path: {ids_path}
|
||||||
|
format:
|
||||||
|
name: csv
|
||||||
|
column: recording_uuid
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user