Starting to create dataset builders

This commit is contained in:
mbsantiago 2024-11-19 22:54:26 +00:00
parent 9cf159efff
commit f6cdd4e87e
11 changed files with 174 additions and 45 deletions

View File

@ -1,5 +1,6 @@
"""Compatibility functions between old and new data structures.""" """Compatibility functions between old and new data structures."""
import json
import os import os
import uuid import uuid
from pathlib import Path from pathlib import Path
@ -17,7 +18,7 @@ PathLike = Union[Path, str, os.PathLike]
__all__ = [ __all__ = [
"convert_to_annotation_group", "convert_to_annotation_group",
"load_annotation_project", "load_annotation_project_from_dir",
] ]
SPECIES_TAG_KEY = "species" SPECIES_TAG_KEY = "species"
@ -298,7 +299,38 @@ def list_file_annotations(path: PathLike) -> List[Path]:
return [file for file in path.glob("*.json")] return [file for file in path.glob("*.json")]
def load_annotation_project( def load_annotation_project_from_file(
path: PathLike,
name: Optional[str] = None,
audio_dir: Optional[PathLike] = None,
) -> data.AnnotationProject:
old_annotations = json.loads(Path(path).read_text())
annotations = []
tasks = []
for ann in old_annotations:
try:
ann = FileAnnotation.model_validate(ann)
except ValueError:
continue
try:
clip = file_annotation_to_clip(ann, audio_dir=audio_dir)
except FileNotFoundError:
continue
annotations.append(file_annotation_to_clip_annotation(ann, clip))
tasks.append(file_annotation_to_annotation_task(ann, clip))
return data.AnnotationProject(
name=name or str(path),
clip_annotations=annotations,
tasks=tasks,
)
def load_annotation_project_from_dir(
path: PathLike, path: PathLike,
name: Optional[str] = None, name: Optional[str] = None,
audio_dir: Optional[PathLike] = None, audio_dir: Optional[PathLike] = None,

View File

@ -1,5 +1,26 @@
from typing import Optional, Type, TypeVar
import yaml
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from soundevent.data import PathLike
class BaseConfig(BaseModel): class BaseConfig(BaseModel):
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")
T = TypeVar("T", bound=BaseModel)
def load_config(
path: PathLike,
schema: Type[T],
field: Optional[str] = None,
) -> T:
with open(path, "r") as file:
config = yaml.safe_load(file)
if field:
config = config[field]
return schema.model_validate(config)

87
batdetect2/data.py Normal file
View File

@ -0,0 +1,87 @@
from pathlib import Path
from typing import List, Literal, Tuple, Union
from pydantic import Field
from soundevent import data, io
from batdetect2.compat.data import (
load_annotation_project_from_dir,
load_annotation_project_from_file,
)
from batdetect2.configs import BaseConfig
class BatDetect2AnnotationFiles(BaseConfig):
format: Literal["batdetect2"] = "batdetect2"
path: Path
class BatDetect2AnnotationFile(BaseConfig):
format: Literal["batdetect2_file"] = "batdetect2_file"
path: Path
class AOEFAnnotationFile(BaseConfig):
format: Literal["aoef"] = "aoef"
annotations_file: Path
AnnotationFormats = Union[
BatDetect2AnnotationFiles,
BatDetect2AnnotationFile,
AOEFAnnotationFile,
]
class DatasetInfo(BaseConfig):
name: str
audio_dir: Path
annotations: AnnotationFormats = Field(discriminator="format")
class DatasetsConfig(BaseConfig):
train: List[DatasetInfo] = Field(default_factory=list)
test: List[DatasetInfo] = Field(default_factory=list)
def load_dataset(info: DatasetInfo) -> data.AnnotationProject:
if info.annotations.format == "batdetect2":
return load_annotation_project_from_dir(
info.annotations.path,
name=info.name,
audio_dir=info.audio_dir,
)
if info.annotations.format == "batdetect2_file":
return load_annotation_project_from_file(
info.annotations.path,
name=info.name,
audio_dir=info.audio_dir,
)
if info.annotations.format == "aoef":
return io.load( # type: ignore
info.annotations.annotations_file,
audio_dir=info.audio_dir,
)
raise NotImplementedError(
f"Unknown annotation format: {info.annotations.name}"
)
def load_datasets(
config: DatasetsConfig,
) -> Tuple[List[data.ClipAnnotation], List[data.ClipAnnotation]]:
test_annotations = []
train_annotations = []
for dataset in config.train:
project = load_dataset(dataset)
train_annotations.extend(project.clip_annotations)
for dataset in config.test:
project = load_dataset(dataset)
test_annotations.extend(project.clip_annotations)
return train_annotations, test_annotations

View File

@ -1,33 +0,0 @@
from typing import Callable, Generic, Iterable, List, TypeVar
from soundevent import data
from torch.utils.data import Dataset
__all__ = [
"ClipDataset",
]
E = TypeVar("E")
class ClipDataset(Dataset, Generic[E]):
clips: List[data.Clip]
transform: Callable[[data.Clip], E]
def __init__(
self,
clips: Iterable[data.Clip],
transform: Callable[[data.Clip], E],
name: str = "ClipDataset",
):
self.clips = list(clips)
self.transform = transform
self.name = name
def __len__(self) -> int:
return len(self.clips)
def __getitem__(self, idx: int) -> E:
return self.transform(self.clips[idx])

View File

@ -49,6 +49,7 @@ class PreprocessingConfig(BaseModel):
def preprocess_audio_clip( def preprocess_audio_clip(
clip: data.Clip, clip: data.Clip,
config: Optional[PreprocessingConfig] = None, config: Optional[PreprocessingConfig] = None,
audio_dir: Optional[data.PathLike] = None,
) -> xr.DataArray: ) -> xr.DataArray:
"""Preprocesses audio clip to generate spectrogram. """Preprocesses audio clip to generate spectrogram.
@ -66,5 +67,5 @@ def preprocess_audio_clip(
""" """
config = config or PreprocessingConfig() config = config or PreprocessingConfig()
wav = load_clip_audio(clip, config=config.audio) wav = load_clip_audio(clip, config=config.audio, audio_dir=audio_dir)
return compute_spectrogram(wav, config=config.spectrogram) return compute_spectrogram(wav, config=config.spectrogram)

View File

@ -30,15 +30,22 @@ class AudioConfig(BaseConfig):
def load_file_audio( def load_file_audio(
path: data.PathLike, path: data.PathLike,
config: Optional[AudioConfig] = None, config: Optional[AudioConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, dtype: DTypeLike = np.float32,
) -> xr.DataArray: ) -> xr.DataArray:
recording = data.Recording.from_file(path) recording = data.Recording.from_file(path)
return load_recording_audio(recording, config=config, dtype=dtype) return load_recording_audio(
recording,
config=config,
dtype=dtype,
audio_dir=audio_dir,
)
def load_recording_audio( def load_recording_audio(
recording: data.Recording, recording: data.Recording,
config: Optional[AudioConfig] = None, config: Optional[AudioConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, dtype: DTypeLike = np.float32,
) -> xr.DataArray: ) -> xr.DataArray:
clip = data.Clip( clip = data.Clip(
@ -46,17 +53,25 @@ def load_recording_audio(
start_time=0, start_time=0,
end_time=recording.duration, end_time=recording.duration,
) )
return load_clip_audio(clip, config=config, dtype=dtype) return load_clip_audio(
clip,
config=config,
dtype=dtype,
audio_dir=audio_dir,
)
def load_clip_audio( def load_clip_audio(
clip: data.Clip, clip: data.Clip,
config: Optional[AudioConfig] = None, config: Optional[AudioConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, dtype: DTypeLike = np.float32,
) -> xr.DataArray: ) -> xr.DataArray:
config = config or AudioConfig() config = config or AudioConfig()
wav = audio.load_clip(clip).sel(channel=0).astype(dtype) wav = (
audio.load_clip(clip, audio_dir=audio_dir).sel(channel=0).astype(dtype)
)
if config.duration is not None: if config.duration is not None:
wav = adjust_audio_duration(wav, duration=config.duration) wav = adjust_audio_duration(wav, duration=config.duration)

View File

@ -28,10 +28,6 @@ class TrainExample(NamedTuple):
idx: torch.Tensor idx: torch.Tensor
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
return list(Path(directory).glob(f"*{extension}"))
class LabeledDataset(Dataset): class LabeledDataset(Dataset):
def __init__( def __init__(
self, self,
@ -92,3 +88,7 @@ class LabeledDataset(Dataset):
return data.ClipAnnotation.model_validate_json( return data.ClipAnnotation.model_validate_json(
self.get_dataset(idx).attrs["clip_annotation"] self.get_dataset(idx).attrs["clip_annotation"]
) )
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
return list(Path(directory).glob(f"*{extension}"))

View File

@ -26,6 +26,8 @@ dependencies = [
"onnx>=1.16.0", "onnx>=1.16.0",
"lightning[extra]>=2.2.2", "lightning[extra]>=2.2.2",
"tensorboard>=2.16.2", "tensorboard>=2.16.2",
"omegaconf>=2.3.0",
"pyyaml>=6.0.2",
] ]
requires-python = ">=3.9,<3.13" requires-python = ">=3.9,<3.13"
readme = "README.md" readme = "README.md"

View File

@ -5,7 +5,7 @@ from typing import List
import numpy as np import numpy as np
import pytest import pytest
from batdetect2.compat.data import load_annotation_project from batdetect2.compat.data import load_annotation_project_from_dir
from batdetect2.compat.params import get_training_preprocessing_config from batdetect2.compat.params import get_training_preprocessing_config
from batdetect2.train.preprocess import generate_train_example from batdetect2.train.preprocess import generate_train_example
@ -36,7 +36,7 @@ def test_can_generate_similar_training_inputs(
size_mask = dataset["size_mask"] size_mask = dataset["size_mask"]
class_mask = dataset["class_mask"] class_mask = dataset["class_mask"]
project = load_annotation_project( project = load_annotation_project_from_dir(
example_anns_dir, example_anns_dir,
audio_dir=example_audio_dir, audio_dir=example_audio_dir,
) )

4
uv.lock generated
View File

@ -198,9 +198,11 @@ dependencies = [
{ name = "matplotlib" }, { name = "matplotlib" },
{ name = "netcdf4" }, { name = "netcdf4" },
{ name = "numpy" }, { name = "numpy" },
{ name = "omegaconf" },
{ name = "onnx" }, { name = "onnx" },
{ name = "pandas" }, { name = "pandas" },
{ name = "pytorch-lightning" }, { name = "pytorch-lightning" },
{ name = "pyyaml" },
{ name = "scikit-learn" }, { name = "scikit-learn" },
{ name = "scipy" }, { name = "scipy" },
{ name = "soundevent", extra = ["audio", "geometry", "plot"] }, { name = "soundevent", extra = ["audio", "geometry", "plot"] },
@ -231,9 +233,11 @@ requires-dist = [
{ name = "matplotlib", specifier = ">=3.7.1" }, { name = "matplotlib", specifier = ">=3.7.1" },
{ name = "netcdf4", specifier = ">=1.6.5" }, { name = "netcdf4", specifier = ">=1.6.5" },
{ name = "numpy", specifier = ">=1.23.5" }, { name = "numpy", specifier = ">=1.23.5" },
{ name = "omegaconf", specifier = ">=2.3.0" },
{ name = "onnx", specifier = ">=1.16.0" }, { name = "onnx", specifier = ">=1.16.0" },
{ name = "pandas", specifier = ">=1.5.3" }, { name = "pandas", specifier = ">=1.5.3" },
{ name = "pytorch-lightning", specifier = ">=2.2.2" }, { name = "pytorch-lightning", specifier = ">=2.2.2" },
{ name = "pyyaml", specifier = ">=6.0.2" },
{ name = "scikit-learn", specifier = ">=1.2.2" }, { name = "scikit-learn", specifier = ">=1.2.2" },
{ name = "scipy", specifier = ">=1.10.1" }, { name = "scipy", specifier = ">=1.10.1" },
{ name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.3" }, { name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.3" },