diff --git a/src/batdetect2/data/conditions/__init__.py b/src/batdetect2/data/conditions/__init__.py index e6470c9..6721393 100644 --- a/src/batdetect2/data/conditions/__init__.py +++ b/src/batdetect2/data/conditions/__init__.py @@ -19,6 +19,7 @@ from batdetect2.data.conditions.common import ( TxtList, ) from batdetect2.data.conditions.recordings import ( + PathInListConfig, RecordingAllOfConfig, RecordingAnyOfConfig, RecordingCondition, @@ -61,6 +62,7 @@ __all__ = [ "ListFormatConfig", "NotConfig", "Operator", + "PathInListConfig", "RecordingCondition", "RecordingConditionConfig", "RecordingConditionImportConfig", diff --git a/src/batdetect2/data/conditions/recordings.py b/src/batdetect2/data/conditions/recordings.py index d6757f1..754cf4f 100644 --- a/src/batdetect2/data/conditions/recordings.py +++ b/src/batdetect2/data/conditions/recordings.py @@ -1,9 +1,12 @@ from collections.abc import Callable, Sequence +from pathlib import Path from typing import Annotated, Literal -from pydantic import Field +from loguru import logger +from pydantic import Field, model_validator from soundevent import data +from batdetect2.core.configs import BaseConfig from batdetect2.core.registries import ( ImportConfig, Registry, @@ -14,8 +17,12 @@ from batdetect2.data.conditions.common import ( HasAnyTagConfig, HasTagConfig, IdInListConfig, + JsonList, + ListFormatConfig, MultiConditionConfigBase, NotConditionConfigBase, + build_list_loader, + list_loaders, register_all_of_condition, register_any_of_condition, register_has_all_tags_condition, @@ -27,6 +34,7 @@ from batdetect2.data.conditions.common import ( __all__ = [ "IdInListConfig", + "PathInListConfig", "RecordingAllOfConfig", "RecordingAnyOfConfig", "RecordingCondition", @@ -60,6 +68,116 @@ register_has_all_tags_condition(recording_conditions, HasAllTagsConfig) register_has_any_tag_condition(recording_conditions, HasAnyTagConfig) +class PathInListConfig(BaseConfig): + name: Literal["path_in_list"] = "path_in_list" + path: Path + format: ListFormatConfig = JsonList() + base_dir: Path | None = None + on_outside: Literal["allow", "warn", "error"] = "allow" + + @model_validator(mode="before") + @classmethod + def _normalize_format(cls, values): + if not isinstance(values, dict): + return values + + format_config = values.get("format") + + if isinstance(format_config, str): + values = values.copy() + config_class = list_loaders.get_config_type(format_config) + values["format"] = config_class().model_dump() + + return values + + +class PathInList: + def __init__( + self, + paths: set[Path], + base_dir: Path | None, + on_outside: Literal["allow", "warn", "error"], + ): + self.paths = paths + self.base_dir = base_dir + self.on_outside = on_outside + + def __call__(self, recording: data.Recording) -> bool: + normalized_path = self._normalize_recording_path(recording.path) + + if normalized_path is None: + return True + + return normalized_path in self.paths + + def _normalize_recording_path(self, path: data.PathLike) -> Path | None: + recording_path = Path(path) + + if self.base_dir is None: + return recording_path + + if not recording_path.is_absolute(): + return recording_path + + try: + return recording_path.relative_to(self.base_dir) + except ValueError as err: + if self.on_outside == "allow": + return None + + if self.on_outside == "warn": + logger.warning( + "Recording path '{}' is outside '{}' in path_in_list; " + "allowing.", + recording_path, + self.base_dir, + ) + return None + + raise ValueError( + f"Recording path '{recording_path}' is outside " + f"'{self.base_dir}' for 'path_in_list'." + ) from err + + @recording_conditions.register(PathInListConfig) + @staticmethod + def from_config( + config: PathInListConfig, + base_dir: data.PathLike | None = None, + ) -> "PathInList": + list_path = config.path + + if base_dir is not None and not list_path.is_absolute(): + list_path = Path(base_dir) / list_path + + match_base_dir = config.base_dir + if ( + match_base_dir is not None + and base_dir is not None + and not match_base_dir.is_absolute() + ): + match_base_dir = Path(base_dir) / match_base_dir + + loader = build_list_loader(config.format) + + paths = { + Path(value).relative_to(match_base_dir) + if ( + match_base_dir is not None + and Path(value).is_absolute() + and Path(value).is_relative_to(match_base_dir) + ) + else Path(value) + for value in loader(list_path) + } + + return PathInList( + paths=paths, + base_dir=match_base_dir, + on_outside=config.on_outside, + ) + + @register_all_of_condition(recording_conditions) class RecordingAllOfConfig(MultiConditionConfigBase): name: Literal["all_of"] = "all_of" @@ -80,6 +198,7 @@ class RecordingNotConfig(NotConditionConfigBase): RecordingConditionConfig = Annotated[ IdInListConfig + | PathInListConfig | HasTagConfig | HasAllTagsConfig | HasAnyTagConfig diff --git a/tests/test_data/test_conditions/test_recording.py b/tests/test_data/test_conditions/test_recording.py index 35f8e84..0ae747a 100644 --- a/tests/test_data/test_conditions/test_recording.py +++ b/tests/test_data/test_conditions/test_recording.py @@ -184,6 +184,203 @@ def test_id_in_list_condition_supports_csv_column( assert not condition(recording_b) +def test_path_in_list_condition_supports_txt_format( + tmp_path: Path, + create_recording, +) -> None: + audio_dir = tmp_path / "audio" + audio_dir.mkdir() + recording_a = create_recording(path=audio_dir / "a.wav") + recording_b = create_recording(path=audio_dir / "b.wav") + paths_file = tmp_path / "recording_paths.txt" + paths_file.write_text(f"{recording_a.path}\n") + + condition = build_recording_condition_from_yaml( + tmp_path, + f""" + name: path_in_list + path: {paths_file} + format: txt + """, + ) + + assert condition(recording_a) + assert not condition(recording_b) + + +def test_path_in_list_condition_supports_json_field( + tmp_path: Path, + create_recording, +) -> None: + audio_dir = tmp_path / "audio" + audio_dir.mkdir() + recording_a = create_recording(path=audio_dir / "a.wav") + recording_b = create_recording(path=audio_dir / "b.wav") + paths_file = tmp_path / "recording_paths.json" + paths_file.write_text( + json.dumps( + { + "train": [str(recording_a.path)], + "val": [str(recording_b.path)], + } + ) + ) + + condition = build_recording_condition_from_yaml( + tmp_path, + f""" + name: path_in_list + path: {paths_file} + format: + name: json + field: train + """, + ) + + assert condition(recording_a) + assert not condition(recording_b) + + +def test_path_in_list_condition_supports_csv_column( + tmp_path: Path, + create_recording, +) -> None: + audio_dir = tmp_path / "audio" + audio_dir.mkdir() + recording_a = create_recording(path=audio_dir / "a.wav") + recording_b = create_recording(path=audio_dir / "b.wav") + paths_file = tmp_path / "recording_paths.csv" + paths_file.write_text(f"recording_path\n{recording_a.path}\n") + + condition = build_recording_condition_from_yaml( + tmp_path, + f""" + name: path_in_list + path: {paths_file} + format: + name: csv + column: recording_path + """, + ) + + assert condition(recording_a) + assert not condition(recording_b) + + +def test_path_in_list_condition_uses_base_dir( + tmp_path: Path, + create_recording, +) -> None: + data_dir = tmp_path / "dataset" + audio_dir = data_dir / "audio" + audio_dir.mkdir(parents=True) + recording_a = create_recording(path=audio_dir / "a.wav") + recording_b = create_recording(path=audio_dir / "b.wav") + paths_file = tmp_path / "recording_paths.txt" + paths_file.write_text(f"{recording_a.path}\n") + + condition = build_recording_condition_from_yaml( + tmp_path, + f""" + name: path_in_list + path: {paths_file} + format: txt + base_dir: {data_dir} + """, + ) + + assert condition(recording_a) + assert not condition(recording_b) + + +def test_path_in_list_condition_outside_allow( + tmp_path: Path, + create_recording, +) -> None: + data_dir = tmp_path / "dataset" + inside_dir = data_dir / "audio" + inside_dir.mkdir(parents=True) + outside_dir = tmp_path / "other" + outside_dir.mkdir() + recording_inside = create_recording(path=inside_dir / "a.wav") + recording_outside = create_recording(path=outside_dir / "x.wav") + paths_file = tmp_path / "recording_paths.txt" + paths_file.write_text("dataset/audio/unknown.wav\n") + + condition = build_recording_condition_from_yaml( + tmp_path, + f""" + name: path_in_list + path: {paths_file} + format: txt + base_dir: {data_dir} + on_outside: allow + """, + ) + + assert condition(recording_outside) + assert not condition(recording_inside) + + +def test_path_in_list_condition_outside_warn( + tmp_path: Path, + create_recording, +) -> None: + data_dir = tmp_path / "dataset" + inside_dir = data_dir / "audio" + inside_dir.mkdir(parents=True) + outside_dir = tmp_path / "other" + outside_dir.mkdir() + recording_inside = create_recording(path=inside_dir / "a.wav") + recording_outside = create_recording(path=outside_dir / "x.wav") + paths_file = tmp_path / "recording_paths.txt" + paths_file.write_text("dataset/audio/unknown.wav\n") + + condition = build_recording_condition_from_yaml( + tmp_path, + f""" + name: path_in_list + path: {paths_file} + format: txt + base_dir: {data_dir} + on_outside: warn + """, + ) + + assert condition(recording_outside) + assert not condition(recording_inside) + + +def test_path_in_list_condition_outside_error( + tmp_path: Path, + create_recording, +) -> None: + data_dir = tmp_path / "dataset" + inside_dir = data_dir / "audio" + inside_dir.mkdir(parents=True) + outside_dir = tmp_path / "other" + outside_dir.mkdir() + recording_inside = create_recording(path=inside_dir / "a.wav") + recording_outside = create_recording(path=outside_dir / "x.wav") + paths_file = tmp_path / "recording_paths.txt" + paths_file.write_text(f"{recording_inside.path}\n") + + condition = build_recording_condition_from_yaml( + tmp_path, + f""" + name: path_in_list + path: {paths_file} + format: txt + base_dir: {data_dir} + on_outside: error + """, + ) + + assert condition(recording_inside) + with pytest.raises(ValueError, match="outside"): + condition(recording_outside) + + def test_has_tag_condition(tmp_path: Path, create_recording) -> None: train = data.Tag(key="split", value="train") val = data.Tag(key="split", value="val")