mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add source filtering to dataset loading
This commit is contained in:
parent
67bb66db3c
commit
548cd366cd
@ -110,12 +110,42 @@ def summary(
|
||||
"made relative to this directory."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--add-source-tag",
|
||||
is_flag=True,
|
||||
help=(
|
||||
"Add a source tag to each clip annotation. This is useful for "
|
||||
"downstream tools that need to know which source the annotations "
|
||||
"came from."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--include-sources",
|
||||
type=str,
|
||||
multiple=True,
|
||||
help=(
|
||||
"Only include sources with the specified names. If provided, only "
|
||||
"sources with matching names will be included in the output."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--exclude-sources",
|
||||
type=str,
|
||||
multiple=True,
|
||||
help=(
|
||||
"Exclude sources with the specified names. If provided, sources with "
|
||||
"matching names will be excluded from the output."
|
||||
),
|
||||
)
|
||||
def convert(
|
||||
dataset_config: Path,
|
||||
field: str | None = None,
|
||||
output: Path = Path("annotations.json"),
|
||||
base_dir: Path | None = None,
|
||||
audio_dir: Path | None = None,
|
||||
add_source_tag: bool = True,
|
||||
include_sources: list[str] | None = None,
|
||||
exclude_sources: list[str] | None = None,
|
||||
):
|
||||
"""Convert a dataset config into soundevent annotation-set format.
|
||||
|
||||
@ -130,7 +160,13 @@ def convert(
|
||||
|
||||
config = load_dataset_config(dataset_config, field=field)
|
||||
|
||||
dataset = load_dataset(config, base_dir=base_dir)
|
||||
dataset = load_dataset(
|
||||
config,
|
||||
base_dir=base_dir,
|
||||
add_source_tag=add_source_tag,
|
||||
include_sources=include_sources,
|
||||
exclude_sources=exclude_sources,
|
||||
)
|
||||
|
||||
annotation_set = data.AnnotationSet(
|
||||
clip_annotations=list(dataset),
|
||||
|
||||
@ -19,7 +19,7 @@ The core components are:
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Sequence
|
||||
from typing import Sequence
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
@ -67,10 +67,10 @@ class DatasetConfig(BaseConfig):
|
||||
|
||||
name: str
|
||||
description: str
|
||||
sources: List[AnnotationFormats]
|
||||
sources: list[AnnotationFormats]
|
||||
|
||||
sound_event_filter: SoundEventConditionConfig | None = None
|
||||
sound_event_transforms: List[SoundEventTransformConfig] = Field(
|
||||
sound_event_transforms: list[SoundEventTransformConfig] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
@ -78,6 +78,9 @@ class DatasetConfig(BaseConfig):
|
||||
def load_dataset(
|
||||
config: DatasetConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
add_source_tag: bool = True,
|
||||
include_sources: list[str] | None = None,
|
||||
exclude_sources: list[str] | None = None,
|
||||
) -> Dataset:
|
||||
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
||||
clip_annotations = []
|
||||
@ -102,6 +105,12 @@ def load_dataset(
|
||||
for source in config.sources:
|
||||
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
|
||||
|
||||
if include_sources and source.name not in include_sources:
|
||||
continue
|
||||
|
||||
if exclude_sources and source.name in exclude_sources:
|
||||
continue
|
||||
|
||||
logger.debug(
|
||||
"Loaded {num_examples} from dataset source '{source_name}'",
|
||||
num_examples=len(annotated_source.clip_annotations),
|
||||
@ -109,7 +118,8 @@ def load_dataset(
|
||||
)
|
||||
|
||||
for clip_annotation in annotated_source.clip_annotations:
|
||||
clip_annotation = insert_source_tag(clip_annotation, source)
|
||||
if add_source_tag:
|
||||
clip_annotation = insert_source_tag(clip_annotation, source)
|
||||
|
||||
if condition is not None:
|
||||
clip_annotation = filter_clip_annotation(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user