Add tests for aoef loading

This commit is contained in:
mbsantiago 2025-04-18 13:32:50 +01:00
parent fd7f2b0081
commit f9e005ec8b
14 changed files with 1443 additions and 164 deletions

View File

@ -5,15 +5,18 @@ from batdetect2.data.annotations import (
BatDetect2MergedAnnotations, BatDetect2MergedAnnotations,
load_annotated_dataset, load_annotated_dataset,
) )
from batdetect2.data.data import load_dataset, load_dataset_from_config from batdetect2.data.datasets import (
from batdetect2.data.types import Dataset DatasetConfig,
load_dataset,
load_dataset_from_config,
)
__all__ = [ __all__ = [
"AOEFAnnotations", "AOEFAnnotations",
"AnnotatedDataset", "AnnotatedDataset",
"BatDetect2FilesAnnotations", "BatDetect2FilesAnnotations",
"BatDetect2MergedAnnotations", "BatDetect2MergedAnnotations",
"Dataset", "DatasetConfig",
"load_annotated_dataset", "load_annotated_dataset",
"load_dataset", "load_dataset",
"load_dataset_from_config", "load_dataset_from_config",

View File

@ -1,33 +0,0 @@
from pathlib import Path
from typing import Literal, Union
from batdetect2.configs import BaseConfig
__all__ = [
"AOEFAnnotationFile",
"AnnotationFormats",
"BatDetect2AnnotationFile",
"BatDetect2AnnotationFiles",
]
class BatDetect2AnnotationFiles(BaseConfig):
format: Literal["batdetect2"] = "batdetect2"
path: Path
class BatDetect2AnnotationFile(BaseConfig):
format: Literal["batdetect2_file"] = "batdetect2_file"
path: Path
class AOEFAnnotationFile(BaseConfig):
format: Literal["aoef"] = "aoef"
path: Path
AnnotationFormats = Union[
BatDetect2AnnotationFiles,
BatDetect2AnnotationFile,
AOEFAnnotationFile,
]

View File

@ -1,9 +1,24 @@
"""Handles loading of annotation data from various formats.
This module serves as the central dispatcher for parsing annotation data
associated with BatDetect2 datasets. Datasets can be composed of multiple
sources, each potentially using a different annotation format (e.g., the
standard AOEF/soundevent format, or legacy BatDetect2 formats).
This module defines the `AnnotationFormats` type, which represents the union
of possible configuration models for these different formats (each identified
by a unique `format` field). The primary function, `load_annotated_dataset`,
inspects the configuration for a single data source and calls the appropriate
format-specific loading function to retrieve the annotations as a standard
`soundevent.data.AnnotationSet`.
"""
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
from soundevent import data from soundevent import data
from batdetect2.data.annotations.aeof import ( from batdetect2.data.annotations.aoef import (
AOEFAnnotations, AOEFAnnotations,
load_aoef_annotated_dataset, load_aoef_annotated_dataset,
) )
@ -32,12 +47,52 @@ AnnotationFormats = Union[
BatDetect2FilesAnnotations, BatDetect2FilesAnnotations,
AOEFAnnotations, AOEFAnnotations,
] ]
"""Type Alias representing all supported data source configurations.
Each specific configuration model within this union (e.g., `AOEFAnnotations`,
`BatDetect2FilesAnnotations`) corresponds to a different annotation format
or storage structure. These models are typically discriminated by a `format`
field (e.g., `format="aoef"`, `format="batdetect2_files"`), allowing Pydantic
and functions like `load_annotated_dataset` to determine which format a given
source configuration represents.
"""
def load_annotated_dataset( def load_annotated_dataset(
dataset: AnnotatedDataset, dataset: AnnotatedDataset,
base_dir: Optional[Path] = None, base_dir: Optional[Path] = None,
) -> data.AnnotationSet: ) -> data.AnnotationSet:
"""Load annotations for a single data source based on its configuration.
This function acts as a dispatcher. It inspects the type of the input
`source_config` object (which corresponds to a specific annotation format)
and calls the appropriate loading function (e.g.,
`load_aoef_annotated_dataset` for `AOEFAnnotations`).
Parameters
----------
source_config : AnnotationFormats
The configuration object for the data source, specifying its format
and necessary details (like paths). Must be an instance of one of the
types included in the `AnnotationFormats` union.
base_dir : Path, optional
An optional base directory path. If provided, relative paths within
the `source_config` might be resolved relative to this directory by
the underlying loading functions. Defaults to None.
Returns
-------
soundevent.data.AnnotationSet
An AnnotationSet containing the `ClipAnnotation` objects loaded and
parsed from the specified data source.
Raises
------
NotImplementedError
If the type of the `source_config` object does not match any of the
known format-specific loading functions implemented in the dispatch
logic.
"""
if isinstance(dataset, AOEFAnnotations): if isinstance(dataset, AOEFAnnotations):
return load_aoef_annotated_dataset(dataset, base_dir=base_dir) return load_aoef_annotated_dataset(dataset, base_dir=base_dir)

View File

@ -1,37 +0,0 @@
from pathlib import Path
from typing import Literal, Optional
from soundevent import data, io
from batdetect2.data.annotations.types import AnnotatedDataset
__all__ = [
"AOEFAnnotations",
"load_aoef_annotated_dataset",
]
class AOEFAnnotations(AnnotatedDataset):
format: Literal["aoef"] = "aoef"
annotations_path: Path
def load_aoef_annotated_dataset(
dataset: AOEFAnnotations,
base_dir: Optional[Path] = None,
) -> data.AnnotationSet:
audio_dir = dataset.audio_dir
path = dataset.annotations_path
if base_dir:
audio_dir = base_dir / audio_dir
path = base_dir / path
loaded = io.load(path, audio_dir=audio_dir)
if not isinstance(loaded, (data.AnnotationSet, data.AnnotationProject)):
raise ValueError(
f"The AOEF file at {path} does not contain a set of annotations"
)
return loaded

View File

@ -0,0 +1,270 @@
"""Loads annotation data specifically from the AOEF / soundevent format.
This module provides the necessary configuration model and loading function
to handle data sources where annotations are stored in the standard format
used by the `soundevent` library (often as `.json` or `.aoef` files),
which includes outputs from annotation tools like Whombat.
It supports loading both simple `AnnotationSet` files and more complex
`AnnotationProject` files. For `AnnotationProject` files, it offers optional
filtering capabilities to select only annotations associated with tasks
that meet specific status criteria (e.g., completed, verified, without issues).
"""
from pathlib import Path
from typing import Literal, Optional
from uuid import uuid5
from pydantic import Field
from soundevent import data, io
from batdetect2.configs import BaseConfig
from batdetect2.data.annotations.types import AnnotatedDataset
__all__ = [
"AOEFAnnotations",
"load_aoef_annotated_dataset",
"AnnotationTaskFilter",
]
class AnnotationTaskFilter(BaseConfig):
"""Configuration for filtering Annotation Tasks within an AnnotationProject.
Specifies criteria based on task status badges to select relevant
annotations, typically used when loading data from annotation projects
that might contain work-in-progress.
Attributes
----------
only_completed : bool, default=True
If True, only include annotations from tasks marked as 'completed'.
only_verified : bool, default=False
If True, only include annotations from tasks marked as 'verified'.
exclude_issues : bool, default=True
If True, exclude annotations from tasks marked as 'rejected' (indicating
issues).
"""
only_completed: bool = True
only_verified: bool = False
exclude_issues: bool = True
class AOEFAnnotations(AnnotatedDataset):
"""Configuration defining a data source stored in AOEF format.
This model specifies how to load annotations from an AOEF (JSON file) file
compatible with the `soundevent` library. It inherits `name`,
`description`, and `audio_dir` from `AnnotatedDataset`.
Attributes
----------
format : Literal["aoef"]
The fixed format identifier for this configuration type.
annotations_path : Path
The file system path to the `.aoef` or `.json` file containing the
`AnnotationSet` or `AnnotationProject`.
filter : AnnotationTaskFilter, optional
Configuration for filtering tasks if the `annotations_path` points to
an `AnnotationProject`. If omitted, default filtering
(only completed, exclude issues, verification not required) is applied
to projects. Set explicitly to `None` in config (e.g., `filter: null`)
to disable filtering for projects entirely.
"""
format: Literal["aoef"] = "aoef"
annotations_path: Path
filter: Optional[AnnotationTaskFilter] = Field(
default_factory=AnnotationTaskFilter
)
def load_aoef_annotated_dataset(
dataset: AOEFAnnotations,
base_dir: Optional[Path] = None,
) -> data.AnnotationSet:
"""Load annotations from an AnnotationSet or AnnotationProject file.
Reads the file specified in the `dataset` configuration using
`soundevent.io.load`. If the loaded file contains an `AnnotationProject`
and filtering is enabled via `dataset.filter`, it applies the filter
criteria based on task status and returns a new `AnnotationSet` containing
only the selected annotations. If the file contains an `AnnotationSet`,
or if it's a project and filtering is disabled, the all annotations are
returned.
Parameters
----------
dataset : AOEFAnnotations
The configuration object describing the AOEF data source, including
the path to the annotation file and optional filtering settings.
base_dir : Path, optional
An optional base directory. If provided, `dataset.annotations_path`
and `dataset.audio_dir` will be resolved relative to this
directory. Defaults to None.
Returns
-------
soundevent.data.AnnotationSet
An AnnotationSet containing the loaded (and potentially filtered)
`ClipAnnotation` objects.
Raises
------
FileNotFoundError
If the specified `annotations_path` (after resolving `base_dir`)
does not exist.
ValueError
If the loaded file does not contain a valid `AnnotationSet` or
`AnnotationProject`.
Exception
May re-raise errors from `soundevent.io.load` related to parsing
or file format issues.
Notes
-----
- The `soundevent` library handles parsing of `.json` or `.aoef` formats.
- If an `AnnotationProject` is loaded and `dataset.filter` is *not* None,
a *new* `AnnotationSet` instance is created containing only the filtered
clip annotations.
"""
audio_dir = dataset.audio_dir
path = dataset.annotations_path
if base_dir:
audio_dir = base_dir / audio_dir
path = base_dir / path
loaded = io.load(path, audio_dir=audio_dir)
if not isinstance(loaded, (data.AnnotationSet, data.AnnotationProject)):
raise ValueError(
f"The file at {path} loaded successfully but does not "
"contain a soundevent AnnotationSet or AnnotationProject "
f"(loaded type: {type(loaded).__name__})."
)
if isinstance(loaded, data.AnnotationProject) and dataset.filter:
loaded = filter_ready_clips(
loaded,
only_completed=dataset.filter.only_completed,
only_verified=dataset.filter.only_verified,
exclude_issues=dataset.filter.exclude_issues,
)
return loaded
def select_task(
annotation_task: data.AnnotationTask,
only_completed: bool = True,
only_verified: bool = False,
exclude_issues: bool = True,
) -> bool:
"""Check if an AnnotationTask meets specified status criteria.
Evaluates the `status_badges` of the task against the filter flags.
Parameters
----------
annotation_task : data.AnnotationTask
The annotation task to check.
only_completed : bool, default=True
Task must be marked 'completed' to pass.
only_verified : bool, default=False
Task must be marked 'verified' to pass.
exclude_issues : bool, default=True
Task must *not* be marked 'rejected' (have issues) to pass.
Returns
-------
bool
True if the task meets all active filter criteria, False otherwise.
"""
has_issues = False
is_completed = False
is_verified = False
for badge in annotation_task.status_badges:
if badge.state == data.AnnotationState.completed:
is_completed = True
continue
if badge.state == data.AnnotationState.rejected:
has_issues = True
continue
if badge.state == data.AnnotationState.verified:
is_verified = True
if exclude_issues and has_issues:
return False
if only_verified and not is_verified:
return False
if only_completed and not is_completed:
return False
return True
def filter_ready_clips(
annotation_project: data.AnnotationProject,
only_completed: bool = True,
only_verified: bool = False,
exclude_issues: bool = True,
) -> data.AnnotationSet:
"""Filter AnnotationProject to create an AnnotationSet of 'ready' clips.
Iterates through tasks in the project, selects tasks meeting the status
criteria using `select_task`, and creates a new `AnnotationSet` containing
only the `ClipAnnotation` objects associated with those selected tasks.
Parameters
----------
annotation_project : data.AnnotationProject
The input annotation project.
only_completed : bool, default=True
Filter flag passed to `select_task`.
only_verified : bool, default=False
Filter flag passed to `select_task`.
exclude_issues : bool, default=True
Filter flag passed to `select_task`.
Returns
-------
data.AnnotationSet
A new annotation set containing only the clip annotations linked to
tasks that satisfied the filtering criteria. The returned set has a
deterministic UUID based on the project UUID and filter settings.
"""
ready_clip_uuids = set()
for annotation_task in annotation_project.tasks:
if not select_task(
annotation_task,
only_completed=only_completed,
only_verified=only_verified,
exclude_issues=exclude_issues,
):
continue
ready_clip_uuids.add(annotation_task.clip.uuid)
return data.AnnotationSet(
uuid=uuid5(
annotation_project.uuid,
f"{only_completed}_{only_verified}_{exclude_issues}",
),
name=annotation_project.name,
description=annotation_project.description,
clip_annotations=[
annotation
for annotation in annotation_project.clip_annotations
if annotation.clip.uuid in ready_clip_uuids
],
)

View File

@ -17,11 +17,10 @@ class AnnotatedDataset(BaseConfig):
Annotations associated with these recordings are defined by the Annotations associated with these recordings are defined by the
`annotations` field, which supports various formats (e.g., AOEF files, `annotations` field, which supports various formats (e.g., AOEF files,
specific CSV specific CSV structures). Crucially, file paths referenced within the
structures). annotation data *must* be relative to the `audio_dir`. This ensures that
Crucially, file paths referenced within the annotation data *must* be the dataset definition remains portable across different systems and base
relative to the `audio_dir`. This ensures that the dataset definition directories.
remains portable across different systems and base directories.
Attributes: Attributes:
name: A unique identifier for this data source. name: A unique identifier for this data source.

View File

@ -1,37 +0,0 @@
from pathlib import Path
from typing import Optional
from soundevent import data
from batdetect2.configs import load_config
from batdetect2.data.annotations import load_annotated_dataset
from batdetect2.data.types import Dataset
__all__ = [
"load_dataset",
"load_dataset_from_config",
]
def load_dataset(
dataset: Dataset,
base_dir: Optional[Path] = None,
) -> data.AnnotationSet:
clip_annotations = []
for source in dataset.sources:
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
clip_annotations.extend(annotated_source.clip_annotations)
return data.AnnotationSet(clip_annotations=clip_annotations)
def load_dataset_from_config(
path: data.PathLike,
field: Optional[str] = None,
base_dir: Optional[Path] = None,
):
config = load_config(
path=path,
schema=Dataset,
field=field,
)
return load_dataset(config, base_dir=base_dir)

207
batdetect2/data/datasets.py Normal file
View File

@ -0,0 +1,207 @@
"""Defines the overall dataset structure and provides loading/saving utilities.
This module focuses on defining what constitutes a BatDetect2 dataset,
potentially composed of multiple distinct data sources with varying annotation
formats. It provides mechanisms to load the annotation metadata from these
sources into a unified representation.
The core components are:
- `DatasetConfig`: A configuration class (typically loaded from YAML) that
describes the dataset's name, description, and constituent sources.
- `Dataset`: A type alias representing the loaded dataset as a list of
`soundevent.data.ClipAnnotation` objects. Note that this implies all
annotation metadata is loaded into memory.
- Loading functions (`load_dataset`, `load_dataset_from_config`): To parse
a `DatasetConfig` and load the corresponding annotation metadata.
- Saving function (`save_dataset`): To save a loaded list of annotations
into a standard `soundevent` format.
"""
from pathlib import Path
from typing import Annotated, List, Optional
from pydantic import Field
from soundevent import data, io
from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.annotations import (
AnnotationFormats,
load_annotated_dataset,
)
__all__ = [
"load_dataset",
"load_dataset_from_config",
"save_dataset",
"Dataset",
"DatasetConfig",
]
Dataset = List[data.ClipAnnotation]
"""Type alias for a loaded dataset representation.
Represents an entire dataset *after loading* as a flat Python list containing
all `soundevent.data.ClipAnnotation` objects gathered from all configured data
sources.
"""
class DatasetConfig(BaseConfig):
"""Configuration model defining the structure of a BatDetect2 dataset.
This class is typically loaded from a YAML file and describes the components
of the dataset, including metadata and a list of data sources.
Attributes
----------
name : str
A descriptive name for the dataset (e.g., "UK_Bats_Project_2024").
description : str
A longer description of the dataset's contents, origin, purpose, etc.
sources : List[AnnotationFormats]
A list defining the different data sources contributing to this
dataset. Each item in the list must conform to one of the Pydantic
models defined in the `AnnotationFormats` type union. The specific
model used for each source is determined by the mandatory `format`
field within the source's configuration, allowing BatDetect2 to use the
correct parser for different annotation styles.
"""
name: str
description: str
sources: List[
Annotated[AnnotationFormats, Field(..., discriminator="format")]
]
def load_dataset(
dataset: DatasetConfig,
base_dir: Optional[Path] = None,
) -> Dataset:
"""Load all clip annotations from the sources defined in a DatasetConfig.
Iterates through each data source specified in the `dataset_config`,
delegates the loading and parsing of that source's annotations to
`batdetect2.data.annotations.load_annotated_dataset` (which handles
different data formats), and aggregates all resulting `ClipAnnotation`
objects into a single flat list.
Parameters
----------
dataset_config : DatasetConfig
The configuration object describing the dataset and its sources.
base_dir : Path, optional
An optional base directory path. If provided, relative paths for
metadata files or data directories within the `dataset_config`'s
sources might be resolved relative to this directory. Defaults to None.
Returns
-------
Dataset (List[data.ClipAnnotation])
A flat list containing all loaded `ClipAnnotation` metadata objects
from all specified sources.
Raises
------
Exception
Can raise various exceptions during the delegated loading process
(`load_annotated_dataset`) if files are not found, cannot be parsed
according to the specified format, or other I/O errors occur.
"""
clip_annotations = []
for source in dataset.sources:
annotated_source = load_annotated_dataset(source, base_dir=base_dir)
clip_annotations.extend(annotated_source.clip_annotations)
return clip_annotations
def load_dataset_from_config(
path: data.PathLike,
field: Optional[str] = None,
base_dir: Optional[Path] = None,
) -> Dataset:
"""Load dataset annotation metadata from a configuration file.
This is a convenience function that first loads the `DatasetConfig` from
the specified file path and optional nested field, and then calls
`load_dataset` to load all corresponding `ClipAnnotation` objects.
Parameters
----------
path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
Dot-separated path to a nested section within the file containing the
dataset configuration (e.g., "data.training_set"). If None, the
entire file content is assumed to be the `DatasetConfig`.
base_dir : Path, optional
An optional base directory path to resolve relative paths within the
configuration sources. Passed to `load_dataset`. Defaults to None.
Returns
-------
Dataset (List[data.ClipAnnotation])
A flat list containing all loaded `ClipAnnotation` metadata objects.
Raises
------
FileNotFoundError
If the config file `path` does not exist.
yaml.YAMLError, pydantic.ValidationError, KeyError, TypeError
If the configuration file is invalid, cannot be parsed, or does not
match the `DatasetConfig` schema.
Exception
Can raise exceptions from `load_dataset` if loading data from sources
fails.
"""
config = load_config(
path=path,
schema=DatasetConfig,
field=field,
)
return load_dataset(config, base_dir=base_dir)
def save_dataset(
dataset: Dataset,
path: data.PathLike,
name: Optional[str] = None,
description: Optional[str] = None,
audio_dir: Optional[Path] = None,
) -> None:
"""Save a loaded dataset (list of ClipAnnotations) to a file.
Wraps the provided list of `ClipAnnotation` objects into a
`soundevent.data.AnnotationSet` and saves it using `soundevent.io.save`.
This saves the aggregated annotation metadata in the standard soundevent
format.
Note: This function saves the *loaded annotation data*, not the original
`DatasetConfig` structure that defined how the data was assembled from
various sources.
Parameters
----------
dataset : Dataset (List[data.ClipAnnotation])
The list of clip annotations to save (typically the result of
`load_dataset` or a split thereof).
path : data.PathLike
The output file path (e.g., 'train_annotations.json',
'val_annotations.json'). The format is determined by `soundevent.io`.
name : str, optional
An optional name to assign to the saved `AnnotationSet`.
description : str, optional
An optional description to assign to the saved `AnnotationSet`.
audio_dir : Path, optional
Passed to `soundevent.io.save`. May be used to relativize audio file
paths within the saved annotations if applicable to the save format.
"""
annotation_set = data.AnnotationSet(
name=name,
description=description,
clip_annotations=dataset,
)
io.save(annotation_set, path, audio_dir=audio_dir)

View File

@ -1,29 +0,0 @@
from typing import Annotated, List
from pydantic import Field
from batdetect2.configs import BaseConfig
from batdetect2.data.annotations import AnnotationFormats
class Dataset(BaseConfig):
"""Represents a collection of one or more DatasetSources.
In the context of batdetect2, a Dataset aggregates multiple `DatasetSource`
instances. It serves as the primary unit for defining data splits,
typically used for model training, validation, or testing phases.
Attributes:
name: A descriptive name for the overall dataset
(e.g., "UK Training Set").
description: A detailed explanation of the dataset's purpose,
composition, how it was assembled, or any specific characteristics.
sources: A list containing the `DatasetSource` objects included in this
dataset.
"""
name: str
description: str
sources: List[
Annotated[AnnotationFormats, Field(..., discriminator="format")]
]

View File

@ -17,7 +17,7 @@ dependencies = [
"torch>=1.13.1,<2.5.0", "torch>=1.13.1,<2.5.0",
"torchaudio>=1.13.1,<2.5.0", "torchaudio>=1.13.1,<2.5.0",
"torchvision>=0.14.0", "torchvision>=0.14.0",
"soundevent[audio,geometry,plot]>=2.3", "soundevent[audio,geometry,plot]>=2.4.1",
"click>=8.1.7", "click>=8.1.7",
"netcdf4>=1.6.5", "netcdf4>=1.6.5",
"tqdm>=4.66.2", "tqdm>=4.66.2",

View File

@ -86,8 +86,8 @@ def wav_factory(tmp_path: Path):
@pytest.fixture @pytest.fixture
def recording_factory(wav_factory: Callable[..., Path]): def create_recording(wav_factory: Callable[..., Path]):
def _recording_factory( def factory(
tags: Optional[list[data.Tag]] = None, tags: Optional[list[data.Tag]] = None,
path: Optional[Path] = None, path: Optional[Path] = None,
recording_id: Optional[uuid.UUID] = None, recording_id: Optional[uuid.UUID] = None,
@ -96,7 +96,8 @@ def recording_factory(wav_factory: Callable[..., Path]):
samplerate: int = 256_000, samplerate: int = 256_000,
time_expansion: float = 1, time_expansion: float = 1,
) -> data.Recording: ) -> data.Recording:
path = path or wav_factory( path = wav_factory(
path=path,
duration=duration, duration=duration,
channels=channels, channels=channels,
samplerate=samplerate, samplerate=samplerate,
@ -108,14 +109,30 @@ def recording_factory(wav_factory: Callable[..., Path]):
tags=tags or [], tags=tags or [],
) )
return _recording_factory return factory
@pytest.fixture @pytest.fixture
def recording( def recording(
recording_factory: Callable[..., data.Recording], create_recording: Callable[..., data.Recording],
) -> data.Recording: ) -> data.Recording:
return recording_factory() return create_recording()
@pytest.fixture
def create_clip():
def factory(
recording: data.Recording,
start_time: float = 0,
end_time: float = 0.5,
) -> data.Clip:
return data.Clip(
recording=recording,
start_time=start_time,
end_time=end_time,
)
return factory
@pytest.fixture @pytest.fixture
@ -123,6 +140,22 @@ def clip(recording: data.Recording) -> data.Clip:
return data.Clip(recording=recording, start_time=0, end_time=0.5) return data.Clip(recording=recording, start_time=0, end_time=0.5)
@pytest.fixture
def create_sound_event():
def factory(
recording: data.Recording,
coords: Optional[List[float]] = None,
) -> data.SoundEvent:
coords = coords or [0.2, 60_000, 0.3, 70_000]
return data.SoundEvent(
geometry=data.BoundingBox(coordinates=coords),
recording=recording,
)
return factory
@pytest.fixture @pytest.fixture
def sound_event(recording: data.Recording) -> data.SoundEvent: def sound_event(recording: data.Recording) -> data.SoundEvent:
return data.SoundEvent( return data.SoundEvent(
@ -131,6 +164,20 @@ def sound_event(recording: data.Recording) -> data.SoundEvent:
) )
@pytest.fixture
def create_sound_event_annotation():
def factory(
sound_event: data.SoundEvent,
tags: Optional[List[data.Tag]] = None,
) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=sound_event,
tags=tags or [],
)
return factory
@pytest.fixture @pytest.fixture
def echolocation_call(recording: data.Recording) -> data.SoundEventAnnotation: def echolocation_call(recording: data.Recording) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation( return data.SoundEventAnnotation(
@ -181,6 +228,22 @@ def non_relevant_sound_event(
) )
@pytest.fixture
def create_clip_annotation():
def factory(
clip: data.Clip,
clip_tags: Optional[List[data.Tag]] = None,
sound_events: Optional[List[data.SoundEventAnnotation]] = None,
) -> data.ClipAnnotation:
return data.ClipAnnotation(
clip=clip,
tags=clip_tags or [],
sound_events=sound_events or [],
)
return factory
@pytest.fixture @pytest.fixture
def clip_annotation( def clip_annotation(
clip: data.Clip, clip: data.Clip,
@ -196,3 +259,37 @@ def clip_annotation(
non_relevant_sound_event, non_relevant_sound_event,
], ],
) )
@pytest.fixture
def create_annotation_set():
def factory(
name: str = "test",
description: str = "Test annotation set",
annotations: Optional[List[data.ClipAnnotation]] = None,
) -> data.AnnotationSet:
return data.AnnotationSet(
name=name,
description=description,
clip_annotations=annotations or [],
)
return factory
@pytest.fixture
def create_annotation_project():
def factory(
name: str = "test_project",
description: str = "Test Annotation Project",
tasks: Optional[List[data.AnnotationTask]] = None,
annotations: Optional[List[data.ClipAnnotation]] = None,
) -> data.AnnotationProject:
return data.AnnotationProject(
name=name,
description=description,
tasks=tasks or [],
clip_annotations=annotations or [],
)
return factory

View File

@ -0,0 +1,784 @@
import uuid
from pathlib import Path
from typing import Callable, Optional, Sequence
import pytest
from pydantic import ValidationError
from soundevent import data, io
from soundevent.data.annotation_tasks import AnnotationState
from batdetect2.data.annotations import aoef
@pytest.fixture
def base_dir(tmp_path: Path) -> Path:
path = tmp_path / "base_dir"
path.mkdir(parents=True, exist_ok=True)
return path
@pytest.fixture
def audio_dir(base_dir: Path) -> Path:
path = base_dir / "audio"
path.mkdir(parents=True, exist_ok=True)
return path
@pytest.fixture
def anns_dir(base_dir: Path) -> Path:
path = base_dir / "annotations"
path.mkdir(parents=True, exist_ok=True)
return path
def create_task(
clip: data.Clip,
badges: list[data.StatusBadge],
task_id: Optional[uuid.UUID] = None,
) -> data.AnnotationTask:
"""Creates a simple AnnotationTask for testing."""
return data.AnnotationTask(
uuid=task_id or uuid.uuid4(),
clip=clip,
status_badges=badges,
)
def test_annotation_task_filter_defaults():
"""Test default values of AnnotationTaskFilter."""
f = aoef.AnnotationTaskFilter()
assert f.only_completed is True
assert f.only_verified is False
assert f.exclude_issues is True
def test_annotation_task_filter_initialization():
"""Test initialization of AnnotationTaskFilter with non-default values."""
f = aoef.AnnotationTaskFilter(
only_completed=False,
only_verified=True,
exclude_issues=False,
)
assert f.only_completed is False
assert f.only_verified is True
assert f.exclude_issues is False
def test_aoef_annotations_defaults(
audio_dir: Path,
anns_dir: Path,
):
"""Test default values of AOEFAnnotations."""
annotations_path = anns_dir / "test.aoef"
config = aoef.AOEFAnnotations(
name="default_name",
audio_dir=audio_dir,
annotations_path=annotations_path,
)
assert config.format == "aoef"
assert config.annotations_path == annotations_path
assert config.audio_dir == audio_dir
assert isinstance(config.filter, aoef.AnnotationTaskFilter)
assert config.filter.only_completed is True
assert config.filter.only_verified is False
assert config.filter.exclude_issues is True
def test_aoef_annotations_initialization(tmp_path):
"""Test initialization of AOEFAnnotations with specific values."""
annotations_path = tmp_path / "custom.json"
audio_dir = Path("audio/files")
custom_filter = aoef.AnnotationTaskFilter(
only_completed=False, only_verified=True
)
config = aoef.AOEFAnnotations(
name="custom_name",
description="custom_desc",
audio_dir=audio_dir,
annotations_path=annotations_path,
filter=custom_filter,
)
assert config.name == "custom_name"
assert config.description == "custom_desc"
assert config.format == "aoef"
assert config.audio_dir == audio_dir
assert config.annotations_path == annotations_path
assert config.filter is custom_filter
def test_aoef_annotations_initialization_no_filter(tmp_path):
"""Test initialization of AOEFAnnotations with filter=None."""
annotations_path = tmp_path / "no_filter.aoef"
audio_dir = tmp_path / "audio"
config = aoef.AOEFAnnotations(
name="no_filter_name",
description="no_filter_desc",
audio_dir=audio_dir,
annotations_path=annotations_path,
filter=None,
)
assert config.format == "aoef"
assert config.annotations_path == annotations_path
assert config.filter is None
def test_aoef_annotations_validation_error(tmp_path):
"""Test Pydantic validation for missing required fields."""
with pytest.raises(ValidationError, match="annotations_path"):
aoef.AOEFAnnotations( # type: ignore
name="test_name",
audio_dir=tmp_path,
)
with pytest.raises(ValidationError, match="name"):
aoef.AOEFAnnotations( # type: ignore
annotations_path=tmp_path / "dummy.aoef",
audio_dir=tmp_path,
)
@pytest.mark.parametrize(
"badges, only_completed, only_verified, exclude_issues, expected",
[
([], True, False, True, False), # No badges -> not completed
(
[data.StatusBadge(state=AnnotationState.completed)],
True,
False,
True,
True,
),
(
[data.StatusBadge(state=AnnotationState.verified)],
True,
False,
True,
False,
), # Not completed
(
[data.StatusBadge(state=AnnotationState.rejected)],
True,
False,
True,
False,
), # Has issues
(
[
data.StatusBadge(state=AnnotationState.completed),
data.StatusBadge(state=AnnotationState.rejected),
],
True,
False,
True,
False,
), # Completed but has issues
(
[
data.StatusBadge(state=AnnotationState.completed),
data.StatusBadge(state=AnnotationState.verified),
],
True,
False,
True,
True,
), # Completed, verified doesn't matter
# Verified only (completed=F, verified=T, exclude_issues=T)
(
[data.StatusBadge(state=AnnotationState.verified)],
False,
True,
True,
True,
),
(
[data.StatusBadge(state=AnnotationState.completed)],
False,
True,
True,
False,
), # Not verified
(
[
data.StatusBadge(state=AnnotationState.verified),
data.StatusBadge(state=AnnotationState.rejected),
],
False,
True,
True,
False,
), # Verified but has issues
# Completed AND Verified (completed=T, verified=T, exclude_issues=T)
(
[
data.StatusBadge(state=AnnotationState.completed),
data.StatusBadge(state=AnnotationState.verified),
],
True,
True,
True,
True,
),
(
[data.StatusBadge(state=AnnotationState.completed)],
True,
True,
True,
False,
), # Not verified
(
[data.StatusBadge(state=AnnotationState.verified)],
True,
True,
True,
False,
), # Not completed
# Include Issues (completed=T, verified=F, exclude_issues=F)
(
[
data.StatusBadge(state=AnnotationState.completed),
data.StatusBadge(state=AnnotationState.rejected),
],
True,
False,
False,
True,
), # Completed, issues allowed
(
[data.StatusBadge(state=AnnotationState.rejected)],
True,
False,
False,
False,
), # Has issues, but not completed
# No filters (completed=F, verified=F, exclude_issues=F)
([], False, False, False, True),
(
[data.StatusBadge(state=AnnotationState.rejected)],
False,
False,
False,
True,
),
(
[data.StatusBadge(state=AnnotationState.completed)],
False,
False,
False,
True,
),
(
[data.StatusBadge(state=AnnotationState.verified)],
False,
False,
False,
True,
),
],
)
def test_select_task(
badges: Sequence[data.StatusBadge],
only_completed: bool,
only_verified: bool,
exclude_issues: bool,
expected: bool,
create_recording: Callable[..., data.Recording],
create_clip: Callable[..., data.Clip],
):
"""Test select_task logic with various badge and filter combinations."""
rec = create_recording()
clip = create_clip(rec)
task = create_task(clip, badges=list(badges))
result = aoef.select_task(
task,
only_completed=only_completed,
only_verified=only_verified,
exclude_issues=exclude_issues,
)
assert result == expected
def test_filter_ready_clips_default(
tmp_path: Path,
create_recording: Callable[..., data.Recording],
create_clip: Callable[..., data.Clip],
create_clip_annotation: Callable[..., data.ClipAnnotation],
create_annotation_project: Callable[..., data.AnnotationProject],
):
"""Test filter_ready_clips with default filtering."""
rec = create_recording(path=tmp_path / "rec.wav")
clip_completed = create_clip(rec, 0, 1)
clip_verified = create_clip(rec, 1, 2)
clip_rejected = create_clip(rec, 2, 3)
clip_completed_rejected = create_clip(rec, 3, 4)
clip_no_badges = create_clip(rec, 4, 5)
task_completed = create_task(
clip_completed, [data.StatusBadge(state=AnnotationState.completed)]
)
task_verified = create_task(
clip_verified, [data.StatusBadge(state=AnnotationState.verified)]
)
task_rejected = create_task(
clip_rejected, [data.StatusBadge(state=AnnotationState.rejected)]
)
task_completed_rejected = create_task(
clip_completed_rejected,
[
data.StatusBadge(state=AnnotationState.completed),
data.StatusBadge(state=AnnotationState.rejected),
],
)
task_no_badges = create_task(clip_no_badges, [])
ann_completed = create_clip_annotation(clip_completed)
ann_verified = create_clip_annotation(clip_verified)
ann_rejected = create_clip_annotation(clip_rejected)
ann_completed_rejected = create_clip_annotation(clip_completed_rejected)
ann_no_badges = create_clip_annotation(clip_no_badges)
project = create_annotation_project(
name="FilterTestProject",
description="Project for testing filters",
tasks=[
task_completed,
task_verified,
task_rejected,
task_completed_rejected,
task_no_badges,
],
annotations=[
ann_completed,
ann_verified,
ann_rejected,
ann_completed_rejected,
ann_no_badges,
],
)
filtered_set = aoef.filter_ready_clips(project)
assert isinstance(filtered_set, data.AnnotationSet)
assert filtered_set.name == project.name
assert filtered_set.description == project.description
assert len(filtered_set.clip_annotations) == 1
assert filtered_set.clip_annotations[0].clip.uuid == clip_completed.uuid
expected_uuid = uuid.uuid5(project.uuid, f"{True}_{False}_{True}")
assert filtered_set.uuid == expected_uuid
def test_filter_ready_clips_custom_filter(
tmp_path: Path,
create_recording: Callable[..., data.Recording],
create_clip: Callable[..., data.Clip],
create_clip_annotation: Callable[..., data.ClipAnnotation],
create_annotation_project: Callable[..., data.AnnotationProject],
):
"""Test filter_ready_clips with custom filtering (verified=T, issues=F)."""
rec = create_recording(path=tmp_path / "rec.wav")
clip_completed = create_clip(rec, 0, 1)
clip_verified = create_clip(rec, 1, 2)
clip_rejected = create_clip(rec, 2, 3)
clip_completed_verified = create_clip(rec, 3, 4)
clip_verified_rejected = create_clip(rec, 4, 5)
task_completed = create_task(
clip_completed, [data.StatusBadge(state=AnnotationState.completed)]
)
task_verified = create_task(
clip_verified, [data.StatusBadge(state=AnnotationState.verified)]
)
task_rejected = create_task(
clip_rejected, [data.StatusBadge(state=AnnotationState.rejected)]
)
task_completed_verified = create_task(
clip_completed_verified,
[
data.StatusBadge(state=AnnotationState.completed),
data.StatusBadge(state=AnnotationState.verified),
],
)
task_verified_rejected = create_task(
clip_verified_rejected,
[
data.StatusBadge(state=AnnotationState.verified),
data.StatusBadge(state=AnnotationState.rejected),
],
)
ann_completed = create_clip_annotation(clip_completed)
ann_verified = create_clip_annotation(clip_verified)
ann_rejected = create_clip_annotation(clip_rejected)
ann_completed_verified = create_clip_annotation(clip_completed_verified)
ann_verified_rejected = create_clip_annotation(clip_verified_rejected)
project = create_annotation_project(
tasks=[
task_completed,
task_verified,
task_rejected,
task_completed_verified,
task_verified_rejected,
],
annotations=[
ann_completed,
ann_verified,
ann_rejected,
ann_completed_verified,
ann_verified_rejected,
],
)
filtered_set = aoef.filter_ready_clips(
project, only_completed=False, only_verified=True, exclude_issues=False
)
assert len(filtered_set.clip_annotations) == 3
filtered_clip_uuids = {
ann.clip.uuid for ann in filtered_set.clip_annotations
}
assert clip_verified.uuid in filtered_clip_uuids
assert clip_completed_verified.uuid in filtered_clip_uuids
assert clip_verified_rejected.uuid in filtered_clip_uuids
expected_uuid = uuid.uuid5(project.uuid, f"{False}_{True}_{False}")
assert filtered_set.uuid == expected_uuid
def test_filter_ready_clips_no_filters(
tmp_path: Path,
create_recording: Callable[..., data.Recording],
create_clip: Callable[..., data.Clip],
create_clip_annotation: Callable[..., data.ClipAnnotation],
create_annotation_project: Callable[..., data.AnnotationProject],
):
"""Test filter_ready_clips with all filters disabled."""
rec = create_recording(path=tmp_path / "rec.wav")
clip1 = create_clip(rec, 0, 1)
clip2 = create_clip(rec, 1, 2)
task1 = create_task(
clip1, [data.StatusBadge(state=AnnotationState.rejected)]
)
task2 = create_task(clip2, [])
ann1 = create_clip_annotation(clip1)
ann2 = create_clip_annotation(clip2)
project = create_annotation_project(
tasks=[task1, task2], annotations=[ann1, ann2]
)
filtered_set = aoef.filter_ready_clips(
project,
only_completed=False,
only_verified=False,
exclude_issues=False,
)
assert len(filtered_set.clip_annotations) == 2
filtered_clip_uuids = {
ann.clip.uuid for ann in filtered_set.clip_annotations
}
assert clip1.uuid in filtered_clip_uuids
assert clip2.uuid in filtered_clip_uuids
expected_uuid = uuid.uuid5(project.uuid, f"{False}_{False}_{False}")
assert filtered_set.uuid == expected_uuid
def test_filter_ready_clips_empty_project(
create_annotation_project: Callable[..., data.AnnotationProject],
):
"""Test filter_ready_clips with an empty project."""
project = create_annotation_project(tasks=[], annotations=[])
filtered_set = aoef.filter_ready_clips(project)
assert len(filtered_set.clip_annotations) == 0
assert filtered_set.name == project.name
assert filtered_set.description == project.description
def test_filter_ready_clips_no_matching_tasks(
tmp_path: Path,
create_recording: Callable[..., data.Recording],
create_clip: Callable[..., data.Clip],
create_clip_annotation: Callable[..., data.ClipAnnotation],
create_annotation_project: Callable[..., data.AnnotationProject],
):
"""Test filter_ready_clips when no tasks match the criteria."""
rec = create_recording(path=tmp_path / "rec.wav")
clip_rejected = create_clip(rec, 0, 1)
task_rejected = create_task(
clip_rejected, [data.StatusBadge(state=AnnotationState.rejected)]
)
ann_rejected = create_clip_annotation(clip_rejected)
project = create_annotation_project(
tasks=[task_rejected], annotations=[ann_rejected]
)
filtered_set = aoef.filter_ready_clips(project)
assert len(filtered_set.clip_annotations) == 0
def test_load_aoef_annotated_dataset_set(
tmp_path: Path,
create_recording: Callable[..., data.Recording],
create_clip: Callable[..., data.Clip],
create_clip_annotation: Callable[..., data.ClipAnnotation],
create_annotation_set: Callable[..., data.AnnotationSet],
):
"""Test loading a standard AnnotationSet file."""
rec_path = tmp_path / "audio" / "rec1.wav"
rec_path.parent.mkdir()
rec = create_recording(path=rec_path)
clip = create_clip(rec)
ann = create_clip_annotation(clip)
original_set = create_annotation_set(annotations=[ann])
annotations_file = tmp_path / "set.json"
io.save(original_set, annotations_file)
config = aoef.AOEFAnnotations(
name="test_set_load",
annotations_path=annotations_file,
audio_dir=rec_path.parent,
)
loaded_set = aoef.load_aoef_annotated_dataset(config)
assert isinstance(loaded_set, data.AnnotationSet)
assert loaded_set.name == original_set.name
assert len(loaded_set.clip_annotations) == len(
original_set.clip_annotations
)
assert (
loaded_set.clip_annotations[0].clip.uuid
== original_set.clip_annotations[0].clip.uuid
)
assert (
loaded_set.clip_annotations[0].clip.recording.path
== rec_path.resolve()
)
def test_load_aoef_annotated_dataset_project_with_filter(
tmp_path: Path,
create_recording: Callable[..., data.Recording],
create_clip: Callable[..., data.Clip],
create_clip_annotation: Callable[..., data.ClipAnnotation],
create_annotation_project: Callable[..., data.AnnotationProject],
):
"""Test loading an AnnotationProject file with filtering enabled."""
rec_path = tmp_path / "audio" / "rec.wav"
rec_path.parent.mkdir()
rec = create_recording(path=rec_path)
clip_completed = create_clip(rec, 0, 1)
clip_rejected = create_clip(rec, 1, 2)
task_completed = create_task(
clip_completed, [data.StatusBadge(state=AnnotationState.completed)]
)
task_rejected = create_task(
clip_rejected, [data.StatusBadge(state=AnnotationState.rejected)]
)
ann_completed = create_clip_annotation(clip_completed)
ann_rejected = create_clip_annotation(clip_rejected)
project = create_annotation_project(
name="ProjectToFilter",
tasks=[task_completed, task_rejected],
annotations=[ann_completed, ann_rejected],
)
annotations_file = tmp_path / "project.json"
io.save(project, annotations_file)
config = aoef.AOEFAnnotations(
name="test_project_filter_load",
annotations_path=annotations_file,
audio_dir=rec_path.parent,
)
loaded_data = aoef.load_aoef_annotated_dataset(config)
assert isinstance(loaded_data, data.AnnotationSet)
assert loaded_data.name == project.name
assert len(loaded_data.clip_annotations) == 1
assert loaded_data.clip_annotations[0].clip.uuid == clip_completed.uuid
assert (
loaded_data.clip_annotations[0].clip.recording.path
== rec_path.resolve()
)
def test_load_aoef_annotated_dataset_project_no_filter(
tmp_path: Path,
create_recording: Callable[..., data.Recording],
create_clip: Callable[..., data.Clip],
create_clip_annotation: Callable[..., data.ClipAnnotation],
create_annotation_project: Callable[..., data.AnnotationProject],
):
"""Test loading an AnnotationProject file with filtering disabled."""
rec_path = tmp_path / "audio" / "rec.wav"
rec_path.parent.mkdir()
rec = create_recording(path=rec_path)
clip1 = create_clip(rec, 0, 1)
clip2 = create_clip(rec, 1, 2)
task1 = create_task(
clip1, [data.StatusBadge(state=AnnotationState.completed)]
)
task2 = create_task(
clip2, [data.StatusBadge(state=AnnotationState.rejected)]
)
ann1 = create_clip_annotation(clip1)
ann2 = create_clip_annotation(clip2)
original_project = create_annotation_project(
tasks=[task1, task2], annotations=[ann1, ann2]
)
annotations_file = tmp_path / "project_nofilter.json"
io.save(original_project, annotations_file)
config = aoef.AOEFAnnotations(
name="test_project_nofilter_load",
annotations_path=annotations_file,
audio_dir=rec_path.parent,
filter=None,
)
loaded_data = aoef.load_aoef_annotated_dataset(config)
assert isinstance(loaded_data, data.AnnotationProject)
assert loaded_data.uuid == original_project.uuid
assert len(loaded_data.clip_annotations) == 2
assert (
loaded_data.clip_annotations[0].clip.recording.path
== rec_path.resolve()
)
assert (
loaded_data.clip_annotations[1].clip.recording.path
== rec_path.resolve()
)
def test_load_aoef_annotated_dataset_base_dir(
tmp_path: Path,
create_recording: Callable[..., data.Recording],
create_clip: Callable[..., data.Clip],
create_clip_annotation: Callable[..., data.ClipAnnotation],
create_annotation_project: Callable[..., data.AnnotationProject],
):
"""Test loading with a base_dir specified."""
base = tmp_path / "basedir"
base.mkdir()
audio_rel = Path("audio")
ann_rel = Path("annotations/project.json")
abs_audio_dir = base / audio_rel
abs_ann_path = base / ann_rel
abs_audio_dir.mkdir(parents=True)
abs_ann_path.parent.mkdir(parents=True)
rec = create_recording(path=abs_audio_dir / "rec.wav")
rec_path = rec.path
clip = create_clip(rec)
task = create_task(
clip, [data.StatusBadge(state=AnnotationState.completed)]
)
ann = create_clip_annotation(clip)
project = create_annotation_project(tasks=[task], annotations=[ann])
io.save(project, abs_ann_path)
config = aoef.AOEFAnnotations(
name="test_base_dir_load",
annotations_path=ann_rel,
audio_dir=audio_rel,
filter=aoef.AnnotationTaskFilter(),
)
loaded_set = aoef.load_aoef_annotated_dataset(config, base_dir=base)
assert isinstance(loaded_set, data.AnnotationSet)
assert len(loaded_set.clip_annotations) == 1
assert (
loaded_set.clip_annotations[0].clip.recording.path
== rec_path.resolve()
)
def test_load_aoef_annotated_dataset_file_not_found(tmp_path):
"""Test FileNotFoundError when annotation file doesn't exist."""
config = aoef.AOEFAnnotations(
name="test_not_found",
annotations_path=tmp_path / "nonexistent.aoef",
audio_dir=tmp_path,
)
with pytest.raises(FileNotFoundError):
aoef.load_aoef_annotated_dataset(config)
def test_load_aoef_annotated_dataset_file_not_found_with_base_dir(tmp_path):
"""Test FileNotFoundError with base_dir."""
base = tmp_path / "base"
base.mkdir()
config = aoef.AOEFAnnotations(
name="test_not_found_base",
annotations_path=Path("nonexistent.aoef"),
audio_dir=Path("audio"),
)
with pytest.raises(FileNotFoundError):
aoef.load_aoef_annotated_dataset(config, base_dir=base)
def test_load_aoef_annotated_dataset_invalid_content(tmp_path):
"""Test ValueError when file contains invalid JSON or non-soundevent data."""
invalid_file = tmp_path / "invalid.json"
invalid_file.write_text("{invalid json")
config = aoef.AOEFAnnotations(
name="test_invalid_content",
annotations_path=invalid_file,
audio_dir=tmp_path,
)
with pytest.raises(ValidationError):
aoef.load_aoef_annotated_dataset(config)
def test_load_aoef_annotated_dataset_wrong_object_type(
tmp_path: Path,
create_recording: Callable[..., data.Recording],
):
"""Test ValueError when file contains correct soundevent obj but wrong type."""
rec_path = tmp_path / "audio" / "rec.wav"
rec_path.parent.mkdir()
rec = create_recording(path=rec_path)
dataset = data.Dataset(
name="test_wrong_type",
description="Test for wrong type",
recordings=[rec],
)
wrong_type_file = tmp_path / "wrong_type.json"
io.save(dataset, wrong_type_file) # type: ignore
config = aoef.AOEFAnnotations(
name="test_wrong_type",
annotations_path=wrong_type_file,
audio_dir=rec_path.parent,
)
with pytest.raises(ValueError) as excinfo:
aoef.load_aoef_annotated_dataset(config)
assert (
"does not contain a soundevent AnnotationSet or AnnotationProject"
in str(excinfo.value)
)

View File

@ -17,10 +17,10 @@ from batdetect2.train.preprocess import (
def test_mix_examples( def test_mix_examples(
recording_factory: Callable[..., data.Recording], create_recording: Callable[..., data.Recording],
): ):
recording1 = recording_factory() recording1 = create_recording()
recording2 = recording_factory() recording2 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8) clip2 = data.Clip(recording=recording2, start_time=0.3, end_time=0.8)
@ -54,12 +54,12 @@ def test_mix_examples(
@pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7]) @pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7])
@pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7]) @pytest.mark.parametrize("duration2", [0.1, 0.4, 0.7])
def test_mix_examples_of_different_durations( def test_mix_examples_of_different_durations(
recording_factory: Callable[..., data.Recording], create_recording: Callable[..., data.Recording],
duration1: float, duration1: float,
duration2: float, duration2: float,
): ):
recording1 = recording_factory() recording1 = create_recording()
recording2 = recording_factory() recording2 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0, end_time=duration1) clip1 = data.Clip(recording=recording1, start_time=0, end_time=duration1)
clip2 = data.Clip(recording=recording2, start_time=0, end_time=duration2) clip2 = data.Clip(recording=recording2, start_time=0, end_time=duration2)
@ -92,9 +92,9 @@ def test_mix_examples_of_different_durations(
def test_add_echo( def test_add_echo(
recording_factory: Callable[..., data.Recording], create_recording: Callable[..., data.Recording],
): ):
recording1 = recording_factory() recording1 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip_annotation_1 = data.ClipAnnotation(clip=clip1) clip_annotation_1 = data.ClipAnnotation(clip=clip1)
config = TrainPreprocessingConfig() config = TrainPreprocessingConfig()
@ -113,9 +113,9 @@ def test_add_echo(
def test_selected_random_subclip_has_the_correct_width( def test_selected_random_subclip_has_the_correct_width(
recording_factory: Callable[..., data.Recording], create_recording: Callable[..., data.Recording],
): ):
recording1 = recording_factory() recording1 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip_annotation_1 = data.ClipAnnotation(clip=clip1) clip_annotation_1 = data.ClipAnnotation(clip=clip1)
config = TrainPreprocessingConfig() config = TrainPreprocessingConfig()
@ -131,9 +131,9 @@ def test_selected_random_subclip_has_the_correct_width(
def test_add_echo_after_subclip( def test_add_echo_after_subclip(
recording_factory: Callable[..., data.Recording], create_recording: Callable[..., data.Recording],
): ):
recording1 = recording_factory(duration=2) recording1 = create_recording(duration=2)
clip1 = data.Clip(recording=recording1, start_time=0, end_time=1) clip1 = data.Clip(recording=recording1, start_time=0, end_time=1)
clip_annotation_1 = data.ClipAnnotation(clip=clip1) clip_annotation_1 = data.ClipAnnotation(clip=clip1)
config = TrainPreprocessingConfig() config = TrainPreprocessingConfig()