Add source filtering to dataset loading

This commit is contained in:
mbsantiago 2026-03-28 19:43:13 +00:00
parent 67bb66db3c
commit 548cd366cd
2 changed files with 51 additions and 5 deletions

View File

@ -110,12 +110,42 @@ def summary(
"made relative to this directory."
),
)
@click.option(
"--add-source-tag",
is_flag=True,
help=(
"Add a source tag to each clip annotation. This is useful for "
"downstream tools that need to know which source the annotations "
"came from."
),
)
@click.option(
"--include-sources",
type=str,
multiple=True,
help=(
"Only include sources with the specified names. If provided, only "
"sources with matching names will be included in the output."
),
)
@click.option(
"--exclude-sources",
type=str,
multiple=True,
help=(
"Exclude sources with the specified names. If provided, sources with "
"matching names will be excluded from the output."
),
)
def convert(
dataset_config: Path,
field: str | None = None,
output: Path = Path("annotations.json"),
base_dir: Path | None = None,
audio_dir: Path | None = None,
add_source_tag: bool = True,
include_sources: list[str] | None = None,
exclude_sources: list[str] | None = None,
):
"""Convert a dataset config into soundevent annotation-set format.
@ -130,7 +160,13 @@ def convert(
config = load_dataset_config(dataset_config, field=field)
dataset = load_dataset(config, base_dir=base_dir)
dataset = load_dataset(
config,
base_dir=base_dir,
add_source_tag=add_source_tag,
include_sources=include_sources,
exclude_sources=exclude_sources,
)
annotation_set = data.AnnotationSet(
clip_annotations=list(dataset),

View File

@ -19,7 +19,7 @@ The core components are:
"""
from pathlib import Path
from typing import List, Sequence
from typing import Sequence
from loguru import logger
from pydantic import Field
@ -67,10 +67,10 @@ class DatasetConfig(BaseConfig):
name: str
description: str
sources: List[AnnotationFormats]
sources: list[AnnotationFormats]
sound_event_filter: SoundEventConditionConfig | None = None
sound_event_transforms: List[SoundEventTransformConfig] = Field(
sound_event_transforms: list[SoundEventTransformConfig] = Field(
default_factory=list
)
@ -78,6 +78,9 @@ class DatasetConfig(BaseConfig):
def load_dataset(
config: DatasetConfig,
base_dir: data.PathLike | None = None,
add_source_tag: bool = True,
include_sources: list[str] | None = None,
exclude_sources: list[str] | None = None,
) -> Dataset:
"""Load all clip annotations from the sources defined in a DatasetConfig."""
clip_annotations = []
@ -102,6 +105,12 @@ def load_dataset(
for source in config.sources:
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
if include_sources and source.name not in include_sources:
continue
if exclude_sources and source.name in exclude_sources:
continue
logger.debug(
"Loaded {num_examples} from dataset source '{source_name}'",
num_examples=len(annotated_source.clip_annotations),
@ -109,7 +118,8 @@ def load_dataset(
)
for clip_annotation in annotated_source.clip_annotations:
clip_annotation = insert_source_tag(clip_annotation, source)
if add_source_tag:
clip_annotation = insert_source_tag(clip_annotation, source)
if condition is not None:
clip_annotation = filter_clip_annotation(