diff --git a/src/batdetect2/cli/data.py b/src/batdetect2/cli/data.py index a772cb5..cf0698c 100644 --- a/src/batdetect2/cli/data.py +++ b/src/batdetect2/cli/data.py @@ -110,12 +110,42 @@ def summary( "made relative to this directory." ), ) +@click.option( + "--add-source-tag", + is_flag=True, + help=( + "Add a source tag to each clip annotation. This is useful for " + "downstream tools that need to know which source the annotations " + "came from." + ), +) +@click.option( + "--include-sources", + type=str, + multiple=True, + help=( + "Only include sources with the specified names. If provided, only " + "sources with matching names will be included in the output." + ), +) +@click.option( + "--exclude-sources", + type=str, + multiple=True, + help=( + "Exclude sources with the specified names. If provided, sources with " + "matching names will be excluded from the output." + ), +) def convert( dataset_config: Path, field: str | None = None, output: Path = Path("annotations.json"), base_dir: Path | None = None, audio_dir: Path | None = None, + add_source_tag: bool = True, + include_sources: list[str] | None = None, + exclude_sources: list[str] | None = None, ): """Convert a dataset config into soundevent annotation-set format. @@ -130,7 +160,13 @@ def convert( config = load_dataset_config(dataset_config, field=field) - dataset = load_dataset(config, base_dir=base_dir) + dataset = load_dataset( + config, + base_dir=base_dir, + add_source_tag=add_source_tag, + include_sources=include_sources, + exclude_sources=exclude_sources, + ) annotation_set = data.AnnotationSet( clip_annotations=list(dataset), diff --git a/src/batdetect2/data/datasets.py b/src/batdetect2/data/datasets.py index a07f1bb..7f8ce7b 100644 --- a/src/batdetect2/data/datasets.py +++ b/src/batdetect2/data/datasets.py @@ -19,7 +19,7 @@ The core components are: """ from pathlib import Path -from typing import List, Sequence +from typing import Sequence from loguru import logger from pydantic import Field @@ -67,10 +67,10 @@ class DatasetConfig(BaseConfig): name: str description: str - sources: List[AnnotationFormats] + sources: list[AnnotationFormats] sound_event_filter: SoundEventConditionConfig | None = None - sound_event_transforms: List[SoundEventTransformConfig] = Field( + sound_event_transforms: list[SoundEventTransformConfig] = Field( default_factory=list ) @@ -78,6 +78,9 @@ class DatasetConfig(BaseConfig): def load_dataset( config: DatasetConfig, base_dir: data.PathLike | None = None, + add_source_tag: bool = True, + include_sources: list[str] | None = None, + exclude_sources: list[str] | None = None, ) -> Dataset: """Load all clip annotations from the sources defined in a DatasetConfig.""" clip_annotations = [] @@ -102,6 +105,12 @@ def load_dataset( for source in config.sources: annotated_source = load_annotated_dataset(source, base_dir=base_dir) + if include_sources and source.name not in include_sources: + continue + + if exclude_sources and source.name in exclude_sources: + continue + logger.debug( "Loaded {num_examples} from dataset source '{source_name}'", num_examples=len(annotated_source.clip_annotations), @@ -109,7 +118,8 @@ def load_dataset( ) for clip_annotation in annotated_source.clip_annotations: - clip_annotation = insert_source_tag(clip_annotation, source) + if add_source_tag: + clip_annotation = insert_source_tag(clip_annotation, source) if condition is not None: clip_annotation = filter_clip_annotation(