From f9e005ec8bfa2b96d9a5383b93358a192a57aedc Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Fri, 18 Apr 2025 13:32:50 +0100 Subject: [PATCH] Add tests for aoef loading --- batdetect2/data/__init__.py | 9 +- batdetect2/data/annotations.py | 33 - batdetect2/data/annotations/__init__.py | 57 +- batdetect2/data/annotations/aeof.py | 37 - batdetect2/data/annotations/aoef.py | 270 ++++++ batdetect2/data/annotations/types.py | 9 +- batdetect2/data/data.py | 37 - batdetect2/data/datasets.py | 207 +++++ batdetect2/data/types.py | 29 - pyproject.toml | 2 +- tests/conftest.py | 109 ++- tests/test_data/test_annotations/__init__.py | 0 tests/test_data/test_annotations/test_aoef.py | 784 ++++++++++++++++++ tests/test_train/test_augmentations.py | 24 +- 14 files changed, 1443 insertions(+), 164 deletions(-) delete mode 100644 batdetect2/data/annotations.py delete mode 100644 batdetect2/data/annotations/aeof.py create mode 100644 batdetect2/data/annotations/aoef.py delete mode 100644 batdetect2/data/data.py create mode 100644 batdetect2/data/datasets.py delete mode 100644 batdetect2/data/types.py create mode 100644 tests/test_data/test_annotations/__init__.py create mode 100644 tests/test_data/test_annotations/test_aoef.py diff --git a/batdetect2/data/__init__.py b/batdetect2/data/__init__.py index 104dfdb..8d10f08 100644 --- a/batdetect2/data/__init__.py +++ b/batdetect2/data/__init__.py @@ -5,15 +5,18 @@ from batdetect2.data.annotations import ( BatDetect2MergedAnnotations, load_annotated_dataset, ) -from batdetect2.data.data import load_dataset, load_dataset_from_config -from batdetect2.data.types import Dataset +from batdetect2.data.datasets import ( + DatasetConfig, + load_dataset, + load_dataset_from_config, +) __all__ = [ "AOEFAnnotations", "AnnotatedDataset", "BatDetect2FilesAnnotations", "BatDetect2MergedAnnotations", - "Dataset", + "DatasetConfig", "load_annotated_dataset", "load_dataset", "load_dataset_from_config", diff --git a/batdetect2/data/annotations.py b/batdetect2/data/annotations.py deleted file mode 100644 index a1e6ed2..0000000 --- a/batdetect2/data/annotations.py +++ /dev/null @@ -1,33 +0,0 @@ -from pathlib import Path -from typing import Literal, Union - -from batdetect2.configs import BaseConfig - -__all__ = [ - "AOEFAnnotationFile", - "AnnotationFormats", - "BatDetect2AnnotationFile", - "BatDetect2AnnotationFiles", -] - - -class BatDetect2AnnotationFiles(BaseConfig): - format: Literal["batdetect2"] = "batdetect2" - path: Path - - -class BatDetect2AnnotationFile(BaseConfig): - format: Literal["batdetect2_file"] = "batdetect2_file" - path: Path - - -class AOEFAnnotationFile(BaseConfig): - format: Literal["aoef"] = "aoef" - path: Path - - -AnnotationFormats = Union[ - BatDetect2AnnotationFiles, - BatDetect2AnnotationFile, - AOEFAnnotationFile, -] diff --git a/batdetect2/data/annotations/__init__.py b/batdetect2/data/annotations/__init__.py index 5193cc1..e642054 100644 --- a/batdetect2/data/annotations/__init__.py +++ b/batdetect2/data/annotations/__init__.py @@ -1,9 +1,24 @@ +"""Handles loading of annotation data from various formats. + +This module serves as the central dispatcher for parsing annotation data +associated with BatDetect2 datasets. Datasets can be composed of multiple +sources, each potentially using a different annotation format (e.g., the +standard AOEF/soundevent format, or legacy BatDetect2 formats). + +This module defines the `AnnotationFormats` type, which represents the union +of possible configuration models for these different formats (each identified +by a unique `format` field). The primary function, `load_annotated_dataset`, +inspects the configuration for a single data source and calls the appropriate +format-specific loading function to retrieve the annotations as a standard +`soundevent.data.AnnotationSet`. +""" + from pathlib import Path from typing import Optional, Union from soundevent import data -from batdetect2.data.annotations.aeof import ( +from batdetect2.data.annotations.aoef import ( AOEFAnnotations, load_aoef_annotated_dataset, ) @@ -32,12 +47,52 @@ AnnotationFormats = Union[ BatDetect2FilesAnnotations, AOEFAnnotations, ] +"""Type Alias representing all supported data source configurations. + +Each specific configuration model within this union (e.g., `AOEFAnnotations`, +`BatDetect2FilesAnnotations`) corresponds to a different annotation format +or storage structure. These models are typically discriminated by a `format` +field (e.g., `format="aoef"`, `format="batdetect2_files"`), allowing Pydantic +and functions like `load_annotated_dataset` to determine which format a given +source configuration represents. +""" def load_annotated_dataset( dataset: AnnotatedDataset, base_dir: Optional[Path] = None, ) -> data.AnnotationSet: + """Load annotations for a single data source based on its configuration. + + This function acts as a dispatcher. It inspects the type of the input + `source_config` object (which corresponds to a specific annotation format) + and calls the appropriate loading function (e.g., + `load_aoef_annotated_dataset` for `AOEFAnnotations`). + + Parameters + ---------- + source_config : AnnotationFormats + The configuration object for the data source, specifying its format + and necessary details (like paths). Must be an instance of one of the + types included in the `AnnotationFormats` union. + base_dir : Path, optional + An optional base directory path. If provided, relative paths within + the `source_config` might be resolved relative to this directory by + the underlying loading functions. Defaults to None. + + Returns + ------- + soundevent.data.AnnotationSet + An AnnotationSet containing the `ClipAnnotation` objects loaded and + parsed from the specified data source. + + Raises + ------ + NotImplementedError + If the type of the `source_config` object does not match any of the + known format-specific loading functions implemented in the dispatch + logic. + """ if isinstance(dataset, AOEFAnnotations): return load_aoef_annotated_dataset(dataset, base_dir=base_dir) diff --git a/batdetect2/data/annotations/aeof.py b/batdetect2/data/annotations/aeof.py deleted file mode 100644 index e634a02..0000000 --- a/batdetect2/data/annotations/aeof.py +++ /dev/null @@ -1,37 +0,0 @@ -from pathlib import Path -from typing import Literal, Optional - -from soundevent import data, io - -from batdetect2.data.annotations.types import AnnotatedDataset - -__all__ = [ - "AOEFAnnotations", - "load_aoef_annotated_dataset", -] - - -class AOEFAnnotations(AnnotatedDataset): - format: Literal["aoef"] = "aoef" - annotations_path: Path - - -def load_aoef_annotated_dataset( - dataset: AOEFAnnotations, - base_dir: Optional[Path] = None, -) -> data.AnnotationSet: - audio_dir = dataset.audio_dir - path = dataset.annotations_path - - if base_dir: - audio_dir = base_dir / audio_dir - path = base_dir / path - - loaded = io.load(path, audio_dir=audio_dir) - - if not isinstance(loaded, (data.AnnotationSet, data.AnnotationProject)): - raise ValueError( - f"The AOEF file at {path} does not contain a set of annotations" - ) - - return loaded diff --git a/batdetect2/data/annotations/aoef.py b/batdetect2/data/annotations/aoef.py new file mode 100644 index 0000000..f57393b --- /dev/null +++ b/batdetect2/data/annotations/aoef.py @@ -0,0 +1,270 @@ +"""Loads annotation data specifically from the AOEF / soundevent format. + +This module provides the necessary configuration model and loading function +to handle data sources where annotations are stored in the standard format +used by the `soundevent` library (often as `.json` or `.aoef` files), +which includes outputs from annotation tools like Whombat. + +It supports loading both simple `AnnotationSet` files and more complex +`AnnotationProject` files. For `AnnotationProject` files, it offers optional +filtering capabilities to select only annotations associated with tasks +that meet specific status criteria (e.g., completed, verified, without issues). +""" + +from pathlib import Path +from typing import Literal, Optional +from uuid import uuid5 + +from pydantic import Field +from soundevent import data, io + +from batdetect2.configs import BaseConfig +from batdetect2.data.annotations.types import AnnotatedDataset + +__all__ = [ + "AOEFAnnotations", + "load_aoef_annotated_dataset", + "AnnotationTaskFilter", +] + + +class AnnotationTaskFilter(BaseConfig): + """Configuration for filtering Annotation Tasks within an AnnotationProject. + + Specifies criteria based on task status badges to select relevant + annotations, typically used when loading data from annotation projects + that might contain work-in-progress. + + Attributes + ---------- + only_completed : bool, default=True + If True, only include annotations from tasks marked as 'completed'. + only_verified : bool, default=False + If True, only include annotations from tasks marked as 'verified'. + exclude_issues : bool, default=True + If True, exclude annotations from tasks marked as 'rejected' (indicating + issues). + """ + + only_completed: bool = True + only_verified: bool = False + exclude_issues: bool = True + + +class AOEFAnnotations(AnnotatedDataset): + """Configuration defining a data source stored in AOEF format. + + This model specifies how to load annotations from an AOEF (JSON file) file + compatible with the `soundevent` library. It inherits `name`, + `description`, and `audio_dir` from `AnnotatedDataset`. + + Attributes + ---------- + format : Literal["aoef"] + The fixed format identifier for this configuration type. + annotations_path : Path + The file system path to the `.aoef` or `.json` file containing the + `AnnotationSet` or `AnnotationProject`. + filter : AnnotationTaskFilter, optional + Configuration for filtering tasks if the `annotations_path` points to + an `AnnotationProject`. If omitted, default filtering + (only completed, exclude issues, verification not required) is applied + to projects. Set explicitly to `None` in config (e.g., `filter: null`) + to disable filtering for projects entirely. + """ + + format: Literal["aoef"] = "aoef" + + annotations_path: Path + + filter: Optional[AnnotationTaskFilter] = Field( + default_factory=AnnotationTaskFilter + ) + + +def load_aoef_annotated_dataset( + dataset: AOEFAnnotations, + base_dir: Optional[Path] = None, +) -> data.AnnotationSet: + """Load annotations from an AnnotationSet or AnnotationProject file. + + Reads the file specified in the `dataset` configuration using + `soundevent.io.load`. If the loaded file contains an `AnnotationProject` + and filtering is enabled via `dataset.filter`, it applies the filter + criteria based on task status and returns a new `AnnotationSet` containing + only the selected annotations. If the file contains an `AnnotationSet`, + or if it's a project and filtering is disabled, the all annotations are + returned. + + Parameters + ---------- + dataset : AOEFAnnotations + The configuration object describing the AOEF data source, including + the path to the annotation file and optional filtering settings. + base_dir : Path, optional + An optional base directory. If provided, `dataset.annotations_path` + and `dataset.audio_dir` will be resolved relative to this + directory. Defaults to None. + + Returns + ------- + soundevent.data.AnnotationSet + An AnnotationSet containing the loaded (and potentially filtered) + `ClipAnnotation` objects. + + Raises + ------ + FileNotFoundError + If the specified `annotations_path` (after resolving `base_dir`) + does not exist. + ValueError + If the loaded file does not contain a valid `AnnotationSet` or + `AnnotationProject`. + Exception + May re-raise errors from `soundevent.io.load` related to parsing + or file format issues. + + Notes + ----- + - The `soundevent` library handles parsing of `.json` or `.aoef` formats. + - If an `AnnotationProject` is loaded and `dataset.filter` is *not* None, + a *new* `AnnotationSet` instance is created containing only the filtered + clip annotations. + """ + audio_dir = dataset.audio_dir + path = dataset.annotations_path + + if base_dir: + audio_dir = base_dir / audio_dir + path = base_dir / path + + loaded = io.load(path, audio_dir=audio_dir) + + if not isinstance(loaded, (data.AnnotationSet, data.AnnotationProject)): + raise ValueError( + f"The file at {path} loaded successfully but does not " + "contain a soundevent AnnotationSet or AnnotationProject " + f"(loaded type: {type(loaded).__name__})." + ) + + if isinstance(loaded, data.AnnotationProject) and dataset.filter: + loaded = filter_ready_clips( + loaded, + only_completed=dataset.filter.only_completed, + only_verified=dataset.filter.only_verified, + exclude_issues=dataset.filter.exclude_issues, + ) + + return loaded + + +def select_task( + annotation_task: data.AnnotationTask, + only_completed: bool = True, + only_verified: bool = False, + exclude_issues: bool = True, +) -> bool: + """Check if an AnnotationTask meets specified status criteria. + + Evaluates the `status_badges` of the task against the filter flags. + + Parameters + ---------- + annotation_task : data.AnnotationTask + The annotation task to check. + only_completed : bool, default=True + Task must be marked 'completed' to pass. + only_verified : bool, default=False + Task must be marked 'verified' to pass. + exclude_issues : bool, default=True + Task must *not* be marked 'rejected' (have issues) to pass. + + Returns + ------- + bool + True if the task meets all active filter criteria, False otherwise. + """ + has_issues = False + is_completed = False + is_verified = False + + for badge in annotation_task.status_badges: + if badge.state == data.AnnotationState.completed: + is_completed = True + continue + + if badge.state == data.AnnotationState.rejected: + has_issues = True + continue + + if badge.state == data.AnnotationState.verified: + is_verified = True + + if exclude_issues and has_issues: + return False + + if only_verified and not is_verified: + return False + + if only_completed and not is_completed: + return False + + return True + + +def filter_ready_clips( + annotation_project: data.AnnotationProject, + only_completed: bool = True, + only_verified: bool = False, + exclude_issues: bool = True, +) -> data.AnnotationSet: + """Filter AnnotationProject to create an AnnotationSet of 'ready' clips. + + Iterates through tasks in the project, selects tasks meeting the status + criteria using `select_task`, and creates a new `AnnotationSet` containing + only the `ClipAnnotation` objects associated with those selected tasks. + + Parameters + ---------- + annotation_project : data.AnnotationProject + The input annotation project. + only_completed : bool, default=True + Filter flag passed to `select_task`. + only_verified : bool, default=False + Filter flag passed to `select_task`. + exclude_issues : bool, default=True + Filter flag passed to `select_task`. + + Returns + ------- + data.AnnotationSet + A new annotation set containing only the clip annotations linked to + tasks that satisfied the filtering criteria. The returned set has a + deterministic UUID based on the project UUID and filter settings. + """ + ready_clip_uuids = set() + + for annotation_task in annotation_project.tasks: + if not select_task( + annotation_task, + only_completed=only_completed, + only_verified=only_verified, + exclude_issues=exclude_issues, + ): + continue + + ready_clip_uuids.add(annotation_task.clip.uuid) + + return data.AnnotationSet( + uuid=uuid5( + annotation_project.uuid, + f"{only_completed}_{only_verified}_{exclude_issues}", + ), + name=annotation_project.name, + description=annotation_project.description, + clip_annotations=[ + annotation + for annotation in annotation_project.clip_annotations + if annotation.clip.uuid in ready_clip_uuids + ], + ) diff --git a/batdetect2/data/annotations/types.py b/batdetect2/data/annotations/types.py index 188eb3d..73adf83 100644 --- a/batdetect2/data/annotations/types.py +++ b/batdetect2/data/annotations/types.py @@ -17,11 +17,10 @@ class AnnotatedDataset(BaseConfig): Annotations associated with these recordings are defined by the `annotations` field, which supports various formats (e.g., AOEF files, - specific CSV - structures). - Crucially, file paths referenced within the annotation data *must* be - relative to the `audio_dir`. This ensures that the dataset definition - remains portable across different systems and base directories. + specific CSV structures). Crucially, file paths referenced within the + annotation data *must* be relative to the `audio_dir`. This ensures that + the dataset definition remains portable across different systems and base + directories. Attributes: name: A unique identifier for this data source. diff --git a/batdetect2/data/data.py b/batdetect2/data/data.py deleted file mode 100644 index c227991..0000000 --- a/batdetect2/data/data.py +++ /dev/null @@ -1,37 +0,0 @@ -from pathlib import Path -from typing import Optional - -from soundevent import data - -from batdetect2.configs import load_config -from batdetect2.data.annotations import load_annotated_dataset -from batdetect2.data.types import Dataset - -__all__ = [ - "load_dataset", - "load_dataset_from_config", -] - - -def load_dataset( - dataset: Dataset, - base_dir: Optional[Path] = None, -) -> data.AnnotationSet: - clip_annotations = [] - for source in dataset.sources: - annotated_source = load_annotated_dataset(source, base_dir=base_dir) - clip_annotations.extend(annotated_source.clip_annotations) - return data.AnnotationSet(clip_annotations=clip_annotations) - - -def load_dataset_from_config( - path: data.PathLike, - field: Optional[str] = None, - base_dir: Optional[Path] = None, -): - config = load_config( - path=path, - schema=Dataset, - field=field, - ) - return load_dataset(config, base_dir=base_dir) diff --git a/batdetect2/data/datasets.py b/batdetect2/data/datasets.py new file mode 100644 index 0000000..f8d94be --- /dev/null +++ b/batdetect2/data/datasets.py @@ -0,0 +1,207 @@ +"""Defines the overall dataset structure and provides loading/saving utilities. + +This module focuses on defining what constitutes a BatDetect2 dataset, +potentially composed of multiple distinct data sources with varying annotation +formats. It provides mechanisms to load the annotation metadata from these +sources into a unified representation. + +The core components are: +- `DatasetConfig`: A configuration class (typically loaded from YAML) that + describes the dataset's name, description, and constituent sources. +- `Dataset`: A type alias representing the loaded dataset as a list of + `soundevent.data.ClipAnnotation` objects. Note that this implies all + annotation metadata is loaded into memory. +- Loading functions (`load_dataset`, `load_dataset_from_config`): To parse + a `DatasetConfig` and load the corresponding annotation metadata. +- Saving function (`save_dataset`): To save a loaded list of annotations + into a standard `soundevent` format. + +""" + +from pathlib import Path +from typing import Annotated, List, Optional + +from pydantic import Field +from soundevent import data, io + +from batdetect2.configs import BaseConfig, load_config +from batdetect2.data.annotations import ( + AnnotationFormats, + load_annotated_dataset, +) + +__all__ = [ + "load_dataset", + "load_dataset_from_config", + "save_dataset", + "Dataset", + "DatasetConfig", +] + + +Dataset = List[data.ClipAnnotation] +"""Type alias for a loaded dataset representation. + +Represents an entire dataset *after loading* as a flat Python list containing +all `soundevent.data.ClipAnnotation` objects gathered from all configured data +sources. +""" + + +class DatasetConfig(BaseConfig): + """Configuration model defining the structure of a BatDetect2 dataset. + + This class is typically loaded from a YAML file and describes the components + of the dataset, including metadata and a list of data sources. + + Attributes + ---------- + name : str + A descriptive name for the dataset (e.g., "UK_Bats_Project_2024"). + description : str + A longer description of the dataset's contents, origin, purpose, etc. + sources : List[AnnotationFormats] + A list defining the different data sources contributing to this + dataset. Each item in the list must conform to one of the Pydantic + models defined in the `AnnotationFormats` type union. The specific + model used for each source is determined by the mandatory `format` + field within the source's configuration, allowing BatDetect2 to use the + correct parser for different annotation styles. + """ + + name: str + description: str + sources: List[ + Annotated[AnnotationFormats, Field(..., discriminator="format")] + ] + + +def load_dataset( + dataset: DatasetConfig, + base_dir: Optional[Path] = None, +) -> Dataset: + """Load all clip annotations from the sources defined in a DatasetConfig. + + Iterates through each data source specified in the `dataset_config`, + delegates the loading and parsing of that source's annotations to + `batdetect2.data.annotations.load_annotated_dataset` (which handles + different data formats), and aggregates all resulting `ClipAnnotation` + objects into a single flat list. + + Parameters + ---------- + dataset_config : DatasetConfig + The configuration object describing the dataset and its sources. + base_dir : Path, optional + An optional base directory path. If provided, relative paths for + metadata files or data directories within the `dataset_config`'s + sources might be resolved relative to this directory. Defaults to None. + + Returns + ------- + Dataset (List[data.ClipAnnotation]) + A flat list containing all loaded `ClipAnnotation` metadata objects + from all specified sources. + + Raises + ------ + Exception + Can raise various exceptions during the delegated loading process + (`load_annotated_dataset`) if files are not found, cannot be parsed + according to the specified format, or other I/O errors occur. + """ + clip_annotations = [] + for source in dataset.sources: + annotated_source = load_annotated_dataset(source, base_dir=base_dir) + clip_annotations.extend(annotated_source.clip_annotations) + return clip_annotations + + +def load_dataset_from_config( + path: data.PathLike, + field: Optional[str] = None, + base_dir: Optional[Path] = None, +) -> Dataset: + """Load dataset annotation metadata from a configuration file. + + This is a convenience function that first loads the `DatasetConfig` from + the specified file path and optional nested field, and then calls + `load_dataset` to load all corresponding `ClipAnnotation` objects. + + Parameters + ---------- + path : data.PathLike + Path to the configuration file (e.g., YAML). + field : str, optional + Dot-separated path to a nested section within the file containing the + dataset configuration (e.g., "data.training_set"). If None, the + entire file content is assumed to be the `DatasetConfig`. + base_dir : Path, optional + An optional base directory path to resolve relative paths within the + configuration sources. Passed to `load_dataset`. Defaults to None. + + Returns + ------- + Dataset (List[data.ClipAnnotation]) + A flat list containing all loaded `ClipAnnotation` metadata objects. + + Raises + ------ + FileNotFoundError + If the config file `path` does not exist. + yaml.YAMLError, pydantic.ValidationError, KeyError, TypeError + If the configuration file is invalid, cannot be parsed, or does not + match the `DatasetConfig` schema. + Exception + Can raise exceptions from `load_dataset` if loading data from sources + fails. + """ + config = load_config( + path=path, + schema=DatasetConfig, + field=field, + ) + return load_dataset(config, base_dir=base_dir) + + +def save_dataset( + dataset: Dataset, + path: data.PathLike, + name: Optional[str] = None, + description: Optional[str] = None, + audio_dir: Optional[Path] = None, +) -> None: + """Save a loaded dataset (list of ClipAnnotations) to a file. + + Wraps the provided list of `ClipAnnotation` objects into a + `soundevent.data.AnnotationSet` and saves it using `soundevent.io.save`. + This saves the aggregated annotation metadata in the standard soundevent + format. + + Note: This function saves the *loaded annotation data*, not the original + `DatasetConfig` structure that defined how the data was assembled from + various sources. + + Parameters + ---------- + dataset : Dataset (List[data.ClipAnnotation]) + The list of clip annotations to save (typically the result of + `load_dataset` or a split thereof). + path : data.PathLike + The output file path (e.g., 'train_annotations.json', + 'val_annotations.json'). The format is determined by `soundevent.io`. + name : str, optional + An optional name to assign to the saved `AnnotationSet`. + description : str, optional + An optional description to assign to the saved `AnnotationSet`. + audio_dir : Path, optional + Passed to `soundevent.io.save`. May be used to relativize audio file + paths within the saved annotations if applicable to the save format. + """ + + annotation_set = data.AnnotationSet( + name=name, + description=description, + clip_annotations=dataset, + ) + io.save(annotation_set, path, audio_dir=audio_dir) diff --git a/batdetect2/data/types.py b/batdetect2/data/types.py deleted file mode 100644 index 237f215..0000000 --- a/batdetect2/data/types.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Annotated, List - -from pydantic import Field - -from batdetect2.configs import BaseConfig -from batdetect2.data.annotations import AnnotationFormats - - -class Dataset(BaseConfig): - """Represents a collection of one or more DatasetSources. - - In the context of batdetect2, a Dataset aggregates multiple `DatasetSource` - instances. It serves as the primary unit for defining data splits, - typically used for model training, validation, or testing phases. - - Attributes: - name: A descriptive name for the overall dataset - (e.g., "UK Training Set"). - description: A detailed explanation of the dataset's purpose, - composition, how it was assembled, or any specific characteristics. - sources: A list containing the `DatasetSource` objects included in this - dataset. - """ - - name: str - description: str - sources: List[ - Annotated[AnnotationFormats, Field(..., discriminator="format")] - ] diff --git a/pyproject.toml b/pyproject.toml index 7685559..821720a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "torch>=1.13.1,<2.5.0", "torchaudio>=1.13.1,<2.5.0", "torchvision>=0.14.0", - "soundevent[audio,geometry,plot]>=2.3", + "soundevent[audio,geometry,plot]>=2.4.1", "click>=8.1.7", "netcdf4>=1.6.5", "tqdm>=4.66.2", diff --git a/tests/conftest.py b/tests/conftest.py index 99c9877..05fdcbb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -86,8 +86,8 @@ def wav_factory(tmp_path: Path): @pytest.fixture -def recording_factory(wav_factory: Callable[..., Path]): - def _recording_factory( +def create_recording(wav_factory: Callable[..., Path]): + def factory( tags: Optional[list[data.Tag]] = None, path: Optional[Path] = None, recording_id: Optional[uuid.UUID] = None, @@ -96,7 +96,8 @@ def recording_factory(wav_factory: Callable[..., Path]): samplerate: int = 256_000, time_expansion: float = 1, ) -> data.Recording: - path = path or wav_factory( + path = wav_factory( + path=path, duration=duration, channels=channels, samplerate=samplerate, @@ -108,14 +109,30 @@ def recording_factory(wav_factory: Callable[..., Path]): tags=tags or [], ) - return _recording_factory + return factory @pytest.fixture def recording( - recording_factory: Callable[..., data.Recording], + create_recording: Callable[..., data.Recording], ) -> data.Recording: - return recording_factory() + return create_recording() + + +@pytest.fixture +def create_clip(): + def factory( + recording: data.Recording, + start_time: float = 0, + end_time: float = 0.5, + ) -> data.Clip: + return data.Clip( + recording=recording, + start_time=start_time, + end_time=end_time, + ) + + return factory @pytest.fixture @@ -123,6 +140,22 @@ def clip(recording: data.Recording) -> data.Clip: return data.Clip(recording=recording, start_time=0, end_time=0.5) +@pytest.fixture +def create_sound_event(): + def factory( + recording: data.Recording, + coords: Optional[List[float]] = None, + ) -> data.SoundEvent: + coords = coords or [0.2, 60_000, 0.3, 70_000] + + return data.SoundEvent( + geometry=data.BoundingBox(coordinates=coords), + recording=recording, + ) + + return factory + + @pytest.fixture def sound_event(recording: data.Recording) -> data.SoundEvent: return data.SoundEvent( @@ -131,6 +164,20 @@ def sound_event(recording: data.Recording) -> data.SoundEvent: ) +@pytest.fixture +def create_sound_event_annotation(): + def factory( + sound_event: data.SoundEvent, + tags: Optional[List[data.Tag]] = None, + ) -> data.SoundEventAnnotation: + return data.SoundEventAnnotation( + sound_event=sound_event, + tags=tags or [], + ) + + return factory + + @pytest.fixture def echolocation_call(recording: data.Recording) -> data.SoundEventAnnotation: return data.SoundEventAnnotation( @@ -181,6 +228,22 @@ def non_relevant_sound_event( ) +@pytest.fixture +def create_clip_annotation(): + def factory( + clip: data.Clip, + clip_tags: Optional[List[data.Tag]] = None, + sound_events: Optional[List[data.SoundEventAnnotation]] = None, + ) -> data.ClipAnnotation: + return data.ClipAnnotation( + clip=clip, + tags=clip_tags or [], + sound_events=sound_events or [], + ) + + return factory + + @pytest.fixture def clip_annotation( clip: data.Clip, @@ -196,3 +259,37 @@ def clip_annotation( non_relevant_sound_event, ], ) + + +@pytest.fixture +def create_annotation_set(): + def factory( + name: str = "test", + description: str = "Test annotation set", + annotations: Optional[List[data.ClipAnnotation]] = None, + ) -> data.AnnotationSet: + return data.AnnotationSet( + name=name, + description=description, + clip_annotations=annotations or [], + ) + + return factory + + +@pytest.fixture +def create_annotation_project(): + def factory( + name: str = "test_project", + description: str = "Test Annotation Project", + tasks: Optional[List[data.AnnotationTask]] = None, + annotations: Optional[List[data.ClipAnnotation]] = None, + ) -> data.AnnotationProject: + return data.AnnotationProject( + name=name, + description=description, + tasks=tasks or [], + clip_annotations=annotations or [], + ) + + return factory diff --git a/tests/test_data/test_annotations/__init__.py b/tests/test_data/test_annotations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_data/test_annotations/test_aoef.py b/tests/test_data/test_annotations/test_aoef.py new file mode 100644 index 0000000..0bedfd5 --- /dev/null +++ b/tests/test_data/test_annotations/test_aoef.py @@ -0,0 +1,784 @@ +import uuid +from pathlib import Path +from typing import Callable, Optional, Sequence + +import pytest +from pydantic import ValidationError +from soundevent import data, io +from soundevent.data.annotation_tasks import AnnotationState + +from batdetect2.data.annotations import aoef + + +@pytest.fixture +def base_dir(tmp_path: Path) -> Path: + path = tmp_path / "base_dir" + path.mkdir(parents=True, exist_ok=True) + return path + + +@pytest.fixture +def audio_dir(base_dir: Path) -> Path: + path = base_dir / "audio" + path.mkdir(parents=True, exist_ok=True) + return path + + +@pytest.fixture +def anns_dir(base_dir: Path) -> Path: + path = base_dir / "annotations" + path.mkdir(parents=True, exist_ok=True) + return path + + +def create_task( + clip: data.Clip, + badges: list[data.StatusBadge], + task_id: Optional[uuid.UUID] = None, +) -> data.AnnotationTask: + """Creates a simple AnnotationTask for testing.""" + return data.AnnotationTask( + uuid=task_id or uuid.uuid4(), + clip=clip, + status_badges=badges, + ) + + +def test_annotation_task_filter_defaults(): + """Test default values of AnnotationTaskFilter.""" + f = aoef.AnnotationTaskFilter() + assert f.only_completed is True + assert f.only_verified is False + assert f.exclude_issues is True + + +def test_annotation_task_filter_initialization(): + """Test initialization of AnnotationTaskFilter with non-default values.""" + f = aoef.AnnotationTaskFilter( + only_completed=False, + only_verified=True, + exclude_issues=False, + ) + assert f.only_completed is False + assert f.only_verified is True + assert f.exclude_issues is False + + +def test_aoef_annotations_defaults( + audio_dir: Path, + anns_dir: Path, +): + """Test default values of AOEFAnnotations.""" + annotations_path = anns_dir / "test.aoef" + config = aoef.AOEFAnnotations( + name="default_name", + audio_dir=audio_dir, + annotations_path=annotations_path, + ) + assert config.format == "aoef" + assert config.annotations_path == annotations_path + assert config.audio_dir == audio_dir + assert isinstance(config.filter, aoef.AnnotationTaskFilter) + assert config.filter.only_completed is True + assert config.filter.only_verified is False + assert config.filter.exclude_issues is True + + +def test_aoef_annotations_initialization(tmp_path): + """Test initialization of AOEFAnnotations with specific values.""" + annotations_path = tmp_path / "custom.json" + audio_dir = Path("audio/files") + custom_filter = aoef.AnnotationTaskFilter( + only_completed=False, only_verified=True + ) + config = aoef.AOEFAnnotations( + name="custom_name", + description="custom_desc", + audio_dir=audio_dir, + annotations_path=annotations_path, + filter=custom_filter, + ) + assert config.name == "custom_name" + assert config.description == "custom_desc" + assert config.format == "aoef" + assert config.audio_dir == audio_dir + assert config.annotations_path == annotations_path + assert config.filter is custom_filter + + +def test_aoef_annotations_initialization_no_filter(tmp_path): + """Test initialization of AOEFAnnotations with filter=None.""" + annotations_path = tmp_path / "no_filter.aoef" + audio_dir = tmp_path / "audio" + config = aoef.AOEFAnnotations( + name="no_filter_name", + description="no_filter_desc", + audio_dir=audio_dir, + annotations_path=annotations_path, + filter=None, + ) + assert config.format == "aoef" + assert config.annotations_path == annotations_path + assert config.filter is None + + +def test_aoef_annotations_validation_error(tmp_path): + """Test Pydantic validation for missing required fields.""" + with pytest.raises(ValidationError, match="annotations_path"): + aoef.AOEFAnnotations( # type: ignore + name="test_name", + audio_dir=tmp_path, + ) + with pytest.raises(ValidationError, match="name"): + aoef.AOEFAnnotations( # type: ignore + annotations_path=tmp_path / "dummy.aoef", + audio_dir=tmp_path, + ) + + +@pytest.mark.parametrize( + "badges, only_completed, only_verified, exclude_issues, expected", + [ + ([], True, False, True, False), # No badges -> not completed + ( + [data.StatusBadge(state=AnnotationState.completed)], + True, + False, + True, + True, + ), + ( + [data.StatusBadge(state=AnnotationState.verified)], + True, + False, + True, + False, + ), # Not completed + ( + [data.StatusBadge(state=AnnotationState.rejected)], + True, + False, + True, + False, + ), # Has issues + ( + [ + data.StatusBadge(state=AnnotationState.completed), + data.StatusBadge(state=AnnotationState.rejected), + ], + True, + False, + True, + False, + ), # Completed but has issues + ( + [ + data.StatusBadge(state=AnnotationState.completed), + data.StatusBadge(state=AnnotationState.verified), + ], + True, + False, + True, + True, + ), # Completed, verified doesn't matter + # Verified only (completed=F, verified=T, exclude_issues=T) + ( + [data.StatusBadge(state=AnnotationState.verified)], + False, + True, + True, + True, + ), + ( + [data.StatusBadge(state=AnnotationState.completed)], + False, + True, + True, + False, + ), # Not verified + ( + [ + data.StatusBadge(state=AnnotationState.verified), + data.StatusBadge(state=AnnotationState.rejected), + ], + False, + True, + True, + False, + ), # Verified but has issues + # Completed AND Verified (completed=T, verified=T, exclude_issues=T) + ( + [ + data.StatusBadge(state=AnnotationState.completed), + data.StatusBadge(state=AnnotationState.verified), + ], + True, + True, + True, + True, + ), + ( + [data.StatusBadge(state=AnnotationState.completed)], + True, + True, + True, + False, + ), # Not verified + ( + [data.StatusBadge(state=AnnotationState.verified)], + True, + True, + True, + False, + ), # Not completed + # Include Issues (completed=T, verified=F, exclude_issues=F) + ( + [ + data.StatusBadge(state=AnnotationState.completed), + data.StatusBadge(state=AnnotationState.rejected), + ], + True, + False, + False, + True, + ), # Completed, issues allowed + ( + [data.StatusBadge(state=AnnotationState.rejected)], + True, + False, + False, + False, + ), # Has issues, but not completed + # No filters (completed=F, verified=F, exclude_issues=F) + ([], False, False, False, True), + ( + [data.StatusBadge(state=AnnotationState.rejected)], + False, + False, + False, + True, + ), + ( + [data.StatusBadge(state=AnnotationState.completed)], + False, + False, + False, + True, + ), + ( + [data.StatusBadge(state=AnnotationState.verified)], + False, + False, + False, + True, + ), + ], +) +def test_select_task( + badges: Sequence[data.StatusBadge], + only_completed: bool, + only_verified: bool, + exclude_issues: bool, + expected: bool, + create_recording: Callable[..., data.Recording], + create_clip: Callable[..., data.Clip], +): + """Test select_task logic with various badge and filter combinations.""" + rec = create_recording() + clip = create_clip(rec) + task = create_task(clip, badges=list(badges)) + result = aoef.select_task( + task, + only_completed=only_completed, + only_verified=only_verified, + exclude_issues=exclude_issues, + ) + assert result == expected + + +def test_filter_ready_clips_default( + tmp_path: Path, + create_recording: Callable[..., data.Recording], + create_clip: Callable[..., data.Clip], + create_clip_annotation: Callable[..., data.ClipAnnotation], + create_annotation_project: Callable[..., data.AnnotationProject], +): + """Test filter_ready_clips with default filtering.""" + rec = create_recording(path=tmp_path / "rec.wav") + clip_completed = create_clip(rec, 0, 1) + clip_verified = create_clip(rec, 1, 2) + clip_rejected = create_clip(rec, 2, 3) + clip_completed_rejected = create_clip(rec, 3, 4) + clip_no_badges = create_clip(rec, 4, 5) + + task_completed = create_task( + clip_completed, [data.StatusBadge(state=AnnotationState.completed)] + ) + task_verified = create_task( + clip_verified, [data.StatusBadge(state=AnnotationState.verified)] + ) + task_rejected = create_task( + clip_rejected, [data.StatusBadge(state=AnnotationState.rejected)] + ) + task_completed_rejected = create_task( + clip_completed_rejected, + [ + data.StatusBadge(state=AnnotationState.completed), + data.StatusBadge(state=AnnotationState.rejected), + ], + ) + task_no_badges = create_task(clip_no_badges, []) + + ann_completed = create_clip_annotation(clip_completed) + ann_verified = create_clip_annotation(clip_verified) + ann_rejected = create_clip_annotation(clip_rejected) + ann_completed_rejected = create_clip_annotation(clip_completed_rejected) + ann_no_badges = create_clip_annotation(clip_no_badges) + + project = create_annotation_project( + name="FilterTestProject", + description="Project for testing filters", + tasks=[ + task_completed, + task_verified, + task_rejected, + task_completed_rejected, + task_no_badges, + ], + annotations=[ + ann_completed, + ann_verified, + ann_rejected, + ann_completed_rejected, + ann_no_badges, + ], + ) + + filtered_set = aoef.filter_ready_clips(project) + + assert isinstance(filtered_set, data.AnnotationSet) + assert filtered_set.name == project.name + assert filtered_set.description == project.description + assert len(filtered_set.clip_annotations) == 1 + assert filtered_set.clip_annotations[0].clip.uuid == clip_completed.uuid + + expected_uuid = uuid.uuid5(project.uuid, f"{True}_{False}_{True}") + assert filtered_set.uuid == expected_uuid + + +def test_filter_ready_clips_custom_filter( + tmp_path: Path, + create_recording: Callable[..., data.Recording], + create_clip: Callable[..., data.Clip], + create_clip_annotation: Callable[..., data.ClipAnnotation], + create_annotation_project: Callable[..., data.AnnotationProject], +): + """Test filter_ready_clips with custom filtering (verified=T, issues=F).""" + rec = create_recording(path=tmp_path / "rec.wav") + clip_completed = create_clip(rec, 0, 1) + clip_verified = create_clip(rec, 1, 2) + clip_rejected = create_clip(rec, 2, 3) + clip_completed_verified = create_clip(rec, 3, 4) + clip_verified_rejected = create_clip(rec, 4, 5) + + task_completed = create_task( + clip_completed, [data.StatusBadge(state=AnnotationState.completed)] + ) + task_verified = create_task( + clip_verified, [data.StatusBadge(state=AnnotationState.verified)] + ) + task_rejected = create_task( + clip_rejected, [data.StatusBadge(state=AnnotationState.rejected)] + ) + task_completed_verified = create_task( + clip_completed_verified, + [ + data.StatusBadge(state=AnnotationState.completed), + data.StatusBadge(state=AnnotationState.verified), + ], + ) + task_verified_rejected = create_task( + clip_verified_rejected, + [ + data.StatusBadge(state=AnnotationState.verified), + data.StatusBadge(state=AnnotationState.rejected), + ], + ) + + ann_completed = create_clip_annotation(clip_completed) + ann_verified = create_clip_annotation(clip_verified) + ann_rejected = create_clip_annotation(clip_rejected) + ann_completed_verified = create_clip_annotation(clip_completed_verified) + ann_verified_rejected = create_clip_annotation(clip_verified_rejected) + + project = create_annotation_project( + tasks=[ + task_completed, + task_verified, + task_rejected, + task_completed_verified, + task_verified_rejected, + ], + annotations=[ + ann_completed, + ann_verified, + ann_rejected, + ann_completed_verified, + ann_verified_rejected, + ], + ) + + filtered_set = aoef.filter_ready_clips( + project, only_completed=False, only_verified=True, exclude_issues=False + ) + + assert len(filtered_set.clip_annotations) == 3 + filtered_clip_uuids = { + ann.clip.uuid for ann in filtered_set.clip_annotations + } + assert clip_verified.uuid in filtered_clip_uuids + assert clip_completed_verified.uuid in filtered_clip_uuids + assert clip_verified_rejected.uuid in filtered_clip_uuids + + expected_uuid = uuid.uuid5(project.uuid, f"{False}_{True}_{False}") + assert filtered_set.uuid == expected_uuid + + +def test_filter_ready_clips_no_filters( + tmp_path: Path, + create_recording: Callable[..., data.Recording], + create_clip: Callable[..., data.Clip], + create_clip_annotation: Callable[..., data.ClipAnnotation], + create_annotation_project: Callable[..., data.AnnotationProject], +): + """Test filter_ready_clips with all filters disabled.""" + rec = create_recording(path=tmp_path / "rec.wav") + clip1 = create_clip(rec, 0, 1) + clip2 = create_clip(rec, 1, 2) + + task1 = create_task( + clip1, [data.StatusBadge(state=AnnotationState.rejected)] + ) + task2 = create_task(clip2, []) + ann1 = create_clip_annotation(clip1) + ann2 = create_clip_annotation(clip2) + + project = create_annotation_project( + tasks=[task1, task2], annotations=[ann1, ann2] + ) + + filtered_set = aoef.filter_ready_clips( + project, + only_completed=False, + only_verified=False, + exclude_issues=False, + ) + + assert len(filtered_set.clip_annotations) == 2 + filtered_clip_uuids = { + ann.clip.uuid for ann in filtered_set.clip_annotations + } + assert clip1.uuid in filtered_clip_uuids + assert clip2.uuid in filtered_clip_uuids + + expected_uuid = uuid.uuid5(project.uuid, f"{False}_{False}_{False}") + assert filtered_set.uuid == expected_uuid + + +def test_filter_ready_clips_empty_project( + create_annotation_project: Callable[..., data.AnnotationProject], +): + """Test filter_ready_clips with an empty project.""" + project = create_annotation_project(tasks=[], annotations=[]) + filtered_set = aoef.filter_ready_clips(project) + assert len(filtered_set.clip_annotations) == 0 + assert filtered_set.name == project.name + assert filtered_set.description == project.description + + +def test_filter_ready_clips_no_matching_tasks( + tmp_path: Path, + create_recording: Callable[..., data.Recording], + create_clip: Callable[..., data.Clip], + create_clip_annotation: Callable[..., data.ClipAnnotation], + create_annotation_project: Callable[..., data.AnnotationProject], +): + """Test filter_ready_clips when no tasks match the criteria.""" + rec = create_recording(path=tmp_path / "rec.wav") + clip_rejected = create_clip(rec, 0, 1) + + task_rejected = create_task( + clip_rejected, [data.StatusBadge(state=AnnotationState.rejected)] + ) + ann_rejected = create_clip_annotation(clip_rejected) + + project = create_annotation_project( + tasks=[task_rejected], annotations=[ann_rejected] + ) + + filtered_set = aoef.filter_ready_clips(project) + assert len(filtered_set.clip_annotations) == 0 + + +def test_load_aoef_annotated_dataset_set( + tmp_path: Path, + create_recording: Callable[..., data.Recording], + create_clip: Callable[..., data.Clip], + create_clip_annotation: Callable[..., data.ClipAnnotation], + create_annotation_set: Callable[..., data.AnnotationSet], +): + """Test loading a standard AnnotationSet file.""" + rec_path = tmp_path / "audio" / "rec1.wav" + rec_path.parent.mkdir() + rec = create_recording(path=rec_path) + clip = create_clip(rec) + ann = create_clip_annotation(clip) + original_set = create_annotation_set(annotations=[ann]) + + annotations_file = tmp_path / "set.json" + io.save(original_set, annotations_file) + + config = aoef.AOEFAnnotations( + name="test_set_load", + annotations_path=annotations_file, + audio_dir=rec_path.parent, + ) + + loaded_set = aoef.load_aoef_annotated_dataset(config) + + assert isinstance(loaded_set, data.AnnotationSet) + + assert loaded_set.name == original_set.name + assert len(loaded_set.clip_annotations) == len( + original_set.clip_annotations + ) + assert ( + loaded_set.clip_annotations[0].clip.uuid + == original_set.clip_annotations[0].clip.uuid + ) + assert ( + loaded_set.clip_annotations[0].clip.recording.path + == rec_path.resolve() + ) + + +def test_load_aoef_annotated_dataset_project_with_filter( + tmp_path: Path, + create_recording: Callable[..., data.Recording], + create_clip: Callable[..., data.Clip], + create_clip_annotation: Callable[..., data.ClipAnnotation], + create_annotation_project: Callable[..., data.AnnotationProject], +): + """Test loading an AnnotationProject file with filtering enabled.""" + rec_path = tmp_path / "audio" / "rec.wav" + rec_path.parent.mkdir() + rec = create_recording(path=rec_path) + + clip_completed = create_clip(rec, 0, 1) + clip_rejected = create_clip(rec, 1, 2) + + task_completed = create_task( + clip_completed, [data.StatusBadge(state=AnnotationState.completed)] + ) + task_rejected = create_task( + clip_rejected, [data.StatusBadge(state=AnnotationState.rejected)] + ) + + ann_completed = create_clip_annotation(clip_completed) + ann_rejected = create_clip_annotation(clip_rejected) + + project = create_annotation_project( + name="ProjectToFilter", + tasks=[task_completed, task_rejected], + annotations=[ann_completed, ann_rejected], + ) + + annotations_file = tmp_path / "project.json" + io.save(project, annotations_file) + + config = aoef.AOEFAnnotations( + name="test_project_filter_load", + annotations_path=annotations_file, + audio_dir=rec_path.parent, + ) + + loaded_data = aoef.load_aoef_annotated_dataset(config) + + assert isinstance(loaded_data, data.AnnotationSet) + assert loaded_data.name == project.name + assert len(loaded_data.clip_annotations) == 1 + assert loaded_data.clip_annotations[0].clip.uuid == clip_completed.uuid + assert ( + loaded_data.clip_annotations[0].clip.recording.path + == rec_path.resolve() + ) + + +def test_load_aoef_annotated_dataset_project_no_filter( + tmp_path: Path, + create_recording: Callable[..., data.Recording], + create_clip: Callable[..., data.Clip], + create_clip_annotation: Callable[..., data.ClipAnnotation], + create_annotation_project: Callable[..., data.AnnotationProject], +): + """Test loading an AnnotationProject file with filtering disabled.""" + rec_path = tmp_path / "audio" / "rec.wav" + rec_path.parent.mkdir() + rec = create_recording(path=rec_path) + clip1 = create_clip(rec, 0, 1) + clip2 = create_clip(rec, 1, 2) + + task1 = create_task( + clip1, [data.StatusBadge(state=AnnotationState.completed)] + ) + task2 = create_task( + clip2, [data.StatusBadge(state=AnnotationState.rejected)] + ) + ann1 = create_clip_annotation(clip1) + ann2 = create_clip_annotation(clip2) + + original_project = create_annotation_project( + tasks=[task1, task2], annotations=[ann1, ann2] + ) + + annotations_file = tmp_path / "project_nofilter.json" + io.save(original_project, annotations_file) + + config = aoef.AOEFAnnotations( + name="test_project_nofilter_load", + annotations_path=annotations_file, + audio_dir=rec_path.parent, + filter=None, + ) + + loaded_data = aoef.load_aoef_annotated_dataset(config) + + assert isinstance(loaded_data, data.AnnotationProject) + assert loaded_data.uuid == original_project.uuid + assert len(loaded_data.clip_annotations) == 2 + assert ( + loaded_data.clip_annotations[0].clip.recording.path + == rec_path.resolve() + ) + assert ( + loaded_data.clip_annotations[1].clip.recording.path + == rec_path.resolve() + ) + + +def test_load_aoef_annotated_dataset_base_dir( + tmp_path: Path, + create_recording: Callable[..., data.Recording], + create_clip: Callable[..., data.Clip], + create_clip_annotation: Callable[..., data.ClipAnnotation], + create_annotation_project: Callable[..., data.AnnotationProject], +): + """Test loading with a base_dir specified.""" + base = tmp_path / "basedir" + base.mkdir() + audio_rel = Path("audio") + ann_rel = Path("annotations/project.json") + + abs_audio_dir = base / audio_rel + abs_ann_path = base / ann_rel + abs_audio_dir.mkdir(parents=True) + abs_ann_path.parent.mkdir(parents=True) + + rec = create_recording(path=abs_audio_dir / "rec.wav") + rec_path = rec.path + + clip = create_clip(rec) + + task = create_task( + clip, [data.StatusBadge(state=AnnotationState.completed)] + ) + ann = create_clip_annotation(clip) + project = create_annotation_project(tasks=[task], annotations=[ann]) + io.save(project, abs_ann_path) + + config = aoef.AOEFAnnotations( + name="test_base_dir_load", + annotations_path=ann_rel, + audio_dir=audio_rel, + filter=aoef.AnnotationTaskFilter(), + ) + + loaded_set = aoef.load_aoef_annotated_dataset(config, base_dir=base) + + assert isinstance(loaded_set, data.AnnotationSet) + assert len(loaded_set.clip_annotations) == 1 + + assert ( + loaded_set.clip_annotations[0].clip.recording.path + == rec_path.resolve() + ) + + +def test_load_aoef_annotated_dataset_file_not_found(tmp_path): + """Test FileNotFoundError when annotation file doesn't exist.""" + config = aoef.AOEFAnnotations( + name="test_not_found", + annotations_path=tmp_path / "nonexistent.aoef", + audio_dir=tmp_path, + ) + with pytest.raises(FileNotFoundError): + aoef.load_aoef_annotated_dataset(config) + + +def test_load_aoef_annotated_dataset_file_not_found_with_base_dir(tmp_path): + """Test FileNotFoundError with base_dir.""" + base = tmp_path / "base" + base.mkdir() + config = aoef.AOEFAnnotations( + name="test_not_found_base", + annotations_path=Path("nonexistent.aoef"), + audio_dir=Path("audio"), + ) + with pytest.raises(FileNotFoundError): + aoef.load_aoef_annotated_dataset(config, base_dir=base) + + +def test_load_aoef_annotated_dataset_invalid_content(tmp_path): + """Test ValueError when file contains invalid JSON or non-soundevent data.""" + invalid_file = tmp_path / "invalid.json" + invalid_file.write_text("{invalid json") + + config = aoef.AOEFAnnotations( + name="test_invalid_content", + annotations_path=invalid_file, + audio_dir=tmp_path, + ) + with pytest.raises(ValidationError): + aoef.load_aoef_annotated_dataset(config) + + +def test_load_aoef_annotated_dataset_wrong_object_type( + tmp_path: Path, + create_recording: Callable[..., data.Recording], +): + """Test ValueError when file contains correct soundevent obj but wrong type.""" + rec_path = tmp_path / "audio" / "rec.wav" + rec_path.parent.mkdir() + rec = create_recording(path=rec_path) + dataset = data.Dataset( + name="test_wrong_type", + description="Test for wrong type", + recordings=[rec], + ) + + wrong_type_file = tmp_path / "wrong_type.json" + io.save(dataset, wrong_type_file) # type: ignore + + config = aoef.AOEFAnnotations( + name="test_wrong_type", + annotations_path=wrong_type_file, + audio_dir=rec_path.parent, + ) + + with pytest.raises(ValueError) as excinfo: + aoef.load_aoef_annotated_dataset(config) + + assert ( + "does not contain a soundevent AnnotationSet or AnnotationProject" + in str(excinfo.value) + ) diff --git a/tests/test_train/test_augmentations.py b/tests/test_train/test_augmentations.py index e69f107..f5e04f6 100644 --- a/tests/test_train/test_augmentations.py +++ b/tests/test_train/test_augmentations.py @@ -17,10 +17,10 @@ from batdetect2.train.preprocess import ( def test_mix_examples( - recording_factory: Callable[..., data.Recording], + create_recording: Callable[..., data.Recording], ): - recording1 = recording_factory() - recording2 = recording_factory() + recording1 = create_recording() + recording2 = create_recording() clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8) @@ -54,12 +54,12 @@ def test_mix_examples( @pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7]) @pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7]) def test_mix_examples_of_different_durations( - recording_factory: Callable[..., data.Recording], + create_recording: Callable[..., data.Recording], duration1: float, duration2: float, ): - recording1 = recording_factory() - recording2 = recording_factory() + recording1 = create_recording() + recording2 = create_recording() clip1 = data.Clip(recording=recording1, start_time=0, end_time=duration1) clip2 = data.Clip(recording=recording2, start_time=0, end_time=duration2) @@ -92,9 +92,9 @@ def test_mix_examples_of_different_durations( def test_add_echo( - recording_factory: Callable[..., data.Recording], + create_recording: Callable[..., data.Recording], ): - recording1 = recording_factory() + recording1 = create_recording() clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) clip_annotation_1 = data.ClipAnnotation(clip=clip1) config = TrainPreprocessingConfig() @@ -113,9 +113,9 @@ def test_add_echo( def test_selected_random_subclip_has_the_correct_width( - recording_factory: Callable[..., data.Recording], + create_recording: Callable[..., data.Recording], ): - recording1 = recording_factory() + recording1 = create_recording() clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) clip_annotation_1 = data.ClipAnnotation(clip=clip1) config = TrainPreprocessingConfig() @@ -131,9 +131,9 @@ def test_selected_random_subclip_has_the_correct_width( def test_add_echo_after_subclip( - recording_factory: Callable[..., data.Recording], + create_recording: Callable[..., data.Recording], ): - recording1 = recording_factory(duration=2) + recording1 = create_recording(duration=2) clip1 = data.Clip(recording=recording1, start_time=0, end_time=1) clip_annotation_1 = data.ClipAnnotation(clip=clip1) config = TrainPreprocessingConfig()