mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Add clip annotation filtering to data loading
This commit is contained in:
parent
c8dd4155bf
commit
e04d86808d
@ -32,7 +32,9 @@ from batdetect2.data.annotations import (
|
||||
load_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.conditions import (
|
||||
ClipAnnotationConditionConfig,
|
||||
SoundEventConditionConfig,
|
||||
build_clip_annotation_condition,
|
||||
build_sound_event_condition,
|
||||
filter_clip_annotation,
|
||||
)
|
||||
@ -69,6 +71,7 @@ class DatasetConfig(BaseConfig):
|
||||
description: str
|
||||
sources: list[AnnotationFormats]
|
||||
|
||||
clip_filter: ClipAnnotationConditionConfig | None = None
|
||||
sound_event_filter: SoundEventConditionConfig | None = None
|
||||
sound_event_transforms: list[SoundEventTransformConfig] = Field(
|
||||
default_factory=list
|
||||
@ -84,11 +87,58 @@ def load_dataset(
|
||||
apply_transforms: bool = True,
|
||||
apply_filters: bool = True,
|
||||
) -> Dataset:
|
||||
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
||||
"""Load and merge clip annotations from configured dataset sources.
|
||||
|
||||
Loads each source listed in ``config.sources`` and returns a flat
|
||||
collection of ``soundevent.data.ClipAnnotation`` objects. Source tags,
|
||||
dataset-level filters, and dataset-level transforms can be enabled or
|
||||
disabled with flags.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : DatasetConfig
|
||||
Dataset definition containing source configurations, optional
|
||||
clip-level filter, sound-event filter, and optional sound-event
|
||||
transform pipeline.
|
||||
base_dir : data.PathLike, optional
|
||||
Base directory used to resolve relative paths in source
|
||||
configurations.
|
||||
add_source_tag : bool, default=True
|
||||
If True, append a ``data_source`` tag to each clip annotation with
|
||||
the source name.
|
||||
include_sources : list[str], optional
|
||||
Source names to include. If None, all sources are eligible.
|
||||
exclude_sources : list[str], optional
|
||||
Source names to skip after include filtering. If a source appears in
|
||||
both include and exclude lists, it is skipped.
|
||||
apply_transforms : bool, default=True
|
||||
If True, apply transforms defined in
|
||||
``config.sound_event_transforms``.
|
||||
apply_filters : bool, default=True
|
||||
If True, apply filters defined in ``config.clip_filter`` and
|
||||
``config.sound_event_filter``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dataset
|
||||
Flat collection of clip annotations loaded from the selected sources.
|
||||
"""
|
||||
clip_annotations = []
|
||||
|
||||
condition = (
|
||||
build_sound_event_condition(config.sound_event_filter)
|
||||
clip_condition = (
|
||||
build_clip_annotation_condition(
|
||||
config.clip_filter,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
if config.clip_filter is not None
|
||||
else None
|
||||
)
|
||||
|
||||
sound_event_condition = (
|
||||
build_sound_event_condition(
|
||||
config.sound_event_filter,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
if config.sound_event_filter is not None
|
||||
else None
|
||||
)
|
||||
@ -123,10 +173,17 @@ def load_dataset(
|
||||
if add_source_tag:
|
||||
clip_annotation = insert_source_tag(clip_annotation, source)
|
||||
|
||||
if condition is not None and apply_filters:
|
||||
if (
|
||||
clip_condition is not None
|
||||
and apply_filters
|
||||
and not clip_condition(clip_annotation)
|
||||
):
|
||||
continue
|
||||
|
||||
if sound_event_condition is not None and apply_filters:
|
||||
clip_annotation = filter_clip_annotation(
|
||||
clip_annotation,
|
||||
condition,
|
||||
sound_event_condition,
|
||||
)
|
||||
|
||||
if transform is not None and apply_transforms:
|
||||
@ -181,47 +238,58 @@ def load_dataset_from_config(
|
||||
path: data.PathLike,
|
||||
field: str | None = None,
|
||||
base_dir: data.PathLike | None = None,
|
||||
add_source_tag: bool = True,
|
||||
include_sources: list[str] | None = None,
|
||||
exclude_sources: list[str] | None = None,
|
||||
apply_transforms: bool = True,
|
||||
apply_filters: bool = True,
|
||||
) -> Dataset:
|
||||
"""Load dataset annotation metadata from a configuration file.
|
||||
"""Load a dataset by reading a ``DatasetConfig`` from disk.
|
||||
|
||||
This is a convenience function that first loads the `DatasetConfig` from
|
||||
the specified file path and optional nested field, and then calls
|
||||
`load_dataset` to load all corresponding `ClipAnnotation` objects.
|
||||
This convenience wrapper first loads a ``DatasetConfig`` from ``path``
|
||||
and optional ``field``, then delegates to :func:`load_dataset`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (e.g., YAML).
|
||||
Path to a configuration file containing a ``DatasetConfig``.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
dataset configuration (e.g., "data.training_set"). If None, the
|
||||
entire file content is assumed to be the `DatasetConfig`.
|
||||
base_dir : Path, optional
|
||||
An optional base directory path to resolve relative paths within the
|
||||
configuration sources. Passed to `load_dataset`. Defaults to None.
|
||||
Dot-separated field path to a nested config section. If None, the
|
||||
full file is parsed as ``DatasetConfig``.
|
||||
base_dir : data.PathLike, optional
|
||||
Base directory used to resolve relative paths in source
|
||||
configurations.
|
||||
add_source_tag : bool, default=True
|
||||
If True, append a ``data_source`` tag to each clip annotation.
|
||||
include_sources : list[str], optional
|
||||
Source names to include. If None, all sources are eligible.
|
||||
exclude_sources : list[str], optional
|
||||
Source names to skip after include filtering.
|
||||
apply_transforms : bool, default=True
|
||||
If True, apply transforms defined in the loaded config.
|
||||
apply_filters : bool, default=True
|
||||
If True, apply clip and sound-event filters defined in the loaded
|
||||
config.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dataset (List[data.ClipAnnotation])
|
||||
A flat list containing all loaded `ClipAnnotation` metadata objects.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file `path` does not exist.
|
||||
yaml.YAMLError, pydantic.ValidationError, KeyError, TypeError
|
||||
If the configuration file is invalid, cannot be parsed, or does not
|
||||
match the `DatasetConfig` schema.
|
||||
Exception
|
||||
Can raise exceptions from `load_dataset` if loading data from sources
|
||||
fails.
|
||||
Dataset
|
||||
Flat collection of clip annotations loaded from the selected sources.
|
||||
"""
|
||||
config = load_config(
|
||||
path=path,
|
||||
schema=DatasetConfig,
|
||||
field=field,
|
||||
)
|
||||
return load_dataset(config, base_dir=base_dir)
|
||||
return load_dataset(
|
||||
config,
|
||||
base_dir=base_dir,
|
||||
add_source_tag=add_source_tag,
|
||||
include_sources=include_sources,
|
||||
exclude_sources=exclude_sources,
|
||||
apply_transforms=apply_transforms,
|
||||
apply_filters=apply_filters,
|
||||
)
|
||||
|
||||
|
||||
def save_dataset(
|
||||
|
||||
100
tests/test_data/test_datasets.py
Normal file
100
tests/test_data/test_datasets.py
Normal file
@ -0,0 +1,100 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data import DatasetConfig, load_dataset
|
||||
from batdetect2.data.conditions import (
|
||||
HasTagConfig,
|
||||
IdInListConfig,
|
||||
RecordingSatisfiesConfig,
|
||||
)
|
||||
|
||||
|
||||
def test_load_dataset_applies_clip_filter(
|
||||
example_dataset: DatasetConfig,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
baseline = list(load_dataset(example_dataset))
|
||||
keep_recording_id = str(baseline[0].clip.recording.uuid)
|
||||
ids_path = tmp_path / "train_ids.json"
|
||||
ids_path.write_text(json.dumps([keep_recording_id]))
|
||||
|
||||
config = example_dataset.model_copy(
|
||||
update={
|
||||
"clip_filter": RecordingSatisfiesConfig(
|
||||
condition=IdInListConfig(path=ids_path)
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
filtered = list(load_dataset(config))
|
||||
|
||||
assert len(filtered) == 1
|
||||
assert str(filtered[0].clip.recording.uuid) == keep_recording_id
|
||||
|
||||
|
||||
def test_load_dataset_clip_filter_is_skipped_when_filters_disabled(
|
||||
example_dataset: DatasetConfig,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
baseline = list(load_dataset(example_dataset))
|
||||
keep_recording_id = str(baseline[0].clip.recording.uuid)
|
||||
ids_path = tmp_path / "train_ids.json"
|
||||
ids_path.write_text(json.dumps([keep_recording_id]))
|
||||
|
||||
config = example_dataset.model_copy(
|
||||
update={
|
||||
"clip_filter": RecordingSatisfiesConfig(
|
||||
condition=IdInListConfig(path=ids_path)
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
filtered = list(load_dataset(config, apply_filters=False))
|
||||
|
||||
assert len(filtered) == len(baseline)
|
||||
|
||||
|
||||
def test_load_dataset_resolves_clip_filter_paths_from_base_dir(
|
||||
example_dataset: DatasetConfig,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
baseline = list(load_dataset(example_dataset))
|
||||
keep_recording_id = str(baseline[0].clip.recording.uuid)
|
||||
split_dir = tmp_path / "splits"
|
||||
split_dir.mkdir()
|
||||
ids_path = split_dir / "train_ids.json"
|
||||
ids_path.write_text(json.dumps([keep_recording_id]))
|
||||
|
||||
config = example_dataset.model_copy(
|
||||
update={
|
||||
"clip_filter": RecordingSatisfiesConfig(
|
||||
condition=IdInListConfig(path=Path("splits/train_ids.json"))
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
filtered = list(load_dataset(config, base_dir=tmp_path))
|
||||
|
||||
assert len(filtered) == 1
|
||||
assert str(filtered[0].clip.recording.uuid) == keep_recording_id
|
||||
|
||||
|
||||
def test_sound_event_filter_keeps_empty_clips(
|
||||
example_dataset: DatasetConfig,
|
||||
) -> None:
|
||||
config = example_dataset.model_copy(
|
||||
update={
|
||||
"sound_event_filter": HasTagConfig(
|
||||
tag=data.Tag(key="species", value="__missing_species__")
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
filtered = list(load_dataset(config))
|
||||
|
||||
assert len(filtered) == 3
|
||||
assert all(
|
||||
len(clip_annotation.sound_events) == 0 for clip_annotation in filtered
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user