Add path_in_list condition

This commit is contained in:
mbsantiago 2026-04-04 10:44:35 +01:00
parent 1579bbc6c5
commit da113eaea8
3 changed files with 319 additions and 1 deletions

View File

@ -19,6 +19,7 @@ from batdetect2.data.conditions.common import (
TxtList, TxtList,
) )
from batdetect2.data.conditions.recordings import ( from batdetect2.data.conditions.recordings import (
PathInListConfig,
RecordingAllOfConfig, RecordingAllOfConfig,
RecordingAnyOfConfig, RecordingAnyOfConfig,
RecordingCondition, RecordingCondition,
@ -61,6 +62,7 @@ __all__ = [
"ListFormatConfig", "ListFormatConfig",
"NotConfig", "NotConfig",
"Operator", "Operator",
"PathInListConfig",
"RecordingCondition", "RecordingCondition",
"RecordingConditionConfig", "RecordingConditionConfig",
"RecordingConditionImportConfig", "RecordingConditionImportConfig",

View File

@ -1,9 +1,12 @@
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Annotated, Literal from typing import Annotated, Literal
from pydantic import Field from loguru import logger
from pydantic import Field, model_validator
from soundevent import data from soundevent import data
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import ( from batdetect2.core.registries import (
ImportConfig, ImportConfig,
Registry, Registry,
@ -14,8 +17,12 @@ from batdetect2.data.conditions.common import (
HasAnyTagConfig, HasAnyTagConfig,
HasTagConfig, HasTagConfig,
IdInListConfig, IdInListConfig,
JsonList,
ListFormatConfig,
MultiConditionConfigBase, MultiConditionConfigBase,
NotConditionConfigBase, NotConditionConfigBase,
build_list_loader,
list_loaders,
register_all_of_condition, register_all_of_condition,
register_any_of_condition, register_any_of_condition,
register_has_all_tags_condition, register_has_all_tags_condition,
@ -27,6 +34,7 @@ from batdetect2.data.conditions.common import (
__all__ = [ __all__ = [
"IdInListConfig", "IdInListConfig",
"PathInListConfig",
"RecordingAllOfConfig", "RecordingAllOfConfig",
"RecordingAnyOfConfig", "RecordingAnyOfConfig",
"RecordingCondition", "RecordingCondition",
@ -60,6 +68,116 @@ register_has_all_tags_condition(recording_conditions, HasAllTagsConfig)
register_has_any_tag_condition(recording_conditions, HasAnyTagConfig) register_has_any_tag_condition(recording_conditions, HasAnyTagConfig)
class PathInListConfig(BaseConfig):
name: Literal["path_in_list"] = "path_in_list"
path: Path
format: ListFormatConfig = JsonList()
base_dir: Path | None = None
on_outside: Literal["allow", "warn", "error"] = "allow"
@model_validator(mode="before")
@classmethod
def _normalize_format(cls, values):
if not isinstance(values, dict):
return values
format_config = values.get("format")
if isinstance(format_config, str):
values = values.copy()
config_class = list_loaders.get_config_type(format_config)
values["format"] = config_class().model_dump()
return values
class PathInList:
def __init__(
self,
paths: set[Path],
base_dir: Path | None,
on_outside: Literal["allow", "warn", "error"],
):
self.paths = paths
self.base_dir = base_dir
self.on_outside = on_outside
def __call__(self, recording: data.Recording) -> bool:
normalized_path = self._normalize_recording_path(recording.path)
if normalized_path is None:
return True
return normalized_path in self.paths
def _normalize_recording_path(self, path: data.PathLike) -> Path | None:
recording_path = Path(path)
if self.base_dir is None:
return recording_path
if not recording_path.is_absolute():
return recording_path
try:
return recording_path.relative_to(self.base_dir)
except ValueError as err:
if self.on_outside == "allow":
return None
if self.on_outside == "warn":
logger.warning(
"Recording path '{}' is outside '{}' in path_in_list; "
"allowing.",
recording_path,
self.base_dir,
)
return None
raise ValueError(
f"Recording path '{recording_path}' is outside "
f"'{self.base_dir}' for 'path_in_list'."
) from err
@recording_conditions.register(PathInListConfig)
@staticmethod
def from_config(
config: PathInListConfig,
base_dir: data.PathLike | None = None,
) -> "PathInList":
list_path = config.path
if base_dir is not None and not list_path.is_absolute():
list_path = Path(base_dir) / list_path
match_base_dir = config.base_dir
if (
match_base_dir is not None
and base_dir is not None
and not match_base_dir.is_absolute()
):
match_base_dir = Path(base_dir) / match_base_dir
loader = build_list_loader(config.format)
paths = {
Path(value).relative_to(match_base_dir)
if (
match_base_dir is not None
and Path(value).is_absolute()
and Path(value).is_relative_to(match_base_dir)
)
else Path(value)
for value in loader(list_path)
}
return PathInList(
paths=paths,
base_dir=match_base_dir,
on_outside=config.on_outside,
)
@register_all_of_condition(recording_conditions) @register_all_of_condition(recording_conditions)
class RecordingAllOfConfig(MultiConditionConfigBase): class RecordingAllOfConfig(MultiConditionConfigBase):
name: Literal["all_of"] = "all_of" name: Literal["all_of"] = "all_of"
@ -80,6 +198,7 @@ class RecordingNotConfig(NotConditionConfigBase):
RecordingConditionConfig = Annotated[ RecordingConditionConfig = Annotated[
IdInListConfig IdInListConfig
| PathInListConfig
| HasTagConfig | HasTagConfig
| HasAllTagsConfig | HasAllTagsConfig
| HasAnyTagConfig | HasAnyTagConfig

View File

@ -184,6 +184,203 @@ def test_id_in_list_condition_supports_csv_column(
assert not condition(recording_b) assert not condition(recording_b)
def test_path_in_list_condition_supports_txt_format(
tmp_path: Path,
create_recording,
) -> None:
audio_dir = tmp_path / "audio"
audio_dir.mkdir()
recording_a = create_recording(path=audio_dir / "a.wav")
recording_b = create_recording(path=audio_dir / "b.wav")
paths_file = tmp_path / "recording_paths.txt"
paths_file.write_text(f"{recording_a.path}\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format: txt
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_path_in_list_condition_supports_json_field(
tmp_path: Path,
create_recording,
) -> None:
audio_dir = tmp_path / "audio"
audio_dir.mkdir()
recording_a = create_recording(path=audio_dir / "a.wav")
recording_b = create_recording(path=audio_dir / "b.wav")
paths_file = tmp_path / "recording_paths.json"
paths_file.write_text(
json.dumps(
{
"train": [str(recording_a.path)],
"val": [str(recording_b.path)],
}
)
)
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format:
name: json
field: train
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_path_in_list_condition_supports_csv_column(
tmp_path: Path,
create_recording,
) -> None:
audio_dir = tmp_path / "audio"
audio_dir.mkdir()
recording_a = create_recording(path=audio_dir / "a.wav")
recording_b = create_recording(path=audio_dir / "b.wav")
paths_file = tmp_path / "recording_paths.csv"
paths_file.write_text(f"recording_path\n{recording_a.path}\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format:
name: csv
column: recording_path
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_path_in_list_condition_uses_base_dir(
tmp_path: Path,
create_recording,
) -> None:
data_dir = tmp_path / "dataset"
audio_dir = data_dir / "audio"
audio_dir.mkdir(parents=True)
recording_a = create_recording(path=audio_dir / "a.wav")
recording_b = create_recording(path=audio_dir / "b.wav")
paths_file = tmp_path / "recording_paths.txt"
paths_file.write_text(f"{recording_a.path}\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format: txt
base_dir: {data_dir}
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_path_in_list_condition_outside_allow(
tmp_path: Path,
create_recording,
) -> None:
data_dir = tmp_path / "dataset"
inside_dir = data_dir / "audio"
inside_dir.mkdir(parents=True)
outside_dir = tmp_path / "other"
outside_dir.mkdir()
recording_inside = create_recording(path=inside_dir / "a.wav")
recording_outside = create_recording(path=outside_dir / "x.wav")
paths_file = tmp_path / "recording_paths.txt"
paths_file.write_text("dataset/audio/unknown.wav\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format: txt
base_dir: {data_dir}
on_outside: allow
""",
)
assert condition(recording_outside)
assert not condition(recording_inside)
def test_path_in_list_condition_outside_warn(
tmp_path: Path,
create_recording,
) -> None:
data_dir = tmp_path / "dataset"
inside_dir = data_dir / "audio"
inside_dir.mkdir(parents=True)
outside_dir = tmp_path / "other"
outside_dir.mkdir()
recording_inside = create_recording(path=inside_dir / "a.wav")
recording_outside = create_recording(path=outside_dir / "x.wav")
paths_file = tmp_path / "recording_paths.txt"
paths_file.write_text("dataset/audio/unknown.wav\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format: txt
base_dir: {data_dir}
on_outside: warn
""",
)
assert condition(recording_outside)
assert not condition(recording_inside)
def test_path_in_list_condition_outside_error(
tmp_path: Path,
create_recording,
) -> None:
data_dir = tmp_path / "dataset"
inside_dir = data_dir / "audio"
inside_dir.mkdir(parents=True)
outside_dir = tmp_path / "other"
outside_dir.mkdir()
recording_inside = create_recording(path=inside_dir / "a.wav")
recording_outside = create_recording(path=outside_dir / "x.wav")
paths_file = tmp_path / "recording_paths.txt"
paths_file.write_text(f"{recording_inside.path}\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format: txt
base_dir: {data_dir}
on_outside: error
""",
)
assert condition(recording_inside)
with pytest.raises(ValueError, match="outside"):
condition(recording_outside)
def test_has_tag_condition(tmp_path: Path, create_recording) -> None: def test_has_tag_condition(tmp_path: Path, create_recording) -> None:
train = data.Tag(key="split", value="train") train = data.Tag(key="split", value="train")
val = data.Tag(key="split", value="val") val = data.Tag(key="split", value="val")