From f6cdd4e87e4faa29c88ecb7b694a801371c93c0a Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 19 Nov 2024 22:54:26 +0000 Subject: [PATCH] Starting to create dataset builders --- batdetect2/compat/data.py | 36 ++++++++++- batdetect2/configs.py | 21 +++++++ batdetect2/data.py | 87 +++++++++++++++++++++++++++ batdetect2/data/__init__.py | 0 batdetect2/data/datasets.py | 33 ---------- batdetect2/preprocess/__init__.py | 3 +- batdetect2/preprocess/audio.py | 21 ++++++- batdetect2/train/dataset.py | 8 +-- pyproject.toml | 2 + tests/test_migration/test_training.py | 4 +- uv.lock | 4 ++ 11 files changed, 174 insertions(+), 45 deletions(-) create mode 100644 batdetect2/data.py delete mode 100644 batdetect2/data/__init__.py delete mode 100644 batdetect2/data/datasets.py diff --git a/batdetect2/compat/data.py b/batdetect2/compat/data.py index 9415686..5139fb2 100644 --- a/batdetect2/compat/data.py +++ b/batdetect2/compat/data.py @@ -1,5 +1,6 @@ """Compatibility functions between old and new data structures.""" +import json import os import uuid from pathlib import Path @@ -17,7 +18,7 @@ PathLike = Union[Path, str, os.PathLike] __all__ = [ "convert_to_annotation_group", - "load_annotation_project", + "load_annotation_project_from_dir", ] SPECIES_TAG_KEY = "species" @@ -298,7 +299,38 @@ def list_file_annotations(path: PathLike) -> List[Path]: return [file for file in path.glob("*.json")] -def load_annotation_project( +def load_annotation_project_from_file( + path: PathLike, + name: Optional[str] = None, + audio_dir: Optional[PathLike] = None, +) -> data.AnnotationProject: + old_annotations = json.loads(Path(path).read_text()) + + annotations = [] + tasks = [] + + for ann in old_annotations: + try: + ann = FileAnnotation.model_validate(ann) + except ValueError: + continue + + try: + clip = file_annotation_to_clip(ann, audio_dir=audio_dir) + except FileNotFoundError: + continue + + annotations.append(file_annotation_to_clip_annotation(ann, clip)) + tasks.append(file_annotation_to_annotation_task(ann, clip)) + + return data.AnnotationProject( + name=name or str(path), + clip_annotations=annotations, + tasks=tasks, + ) + + +def load_annotation_project_from_dir( path: PathLike, name: Optional[str] = None, audio_dir: Optional[PathLike] = None, diff --git a/batdetect2/configs.py b/batdetect2/configs.py index ab94d82..d05a7e9 100644 --- a/batdetect2/configs.py +++ b/batdetect2/configs.py @@ -1,5 +1,26 @@ +from typing import Optional, Type, TypeVar + +import yaml from pydantic import BaseModel, ConfigDict +from soundevent.data import PathLike class BaseConfig(BaseModel): model_config = ConfigDict(extra="forbid") + + +T = TypeVar("T", bound=BaseModel) + + +def load_config( + path: PathLike, + schema: Type[T], + field: Optional[str] = None, +) -> T: + with open(path, "r") as file: + config = yaml.safe_load(file) + + if field: + config = config[field] + + return schema.model_validate(config) diff --git a/batdetect2/data.py b/batdetect2/data.py new file mode 100644 index 0000000..611f0a9 --- /dev/null +++ b/batdetect2/data.py @@ -0,0 +1,87 @@ +from pathlib import Path +from typing import List, Literal, Tuple, Union + +from pydantic import Field +from soundevent import data, io + +from batdetect2.compat.data import ( + load_annotation_project_from_dir, + load_annotation_project_from_file, +) +from batdetect2.configs import BaseConfig + + +class BatDetect2AnnotationFiles(BaseConfig): + format: Literal["batdetect2"] = "batdetect2" + path: Path + + +class BatDetect2AnnotationFile(BaseConfig): + format: Literal["batdetect2_file"] = "batdetect2_file" + path: Path + + +class AOEFAnnotationFile(BaseConfig): + format: Literal["aoef"] = "aoef" + annotations_file: Path + + +AnnotationFormats = Union[ + BatDetect2AnnotationFiles, + BatDetect2AnnotationFile, + AOEFAnnotationFile, +] + + +class DatasetInfo(BaseConfig): + name: str + audio_dir: Path + annotations: AnnotationFormats = Field(discriminator="format") + + +class DatasetsConfig(BaseConfig): + train: List[DatasetInfo] = Field(default_factory=list) + test: List[DatasetInfo] = Field(default_factory=list) + + +def load_dataset(info: DatasetInfo) -> data.AnnotationProject: + if info.annotations.format == "batdetect2": + return load_annotation_project_from_dir( + info.annotations.path, + name=info.name, + audio_dir=info.audio_dir, + ) + + if info.annotations.format == "batdetect2_file": + return load_annotation_project_from_file( + info.annotations.path, + name=info.name, + audio_dir=info.audio_dir, + ) + + if info.annotations.format == "aoef": + return io.load( # type: ignore + info.annotations.annotations_file, + audio_dir=info.audio_dir, + ) + + raise NotImplementedError( + f"Unknown annotation format: {info.annotations.name}" + ) + + +def load_datasets( + config: DatasetsConfig, +) -> Tuple[List[data.ClipAnnotation], List[data.ClipAnnotation]]: + test_annotations = [] + train_annotations = [] + + for dataset in config.train: + project = load_dataset(dataset) + train_annotations.extend(project.clip_annotations) + + for dataset in config.test: + project = load_dataset(dataset) + test_annotations.extend(project.clip_annotations) + + return train_annotations, test_annotations diff --git a/batdetect2/data/__init__.py b/batdetect2/data/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/batdetect2/data/datasets.py b/batdetect2/data/datasets.py deleted file mode 100644 index b3747d3..0000000 --- a/batdetect2/data/datasets.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Callable, Generic, Iterable, List, TypeVar - -from soundevent import data -from torch.utils.data import Dataset - -__all__ = [ - "ClipDataset", -] - - -E = TypeVar("E") - - -class ClipDataset(Dataset, Generic[E]): - clips: List[data.Clip] - - transform: Callable[[data.Clip], E] - - def __init__( - self, - clips: Iterable[data.Clip], - transform: Callable[[data.Clip], E], - name: str = "ClipDataset", - ): - self.clips = list(clips) - self.transform = transform - self.name = name - - def __len__(self) -> int: - return len(self.clips) - - def __getitem__(self, idx: int) -> E: - return self.transform(self.clips[idx]) diff --git a/batdetect2/preprocess/__init__.py b/batdetect2/preprocess/__init__.py index a9c2ee5..b3b4142 100644 --- a/batdetect2/preprocess/__init__.py +++ b/batdetect2/preprocess/__init__.py @@ -49,6 +49,7 @@ class PreprocessingConfig(BaseModel): def preprocess_audio_clip( clip: data.Clip, config: Optional[PreprocessingConfig] = None, + audio_dir: Optional[data.PathLike] = None, ) -> xr.DataArray: """Preprocesses audio clip to generate spectrogram. @@ -66,5 +67,5 @@ def preprocess_audio_clip( """ config = config or PreprocessingConfig() - wav = load_clip_audio(clip, config=config.audio) + wav = load_clip_audio(clip, config=config.audio, audio_dir=audio_dir) return compute_spectrogram(wav, config=config.spectrogram) diff --git a/batdetect2/preprocess/audio.py b/batdetect2/preprocess/audio.py index 2a7bff1..5c613f3 100644 --- a/batdetect2/preprocess/audio.py +++ b/batdetect2/preprocess/audio.py @@ -30,15 +30,22 @@ class AudioConfig(BaseConfig): def load_file_audio( path: data.PathLike, config: Optional[AudioConfig] = None, + audio_dir: Optional[data.PathLike] = None, dtype: DTypeLike = np.float32, ) -> xr.DataArray: recording = data.Recording.from_file(path) - return load_recording_audio(recording, config=config, dtype=dtype) + return load_recording_audio( + recording, + config=config, + dtype=dtype, + audio_dir=audio_dir, + ) def load_recording_audio( recording: data.Recording, config: Optional[AudioConfig] = None, + audio_dir: Optional[data.PathLike] = None, dtype: DTypeLike = np.float32, ) -> xr.DataArray: clip = data.Clip( @@ -46,17 +53,25 @@ def load_recording_audio( start_time=0, end_time=recording.duration, ) - return load_clip_audio(clip, config=config, dtype=dtype) + return load_clip_audio( + clip, + config=config, + dtype=dtype, + audio_dir=audio_dir, + ) def load_clip_audio( clip: data.Clip, config: Optional[AudioConfig] = None, + audio_dir: Optional[data.PathLike] = None, dtype: DTypeLike = np.float32, ) -> xr.DataArray: config = config or AudioConfig() - wav = audio.load_clip(clip).sel(channel=0).astype(dtype) + wav = ( + audio.load_clip(clip, audio_dir=audio_dir).sel(channel=0).astype(dtype) + ) if config.duration is not None: wav = adjust_audio_duration(wav, duration=config.duration) diff --git a/batdetect2/train/dataset.py b/batdetect2/train/dataset.py index 2c7b04c..1828ef2 100644 --- a/batdetect2/train/dataset.py +++ b/batdetect2/train/dataset.py @@ -28,10 +28,6 @@ class TrainExample(NamedTuple): idx: torch.Tensor -def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]: - return list(Path(directory).glob(f"*{extension}")) - - class LabeledDataset(Dataset): def __init__( self, @@ -92,3 +88,7 @@ class LabeledDataset(Dataset): return data.ClipAnnotation.model_validate_json( self.get_dataset(idx).attrs["clip_annotation"] ) + + +def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]: + return list(Path(directory).glob(f"*{extension}")) diff --git a/pyproject.toml b/pyproject.toml index fa3a339..ce294a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,8 @@ dependencies = [ "onnx>=1.16.0", "lightning[extra]>=2.2.2", "tensorboard>=2.16.2", + "omegaconf>=2.3.0", + "pyyaml>=6.0.2", ] requires-python = ">=3.9,<3.13" readme = "README.md" diff --git a/tests/test_migration/test_training.py b/tests/test_migration/test_training.py index 646831e..fa46a3e 100644 --- a/tests/test_migration/test_training.py +++ b/tests/test_migration/test_training.py @@ -5,7 +5,7 @@ from typing import List import numpy as np import pytest -from batdetect2.compat.data import load_annotation_project +from batdetect2.compat.data import load_annotation_project_from_dir from batdetect2.compat.params import get_training_preprocessing_config from batdetect2.train.preprocess import generate_train_example @@ -36,7 +36,7 @@ def test_can_generate_similar_training_inputs( size_mask = dataset["size_mask"] class_mask = dataset["class_mask"] - project = load_annotation_project( + project = load_annotation_project_from_dir( example_anns_dir, audio_dir=example_audio_dir, ) diff --git a/uv.lock b/uv.lock index 40d2b66..4d484bc 100644 --- a/uv.lock +++ b/uv.lock @@ -198,9 +198,11 @@ dependencies = [ { name = "matplotlib" }, { name = "netcdf4" }, { name = "numpy" }, + { name = "omegaconf" }, { name = "onnx" }, { name = "pandas" }, { name = "pytorch-lightning" }, + { name = "pyyaml" }, { name = "scikit-learn" }, { name = "scipy" }, { name = "soundevent", extra = ["audio", "geometry", "plot"] }, @@ -231,9 +233,11 @@ requires-dist = [ { name = "matplotlib", specifier = ">=3.7.1" }, { name = "netcdf4", specifier = ">=1.6.5" }, { name = "numpy", specifier = ">=1.23.5" }, + { name = "omegaconf", specifier = ">=2.3.0" }, { name = "onnx", specifier = ">=1.16.0" }, { name = "pandas", specifier = ">=1.5.3" }, { name = "pytorch-lightning", specifier = ">=2.2.2" }, + { name = "pyyaml", specifier = ">=6.0.2" }, { name = "scikit-learn", specifier = ">=1.2.2" }, { name = "scipy", specifier = ">=1.10.1" }, { name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.3" },