Add clip annotation filtering to data loading

This commit is contained in:
mbsantiago 2026-04-03 16:40:23 +01:00
parent c8dd4155bf
commit e04d86808d
2 changed files with 198 additions and 30 deletions

View File

@ -32,7 +32,9 @@ from batdetect2.data.annotations import (
load_annotated_dataset, load_annotated_dataset,
) )
from batdetect2.data.conditions import ( from batdetect2.data.conditions import (
ClipAnnotationConditionConfig,
SoundEventConditionConfig, SoundEventConditionConfig,
build_clip_annotation_condition,
build_sound_event_condition, build_sound_event_condition,
filter_clip_annotation, filter_clip_annotation,
) )
@ -69,6 +71,7 @@ class DatasetConfig(BaseConfig):
description: str description: str
sources: list[AnnotationFormats] sources: list[AnnotationFormats]
clip_filter: ClipAnnotationConditionConfig | None = None
sound_event_filter: SoundEventConditionConfig | None = None sound_event_filter: SoundEventConditionConfig | None = None
sound_event_transforms: list[SoundEventTransformConfig] = Field( sound_event_transforms: list[SoundEventTransformConfig] = Field(
default_factory=list default_factory=list
@ -84,11 +87,58 @@ def load_dataset(
apply_transforms: bool = True, apply_transforms: bool = True,
apply_filters: bool = True, apply_filters: bool = True,
) -> Dataset: ) -> Dataset:
"""Load all clip annotations from the sources defined in a DatasetConfig.""" """Load and merge clip annotations from configured dataset sources.
Loads each source listed in ``config.sources`` and returns a flat
collection of ``soundevent.data.ClipAnnotation`` objects. Source tags,
dataset-level filters, and dataset-level transforms can be enabled or
disabled with flags.
Parameters
----------
config : DatasetConfig
Dataset definition containing source configurations, optional
clip-level filter, sound-event filter, and optional sound-event
transform pipeline.
base_dir : data.PathLike, optional
Base directory used to resolve relative paths in source
configurations.
add_source_tag : bool, default=True
If True, append a ``data_source`` tag to each clip annotation with
the source name.
include_sources : list[str], optional
Source names to include. If None, all sources are eligible.
exclude_sources : list[str], optional
Source names to skip after include filtering. If a source appears in
both include and exclude lists, it is skipped.
apply_transforms : bool, default=True
If True, apply transforms defined in
``config.sound_event_transforms``.
apply_filters : bool, default=True
If True, apply filters defined in ``config.clip_filter`` and
``config.sound_event_filter``.
Returns
-------
Dataset
Flat collection of clip annotations loaded from the selected sources.
"""
clip_annotations = [] clip_annotations = []
condition = ( clip_condition = (
build_sound_event_condition(config.sound_event_filter) build_clip_annotation_condition(
config.clip_filter,
base_dir=base_dir,
)
if config.clip_filter is not None
else None
)
sound_event_condition = (
build_sound_event_condition(
config.sound_event_filter,
base_dir=base_dir,
)
if config.sound_event_filter is not None if config.sound_event_filter is not None
else None else None
) )
@ -123,10 +173,17 @@ def load_dataset(
if add_source_tag: if add_source_tag:
clip_annotation = insert_source_tag(clip_annotation, source) clip_annotation = insert_source_tag(clip_annotation, source)
if condition is not None and apply_filters: if (
clip_condition is not None
and apply_filters
and not clip_condition(clip_annotation)
):
continue
if sound_event_condition is not None and apply_filters:
clip_annotation = filter_clip_annotation( clip_annotation = filter_clip_annotation(
clip_annotation, clip_annotation,
condition, sound_event_condition,
) )
if transform is not None and apply_transforms: if transform is not None and apply_transforms:
@ -181,47 +238,58 @@ def load_dataset_from_config(
path: data.PathLike, path: data.PathLike,
field: str | None = None, field: str | None = None,
base_dir: data.PathLike | None = None, base_dir: data.PathLike | None = None,
add_source_tag: bool = True,
include_sources: list[str] | None = None,
exclude_sources: list[str] | None = None,
apply_transforms: bool = True,
apply_filters: bool = True,
) -> Dataset: ) -> Dataset:
"""Load dataset annotation metadata from a configuration file. """Load a dataset by reading a ``DatasetConfig`` from disk.
This is a convenience function that first loads the `DatasetConfig` from This convenience wrapper first loads a ``DatasetConfig`` from ``path``
the specified file path and optional nested field, and then calls and optional ``field``, then delegates to :func:`load_dataset`.
`load_dataset` to load all corresponding `ClipAnnotation` objects.
Parameters Parameters
---------- ----------
path : data.PathLike path : data.PathLike
Path to the configuration file (e.g., YAML). Path to a configuration file containing a ``DatasetConfig``.
field : str, optional field : str, optional
Dot-separated path to a nested section within the file containing the Dot-separated field path to a nested config section. If None, the
dataset configuration (e.g., "data.training_set"). If None, the full file is parsed as ``DatasetConfig``.
entire file content is assumed to be the `DatasetConfig`. base_dir : data.PathLike, optional
base_dir : Path, optional Base directory used to resolve relative paths in source
An optional base directory path to resolve relative paths within the configurations.
configuration sources. Passed to `load_dataset`. Defaults to None. add_source_tag : bool, default=True
If True, append a ``data_source`` tag to each clip annotation.
include_sources : list[str], optional
Source names to include. If None, all sources are eligible.
exclude_sources : list[str], optional
Source names to skip after include filtering.
apply_transforms : bool, default=True
If True, apply transforms defined in the loaded config.
apply_filters : bool, default=True
If True, apply clip and sound-event filters defined in the loaded
config.
Returns Returns
------- -------
Dataset (List[data.ClipAnnotation]) Dataset
A flat list containing all loaded `ClipAnnotation` metadata objects. Flat collection of clip annotations loaded from the selected sources.
Raises
------
FileNotFoundError
If the config file `path` does not exist.
yaml.YAMLError, pydantic.ValidationError, KeyError, TypeError
If the configuration file is invalid, cannot be parsed, or does not
match the `DatasetConfig` schema.
Exception
Can raise exceptions from `load_dataset` if loading data from sources
fails.
""" """
config = load_config( config = load_config(
path=path, path=path,
schema=DatasetConfig, schema=DatasetConfig,
field=field, field=field,
) )
return load_dataset(config, base_dir=base_dir) return load_dataset(
config,
base_dir=base_dir,
add_source_tag=add_source_tag,
include_sources=include_sources,
exclude_sources=exclude_sources,
apply_transforms=apply_transforms,
apply_filters=apply_filters,
)
def save_dataset( def save_dataset(

View File

@ -0,0 +1,100 @@
import json
from pathlib import Path
from soundevent import data
from batdetect2.data import DatasetConfig, load_dataset
from batdetect2.data.conditions import (
HasTagConfig,
IdInListConfig,
RecordingSatisfiesConfig,
)
def test_load_dataset_applies_clip_filter(
example_dataset: DatasetConfig,
tmp_path: Path,
) -> None:
baseline = list(load_dataset(example_dataset))
keep_recording_id = str(baseline[0].clip.recording.uuid)
ids_path = tmp_path / "train_ids.json"
ids_path.write_text(json.dumps([keep_recording_id]))
config = example_dataset.model_copy(
update={
"clip_filter": RecordingSatisfiesConfig(
condition=IdInListConfig(path=ids_path)
)
}
)
filtered = list(load_dataset(config))
assert len(filtered) == 1
assert str(filtered[0].clip.recording.uuid) == keep_recording_id
def test_load_dataset_clip_filter_is_skipped_when_filters_disabled(
example_dataset: DatasetConfig,
tmp_path: Path,
) -> None:
baseline = list(load_dataset(example_dataset))
keep_recording_id = str(baseline[0].clip.recording.uuid)
ids_path = tmp_path / "train_ids.json"
ids_path.write_text(json.dumps([keep_recording_id]))
config = example_dataset.model_copy(
update={
"clip_filter": RecordingSatisfiesConfig(
condition=IdInListConfig(path=ids_path)
)
}
)
filtered = list(load_dataset(config, apply_filters=False))
assert len(filtered) == len(baseline)
def test_load_dataset_resolves_clip_filter_paths_from_base_dir(
example_dataset: DatasetConfig,
tmp_path: Path,
) -> None:
baseline = list(load_dataset(example_dataset))
keep_recording_id = str(baseline[0].clip.recording.uuid)
split_dir = tmp_path / "splits"
split_dir.mkdir()
ids_path = split_dir / "train_ids.json"
ids_path.write_text(json.dumps([keep_recording_id]))
config = example_dataset.model_copy(
update={
"clip_filter": RecordingSatisfiesConfig(
condition=IdInListConfig(path=Path("splits/train_ids.json"))
)
}
)
filtered = list(load_dataset(config, base_dir=tmp_path))
assert len(filtered) == 1
assert str(filtered[0].clip.recording.uuid) == keep_recording_id
def test_sound_event_filter_keeps_empty_clips(
example_dataset: DatasetConfig,
) -> None:
config = example_dataset.model_copy(
update={
"sound_event_filter": HasTagConfig(
tag=data.Tag(key="species", value="__missing_species__")
)
}
)
filtered = list(load_dataset(config))
assert len(filtered) == 3
assert all(
len(clip_annotation.sound_events) == 0 for clip_annotation in filtered
)