diff --git a/batdetect2/data/datasets.py b/batdetect2/data/datasets.py index f8d94be..f0e3278 100644 --- a/batdetect2/data/datasets.py +++ b/batdetect2/data/datasets.py @@ -26,9 +26,11 @@ from soundevent import data, io from batdetect2.configs import BaseConfig, load_config from batdetect2.data.annotations import ( + AnnotatedDataset, AnnotationFormats, load_annotated_dataset, ) +from batdetect2.targets.terms import data_source __all__ = [ "load_dataset", @@ -113,10 +115,46 @@ def load_dataset( clip_annotations = [] for source in dataset.sources: annotated_source = load_annotated_dataset(source, base_dir=base_dir) - clip_annotations.extend(annotated_source.clip_annotations) + clip_annotations.extend( + insert_source_tag(clip_annotation, source) + for clip_annotation in annotated_source.clip_annotations + ) return clip_annotations +def insert_source_tag( + clip_annotation: data.ClipAnnotation, + source: AnnotatedDataset, +) -> data.ClipAnnotation: + """Insert the source tag into a ClipAnnotation. + + This function adds a tag to the `ClipAnnotation` object, indicating the + source from which it was loaded. The source information is derived from + the `recording` attribute of the `ClipAnnotation`. + + Parameters + ---------- + clip_annotation : data.ClipAnnotation + The `ClipAnnotation` object to which the source tag will be added. + + Returns + ------- + data.ClipAnnotation + The modified `ClipAnnotation` object with the source tag added. + """ + return clip_annotation.model_copy( + update=dict( + tags=[ + *clip_annotation.tags, + data.Tag( + term=data_source, + value=source.name, + ), + ] + ), + ) + + def load_dataset_from_config( path: data.PathLike, field: Optional[str] = None,