mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add source filtering to dataset loading
This commit is contained in:
parent
67bb66db3c
commit
548cd366cd
@ -110,12 +110,42 @@ def summary(
|
|||||||
"made relative to this directory."
|
"made relative to this directory."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--add-source-tag",
|
||||||
|
is_flag=True,
|
||||||
|
help=(
|
||||||
|
"Add a source tag to each clip annotation. This is useful for "
|
||||||
|
"downstream tools that need to know which source the annotations "
|
||||||
|
"came from."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--include-sources",
|
||||||
|
type=str,
|
||||||
|
multiple=True,
|
||||||
|
help=(
|
||||||
|
"Only include sources with the specified names. If provided, only "
|
||||||
|
"sources with matching names will be included in the output."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--exclude-sources",
|
||||||
|
type=str,
|
||||||
|
multiple=True,
|
||||||
|
help=(
|
||||||
|
"Exclude sources with the specified names. If provided, sources with "
|
||||||
|
"matching names will be excluded from the output."
|
||||||
|
),
|
||||||
|
)
|
||||||
def convert(
|
def convert(
|
||||||
dataset_config: Path,
|
dataset_config: Path,
|
||||||
field: str | None = None,
|
field: str | None = None,
|
||||||
output: Path = Path("annotations.json"),
|
output: Path = Path("annotations.json"),
|
||||||
base_dir: Path | None = None,
|
base_dir: Path | None = None,
|
||||||
audio_dir: Path | None = None,
|
audio_dir: Path | None = None,
|
||||||
|
add_source_tag: bool = True,
|
||||||
|
include_sources: list[str] | None = None,
|
||||||
|
exclude_sources: list[str] | None = None,
|
||||||
):
|
):
|
||||||
"""Convert a dataset config into soundevent annotation-set format.
|
"""Convert a dataset config into soundevent annotation-set format.
|
||||||
|
|
||||||
@ -130,7 +160,13 @@ def convert(
|
|||||||
|
|
||||||
config = load_dataset_config(dataset_config, field=field)
|
config = load_dataset_config(dataset_config, field=field)
|
||||||
|
|
||||||
dataset = load_dataset(config, base_dir=base_dir)
|
dataset = load_dataset(
|
||||||
|
config,
|
||||||
|
base_dir=base_dir,
|
||||||
|
add_source_tag=add_source_tag,
|
||||||
|
include_sources=include_sources,
|
||||||
|
exclude_sources=exclude_sources,
|
||||||
|
)
|
||||||
|
|
||||||
annotation_set = data.AnnotationSet(
|
annotation_set = data.AnnotationSet(
|
||||||
clip_annotations=list(dataset),
|
clip_annotations=list(dataset),
|
||||||
|
|||||||
@ -19,7 +19,7 @@ The core components are:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -67,10 +67,10 @@ class DatasetConfig(BaseConfig):
|
|||||||
|
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
sources: List[AnnotationFormats]
|
sources: list[AnnotationFormats]
|
||||||
|
|
||||||
sound_event_filter: SoundEventConditionConfig | None = None
|
sound_event_filter: SoundEventConditionConfig | None = None
|
||||||
sound_event_transforms: List[SoundEventTransformConfig] = Field(
|
sound_event_transforms: list[SoundEventTransformConfig] = Field(
|
||||||
default_factory=list
|
default_factory=list
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -78,6 +78,9 @@ class DatasetConfig(BaseConfig):
|
|||||||
def load_dataset(
|
def load_dataset(
|
||||||
config: DatasetConfig,
|
config: DatasetConfig,
|
||||||
base_dir: data.PathLike | None = None,
|
base_dir: data.PathLike | None = None,
|
||||||
|
add_source_tag: bool = True,
|
||||||
|
include_sources: list[str] | None = None,
|
||||||
|
exclude_sources: list[str] | None = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
||||||
clip_annotations = []
|
clip_annotations = []
|
||||||
@ -102,6 +105,12 @@ def load_dataset(
|
|||||||
for source in config.sources:
|
for source in config.sources:
|
||||||
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
|
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
|
||||||
|
|
||||||
|
if include_sources and source.name not in include_sources:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if exclude_sources and source.name in exclude_sources:
|
||||||
|
continue
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Loaded {num_examples} from dataset source '{source_name}'",
|
"Loaded {num_examples} from dataset source '{source_name}'",
|
||||||
num_examples=len(annotated_source.clip_annotations),
|
num_examples=len(annotated_source.clip_annotations),
|
||||||
@ -109,7 +118,8 @@ def load_dataset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for clip_annotation in annotated_source.clip_annotations:
|
for clip_annotation in annotated_source.clip_annotations:
|
||||||
clip_annotation = insert_source_tag(clip_annotation, source)
|
if add_source_tag:
|
||||||
|
clip_annotation = insert_source_tag(clip_annotation, source)
|
||||||
|
|
||||||
if condition is not None:
|
if condition is not None:
|
||||||
clip_annotation = filter_clip_annotation(
|
clip_annotation = filter_clip_annotation(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user