diff --git a/src/batdetect2/data/datasets.py b/src/batdetect2/data/datasets.py index 438c3e5..af13f3c 100644 --- a/src/batdetect2/data/datasets.py +++ b/src/batdetect2/data/datasets.py @@ -32,7 +32,9 @@ from batdetect2.data.annotations import ( load_annotated_dataset, ) from batdetect2.data.conditions import ( + ClipAnnotationConditionConfig, SoundEventConditionConfig, + build_clip_annotation_condition, build_sound_event_condition, filter_clip_annotation, ) @@ -69,6 +71,7 @@ class DatasetConfig(BaseConfig): description: str sources: list[AnnotationFormats] + clip_filter: ClipAnnotationConditionConfig | None = None sound_event_filter: SoundEventConditionConfig | None = None sound_event_transforms: list[SoundEventTransformConfig] = Field( default_factory=list @@ -84,11 +87,58 @@ def load_dataset( apply_transforms: bool = True, apply_filters: bool = True, ) -> Dataset: - """Load all clip annotations from the sources defined in a DatasetConfig.""" + """Load and merge clip annotations from configured dataset sources. + + Loads each source listed in ``config.sources`` and returns a flat + collection of ``soundevent.data.ClipAnnotation`` objects. Source tags, + dataset-level filters, and dataset-level transforms can be enabled or + disabled with flags. + + Parameters + ---------- + config : DatasetConfig + Dataset definition containing source configurations, optional + clip-level filter, sound-event filter, and optional sound-event + transform pipeline. + base_dir : data.PathLike, optional + Base directory used to resolve relative paths in source + configurations. + add_source_tag : bool, default=True + If True, append a ``data_source`` tag to each clip annotation with + the source name. + include_sources : list[str], optional + Source names to include. If None, all sources are eligible. + exclude_sources : list[str], optional + Source names to skip after include filtering. If a source appears in + both include and exclude lists, it is skipped. + apply_transforms : bool, default=True + If True, apply transforms defined in + ``config.sound_event_transforms``. + apply_filters : bool, default=True + If True, apply filters defined in ``config.clip_filter`` and + ``config.sound_event_filter``. + + Returns + ------- + Dataset + Flat collection of clip annotations loaded from the selected sources. + """ clip_annotations = [] - condition = ( - build_sound_event_condition(config.sound_event_filter) + clip_condition = ( + build_clip_annotation_condition( + config.clip_filter, + base_dir=base_dir, + ) + if config.clip_filter is not None + else None + ) + + sound_event_condition = ( + build_sound_event_condition( + config.sound_event_filter, + base_dir=base_dir, + ) if config.sound_event_filter is not None else None ) @@ -123,10 +173,17 @@ def load_dataset( if add_source_tag: clip_annotation = insert_source_tag(clip_annotation, source) - if condition is not None and apply_filters: + if ( + clip_condition is not None + and apply_filters + and not clip_condition(clip_annotation) + ): + continue + + if sound_event_condition is not None and apply_filters: clip_annotation = filter_clip_annotation( clip_annotation, - condition, + sound_event_condition, ) if transform is not None and apply_transforms: @@ -181,47 +238,58 @@ def load_dataset_from_config( path: data.PathLike, field: str | None = None, base_dir: data.PathLike | None = None, + add_source_tag: bool = True, + include_sources: list[str] | None = None, + exclude_sources: list[str] | None = None, + apply_transforms: bool = True, + apply_filters: bool = True, ) -> Dataset: - """Load dataset annotation metadata from a configuration file. + """Load a dataset by reading a ``DatasetConfig`` from disk. - This is a convenience function that first loads the `DatasetConfig` from - the specified file path and optional nested field, and then calls - `load_dataset` to load all corresponding `ClipAnnotation` objects. + This convenience wrapper first loads a ``DatasetConfig`` from ``path`` + and optional ``field``, then delegates to :func:`load_dataset`. Parameters ---------- path : data.PathLike - Path to the configuration file (e.g., YAML). + Path to a configuration file containing a ``DatasetConfig``. field : str, optional - Dot-separated path to a nested section within the file containing the - dataset configuration (e.g., "data.training_set"). If None, the - entire file content is assumed to be the `DatasetConfig`. - base_dir : Path, optional - An optional base directory path to resolve relative paths within the - configuration sources. Passed to `load_dataset`. Defaults to None. + Dot-separated field path to a nested config section. If None, the + full file is parsed as ``DatasetConfig``. + base_dir : data.PathLike, optional + Base directory used to resolve relative paths in source + configurations. + add_source_tag : bool, default=True + If True, append a ``data_source`` tag to each clip annotation. + include_sources : list[str], optional + Source names to include. If None, all sources are eligible. + exclude_sources : list[str], optional + Source names to skip after include filtering. + apply_transforms : bool, default=True + If True, apply transforms defined in the loaded config. + apply_filters : bool, default=True + If True, apply clip and sound-event filters defined in the loaded + config. Returns ------- - Dataset (List[data.ClipAnnotation]) - A flat list containing all loaded `ClipAnnotation` metadata objects. - - Raises - ------ - FileNotFoundError - If the config file `path` does not exist. - yaml.YAMLError, pydantic.ValidationError, KeyError, TypeError - If the configuration file is invalid, cannot be parsed, or does not - match the `DatasetConfig` schema. - Exception - Can raise exceptions from `load_dataset` if loading data from sources - fails. + Dataset + Flat collection of clip annotations loaded from the selected sources. """ config = load_config( path=path, schema=DatasetConfig, field=field, ) - return load_dataset(config, base_dir=base_dir) + return load_dataset( + config, + base_dir=base_dir, + add_source_tag=add_source_tag, + include_sources=include_sources, + exclude_sources=exclude_sources, + apply_transforms=apply_transforms, + apply_filters=apply_filters, + ) def save_dataset( diff --git a/tests/test_data/test_datasets.py b/tests/test_data/test_datasets.py new file mode 100644 index 0000000..ca4e5c0 --- /dev/null +++ b/tests/test_data/test_datasets.py @@ -0,0 +1,100 @@ +import json +from pathlib import Path + +from soundevent import data + +from batdetect2.data import DatasetConfig, load_dataset +from batdetect2.data.conditions import ( + HasTagConfig, + IdInListConfig, + RecordingSatisfiesConfig, +) + + +def test_load_dataset_applies_clip_filter( + example_dataset: DatasetConfig, + tmp_path: Path, +) -> None: + baseline = list(load_dataset(example_dataset)) + keep_recording_id = str(baseline[0].clip.recording.uuid) + ids_path = tmp_path / "train_ids.json" + ids_path.write_text(json.dumps([keep_recording_id])) + + config = example_dataset.model_copy( + update={ + "clip_filter": RecordingSatisfiesConfig( + condition=IdInListConfig(path=ids_path) + ) + } + ) + + filtered = list(load_dataset(config)) + + assert len(filtered) == 1 + assert str(filtered[0].clip.recording.uuid) == keep_recording_id + + +def test_load_dataset_clip_filter_is_skipped_when_filters_disabled( + example_dataset: DatasetConfig, + tmp_path: Path, +) -> None: + baseline = list(load_dataset(example_dataset)) + keep_recording_id = str(baseline[0].clip.recording.uuid) + ids_path = tmp_path / "train_ids.json" + ids_path.write_text(json.dumps([keep_recording_id])) + + config = example_dataset.model_copy( + update={ + "clip_filter": RecordingSatisfiesConfig( + condition=IdInListConfig(path=ids_path) + ) + } + ) + + filtered = list(load_dataset(config, apply_filters=False)) + + assert len(filtered) == len(baseline) + + +def test_load_dataset_resolves_clip_filter_paths_from_base_dir( + example_dataset: DatasetConfig, + tmp_path: Path, +) -> None: + baseline = list(load_dataset(example_dataset)) + keep_recording_id = str(baseline[0].clip.recording.uuid) + split_dir = tmp_path / "splits" + split_dir.mkdir() + ids_path = split_dir / "train_ids.json" + ids_path.write_text(json.dumps([keep_recording_id])) + + config = example_dataset.model_copy( + update={ + "clip_filter": RecordingSatisfiesConfig( + condition=IdInListConfig(path=Path("splits/train_ids.json")) + ) + } + ) + + filtered = list(load_dataset(config, base_dir=tmp_path)) + + assert len(filtered) == 1 + assert str(filtered[0].clip.recording.uuid) == keep_recording_id + + +def test_sound_event_filter_keeps_empty_clips( + example_dataset: DatasetConfig, +) -> None: + config = example_dataset.model_copy( + update={ + "sound_event_filter": HasTagConfig( + tag=data.Tag(key="species", value="__missing_species__") + ) + } + ) + + filtered = list(load_dataset(config)) + + assert len(filtered) == 3 + assert all( + len(clip_annotation.sound_events) == 0 for clip_annotation in filtered + )