Add source filtering to dataset loading

This commit is contained in:
mbsantiago 2026-03-28 19:43:13 +00:00
parent 67bb66db3c
commit 548cd366cd
2 changed files with 51 additions and 5 deletions

View File

@ -110,12 +110,42 @@ def summary(
"made relative to this directory." "made relative to this directory."
), ),
) )
@click.option(
"--add-source-tag",
is_flag=True,
help=(
"Add a source tag to each clip annotation. This is useful for "
"downstream tools that need to know which source the annotations "
"came from."
),
)
@click.option(
"--include-sources",
type=str,
multiple=True,
help=(
"Only include sources with the specified names. If provided, only "
"sources with matching names will be included in the output."
),
)
@click.option(
"--exclude-sources",
type=str,
multiple=True,
help=(
"Exclude sources with the specified names. If provided, sources with "
"matching names will be excluded from the output."
),
)
def convert( def convert(
dataset_config: Path, dataset_config: Path,
field: str | None = None, field: str | None = None,
output: Path = Path("annotations.json"), output: Path = Path("annotations.json"),
base_dir: Path | None = None, base_dir: Path | None = None,
audio_dir: Path | None = None, audio_dir: Path | None = None,
add_source_tag: bool = True,
include_sources: list[str] | None = None,
exclude_sources: list[str] | None = None,
): ):
"""Convert a dataset config into soundevent annotation-set format. """Convert a dataset config into soundevent annotation-set format.
@ -130,7 +160,13 @@ def convert(
config = load_dataset_config(dataset_config, field=field) config = load_dataset_config(dataset_config, field=field)
dataset = load_dataset(config, base_dir=base_dir) dataset = load_dataset(
config,
base_dir=base_dir,
add_source_tag=add_source_tag,
include_sources=include_sources,
exclude_sources=exclude_sources,
)
annotation_set = data.AnnotationSet( annotation_set = data.AnnotationSet(
clip_annotations=list(dataset), clip_annotations=list(dataset),

View File

@ -19,7 +19,7 @@ The core components are:
""" """
from pathlib import Path from pathlib import Path
from typing import List, Sequence from typing import Sequence
from loguru import logger from loguru import logger
from pydantic import Field from pydantic import Field
@ -67,10 +67,10 @@ class DatasetConfig(BaseConfig):
name: str name: str
description: str description: str
sources: List[AnnotationFormats] sources: list[AnnotationFormats]
sound_event_filter: SoundEventConditionConfig | None = None sound_event_filter: SoundEventConditionConfig | None = None
sound_event_transforms: List[SoundEventTransformConfig] = Field( sound_event_transforms: list[SoundEventTransformConfig] = Field(
default_factory=list default_factory=list
) )
@ -78,6 +78,9 @@ class DatasetConfig(BaseConfig):
def load_dataset( def load_dataset(
config: DatasetConfig, config: DatasetConfig,
base_dir: data.PathLike | None = None, base_dir: data.PathLike | None = None,
add_source_tag: bool = True,
include_sources: list[str] | None = None,
exclude_sources: list[str] | None = None,
) -> Dataset: ) -> Dataset:
"""Load all clip annotations from the sources defined in a DatasetConfig.""" """Load all clip annotations from the sources defined in a DatasetConfig."""
clip_annotations = [] clip_annotations = []
@ -102,6 +105,12 @@ def load_dataset(
for source in config.sources: for source in config.sources:
annotated_source = load_annotated_dataset(source, base_dir=base_dir) annotated_source = load_annotated_dataset(source, base_dir=base_dir)
if include_sources and source.name not in include_sources:
continue
if exclude_sources and source.name in exclude_sources:
continue
logger.debug( logger.debug(
"Loaded {num_examples} from dataset source '{source_name}'", "Loaded {num_examples} from dataset source '{source_name}'",
num_examples=len(annotated_source.clip_annotations), num_examples=len(annotated_source.clip_annotations),
@ -109,7 +118,8 @@ def load_dataset(
) )
for clip_annotation in annotated_source.clip_annotations: for clip_annotation in annotated_source.clip_annotations:
clip_annotation = insert_source_tag(clip_annotation, source) if add_source_tag:
clip_annotation = insert_source_tag(clip_annotation, source)
if condition is not None: if condition is not None:
clip_annotation = filter_clip_annotation( clip_annotation = filter_clip_annotation(