mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Add path_in_list condition
This commit is contained in:
parent
1579bbc6c5
commit
da113eaea8
@ -19,6 +19,7 @@ from batdetect2.data.conditions.common import (
|
||||
TxtList,
|
||||
)
|
||||
from batdetect2.data.conditions.recordings import (
|
||||
PathInListConfig,
|
||||
RecordingAllOfConfig,
|
||||
RecordingAnyOfConfig,
|
||||
RecordingCondition,
|
||||
@ -61,6 +62,7 @@ __all__ = [
|
||||
"ListFormatConfig",
|
||||
"NotConfig",
|
||||
"Operator",
|
||||
"PathInListConfig",
|
||||
"RecordingCondition",
|
||||
"RecordingConditionConfig",
|
||||
"RecordingConditionImportConfig",
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from loguru import logger
|
||||
from pydantic import Field, model_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
@ -14,8 +17,12 @@ from batdetect2.data.conditions.common import (
|
||||
HasAnyTagConfig,
|
||||
HasTagConfig,
|
||||
IdInListConfig,
|
||||
JsonList,
|
||||
ListFormatConfig,
|
||||
MultiConditionConfigBase,
|
||||
NotConditionConfigBase,
|
||||
build_list_loader,
|
||||
list_loaders,
|
||||
register_all_of_condition,
|
||||
register_any_of_condition,
|
||||
register_has_all_tags_condition,
|
||||
@ -27,6 +34,7 @@ from batdetect2.data.conditions.common import (
|
||||
|
||||
__all__ = [
|
||||
"IdInListConfig",
|
||||
"PathInListConfig",
|
||||
"RecordingAllOfConfig",
|
||||
"RecordingAnyOfConfig",
|
||||
"RecordingCondition",
|
||||
@ -60,6 +68,116 @@ register_has_all_tags_condition(recording_conditions, HasAllTagsConfig)
|
||||
register_has_any_tag_condition(recording_conditions, HasAnyTagConfig)
|
||||
|
||||
|
||||
class PathInListConfig(BaseConfig):
|
||||
name: Literal["path_in_list"] = "path_in_list"
|
||||
path: Path
|
||||
format: ListFormatConfig = JsonList()
|
||||
base_dir: Path | None = None
|
||||
on_outside: Literal["allow", "warn", "error"] = "allow"
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_format(cls, values):
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
|
||||
format_config = values.get("format")
|
||||
|
||||
if isinstance(format_config, str):
|
||||
values = values.copy()
|
||||
config_class = list_loaders.get_config_type(format_config)
|
||||
values["format"] = config_class().model_dump()
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class PathInList:
|
||||
def __init__(
|
||||
self,
|
||||
paths: set[Path],
|
||||
base_dir: Path | None,
|
||||
on_outside: Literal["allow", "warn", "error"],
|
||||
):
|
||||
self.paths = paths
|
||||
self.base_dir = base_dir
|
||||
self.on_outside = on_outside
|
||||
|
||||
def __call__(self, recording: data.Recording) -> bool:
|
||||
normalized_path = self._normalize_recording_path(recording.path)
|
||||
|
||||
if normalized_path is None:
|
||||
return True
|
||||
|
||||
return normalized_path in self.paths
|
||||
|
||||
def _normalize_recording_path(self, path: data.PathLike) -> Path | None:
|
||||
recording_path = Path(path)
|
||||
|
||||
if self.base_dir is None:
|
||||
return recording_path
|
||||
|
||||
if not recording_path.is_absolute():
|
||||
return recording_path
|
||||
|
||||
try:
|
||||
return recording_path.relative_to(self.base_dir)
|
||||
except ValueError as err:
|
||||
if self.on_outside == "allow":
|
||||
return None
|
||||
|
||||
if self.on_outside == "warn":
|
||||
logger.warning(
|
||||
"Recording path '{}' is outside '{}' in path_in_list; "
|
||||
"allowing.",
|
||||
recording_path,
|
||||
self.base_dir,
|
||||
)
|
||||
return None
|
||||
|
||||
raise ValueError(
|
||||
f"Recording path '{recording_path}' is outside "
|
||||
f"'{self.base_dir}' for 'path_in_list'."
|
||||
) from err
|
||||
|
||||
@recording_conditions.register(PathInListConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: PathInListConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> "PathInList":
|
||||
list_path = config.path
|
||||
|
||||
if base_dir is not None and not list_path.is_absolute():
|
||||
list_path = Path(base_dir) / list_path
|
||||
|
||||
match_base_dir = config.base_dir
|
||||
if (
|
||||
match_base_dir is not None
|
||||
and base_dir is not None
|
||||
and not match_base_dir.is_absolute()
|
||||
):
|
||||
match_base_dir = Path(base_dir) / match_base_dir
|
||||
|
||||
loader = build_list_loader(config.format)
|
||||
|
||||
paths = {
|
||||
Path(value).relative_to(match_base_dir)
|
||||
if (
|
||||
match_base_dir is not None
|
||||
and Path(value).is_absolute()
|
||||
and Path(value).is_relative_to(match_base_dir)
|
||||
)
|
||||
else Path(value)
|
||||
for value in loader(list_path)
|
||||
}
|
||||
|
||||
return PathInList(
|
||||
paths=paths,
|
||||
base_dir=match_base_dir,
|
||||
on_outside=config.on_outside,
|
||||
)
|
||||
|
||||
|
||||
@register_all_of_condition(recording_conditions)
|
||||
class RecordingAllOfConfig(MultiConditionConfigBase):
|
||||
name: Literal["all_of"] = "all_of"
|
||||
@ -80,6 +198,7 @@ class RecordingNotConfig(NotConditionConfigBase):
|
||||
|
||||
RecordingConditionConfig = Annotated[
|
||||
IdInListConfig
|
||||
| PathInListConfig
|
||||
| HasTagConfig
|
||||
| HasAllTagsConfig
|
||||
| HasAnyTagConfig
|
||||
|
||||
@ -184,6 +184,203 @@ def test_id_in_list_condition_supports_csv_column(
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_path_in_list_condition_supports_txt_format(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
audio_dir = tmp_path / "audio"
|
||||
audio_dir.mkdir()
|
||||
recording_a = create_recording(path=audio_dir / "a.wav")
|
||||
recording_b = create_recording(path=audio_dir / "b.wav")
|
||||
paths_file = tmp_path / "recording_paths.txt"
|
||||
paths_file.write_text(f"{recording_a.path}\n")
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: path_in_list
|
||||
path: {paths_file}
|
||||
format: txt
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_path_in_list_condition_supports_json_field(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
audio_dir = tmp_path / "audio"
|
||||
audio_dir.mkdir()
|
||||
recording_a = create_recording(path=audio_dir / "a.wav")
|
||||
recording_b = create_recording(path=audio_dir / "b.wav")
|
||||
paths_file = tmp_path / "recording_paths.json"
|
||||
paths_file.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"train": [str(recording_a.path)],
|
||||
"val": [str(recording_b.path)],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: path_in_list
|
||||
path: {paths_file}
|
||||
format:
|
||||
name: json
|
||||
field: train
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_path_in_list_condition_supports_csv_column(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
audio_dir = tmp_path / "audio"
|
||||
audio_dir.mkdir()
|
||||
recording_a = create_recording(path=audio_dir / "a.wav")
|
||||
recording_b = create_recording(path=audio_dir / "b.wav")
|
||||
paths_file = tmp_path / "recording_paths.csv"
|
||||
paths_file.write_text(f"recording_path\n{recording_a.path}\n")
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: path_in_list
|
||||
path: {paths_file}
|
||||
format:
|
||||
name: csv
|
||||
column: recording_path
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_path_in_list_condition_uses_base_dir(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
data_dir = tmp_path / "dataset"
|
||||
audio_dir = data_dir / "audio"
|
||||
audio_dir.mkdir(parents=True)
|
||||
recording_a = create_recording(path=audio_dir / "a.wav")
|
||||
recording_b = create_recording(path=audio_dir / "b.wav")
|
||||
paths_file = tmp_path / "recording_paths.txt"
|
||||
paths_file.write_text(f"{recording_a.path}\n")
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: path_in_list
|
||||
path: {paths_file}
|
||||
format: txt
|
||||
base_dir: {data_dir}
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_path_in_list_condition_outside_allow(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
data_dir = tmp_path / "dataset"
|
||||
inside_dir = data_dir / "audio"
|
||||
inside_dir.mkdir(parents=True)
|
||||
outside_dir = tmp_path / "other"
|
||||
outside_dir.mkdir()
|
||||
recording_inside = create_recording(path=inside_dir / "a.wav")
|
||||
recording_outside = create_recording(path=outside_dir / "x.wav")
|
||||
paths_file = tmp_path / "recording_paths.txt"
|
||||
paths_file.write_text("dataset/audio/unknown.wav\n")
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: path_in_list
|
||||
path: {paths_file}
|
||||
format: txt
|
||||
base_dir: {data_dir}
|
||||
on_outside: allow
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_outside)
|
||||
assert not condition(recording_inside)
|
||||
|
||||
|
||||
def test_path_in_list_condition_outside_warn(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
data_dir = tmp_path / "dataset"
|
||||
inside_dir = data_dir / "audio"
|
||||
inside_dir.mkdir(parents=True)
|
||||
outside_dir = tmp_path / "other"
|
||||
outside_dir.mkdir()
|
||||
recording_inside = create_recording(path=inside_dir / "a.wav")
|
||||
recording_outside = create_recording(path=outside_dir / "x.wav")
|
||||
paths_file = tmp_path / "recording_paths.txt"
|
||||
paths_file.write_text("dataset/audio/unknown.wav\n")
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: path_in_list
|
||||
path: {paths_file}
|
||||
format: txt
|
||||
base_dir: {data_dir}
|
||||
on_outside: warn
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_outside)
|
||||
assert not condition(recording_inside)
|
||||
|
||||
|
||||
def test_path_in_list_condition_outside_error(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
data_dir = tmp_path / "dataset"
|
||||
inside_dir = data_dir / "audio"
|
||||
inside_dir.mkdir(parents=True)
|
||||
outside_dir = tmp_path / "other"
|
||||
outside_dir.mkdir()
|
||||
recording_inside = create_recording(path=inside_dir / "a.wav")
|
||||
recording_outside = create_recording(path=outside_dir / "x.wav")
|
||||
paths_file = tmp_path / "recording_paths.txt"
|
||||
paths_file.write_text(f"{recording_inside.path}\n")
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: path_in_list
|
||||
path: {paths_file}
|
||||
format: txt
|
||||
base_dir: {data_dir}
|
||||
on_outside: error
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_inside)
|
||||
with pytest.raises(ValueError, match="outside"):
|
||||
condition(recording_outside)
|
||||
|
||||
|
||||
def test_has_tag_condition(tmp_path: Path, create_recording) -> None:
|
||||
train = data.Tag(key="split", value="train")
|
||||
val = data.Tag(key="split", value="val")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user