mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Add path_in_list condition
This commit is contained in:
parent
1579bbc6c5
commit
da113eaea8
@ -19,6 +19,7 @@ from batdetect2.data.conditions.common import (
|
|||||||
TxtList,
|
TxtList,
|
||||||
)
|
)
|
||||||
from batdetect2.data.conditions.recordings import (
|
from batdetect2.data.conditions.recordings import (
|
||||||
|
PathInListConfig,
|
||||||
RecordingAllOfConfig,
|
RecordingAllOfConfig,
|
||||||
RecordingAnyOfConfig,
|
RecordingAnyOfConfig,
|
||||||
RecordingCondition,
|
RecordingCondition,
|
||||||
@ -61,6 +62,7 @@ __all__ = [
|
|||||||
"ListFormatConfig",
|
"ListFormatConfig",
|
||||||
"NotConfig",
|
"NotConfig",
|
||||||
"Operator",
|
"Operator",
|
||||||
|
"PathInListConfig",
|
||||||
"RecordingCondition",
|
"RecordingCondition",
|
||||||
"RecordingConditionConfig",
|
"RecordingConditionConfig",
|
||||||
"RecordingConditionImportConfig",
|
"RecordingConditionImportConfig",
|
||||||
|
|||||||
@ -1,9 +1,12 @@
|
|||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
|
from pathlib import Path
|
||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from loguru import logger
|
||||||
|
from pydantic import Field, model_validator
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.core.registries import (
|
from batdetect2.core.registries import (
|
||||||
ImportConfig,
|
ImportConfig,
|
||||||
Registry,
|
Registry,
|
||||||
@ -14,8 +17,12 @@ from batdetect2.data.conditions.common import (
|
|||||||
HasAnyTagConfig,
|
HasAnyTagConfig,
|
||||||
HasTagConfig,
|
HasTagConfig,
|
||||||
IdInListConfig,
|
IdInListConfig,
|
||||||
|
JsonList,
|
||||||
|
ListFormatConfig,
|
||||||
MultiConditionConfigBase,
|
MultiConditionConfigBase,
|
||||||
NotConditionConfigBase,
|
NotConditionConfigBase,
|
||||||
|
build_list_loader,
|
||||||
|
list_loaders,
|
||||||
register_all_of_condition,
|
register_all_of_condition,
|
||||||
register_any_of_condition,
|
register_any_of_condition,
|
||||||
register_has_all_tags_condition,
|
register_has_all_tags_condition,
|
||||||
@ -27,6 +34,7 @@ from batdetect2.data.conditions.common import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"IdInListConfig",
|
"IdInListConfig",
|
||||||
|
"PathInListConfig",
|
||||||
"RecordingAllOfConfig",
|
"RecordingAllOfConfig",
|
||||||
"RecordingAnyOfConfig",
|
"RecordingAnyOfConfig",
|
||||||
"RecordingCondition",
|
"RecordingCondition",
|
||||||
@ -60,6 +68,116 @@ register_has_all_tags_condition(recording_conditions, HasAllTagsConfig)
|
|||||||
register_has_any_tag_condition(recording_conditions, HasAnyTagConfig)
|
register_has_any_tag_condition(recording_conditions, HasAnyTagConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class PathInListConfig(BaseConfig):
|
||||||
|
name: Literal["path_in_list"] = "path_in_list"
|
||||||
|
path: Path
|
||||||
|
format: ListFormatConfig = JsonList()
|
||||||
|
base_dir: Path | None = None
|
||||||
|
on_outside: Literal["allow", "warn", "error"] = "allow"
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _normalize_format(cls, values):
|
||||||
|
if not isinstance(values, dict):
|
||||||
|
return values
|
||||||
|
|
||||||
|
format_config = values.get("format")
|
||||||
|
|
||||||
|
if isinstance(format_config, str):
|
||||||
|
values = values.copy()
|
||||||
|
config_class = list_loaders.get_config_type(format_config)
|
||||||
|
values["format"] = config_class().model_dump()
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class PathInList:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
paths: set[Path],
|
||||||
|
base_dir: Path | None,
|
||||||
|
on_outside: Literal["allow", "warn", "error"],
|
||||||
|
):
|
||||||
|
self.paths = paths
|
||||||
|
self.base_dir = base_dir
|
||||||
|
self.on_outside = on_outside
|
||||||
|
|
||||||
|
def __call__(self, recording: data.Recording) -> bool:
|
||||||
|
normalized_path = self._normalize_recording_path(recording.path)
|
||||||
|
|
||||||
|
if normalized_path is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return normalized_path in self.paths
|
||||||
|
|
||||||
|
def _normalize_recording_path(self, path: data.PathLike) -> Path | None:
|
||||||
|
recording_path = Path(path)
|
||||||
|
|
||||||
|
if self.base_dir is None:
|
||||||
|
return recording_path
|
||||||
|
|
||||||
|
if not recording_path.is_absolute():
|
||||||
|
return recording_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
return recording_path.relative_to(self.base_dir)
|
||||||
|
except ValueError as err:
|
||||||
|
if self.on_outside == "allow":
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.on_outside == "warn":
|
||||||
|
logger.warning(
|
||||||
|
"Recording path '{}' is outside '{}' in path_in_list; "
|
||||||
|
"allowing.",
|
||||||
|
recording_path,
|
||||||
|
self.base_dir,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Recording path '{recording_path}' is outside "
|
||||||
|
f"'{self.base_dir}' for 'path_in_list'."
|
||||||
|
) from err
|
||||||
|
|
||||||
|
@recording_conditions.register(PathInListConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: PathInListConfig,
|
||||||
|
base_dir: data.PathLike | None = None,
|
||||||
|
) -> "PathInList":
|
||||||
|
list_path = config.path
|
||||||
|
|
||||||
|
if base_dir is not None and not list_path.is_absolute():
|
||||||
|
list_path = Path(base_dir) / list_path
|
||||||
|
|
||||||
|
match_base_dir = config.base_dir
|
||||||
|
if (
|
||||||
|
match_base_dir is not None
|
||||||
|
and base_dir is not None
|
||||||
|
and not match_base_dir.is_absolute()
|
||||||
|
):
|
||||||
|
match_base_dir = Path(base_dir) / match_base_dir
|
||||||
|
|
||||||
|
loader = build_list_loader(config.format)
|
||||||
|
|
||||||
|
paths = {
|
||||||
|
Path(value).relative_to(match_base_dir)
|
||||||
|
if (
|
||||||
|
match_base_dir is not None
|
||||||
|
and Path(value).is_absolute()
|
||||||
|
and Path(value).is_relative_to(match_base_dir)
|
||||||
|
)
|
||||||
|
else Path(value)
|
||||||
|
for value in loader(list_path)
|
||||||
|
}
|
||||||
|
|
||||||
|
return PathInList(
|
||||||
|
paths=paths,
|
||||||
|
base_dir=match_base_dir,
|
||||||
|
on_outside=config.on_outside,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_all_of_condition(recording_conditions)
|
@register_all_of_condition(recording_conditions)
|
||||||
class RecordingAllOfConfig(MultiConditionConfigBase):
|
class RecordingAllOfConfig(MultiConditionConfigBase):
|
||||||
name: Literal["all_of"] = "all_of"
|
name: Literal["all_of"] = "all_of"
|
||||||
@ -80,6 +198,7 @@ class RecordingNotConfig(NotConditionConfigBase):
|
|||||||
|
|
||||||
RecordingConditionConfig = Annotated[
|
RecordingConditionConfig = Annotated[
|
||||||
IdInListConfig
|
IdInListConfig
|
||||||
|
| PathInListConfig
|
||||||
| HasTagConfig
|
| HasTagConfig
|
||||||
| HasAllTagsConfig
|
| HasAllTagsConfig
|
||||||
| HasAnyTagConfig
|
| HasAnyTagConfig
|
||||||
|
|||||||
@ -184,6 +184,203 @@ def test_id_in_list_condition_supports_csv_column(
|
|||||||
assert not condition(recording_b)
|
assert not condition(recording_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_path_in_list_condition_supports_txt_format(
|
||||||
|
tmp_path: Path,
|
||||||
|
create_recording,
|
||||||
|
) -> None:
|
||||||
|
audio_dir = tmp_path / "audio"
|
||||||
|
audio_dir.mkdir()
|
||||||
|
recording_a = create_recording(path=audio_dir / "a.wav")
|
||||||
|
recording_b = create_recording(path=audio_dir / "b.wav")
|
||||||
|
paths_file = tmp_path / "recording_paths.txt"
|
||||||
|
paths_file.write_text(f"{recording_a.path}\n")
|
||||||
|
|
||||||
|
condition = build_recording_condition_from_yaml(
|
||||||
|
tmp_path,
|
||||||
|
f"""
|
||||||
|
name: path_in_list
|
||||||
|
path: {paths_file}
|
||||||
|
format: txt
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert condition(recording_a)
|
||||||
|
assert not condition(recording_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_path_in_list_condition_supports_json_field(
|
||||||
|
tmp_path: Path,
|
||||||
|
create_recording,
|
||||||
|
) -> None:
|
||||||
|
audio_dir = tmp_path / "audio"
|
||||||
|
audio_dir.mkdir()
|
||||||
|
recording_a = create_recording(path=audio_dir / "a.wav")
|
||||||
|
recording_b = create_recording(path=audio_dir / "b.wav")
|
||||||
|
paths_file = tmp_path / "recording_paths.json"
|
||||||
|
paths_file.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"train": [str(recording_a.path)],
|
||||||
|
"val": [str(recording_b.path)],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
condition = build_recording_condition_from_yaml(
|
||||||
|
tmp_path,
|
||||||
|
f"""
|
||||||
|
name: path_in_list
|
||||||
|
path: {paths_file}
|
||||||
|
format:
|
||||||
|
name: json
|
||||||
|
field: train
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert condition(recording_a)
|
||||||
|
assert not condition(recording_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_path_in_list_condition_supports_csv_column(
|
||||||
|
tmp_path: Path,
|
||||||
|
create_recording,
|
||||||
|
) -> None:
|
||||||
|
audio_dir = tmp_path / "audio"
|
||||||
|
audio_dir.mkdir()
|
||||||
|
recording_a = create_recording(path=audio_dir / "a.wav")
|
||||||
|
recording_b = create_recording(path=audio_dir / "b.wav")
|
||||||
|
paths_file = tmp_path / "recording_paths.csv"
|
||||||
|
paths_file.write_text(f"recording_path\n{recording_a.path}\n")
|
||||||
|
|
||||||
|
condition = build_recording_condition_from_yaml(
|
||||||
|
tmp_path,
|
||||||
|
f"""
|
||||||
|
name: path_in_list
|
||||||
|
path: {paths_file}
|
||||||
|
format:
|
||||||
|
name: csv
|
||||||
|
column: recording_path
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert condition(recording_a)
|
||||||
|
assert not condition(recording_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_path_in_list_condition_uses_base_dir(
|
||||||
|
tmp_path: Path,
|
||||||
|
create_recording,
|
||||||
|
) -> None:
|
||||||
|
data_dir = tmp_path / "dataset"
|
||||||
|
audio_dir = data_dir / "audio"
|
||||||
|
audio_dir.mkdir(parents=True)
|
||||||
|
recording_a = create_recording(path=audio_dir / "a.wav")
|
||||||
|
recording_b = create_recording(path=audio_dir / "b.wav")
|
||||||
|
paths_file = tmp_path / "recording_paths.txt"
|
||||||
|
paths_file.write_text(f"{recording_a.path}\n")
|
||||||
|
|
||||||
|
condition = build_recording_condition_from_yaml(
|
||||||
|
tmp_path,
|
||||||
|
f"""
|
||||||
|
name: path_in_list
|
||||||
|
path: {paths_file}
|
||||||
|
format: txt
|
||||||
|
base_dir: {data_dir}
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert condition(recording_a)
|
||||||
|
assert not condition(recording_b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_path_in_list_condition_outside_allow(
|
||||||
|
tmp_path: Path,
|
||||||
|
create_recording,
|
||||||
|
) -> None:
|
||||||
|
data_dir = tmp_path / "dataset"
|
||||||
|
inside_dir = data_dir / "audio"
|
||||||
|
inside_dir.mkdir(parents=True)
|
||||||
|
outside_dir = tmp_path / "other"
|
||||||
|
outside_dir.mkdir()
|
||||||
|
recording_inside = create_recording(path=inside_dir / "a.wav")
|
||||||
|
recording_outside = create_recording(path=outside_dir / "x.wav")
|
||||||
|
paths_file = tmp_path / "recording_paths.txt"
|
||||||
|
paths_file.write_text("dataset/audio/unknown.wav\n")
|
||||||
|
|
||||||
|
condition = build_recording_condition_from_yaml(
|
||||||
|
tmp_path,
|
||||||
|
f"""
|
||||||
|
name: path_in_list
|
||||||
|
path: {paths_file}
|
||||||
|
format: txt
|
||||||
|
base_dir: {data_dir}
|
||||||
|
on_outside: allow
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert condition(recording_outside)
|
||||||
|
assert not condition(recording_inside)
|
||||||
|
|
||||||
|
|
||||||
|
def test_path_in_list_condition_outside_warn(
|
||||||
|
tmp_path: Path,
|
||||||
|
create_recording,
|
||||||
|
) -> None:
|
||||||
|
data_dir = tmp_path / "dataset"
|
||||||
|
inside_dir = data_dir / "audio"
|
||||||
|
inside_dir.mkdir(parents=True)
|
||||||
|
outside_dir = tmp_path / "other"
|
||||||
|
outside_dir.mkdir()
|
||||||
|
recording_inside = create_recording(path=inside_dir / "a.wav")
|
||||||
|
recording_outside = create_recording(path=outside_dir / "x.wav")
|
||||||
|
paths_file = tmp_path / "recording_paths.txt"
|
||||||
|
paths_file.write_text("dataset/audio/unknown.wav\n")
|
||||||
|
|
||||||
|
condition = build_recording_condition_from_yaml(
|
||||||
|
tmp_path,
|
||||||
|
f"""
|
||||||
|
name: path_in_list
|
||||||
|
path: {paths_file}
|
||||||
|
format: txt
|
||||||
|
base_dir: {data_dir}
|
||||||
|
on_outside: warn
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert condition(recording_outside)
|
||||||
|
assert not condition(recording_inside)
|
||||||
|
|
||||||
|
|
||||||
|
def test_path_in_list_condition_outside_error(
|
||||||
|
tmp_path: Path,
|
||||||
|
create_recording,
|
||||||
|
) -> None:
|
||||||
|
data_dir = tmp_path / "dataset"
|
||||||
|
inside_dir = data_dir / "audio"
|
||||||
|
inside_dir.mkdir(parents=True)
|
||||||
|
outside_dir = tmp_path / "other"
|
||||||
|
outside_dir.mkdir()
|
||||||
|
recording_inside = create_recording(path=inside_dir / "a.wav")
|
||||||
|
recording_outside = create_recording(path=outside_dir / "x.wav")
|
||||||
|
paths_file = tmp_path / "recording_paths.txt"
|
||||||
|
paths_file.write_text(f"{recording_inside.path}\n")
|
||||||
|
|
||||||
|
condition = build_recording_condition_from_yaml(
|
||||||
|
tmp_path,
|
||||||
|
f"""
|
||||||
|
name: path_in_list
|
||||||
|
path: {paths_file}
|
||||||
|
format: txt
|
||||||
|
base_dir: {data_dir}
|
||||||
|
on_outside: error
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert condition(recording_inside)
|
||||||
|
with pytest.raises(ValueError, match="outside"):
|
||||||
|
condition(recording_outside)
|
||||||
|
|
||||||
|
|
||||||
def test_has_tag_condition(tmp_path: Path, create_recording) -> None:
|
def test_has_tag_condition(tmp_path: Path, create_recording) -> None:
|
||||||
train = data.Tag(key="split", value="train")
|
train = data.Tag(key="split", value="train")
|
||||||
val = data.Tag(key="split", value="val")
|
val = data.Tag(key="split", value="val")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user