Add path_in_list condition

This commit is contained in:
mbsantiago 2026-04-04 10:44:35 +01:00
parent 1579bbc6c5
commit da113eaea8
3 changed files with 319 additions and 1 deletions

View File

@ -19,6 +19,7 @@ from batdetect2.data.conditions.common import (
TxtList,
)
from batdetect2.data.conditions.recordings import (
PathInListConfig,
RecordingAllOfConfig,
RecordingAnyOfConfig,
RecordingCondition,
@ -61,6 +62,7 @@ __all__ = [
"ListFormatConfig",
"NotConfig",
"Operator",
"PathInListConfig",
"RecordingCondition",
"RecordingConditionConfig",
"RecordingConditionImportConfig",

View File

@ -1,9 +1,12 @@
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Annotated, Literal
from pydantic import Field
from loguru import logger
from pydantic import Field, model_validator
from soundevent import data
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import (
ImportConfig,
Registry,
@ -14,8 +17,12 @@ from batdetect2.data.conditions.common import (
HasAnyTagConfig,
HasTagConfig,
IdInListConfig,
JsonList,
ListFormatConfig,
MultiConditionConfigBase,
NotConditionConfigBase,
build_list_loader,
list_loaders,
register_all_of_condition,
register_any_of_condition,
register_has_all_tags_condition,
@ -27,6 +34,7 @@ from batdetect2.data.conditions.common import (
__all__ = [
"IdInListConfig",
"PathInListConfig",
"RecordingAllOfConfig",
"RecordingAnyOfConfig",
"RecordingCondition",
@ -60,6 +68,116 @@ register_has_all_tags_condition(recording_conditions, HasAllTagsConfig)
register_has_any_tag_condition(recording_conditions, HasAnyTagConfig)
class PathInListConfig(BaseConfig):
name: Literal["path_in_list"] = "path_in_list"
path: Path
format: ListFormatConfig = JsonList()
base_dir: Path | None = None
on_outside: Literal["allow", "warn", "error"] = "allow"
@model_validator(mode="before")
@classmethod
def _normalize_format(cls, values):
if not isinstance(values, dict):
return values
format_config = values.get("format")
if isinstance(format_config, str):
values = values.copy()
config_class = list_loaders.get_config_type(format_config)
values["format"] = config_class().model_dump()
return values
class PathInList:
def __init__(
self,
paths: set[Path],
base_dir: Path | None,
on_outside: Literal["allow", "warn", "error"],
):
self.paths = paths
self.base_dir = base_dir
self.on_outside = on_outside
def __call__(self, recording: data.Recording) -> bool:
normalized_path = self._normalize_recording_path(recording.path)
if normalized_path is None:
return True
return normalized_path in self.paths
def _normalize_recording_path(self, path: data.PathLike) -> Path | None:
recording_path = Path(path)
if self.base_dir is None:
return recording_path
if not recording_path.is_absolute():
return recording_path
try:
return recording_path.relative_to(self.base_dir)
except ValueError as err:
if self.on_outside == "allow":
return None
if self.on_outside == "warn":
logger.warning(
"Recording path '{}' is outside '{}' in path_in_list; "
"allowing.",
recording_path,
self.base_dir,
)
return None
raise ValueError(
f"Recording path '{recording_path}' is outside "
f"'{self.base_dir}' for 'path_in_list'."
) from err
@recording_conditions.register(PathInListConfig)
@staticmethod
def from_config(
config: PathInListConfig,
base_dir: data.PathLike | None = None,
) -> "PathInList":
list_path = config.path
if base_dir is not None and not list_path.is_absolute():
list_path = Path(base_dir) / list_path
match_base_dir = config.base_dir
if (
match_base_dir is not None
and base_dir is not None
and not match_base_dir.is_absolute()
):
match_base_dir = Path(base_dir) / match_base_dir
loader = build_list_loader(config.format)
paths = {
Path(value).relative_to(match_base_dir)
if (
match_base_dir is not None
and Path(value).is_absolute()
and Path(value).is_relative_to(match_base_dir)
)
else Path(value)
for value in loader(list_path)
}
return PathInList(
paths=paths,
base_dir=match_base_dir,
on_outside=config.on_outside,
)
@register_all_of_condition(recording_conditions)
class RecordingAllOfConfig(MultiConditionConfigBase):
name: Literal["all_of"] = "all_of"
@ -80,6 +198,7 @@ class RecordingNotConfig(NotConditionConfigBase):
RecordingConditionConfig = Annotated[
IdInListConfig
| PathInListConfig
| HasTagConfig
| HasAllTagsConfig
| HasAnyTagConfig

View File

@ -184,6 +184,203 @@ def test_id_in_list_condition_supports_csv_column(
assert not condition(recording_b)
def test_path_in_list_condition_supports_txt_format(
tmp_path: Path,
create_recording,
) -> None:
audio_dir = tmp_path / "audio"
audio_dir.mkdir()
recording_a = create_recording(path=audio_dir / "a.wav")
recording_b = create_recording(path=audio_dir / "b.wav")
paths_file = tmp_path / "recording_paths.txt"
paths_file.write_text(f"{recording_a.path}\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format: txt
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_path_in_list_condition_supports_json_field(
tmp_path: Path,
create_recording,
) -> None:
audio_dir = tmp_path / "audio"
audio_dir.mkdir()
recording_a = create_recording(path=audio_dir / "a.wav")
recording_b = create_recording(path=audio_dir / "b.wav")
paths_file = tmp_path / "recording_paths.json"
paths_file.write_text(
json.dumps(
{
"train": [str(recording_a.path)],
"val": [str(recording_b.path)],
}
)
)
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format:
name: json
field: train
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_path_in_list_condition_supports_csv_column(
tmp_path: Path,
create_recording,
) -> None:
audio_dir = tmp_path / "audio"
audio_dir.mkdir()
recording_a = create_recording(path=audio_dir / "a.wav")
recording_b = create_recording(path=audio_dir / "b.wav")
paths_file = tmp_path / "recording_paths.csv"
paths_file.write_text(f"recording_path\n{recording_a.path}\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format:
name: csv
column: recording_path
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_path_in_list_condition_uses_base_dir(
tmp_path: Path,
create_recording,
) -> None:
data_dir = tmp_path / "dataset"
audio_dir = data_dir / "audio"
audio_dir.mkdir(parents=True)
recording_a = create_recording(path=audio_dir / "a.wav")
recording_b = create_recording(path=audio_dir / "b.wav")
paths_file = tmp_path / "recording_paths.txt"
paths_file.write_text(f"{recording_a.path}\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format: txt
base_dir: {data_dir}
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_path_in_list_condition_outside_allow(
tmp_path: Path,
create_recording,
) -> None:
data_dir = tmp_path / "dataset"
inside_dir = data_dir / "audio"
inside_dir.mkdir(parents=True)
outside_dir = tmp_path / "other"
outside_dir.mkdir()
recording_inside = create_recording(path=inside_dir / "a.wav")
recording_outside = create_recording(path=outside_dir / "x.wav")
paths_file = tmp_path / "recording_paths.txt"
paths_file.write_text("dataset/audio/unknown.wav\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format: txt
base_dir: {data_dir}
on_outside: allow
""",
)
assert condition(recording_outside)
assert not condition(recording_inside)
def test_path_in_list_condition_outside_warn(
tmp_path: Path,
create_recording,
) -> None:
data_dir = tmp_path / "dataset"
inside_dir = data_dir / "audio"
inside_dir.mkdir(parents=True)
outside_dir = tmp_path / "other"
outside_dir.mkdir()
recording_inside = create_recording(path=inside_dir / "a.wav")
recording_outside = create_recording(path=outside_dir / "x.wav")
paths_file = tmp_path / "recording_paths.txt"
paths_file.write_text("dataset/audio/unknown.wav\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format: txt
base_dir: {data_dir}
on_outside: warn
""",
)
assert condition(recording_outside)
assert not condition(recording_inside)
def test_path_in_list_condition_outside_error(
tmp_path: Path,
create_recording,
) -> None:
data_dir = tmp_path / "dataset"
inside_dir = data_dir / "audio"
inside_dir.mkdir(parents=True)
outside_dir = tmp_path / "other"
outside_dir.mkdir()
recording_inside = create_recording(path=inside_dir / "a.wav")
recording_outside = create_recording(path=outside_dir / "x.wav")
paths_file = tmp_path / "recording_paths.txt"
paths_file.write_text(f"{recording_inside.path}\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format: txt
base_dir: {data_dir}
on_outside: error
""",
)
assert condition(recording_inside)
with pytest.raises(ValueError, match="outside"):
condition(recording_outside)
def test_has_tag_condition(tmp_path: Path, create_recording) -> None:
train = data.Tag(key="split", value="train")
val = data.Tag(key="split", value="val")