mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-23 23:01:22 +02:00
101 lines
2.8 KiB
Python
101 lines
2.8 KiB
Python
import json
|
|
from pathlib import Path
|
|
|
|
from soundevent import data
|
|
|
|
from batdetect2.data import DatasetConfig, load_dataset
|
|
from batdetect2.data.conditions import (
|
|
HasTagConfig,
|
|
IdInListConfig,
|
|
RecordingSatisfiesConfig,
|
|
)
|
|
|
|
|
|
def test_load_dataset_applies_clip_filter(
|
|
example_dataset: DatasetConfig,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
baseline = list(load_dataset(example_dataset))
|
|
keep_recording_id = str(baseline[0].clip.recording.uuid)
|
|
ids_path = tmp_path / "train_ids.json"
|
|
ids_path.write_text(json.dumps([keep_recording_id]))
|
|
|
|
config = example_dataset.model_copy(
|
|
update={
|
|
"clip_filter": RecordingSatisfiesConfig(
|
|
condition=IdInListConfig(path=ids_path)
|
|
)
|
|
}
|
|
)
|
|
|
|
filtered = list(load_dataset(config))
|
|
|
|
assert len(filtered) == 1
|
|
assert str(filtered[0].clip.recording.uuid) == keep_recording_id
|
|
|
|
|
|
def test_load_dataset_clip_filter_is_skipped_when_filters_disabled(
|
|
example_dataset: DatasetConfig,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
baseline = list(load_dataset(example_dataset))
|
|
keep_recording_id = str(baseline[0].clip.recording.uuid)
|
|
ids_path = tmp_path / "train_ids.json"
|
|
ids_path.write_text(json.dumps([keep_recording_id]))
|
|
|
|
config = example_dataset.model_copy(
|
|
update={
|
|
"clip_filter": RecordingSatisfiesConfig(
|
|
condition=IdInListConfig(path=ids_path)
|
|
)
|
|
}
|
|
)
|
|
|
|
filtered = list(load_dataset(config, apply_filters=False))
|
|
|
|
assert len(filtered) == len(baseline)
|
|
|
|
|
|
def test_load_dataset_resolves_clip_filter_paths_from_base_dir(
|
|
example_dataset: DatasetConfig,
|
|
tmp_path: Path,
|
|
) -> None:
|
|
baseline = list(load_dataset(example_dataset))
|
|
keep_recording_id = str(baseline[0].clip.recording.uuid)
|
|
split_dir = tmp_path / "splits"
|
|
split_dir.mkdir()
|
|
ids_path = split_dir / "train_ids.json"
|
|
ids_path.write_text(json.dumps([keep_recording_id]))
|
|
|
|
config = example_dataset.model_copy(
|
|
update={
|
|
"clip_filter": RecordingSatisfiesConfig(
|
|
condition=IdInListConfig(path=Path("splits/train_ids.json"))
|
|
)
|
|
}
|
|
)
|
|
|
|
filtered = list(load_dataset(config, base_dir=tmp_path))
|
|
|
|
assert len(filtered) == 1
|
|
assert str(filtered[0].clip.recording.uuid) == keep_recording_id
|
|
|
|
|
|
def test_sound_event_filter_keeps_empty_clips(
|
|
example_dataset: DatasetConfig,
|
|
) -> None:
|
|
config = example_dataset.model_copy(
|
|
update={
|
|
"sound_event_filter": HasTagConfig(
|
|
tag=data.Tag(key="species", value="__missing_species__")
|
|
)
|
|
}
|
|
)
|
|
|
|
filtered = list(load_dataset(config))
|
|
|
|
assert len(filtered) == 3
|
|
assert all(
|
|
len(clip_annotation.sound_events) == 0 for clip_annotation in filtered
|
|
)
|